diff --git a/.github/workflows/golang-test-darwin.yml b/.github/workflows/golang-test-darwin.yml index 5998fab01..48ffd27e9 100644 --- a/.github/workflows/golang-test-darwin.yml +++ b/.github/workflows/golang-test-darwin.yml @@ -14,8 +14,20 @@ jobs: test: strategy: matrix: - store: ['jsonfile', 'sqlite'] + store: ['jsonfile', 'sqlite', 'postgresql'] runs-on: macos-latest + services: + postgres: + image: postgres + env: + POSTGRES_PASSWORD: postgres + POSTGRES_DB: integrations + options: >- + --health-cmd pg_isready + --health-interval 10s + --health-timeout 5s + ports: + - 5432:5432 steps: - name: Install Go uses: actions/setup-go@v4 diff --git a/.github/workflows/golang-test-linux.yml b/.github/workflows/golang-test-linux.yml index 8015fb36a..61ce3e587 100644 --- a/.github/workflows/golang-test-linux.yml +++ b/.github/workflows/golang-test-linux.yml @@ -15,8 +15,20 @@ jobs: strategy: matrix: arch: ['386','amd64'] - store: ['jsonfile', 'sqlite'] + store: ['jsonfile', 'sqlite', 'postgresql'] runs-on: ubuntu-latest + services: + postgres: + image: postgres + env: + POSTGRES_PASSWORD: postgres + POSTGRES_DB: integrations + options: >- + --health-cmd pg_isready + --health-interval 10s + --health-timeout 5s + ports: + - 5432:5432 steps: - name: Install Go uses: actions/setup-go@v4 diff --git a/go.mod b/go.mod index 1c0cfc0c0..fa3518743 100644 --- a/go.mod +++ b/go.mod @@ -44,6 +44,7 @@ require ( github.com/google/nftables v0.0.0-20220808154552-2eca00135732 github.com/hashicorp/go-secure-stdlib/base62 v0.1.2 github.com/hashicorp/go-version v1.6.0 + github.com/lib/pq v1.10.9 github.com/libp2p/go-netroute v0.2.0 github.com/magiconair/properties v1.8.5 github.com/mattn/go-sqlite3 v1.14.17 @@ -77,8 +78,9 @@ require ( golang.org/x/term v0.13.0 google.golang.org/api v0.126.0 gopkg.in/yaml.v3 v3.0.1 + gorm.io/driver/postgres v1.5.4 gorm.io/driver/sqlite v1.5.3 - gorm.io/gorm v1.25.4 + gorm.io/gorm v1.25.5 ) require ( @@ -116,6 +118,9 @@ require ( github.com/googleapis/gax-go/v2 v2.10.0 // indirect github.com/hashicorp/go-uuid v1.0.2 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect + github.com/jackc/pgpassfile v1.0.0 // indirect + github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect + github.com/jackc/pgx/v5 v5.4.3 // indirect github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/now v1.1.5 // indirect github.com/josharian/native v1.0.0 // indirect @@ -151,7 +156,6 @@ require ( golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect google.golang.org/appengine v1.6.7 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20230530153820-e85fd2cbaebc // indirect - gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect gopkg.in/square/go-jose.v2 v2.6.0 // indirect gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect gopkg.in/tomb.v2 v2.0.0-20161208151619-d5d1b5820637 // indirect diff --git a/go.sum b/go.sum index 26322b1ad..2620967d6 100644 --- a/go.sum +++ b/go.sum @@ -382,6 +382,12 @@ github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANyt github.com/inconshreveable/mousetrap v1.0.1/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= +github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= +github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= +github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk= +github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= +github.com/jackc/pgx/v5 v5.4.3 h1:cxFyXhxlvAifxnkKKdlxv8XqUf59tDlYjnV5YYfsJJY= +github.com/jackc/pgx/v5 v5.4.3/go.mod h1:Ig06C2Vu0t5qXC60W8sqIthScaEnFvojjj9dSljmHRA= github.com/jackmordaunt/icns v0.0.0-20181231085925-4f16af745526/go.mod h1:UQkeMHVoNcyXYq9otUupF7/h/2tmHlhrS2zw7ZVvUqc= github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= @@ -434,6 +440,8 @@ github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/leodido/go-urn v1.1.0/go.mod h1:+cyI34gQWZcE1eQU7NVgKkkzdXDQHr1dBMtdAPozLkw= +github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= +github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/libp2p/go-netroute v0.2.0 h1:0FpsbsvuSnAhXFnCY0VLFbJOzaK0VnP0r1QT/o4nWRE= github.com/libp2p/go-netroute v0.2.0/go.mod h1:Vio7LTzZ+6hoT4CMZi5/6CpY3Snzh2vgZhWgxMNwlQI= github.com/lucor/goinfo v0.0.0-20210802170112-c078a2b0f08b/go.mod h1:PRq09yoB+Q2OJReAmwzKivcYyremnibWGbK7WfftHzc= @@ -1179,7 +1187,6 @@ gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8 gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= -gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys= gopkg.in/go-playground/assert.v1 v1.2.1/go.mod h1:9RXL0bg/zibRAgZUYszZSwO/z8Y/a8bDuhia5mkpMnE= @@ -1209,10 +1216,12 @@ gopkg.in/yaml.v3 v3.0.0-20200615113413-eeeca48fe776/go.mod h1:K4uyk7z7BCEPqu6E+C gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gorm.io/driver/postgres v1.5.4 h1:Iyrp9Meh3GmbSuyIAGyjkN+n9K+GHX9b9MqsTL4EJCo= +gorm.io/driver/postgres v1.5.4/go.mod h1:Bgo89+h0CRcdA33Y6frlaHHVuTdOf87pmyzwW9C/BH0= gorm.io/driver/sqlite v1.5.3 h1:7/0dUgX28KAcopdfbRWWl68Rflh6osa4rDh+m51KL2g= gorm.io/driver/sqlite v1.5.3/go.mod h1:qxAuCol+2r6PannQDpOP1FP6ag3mKi4esLnB/jHed+4= -gorm.io/gorm v1.25.4 h1:iyNd8fNAe8W9dvtlgeRI5zSVZPsq3OpcTu37cYcpCmw= -gorm.io/gorm v1.25.4/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k= +gorm.io/gorm v1.25.5 h1:zR9lOiiYf09VNh5Q1gphfyia1JpiClIWG9hQaxB/mls= +gorm.io/gorm v1.25.5/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8= gotest.tools v2.2.0+incompatible/go.mod h1:DsYFclhRJ6vuDpmuTbkuFWG+y2sxOXAzmJt81HFBacw= gotest.tools/v3 v3.4.0/go.mod h1:CtbdzLSsqVhDgMtKsx03ird5YTGB3ar27v0u/yKBW5g= gvisor.dev/gvisor v0.0.0-20221203005347-703fd9b7fbc0 h1:Wobr37noukisGxpKo5jAsLREcpj61RxrWYzD8uwveOY= diff --git a/management/server/postgresql_store.go b/management/server/postgresql_store.go new file mode 100644 index 000000000..8e22c4bf5 --- /dev/null +++ b/management/server/postgresql_store.go @@ -0,0 +1,534 @@ +package server + +import ( + "bytes" + "context" + "database/sql" + "encoding/base64" + "encoding/gob" + "fmt" + "reflect" + "runtime" + "strings" + "sync" + "time" + + _ "github.com/lib/pq" + log "github.com/sirupsen/logrus" + "gorm.io/driver/postgres" + "gorm.io/gorm" + "gorm.io/gorm/clause" + "gorm.io/gorm/logger" + "gorm.io/gorm/schema" + + nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/management/server/account" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/telemetry" + "github.com/netbirdio/netbird/route" +) + +// PostgresqlStore represents an account storage backed by a Postgres DB persisted to disk +type PostgresqlStore struct { + db *gorm.DB + dsn string + accountLocks sync.Map + globalAccountLock sync.Mutex + metrics telemetry.AppMetrics + installationPK int +} + +// GobSerializer gob serializer +type GobBase64Serializer struct{} + +// Scan implements serializer interface with base64 encoding +func (GobBase64Serializer) Scan(ctx context.Context, field *schema.Field, dst reflect.Value, dbValue interface{}) (err error) { + fieldValue := reflect.New(field.FieldType) + + if dbValue != nil { + var bytesValue []byte + switch v := dbValue.(type) { + case []byte: + bytesValue = v + case string: + bytesValue = []byte(v) + default: + return fmt.Errorf("failed to unmarshal gob value: %#v", dbValue) + } + if len(bytesValue) > 0 { + var decoded []byte + decoded, err = base64.StdEncoding.DecodeString(string(bytesValue)) + if err == nil { + decoder := gob.NewDecoder(bytes.NewBuffer(decoded)) + err = decoder.Decode(fieldValue.Interface()) + } + } + } + field.ReflectValueOf(ctx, dst).Set(fieldValue.Elem()) + return +} + +// Value implements serializer interface +func (GobBase64Serializer) Value(ctx context.Context, field *schema.Field, dst reflect.Value, fieldValue interface{}) (interface{}, error) { + buf := new(bytes.Buffer) + err := gob.NewEncoder(buf).Encode(fieldValue) + return base64.StdEncoding.EncodeToString(buf.Bytes()), err +} + +// NewPostgresqlStore restores a store from the file located in the datadir +func NewPostgresqlStore(dsn string, metrics telemetry.AppMetrics) (*PostgresqlStore, error) { + schema.RegisterSerializer("gob", GobBase64Serializer{}) + + sqlDB, err := sql.Open("postgres", dsn) + if err != nil { + return nil, err + } + + db, err := gorm.Open(postgres.New(postgres.Config{ + Conn: sqlDB, + }), &gorm.Config{ + Logger: logger.Default.LogMode(logger.Silent), + PrepareStmt: true, + }) + if err != nil { + return nil, err + } + + //sql, err := db.DB() + //if err != nil { + // return nil, err + //} + conns := runtime.NumCPU() + sqlDB.SetMaxOpenConns(conns) // TODO: make it configurable + + err = db.AutoMigrate( + &Account{}, &SetupKey{}, &nbpeer.Peer{}, &User{}, &PersonalAccessToken{}, &Group{}, &Rule{}, + &Policy{}, &PolicyRule{}, &route.Route{}, &nbdns.NameServerGroup{}, + &installation{}, &account.ExtraSettings{}, + ) + if err != nil { + return nil, err + } + + return &PostgresqlStore{db: db, dsn: dsn, metrics: metrics, installationPK: 1}, nil +} + +// NewPostgresqlStoreFromFileStore restores a store from FileStore and stores PostgreSQL DB in the file located in datadir +func NewPostgresqlStoreFromFileStore(filestore *FileStore, dataDir string, metrics telemetry.AppMetrics) (*PostgresqlStore, error) { + store, err := NewPostgresqlStore(dataDir, metrics) + if err != nil { + return nil, err + } + + err = store.SaveInstallationID(filestore.InstallationID) + if err != nil { + return nil, err + } + + for _, account := range filestore.GetAllAccounts() { + err := store.SaveAccount(account) + if err != nil { + return nil, err + } + } + + return store, nil +} + +// AcquireGlobalLock acquires global lock across all the accounts and returns a function that releases the lock +func (s *PostgresqlStore) AcquireGlobalLock() (unlock func()) { + log.Debugf("acquiring global lock") + start := time.Now() + s.globalAccountLock.Lock() + + unlock = func() { + s.globalAccountLock.Unlock() + log.Debugf("released global lock in %v", time.Since(start)) + } + + took := time.Since(start) + log.Debugf("took %v to acquire global lock", took) + if s.metrics != nil { + s.metrics.StoreMetrics().CountGlobalLockAcquisitionDuration(took) + } + + return unlock +} + +func (s *PostgresqlStore) AcquireAccountLock(accountID string) (unlock func()) { + log.Debugf("acquiring lock for account %s", accountID) + + start := time.Now() + value, _ := s.accountLocks.LoadOrStore(accountID, &sync.Mutex{}) + mtx := value.(*sync.Mutex) + mtx.Lock() + + unlock = func() { + mtx.Unlock() + log.Debugf("released lock for account %s in %v", accountID, time.Since(start)) + } + + return unlock +} + +func (s *PostgresqlStore) SaveAccount(account *Account) error { + start := time.Now() + + for _, key := range account.SetupKeys { + account.SetupKeysG = append(account.SetupKeysG, *key) + } + + for id, peer := range account.Peers { + peer.ID = id + account.PeersG = append(account.PeersG, *peer) + } + + for id, user := range account.Users { + user.Id = id + for id, pat := range user.PATs { + pat.ID = id + user.PATsG = append(user.PATsG, *pat) + } + account.UsersG = append(account.UsersG, *user) + } + + for id, group := range account.Groups { + group.ID = id + account.GroupsG = append(account.GroupsG, *group) + } + + for id, rule := range account.Rules { + rule.ID = id + account.RulesG = append(account.RulesG, *rule) + } + + for id, route := range account.Routes { + route.ID = id + account.RoutesG = append(account.RoutesG, *route) + } + + for id, ns := range account.NameServerGroups { + ns.ID = id + account.NameServerGroupsG = append(account.NameServerGroupsG, *ns) + } + + err := s.db.Transaction(func(tx *gorm.DB) error { + result := tx.Select(clause.Associations).Delete(account.Policies, "account_id = ?", account.Id) + if result.Error != nil { + return result.Error + } + + result = tx.Select(clause.Associations).Delete(account.UsersG, "account_id = ?", account.Id) + if result.Error != nil { + return result.Error + } + + result = tx.Select(clause.Associations).Delete(account) + if result.Error != nil { + return result.Error + } + + result = tx. + Session(&gorm.Session{FullSaveAssociations: true}). + Clauses(clause.OnConflict{UpdateAll: true}).Create(account) + if result.Error != nil { + return result.Error + } + return nil + }) + + took := time.Since(start) + if s.metrics != nil { + s.metrics.StoreMetrics().CountPersistenceDuration(took) + } + log.Debugf("took %d ms to persist an account to the PostgreSQL", took.Milliseconds()) + + return err +} + +func (s *PostgresqlStore) DeleteAccount(account *Account) error { + start := time.Now() + + err := s.db.Transaction(func(tx *gorm.DB) error { + result := tx.Select(clause.Associations).Delete(account.Policies, "account_id = ?", account.Id) + if result.Error != nil { + return result.Error + } + + result = tx.Select(clause.Associations).Delete(account.UsersG, "account_id = ?", account.Id) + if result.Error != nil { + return result.Error + } + + result = tx.Select(clause.Associations).Delete(account) + if result.Error != nil { + return result.Error + } + + return nil + }) + + took := time.Since(start) + if s.metrics != nil { + s.metrics.StoreMetrics().CountPersistenceDuration(took) + } + log.Debugf("took %d ms to delete an account to the PostgreSQL", took.Milliseconds()) + + return err +} + +func (s *PostgresqlStore) SaveInstallationID(ID string) error { + installation := installation{InstallationIDValue: ID} + installation.ID = uint(s.installationPK) + + return s.db.Clauses(clause.OnConflict{UpdateAll: true}).Create(&installation).Error +} + +func (s *PostgresqlStore) GetInstallationID() string { + var installation installation + + if result := s.db.First(&installation, "id = ?", s.installationPK); result.Error != nil { + return "" + } + + return installation.InstallationIDValue +} + +func (s *PostgresqlStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer.PeerStatus) error { + var peer nbpeer.Peer + + result := s.db.First(&peer, "account_id = ? and id = ?", accountID, peerID) + if result.Error != nil { + return status.Errorf(status.NotFound, "peer %s not found", peerID) + } + + peer.Status = &peerStatus + + return s.db.Save(peer).Error +} + +// DeleteHashedPAT2TokenIDIndex is noop in PostgreSQL +func (s *PostgresqlStore) DeleteHashedPAT2TokenIDIndex(hashedToken string) error { + return nil +} + +// DeleteTokenID2UserIDIndex is noop in PostgreSQL +func (s *PostgresqlStore) DeleteTokenID2UserIDIndex(tokenID string) error { + return nil +} + +func (s *PostgresqlStore) GetAccountByPrivateDomain(domain string) (*Account, error) { + var account Account + + result := s.db.First(&account, "domain = ? and is_domain_primary_account = ? and domain_category = ?", + strings.ToLower(domain), true, PrivateCategory) + if result.Error != nil { + return nil, status.Errorf(status.NotFound, "account not found: provided domain is not registered or is not private") + } + + // TODO: rework to not call GetAccount + return s.GetAccount(account.Id) +} + +func (s *PostgresqlStore) GetAccountBySetupKey(setupKey string) (*Account, error) { + var key SetupKey + result := s.db.Select("account_id").First(&key, "key = ?", strings.ToUpper(setupKey)) + if result.Error != nil { + return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") + } + + if key.AccountID == "" { + return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") + } + + return s.GetAccount(key.AccountID) +} + +func (s *PostgresqlStore) GetTokenIDByHashedToken(hashedToken string) (string, error) { + var token PersonalAccessToken + result := s.db.First(&token, "hashed_token = ?", hashedToken) + if result.Error != nil { + return "", status.Errorf(status.NotFound, "account not found: index lookup failed") + } + + return token.ID, nil +} + +func (s *PostgresqlStore) GetUserByTokenID(tokenID string) (*User, error) { + var token PersonalAccessToken + result := s.db.First(&token, "id = ?", tokenID) + if result.Error != nil { + return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") + } + + if token.UserID == "" { + return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") + } + + var user User + result = s.db.Preload("PATsG").First(&user, "id = ?", token.UserID) + if result.Error != nil { + return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") + } + + user.PATs = make(map[string]*PersonalAccessToken, len(user.PATsG)) + for _, pat := range user.PATsG { + user.PATs[pat.ID] = pat.Copy() + } + + return &user, nil +} + +func (s *PostgresqlStore) GetAllAccounts() (all []*Account) { + var accounts []Account + result := s.db.Find(&accounts) + if result.Error != nil { + return all + } + + for _, account := range accounts { + if acc, err := s.GetAccount(account.Id); err == nil { + all = append(all, acc) + } + } + + return all +} + +func (s *PostgresqlStore) GetAccount(accountID string) (*Account, error) { + var account Account + + result := s.db.Model(&account). + Preload("UsersG.PATsG"). // have to be specifies as this is nester reference + Preload(clause.Associations). + First(&account, "id = ?", accountID) + if result.Error != nil { + return nil, status.Errorf(status.NotFound, "account not found") + } + + // we have to manually preload policy rules as it seems that gorm preloading doesn't do it for us + for i, policy := range account.Policies { + var rules []*PolicyRule + err := s.db.Model(&PolicyRule{}).Find(&rules, "policy_id = ?", policy.ID).Error + if err != nil { + return nil, status.Errorf(status.NotFound, "rule not found") + } + account.Policies[i].Rules = rules + } + + account.SetupKeys = make(map[string]*SetupKey, len(account.SetupKeysG)) + for _, key := range account.SetupKeysG { + account.SetupKeys[key.Key] = key.Copy() + } + account.SetupKeysG = nil + + account.Peers = make(map[string]*nbpeer.Peer, len(account.PeersG)) + for _, peer := range account.PeersG { + account.Peers[peer.ID] = peer.Copy() + } + account.PeersG = nil + + account.Users = make(map[string]*User, len(account.UsersG)) + for _, user := range account.UsersG { + user.PATs = make(map[string]*PersonalAccessToken, len(user.PATs)) + for _, pat := range user.PATsG { + user.PATs[pat.ID] = pat.Copy() + } + account.Users[user.Id] = user.Copy() + } + account.UsersG = nil + + account.Groups = make(map[string]*Group, len(account.GroupsG)) + for _, group := range account.GroupsG { + account.Groups[group.ID] = group.Copy() + } + account.GroupsG = nil + + account.Rules = make(map[string]*Rule, len(account.RulesG)) + for _, rule := range account.RulesG { + account.Rules[rule.ID] = rule.Copy() + } + account.RulesG = nil + + account.Routes = make(map[string]*route.Route, len(account.RoutesG)) + for _, route := range account.RoutesG { + account.Routes[route.ID] = route.Copy() + } + account.RoutesG = nil + + account.NameServerGroups = make(map[string]*nbdns.NameServerGroup, len(account.NameServerGroupsG)) + for _, ns := range account.NameServerGroupsG { + account.NameServerGroups[ns.ID] = ns.Copy() + } + account.NameServerGroupsG = nil + + return &account, nil +} + +func (s *PostgresqlStore) GetAccountByUser(userID string) (*Account, error) { + var user User + result := s.db.Select("account_id").First(&user, "id = ?", userID) + if result.Error != nil { + return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") + } + + if user.AccountID == "" { + return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") + } + + return s.GetAccount(user.AccountID) +} + +func (s *PostgresqlStore) GetAccountByPeerID(peerID string) (*Account, error) { + var peer nbpeer.Peer + result := s.db.Select("account_id").First(&peer, "id = ?", peerID) + if result.Error != nil { + return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") + } + + if peer.AccountID == "" { + return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") + } + + return s.GetAccount(peer.AccountID) +} + +func (s *PostgresqlStore) GetAccountByPeerPubKey(peerKey string) (*Account, error) { + var peer nbpeer.Peer + + result := s.db.Select("account_id").First(&peer, "key = ?", peerKey) + if result.Error != nil { + return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") + } + + if peer.AccountID == "" { + return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") + } + + return s.GetAccount(peer.AccountID) +} + +// SaveUserLastLogin stores the last login time for a user in DB. +func (s *PostgresqlStore) SaveUserLastLogin(accountID, userID string, lastLogin time.Time) error { + var user User + + result := s.db.First(&user, "account_id = ? and id = ?", accountID, userID) + if result.Error != nil { + return status.Errorf(status.NotFound, "user %s not found", userID) + } + + user.LastLogin = lastLogin + + return s.db.Save(user).Error +} + +// Close is noop in PostgreSQL +func (s *PostgresqlStore) Close() error { + return nil +} + +// GetStoreEngine returns PostgresqlStoreEngine +func (s *PostgresqlStore) GetStoreEngine() StoreEngine { + return PostgresqlStoreEngine +} diff --git a/management/server/postgresql_store_test.go b/management/server/postgresql_store_test.go new file mode 100644 index 000000000..65083c7b8 --- /dev/null +++ b/management/server/postgresql_store_test.go @@ -0,0 +1,353 @@ +package server + +import ( + "fmt" + "math/rand" + "net" + "path/filepath" + "runtime" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gorm.io/driver/postgres" + "gorm.io/gorm" + + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/util" +) + +func TestPostgresql_NewStore(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("The PostgreSQL store is not properly supported by Windows yet") + } + + store, cleanup := newPostgresqlStore(t) + defer cleanup() + + if len(store.GetAllAccounts()) != 0 { + t.Errorf("expected to create a new empty Accounts map when creating a new FileStore") + } +} + +func TestPostgresql_SaveAccount(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("The PostgreSQL store is not properly supported by Windows yet") + } + + store, cleanup := newPostgresqlStore(t) + defer cleanup() + + account := newAccountWithId("account_id", "testuser", "") + setupKey := GenerateDefaultSetupKey() + account.SetupKeys[setupKey.Key] = setupKey + account.Peers["testpeer"] = &nbpeer.Peer{ + Key: "peerkey", + SetupKey: "peerkeysetupkey", + IP: net.IP{127, 0, 0, 1}, + Meta: nbpeer.PeerSystemMeta{}, + Name: "peer name", + Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, + } + + err := store.SaveAccount(account) + require.NoError(t, err) + + account2 := newAccountWithId("account_id2", "testuser2", "") + setupKey = GenerateDefaultSetupKey() + account2.SetupKeys[setupKey.Key] = setupKey + account2.Peers["testpeer2"] = &nbpeer.Peer{ + Key: "peerkey2", + SetupKey: "peerkeysetupkey2", + IP: net.IP{127, 0, 0, 2}, + Meta: nbpeer.PeerSystemMeta{}, + Name: "peer name 2", + Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, + } + + err = store.SaveAccount(account2) + require.NoError(t, err) + + if len(store.GetAllAccounts()) != 2 { + t.Errorf("expecting 2 Accounts to be stored after SaveAccount()") + } + + a, err := store.GetAccount(account.Id) + if a == nil { + t.Errorf("expecting Account to be stored after SaveAccount(): %v", err) + } + + if a != nil && len(a.Policies) != 1 { + t.Errorf("expecting Account to have one policy stored after SaveAccount(), got %d", len(a.Policies)) + } + + if a != nil && len(a.Policies[0].Rules) != 1 { + t.Errorf("expecting Account to have one policy rule stored after SaveAccount(), got %d", len(a.Policies[0].Rules)) + return + } + + if a, err := store.GetAccountByPeerPubKey("peerkey"); a == nil { + t.Errorf("expecting PeerKeyID2AccountID index updated after SaveAccount(): %v", err) + } + + if a, err := store.GetAccountByUser("testuser"); a == nil { + t.Errorf("expecting UserID2AccountID index updated after SaveAccount(): %v", err) + } + + if a, err := store.GetAccountByPeerID("testpeer"); a == nil { + t.Errorf("expecting PeerID2AccountID index updated after SaveAccount(): %v", err) + } + + if a, err := store.GetAccountBySetupKey(setupKey.Key); a == nil { + t.Errorf("expecting SetupKeyID2AccountID index updated after SaveAccount(): %v", err) + } +} + +func TestPostgresql_DeleteAccount(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("The PostgreSQL store is not properly supported by Windows yet") + } + + store, cleanup := newPostgresqlStore(t) + defer cleanup() + + testUserID := "testuser" + user := NewAdminUser(testUserID) + user.PATs = map[string]*PersonalAccessToken{"testtoken": { + ID: "testtoken", + Name: "test token", + }} + + account := newAccountWithId("account_id", testUserID, "") + setupKey := GenerateDefaultSetupKey() + account.SetupKeys[setupKey.Key] = setupKey + account.Peers["testpeer"] = &nbpeer.Peer{ + Key: "peerkey", + SetupKey: "peerkeysetupkey", + IP: net.IP{127, 0, 0, 1}, + Meta: nbpeer.PeerSystemMeta{}, + Name: "peer name", + Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, + } + account.Users[testUserID] = user + + err := store.SaveAccount(account) + require.NoError(t, err) + + if len(store.GetAllAccounts()) != 1 { + t.Errorf("expecting 1 Accounts to be stored after SaveAccount()") + } + + err = store.DeleteAccount(account) + require.NoError(t, err) + + if len(store.GetAllAccounts()) != 0 { + t.Errorf("expecting 0 Accounts to be stored after DeleteAccount()") + } + + _, err = store.GetAccountByPeerPubKey("peerkey") + require.Error(t, err, "expecting error after removing DeleteAccount when getting account by peer public key") + + _, err = store.GetAccountByUser("testuser") + require.Error(t, err, "expecting error after removing DeleteAccount when getting account by user") + + _, err = store.GetAccountByPeerID("testpeer") + require.Error(t, err, "expecting error after removing DeleteAccount when getting account by peer id") + + _, err = store.GetAccountBySetupKey(setupKey.Key) + require.Error(t, err, "expecting error after removing DeleteAccount when getting account by setup key") + + _, err = store.GetAccount(account.Id) + require.Error(t, err, "expecting error after removing DeleteAccount when getting account by id") + + for _, policy := range account.Policies { + var rules []*PolicyRule + err = store.db.Model(&PolicyRule{}).Find(&rules, "policy_id = ?", policy.ID).Error + require.NoError(t, err, "expecting no error after removing DeleteAccount when searching for policy rules") + require.Len(t, rules, 0, "expecting no policy rules to be found after removing DeleteAccount") + + } + + for _, accountUser := range account.Users { + var pats []*PersonalAccessToken + err = store.db.Model(&PersonalAccessToken{}).Find(&pats, "user_id = ?", accountUser.Id).Error + require.NoError(t, err, "expecting no error after removing DeleteAccount when searching for personal access token") + require.Len(t, pats, 0, "expecting no personal access token to be found after removing DeleteAccount") + + } + +} + +func TestPostgresql_SavePeerStatus(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("The PostgreSQL store is not properly supported by Windows yet") + } + + store, cleanup := newPostgresqlStoreFromFile(t, "testdata/store.json") + defer cleanup() + + account, err := store.GetAccount("bf1c8084-ba50-4ce7-9439-34653001fc3b") + require.NoError(t, err) + + // save status of non-existing peer + newStatus := nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()} + err = store.SavePeerStatus(account.Id, "non-existing-peer", newStatus) + assert.Error(t, err) + + // save new status of existing peer + account.Peers["testpeer"] = &nbpeer.Peer{ + Key: "peerkey", + ID: "testpeer", + SetupKey: "peerkeysetupkey", + IP: net.IP{127, 0, 0, 1}, + Meta: nbpeer.PeerSystemMeta{}, + Name: "peer name", + Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now().UTC()}, + } + + err = store.SaveAccount(account) + require.NoError(t, err) + + err = store.SavePeerStatus(account.Id, "testpeer", newStatus) + require.NoError(t, err) + + account, err = store.GetAccount(account.Id) + require.NoError(t, err) + + actual := account.Peers["testpeer"].Status + assert.Equal(t, newStatus.Connected, actual.Connected) + assert.True(t, newStatus.LastSeen.Equal(actual.LastSeen)) +} + +func TestPostgresql_TestGetAccountByPrivateDomain(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("The PostgreSQL store is not properly supported by Windows yet") + } + + store, cleanup := newPostgresqlStoreFromFile(t, "testdata/store.json") + defer cleanup() + + existingDomain := "test.com" + + account, err := store.GetAccountByPrivateDomain(existingDomain) + require.NoError(t, err, "should found account") + require.Equal(t, existingDomain, account.Domain, "domains should match") + + _, err = store.GetAccountByPrivateDomain("missing-domain.com") + require.Error(t, err, "should return error on domain lookup") +} + +func TestPostgresql_GetTokenIDByHashedToken(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("The PostgreSQL store is not properly supported by Windows yet") + } + + store, cleanup := newPostgresqlStoreFromFile(t, "testdata/store.json") + defer cleanup() + + hashed := "SoMeHaShEdToKeN" + id := "9dj38s35-63fb-11ec-90d6-0242ac120003" + + token, err := store.GetTokenIDByHashedToken(hashed) + require.NoError(t, err) + require.Equal(t, id, token) +} + +func TestPostgresql_GetUserByTokenID(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("The PostgreSQL store is not properly supported by Windows yet") + } + + store, cleanup := newPostgresqlStoreFromFile(t, "testdata/store.json") + defer cleanup() + + id := "9dj38s35-63fb-11ec-90d6-0242ac120003" + + user, err := store.GetUserByTokenID(id) + require.NoError(t, err) + require.Equal(t, id, user.PATs[id].ID) +} + +func newPostgresqlStore(t *testing.T) (*PostgresqlStore, func()) { + t.Helper() + + dbName := "store_" + randString(10) + postgresDsn := "host=localhost user=postgres password=postgres port=5432 sslmode=disable" + db, _ := gorm.Open(postgres.Open(postgresDsn), &gorm.Config{}) + result := db.Exec(fmt.Sprintf("CREATE DATABASE %s ENCODING = 'UTF8'", dbName)) + if result.Error != nil { + t.Fatalf("could not initialize postgresql store: %s", result.Error) + } + postgresDsn = fmt.Sprintf("%s dbname=%s ", postgresDsn, dbName) + cleanup := func() { + db.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s WITH (FORCE)", dbName)) + } + + store, err := NewPostgresqlStore(postgresDsn, nil) + if err != nil { + t.Fatalf("could not initialize postgresql store: %s", err) + } + require.NoError(t, err) + require.NotNil(t, store) + + return store, cleanup +} + +func randString(n int) string { + var letterRunes = []rune("abcdefghijklmnopqrstuvwxyz1234567890") + b := make([]rune, n) + for i := range b { + b[i] = letterRunes[rand.Intn(len(letterRunes))] + } + return string(b) +} + +func newPostgresqlStoreFromFile(t *testing.T, filename string) (*PostgresqlStore, func()) { + t.Helper() + + storeDir := t.TempDir() + + err := util.CopyFileContents(filename, filepath.Join(storeDir, "store.json")) + require.NoError(t, err) + + fStore, err := NewFileStore(storeDir, nil) + require.NoError(t, err) + + dbName := "store_" + randString(10) + postgresDsn := "host=localhost user=postgres password=postgres port=5432 sslmode=disable" + db, _ := gorm.Open(postgres.Open(postgresDsn), &gorm.Config{}) + result := db.Exec(fmt.Sprintf("CREATE DATABASE %s ENCODING = 'UTF8'", dbName)) + if result.Error != nil { + t.Fatalf("could not initialize postgresql store: %s", result.Error) + } + postgresDsn = fmt.Sprintf("%s dbname=%s ", postgresDsn, dbName) + cleanup := func() { + db.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s WITH (FORCE)", dbName)) + } + + store, err := NewPostgresqlStoreFromFileStore(fStore, postgresDsn, nil) + require.NoError(t, err) + require.NotNil(t, store) + + return store, cleanup +} + +/* +func newAccount(store Store, id int) error { + str := fmt.Sprintf("%s-%d", uuid.New().String(), id) + account := newAccountWithId(str, str+"-testuser", "example.com") + setupKey := GenerateDefaultSetupKey() + account.SetupKeys[setupKey.Key] = setupKey + account.Peers["p"+str] = &nbpeer.Peer{ + Key: "peerkey" + str, + SetupKey: "peerkeysetupkey", + IP: net.IP{127, 0, 0, 1}, + Meta: nbpeer.PeerSystemMeta{}, + Name: "peer name", + Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, + } + + return store.SaveAccount(account) +} +*/ diff --git a/management/server/store.go b/management/server/store.go index a482ca947..55de10c4f 100644 --- a/management/server/store.go +++ b/management/server/store.go @@ -44,8 +44,9 @@ type Store interface { type StoreEngine string const ( - FileStoreEngine StoreEngine = "jsonfile" - SqliteStoreEngine StoreEngine = "sqlite" + FileStoreEngine StoreEngine = "jsonfile" + SqliteStoreEngine StoreEngine = "sqlite" + PostgresqlStoreEngine StoreEngine = "postgresql" ) func getStoreEngineFromEnv() StoreEngine { @@ -76,6 +77,9 @@ func NewStore(kind StoreEngine, dataDir string, metrics telemetry.AppMetrics) (S case SqliteStoreEngine: log.Info("using SQLite store engine") return NewSqliteStore(dataDir, metrics) + case PostgresqlStoreEngine: + log.Info("using PostgreSQL store engine") + return NewPostgresqlStore(dataDir, metrics) // dataDir is dsn default: return nil, fmt.Errorf("unsupported kind of store %s", kind) } @@ -94,6 +98,8 @@ func NewStoreFromJson(dataDir string, metrics telemetry.AppMetrics) (Store, erro return fstore, nil case SqliteStoreEngine: return NewSqliteStoreFromFileStore(fstore, dataDir, metrics) + case PostgresqlStoreEngine: + return NewPostgresqlStoreFromFileStore(fstore, dataDir, metrics) // dataDir is dsn default: return nil, fmt.Errorf("unsupported store engine %s", kind) }