diff --git a/management/server/file_store.go b/management/server/file_store.go index 50eeca596..981c8c653 100644 --- a/management/server/file_store.go +++ b/management/server/file_store.go @@ -963,7 +963,7 @@ func (s *FileStore) SaveUsers(_ string, _ map[string]*User) error { return status.Errorf(status.Internal, "SaveUsers is not implemented") } -func (s *FileStore) SaveGroups(_ string, _ map[string]*nbgroup.Group) error { +func (s *FileStore) SaveGroups(_ context.Context, _ LockingStrength, _ []*nbgroup.Group) error { return status.Errorf(status.Internal, "SaveGroups is not implemented") } @@ -1112,3 +1112,18 @@ func (s *FileStore) SaveDNSSettings(_ context.Context, _ LockingStrength, _ stri func (s *FileStore) SaveAccountSettings(_ context.Context, _ LockingStrength, _ string, _ *Settings) error { return status.Errorf(status.Internal, "SaveAccountSettings is not implemented") } + +func (s *FileStore) SaveGroup(_ context.Context, _ LockingStrength, _ *nbgroup.Group) error { + return status.Errorf(status.Internal, "SaveGroup is not implemented") +} + +func (s *FileStore) DeleteGroup(_ context.Context, _ LockingStrength, _, _ string) error { + return status.Errorf(status.Internal, "DeleteGroup is not implemented") +} +func (s *FileStore) DeleteGroups(_ context.Context, _ LockingStrength, _ []string, _ string) error { + return status.Errorf(status.Internal, "DeleteGroups is not implemented") +} + +func (s *FileStore) GetAccountUsers(_ context.Context, _ LockingStrength, _ string) ([]*User, error) { + return nil, status.Errorf(status.Internal, "GetAccountUsers is not implemented") +} diff --git a/management/server/group.go b/management/server/group.go index aa387c058..7c6ece6d7 100644 --- a/management/server/group.go +++ b/management/server/group.go @@ -69,21 +69,24 @@ func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, // SaveGroup object of the peers func (am *DefaultAccountManager) SaveGroup(ctx context.Context, accountID, userID string, newGroup *nbgroup.Group) error { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() return am.SaveGroups(ctx, accountID, userID, []*nbgroup.Group{newGroup}) } // SaveGroups adds new groups to the account. -// Note: This function does not acquire the global lock. -// It is the caller's responsibility to ensure proper locking is in place before invoking this method. func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, userID string, newGroups []*nbgroup.Group) error { - account, err := am.Store.GetAccount(ctx, accountID) + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return err } - var eventsToStore []func() + if user.AccountID != accountID { + return status.Errorf(status.PermissionDenied, "no permission to create group") + } + + var ( + eventsToStore []func() + groupsToSave []*nbgroup.Group + ) for _, newGroup := range newGroups { if newGroup.ID == "" && newGroup.Issued != nbgroup.GroupIssuedAPI { @@ -91,7 +94,7 @@ func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, user } if newGroup.ID == "" && newGroup.Issued == nbgroup.GroupIssuedAPI { - existingGroup, err := account.FindGroupByName(newGroup.Name) + existingGroup, err := am.Store.GetGroupByName(ctx, LockingStrengthShare, newGroup.Name, accountID) if err != nil { s, ok := status.FromError(err) if !ok || s.ErrorType != status.NotFound { @@ -109,40 +112,54 @@ func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, user } for _, peerID := range newGroup.Peers { - if account.Peers[peerID] == nil { + if _, err = am.Store.GetPeerByID(ctx, LockingStrengthShare, peerID, accountID); err != nil { return status.Errorf(status.InvalidArgument, "peer with ID \"%s\" not found", peerID) } } - oldGroup := account.Groups[newGroup.ID] - account.Groups[newGroup.ID] = newGroup + newGroup.AccountID = accountID + groupsToSave = append(groupsToSave, newGroup) - events := am.prepareGroupEvents(ctx, userID, accountID, newGroup, oldGroup, account) + events := am.prepareGroupEvents(ctx, userID, accountID, newGroup) eventsToStore = append(eventsToStore, events...) } - account.Network.IncSerial() - if err = am.Store.SaveAccount(ctx, account); err != nil { + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + return fmt.Errorf("failed to increment network serial: %w", err) + } + + if err = transaction.SaveGroups(ctx, LockingStrengthUpdate, groupsToSave); err != nil { + return fmt.Errorf("failed to save groups: %w", err) + } + return nil + }) + if err != nil { return err } - am.updateAccountPeers(ctx, account) - for _, storeEvent := range eventsToStore { storeEvent() } + account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) + if err != nil { + return fmt.Errorf("error getting account: %w", err) + } + am.updateAccountPeers(ctx, account) + return nil } // prepareGroupEvents prepares a list of event functions to be stored. -func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, userID string, accountID string, newGroup, oldGroup *nbgroup.Group, account *Account) []func() { +func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, userID string, accountID string, newGroup *nbgroup.Group) []func() { var eventsToStore []func() addedPeers := make([]string, 0) removedPeers := make([]string, 0) - if oldGroup != nil { + oldGroup, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, newGroup.ID, accountID) + if err == nil && oldGroup != nil { addedPeers = difference(newGroup.Peers, oldGroup.Peers) removedPeers = difference(oldGroup.Peers, newGroup.Peers) } else { @@ -152,12 +169,13 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, userID }) } - for _, p := range addedPeers { - peer := account.Peers[p] - if peer == nil { - log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", p, accountID) + for _, peerID := range addedPeers { + peer, err := am.Store.GetPeerByID(ctx, LockingStrengthShare, peerID, accountID) + if err != nil { + log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", peerID, accountID) continue } + peerCopy := peer // copy to avoid closure issues eventsToStore = append(eventsToStore, func() { am.StoreEvent(ctx, userID, peerCopy.ID, accountID, activity.GroupAddedToPeer, @@ -168,12 +186,13 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, userID }) } - for _, p := range removedPeers { - peer := account.Peers[p] - if peer == nil { - log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", p, accountID) + for _, peerID := range removedPeers { + peer, err := am.Store.GetPeerByID(ctx, LockingStrengthShare, peerID, accountID) + if err != nil { + log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", peerID, accountID) continue } + peerCopy := peer // copy to avoid closure issues eventsToStore = append(eventsToStore, func() { am.StoreEvent(ctx, userID, peerCopy.ID, accountID, activity.GroupRemovedFromPeer, @@ -203,85 +222,108 @@ func difference(a, b []string) []string { } // DeleteGroup object of the peers. -func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountId, userId, groupID string) error { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountId) - defer unlock() - - account, err := am.Store.GetAccount(ctx, accountId) +func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountID, userID, groupID string) error { + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return err } - group, ok := account.Groups[groupID] - if !ok { - return nil + if user.AccountID != accountID { + return status.Errorf(status.PermissionDenied, "no permission to delete group") } - allGroup, err := account.GetGroupAll() + group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, groupID, accountID) if err != nil { return err } - if allGroup.ID == groupID { + if group.Name == "All" { return status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed") } - if err = validateDeleteGroup(account, group, userId); err != nil { - return err - } - delete(account.Groups, groupID) - - account.Network.IncSerial() - if err = am.Store.SaveAccount(ctx, account); err != nil { + if err = am.validateDeleteGroup(ctx, group, userID); err != nil { return err } - am.StoreEvent(ctx, userId, groupID, accountId, activity.GroupDeleted, group.EventMeta()) + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + return fmt.Errorf("failed to increment network serial: %w", err) + } + if err = transaction.DeleteGroup(ctx, LockingStrengthUpdate, groupID, accountID); err != nil { + return fmt.Errorf("failed to delete group: %w", err) + } + return nil + }) + if err != nil { + return err + } + + am.StoreEvent(ctx, userID, groupID, accountID, activity.GroupDeleted, group.EventMeta()) + + account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) + if err != nil { + return fmt.Errorf("error getting account: %w", err) + } am.updateAccountPeers(ctx, account) return nil } // DeleteGroups deletes groups from an account. -// Note: This function does not acquire the global lock. -// It is the caller's responsibility to ensure proper locking is in place before invoking this method. -// -// If an error occurs while deleting a group, the function skips it and continues deleting other groups. -// Errors are collected and returned at the end. -func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountId, userId string, groupIDs []string) error { - account, err := am.Store.GetAccount(ctx, accountId) +func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, userID string, groupIDs []string) error { + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return err } - var allErrors error + if user.AccountID != accountID { + return status.Errorf(status.PermissionDenied, "no permission to delete groups") + } + + var ( + allErrors error + groupIDsToDelete []string + deletedGroups []*nbgroup.Group + ) - deletedGroups := make([]*nbgroup.Group, 0, len(groupIDs)) for _, groupID := range groupIDs { - group, ok := account.Groups[groupID] - if !ok { + group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, groupID, accountID) + if err != nil { continue } - if err := validateDeleteGroup(account, group, userId); err != nil { + if err := am.validateDeleteGroup(ctx, group, userID); err != nil { allErrors = errors.Join(allErrors, fmt.Errorf("failed to delete group %s: %w", groupID, err)) continue } - delete(account.Groups, groupID) + groupIDsToDelete = append(groupIDsToDelete, groupID) deletedGroups = append(deletedGroups, group) } - account.Network.IncSerial() - if err = am.Store.SaveAccount(ctx, account); err != nil { + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + return fmt.Errorf("failed to increment network serial: %w", err) + } + + if err = transaction.DeleteGroups(ctx, LockingStrengthUpdate, groupIDsToDelete, accountID); err != nil { + return fmt.Errorf("failed to delete group: %w", err) + } + return nil + }) + if err != nil { return err } - for _, g := range deletedGroups { - am.StoreEvent(ctx, userId, g.ID, accountId, activity.GroupDeleted, g.EventMeta()) + for _, group := range deletedGroups { + am.StoreEvent(ctx, userID, group.ID, accountID, activity.GroupDeleted, group.EventMeta()) } + account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) + if err != nil { + return fmt.Errorf("error getting account: %w", err) + } am.updateAccountPeers(ctx, account) return allErrors @@ -371,11 +413,11 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID, return nil } -func validateDeleteGroup(account *Account, group *nbgroup.Group, userID string) error { +func (am *DefaultAccountManager) validateDeleteGroup(ctx context.Context, group *nbgroup.Group, userID string) error { // disable a deleting integration group if the initiator is not an admin service user if group.Issued == nbgroup.GroupIssuedIntegration { - executingUser := account.Users[userID] - if executingUser == nil { + executingUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + if err != nil { return status.Errorf(status.NotFound, "user not found") } if executingUser.Role != UserRoleAdmin || !executingUser.IsServiceUser { @@ -383,32 +425,42 @@ func validateDeleteGroup(account *Account, group *nbgroup.Group, userID string) } } - if isLinked, linkedRoute := isGroupLinkedToRoute(account.Routes, group.ID); isLinked { + if isLinked, linkedRoute := am.isGroupLinkedToRoute(ctx, group.ID, group.AccountID); isLinked { return &GroupLinkError{"route", string(linkedRoute.NetID)} } - if isLinked, linkedDns := isGroupLinkedToDns(account.NameServerGroups, group.ID); isLinked { + if isLinked, linkedDns := am.isGroupLinkedToDns(ctx, group.ID, group.AccountID); isLinked { return &GroupLinkError{"name server groups", linkedDns.Name} } - if isLinked, linkedPolicy := isGroupLinkedToPolicy(account.Policies, group.ID); isLinked { + if isLinked, linkedPolicy := am.isGroupLinkedToPolicy(ctx, group.ID, group.AccountID); isLinked { return &GroupLinkError{"policy", linkedPolicy.Name} } - if isLinked, linkedSetupKey := isGroupLinkedToSetupKey(account.SetupKeys, group.ID); isLinked { + if isLinked, linkedSetupKey := am.isGroupLinkedToSetupKey(ctx, group.ID, group.AccountID); isLinked { return &GroupLinkError{"setup key", linkedSetupKey.Name} } - if isLinked, linkedUser := isGroupLinkedToUser(account.Users, group.ID); isLinked { + if isLinked, linkedUser := am.isGroupLinkedToUser(ctx, group.ID, group.AccountID); isLinked { return &GroupLinkError{"user", linkedUser.Id} } - if slices.Contains(account.DNSSettings.DisabledManagementGroups, group.ID) { + dnsSettings, err := am.Store.GetAccountDNSSettings(ctx, LockingStrengthShare, group.AccountID) + if err != nil { + return err + } + + if slices.Contains(dnsSettings.DisabledManagementGroups, group.ID) { return &GroupLinkError{"disabled DNS management groups", group.Name} } - if account.Settings.Extra != nil { - if slices.Contains(account.Settings.Extra.IntegratedValidatorGroups, group.ID) { + settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, group.AccountID) + if err != nil { + return err + } + + if settings.Extra != nil { + if slices.Contains(settings.Extra.IntegratedValidatorGroups, group.ID) { return &GroupLinkError{"integrated validator", group.Name} } } @@ -417,17 +469,30 @@ func validateDeleteGroup(account *Account, group *nbgroup.Group, userID string) } // isGroupLinkedToRoute checks if a group is linked to any route in the account. -func isGroupLinkedToRoute(routes map[route.ID]*route.Route, groupID string) (bool, *route.Route) { +func (am *DefaultAccountManager) isGroupLinkedToRoute(ctx context.Context, groupID string, accountID string) (bool, *route.Route) { + routes, err := am.Store.GetAccountRoutes(ctx, LockingStrengthShare, accountID) + if err != nil { + log.WithContext(ctx).Errorf("error retrieving routes while checking group linkage: %v", err) + return false, nil + } + for _, r := range routes { if slices.Contains(r.Groups, groupID) || slices.Contains(r.PeerGroups, groupID) { return true, r } } + return false, nil } // isGroupLinkedToPolicy checks if a group is linked to any policy in the account. -func isGroupLinkedToPolicy(policies []*Policy, groupID string) (bool, *Policy) { +func (am *DefaultAccountManager) isGroupLinkedToPolicy(ctx context.Context, groupID string, accountID string) (bool, *Policy) { + policies, err := am.Store.GetAccountPolicies(ctx, LockingStrengthShare, accountID) + if err != nil { + log.WithContext(ctx).Errorf("error retrieving policies while checking group linkage: %v", err) + return false, nil + } + for _, policy := range policies { for _, rule := range policy.Rules { if slices.Contains(rule.Sources, groupID) || slices.Contains(rule.Destinations, groupID) { @@ -439,7 +504,13 @@ func isGroupLinkedToPolicy(policies []*Policy, groupID string) (bool, *Policy) { } // isGroupLinkedToDns checks if a group is linked to any nameserver group in the account. -func isGroupLinkedToDns(nameServerGroups map[string]*nbdns.NameServerGroup, groupID string) (bool, *nbdns.NameServerGroup) { +func (am *DefaultAccountManager) isGroupLinkedToDns(ctx context.Context, groupID string, accountID string) (bool, *nbdns.NameServerGroup) { + nameServerGroups, err := am.Store.GetAccountNameServerGroups(ctx, LockingStrengthShare, accountID) + if err != nil { + log.WithContext(ctx).Errorf("error retrieving name server groups while checking group linkage: %v", err) + return false, nil + } + for _, dns := range nameServerGroups { for _, g := range dns.Groups { if g == groupID { @@ -447,11 +518,18 @@ func isGroupLinkedToDns(nameServerGroups map[string]*nbdns.NameServerGroup, grou } } } + return false, nil } // isGroupLinkedToSetupKey checks if a group is linked to any setup key in the account. -func isGroupLinkedToSetupKey(setupKeys map[string]*SetupKey, groupID string) (bool, *SetupKey) { +func (am *DefaultAccountManager) isGroupLinkedToSetupKey(ctx context.Context, groupID string, accountID string) (bool, *SetupKey) { + setupKeys, err := am.Store.GetAccountSetupKeys(ctx, LockingStrengthShare, accountID) + if err != nil { + log.WithContext(ctx).Errorf("error retrieving setup keys while checking group linkage: %v", err) + return false, nil + } + for _, setupKey := range setupKeys { if slices.Contains(setupKey.AutoGroups, groupID) { return true, setupKey @@ -461,7 +539,13 @@ func isGroupLinkedToSetupKey(setupKeys map[string]*SetupKey, groupID string) (bo } // isGroupLinkedToUser checks if a group is linked to any user in the account. -func isGroupLinkedToUser(users map[string]*User, groupID string) (bool, *User) { +func (am *DefaultAccountManager) isGroupLinkedToUser(ctx context.Context, groupID string, accountID string) (bool, *User) { + users, err := am.Store.GetAccountUsers(ctx, LockingStrengthShare, accountID) + if err != nil { + log.WithContext(ctx).Errorf("error retrieving users while checking group linkage: %v", err) + return false, nil + } + for _, user := range users { if slices.Contains(user.AutoGroups, groupID) { return true, user diff --git a/management/server/sql_store.go b/management/server/sql_store.go index 3bef8d410..6e289eb7e 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -378,17 +378,6 @@ func (s *SqlStore) SaveUsers(accountID string, users map[string]*User) error { Create(&usersToSave).Error } -// SaveGroups saves the given list of groups to the database. -// It updates existing groups if a conflict occurs. -func (s *SqlStore) SaveGroups(accountID string, groups map[string]*nbgroup.Group) error { - groupsToSave := make([]nbgroup.Group, 0, len(groups)) - for _, group := range groups { - group.AccountID = accountID - groupsToSave = append(groupsToSave, *group) - } - return s.db.Clauses(clause.OnConflict{UpdateAll: true}).Create(&groupsToSave).Error -} - // DeleteHashedPAT2TokenIDIndex is noop in SqlStore func (s *SqlStore) DeleteHashedPAT2TokenIDIndex(hashedToken string) error { return nil @@ -500,6 +489,11 @@ func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStre return &user, nil } +// GetAccountUsers returns all users associated with the account. +func (s *SqlStore) GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*User, error) { + return getRecords[User](s.db.WithContext(ctx).Preload(clause.Associations), lockStrength, accountID) +} + func (s *SqlStore) GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) { var groups []*nbgroup.Group result := s.db.Find(&groups, accountIDCondition, accountID) @@ -1135,9 +1129,38 @@ func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStren return &group, nil } +// SaveGroup saves a group to the database. +func (s *SqlStore) SaveGroup(ctx context.Context, lockStrength LockingStrength, group *nbgroup.Group) error { + return saveRecord[nbgroup.Group](s.db.WithContext(ctx), lockStrength, group) +} + +// SaveGroups saves the given list of groups to the database. +func (s *SqlStore) SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*nbgroup.Group) error { + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(&groups) + if result.Error != nil { + return status.Errorf(status.Internal, "failed to save groups to store: %v", result.Error) + } + return nil +} + +// DeleteGroup deletes a group from the database. +func (s *SqlStore) DeleteGroup(ctx context.Context, lockStrength LockingStrength, groupID, accountID string) error { + return deleteRecordByID[nbgroup.Group](s.db.WithContext(ctx), lockStrength, groupID, accountID) +} + +// DeleteGroups deletes groups from the database. +func (s *SqlStore) DeleteGroups(ctx context.Context, strength LockingStrength, groupIDs []string, accountID string) error { + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(strength)}). + Where("account_id AND id IN ?", accountID, groupIDs).Delete(&nbgroup.Group{}) + if result.Error != nil { + return status.Errorf(status.Internal, "failed to delete groups from store: %v", result.Error) + } + return nil +} + // GetAccountPolicies retrieves policies for an account. func (s *SqlStore) GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) { - return getRecords[*Policy](s.db.WithContext(ctx).Preload(clause.Associations), lockStrength, accountID) + return getRecords[Policy](s.db.WithContext(ctx).Preload(clause.Associations), lockStrength, accountID) } // GetPolicyByID retrieves a policy by its ID and account ID. @@ -1159,7 +1182,7 @@ func (s *SqlStore) DeletePolicy(ctx context.Context, lockStrength LockingStrengt // GetAccountPostureChecks retrieves posture checks for an account. func (s *SqlStore) GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error) { - return getRecords[*posture.Checks](s.db.WithContext(ctx), lockStrength, accountID) + return getRecords[posture.Checks](s.db.WithContext(ctx), lockStrength, accountID) } // GetPostureChecksByID retrieves posture checks by their ID and account ID. @@ -1188,7 +1211,7 @@ func (s *SqlStore) DeletePostureChecks(ctx context.Context, lockStrength Locking // GetAccountRoutes retrieves network routes for an account. func (s *SqlStore) GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error) { - return getRecords[*route.Route](s.db.WithContext(ctx), lockStrength, accountID) + return getRecords[route.Route](s.db.WithContext(ctx), lockStrength, accountID) } // GetRouteByID retrieves a route by its ID and account ID. @@ -1209,7 +1232,7 @@ func (s *SqlStore) DeleteRoute(ctx context.Context, lockStrength LockingStrength // GetAccountSetupKeys retrieves setup keys for an account. func (s *SqlStore) GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*SetupKey, error) { - return getRecords[*SetupKey](s.db.WithContext(ctx), lockStrength, accountID) + return getRecords[SetupKey](s.db.WithContext(ctx), lockStrength, accountID) } // GetSetupKeyByID retrieves a setup key by its ID and account ID. @@ -1231,7 +1254,7 @@ func (s *SqlStore) DeleteSetupKey(ctx context.Context, lockStrength LockingStren // GetAccountNameServerGroups retrieves name server groups for an account. func (s *SqlStore) GetAccountNameServerGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbdns.NameServerGroup, error) { - return getRecords[*nbdns.NameServerGroup](s.db.WithContext(ctx), lockStrength, accountID) + return getRecords[nbdns.NameServerGroup](s.db.WithContext(ctx), lockStrength, accountID) } // GetNameServerGroupByID retrieves a name server group by its ID and account ID. @@ -1277,13 +1300,13 @@ func (s *SqlStore) DeletePAT(ctx context.Context, lockStrength LockingStrength, // GetAccountPeers retrieves peers for an account. func (s *SqlStore) GetAccountPeers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error) { - return getRecords[*nbpeer.Peer](s.db.WithContext(ctx), lockStrength, accountID) + return getRecords[nbpeer.Peer](s.db.WithContext(ctx), lockStrength, accountID) } // GetAccountPeersWithExpiration retrieves a list of peers that have Peer.LoginExpirationEnabled set to true and that were added by a user. func (s *SqlStore) GetAccountPeersWithExpiration(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error) { db := s.db.WithContext(ctx).Where("login_expiration_enabled = ? AND user_id IS NOT NULL AND user_id != ''", true) - return getRecords[*nbpeer.Peer](db, lockStrength, accountID) + return getRecords[nbpeer.Peer](db, lockStrength, accountID) } // GetPeerByID retrieves a peer by its ID and account ID. @@ -1292,14 +1315,12 @@ func (s *SqlStore) GetPeerByID(ctx context.Context, lockStrength LockingStrength } // getRecords retrieves records from the database based on the account ID. -func getRecords[T any](db *gorm.DB, lockStrength LockingStrength, accountID string) ([]T, error) { - var record []T +func getRecords[T any](db *gorm.DB, lockStrength LockingStrength, accountID string) ([]*T, error) { + var record []*T result := db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&record, accountIDCondition, accountID) if err := result.Error; err != nil { - parts := strings.Split(fmt.Sprintf("%T", record), ".") - recordType := parts[len(parts)-1] - + recordType := getRecordType(record) return nil, status.Errorf(status.Internal, "failed to get account %ss from store: %v", recordType, err) } @@ -1313,8 +1334,7 @@ func getRecordByID[T any](db *gorm.DB, lockStrength LockingStrength, recordID, a result := db.Clauses(clause.Locking{Strength: string(lockStrength)}). First(&record, accountAndIDQueryCondition, accountID, recordID) if err := result.Error; err != nil { - parts := strings.Split(fmt.Sprintf("%T", record), ".") - recordType := parts[len(parts)-1] + recordType := getRecordType(record) if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "%s not found", recordType) @@ -1324,15 +1344,23 @@ func getRecordByID[T any](db *gorm.DB, lockStrength LockingStrength, recordID, a return &record, nil } +// saveRecord saves a record to the database. +func saveRecord[T any](db *gorm.DB, lockStrength LockingStrength, record *T) error { + result := db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(record) + if result.Error != nil { + return status.Errorf(status.Internal, "failed to save %s to store: %v", getRecordType(record), result.Error) + } + + return nil +} + // deleteRecordByID deletes a record by its ID and account ID from the database. func deleteRecordByID[T any](db *gorm.DB, lockStrength LockingStrength, recordID, accountID string) error { var record T - parts := strings.Split(fmt.Sprintf("%T", record), ".") - recordType := parts[len(parts)-1] + recordType := getRecordType(record) - result := db.Clauses(clause.Locking{Strength: string(lockStrength)}). - Delete(&record, accountAndIDQueryCondition, accountID, recordID) + result := db.Clauses(clause.Locking{Strength: string(lockStrength)}).Delete(&record, accountAndIDQueryCondition, accountID, recordID) if err := result.Error; err != nil { return status.Errorf(status.Internal, "failed to delete %s from store: %v", recordType, err) } @@ -1343,3 +1371,8 @@ func deleteRecordByID[T any](db *gorm.DB, lockStrength LockingStrength, recordID return nil } + +func getRecordType(record any) string { + parts := strings.Split(fmt.Sprintf("%T", record), ".") + return parts[len(parts)-1] +} diff --git a/management/server/store.go b/management/server/store.go index a3fdd012e..a53b5ad70 100644 --- a/management/server/store.go +++ b/management/server/store.go @@ -62,6 +62,7 @@ type Store interface { GetUserByTokenID(ctx context.Context, tokenID string) (*User, error) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error) + GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*User, error) SaveUsers(accountID string, users map[string]*User) error SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error) @@ -75,7 +76,10 @@ type Store interface { GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) GetGroupByID(ctx context.Context, lockStrength LockingStrength, groupID, accountID string) (*nbgroup.Group, error) GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error) - SaveGroups(accountID string, groups map[string]*nbgroup.Group) error + SaveGroup(ctx context.Context, lockStrength LockingStrength, group *nbgroup.Group) error + SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*nbgroup.Group) error + DeleteGroup(ctx context.Context, lockStrength LockingStrength, groupID, accountID string) error + DeleteGroups(ctx context.Context, strength LockingStrength, groupIDs []string, accountID string) error GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error)