diff --git a/management/server/account.go b/management/server/account.go index da3203852..a9781b385 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -76,6 +76,7 @@ type AccountManager interface { SaveOrAddUsers(ctx context.Context, accountID, initiatorUserID string, updates []*User, addIfNotExists bool) ([]*UserInfo, error) GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*SetupKey, error) GetAccountByID(ctx context.Context, accountID string, userID string) (*Account, error) + AccountExists(ctx context.Context, accountID string) (bool, error) GetAccountIDByUserID(ctx context.Context, userID, domain string) (string, error) GetAccountIDFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error @@ -1261,6 +1262,11 @@ func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, u return nil } +// AccountExists checks if an account exists. +func (am *DefaultAccountManager) AccountExists(ctx context.Context, accountID string) (bool, error) { + return am.Store.AccountExists(ctx, LockingStrengthShare, accountID) +} + // GetAccountIDByUserID retrieves the account ID based on the userID provided. // If user does have an account, it returns the user's account ID. // If the user doesn't have an account, it creates one using the provided domain. diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index b6283a7e6..ec29222a4 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -27,6 +27,7 @@ type MockAccountManager struct { CreateSetupKeyFunc func(ctx context.Context, accountId string, keyName string, keyType server.SetupKeyType, expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*server.SetupKey, error) GetSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) (*server.SetupKey, error) + AccountExistsFunc func(ctx context.Context, accountID string) (bool, error) GetAccountIDByUserIdFunc func(ctx context.Context, userId, domain string) (string, error) GetUserFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.User, error) ListUsersFunc func(ctx context.Context, accountID string) ([]*server.User, error) @@ -58,7 +59,7 @@ type MockAccountManager struct { UpdatePeerMetaFunc func(ctx context.Context, peerID string, meta nbpeer.PeerSystemMeta) error UpdatePeerSSHKeyFunc func(ctx context.Context, peerID string, sshKey string) error UpdatePeerFunc func(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error) - CreateRouteFunc func(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peer string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups,accessControlGroupIDs []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) + CreateRouteFunc func(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peer string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupIDs []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) GetRouteFunc func(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error) SaveRouteFunc func(ctx context.Context, accountID string, userID string, route *route.Route) error DeleteRouteFunc func(ctx context.Context, accountID string, routeID route.ID, userID string) error @@ -194,6 +195,14 @@ func (am *MockAccountManager) CreateSetupKey( return nil, status.Errorf(codes.Unimplemented, "method CreateSetupKey is not implemented") } +// AccountExists mock implementation of AccountExists from server.AccountManager interface +func (am *MockAccountManager) AccountExists(ctx context.Context, accountID string) (bool, error) { + if am.GetAccountIDByUserIdFunc != nil { + return am.AccountExistsFunc(ctx, accountID) + } + return false, status.Errorf(codes.Unimplemented, "method AccountExists is not implemented") +} + // GetAccountIDByUserID mock implementation of GetAccountIDByUserID from server.AccountManager interface func (am *MockAccountManager) GetAccountIDByUserID(ctx context.Context, userId, domain string) (string, error) { if am.GetAccountIDByUserIdFunc != nil { @@ -444,7 +453,7 @@ func (am *MockAccountManager) UpdatePeer(ctx context.Context, accountID, userID // CreateRoute mock implementation of CreateRoute from server.AccountManager interface func (am *MockAccountManager) CreateRoute(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupID []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) { if am.CreateRouteFunc != nil { - return am.CreateRouteFunc(ctx, accountID, prefix, networkType, domains, peerID, peerGroupIDs, description, netID, masquerade, metric, groups,accessControlGroupID, enabled, userID, keepRoute) + return am.CreateRouteFunc(ctx, accountID, prefix, networkType, domains, peerID, peerGroupIDs, description, netID, masquerade, metric, groups, accessControlGroupID, enabled, userID, keepRoute) } return nil, status.Errorf(codes.Unimplemented, "method CreateRoute is not implemented") } diff --git a/management/server/sql_store.go b/management/server/sql_store.go index 9e1ab27dc..d056015d8 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -10,7 +10,6 @@ import ( "path/filepath" "runtime" "runtime/debug" - "slices" "strings" "sync" "time" @@ -1033,84 +1032,6 @@ func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountId string, peerId return nil } -// AddUserPeersToGroups adds the user's peers to specified groups in database. -func (s *SqlStore) AddUserPeersToGroups(ctx context.Context, accountID string, userID string, groupIDs []string) error { - if len(groupIDs) == 0 { - return nil - } - - var userPeerIDs []string - result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(LockingStrengthShare)}).Select("id"). - Where("account_id = ? AND user_id = ?", accountID, userID).Model(&nbpeer.Peer{}).Find(&userPeerIDs) - if result.Error != nil { - return status.Errorf(status.Internal, "issue finding user peers") - } - - groupsToUpdate := make([]*nbgroup.Group, 0, len(groupIDs)) - for _, gid := range groupIDs { - group, err := s.GetGroupByID(ctx, LockingStrengthShare, gid, accountID) - if err != nil { - return err - } - - groupPeers := make(map[string]struct{}) - for _, pid := range group.Peers { - groupPeers[pid] = struct{}{} - } - - for _, pid := range userPeerIDs { - groupPeers[pid] = struct{}{} - } - - group.Peers = group.Peers[:0] - for pid := range groupPeers { - group.Peers = append(group.Peers, pid) - } - - groupsToUpdate = append(groupsToUpdate, group) - } - - return s.SaveGroups(ctx, LockingStrengthUpdate, groupsToUpdate) -} - -// RemoveUserPeersFromGroups removes the user's peers from specified groups in database. -func (s *SqlStore) RemoveUserPeersFromGroups(ctx context.Context, accountID string, userID string, groupIDs []string) error { - if len(groupIDs) == 0 { - return nil - } - - var userPeerIDs []string - result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(LockingStrengthShare)}).Select("id"). - Where("account_id = ? AND user_id = ?", accountID, userID).Model(&nbpeer.Peer{}).Find(&userPeerIDs) - if result.Error != nil { - return status.Errorf(status.Internal, "issue finding user peers") - } - - groupsToUpdate := make([]*nbgroup.Group, 0, len(groupIDs)) - for _, gid := range groupIDs { - group, err := s.GetGroupByID(ctx, LockingStrengthShare, gid, accountID) - if err != nil { - return err - } - - if group.Name == "All" { - continue - } - - update := make([]string, 0, len(group.Peers)) - for _, pid := range group.Peers { - if !slices.Contains(userPeerIDs, pid) { - update = append(update, pid) - } - } - - group.Peers = update - groupsToUpdate = append(groupsToUpdate, group) - } - - return s.SaveGroups(ctx, LockingStrengthUpdate, groupsToUpdate) -} - // GetUserPeers retrieves peers for a user. func (s *SqlStore) GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error) { return getRecords[*nbpeer.Peer](s.db.WithContext(ctx).Where("user_id = ?", userID), lockStrength, accountID)