From 6c50b0c84b66cfac198e0aed47bfc9688237fb7f Mon Sep 17 00:00:00 2001 From: pascal-fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Mon, 16 Sep 2024 15:47:03 +0200 Subject: [PATCH] [management] Add transaction to addPeer (#2469) This PR removes the GetAccount and SaveAccount operations from the AddPeer and instead makes use of gorm.Transaction to add the new peer. --- .github/workflows/golang-test-linux.yml | 2 +- management/server/account.go | 42 ++- management/server/ephemeral_test.go | 3 +- management/server/file_store.go | 168 +++++++++- management/server/management_proto_test.go | 47 ++- management/server/peer.go | 303 ++++++++++-------- management/server/peer_test.go | 185 +++++++++++ management/server/sql_store.go | 238 ++++++++++++-- management/server/sql_store_test.go | 160 +++++++++ management/server/status/error.go | 10 + management/server/store.go | 27 +- .../server/testdata/extended-store.json | 120 +++++++ management/server/user.go | 6 +- 13 files changed, 1095 insertions(+), 216 deletions(-) create mode 100644 management/server/testdata/extended-store.json diff --git a/.github/workflows/golang-test-linux.yml b/.github/workflows/golang-test-linux.yml index 263623bd1..2d5cf2856 100644 --- a/.github/workflows/golang-test-linux.yml +++ b/.github/workflows/golang-test-linux.yml @@ -49,7 +49,7 @@ jobs: run: git --no-pager diff --exit-code - name: Test - run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 ./... + run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 6m -p 1 ./... test_client_on_docker: runs-on: ubuntu-20.04 diff --git a/management/server/account.go b/management/server/account.go index 7159aa9ac..208315643 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -263,6 +263,11 @@ type AccountSettings struct { Settings *Settings `gorm:"embedded;embeddedPrefix:settings_"` } +// Subclass used in gorm to only load network and not whole account +type AccountNetwork struct { + Network *Network `gorm:"embedded;embeddedPrefix:network_"` +} + type UserPermissions struct { DashboardView string `json:"dashboard_view"` } @@ -700,14 +705,6 @@ func (a *Account) GetPeerGroupsList(peerID string) []string { return grps } -func (a *Account) getUserGroups(userID string) ([]string, error) { - user, err := a.FindUser(userID) - if err != nil { - return nil, err - } - return user.AutoGroups, nil -} - func (a *Account) getPeerDNSManagementStatus(peerID string) bool { peerGroups := a.getPeerGroups(peerID) enabled := true @@ -734,14 +731,6 @@ func (a *Account) getPeerGroups(peerID string) lookupMap { return groupList } -func (a *Account) getSetupKeyGroups(setupKey string) ([]string, error) { - key, err := a.FindSetupKey(setupKey) - if err != nil { - return nil, err - } - return key.AutoGroups, nil -} - func (a *Account) getTakenIPs() []net.IP { var takenIps []net.IP for _, existingPeer := range a.Peers { @@ -2082,7 +2071,7 @@ func (am *DefaultAccountManager) GetAccountIDForPeerKey(ctx context.Context, pee } func (am *DefaultAccountManager) handleUserPeer(ctx context.Context, peer *nbpeer.Peer, settings *Settings) (bool, error) { - user, err := am.Store.GetUserByUserID(ctx, peer.UserID) + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, peer.UserID) if err != nil { return false, err } @@ -2103,6 +2092,25 @@ func (am *DefaultAccountManager) handleUserPeer(ctx context.Context, peer *nbpee return false, nil } +func (am *DefaultAccountManager) getFreeDNSLabel(ctx context.Context, store Store, accountID string, peerHostName string) (string, error) { + existingLabels, err := store.GetPeerLabelsInAccount(ctx, LockingStrengthShare, accountID) + if err != nil { + return "", fmt.Errorf("failed to get peer dns labels: %w", err) + } + + labelMap := ConvertSliceToMap(existingLabels) + newLabel, err := getPeerHostLabel(peerHostName, labelMap) + if err != nil { + return "", fmt.Errorf("failed to get new host label: %w", err) + } + + if newLabel == "" { + return "", fmt.Errorf("failed to get new host label: %w", err) + } + + return newLabel, nil +} + // addAllGroup to account object if it doesn't exist func addAllGroup(account *Account) error { if len(account.Groups) == 0 { diff --git a/management/server/ephemeral_test.go b/management/server/ephemeral_test.go index 36c88f1d1..1390352a5 100644 --- a/management/server/ephemeral_test.go +++ b/management/server/ephemeral_test.go @@ -7,6 +7,7 @@ import ( "time" nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/status" ) type MockStore struct { @@ -24,7 +25,7 @@ func (s *MockStore) GetAccountByPeerID(_ context.Context, peerId string) (*Accou return s.account, nil } - return nil, fmt.Errorf("account not found") + return nil, status.NewPeerNotFoundError(peerId) } type MocAccountManager struct { diff --git a/management/server/file_store.go b/management/server/file_store.go index 1927568ef..95d5b4e6e 100644 --- a/management/server/file_store.go +++ b/management/server/file_store.go @@ -2,6 +2,8 @@ package server import ( "context" + "errors" + "net" "os" "path/filepath" "strings" @@ -46,6 +48,158 @@ type FileStore struct { metrics telemetry.AppMetrics `json:"-"` } +func (s *FileStore) ExecuteInTransaction(ctx context.Context, f func(store Store) error) error { + return f(s) +} + +func (s *FileStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error { + s.mux.Lock() + defer s.mux.Unlock() + + accountID, ok := s.SetupKeyID2AccountID[strings.ToUpper(setupKeyID)] + if !ok { + return status.NewSetupKeyNotFoundError() + } + + account, err := s.getAccount(accountID) + if err != nil { + return err + } + + account.SetupKeys[setupKeyID].UsedTimes++ + + return s.SaveAccount(ctx, account) +} + +func (s *FileStore) AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error { + s.mux.Lock() + defer s.mux.Unlock() + + account, err := s.getAccount(accountID) + if err != nil { + return err + } + + allGroup, err := account.GetGroupAll() + if err != nil || allGroup == nil { + return errors.New("all group not found") + } + + allGroup.Peers = append(allGroup.Peers, peerID) + + return nil +} + +func (s *FileStore) AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error { + s.mux.Lock() + defer s.mux.Unlock() + + account, err := s.getAccount(accountId) + if err != nil { + return err + } + + account.Groups[groupID].Peers = append(account.Groups[groupID].Peers, peerId) + + return nil +} + +func (s *FileStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error { + s.mux.Lock() + defer s.mux.Unlock() + + account, ok := s.Accounts[peer.AccountID] + if !ok { + return status.NewAccountNotFoundError(peer.AccountID) + } + + account.Peers[peer.ID] = peer + return s.SaveAccount(ctx, account) +} + +func (s *FileStore) IncrementNetworkSerial(ctx context.Context, accountId string) error { + s.mux.Lock() + defer s.mux.Unlock() + + account, ok := s.Accounts[accountId] + if !ok { + return status.NewAccountNotFoundError(accountId) + } + + account.Network.Serial++ + + return s.SaveAccount(ctx, account) +} + +func (s *FileStore) GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error) { + s.mux.Lock() + defer s.mux.Unlock() + + accountID, ok := s.SetupKeyID2AccountID[strings.ToUpper(key)] + if !ok { + return nil, status.NewSetupKeyNotFoundError() + } + + account, err := s.getAccount(accountID) + if err != nil { + return nil, err + } + + setupKey, ok := account.SetupKeys[key] + if !ok { + return nil, status.Errorf(status.NotFound, "setup key not found") + } + + return setupKey, nil +} + +func (s *FileStore) GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountID string) ([]net.IP, error) { + s.mux.Lock() + defer s.mux.Unlock() + + account, err := s.getAccount(accountID) + if err != nil { + return nil, err + } + + var takenIps []net.IP + for _, existingPeer := range account.Peers { + takenIps = append(takenIps, existingPeer.IP) + } + + return takenIps, nil +} + +func (s *FileStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountID string) ([]string, error) { + s.mux.Lock() + defer s.mux.Unlock() + + account, err := s.getAccount(accountID) + if err != nil { + return nil, err + } + + existingLabels := []string{} + for _, peer := range account.Peers { + if peer.DNSLabel != "" { + existingLabels = append(existingLabels, peer.DNSLabel) + } + } + return existingLabels, nil +} + +func (s *FileStore) GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountID string) (*Network, error) { + s.mux.Lock() + defer s.mux.Unlock() + + account, err := s.getAccount(accountID) + if err != nil { + return nil, err + } + + return account.Network, nil +} + type StoredAccount struct{} // NewFileStore restores a store from the file located in the datadir @@ -422,7 +576,7 @@ func (s *FileStore) GetAccountBySetupKey(_ context.Context, setupKey string) (*A accountID, ok := s.SetupKeyID2AccountID[strings.ToUpper(setupKey)] if !ok { - return nil, status.Errorf(status.NotFound, "account not found: provided setup key doesn't exists") + return nil, status.NewSetupKeyNotFoundError() } account, err := s.getAccount(accountID) @@ -469,7 +623,7 @@ func (s *FileStore) GetUserByTokenID(_ context.Context, tokenID string) (*User, return account.Users[userID].Copy(), nil } -func (s *FileStore) GetUserByUserID(_ context.Context, userID string) (*User, error) { +func (s *FileStore) GetUserByUserID(_ context.Context, _ LockingStrength, userID string) (*User, error) { accountID, ok := s.UserID2AccountID[userID] if !ok { return nil, status.Errorf(status.NotFound, "accountID not found: provided userID doesn't exists") @@ -513,7 +667,7 @@ func (s *FileStore) GetAllAccounts(_ context.Context) (all []*Account) { func (s *FileStore) getAccount(accountID string) (*Account, error) { account, ok := s.Accounts[accountID] if !ok { - return nil, status.Errorf(status.NotFound, "account not found") + return nil, status.NewAccountNotFoundError(accountID) } return account, nil @@ -639,13 +793,13 @@ func (s *FileStore) GetAccountIDBySetupKey(_ context.Context, setupKey string) ( accountID, ok := s.SetupKeyID2AccountID[strings.ToUpper(setupKey)] if !ok { - return "", status.Errorf(status.NotFound, "account not found: provided setup key doesn't exists") + return "", status.NewSetupKeyNotFoundError() } return accountID, nil } -func (s *FileStore) GetPeerByPeerPubKey(_ context.Context, peerKey string) (*nbpeer.Peer, error) { +func (s *FileStore) GetPeerByPeerPubKey(_ context.Context, _ LockingStrength, peerKey string) (*nbpeer.Peer, error) { s.mux.Lock() defer s.mux.Unlock() @@ -668,7 +822,7 @@ func (s *FileStore) GetPeerByPeerPubKey(_ context.Context, peerKey string) (*nbp return nil, status.NewPeerNotFoundError(peerKey) } -func (s *FileStore) GetAccountSettings(_ context.Context, accountID string) (*Settings, error) { +func (s *FileStore) GetAccountSettings(_ context.Context, _ LockingStrength, accountID string) (*Settings, error) { s.mux.Lock() defer s.mux.Unlock() @@ -758,7 +912,7 @@ func (s *FileStore) SavePeerLocation(accountID string, peerWithLocation *nbpeer. } // SaveUserLastLogin stores the last login time for a user in memory. It doesn't attempt to persist data to speed up things. -func (s *FileStore) SaveUserLastLogin(accountID, userID string, lastLogin time.Time) error { +func (s *FileStore) SaveUserLastLogin(_ context.Context, accountID, userID string, lastLogin time.Time) error { s.mux.Lock() defer s.mux.Unlock() diff --git a/management/server/management_proto_test.go b/management/server/management_proto_test.go index 00ee4bda2..ff09129bd 100644 --- a/management/server/management_proto_test.go +++ b/management/server/management_proto_test.go @@ -627,7 +627,7 @@ func testSyncStatusRace(t *testing.T) { } time.Sleep(10 * time.Millisecond) - peer, err := am.Store.GetPeerByPeerPubKey(context.Background(), peerWithInvalidStatus.PublicKey().String()) + peer, err := am.Store.GetPeerByPeerPubKey(context.Background(), LockingStrengthShare, peerWithInvalidStatus.PublicKey().String()) if err != nil { t.Fatal(err) return @@ -638,8 +638,8 @@ func testSyncStatusRace(t *testing.T) { } func Test_LoginPerformance(t *testing.T) { - if os.Getenv("CI") == "true" { - t.Skip("Skipping on CI") + if os.Getenv("CI") == "true" || runtime.GOOS == "windows" { + t.Skip("Skipping test on CI or Windows") } t.Setenv("NETBIRD_STORE_ENGINE", "sqlite") @@ -655,7 +655,7 @@ func Test_LoginPerformance(t *testing.T) { // {"M", 250, 1}, // {"L", 500, 1}, // {"XL", 750, 1}, - {"XXL", 2000, 1}, + {"XXL", 5000, 1}, } log.SetOutput(io.Discard) @@ -700,15 +700,18 @@ func Test_LoginPerformance(t *testing.T) { } defer mgmtServer.GracefulStop() + t.Logf("management setup complete, start registering peers") + var counter int32 var counterStart int32 - var wg sync.WaitGroup + var wgAccount sync.WaitGroup var mu sync.Mutex messageCalls := []func() error{} for j := 0; j < bc.accounts; j++ { - wg.Add(1) + wgAccount.Add(1) + var wgPeer sync.WaitGroup go func(j int, counter *int32, counterStart *int32) { - defer wg.Done() + defer wgAccount.Done() account, err := createAccount(am, fmt.Sprintf("account-%d", j), fmt.Sprintf("user-%d", j), fmt.Sprintf("domain-%d", j)) if err != nil { @@ -722,7 +725,9 @@ func Test_LoginPerformance(t *testing.T) { return } + startTime := time.Now() for i := 0; i < bc.peers; i++ { + wgPeer.Add(1) key, err := wgtypes.GeneratePrivateKey() if err != nil { t.Logf("failed to generate key: %v", err) @@ -763,21 +768,29 @@ func Test_LoginPerformance(t *testing.T) { mu.Lock() messageCalls = append(messageCalls, login) mu.Unlock() - _, _, _, err = am.LoginPeer(context.Background(), peerLogin) - if err != nil { - t.Logf("failed to login peer: %v", err) - return - } - atomic.AddInt32(counterStart, 1) - if *counterStart%100 == 0 { - t.Logf("registered %d peers", *counterStart) - } + go func(peerLogin PeerLogin, counterStart *int32) { + defer wgPeer.Done() + _, _, _, err = am.LoginPeer(context.Background(), peerLogin) + if err != nil { + t.Logf("failed to login peer: %v", err) + return + } + + atomic.AddInt32(counterStart, 1) + if *counterStart%100 == 0 { + t.Logf("registered %d peers", *counterStart) + } + }(peerLogin, counterStart) + } + wgPeer.Wait() + + t.Logf("Time for registration: %s", time.Since(startTime)) }(j, &counter, &counterStart) } - wg.Wait() + wgAccount.Wait() t.Logf("prepared %d login calls", len(messageCalls)) testLoginPerformance(t, messageCalls) diff --git a/management/server/peer.go b/management/server/peer.go index 26e27617d..da9586734 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -11,6 +11,7 @@ import ( "github.com/rs/xid" log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/proto" @@ -371,164 +372,175 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s } }() - var account *Account - // ensure that we consider modification happened meanwhile (because we were outside the account lock when we fetched the account) - account, err = am.Store.GetAccount(ctx, accountID) - if err != nil { - return nil, nil, nil, err - } - - if strings.ToLower(peer.Meta.Hostname) == "iphone" || strings.ToLower(peer.Meta.Hostname) == "ipad" && userID != "" { - if am.idpManager != nil { - userdata, err := am.lookupUserInCache(ctx, userID, account) - if err == nil && userdata != nil { - peer.Meta.Hostname = fmt.Sprintf("%s-%s", peer.Meta.Hostname, strings.Split(userdata.Email, "@")[0]) - } - } - } - // This is a handling for the case when the same machine (with the same WireGuard pub key) tries to register twice. // Such case is possible when AddPeer function takes long time to finish after AcquireWriteLockByUID (e.g., database is slow) // and the peer disconnects with a timeout and tries to register again. // We just check if this machine has been registered before and reject the second registration. // The connecting peer should be able to recover with a retry. - _, err = account.FindPeerByPubKey(peer.Key) + _, err = am.Store.GetPeerByPeerPubKey(ctx, LockingStrengthShare, peer.Key) if err == nil { return nil, nil, nil, status.Errorf(status.PreconditionFailed, "peer has been already registered") } opEvent := &activity.Event{ Timestamp: time.Now().UTC(), - AccountID: account.Id, + AccountID: accountID, } - var ephemeral bool - setupKeyName := "" - if !addedByUser { - // validate the setup key if adding with a key - sk, err := account.FindSetupKey(upperKey) - if err != nil { - return nil, nil, nil, err - } + var newPeer *nbpeer.Peer - if !sk.IsValid() { - return nil, nil, nil, status.Errorf(status.PreconditionFailed, "couldn't add peer: setup key is invalid") - } - - account.SetupKeys[sk.Key] = sk.IncrementUsage() - opEvent.InitiatorID = sk.Id - opEvent.Activity = activity.PeerAddedWithSetupKey - ephemeral = sk.Ephemeral - setupKeyName = sk.Name - } else { - opEvent.InitiatorID = userID - opEvent.Activity = activity.PeerAddedByUser - } - - takenIps := account.getTakenIPs() - existingLabels := account.getPeerDNSLabels() - - newLabel, err := getPeerHostLabel(peer.Meta.Hostname, existingLabels) - if err != nil { - return nil, nil, nil, err - } - - peer.DNSLabel = newLabel - network := account.Network - nextIp, err := AllocatePeerIP(network.Net, takenIps) - if err != nil { - return nil, nil, nil, err - } - - registrationTime := time.Now().UTC() - - newPeer := &nbpeer.Peer{ - ID: xid.New().String(), - Key: peer.Key, - SetupKey: upperKey, - IP: nextIp, - Meta: peer.Meta, - Name: peer.Meta.Hostname, - DNSLabel: newLabel, - UserID: userID, - Status: &nbpeer.PeerStatus{Connected: false, LastSeen: registrationTime}, - SSHEnabled: false, - SSHKey: peer.SSHKey, - LastLogin: registrationTime, - CreatedAt: registrationTime, - LoginExpirationEnabled: addedByUser, - Ephemeral: ephemeral, - Location: peer.Location, - } - - if am.geo != nil && newPeer.Location.ConnectionIP != nil { - location, err := am.geo.Lookup(newPeer.Location.ConnectionIP) - if err != nil { - log.WithContext(ctx).Warnf("failed to get location for new peer realip: [%s]: %v", newPeer.Location.ConnectionIP.String(), err) + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + var groupsToAdd []string + var setupKeyID string + var setupKeyName string + var ephemeral bool + if addedByUser { + user, err := transaction.GetUserByUserID(ctx, LockingStrengthUpdate, userID) + if err != nil { + return fmt.Errorf("failed to get user groups: %w", err) + } + groupsToAdd = user.AutoGroups + opEvent.InitiatorID = userID + opEvent.Activity = activity.PeerAddedByUser } else { - newPeer.Location.CountryCode = location.Country.ISOCode - newPeer.Location.CityName = location.City.Names.En - newPeer.Location.GeoNameID = location.City.GeonameID - } - } + // Validate the setup key + sk, err := transaction.GetSetupKeyBySecret(ctx, LockingStrengthUpdate, upperKey) + if err != nil { + return fmt.Errorf("failed to get setup key: %w", err) + } - // add peer to 'All' group - group, err := account.GetGroupAll() - if err != nil { - return nil, nil, nil, err - } - group.Peers = append(group.Peers, newPeer.ID) + if !sk.IsValid() { + return status.Errorf(status.PreconditionFailed, "couldn't add peer: setup key is invalid") + } - var groupsToAdd []string - if addedByUser { - groupsToAdd, err = account.getUserGroups(userID) - if err != nil { - return nil, nil, nil, err + opEvent.InitiatorID = sk.Id + opEvent.Activity = activity.PeerAddedWithSetupKey + groupsToAdd = sk.AutoGroups + ephemeral = sk.Ephemeral + setupKeyID = sk.Id + setupKeyName = sk.Name } - } else { - groupsToAdd, err = account.getSetupKeyGroups(upperKey) - if err != nil { - return nil, nil, nil, err - } - } - if len(groupsToAdd) > 0 { - for _, s := range groupsToAdd { - if g, ok := account.Groups[s]; ok && g.Name != "All" { - g.Peers = append(g.Peers, newPeer.ID) + if strings.ToLower(peer.Meta.Hostname) == "iphone" || strings.ToLower(peer.Meta.Hostname) == "ipad" && userID != "" { + if am.idpManager != nil { + userdata, err := am.idpManager.GetUserDataByID(ctx, userID, idp.AppMetadata{WTAccountID: accountID}) + if err == nil && userdata != nil { + peer.Meta.Hostname = fmt.Sprintf("%s-%s", peer.Meta.Hostname, strings.Split(userdata.Email, "@")[0]) + } } } - } - newPeer = am.integratedPeerValidator.PreparePeer(ctx, account.Id, newPeer, account.GetPeerGroupsList(newPeer.ID), account.Settings.Extra) - - if addedByUser { - user, err := account.FindUser(userID) + freeLabel, err := am.getFreeDNSLabel(ctx, transaction, accountID, peer.Meta.Hostname) if err != nil { - return nil, nil, nil, status.Errorf(status.Internal, "couldn't find user") + return fmt.Errorf("failed to get free DNS label: %w", err) } - user.updateLastLogin(newPeer.LastLogin) - } - account.Peers[newPeer.ID] = newPeer - account.Network.IncSerial() - err = am.Store.SaveAccount(ctx, account) + freeIP, err := am.getFreeIP(ctx, transaction, accountID) + if err != nil { + return fmt.Errorf("failed to get free IP: %w", err) + } + + registrationTime := time.Now().UTC() + newPeer = &nbpeer.Peer{ + ID: xid.New().String(), + AccountID: accountID, + Key: peer.Key, + SetupKey: upperKey, + IP: freeIP, + Meta: peer.Meta, + Name: peer.Meta.Hostname, + DNSLabel: freeLabel, + UserID: userID, + Status: &nbpeer.PeerStatus{Connected: false, LastSeen: registrationTime}, + SSHEnabled: false, + SSHKey: peer.SSHKey, + LastLogin: registrationTime, + CreatedAt: registrationTime, + LoginExpirationEnabled: addedByUser, + Ephemeral: ephemeral, + Location: peer.Location, + } + opEvent.TargetID = newPeer.ID + opEvent.Meta = newPeer.EventMeta(am.GetDNSDomain()) + if !addedByUser { + opEvent.Meta["setup_key_name"] = setupKeyName + } + + if am.geo != nil && newPeer.Location.ConnectionIP != nil { + location, err := am.geo.Lookup(newPeer.Location.ConnectionIP) + if err != nil { + log.WithContext(ctx).Warnf("failed to get location for new peer realip: [%s]: %v", newPeer.Location.ConnectionIP.String(), err) + } else { + newPeer.Location.CountryCode = location.Country.ISOCode + newPeer.Location.CityName = location.City.Names.En + newPeer.Location.GeoNameID = location.City.GeonameID + } + } + + settings, err := transaction.GetAccountSettings(ctx, LockingStrengthShare, accountID) + if err != nil { + return fmt.Errorf("failed to get account settings: %w", err) + } + newPeer = am.integratedPeerValidator.PreparePeer(ctx, accountID, newPeer, groupsToAdd, settings.Extra) + + err = transaction.AddPeerToAllGroup(ctx, accountID, newPeer.ID) + if err != nil { + return fmt.Errorf("failed adding peer to All group: %w", err) + } + + if len(groupsToAdd) > 0 { + for _, g := range groupsToAdd { + err = transaction.AddPeerToGroup(ctx, accountID, newPeer.ID, g) + if err != nil { + return err + } + } + } + + err = transaction.AddPeerToAccount(ctx, newPeer) + if err != nil { + return fmt.Errorf("failed to add peer to account: %w", err) + } + + err = transaction.IncrementNetworkSerial(ctx, accountID) + if err != nil { + return fmt.Errorf("failed to increment network serial: %w", err) + } + + if addedByUser { + err := transaction.SaveUserLastLogin(ctx, accountID, userID, newPeer.LastLogin) + if err != nil { + return fmt.Errorf("failed to update user last login: %w", err) + } + } else { + err = transaction.IncrementSetupKeyUsage(ctx, setupKeyID) + if err != nil { + return fmt.Errorf("failed to increment setup key usage: %w", err) + } + } + + log.WithContext(ctx).Debugf("Peer %s added to account %s", newPeer.ID, accountID) + return nil + }) + if err != nil { - return nil, nil, nil, err + return nil, nil, nil, fmt.Errorf("failed to add peer to database: %w", err) } - // Account is saved, we can release the lock - unlock() - unlock = nil - - opEvent.TargetID = newPeer.ID - opEvent.Meta = newPeer.EventMeta(am.GetDNSDomain()) - if !addedByUser { - opEvent.Meta["setup_key_name"] = setupKeyName + if newPeer == nil { + return nil, nil, nil, fmt.Errorf("new peer is nil") } am.StoreEvent(ctx, opEvent.InitiatorID, opEvent.TargetID, opEvent.AccountID, opEvent.Activity, opEvent.Meta) + unlock() + unlock = nil + + account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) + if err != nil { + return nil, nil, nil, fmt.Errorf("error getting account: %w", err) + } + am.updateAccountPeers(ctx, account) approvedPeersMap, err := am.GetValidatedPeers(account) @@ -536,12 +548,31 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s return nil, nil, nil, err } - postureChecks := am.getPeerPostureChecks(account, peer) + postureChecks := am.getPeerPostureChecks(account, newPeer) customZone := account.GetPeersCustomZone(ctx, am.dnsDomain) networkMap := account.GetPeerNetworkMap(ctx, newPeer.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics()) return newPeer, networkMap, postureChecks, nil } +func (am *DefaultAccountManager) getFreeIP(ctx context.Context, store Store, accountID string) (net.IP, error) { + takenIps, err := store.GetTakenIPs(ctx, LockingStrengthShare, accountID) + if err != nil { + return nil, fmt.Errorf("failed to get taken IPs: %w", err) + } + + network, err := store.GetAccountNetwork(ctx, LockingStrengthUpdate, accountID) + if err != nil { + return nil, fmt.Errorf("failed getting network: %w", err) + } + + nextIp, err := AllocatePeerIP(network.Net, takenIps) + if err != nil { + return nil, fmt.Errorf("failed to allocate new peer ip: %w", err) + } + + return nextIp, nil +} + // SyncPeer checks whether peer is eligible for receiving NetworkMap (authenticated) and returns its NetworkMap if eligible func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, account *Account) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) { peer, err := account.FindPeerByPubKey(sync.WireGuardPubKey) @@ -647,12 +678,12 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin) } }() - peer, err := am.Store.GetPeerByPeerPubKey(ctx, login.WireGuardPubKey) + peer, err := am.Store.GetPeerByPeerPubKey(ctx, LockingStrengthUpdate, login.WireGuardPubKey) if err != nil { return nil, nil, nil, err } - settings, err := am.Store.GetAccountSettings(ctx, accountID) + settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID) if err != nil { return nil, nil, nil, err } @@ -730,7 +761,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin) // with no JWT token and usually no setup-key. As the client can send up to two login request to check if it is expired // and before starting the engine, we do the checks without an account lock to avoid piling up requests. func (am *DefaultAccountManager) checkIFPeerNeedsLoginWithoutLock(ctx context.Context, accountID string, login PeerLogin) error { - peer, err := am.Store.GetPeerByPeerPubKey(ctx, login.WireGuardPubKey) + peer, err := am.Store.GetPeerByPeerPubKey(ctx, LockingStrengthShare, login.WireGuardPubKey) if err != nil { return err } @@ -741,7 +772,7 @@ func (am *DefaultAccountManager) checkIFPeerNeedsLoginWithoutLock(ctx context.Co return nil } - settings, err := am.Store.GetAccountSettings(ctx, accountID) + settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID) if err != nil { return err } @@ -786,7 +817,7 @@ func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, user *Us return err } - err = am.Store.SaveUserLastLogin(user.AccountID, user.Id, peer.LastLogin) + err = am.Store.SaveUserLastLogin(ctx, user.AccountID, user.Id, peer.LastLogin) if err != nil { return err } @@ -969,3 +1000,11 @@ func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account wg.Wait() } + +func ConvertSliceToMap(existingLabels []string) map[string]struct{} { + labelMap := make(map[string]struct{}, len(existingLabels)) + for _, label := range existingLabels { + labelMap[label] = struct{}{} + } + return labelMap +} diff --git a/management/server/peer_test.go b/management/server/peer_test.go index 448e83a08..4b2ec66c6 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -7,20 +7,24 @@ import ( "net" "net/netip" "os" + "runtime" "testing" "time" "github.com/rs/xid" log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/management/proto" + "github.com/netbirdio/netbird/management/server/activity" nbgroup "github.com/netbirdio/netbird/management/server/group" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/posture" + "github.com/netbirdio/netbird/management/server/telemetry" nbroute "github.com/netbirdio/netbird/route" ) @@ -995,3 +999,184 @@ func TestToSyncResponse(t *testing.T) { assert.Equal(t, 1, len(response.Checks)) assert.Equal(t, "/usr/bin/netbird", response.Checks[0].Files[0]) } + +func Test_RegisterPeerByUser(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("The SQLite store is not properly supported by Windows yet") + } + + store := newSqliteStoreFromFile(t, "testdata/extended-store.json") + + eventStore := &activity.InMemoryEventStore{} + + metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) + assert.NoError(t, err) + + am, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics) + assert.NoError(t, err) + + existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + existingUserID := "edafee4e-63fb-11ec-90d6-0242ac120003" + + _, err = store.GetAccount(context.Background(), existingAccountID) + require.NoError(t, err) + + newPeer := &nbpeer.Peer{ + ID: xid.New().String(), + AccountID: existingAccountID, + Key: "newPeerKey", + SetupKey: "", + IP: net.IP{123, 123, 123, 123}, + Meta: nbpeer.PeerSystemMeta{ + Hostname: "newPeer", + GoOS: "linux", + }, + Name: "newPeerName", + DNSLabel: "newPeer.test", + UserID: existingUserID, + Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now()}, + SSHEnabled: false, + LastLogin: time.Now(), + } + + addedPeer, _, _, err := am.AddPeer(context.Background(), "", existingUserID, newPeer) + require.NoError(t, err) + + peer, err := store.GetPeerByPeerPubKey(context.Background(), LockingStrengthShare, addedPeer.Key) + require.NoError(t, err) + assert.Equal(t, peer.AccountID, existingAccountID) + assert.Equal(t, peer.UserID, existingUserID) + + account, err := store.GetAccount(context.Background(), existingAccountID) + require.NoError(t, err) + assert.Contains(t, account.Peers, addedPeer.ID) + assert.Equal(t, peer.Meta.Hostname, newPeer.Meta.Hostname) + assert.Contains(t, account.Groups["cfefqs706sqkneg59g3g"].Peers, addedPeer.ID) + assert.Contains(t, account.Groups["cfefqs706sqkneg59g4g"].Peers, addedPeer.ID) + + assert.Equal(t, uint64(1), account.Network.Serial) + + lastLogin, err := time.Parse("2006-01-02T15:04:05Z", "0001-01-01T00:00:00Z") + assert.NoError(t, err) + assert.NotEqual(t, lastLogin, account.Users[existingUserID].LastLogin) +} + +func Test_RegisterPeerBySetupKey(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("The SQLite store is not properly supported by Windows yet") + } + + store := newSqliteStoreFromFile(t, "testdata/extended-store.json") + + eventStore := &activity.InMemoryEventStore{} + + metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) + assert.NoError(t, err) + + am, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics) + assert.NoError(t, err) + + existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + existingSetupKeyID := "A2C8E62B-38F5-4553-B31E-DD66C696CEBB" + + _, err = store.GetAccount(context.Background(), existingAccountID) + require.NoError(t, err) + + newPeer := &nbpeer.Peer{ + ID: xid.New().String(), + AccountID: existingAccountID, + Key: "newPeerKey", + SetupKey: "existingSetupKey", + UserID: "", + IP: net.IP{123, 123, 123, 123}, + Meta: nbpeer.PeerSystemMeta{ + Hostname: "newPeer", + GoOS: "linux", + }, + Name: "newPeerName", + DNSLabel: "newPeer.test", + Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now()}, + SSHEnabled: false, + } + + addedPeer, _, _, err := am.AddPeer(context.Background(), existingSetupKeyID, "", newPeer) + + require.NoError(t, err) + + peer, err := store.GetPeerByPeerPubKey(context.Background(), LockingStrengthShare, newPeer.Key) + require.NoError(t, err) + assert.Equal(t, peer.AccountID, existingAccountID) + assert.Equal(t, peer.SetupKey, existingSetupKeyID) + + account, err := store.GetAccount(context.Background(), existingAccountID) + require.NoError(t, err) + assert.Contains(t, account.Peers, addedPeer.ID) + assert.Contains(t, account.Groups["cfefqs706sqkneg59g2g"].Peers, addedPeer.ID) + assert.Contains(t, account.Groups["cfefqs706sqkneg59g4g"].Peers, addedPeer.ID) + + assert.Equal(t, uint64(1), account.Network.Serial) + + lastUsed, err := time.Parse("2006-01-02T15:04:05Z", "0001-01-01T00:00:00Z") + assert.NoError(t, err) + assert.NotEqual(t, lastUsed, account.SetupKeys[existingSetupKeyID].LastUsed) + assert.Equal(t, 1, account.SetupKeys[existingSetupKeyID].UsedTimes) + +} + +func Test_RegisterPeerRollbackOnFailure(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("The SQLite store is not properly supported by Windows yet") + } + + store := newSqliteStoreFromFile(t, "testdata/extended-store.json") + + eventStore := &activity.InMemoryEventStore{} + + metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) + assert.NoError(t, err) + + am, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics) + assert.NoError(t, err) + + existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + faultyKey := "A2C8E62B-38F5-4553-B31E-DD66C696CEBC" + + _, err = store.GetAccount(context.Background(), existingAccountID) + require.NoError(t, err) + + newPeer := &nbpeer.Peer{ + ID: xid.New().String(), + AccountID: existingAccountID, + Key: "newPeerKey", + SetupKey: "existingSetupKey", + UserID: "", + IP: net.IP{123, 123, 123, 123}, + Meta: nbpeer.PeerSystemMeta{ + Hostname: "newPeer", + GoOS: "linux", + }, + Name: "newPeerName", + DNSLabel: "newPeer.test", + Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now()}, + SSHEnabled: false, + } + + _, _, _, err = am.AddPeer(context.Background(), faultyKey, "", newPeer) + require.Error(t, err) + + _, err = store.GetPeerByPeerPubKey(context.Background(), LockingStrengthShare, newPeer.Key) + require.Error(t, err) + + account, err := store.GetAccount(context.Background(), existingAccountID) + require.NoError(t, err) + assert.NotContains(t, account.Peers, newPeer.ID) + assert.NotContains(t, account.Groups["cfefqs706sqkneg59g3g"].Peers, newPeer.ID) + assert.NotContains(t, account.Groups["cfefqs706sqkneg59g4g"].Peers, newPeer.ID) + + assert.Equal(t, uint64(0), account.Network.Serial) + + lastUsed, err := time.Parse("2006-01-02T15:04:05Z", "0001-01-01T00:00:00Z") + assert.NoError(t, err) + assert.Equal(t, lastUsed, account.SetupKeys[faultyKey].LastUsed) + assert.Equal(t, 0, account.SetupKeys[faultyKey].UsedTimes) +} diff --git a/management/server/sql_store.go b/management/server/sql_store.go index 0fb3d391f..6f1f66ef8 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -5,6 +5,7 @@ import ( "encoding/json" "errors" "fmt" + "net" "os" "path/filepath" "runtime" @@ -33,6 +34,7 @@ import ( const ( storeSqliteFileName = "store.db" idQueryCondition = "id = ?" + keyQueryCondition = "key = ?" accountAndIDQueryCondition = "account_id = ? and id = ?" peerNotFoundFMT = "peer %s not found" ) @@ -415,13 +417,12 @@ func (s *SqlStore) GetAccountByPrivateDomain(ctx context.Context, domain string) func (s *SqlStore) GetAccountBySetupKey(ctx context.Context, setupKey string) (*Account, error) { var key SetupKey - result := s.db.Select("account_id").First(&key, "key = ?", strings.ToUpper(setupKey)) + result := s.db.WithContext(ctx).Select("account_id").First(&key, keyQueryCondition, strings.ToUpper(setupKey)) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") } - log.WithContext(ctx).Errorf("error when getting setup key from the store: %s", result.Error) - return nil, status.Errorf(status.Internal, "issue getting setup key from store") + return nil, status.NewSetupKeyNotFoundError() } if key.AccountID == "" { @@ -474,15 +475,15 @@ func (s *SqlStore) GetUserByTokenID(ctx context.Context, tokenID string) (*User, return &user, nil } -func (s *SqlStore) GetUserByUserID(ctx context.Context, userID string) (*User, error) { +func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error) { var user User - result := s.db.First(&user, idQueryCondition, userID) + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}). + First(&user, idQueryCondition, userID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { - return nil, status.Errorf(status.NotFound, "user not found: index lookup failed") + return nil, status.NewUserNotFoundError(userID) } - log.WithContext(ctx).Errorf("error when getting user from the store: %s", result.Error) - return nil, status.Errorf(status.Internal, "issue getting user from store") + return nil, status.NewGetUserFromStoreError() } return &user, nil @@ -535,7 +536,7 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*Account, if result.Error != nil { log.WithContext(ctx).Errorf("error when getting account %s from the store: %s", accountID, result.Error) if errors.Is(result.Error, gorm.ErrRecordNotFound) { - return nil, status.Errorf(status.NotFound, "account not found") + return nil, status.NewAccountNotFoundError(accountID) } return nil, status.Errorf(status.Internal, "issue getting account from store") } @@ -595,7 +596,7 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*Account, func (s *SqlStore) GetAccountByUser(ctx context.Context, userID string) (*Account, error) { var user User - result := s.db.Select("account_id").First(&user, idQueryCondition, userID) + result := s.db.WithContext(ctx).Select("account_id").First(&user, idQueryCondition, userID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") @@ -612,12 +613,11 @@ func (s *SqlStore) GetAccountByUser(ctx context.Context, userID string) (*Accoun func (s *SqlStore) GetAccountByPeerID(ctx context.Context, peerID string) (*Account, error) { var peer nbpeer.Peer - result := s.db.Select("account_id").First(&peer, idQueryCondition, peerID) + result := s.db.WithContext(ctx).Select("account_id").First(&peer, idQueryCondition, peerID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") } - log.WithContext(ctx).Errorf("error when getting peer from the store: %s", result.Error) return nil, status.Errorf(status.Internal, "issue getting account from store") } @@ -631,12 +631,11 @@ func (s *SqlStore) GetAccountByPeerID(ctx context.Context, peerID string) (*Acco func (s *SqlStore) GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*Account, error) { var peer nbpeer.Peer - result := s.db.Select("account_id").First(&peer, "key = ?", peerKey) + result := s.db.WithContext(ctx).Select("account_id").First(&peer, keyQueryCondition, peerKey) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") } - log.WithContext(ctx).Errorf("error when getting peer from the store: %s", result.Error) return nil, status.Errorf(status.Internal, "issue getting account from store") } @@ -650,12 +649,11 @@ func (s *SqlStore) GetAccountByPeerPubKey(ctx context.Context, peerKey string) ( func (s *SqlStore) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) (string, error) { var peer nbpeer.Peer var accountID string - result := s.db.Model(&peer).Select("account_id").Where("key = ?", peerKey).First(&accountID) + result := s.db.WithContext(ctx).Model(&peer).Select("account_id").Where(keyQueryCondition, peerKey).First(&accountID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return "", status.Errorf(status.NotFound, "account not found: index lookup failed") } - log.WithContext(ctx).Errorf("error when getting peer from the store: %s", result.Error) return "", status.Errorf(status.Internal, "issue getting account from store") } @@ -677,61 +675,117 @@ func (s *SqlStore) GetAccountIDByUserID(userID string) (string, error) { } func (s *SqlStore) GetAccountIDBySetupKey(ctx context.Context, setupKey string) (string, error) { - var key SetupKey var accountID string - result := s.db.Model(&key).Select("account_id").Where("key = ?", strings.ToUpper(setupKey)).First(&accountID) + result := s.db.WithContext(ctx).Model(&SetupKey{}).Select("account_id").Where(keyQueryCondition, strings.ToUpper(setupKey)).First(&accountID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return "", status.Errorf(status.NotFound, "account not found: index lookup failed") } - log.WithContext(ctx).Errorf("error when getting setup key from the store: %s", result.Error) - return "", status.Errorf(status.Internal, "issue getting setup key from store") + return "", status.NewSetupKeyNotFoundError() + } + + if accountID == "" { + return "", status.Errorf(status.NotFound, "account not found: index lookup failed") } return accountID, nil } -func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, peerKey string) (*nbpeer.Peer, error) { +func (s *SqlStore) GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountID string) ([]net.IP, error) { + var ipJSONStrings []string + + // Fetch the IP addresses as JSON strings + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&nbpeer.Peer{}). + Where("account_id = ?", accountID). + Pluck("ip", &ipJSONStrings) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, status.Errorf(status.NotFound, "no peers found for the account") + } + return nil, status.Errorf(status.Internal, "issue getting IPs from store") + } + + // Convert the JSON strings to net.IP objects + ips := make([]net.IP, len(ipJSONStrings)) + for i, ipJSON := range ipJSONStrings { + var ip net.IP + if err := json.Unmarshal([]byte(ipJSON), &ip); err != nil { + return nil, status.Errorf(status.Internal, "issue parsing IP JSON from store") + } + ips[i] = ip + } + + return ips, nil +} + +func (s *SqlStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountID string) ([]string, error) { + var labels []string + + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&nbpeer.Peer{}). + Where("account_id = ?", accountID). + Pluck("dns_label", &labels) + + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, status.Errorf(status.NotFound, "no peers found for the account") + } + log.WithContext(ctx).Errorf("error when getting dns labels from the store: %s", result.Error) + return nil, status.Errorf(status.Internal, "issue getting dns labels from store") + } + + return labels, nil +} + +func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountID string) (*Network, error) { + var accountNetwork AccountNetwork + + if err := s.db.WithContext(ctx).Model(&Account{}).Where(idQueryCondition, accountID).First(&accountNetwork).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, status.NewAccountNotFoundError(accountID) + } + return nil, status.Errorf(status.Internal, "issue getting network from store") + } + return accountNetwork.Network, nil +} + +func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error) { var peer nbpeer.Peer - result := s.db.First(&peer, "key = ?", peerKey) + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).First(&peer, keyQueryCondition, peerKey) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "peer not found") } - log.WithContext(ctx).Errorf("error when getting peer from the store: %s", result.Error) return nil, status.Errorf(status.Internal, "issue getting peer from store") } return &peer, nil } -func (s *SqlStore) GetAccountSettings(ctx context.Context, accountID string) (*Settings, error) { +func (s *SqlStore) GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*Settings, error) { var accountSettings AccountSettings - if err := s.db.Model(&Account{}).Where(idQueryCondition, accountID).First(&accountSettings).Error; err != nil { + if err := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).Where(idQueryCondition, accountID).First(&accountSettings).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "settings not found") } - log.WithContext(ctx).Errorf("error when getting settings from the store: %s", err) return nil, status.Errorf(status.Internal, "issue getting settings from store") } return accountSettings.Settings, nil } // SaveUserLastLogin stores the last login time for a user in DB. -func (s *SqlStore) SaveUserLastLogin(accountID, userID string, lastLogin time.Time) error { +func (s *SqlStore) SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error { var user User - result := s.db.First(&user, accountAndIDQueryCondition, accountID, userID) + result := s.db.WithContext(ctx).First(&user, accountAndIDQueryCondition, accountID, userID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { - return status.Errorf(status.NotFound, "user %s not found", userID) + return status.NewUserNotFoundError(userID) } - return status.Errorf(status.Internal, "issue getting user from store") + return status.NewGetUserFromStoreError() } - user.LastLogin = lastLogin - return s.db.Save(user).Error + return s.db.Save(&user).Error } func (s *SqlStore) GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) { @@ -850,3 +904,123 @@ func NewPostgresqlStoreFromFileStore(ctx context.Context, fileStore *FileStore, return store, nil } + +func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error) { + var setupKey SetupKey + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}). + First(&setupKey, keyQueryCondition, strings.ToUpper(key)) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, status.Errorf(status.NotFound, "setup key not found") + } + return nil, status.NewSetupKeyNotFoundError() + } + return &setupKey, nil +} + +func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error { + result := s.db.WithContext(ctx).Model(&SetupKey{}). + Where(idQueryCondition, setupKeyID). + Updates(map[string]interface{}{ + "used_times": gorm.Expr("used_times + 1"), + "last_used": time.Now(), + }) + + if result.Error != nil { + return status.Errorf(status.Internal, "issue incrementing setup key usage count: %s", result.Error) + } + + if result.RowsAffected == 0 { + return status.Errorf(status.NotFound, "setup key not found") + } + + return nil +} + +func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error { + var group nbgroup.Group + + result := s.db.WithContext(ctx).Where("account_id = ? AND name = ?", accountID, "All").First(&group) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return status.Errorf(status.NotFound, "group 'All' not found for account") + } + return status.Errorf(status.Internal, "issue finding group 'All'") + } + + for _, existingPeerID := range group.Peers { + if existingPeerID == peerID { + return nil + } + } + + group.Peers = append(group.Peers, peerID) + + if err := s.db.Save(&group).Error; err != nil { + return status.Errorf(status.Internal, "issue updating group 'All'") + } + + return nil +} + +func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error { + var group nbgroup.Group + + result := s.db.WithContext(ctx).Where(accountAndIDQueryCondition, accountId, groupID).First(&group) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return status.Errorf(status.NotFound, "group not found for account") + } + return status.Errorf(status.Internal, "issue finding group") + } + + for _, existingPeerID := range group.Peers { + if existingPeerID == peerId { + return nil + } + } + + group.Peers = append(group.Peers, peerId) + + if err := s.db.Save(&group).Error; err != nil { + return status.Errorf(status.Internal, "issue updating group") + } + + return nil +} + +func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error { + if err := s.db.WithContext(ctx).Create(peer).Error; err != nil { + return status.Errorf(status.Internal, "issue adding peer to account") + } + + return nil +} + +func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, accountId string) error { + result := s.db.WithContext(ctx).Model(&Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1")) + if result.Error != nil { + return status.Errorf(status.Internal, "issue incrementing network serial count") + } + return nil +} + +func (s *SqlStore) ExecuteInTransaction(ctx context.Context, operation func(store Store) error) error { + tx := s.db.WithContext(ctx).Begin() + if tx.Error != nil { + return tx.Error + } + repo := s.withTx(tx) + err := operation(repo) + if err != nil { + tx.Rollback() + return err + } + return tx.Commit().Error +} + +func (s *SqlStore) withTx(tx *gorm.DB) Store { + return &SqlStore{ + db: tx, + } +} diff --git a/management/server/sql_store_test.go b/management/server/sql_store_test.go index ce4ee531a..64ef36831 100644 --- a/management/server/sql_store_test.go +++ b/management/server/sql_store_test.go @@ -1003,3 +1003,163 @@ func TestPostgresql_GetUserByTokenID(t *testing.T) { require.NoError(t, err) require.Equal(t, id, user.PATs[id].ID) } + +func TestSqlite_GetTakenIPs(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("The SQLite store is not properly supported by Windows yet") + } + + store := newSqliteStoreFromFile(t, "testdata/extended-store.json") + defer store.Close(context.Background()) + + existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + _, err := store.GetAccount(context.Background(), existingAccountID) + require.NoError(t, err) + + takenIPs, err := store.GetTakenIPs(context.Background(), LockingStrengthShare, existingAccountID) + require.NoError(t, err) + assert.Equal(t, []net.IP{}, takenIPs) + + peer1 := &nbpeer.Peer{ + ID: "peer1", + AccountID: existingAccountID, + IP: net.IP{1, 1, 1, 1}, + } + err = store.AddPeerToAccount(context.Background(), peer1) + require.NoError(t, err) + + takenIPs, err = store.GetTakenIPs(context.Background(), LockingStrengthShare, existingAccountID) + require.NoError(t, err) + ip1 := net.IP{1, 1, 1, 1}.To16() + assert.Equal(t, []net.IP{ip1}, takenIPs) + + peer2 := &nbpeer.Peer{ + ID: "peer2", + AccountID: existingAccountID, + IP: net.IP{2, 2, 2, 2}, + } + err = store.AddPeerToAccount(context.Background(), peer2) + require.NoError(t, err) + + takenIPs, err = store.GetTakenIPs(context.Background(), LockingStrengthShare, existingAccountID) + require.NoError(t, err) + ip2 := net.IP{2, 2, 2, 2}.To16() + assert.Equal(t, []net.IP{ip1, ip2}, takenIPs) + +} + +func TestSqlite_GetPeerLabelsInAccount(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("The SQLite store is not properly supported by Windows yet") + } + + store := newSqliteStoreFromFile(t, "testdata/extended-store.json") + defer store.Close(context.Background()) + + existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + _, err := store.GetAccount(context.Background(), existingAccountID) + require.NoError(t, err) + + labels, err := store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID) + require.NoError(t, err) + assert.Equal(t, []string{}, labels) + + peer1 := &nbpeer.Peer{ + ID: "peer1", + AccountID: existingAccountID, + DNSLabel: "peer1.domain.test", + } + err = store.AddPeerToAccount(context.Background(), peer1) + require.NoError(t, err) + + labels, err = store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID) + require.NoError(t, err) + assert.Equal(t, []string{"peer1.domain.test"}, labels) + + peer2 := &nbpeer.Peer{ + ID: "peer2", + AccountID: existingAccountID, + DNSLabel: "peer2.domain.test", + } + err = store.AddPeerToAccount(context.Background(), peer2) + require.NoError(t, err) + + labels, err = store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID) + require.NoError(t, err) + assert.Equal(t, []string{"peer1.domain.test", "peer2.domain.test"}, labels) +} + +func TestSqlite_GetAccountNetwork(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("The SQLite store is not properly supported by Windows yet") + } + + store := newSqliteStoreFromFile(t, "testdata/extended-store.json") + defer store.Close(context.Background()) + + existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + _, err := store.GetAccount(context.Background(), existingAccountID) + require.NoError(t, err) + + network, err := store.GetAccountNetwork(context.Background(), LockingStrengthShare, existingAccountID) + require.NoError(t, err) + ip := net.IP{100, 64, 0, 0}.To16() + assert.Equal(t, ip, network.Net.IP) + assert.Equal(t, net.IPMask{255, 255, 0, 0}, network.Net.Mask) + assert.Equal(t, "", network.Dns) + assert.Equal(t, "af1c8024-ha40-4ce2-9418-34653101fc3c", network.Identifier) + assert.Equal(t, uint64(0), network.Serial) +} + +func TestSqlite_GetSetupKeyBySecret(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("The SQLite store is not properly supported by Windows yet") + } + store := newSqliteStoreFromFile(t, "testdata/extended-store.json") + defer store.Close(context.Background()) + + existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + _, err := store.GetAccount(context.Background(), existingAccountID) + require.NoError(t, err) + + setupKey, err := store.GetSetupKeyBySecret(context.Background(), LockingStrengthShare, "A2C8E62B-38F5-4553-B31E-DD66C696CEBB") + require.NoError(t, err) + assert.Equal(t, "A2C8E62B-38F5-4553-B31E-DD66C696CEBB", setupKey.Key) + assert.Equal(t, "bf1c8084-ba50-4ce7-9439-34653001fc3b", setupKey.AccountID) + assert.Equal(t, "Default key", setupKey.Name) +} + +func TestSqlite_incrementSetupKeyUsage(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("The SQLite store is not properly supported by Windows yet") + } + store := newSqliteStoreFromFile(t, "testdata/extended-store.json") + defer store.Close(context.Background()) + + existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + _, err := store.GetAccount(context.Background(), existingAccountID) + require.NoError(t, err) + + setupKey, err := store.GetSetupKeyBySecret(context.Background(), LockingStrengthShare, "A2C8E62B-38F5-4553-B31E-DD66C696CEBB") + require.NoError(t, err) + assert.Equal(t, 0, setupKey.UsedTimes) + + err = store.IncrementSetupKeyUsage(context.Background(), setupKey.Id) + require.NoError(t, err) + + setupKey, err = store.GetSetupKeyBySecret(context.Background(), LockingStrengthShare, "A2C8E62B-38F5-4553-B31E-DD66C696CEBB") + require.NoError(t, err) + assert.Equal(t, 1, setupKey.UsedTimes) + + err = store.IncrementSetupKeyUsage(context.Background(), setupKey.Id) + require.NoError(t, err) + + setupKey, err = store.GetSetupKeyBySecret(context.Background(), LockingStrengthShare, "A2C8E62B-38F5-4553-B31E-DD66C696CEBB") + require.NoError(t, err) + assert.Equal(t, 2, setupKey.UsedTimes) +} diff --git a/management/server/status/error.go b/management/server/status/error.go index 58b9a84a0..d7fde35b9 100644 --- a/management/server/status/error.go +++ b/management/server/status/error.go @@ -100,3 +100,13 @@ func NewPeerNotRegisteredError() error { func NewPeerLoginExpiredError() error { return Errorf(PermissionDenied, "peer login has expired, please log in once more") } + +// NewSetupKeyNotFoundError creates a new Error with NotFound type for a missing setup key +func NewSetupKeyNotFoundError() error { + return Errorf(NotFound, "setup key not found") +} + +// NewGetUserFromStoreError creates a new Error with Internal type for an issue getting user from store +func NewGetUserFromStoreError() error { + return Errorf(Internal, "issue getting user from store") +} diff --git a/management/server/store.go b/management/server/store.go index a2b489391..84b3b140c 100644 --- a/management/server/store.go +++ b/management/server/store.go @@ -27,6 +27,15 @@ import ( "github.com/netbirdio/netbird/route" ) +type LockingStrength string + +const ( + LockingStrengthUpdate LockingStrength = "UPDATE" // Strongest lock, preventing any changes by other transactions until your transaction completes. + LockingStrengthShare LockingStrength = "SHARE" // Allows reading but prevents changes by other transactions. + LockingStrengthNoKeyUpdate LockingStrength = "NO KEY UPDATE" // Similar to UPDATE but allows changes to related rows. + LockingStrengthKeyShare LockingStrength = "KEY SHARE" // Protects against changes to primary/unique keys but allows other updates. +) + type Store interface { GetAllAccounts(ctx context.Context) []*Account GetAccount(ctx context.Context, accountID string) (*Account, error) @@ -41,7 +50,7 @@ type Store interface { GetAccountByPrivateDomain(ctx context.Context, domain string) (*Account, error) GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error) GetUserByTokenID(ctx context.Context, tokenID string) (*User, error) - GetUserByUserID(ctx context.Context, userID string) (*User, error) + GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error) GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) SaveAccount(ctx context.Context, account *Account) error @@ -60,14 +69,24 @@ type Store interface { SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error SavePeerLocation(accountID string, peer *nbpeer.Peer) error - SaveUserLastLogin(accountID, userID string, lastLogin time.Time) error + SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error // Close should close the store persisting all unsaved data. Close(ctx context.Context) error // GetStoreEngine should return StoreEngine of the current store implementation. // This is also a method of metrics.DataSource interface. GetStoreEngine() StoreEngine - GetPeerByPeerPubKey(ctx context.Context, peerKey string) (*nbpeer.Peer, error) - GetAccountSettings(ctx context.Context, accountID string) (*Settings, error) + GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error) + GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*Settings, error) + GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error) + GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountId string) ([]net.IP, error) + IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error + AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error + GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string) ([]string, error) + AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error + AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error + IncrementNetworkSerial(ctx context.Context, accountId string) error + GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountId string) (*Network, error) + ExecuteInTransaction(ctx context.Context, f func(store Store) error) error } type StoreEngine string diff --git a/management/server/testdata/extended-store.json b/management/server/testdata/extended-store.json new file mode 100644 index 000000000..7f96e57a8 --- /dev/null +++ b/management/server/testdata/extended-store.json @@ -0,0 +1,120 @@ +{ + "Accounts": { + "bf1c8084-ba50-4ce7-9439-34653001fc3b": { + "Id": "bf1c8084-ba50-4ce7-9439-34653001fc3b", + "CreatedBy": "", + "Domain": "test.com", + "DomainCategory": "private", + "IsDomainPrimaryAccount": true, + "SetupKeys": { + "A2C8E62B-38F5-4553-B31E-DD66C696CEBB": { + "Id": "A2C8E62B-38F5-4553-B31E-DD66C696CEBB", + "AccountID": "", + "Key": "A2C8E62B-38F5-4553-B31E-DD66C696CEBB", + "Name": "Default key", + "Type": "reusable", + "CreatedAt": "2021-08-19T20:46:20.005936822+02:00", + "ExpiresAt": "2321-09-18T20:46:20.005936822+02:00", + "UpdatedAt": "0001-01-01T00:00:00Z", + "Revoked": false, + "UsedTimes": 0, + "LastUsed": "0001-01-01T00:00:00Z", + "AutoGroups": ["cfefqs706sqkneg59g2g"], + "UsageLimit": 0, + "Ephemeral": false + }, + "A2C8E62B-38F5-4553-B31E-DD66C696CEBC": { + "Id": "A2C8E62B-38F5-4553-B31E-DD66C696CEBC", + "AccountID": "", + "Key": "A2C8E62B-38F5-4553-B31E-DD66C696CEBC", + "Name": "Faulty key with non existing group", + "Type": "reusable", + "CreatedAt": "2021-08-19T20:46:20.005936822+02:00", + "ExpiresAt": "2321-09-18T20:46:20.005936822+02:00", + "UpdatedAt": "0001-01-01T00:00:00Z", + "Revoked": false, + "UsedTimes": 0, + "LastUsed": "0001-01-01T00:00:00Z", + "AutoGroups": ["abcd"], + "UsageLimit": 0, + "Ephemeral": false + } + }, + "Network": { + "id": "af1c8024-ha40-4ce2-9418-34653101fc3c", + "Net": { + "IP": "100.64.0.0", + "Mask": "//8AAA==" + }, + "Dns": "", + "Serial": 0 + }, + "Peers": {}, + "Users": { + "edafee4e-63fb-11ec-90d6-0242ac120003": { + "Id": "edafee4e-63fb-11ec-90d6-0242ac120003", + "AccountID": "", + "Role": "admin", + "IsServiceUser": false, + "ServiceUserName": "", + "AutoGroups": ["cfefqs706sqkneg59g3g"], + "PATs": {}, + "Blocked": false, + "LastLogin": "0001-01-01T00:00:00Z" + }, + "f4f6d672-63fb-11ec-90d6-0242ac120003": { + "Id": "f4f6d672-63fb-11ec-90d6-0242ac120003", + "AccountID": "", + "Role": "user", + "IsServiceUser": false, + "ServiceUserName": "", + "AutoGroups": null, + "PATs": { + "9dj38s35-63fb-11ec-90d6-0242ac120003": { + "ID": "9dj38s35-63fb-11ec-90d6-0242ac120003", + "UserID": "", + "Name": "", + "HashedToken": "SoMeHaShEdToKeN", + "ExpirationDate": "2023-02-27T00:00:00Z", + "CreatedBy": "user", + "CreatedAt": "2023-01-01T00:00:00Z", + "LastUsed": "2023-02-01T00:00:00Z" + } + }, + "Blocked": false, + "LastLogin": "0001-01-01T00:00:00Z" + } + }, + "Groups": { + "cfefqs706sqkneg59g4g": { + "ID": "cfefqs706sqkneg59g4g", + "Name": "All", + "Peers": [] + }, + "cfefqs706sqkneg59g3g": { + "ID": "cfefqs706sqkneg59g3g", + "Name": "AwesomeGroup1", + "Peers": [] + }, + "cfefqs706sqkneg59g2g": { + "ID": "cfefqs706sqkneg59g2g", + "Name": "AwesomeGroup2", + "Peers": [] + } + }, + "Rules": null, + "Policies": [], + "Routes": null, + "NameServerGroups": null, + "DNSSettings": null, + "Settings": { + "PeerLoginExpirationEnabled": false, + "PeerLoginExpiration": 86400000000000, + "GroupsPropagationEnabled": false, + "JWTGroupsEnabled": false, + "JWTGroupsClaimName": "" + } + } + }, + "InstallationID": "" +} diff --git a/management/server/user.go b/management/server/user.go index 727bc5c6b..9e60bb94b 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -89,10 +89,6 @@ func (u *User) LastDashboardLoginChanged(LastLogin time.Time) bool { return LastLogin.After(u.LastLogin) && !u.LastLogin.IsZero() } -func (u *User) updateLastLogin(login time.Time) { - u.LastLogin = login -} - // HasAdminPower returns true if the user has admin or owner roles, false otherwise func (u *User) HasAdminPower() bool { return u.Role == UserRoleAdmin || u.Role == UserRoleOwner @@ -386,7 +382,7 @@ func (am *DefaultAccountManager) GetUser(ctx context.Context, claims jwtclaims.A // server when user authenticates a device. And we need to separate the Dashboard login event from the Device login event. newLogin := user.LastDashboardLoginChanged(claims.LastLogin) - err = am.Store.SaveUserLastLogin(account.Id, claims.UserId, claims.LastLogin) + err = am.Store.SaveUserLastLogin(ctx, account.Id, claims.UserId, claims.LastLogin) if err != nil { log.WithContext(ctx).Errorf("failed saving user last login: %v", err) }