diff --git a/management/server/account.go b/management/server/account.go index 11c3a17e0..523a47972 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -1266,7 +1266,7 @@ func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, u // Returns the account ID or an error if none is found or created. func (am *DefaultAccountManager) GetAccountIDByUserOrAccountID(ctx context.Context, userID, accountID, domain string) (string, error) { if accountID != "" { - exists, err := am.Store.AccountExists(ctx, accountID) + exists, err := am.Store.AccountExists(ctx, LockingStrengthShare, accountID) if err != nil { return "", err } diff --git a/management/server/file_store.go b/management/server/file_store.go index e4307b1bd..994a4b1ee 100644 --- a/management/server/file_store.go +++ b/management/server/file_store.go @@ -647,7 +647,7 @@ func (s *FileStore) GetUserByUserID(_ context.Context, _ LockingStrength, userID return user, nil } -func (s *FileStore) GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) { +func (s *FileStore) GetAccountGroups(_ context.Context, accountID string) ([]*nbgroup.Group, error) { account, err := s.getAccount(accountID) if err != nil { return nil, err @@ -985,7 +985,7 @@ func (s *FileStore) GetAccountDomainAndCategory(_ context.Context, _ LockingStre } // AccountExists checks whether an account exists by the given ID. -func (s *FileStore) AccountExists(_ context.Context, id string) (bool, error) { +func (s *FileStore) AccountExists(_ context.Context, _ LockingStrength, id string) (bool, error) { _, exists := s.Accounts[id] return exists, nil } @@ -1002,7 +1002,7 @@ func (s *FileStore) GetGroupByName(_ context.Context, _ LockingStrength, _, _ st return nil, status.Errorf(status.Internal, "GetGroupByName is not implemented") } -func (s *FileStore) GetAccountPolicies(_ context.Context, _ string) ([]*Policy, error) { +func (s *FileStore) GetAccountPolicies(_ context.Context, _ LockingStrength, _ string) ([]*Policy, error) { return nil, status.Errorf(status.Internal, "GetPolicyByID is not implemented") } @@ -1011,7 +1011,7 @@ func (s *FileStore) GetPolicyByID(_ context.Context, _ LockingStrength, _ string } -func (s *FileStore) GetAccountPostureChecks(_ context.Context, _ string) ([]*posture.Checks, error) { +func (s *FileStore) GetAccountPostureChecks(_ context.Context, _ LockingStrength, _ string) ([]*posture.Checks, error) { return nil, status.Errorf(status.Internal, "GetAccountPostureChecks is not implemented") } @@ -1019,7 +1019,7 @@ func (s *FileStore) GetPostureChecksByID(_ context.Context, _ LockingStrength, _ return nil, status.Errorf(status.Internal, "GetPostureChecksByID is not implemented") } -func (s *FileStore) GetAccountRoutes(_ context.Context, _ string) ([]*route.Route, error) { +func (s *FileStore) GetAccountRoutes(_ context.Context, _ LockingStrength, _ string) ([]*route.Route, error) { return nil, status.Errorf(status.Internal, "GetAccountRoutes is not implemented") } @@ -1027,7 +1027,7 @@ func (s *FileStore) GetRouteByID(_ context.Context, _ LockingStrength, _ string, return nil, status.Errorf(status.Internal, "GetRouteByID is not implemented") } -func (s *FileStore) GetAccountSetupKeys(_ context.Context, _ string) ([]*SetupKey, error) { +func (s *FileStore) GetAccountSetupKeys(_ context.Context, _ LockingStrength, _ string) ([]*SetupKey, error) { return nil, status.Errorf(status.Internal, "GetAccountSetupKeys is not implemented") } @@ -1035,7 +1035,7 @@ func (s *FileStore) GetSetupKeyByID(_ context.Context, _ LockingStrength, _ stri return nil, status.Errorf(status.Internal, "GetSetupKeyByID is not implemented") } -func (s *FileStore) GetAccountNameServerGroups(_ context.Context, _ string) ([]*dns.NameServerGroup, error) { +func (s *FileStore) GetAccountNameServerGroups(_ context.Context, _ LockingStrength, _ string) ([]*dns.NameServerGroup, error) { return nil, status.Errorf(status.Internal, "GetAccountNameServerGroups is not implemented") } diff --git a/management/server/nameserver.go b/management/server/nameserver.go index e059d2217..0eb5d9ae4 100644 --- a/management/server/nameserver.go +++ b/management/server/nameserver.go @@ -154,7 +154,7 @@ func (am *DefaultAccountManager) ListNameServerGroups(ctx context.Context, accou return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view name server groups") } - return am.Store.GetAccountNameServerGroups(ctx, accountID) + return am.Store.GetAccountNameServerGroups(ctx, LockingStrengthShare, accountID) } func validateNameServerGroup(existingGroup bool, nameserverGroup *nbdns.NameServerGroup, account *Account) error { diff --git a/management/server/policy.go b/management/server/policy.go index c10be5c0c..5d07ba8f8 100644 --- a/management/server/policy.go +++ b/management/server/policy.go @@ -395,7 +395,7 @@ func (am *DefaultAccountManager) ListPolicies(ctx context.Context, accountID, us return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view policies") } - return am.Store.GetAccountPolicies(ctx, accountID) + return am.Store.GetAccountPolicies(ctx, LockingStrengthShare, accountID) } func (am *DefaultAccountManager) deletePolicy(account *Account, policyID string) (*Policy, error) { diff --git a/management/server/posture_checks.go b/management/server/posture_checks.go index 7a03effb1..9a4b679ce 100644 --- a/management/server/posture_checks.go +++ b/management/server/posture_checks.go @@ -116,7 +116,7 @@ func (am *DefaultAccountManager) ListPostureChecks(ctx context.Context, accountI return nil, status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly) } - return am.Store.GetAccountPostureChecks(ctx, accountID) + return am.Store.GetAccountPostureChecks(ctx, LockingStrengthShare, accountID) } func (am *DefaultAccountManager) savePostureChecks(account *Account, postureChecks *posture.Checks) (exists, uniqName bool) { diff --git a/management/server/route.go b/management/server/route.go index fbf4e82c5..6c1c8b1b3 100644 --- a/management/server/route.go +++ b/management/server/route.go @@ -321,7 +321,7 @@ func (am *DefaultAccountManager) ListRoutes(ctx context.Context, accountID, user return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view Network Routes") } - return am.Store.GetAccountRoutes(ctx, accountID) + return am.Store.GetAccountRoutes(ctx, LockingStrengthShare, accountID) } func toProtocolRoute(route *route.Route) *proto.Route { diff --git a/management/server/setupkey.go b/management/server/setupkey.go index 3e7be8a16..9521e22d3 100644 --- a/management/server/setupkey.go +++ b/management/server/setupkey.go @@ -339,7 +339,7 @@ func (am *DefaultAccountManager) ListSetupKeys(ctx context.Context, accountID, u return nil, status.Errorf(status.Unauthorized, "only users with admin power can view setup keys") } - setupKeys, err := am.Store.GetAccountSetupKeys(ctx, accountID) + setupKeys, err := am.Store.GetAccountSetupKeys(ctx, LockingStrengthShare, accountID) if err != nil { return nil, err } diff --git a/management/server/sql_store.go b/management/server/sql_store.go index 31d9e9279..9e65ba9cf 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -1054,10 +1054,11 @@ func (s *SqlStore) GetAccountDNSSettings(ctx context.Context, lockStrength Locki } // AccountExists checks whether an account exists by the given ID. -func (s *SqlStore) AccountExists(ctx context.Context, id string) (bool, error) { +func (s *SqlStore) AccountExists(ctx context.Context, lockStrength LockingStrength, id string) (bool, error) { var count int64 - result := s.db.WithContext(ctx).Model(&Account{}).Where(idQueryCondition, id).Count(&count) + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}). + Where(idQueryCondition, id).Count(&count) if result.Error != nil { return false, result.Error } @@ -1101,8 +1102,8 @@ func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStren } // GetAccountPolicies retrieves policies for an account. -func (s *SqlStore) GetAccountPolicies(ctx context.Context, accountID string) ([]*Policy, error) { - return getRecords[*Policy](s.db.WithContext(ctx).Preload(clause.Associations), accountID) +func (s *SqlStore) GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) { + return getRecords[*Policy](s.db.WithContext(ctx).Preload(clause.Associations), lockStrength, accountID) } // GetPolicyByID retrieves a policy by its ID and account ID. @@ -1111,8 +1112,8 @@ func (s *SqlStore) GetPolicyByID(ctx context.Context, lockStrength LockingStreng } // GetAccountPostureChecks retrieves posture checks for an account. -func (s *SqlStore) GetAccountPostureChecks(ctx context.Context, accountID string) ([]*posture.Checks, error) { - return getRecords[*posture.Checks](s.db.WithContext(ctx), accountID) +func (s *SqlStore) GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error) { + return getRecords[*posture.Checks](s.db.WithContext(ctx), lockStrength, accountID) } // GetPostureChecksByID retrieves posture checks by their ID and account ID. @@ -1121,8 +1122,8 @@ func (s *SqlStore) GetPostureChecksByID(ctx context.Context, lockStrength Lockin } // GetAccountRoutes retrieves network routes for an account. -func (s *SqlStore) GetAccountRoutes(ctx context.Context, accountID string) ([]*route.Route, error) { - return getRecords[*route.Route](s.db.WithContext(ctx), accountID) +func (s *SqlStore) GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error) { + return getRecords[*route.Route](s.db.WithContext(ctx), lockStrength, accountID) } // GetRouteByID retrieves a route by its ID and account ID. @@ -1131,8 +1132,8 @@ func (s *SqlStore) GetRouteByID(ctx context.Context, lockStrength LockingStrengt } // GetAccountSetupKeys retrieves setup keys for an account. -func (s *SqlStore) GetAccountSetupKeys(ctx context.Context, accountID string) ([]*SetupKey, error) { - return getRecords[*SetupKey](s.db.WithContext(ctx), accountID) +func (s *SqlStore) GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*SetupKey, error) { + return getRecords[*SetupKey](s.db.WithContext(ctx), lockStrength, accountID) } // GetSetupKeyByID retrieves a setup key by its ID and account ID. @@ -1141,8 +1142,8 @@ func (s *SqlStore) GetSetupKeyByID(ctx context.Context, lockStrength LockingStre } // GetAccountNameServerGroups retrieves name server groups for an account. -func (s *SqlStore) GetAccountNameServerGroups(ctx context.Context, accountID string) ([]*nbdns.NameServerGroup, error) { - return getRecords[*nbdns.NameServerGroup](s.db.WithContext(ctx), accountID) +func (s *SqlStore) GetAccountNameServerGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbdns.NameServerGroup, error) { + return getRecords[*nbdns.NameServerGroup](s.db.WithContext(ctx), lockStrength, accountID) } // GetNameServerGroupByID retrieves a name server group by its ID and account ID. @@ -1151,10 +1152,10 @@ func (s *SqlStore) GetNameServerGroupByID(ctx context.Context, lockStrength Lock } // getRecords retrieves records from the database based on the account ID. -func getRecords[T any](db *gorm.DB, accountID string) ([]T, error) { +func getRecords[T any](db *gorm.DB, lockStrength LockingStrength, accountID string) ([]T, error) { var record []T - result := db.Find(&record, accountIDCondition, accountID) + result := db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&record, accountIDCondition, accountID) if err := result.Error; err != nil { parts := strings.Split(fmt.Sprintf("%T", record), ".") recordType := parts[len(parts)-1] diff --git a/management/server/store.go b/management/server/store.go index 62a0d72a4..f34a73c2d 100644 --- a/management/server/store.go +++ b/management/server/store.go @@ -40,7 +40,7 @@ const ( type Store interface { GetAllAccounts(ctx context.Context) []*Account GetAccount(ctx context.Context, accountID string) (*Account, error) - AccountExists(ctx context.Context, id string) (bool, error) + AccountExists(ctx context.Context, lockStrength LockingStrength, id string) (bool, error) GetAccountDomainAndCategory(ctx context.Context, lockStrength LockingStrength, accountID string) (string, string, error) GetAccountByUser(ctx context.Context, userID string) (*Account, error) GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*Account, error) @@ -69,11 +69,11 @@ type Store interface { GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error) SaveGroups(accountID string, groups map[string]*nbgroup.Group) error - GetAccountPolicies(ctx context.Context, accountID string) ([]*Policy, error) + GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error) GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) - GetAccountPostureChecks(ctx context.Context, accountID string) ([]*posture.Checks, error) + GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error) GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, postureCheckID string, accountID string) (*posture.Checks, error) GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string) ([]string, error) @@ -87,13 +87,13 @@ type Store interface { GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error - GetAccountSetupKeys(ctx context.Context, accountID string) ([]*SetupKey, error) + GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*SetupKey, error) GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, setupKeyID string, accountID string) (*SetupKey, error) - GetAccountRoutes(ctx context.Context, accountID string) ([]*route.Route, error) + GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error) GetRouteByID(ctx context.Context, lockStrength LockingStrength, routeID string, accountID string) (*route.Route, error) - GetAccountNameServerGroups(ctx context.Context, accountID string) ([]*dns.NameServerGroup, error) + GetAccountNameServerGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*dns.NameServerGroup, error) GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, nameServerGroupID string, accountID string) (*dns.NameServerGroup, error) GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountId string) ([]net.IP, error)