diff --git a/management/server/account.go b/management/server/account.go index 6481a71d8..ebc89983e 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -128,7 +128,7 @@ type AccountManager interface { GetDNSSettings(ctx context.Context, accountID string, userID string) (*DNSSettings, error) SaveDNSSettings(ctx context.Context, accountID string, userID string, dnsSettingsToSave *DNSSettings) error GetPeer(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) - UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *Settings) (*Account, error) + UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *Settings) (*Settings, error) LoginPeer(ctx context.Context, login PeerLogin) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) // used by peer gRPC API SyncPeer(ctx context.Context, sync PeerSync, account *Account) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) // used by peer gRPC API GetAllConnectedPeers() (map[string]struct{}, error) @@ -1048,7 +1048,16 @@ func (am *DefaultAccountManager) GetIdpManager() idp.Manager { // Only users with role UserRoleAdmin can update the account. // User that performs the update has to belong to the account. // Returns an updated Account -func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *Settings) (*Account, error) { +func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *Settings) (*Settings, error) { + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + if err != nil { + return nil, err + } + + if !user.HasAdminPower() || user.AccountID != accountID { + return nil, status.Errorf(status.PermissionDenied, "user is not allowed to update account") + } + halfYearLimit := 180 * 24 * time.Hour if newSettings.PeerLoginExpiration > halfYearLimit { return nil, status.Errorf(status.InvalidArgument, "peer login expiration can't be larger than 180 days") @@ -1058,53 +1067,57 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco return nil, status.Errorf(status.InvalidArgument, "peer login expiration can't be smaller than one hour") } - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - - account, err := am.Store.GetAccount(ctx, accountID) + oldSettings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID) if err != nil { return nil, err } - user, err := account.FindUser(userID) - if err != nil { + if err = am.validateExtraSettings(ctx, newSettings, oldSettings, userID, accountID); err != nil { return nil, err } - if !user.HasAdminPower() { - return nil, status.Errorf(status.PermissionDenied, "user is not allowed to update account") + if err = am.Store.SaveAccountSettings(ctx, LockingStrengthUpdate, accountID, newSettings); err != nil { + return nil, fmt.Errorf("failed updating account settings: %w", err) } - err = am.integratedPeerValidator.ValidateExtraSettings(ctx, newSettings.Extra, account.Settings.Extra, account.Peers, userID, accountID) - if err != nil { - return nil, err - } - - oldSettings := account.Settings if oldSettings.PeerLoginExpirationEnabled != newSettings.PeerLoginExpirationEnabled { event := activity.AccountPeerLoginExpirationEnabled if !newSettings.PeerLoginExpirationEnabled { event = activity.AccountPeerLoginExpirationDisabled am.peerLoginExpiry.Cancel(ctx, []string{accountID}) } else { - am.checkAndSchedulePeerLoginExpiration(ctx, account) + am.checkAndSchedulePeerLoginExpiration(ctx, accountID) } am.StoreEvent(ctx, userID, accountID, accountID, event, nil) } if oldSettings.PeerLoginExpiration != newSettings.PeerLoginExpiration { am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountPeerLoginExpirationDurationUpdated, nil) - am.checkAndSchedulePeerLoginExpiration(ctx, account) + am.checkAndSchedulePeerLoginExpiration(ctx, accountID) } - updatedAccount := account.UpdateSettings(newSettings) - - err = am.Store.SaveAccount(ctx, account) + account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) if err != nil { - return nil, err + return nil, fmt.Errorf("error getting account: %w", err) + } + am.updateAccountPeers(ctx, account) + + return newSettings, nil +} + +// validateExtraSettings validates the extra settings of the account. +func (am *DefaultAccountManager) validateExtraSettings(ctx context.Context, newSettings, oldSettings *Settings, userID, accountID string) error { + peers, err := am.Store.GetAccountPeers(ctx, LockingStrengthShare, accountID) + if err != nil { + return err } - return updatedAccount, nil + peerMap := make(map[string]*nbpeer.Peer, len(peers)) + for _, peer := range peers { + peerMap[peer.ID] = peer + } + + return am.integratedPeerValidator.ValidateExtraSettings(ctx, newSettings.Extra, oldSettings.Extra, peerMap, userID, accountID) } func (am *DefaultAccountManager) peerLoginExpirationJob(ctx context.Context, accountID string) func() (time.Duration, bool) { @@ -1135,10 +1148,10 @@ func (am *DefaultAccountManager) peerLoginExpirationJob(ctx context.Context, acc } } -func (am *DefaultAccountManager) checkAndSchedulePeerLoginExpiration(ctx context.Context, account *Account) { - am.peerLoginExpiry.Cancel(ctx, []string{account.Id}) - if nextRun, ok := account.GetNextPeerExpiration(); ok { - go am.peerLoginExpiry.Schedule(ctx, nextRun, account.Id, am.peerLoginExpirationJob(ctx, account.Id)) +func (am *DefaultAccountManager) checkAndSchedulePeerLoginExpiration(ctx context.Context, accountID string) { + am.peerLoginExpiry.Cancel(ctx, []string{accountID}) + if nextRun, ok := am.getNextPeerExpiration(ctx, accountID); ok { + go am.peerLoginExpiry.Schedule(ctx, nextRun, accountID, am.peerLoginExpirationJob(ctx, accountID)) } } @@ -1674,33 +1687,18 @@ func (am *DefaultAccountManager) redeemInvite(ctx context.Context, accountID str // MarkPATUsed marks a personal access token as used func (am *DefaultAccountManager) MarkPATUsed(ctx context.Context, tokenID string) error { - user, err := am.Store.GetUserByTokenID(ctx, tokenID) if err != nil { return err } - account, err := am.Store.GetAccountByUser(ctx, user.Id) + pat, err := am.Store.GetPATByID(ctx, LockingStrengthShare, tokenID, user.Id) if err != nil { return err } - - unlock := am.Store.AcquireWriteLockByUID(ctx, account.Id) - defer unlock() - - account, err = am.Store.GetAccountByUser(ctx, user.Id) - if err != nil { - return err - } - - pat, ok := account.Users[user.Id].PATs[tokenID] - if !ok { - return fmt.Errorf("token not found") - } - pat.LastUsed = time.Now().UTC() - return am.Store.SaveAccount(ctx, account) + return am.Store.SavePAT(ctx, LockingStrengthUpdate, pat) } // GetAccount returns an account associated with this account ID. diff --git a/management/server/dns.go b/management/server/dns.go index 7410aaa15..12a332156 100644 --- a/management/server/dns.go +++ b/management/server/dns.go @@ -6,6 +6,7 @@ import ( "strconv" "sync" + nbgroup "github.com/netbirdio/netbird/management/server/group" log "github.com/sirupsen/logrus" nbdns "github.com/netbirdio/netbird/dns" @@ -94,56 +95,78 @@ func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID s // SaveDNSSettings validates a user role and updates the account's DNS settings func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID string, userID string, dnsSettingsToSave *DNSSettings) error { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - - account, err := am.Store.GetAccount(ctx, accountID) - if err != nil { - return err - } - - user, err := account.FindUser(userID) - if err != nil { - return err - } - - if !user.HasAdminPower() { - return status.Errorf(status.PermissionDenied, "only users with admin power are allowed to update DNS settings") - } - if dnsSettingsToSave == nil { return status.Errorf(status.InvalidArgument, "the dns settings provided are nil") } + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + if err != nil { + return err + } + + if !user.HasAdminPower() || user.AccountID != accountID { + return status.Errorf(status.PermissionDenied, "only users with admin power are allowed to update DNS settings") + } + + oldSettings, err := am.Store.GetAccountDNSSettings(ctx, LockingStrengthUpdate, accountID) + if err != nil { + return err + } + + groups, err := am.Store.GetAccountGroups(ctx, accountID) + if err != nil { + return err + } + if len(dnsSettingsToSave.DisabledManagementGroups) != 0 { - err = validateGroups(dnsSettingsToSave.DisabledManagementGroups, account.Groups) + err = validateGroups(dnsSettingsToSave.DisabledManagementGroups, groups) if err != nil { return err } } - oldSettings := account.DNSSettings.Copy() - account.DNSSettings = dnsSettingsToSave.Copy() + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + return fmt.Errorf("failed to increment network serial: %w", err) + } - account.Network.IncSerial() - if err = am.Store.SaveAccount(ctx, account); err != nil { + if err = transaction.SaveDNSSettings(ctx, LockingStrengthUpdate, accountID, dnsSettingsToSave); err != nil { + return fmt.Errorf("failed to update dns settings: %w", err) + } + + return nil + }) + if err != nil { return err } + groupMap := make(map[string]*nbgroup.Group, len(groups)) + for _, g := range groups { + groupMap[g.ID] = g + } + addedGroups := difference(dnsSettingsToSave.DisabledManagementGroups, oldSettings.DisabledManagementGroups) for _, id := range addedGroups { - group := account.GetGroup(id) - meta := map[string]any{"group": group.Name, "group_id": group.ID} - am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupAddedToDisabledManagementGroups, meta) + group, ok := groupMap[id] + if ok { + meta := map[string]any{"group": group.Name, "group_id": group.ID} + am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupAddedToDisabledManagementGroups, meta) + } } removedGroups := difference(oldSettings.DisabledManagementGroups, dnsSettingsToSave.DisabledManagementGroups) for _, id := range removedGroups { - group := account.GetGroup(id) - meta := map[string]any{"group": group.Name, "group_id": group.ID} - am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupRemovedFromDisabledManagementGroups, meta) + group, ok := groupMap[id] + if ok { + meta := map[string]any{"group": group.Name, "group_id": group.ID} + am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupRemovedFromDisabledManagementGroups, meta) + } } + account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) + if err != nil { + return fmt.Errorf("error getting account: %w", err) + } am.updateAccountPeers(ctx, account) return nil diff --git a/management/server/http/accounts_handler.go b/management/server/http/accounts_handler.go index 91caa1512..73bd5c35d 100644 --- a/management/server/http/accounts_handler.go +++ b/management/server/http/accounts_handler.go @@ -97,13 +97,13 @@ func (h *AccountsHandler) UpdateAccount(w http.ResponseWriter, r *http.Request) settings.JWTAllowGroups = *req.Settings.JwtAllowGroups } - updatedAccount, err := h.accountManager.UpdateAccountSettings(r.Context(), accountID, userID, settings) + updatedAccountSettings, err := h.accountManager.UpdateAccountSettings(r.Context(), accountID, userID, settings) if err != nil { util.WriteError(r.Context(), err, w) return } - resp := toAccountResponse(updatedAccount.Id, updatedAccount.Settings) + resp := toAccountResponse(accountID, updatedAccountSettings) util.WriteJSONObject(r.Context(), w, &resp) } diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index 7e5057dbd..b3fb03f3d 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -89,7 +89,7 @@ type MockAccountManager struct { GetDNSSettingsFunc func(ctx context.Context, accountID, userID string) (*server.DNSSettings, error) SaveDNSSettingsFunc func(ctx context.Context, accountID, userID string, dnsSettingsToSave *server.DNSSettings) error GetPeerFunc func(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) - UpdateAccountSettingsFunc func(ctx context.Context, accountID, userID string, newSettings *server.Settings) (*server.Account, error) + UpdateAccountSettingsFunc func(ctx context.Context, accountID, userID string, newSettings *server.Settings) (*server.Settings, error) LoginPeerFunc func(ctx context.Context, login server.PeerLogin) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) SyncPeerFunc func(ctx context.Context, sync server.PeerSync, account *server.Account) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) InviteUserFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserEmail string) error @@ -667,7 +667,7 @@ func (am *MockAccountManager) GetPeer(ctx context.Context, accountID, peerID, us } // UpdateAccountSettings mocks UpdateAccountSettings of the AccountManager interface -func (am *MockAccountManager) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *server.Settings) (*server.Account, error) { +func (am *MockAccountManager) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *server.Settings) (*server.Settings, error) { if am.UpdateAccountSettingsFunc != nil { return am.UpdateAccountSettingsFunc(ctx, accountID, userID, newSettings) }