mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-16 15:26:40 +00:00
Implement experimental PostgreSQL store (#1939)
* migrate sqlite store to generic sql store * fix conflicts * init postgres store * Add postgres store tests * Refactor postgres store engine name * fix tests * Run postgres store tests on linux only * fix tests * Refactor * cascade policy rules on policy deletion * fix tests * run postgres cases in new db * close store connection after tests * refactor * using testcontainers * sync go sum * remove postgres service * remove store cleanup * go mod tidy * remove env * use postgres as engine and initialize test store with testcontainer --------- Co-authored-by: Maycon Santos <mlsmaycon@gmail.com>
This commit is contained in:
@@ -62,10 +62,11 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
s := grpc.NewServer()
|
||||
store, err := mgmt.NewStoreFromJson(config.Datadir, nil)
|
||||
store, cleanUp, err := mgmt.NewTestStoreFromJson(config.Datadir)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Cleanup(cleanUp)
|
||||
|
||||
peersUpdateManager := mgmt.NewPeersUpdateManager(nil)
|
||||
eventStore := &activity.InMemoryEventStore{}
|
||||
|
||||
@@ -7,10 +7,11 @@ import (
|
||||
"os"
|
||||
"path"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
var shortDown = "Rollback SQLite store to JSON file store. Please make a backup of the SQLite file before running this command."
|
||||
@@ -39,16 +40,16 @@ var downCmd = &cobra.Command{
|
||||
return fmt.Errorf("%s already exists, couldn't continue the operation", fileStorePath)
|
||||
}
|
||||
|
||||
sqlstore, err := server.NewSqliteStore(mgmtDataDir, nil)
|
||||
sqlStore, err := server.NewSqliteStore(mgmtDataDir, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed creating file store: %s: %v", mgmtDataDir, err)
|
||||
}
|
||||
|
||||
sqliteStoreAccounts := len(sqlstore.GetAllAccounts())
|
||||
sqliteStoreAccounts := len(sqlStore.GetAllAccounts())
|
||||
log.Infof("%d account will be migrated from sqlite store %s to file store %s",
|
||||
sqliteStoreAccounts, sqliteStorePath, fileStorePath)
|
||||
|
||||
store, err := server.NewFilestoreFromSqliteStore(sqlstore, mgmtDataDir, nil)
|
||||
store, err := server.NewFilestoreFromSqliteStore(sqlStore, mgmtDataDir, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed creating file store: %s: %v", mgmtDataDir, err)
|
||||
}
|
||||
|
||||
@@ -1294,6 +1294,7 @@ func TestAccountManager_DeletePeer(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
|
||||
userID := "account_creator"
|
||||
account, err := createAccount(manager, "test_account", userID, "netbird.cloud")
|
||||
if err != nil {
|
||||
@@ -1655,6 +1656,7 @@ func TestDefaultAccountManager_DefaultAccountSettings(t *testing.T) {
|
||||
func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
require.NoError(t, err, "unable to create account manager")
|
||||
|
||||
_, err = manager.GetAccountByUserOrAccountID(userID, "", "")
|
||||
require.NoError(t, err, "unable to create an account")
|
||||
|
||||
@@ -1707,6 +1709,7 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
|
||||
func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
require.NoError(t, err, "unable to create account manager")
|
||||
|
||||
account, err := manager.GetAccountByUserOrAccountID(userID, "", "")
|
||||
require.NoError(t, err, "unable to create an account")
|
||||
|
||||
@@ -1750,6 +1753,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
|
||||
func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
require.NoError(t, err, "unable to create account manager")
|
||||
|
||||
_, err = manager.GetAccountByUserOrAccountID(userID, "", "")
|
||||
require.NoError(t, err, "unable to create an account")
|
||||
|
||||
@@ -2267,21 +2271,29 @@ func TestAccount_UserGroupsRemoveFromPeers(t *testing.T) {
|
||||
|
||||
func createManager(t *testing.T) (*DefaultAccountManager, error) {
|
||||
t.Helper()
|
||||
|
||||
store, err := createStore(t)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
eventStore := &activity.InMemoryEventStore{}
|
||||
return BuildManager(store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{})
|
||||
|
||||
manager, err := BuildManager(store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return manager, nil
|
||||
}
|
||||
|
||||
func createStore(t *testing.T) (Store, error) {
|
||||
t.Helper()
|
||||
dataDir := t.TempDir()
|
||||
store, err := NewStoreFromJson(dataDir, nil)
|
||||
store, cleanUp, err := NewTestStoreFromJson(dataDir)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
t.Cleanup(cleanUp)
|
||||
|
||||
return store, nil
|
||||
}
|
||||
|
||||
@@ -32,7 +32,7 @@ func TestGetDNSSettings(t *testing.T) {
|
||||
|
||||
account, err := initTestDNSAccount(t, am)
|
||||
if err != nil {
|
||||
t.Error("failed to init testing account")
|
||||
t.Fatal("failed to init testing account")
|
||||
}
|
||||
|
||||
dnsSettings, err := am.GetDNSSettings(account.Id, dnsAdminUserID)
|
||||
@@ -200,10 +200,11 @@ func createDNSManager(t *testing.T) (*DefaultAccountManager, error) {
|
||||
func createDNSStore(t *testing.T) (Store, error) {
|
||||
t.Helper()
|
||||
dataDir := t.TempDir()
|
||||
store, err := NewStoreFromJson(dataDir, nil)
|
||||
store, cleanUp, err := NewTestStoreFromJson(dataDir)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
t.Cleanup(cleanUp)
|
||||
|
||||
return store, nil
|
||||
}
|
||||
|
||||
@@ -57,18 +57,18 @@ func NewFileStore(dataDir string, metrics telemetry.AppMetrics) (*FileStore, err
|
||||
}
|
||||
|
||||
// NewFilestoreFromSqliteStore restores a store from Sqlite and stores to Filestore json in the file located in datadir
|
||||
func NewFilestoreFromSqliteStore(sqlitestore *SqliteStore, dataDir string, metrics telemetry.AppMetrics) (*FileStore, error) {
|
||||
func NewFilestoreFromSqliteStore(sqlStore *SqlStore, dataDir string, metrics telemetry.AppMetrics) (*FileStore, error) {
|
||||
store, err := NewFileStore(dataDir, metrics)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = store.SaveInstallationID(sqlitestore.GetInstallationID())
|
||||
err = store.SaveInstallationID(sqlStore.GetInstallationID())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, account := range sqlitestore.GetAllAccounts() {
|
||||
for _, account := range sqlStore.GetAllAccounts() {
|
||||
store.Accounts[account.Id] = account
|
||||
}
|
||||
|
||||
|
||||
@@ -59,6 +59,7 @@ func TestStalePeerIndices(t *testing.T) {
|
||||
|
||||
func TestNewStore(t *testing.T) {
|
||||
store := newStore(t)
|
||||
defer store.Close()
|
||||
|
||||
if store.Accounts == nil || len(store.Accounts) != 0 {
|
||||
t.Errorf("expected to create a new empty Accounts map when creating a new FileStore")
|
||||
@@ -87,6 +88,7 @@ func TestNewStore(t *testing.T) {
|
||||
|
||||
func TestSaveAccount(t *testing.T) {
|
||||
store := newStore(t)
|
||||
defer store.Close()
|
||||
|
||||
account := newAccountWithId("account_id", "testuser", "")
|
||||
setupKey := GenerateDefaultSetupKey()
|
||||
@@ -135,6 +137,8 @@ func TestDeleteAccount(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer store.Close()
|
||||
|
||||
var account *Account
|
||||
for _, a := range store.Accounts {
|
||||
account = a
|
||||
@@ -179,6 +183,7 @@ func TestDeleteAccount(t *testing.T) {
|
||||
|
||||
func TestStore(t *testing.T) {
|
||||
store := newStore(t)
|
||||
defer store.Close()
|
||||
|
||||
account := newAccountWithId("account_id", "testuser", "")
|
||||
account.Peers["testpeer"] = &nbpeer.Peer{
|
||||
@@ -436,6 +441,7 @@ func TestFileStore_GetTokenIDByHashedToken(t *testing.T) {
|
||||
|
||||
func TestFileStore_DeleteHashedPAT2TokenIDIndex(t *testing.T) {
|
||||
store := newStore(t)
|
||||
defer store.Close()
|
||||
store.HashedPAT2TokenID["someHashedToken"] = "someTokenId"
|
||||
|
||||
err := store.DeleteHashedPAT2TokenIDIndex("someHashedToken")
|
||||
|
||||
@@ -405,10 +405,12 @@ func startManagement(t *testing.T, config *Config) (*grpc.Server, string, error)
|
||||
return nil, "", err
|
||||
}
|
||||
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
|
||||
store, err := NewStoreFromJson(config.Datadir, nil)
|
||||
store, cleanUp, err := NewTestStoreFromJson(config.Datadir)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
t.Cleanup(cleanUp)
|
||||
|
||||
peersUpdateManager := NewPeersUpdateManager(nil)
|
||||
eventStore := &activity.InMemoryEventStore{}
|
||||
accountManager, err := BuildManager(store, peersUpdateManager, nil, "", "netbird.selfhosted",
|
||||
|
||||
@@ -532,10 +532,11 @@ func startServer(config *server.Config) (*grpc.Server, net.Listener) {
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
s := grpc.NewServer()
|
||||
|
||||
store, err := server.NewStoreFromJson(config.Datadir, nil)
|
||||
store, _, err := server.NewTestStoreFromJson(config.Datadir)
|
||||
if err != nil {
|
||||
log.Fatalf("failed creating a store: %s: %v", config.Datadir, err)
|
||||
}
|
||||
|
||||
peersUpdateManager := server.NewPeersUpdateManager(nil)
|
||||
eventStore := &activity.InMemoryEventStore{}
|
||||
accountManager, err := server.BuildManager(store, peersUpdateManager, nil, "", "netbird.selfhosted",
|
||||
|
||||
@@ -766,10 +766,11 @@ func createNSManager(t *testing.T) (*DefaultAccountManager, error) {
|
||||
func createNSStore(t *testing.T) (Store, error) {
|
||||
t.Helper()
|
||||
dataDir := t.TempDir()
|
||||
store, err := NewStoreFromJson(dataDir, nil)
|
||||
store, cleanUp, err := NewTestStoreFromJson(dataDir)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
t.Cleanup(cleanUp)
|
||||
|
||||
return store, nil
|
||||
}
|
||||
|
||||
@@ -148,7 +148,7 @@ type Policy struct {
|
||||
Enabled bool
|
||||
|
||||
// Rules of the policy
|
||||
Rules []*PolicyRule `gorm:"foreignKey:PolicyID;references:id"`
|
||||
Rules []*PolicyRule `gorm:"foreignKey:PolicyID;references:id;constraint:OnDelete:CASCADE;"`
|
||||
|
||||
// SourcePostureChecks are ID references to Posture checks for policy source groups
|
||||
SourcePostureChecks []string `gorm:"serializer:json"`
|
||||
|
||||
@@ -3,8 +3,9 @@ package server
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
)
|
||||
|
||||
const (
|
||||
|
||||
@@ -1021,10 +1021,11 @@ func createRouterManager(t *testing.T) (*DefaultAccountManager, error) {
|
||||
func createRouterStore(t *testing.T) (Store, error) {
|
||||
t.Helper()
|
||||
dataDir := t.TempDir()
|
||||
store, err := NewStoreFromJson(dataDir, nil)
|
||||
store, cleanUp, err := NewTestStoreFromJson(dataDir)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
t.Cleanup(cleanUp)
|
||||
|
||||
return store, nil
|
||||
}
|
||||
|
||||
@@ -3,8 +3,6 @@ package server
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
@@ -12,6 +10,7 @@ import (
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"gorm.io/driver/postgres"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
@@ -20,7 +19,6 @@ import (
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||
"github.com/netbirdio/netbird/management/server/migration"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
@@ -28,14 +26,14 @@ import (
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
// SqliteStore represents an account storage backed by a Sqlite DB persisted to disk
|
||||
type SqliteStore struct {
|
||||
// SqlStore represents an account storage backed by a Sql DB persisted to disk
|
||||
type SqlStore struct {
|
||||
db *gorm.DB
|
||||
storeFile string
|
||||
accountLocks sync.Map
|
||||
globalAccountLock sync.Mutex
|
||||
metrics telemetry.AppMetrics
|
||||
installationPK int
|
||||
storeEngine StoreEngine
|
||||
}
|
||||
|
||||
type installation struct {
|
||||
@@ -45,24 +43,8 @@ type installation struct {
|
||||
|
||||
type migrationFunc func(*gorm.DB) error
|
||||
|
||||
// NewSqliteStore restores a store from the file located in the datadir
|
||||
func NewSqliteStore(dataDir string, metrics telemetry.AppMetrics) (*SqliteStore, error) {
|
||||
storeStr := "store.db?cache=shared"
|
||||
if runtime.GOOS == "windows" {
|
||||
// Vo avoid `The process cannot access the file because it is being used by another process` on Windows
|
||||
storeStr = "store.db"
|
||||
}
|
||||
|
||||
file := filepath.Join(dataDir, storeStr)
|
||||
db, err := gorm.Open(sqlite.Open(file), &gorm.Config{
|
||||
Logger: logger.Default.LogMode(logger.Silent),
|
||||
CreateBatchSize: 400,
|
||||
PrepareStmt: true,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// NewSqlStore creates a new SqlStore instance.
|
||||
func NewSqlStore(db *gorm.DB, storeEngine StoreEngine, metrics telemetry.AppMetrics) (*SqlStore, error) {
|
||||
sql, err := db.DB()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -82,33 +64,11 @@ func NewSqliteStore(dataDir string, metrics telemetry.AppMetrics) (*SqliteStore,
|
||||
return nil, fmt.Errorf("auto migrate: %w", err)
|
||||
}
|
||||
|
||||
return &SqliteStore{db: db, storeFile: file, metrics: metrics, installationPK: 1}, nil
|
||||
}
|
||||
|
||||
// NewSqliteStoreFromFileStore restores a store from FileStore and stores SQLite DB in the file located in datadir
|
||||
func NewSqliteStoreFromFileStore(filestore *FileStore, dataDir string, metrics telemetry.AppMetrics) (*SqliteStore, error) {
|
||||
store, err := NewSqliteStore(dataDir, metrics)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = store.SaveInstallationID(filestore.InstallationID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, account := range filestore.GetAllAccounts() {
|
||||
err := store.SaveAccount(account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return store, nil
|
||||
return &SqlStore{db: db, storeEngine: storeEngine, metrics: metrics, installationPK: 1}, nil
|
||||
}
|
||||
|
||||
// AcquireGlobalLock acquires global lock across all the accounts and returns a function that releases the lock
|
||||
func (s *SqliteStore) AcquireGlobalLock() (unlock func()) {
|
||||
func (s *SqlStore) AcquireGlobalLock() (unlock func()) {
|
||||
log.Tracef("acquiring global lock")
|
||||
start := time.Now()
|
||||
s.globalAccountLock.Lock()
|
||||
@@ -127,7 +87,7 @@ func (s *SqliteStore) AcquireGlobalLock() (unlock func()) {
|
||||
return unlock
|
||||
}
|
||||
|
||||
func (s *SqliteStore) AcquireAccountWriteLock(accountID string) (unlock func()) {
|
||||
func (s *SqlStore) AcquireAccountWriteLock(accountID string) (unlock func()) {
|
||||
log.Tracef("acquiring write lock for account %s", accountID)
|
||||
|
||||
start := time.Now()
|
||||
@@ -143,7 +103,7 @@ func (s *SqliteStore) AcquireAccountWriteLock(accountID string) (unlock func())
|
||||
return unlock
|
||||
}
|
||||
|
||||
func (s *SqliteStore) AcquireAccountReadLock(accountID string) (unlock func()) {
|
||||
func (s *SqlStore) AcquireAccountReadLock(accountID string) (unlock func()) {
|
||||
log.Tracef("acquiring read lock for account %s", accountID)
|
||||
|
||||
start := time.Now()
|
||||
@@ -159,7 +119,7 @@ func (s *SqliteStore) AcquireAccountReadLock(accountID string) (unlock func()) {
|
||||
return unlock
|
||||
}
|
||||
|
||||
func (s *SqliteStore) SaveAccount(account *Account) error {
|
||||
func (s *SqlStore) SaveAccount(account *Account) error {
|
||||
start := time.Now()
|
||||
|
||||
for _, key := range account.SetupKeys {
|
||||
@@ -225,12 +185,12 @@ func (s *SqliteStore) SaveAccount(account *Account) error {
|
||||
if s.metrics != nil {
|
||||
s.metrics.StoreMetrics().CountPersistenceDuration(took)
|
||||
}
|
||||
log.Debugf("took %d ms to persist an account to the SQLite", took.Milliseconds())
|
||||
log.Debugf("took %d ms to persist an account to the store", took.Milliseconds())
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *SqliteStore) DeleteAccount(account *Account) error {
|
||||
func (s *SqlStore) DeleteAccount(account *Account) error {
|
||||
start := time.Now()
|
||||
|
||||
err := s.db.Transaction(func(tx *gorm.DB) error {
|
||||
@@ -256,19 +216,19 @@ func (s *SqliteStore) DeleteAccount(account *Account) error {
|
||||
if s.metrics != nil {
|
||||
s.metrics.StoreMetrics().CountPersistenceDuration(took)
|
||||
}
|
||||
log.Debugf("took %d ms to delete an account to the SQLite", took.Milliseconds())
|
||||
log.Debugf("took %d ms to delete an account to the store", took.Milliseconds())
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *SqliteStore) SaveInstallationID(ID string) error {
|
||||
func (s *SqlStore) SaveInstallationID(ID string) error {
|
||||
installation := installation{InstallationIDValue: ID}
|
||||
installation.ID = uint(s.installationPK)
|
||||
|
||||
return s.db.Clauses(clause.OnConflict{UpdateAll: true}).Create(&installation).Error
|
||||
}
|
||||
|
||||
func (s *SqliteStore) GetInstallationID() string {
|
||||
func (s *SqlStore) GetInstallationID() string {
|
||||
var installation installation
|
||||
|
||||
if result := s.db.First(&installation, "id = ?", s.installationPK); result.Error != nil {
|
||||
@@ -278,7 +238,7 @@ func (s *SqliteStore) GetInstallationID() string {
|
||||
return installation.InstallationIDValue
|
||||
}
|
||||
|
||||
func (s *SqliteStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer.PeerStatus) error {
|
||||
func (s *SqlStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer.PeerStatus) error {
|
||||
var peerCopy nbpeer.Peer
|
||||
peerCopy.Status = &peerStatus
|
||||
result := s.db.Model(&nbpeer.Peer{}).
|
||||
@@ -296,7 +256,7 @@ func (s *SqliteStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SqliteStore) SavePeerLocation(accountID string, peerWithLocation *nbpeer.Peer) error {
|
||||
func (s *SqlStore) SavePeerLocation(accountID string, peerWithLocation *nbpeer.Peer) error {
|
||||
// To maintain data integrity, we create a copy of the peer's location to prevent unintended updates to other fields.
|
||||
var peerCopy nbpeer.Peer
|
||||
// Since the location field has been migrated to JSON serialization,
|
||||
@@ -318,17 +278,17 @@ func (s *SqliteStore) SavePeerLocation(accountID string, peerWithLocation *nbpee
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteHashedPAT2TokenIDIndex is noop in Sqlite
|
||||
func (s *SqliteStore) DeleteHashedPAT2TokenIDIndex(hashedToken string) error {
|
||||
// DeleteHashedPAT2TokenIDIndex is noop in SqlStore
|
||||
func (s *SqlStore) DeleteHashedPAT2TokenIDIndex(hashedToken string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteTokenID2UserIDIndex is noop in Sqlite
|
||||
func (s *SqliteStore) DeleteTokenID2UserIDIndex(tokenID string) error {
|
||||
// DeleteTokenID2UserIDIndex is noop in SqlStore
|
||||
func (s *SqlStore) DeleteTokenID2UserIDIndex(tokenID string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SqliteStore) GetAccountByPrivateDomain(domain string) (*Account, error) {
|
||||
func (s *SqlStore) GetAccountByPrivateDomain(domain string) (*Account, error) {
|
||||
var account Account
|
||||
|
||||
result := s.db.First(&account, "domain = ? and is_domain_primary_account = ? and domain_category = ?",
|
||||
@@ -345,7 +305,7 @@ func (s *SqliteStore) GetAccountByPrivateDomain(domain string) (*Account, error)
|
||||
return s.GetAccount(account.Id)
|
||||
}
|
||||
|
||||
func (s *SqliteStore) GetAccountBySetupKey(setupKey string) (*Account, error) {
|
||||
func (s *SqlStore) GetAccountBySetupKey(setupKey string) (*Account, error) {
|
||||
var key SetupKey
|
||||
result := s.db.Select("account_id").First(&key, "key = ?", strings.ToUpper(setupKey))
|
||||
if result.Error != nil {
|
||||
@@ -363,7 +323,7 @@ func (s *SqliteStore) GetAccountBySetupKey(setupKey string) (*Account, error) {
|
||||
return s.GetAccount(key.AccountID)
|
||||
}
|
||||
|
||||
func (s *SqliteStore) GetTokenIDByHashedToken(hashedToken string) (string, error) {
|
||||
func (s *SqlStore) GetTokenIDByHashedToken(hashedToken string) (string, error) {
|
||||
var token PersonalAccessToken
|
||||
result := s.db.First(&token, "hashed_token = ?", hashedToken)
|
||||
if result.Error != nil {
|
||||
@@ -377,7 +337,7 @@ func (s *SqliteStore) GetTokenIDByHashedToken(hashedToken string) (string, error
|
||||
return token.ID, nil
|
||||
}
|
||||
|
||||
func (s *SqliteStore) GetUserByTokenID(tokenID string) (*User, error) {
|
||||
func (s *SqlStore) GetUserByTokenID(tokenID string) (*User, error) {
|
||||
var token PersonalAccessToken
|
||||
result := s.db.First(&token, "id = ?", tokenID)
|
||||
if result.Error != nil {
|
||||
@@ -406,7 +366,7 @@ func (s *SqliteStore) GetUserByTokenID(tokenID string) (*User, error) {
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
func (s *SqliteStore) GetAllAccounts() (all []*Account) {
|
||||
func (s *SqlStore) GetAllAccounts() (all []*Account) {
|
||||
var accounts []Account
|
||||
result := s.db.Find(&accounts)
|
||||
if result.Error != nil {
|
||||
@@ -422,7 +382,7 @@ func (s *SqliteStore) GetAllAccounts() (all []*Account) {
|
||||
return all
|
||||
}
|
||||
|
||||
func (s *SqliteStore) GetAccount(accountID string) (*Account, error) {
|
||||
func (s *SqlStore) GetAccount(accountID string) (*Account, error) {
|
||||
|
||||
var account Account
|
||||
result := s.db.Model(&account).
|
||||
@@ -490,7 +450,7 @@ func (s *SqliteStore) GetAccount(accountID string) (*Account, error) {
|
||||
return &account, nil
|
||||
}
|
||||
|
||||
func (s *SqliteStore) GetAccountByUser(userID string) (*Account, error) {
|
||||
func (s *SqlStore) GetAccountByUser(userID string) (*Account, error) {
|
||||
var user User
|
||||
result := s.db.Select("account_id").First(&user, "id = ?", userID)
|
||||
if result.Error != nil {
|
||||
@@ -508,7 +468,7 @@ func (s *SqliteStore) GetAccountByUser(userID string) (*Account, error) {
|
||||
return s.GetAccount(user.AccountID)
|
||||
}
|
||||
|
||||
func (s *SqliteStore) GetAccountByPeerID(peerID string) (*Account, error) {
|
||||
func (s *SqlStore) GetAccountByPeerID(peerID string) (*Account, error) {
|
||||
var peer nbpeer.Peer
|
||||
result := s.db.Select("account_id").First(&peer, "id = ?", peerID)
|
||||
if result.Error != nil {
|
||||
@@ -526,7 +486,7 @@ func (s *SqliteStore) GetAccountByPeerID(peerID string) (*Account, error) {
|
||||
return s.GetAccount(peer.AccountID)
|
||||
}
|
||||
|
||||
func (s *SqliteStore) GetAccountByPeerPubKey(peerKey string) (*Account, error) {
|
||||
func (s *SqlStore) GetAccountByPeerPubKey(peerKey string) (*Account, error) {
|
||||
var peer nbpeer.Peer
|
||||
|
||||
result := s.db.Select("account_id").First(&peer, "key = ?", peerKey)
|
||||
@@ -545,7 +505,7 @@ func (s *SqliteStore) GetAccountByPeerPubKey(peerKey string) (*Account, error) {
|
||||
return s.GetAccount(peer.AccountID)
|
||||
}
|
||||
|
||||
func (s *SqliteStore) GetAccountIDByPeerPubKey(peerKey string) (string, error) {
|
||||
func (s *SqlStore) GetAccountIDByPeerPubKey(peerKey string) (string, error) {
|
||||
var peer nbpeer.Peer
|
||||
var accountID string
|
||||
result := s.db.Model(&peer).Select("account_id").Where("key = ?", peerKey).First(&accountID)
|
||||
@@ -561,7 +521,7 @@ func (s *SqliteStore) GetAccountIDByPeerPubKey(peerKey string) (string, error) {
|
||||
}
|
||||
|
||||
// SaveUserLastLogin stores the last login time for a user in DB.
|
||||
func (s *SqliteStore) SaveUserLastLogin(accountID, userID string, lastLogin time.Time) error {
|
||||
func (s *SqlStore) SaveUserLastLogin(accountID, userID string, lastLogin time.Time) error {
|
||||
var user User
|
||||
|
||||
result := s.db.First(&user, "account_id = ? and id = ?", accountID, userID)
|
||||
@@ -579,7 +539,7 @@ func (s *SqliteStore) SaveUserLastLogin(accountID, userID string, lastLogin time
|
||||
}
|
||||
|
||||
// Close closes the underlying DB connection
|
||||
func (s *SqliteStore) Close() error {
|
||||
func (s *SqlStore) Close() error {
|
||||
sql, err := s.db.DB()
|
||||
if err != nil {
|
||||
return fmt.Errorf("get db: %w", err)
|
||||
@@ -587,40 +547,85 @@ func (s *SqliteStore) Close() error {
|
||||
return sql.Close()
|
||||
}
|
||||
|
||||
// GetStoreEngine returns SqliteStoreEngine
|
||||
func (s *SqliteStore) GetStoreEngine() StoreEngine {
|
||||
return SqliteStoreEngine
|
||||
// GetStoreEngine returns underlying store engine
|
||||
func (s *SqlStore) GetStoreEngine() StoreEngine {
|
||||
return s.storeEngine
|
||||
}
|
||||
|
||||
// migrate migrates the SQLite database to the latest schema
|
||||
func migrate(db *gorm.DB) error {
|
||||
migrations := getMigrations()
|
||||
// NewSqliteStore creates a new SQLite store.
|
||||
func NewSqliteStore(dataDir string, metrics telemetry.AppMetrics) (*SqlStore, error) {
|
||||
storeStr := "store.db?cache=shared"
|
||||
if runtime.GOOS == "windows" {
|
||||
// Vo avoid `The process cannot access the file because it is being used by another process` on Windows
|
||||
storeStr = "store.db"
|
||||
}
|
||||
|
||||
for _, m := range migrations {
|
||||
if err := m(db); err != nil {
|
||||
return err
|
||||
file := filepath.Join(dataDir, storeStr)
|
||||
db, err := gorm.Open(sqlite.Open(file), &gorm.Config{
|
||||
Logger: logger.Default.LogMode(logger.Silent),
|
||||
CreateBatchSize: 400,
|
||||
PrepareStmt: true,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return NewSqlStore(db, SqliteStoreEngine, metrics)
|
||||
}
|
||||
|
||||
// NewPostgresqlStore creates a new Postgres store.
|
||||
func NewPostgresqlStore(dsn string, metrics telemetry.AppMetrics) (*SqlStore, error) {
|
||||
db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{
|
||||
Logger: logger.Default.LogMode(logger.Silent),
|
||||
PrepareStmt: true,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return NewSqlStore(db, PostgresStoreEngine, metrics)
|
||||
}
|
||||
|
||||
// NewSqliteStoreFromFileStore restores a store from FileStore and stores SQLite DB in the file located in datadir.
|
||||
func NewSqliteStoreFromFileStore(fileStore *FileStore, dataDir string, metrics telemetry.AppMetrics) (*SqlStore, error) {
|
||||
store, err := NewSqliteStore(dataDir, metrics)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = store.SaveInstallationID(fileStore.InstallationID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, account := range fileStore.GetAllAccounts() {
|
||||
err := store.SaveAccount(account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
return store, nil
|
||||
}
|
||||
|
||||
func getMigrations() []migrationFunc {
|
||||
return []migrationFunc{
|
||||
func(db *gorm.DB) error {
|
||||
return migration.MigrateFieldFromGobToJSON[Account, net.IPNet](db, "network_net")
|
||||
},
|
||||
func(db *gorm.DB) error {
|
||||
return migration.MigrateFieldFromGobToJSON[route.Route, netip.Prefix](db, "network")
|
||||
},
|
||||
func(db *gorm.DB) error {
|
||||
return migration.MigrateFieldFromGobToJSON[route.Route, []string](db, "peer_groups")
|
||||
},
|
||||
func(db *gorm.DB) error {
|
||||
return migration.MigrateNetIPFieldFromBlobToJSON[nbpeer.Peer](db, "location_connection_ip", "")
|
||||
},
|
||||
func(db *gorm.DB) error {
|
||||
return migration.MigrateNetIPFieldFromBlobToJSON[nbpeer.Peer](db, "ip", "idx_peers_account_id_ip")
|
||||
},
|
||||
// NewPostgresqlStoreFromFileStore restores a store from FileStore and stores Postgres DB.
|
||||
func NewPostgresqlStoreFromFileStore(fileStore *FileStore, dsn string, metrics telemetry.AppMetrics) (*SqlStore, error) {
|
||||
store, err := NewPostgresqlStore(dsn, metrics)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = store.SaveInstallationID(fileStore.InstallationID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, account := range fileStore.GetAllAccounts() {
|
||||
err := store.SaveAccount(account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return store, nil
|
||||
}
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"math/rand"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"testing"
|
||||
@@ -569,7 +570,7 @@ func TestMigrate(t *testing.T) {
|
||||
require.NoError(t, err, "Migration should not fail on migrated db")
|
||||
}
|
||||
|
||||
func newSqliteStore(t *testing.T) *SqliteStore {
|
||||
func newSqliteStore(t *testing.T) *SqlStore {
|
||||
t.Helper()
|
||||
|
||||
store, err := NewSqliteStore(t.TempDir(), nil)
|
||||
@@ -579,7 +580,7 @@ func newSqliteStore(t *testing.T) *SqliteStore {
|
||||
return store
|
||||
}
|
||||
|
||||
func newSqliteStoreFromFile(t *testing.T, filename string) *SqliteStore {
|
||||
func newSqliteStoreFromFile(t *testing.T, filename string) *SqlStore {
|
||||
t.Helper()
|
||||
|
||||
storeDir := t.TempDir()
|
||||
@@ -613,3 +614,298 @@ func newAccount(store Store, id int) error {
|
||||
|
||||
return store.SaveAccount(account)
|
||||
}
|
||||
|
||||
func newPostgresqlStore(t *testing.T) *SqlStore {
|
||||
t.Helper()
|
||||
|
||||
cleanUp, err := createPGDB()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Cleanup(cleanUp)
|
||||
|
||||
postgresDsn, ok := os.LookupEnv(postgresDsnEnv)
|
||||
if !ok {
|
||||
t.Fatalf("could not initialize postgresql store: %s is not set", postgresDsnEnv)
|
||||
}
|
||||
|
||||
store, err := NewPostgresqlStore(postgresDsn, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("could not initialize postgresql store: %s", err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, store)
|
||||
|
||||
return store
|
||||
}
|
||||
|
||||
func newPostgresqlStoreFromFile(t *testing.T, filename string) *SqlStore {
|
||||
t.Helper()
|
||||
|
||||
storeDir := t.TempDir()
|
||||
err := util.CopyFileContents(filename, filepath.Join(storeDir, "store.json"))
|
||||
require.NoError(t, err)
|
||||
|
||||
fStore, err := NewFileStore(storeDir, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
cleanUp, err := createPGDB()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Cleanup(cleanUp)
|
||||
|
||||
postgresDsn, ok := os.LookupEnv(postgresDsnEnv)
|
||||
if !ok {
|
||||
t.Fatalf("could not initialize postgresql store: %s is not set", postgresDsnEnv)
|
||||
}
|
||||
|
||||
store, err := NewPostgresqlStoreFromFileStore(fStore, postgresDsn, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, store)
|
||||
|
||||
return store
|
||||
}
|
||||
|
||||
func TestPostgresql_NewStore(t *testing.T) {
|
||||
if runtime.GOOS != "linux" {
|
||||
t.Skipf("The PostgreSQL store is not properly supported by %s yet", runtime.GOOS)
|
||||
}
|
||||
|
||||
store := newPostgresqlStore(t)
|
||||
|
||||
if len(store.GetAllAccounts()) != 0 {
|
||||
t.Errorf("expected to create a new empty Accounts map when creating a new FileStore")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPostgresql_SaveAccount(t *testing.T) {
|
||||
if runtime.GOOS != "linux" {
|
||||
t.Skipf("The PostgreSQL store is not properly supported by %s yet", runtime.GOOS)
|
||||
}
|
||||
|
||||
store := newPostgresqlStore(t)
|
||||
|
||||
account := newAccountWithId("account_id", "testuser", "")
|
||||
setupKey := GenerateDefaultSetupKey()
|
||||
account.SetupKeys[setupKey.Key] = setupKey
|
||||
account.Peers["testpeer"] = &nbpeer.Peer{
|
||||
Key: "peerkey",
|
||||
SetupKey: "peerkeysetupkey",
|
||||
IP: net.IP{127, 0, 0, 1},
|
||||
Meta: nbpeer.PeerSystemMeta{},
|
||||
Name: "peer name",
|
||||
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
|
||||
}
|
||||
|
||||
err := store.SaveAccount(account)
|
||||
require.NoError(t, err)
|
||||
|
||||
account2 := newAccountWithId("account_id2", "testuser2", "")
|
||||
setupKey = GenerateDefaultSetupKey()
|
||||
account2.SetupKeys[setupKey.Key] = setupKey
|
||||
account2.Peers["testpeer2"] = &nbpeer.Peer{
|
||||
Key: "peerkey2",
|
||||
SetupKey: "peerkeysetupkey2",
|
||||
IP: net.IP{127, 0, 0, 2},
|
||||
Meta: nbpeer.PeerSystemMeta{},
|
||||
Name: "peer name 2",
|
||||
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
|
||||
}
|
||||
|
||||
err = store.SaveAccount(account2)
|
||||
require.NoError(t, err)
|
||||
|
||||
if len(store.GetAllAccounts()) != 2 {
|
||||
t.Errorf("expecting 2 Accounts to be stored after SaveAccount()")
|
||||
}
|
||||
|
||||
a, err := store.GetAccount(account.Id)
|
||||
if a == nil {
|
||||
t.Errorf("expecting Account to be stored after SaveAccount(): %v", err)
|
||||
}
|
||||
|
||||
if a != nil && len(a.Policies) != 1 {
|
||||
t.Errorf("expecting Account to have one policy stored after SaveAccount(), got %d", len(a.Policies))
|
||||
}
|
||||
|
||||
if a != nil && len(a.Policies[0].Rules) != 1 {
|
||||
t.Errorf("expecting Account to have one policy rule stored after SaveAccount(), got %d", len(a.Policies[0].Rules))
|
||||
return
|
||||
}
|
||||
|
||||
if a, err := store.GetAccountByPeerPubKey("peerkey"); a == nil {
|
||||
t.Errorf("expecting PeerKeyID2AccountID index updated after SaveAccount(): %v", err)
|
||||
}
|
||||
|
||||
if a, err := store.GetAccountByUser("testuser"); a == nil {
|
||||
t.Errorf("expecting UserID2AccountID index updated after SaveAccount(): %v", err)
|
||||
}
|
||||
|
||||
if a, err := store.GetAccountByPeerID("testpeer"); a == nil {
|
||||
t.Errorf("expecting PeerID2AccountID index updated after SaveAccount(): %v", err)
|
||||
}
|
||||
|
||||
if a, err := store.GetAccountBySetupKey(setupKey.Key); a == nil {
|
||||
t.Errorf("expecting SetupKeyID2AccountID index updated after SaveAccount(): %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPostgresql_DeleteAccount(t *testing.T) {
|
||||
if runtime.GOOS != "linux" {
|
||||
t.Skipf("The PostgreSQL store is not properly supported by %s yet", runtime.GOOS)
|
||||
}
|
||||
|
||||
store := newPostgresqlStore(t)
|
||||
|
||||
testUserID := "testuser"
|
||||
user := NewAdminUser(testUserID)
|
||||
user.PATs = map[string]*PersonalAccessToken{"testtoken": {
|
||||
ID: "testtoken",
|
||||
Name: "test token",
|
||||
}}
|
||||
|
||||
account := newAccountWithId("account_id", testUserID, "")
|
||||
setupKey := GenerateDefaultSetupKey()
|
||||
account.SetupKeys[setupKey.Key] = setupKey
|
||||
account.Peers["testpeer"] = &nbpeer.Peer{
|
||||
Key: "peerkey",
|
||||
SetupKey: "peerkeysetupkey",
|
||||
IP: net.IP{127, 0, 0, 1},
|
||||
Meta: nbpeer.PeerSystemMeta{},
|
||||
Name: "peer name",
|
||||
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
|
||||
}
|
||||
account.Users[testUserID] = user
|
||||
|
||||
err := store.SaveAccount(account)
|
||||
require.NoError(t, err)
|
||||
|
||||
if len(store.GetAllAccounts()) != 1 {
|
||||
t.Errorf("expecting 1 Accounts to be stored after SaveAccount()")
|
||||
}
|
||||
|
||||
err = store.DeleteAccount(account)
|
||||
require.NoError(t, err)
|
||||
|
||||
if len(store.GetAllAccounts()) != 0 {
|
||||
t.Errorf("expecting 0 Accounts to be stored after DeleteAccount()")
|
||||
}
|
||||
|
||||
_, err = store.GetAccountByPeerPubKey("peerkey")
|
||||
require.Error(t, err, "expecting error after removing DeleteAccount when getting account by peer public key")
|
||||
|
||||
_, err = store.GetAccountByUser("testuser")
|
||||
require.Error(t, err, "expecting error after removing DeleteAccount when getting account by user")
|
||||
|
||||
_, err = store.GetAccountByPeerID("testpeer")
|
||||
require.Error(t, err, "expecting error after removing DeleteAccount when getting account by peer id")
|
||||
|
||||
_, err = store.GetAccountBySetupKey(setupKey.Key)
|
||||
require.Error(t, err, "expecting error after removing DeleteAccount when getting account by setup key")
|
||||
|
||||
_, err = store.GetAccount(account.Id)
|
||||
require.Error(t, err, "expecting error after removing DeleteAccount when getting account by id")
|
||||
|
||||
for _, policy := range account.Policies {
|
||||
var rules []*PolicyRule
|
||||
err = store.db.Model(&PolicyRule{}).Find(&rules, "policy_id = ?", policy.ID).Error
|
||||
require.NoError(t, err, "expecting no error after removing DeleteAccount when searching for policy rules")
|
||||
require.Len(t, rules, 0, "expecting no policy rules to be found after removing DeleteAccount")
|
||||
|
||||
}
|
||||
|
||||
for _, accountUser := range account.Users {
|
||||
var pats []*PersonalAccessToken
|
||||
err = store.db.Model(&PersonalAccessToken{}).Find(&pats, "user_id = ?", accountUser.Id).Error
|
||||
require.NoError(t, err, "expecting no error after removing DeleteAccount when searching for personal access token")
|
||||
require.Len(t, pats, 0, "expecting no personal access token to be found after removing DeleteAccount")
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestPostgresql_SavePeerStatus(t *testing.T) {
|
||||
if runtime.GOOS != "linux" {
|
||||
t.Skipf("The PostgreSQL store is not properly supported by %s yet", runtime.GOOS)
|
||||
}
|
||||
|
||||
store := newPostgresqlStoreFromFile(t, "testdata/store.json")
|
||||
|
||||
account, err := store.GetAccount("bf1c8084-ba50-4ce7-9439-34653001fc3b")
|
||||
require.NoError(t, err)
|
||||
|
||||
// save status of non-existing peer
|
||||
newStatus := nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}
|
||||
err = store.SavePeerStatus(account.Id, "non-existing-peer", newStatus)
|
||||
assert.Error(t, err)
|
||||
|
||||
// save new status of existing peer
|
||||
account.Peers["testpeer"] = &nbpeer.Peer{
|
||||
Key: "peerkey",
|
||||
ID: "testpeer",
|
||||
SetupKey: "peerkeysetupkey",
|
||||
IP: net.IP{127, 0, 0, 1},
|
||||
Meta: nbpeer.PeerSystemMeta{},
|
||||
Name: "peer name",
|
||||
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now().UTC()},
|
||||
}
|
||||
|
||||
err = store.SaveAccount(account)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = store.SavePeerStatus(account.Id, "testpeer", newStatus)
|
||||
require.NoError(t, err)
|
||||
|
||||
account, err = store.GetAccount(account.Id)
|
||||
require.NoError(t, err)
|
||||
|
||||
actual := account.Peers["testpeer"].Status
|
||||
assert.Equal(t, newStatus.Connected, actual.Connected)
|
||||
}
|
||||
|
||||
func TestPostgresql_TestGetAccountByPrivateDomain(t *testing.T) {
|
||||
if runtime.GOOS != "linux" {
|
||||
t.Skipf("The PostgreSQL store is not properly supported by %s yet", runtime.GOOS)
|
||||
}
|
||||
|
||||
store := newPostgresqlStoreFromFile(t, "testdata/store.json")
|
||||
|
||||
existingDomain := "test.com"
|
||||
|
||||
account, err := store.GetAccountByPrivateDomain(existingDomain)
|
||||
require.NoError(t, err, "should found account")
|
||||
require.Equal(t, existingDomain, account.Domain, "domains should match")
|
||||
|
||||
_, err = store.GetAccountByPrivateDomain("missing-domain.com")
|
||||
require.Error(t, err, "should return error on domain lookup")
|
||||
}
|
||||
|
||||
func TestPostgresql_GetTokenIDByHashedToken(t *testing.T) {
|
||||
if runtime.GOOS != "linux" {
|
||||
t.Skipf("The PostgreSQL store is not properly supported by %s yet", runtime.GOOS)
|
||||
}
|
||||
|
||||
store := newPostgresqlStoreFromFile(t, "testdata/store.json")
|
||||
|
||||
hashed := "SoMeHaShEdToKeN"
|
||||
id := "9dj38s35-63fb-11ec-90d6-0242ac120003"
|
||||
|
||||
token, err := store.GetTokenIDByHashedToken(hashed)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, id, token)
|
||||
}
|
||||
|
||||
func TestPostgresql_GetUserByTokenID(t *testing.T) {
|
||||
if runtime.GOOS != "linux" {
|
||||
t.Skipf("The PostgreSQL store is not properly supported by %s yet", runtime.GOOS)
|
||||
}
|
||||
|
||||
store := newPostgresqlStoreFromFile(t, "testdata/store.json")
|
||||
|
||||
id := "9dj38s35-63fb-11ec-90d6-0242ac120003"
|
||||
|
||||
user, err := store.GetUserByTokenID(id)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, id, user.PATs[id].ID)
|
||||
}
|
||||
@@ -1,16 +1,25 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/testcontainers/testcontainers-go"
|
||||
"github.com/testcontainers/testcontainers-go/modules/postgres"
|
||||
"github.com/testcontainers/testcontainers-go/wait"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/migration"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
type Store interface {
|
||||
@@ -49,8 +58,11 @@ type Store interface {
|
||||
type StoreEngine string
|
||||
|
||||
const (
|
||||
FileStoreEngine StoreEngine = "jsonfile"
|
||||
SqliteStoreEngine StoreEngine = "sqlite"
|
||||
FileStoreEngine StoreEngine = "jsonfile"
|
||||
SqliteStoreEngine StoreEngine = "sqlite"
|
||||
PostgresStoreEngine StoreEngine = "postgres"
|
||||
|
||||
postgresDsnEnv = "NETBIRD_STORE_ENGINE_POSTGRES_DSN"
|
||||
)
|
||||
|
||||
func getStoreEngineFromEnv() StoreEngine {
|
||||
@@ -61,8 +73,7 @@ func getStoreEngineFromEnv() StoreEngine {
|
||||
}
|
||||
|
||||
value := StoreEngine(strings.ToLower(kind))
|
||||
|
||||
if value == FileStoreEngine || value == SqliteStoreEngine {
|
||||
if value == FileStoreEngine || value == SqliteStoreEngine || value == PostgresStoreEngine {
|
||||
return value
|
||||
}
|
||||
|
||||
@@ -94,18 +105,60 @@ func NewStore(kind StoreEngine, dataDir string, metrics telemetry.AppMetrics) (S
|
||||
case SqliteStoreEngine:
|
||||
log.Info("using SQLite store engine")
|
||||
return NewSqliteStore(dataDir, metrics)
|
||||
case PostgresStoreEngine:
|
||||
log.Info("using Postgres store engine")
|
||||
dsn, ok := os.LookupEnv(postgresDsnEnv)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("%s is not set", postgresDsnEnv)
|
||||
}
|
||||
return NewPostgresqlStore(dsn, metrics)
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported kind of store %s", kind)
|
||||
}
|
||||
}
|
||||
|
||||
// NewStoreFromJson is only used in tests
|
||||
func NewStoreFromJson(dataDir string, metrics telemetry.AppMetrics) (Store, error) {
|
||||
// migrate migrates the SQLite database to the latest schema
|
||||
func migrate(db *gorm.DB) error {
|
||||
migrations := getMigrations()
|
||||
|
||||
for _, m := range migrations {
|
||||
if err := m(db); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func getMigrations() []migrationFunc {
|
||||
return []migrationFunc{
|
||||
func(db *gorm.DB) error {
|
||||
return migration.MigrateFieldFromGobToJSON[Account, net.IPNet](db, "network_net")
|
||||
},
|
||||
func(db *gorm.DB) error {
|
||||
return migration.MigrateFieldFromGobToJSON[route.Route, netip.Prefix](db, "network")
|
||||
},
|
||||
func(db *gorm.DB) error {
|
||||
return migration.MigrateFieldFromGobToJSON[route.Route, []string](db, "peer_groups")
|
||||
},
|
||||
func(db *gorm.DB) error {
|
||||
return migration.MigrateNetIPFieldFromBlobToJSON[nbpeer.Peer](db, "location_connection_ip", "")
|
||||
},
|
||||
func(db *gorm.DB) error {
|
||||
return migration.MigrateNetIPFieldFromBlobToJSON[nbpeer.Peer](db, "ip", "idx_peers_account_id_ip")
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// NewTestStoreFromJson is only used in tests
|
||||
func NewTestStoreFromJson(dataDir string) (Store, func(), error) {
|
||||
fstore, err := NewFileStore(dataDir, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
cleanUp := func() {}
|
||||
|
||||
// if store engine is not set in the config we first try to evaluate NETBIRD_STORE_ENGINE
|
||||
kind := getStoreEngineFromEnv()
|
||||
if kind == "" {
|
||||
@@ -115,10 +168,64 @@ func NewStoreFromJson(dataDir string, metrics telemetry.AppMetrics) (Store, erro
|
||||
|
||||
switch kind {
|
||||
case FileStoreEngine:
|
||||
return fstore, nil
|
||||
return fstore, cleanUp, nil
|
||||
case SqliteStoreEngine:
|
||||
return NewSqliteStoreFromFileStore(fstore, dataDir, metrics)
|
||||
store, err := NewSqliteStoreFromFileStore(fstore, dataDir, nil)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return store, cleanUp, nil
|
||||
case PostgresStoreEngine:
|
||||
cleanUp, err = createPGDB()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
dsn, ok := os.LookupEnv(postgresDsnEnv)
|
||||
if !ok {
|
||||
return nil, nil, fmt.Errorf("%s is not set", postgresDsnEnv)
|
||||
}
|
||||
|
||||
store, err := NewPostgresqlStoreFromFileStore(fstore, dsn, nil)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return store, cleanUp, nil
|
||||
default:
|
||||
return NewSqliteStoreFromFileStore(fstore, dataDir, metrics)
|
||||
store, err := NewSqliteStoreFromFileStore(fstore, dataDir, nil)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return store, cleanUp, nil
|
||||
}
|
||||
}
|
||||
|
||||
func createPGDB() (func(), error) {
|
||||
ctx := context.Background()
|
||||
c, err := postgres.RunContainer(ctx,
|
||||
testcontainers.WithImage("postgres:alpine"),
|
||||
postgres.WithDatabase("test"),
|
||||
postgres.WithUsername("postgres"),
|
||||
postgres.WithPassword("postgres"),
|
||||
testcontainers.WithWaitStrategy(
|
||||
wait.ForLog("database system is ready to accept connections").
|
||||
WithOccurrence(2).WithStartupTimeout(15*time.Second)),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cleanup := func() {
|
||||
timeout := 10 * time.Second
|
||||
err = c.Stop(ctx, &timeout)
|
||||
if err != nil {
|
||||
log.Warnf("failed to stop container: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
talksConn, err := c.ConnectionString(ctx)
|
||||
if err != nil {
|
||||
return cleanup, err
|
||||
}
|
||||
return cleanup, os.Setenv("NETBIRD_STORE_ENGINE_POSTGRES_DSN", talksConn)
|
||||
}
|
||||
|
||||
@@ -39,6 +39,7 @@ const (
|
||||
|
||||
func TestUser_CreatePAT_ForSameUser(t *testing.T) {
|
||||
store := newStore(t)
|
||||
defer store.Close()
|
||||
account := newAccountWithId(mockAccountID, mockUserID, "")
|
||||
|
||||
err := store.SaveAccount(account)
|
||||
@@ -76,6 +77,7 @@ func TestUser_CreatePAT_ForSameUser(t *testing.T) {
|
||||
|
||||
func TestUser_CreatePAT_ForDifferentUser(t *testing.T) {
|
||||
store := newStore(t)
|
||||
defer store.Close()
|
||||
account := newAccountWithId(mockAccountID, mockUserID, "")
|
||||
account.Users[mockTargetUserId] = &User{
|
||||
Id: mockTargetUserId,
|
||||
@@ -97,6 +99,7 @@ func TestUser_CreatePAT_ForDifferentUser(t *testing.T) {
|
||||
|
||||
func TestUser_CreatePAT_ForServiceUser(t *testing.T) {
|
||||
store := newStore(t)
|
||||
defer store.Close()
|
||||
account := newAccountWithId(mockAccountID, mockUserID, "")
|
||||
account.Users[mockTargetUserId] = &User{
|
||||
Id: mockTargetUserId,
|
||||
@@ -122,6 +125,7 @@ func TestUser_CreatePAT_ForServiceUser(t *testing.T) {
|
||||
|
||||
func TestUser_CreatePAT_WithWrongExpiration(t *testing.T) {
|
||||
store := newStore(t)
|
||||
defer store.Close()
|
||||
account := newAccountWithId(mockAccountID, mockUserID, "")
|
||||
|
||||
err := store.SaveAccount(account)
|
||||
@@ -140,6 +144,7 @@ func TestUser_CreatePAT_WithWrongExpiration(t *testing.T) {
|
||||
|
||||
func TestUser_CreatePAT_WithEmptyName(t *testing.T) {
|
||||
store := newStore(t)
|
||||
defer store.Close()
|
||||
account := newAccountWithId(mockAccountID, mockUserID, "")
|
||||
|
||||
err := store.SaveAccount(account)
|
||||
@@ -158,6 +163,7 @@ func TestUser_CreatePAT_WithEmptyName(t *testing.T) {
|
||||
|
||||
func TestUser_DeletePAT(t *testing.T) {
|
||||
store := newStore(t)
|
||||
defer store.Close()
|
||||
account := newAccountWithId(mockAccountID, mockUserID, "")
|
||||
account.Users[mockUserID] = &User{
|
||||
Id: mockUserID,
|
||||
@@ -190,6 +196,7 @@ func TestUser_DeletePAT(t *testing.T) {
|
||||
|
||||
func TestUser_GetPAT(t *testing.T) {
|
||||
store := newStore(t)
|
||||
defer store.Close()
|
||||
account := newAccountWithId(mockAccountID, mockUserID, "")
|
||||
account.Users[mockUserID] = &User{
|
||||
Id: mockUserID,
|
||||
@@ -221,6 +228,7 @@ func TestUser_GetPAT(t *testing.T) {
|
||||
|
||||
func TestUser_GetAllPATs(t *testing.T) {
|
||||
store := newStore(t)
|
||||
defer store.Close()
|
||||
account := newAccountWithId(mockAccountID, mockUserID, "")
|
||||
account.Users[mockUserID] = &User{
|
||||
Id: mockUserID,
|
||||
@@ -322,6 +330,7 @@ func validateStruct(s interface{}) (err error) {
|
||||
|
||||
func TestUser_CreateServiceUser(t *testing.T) {
|
||||
store := newStore(t)
|
||||
defer store.Close()
|
||||
account := newAccountWithId(mockAccountID, mockUserID, "")
|
||||
|
||||
err := store.SaveAccount(account)
|
||||
@@ -359,6 +368,7 @@ func TestUser_CreateServiceUser(t *testing.T) {
|
||||
|
||||
func TestUser_CreateUser_ServiceUser(t *testing.T) {
|
||||
store := newStore(t)
|
||||
defer store.Close()
|
||||
account := newAccountWithId(mockAccountID, mockUserID, "")
|
||||
|
||||
err := store.SaveAccount(account)
|
||||
@@ -397,6 +407,7 @@ func TestUser_CreateUser_ServiceUser(t *testing.T) {
|
||||
|
||||
func TestUser_CreateUser_RegularUser(t *testing.T) {
|
||||
store := newStore(t)
|
||||
defer store.Close()
|
||||
account := newAccountWithId(mockAccountID, mockUserID, "")
|
||||
|
||||
err := store.SaveAccount(account)
|
||||
@@ -421,6 +432,7 @@ func TestUser_CreateUser_RegularUser(t *testing.T) {
|
||||
|
||||
func TestUser_InviteNewUser(t *testing.T) {
|
||||
store := newStore(t)
|
||||
defer store.Close()
|
||||
account := newAccountWithId(mockAccountID, mockUserID, "")
|
||||
|
||||
err := store.SaveAccount(account)
|
||||
@@ -549,6 +561,7 @@ func TestUser_DeleteUser_ServiceUser(t *testing.T) {
|
||||
|
||||
func TestUser_DeleteUser_SelfDelete(t *testing.T) {
|
||||
store := newStore(t)
|
||||
defer store.Close()
|
||||
account := newAccountWithId(mockAccountID, mockUserID, "")
|
||||
|
||||
err := store.SaveAccount(account)
|
||||
@@ -569,6 +582,7 @@ func TestUser_DeleteUser_SelfDelete(t *testing.T) {
|
||||
|
||||
func TestUser_DeleteUser_regularUser(t *testing.T) {
|
||||
store := newStore(t)
|
||||
defer store.Close()
|
||||
account := newAccountWithId(mockAccountID, mockUserID, "")
|
||||
|
||||
targetId := "user2"
|
||||
@@ -650,6 +664,7 @@ func TestUser_DeleteUser_regularUser(t *testing.T) {
|
||||
|
||||
func TestDefaultAccountManager_GetUser(t *testing.T) {
|
||||
store := newStore(t)
|
||||
defer store.Close()
|
||||
account := newAccountWithId(mockAccountID, mockUserID, "")
|
||||
|
||||
err := store.SaveAccount(account)
|
||||
@@ -678,6 +693,7 @@ func TestDefaultAccountManager_GetUser(t *testing.T) {
|
||||
|
||||
func TestDefaultAccountManager_ListUsers(t *testing.T) {
|
||||
store := newStore(t)
|
||||
defer store.Close()
|
||||
account := newAccountWithId(mockAccountID, mockUserID, "")
|
||||
account.Users["normal_user1"] = NewRegularUser("normal_user1")
|
||||
account.Users["normal_user2"] = NewRegularUser("normal_user2")
|
||||
@@ -790,6 +806,7 @@ func TestDefaultAccountManager_ListUsers_DashboardPermissions(t *testing.T) {
|
||||
|
||||
func TestDefaultAccountManager_ExternalCache(t *testing.T) {
|
||||
store := newStore(t)
|
||||
defer store.Close()
|
||||
account := newAccountWithId(mockAccountID, mockUserID, "")
|
||||
externalUser := &User{
|
||||
Id: "externalUser",
|
||||
@@ -853,6 +870,7 @@ func TestUser_IsAdmin(t *testing.T) {
|
||||
|
||||
func TestUser_GetUsersFromAccount_ForAdmin(t *testing.T) {
|
||||
store := newStore(t)
|
||||
defer store.Close()
|
||||
account := newAccountWithId(mockAccountID, mockUserID, "")
|
||||
account.Users[mockServiceUserID] = &User{
|
||||
Id: mockServiceUserID,
|
||||
@@ -880,6 +898,8 @@ func TestUser_GetUsersFromAccount_ForAdmin(t *testing.T) {
|
||||
|
||||
func TestUser_GetUsersFromAccount_ForUser(t *testing.T) {
|
||||
store := newStore(t)
|
||||
defer store.Close()
|
||||
|
||||
account := newAccountWithId(mockAccountID, mockUserID, "")
|
||||
account.Users[mockServiceUserID] = &User{
|
||||
Id: mockServiceUserID,
|
||||
|
||||
Reference in New Issue
Block a user