diff --git a/management/server/file_store.go b/management/server/file_store.go index e4307b1bd..4c42bde0b 100644 --- a/management/server/file_store.go +++ b/management/server/file_store.go @@ -118,7 +118,7 @@ func (s *FileStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) err return s.SaveAccount(ctx, account) } -func (s *FileStore) IncrementNetworkSerial(ctx context.Context, accountId string) error { +func (s *FileStore) IncrementNetworkSerial(ctx context.Context, _ LockingStrength, accountId string) error { s.mux.Lock() defer s.mux.Unlock() @@ -1011,6 +1011,14 @@ func (s *FileStore) GetPolicyByID(_ context.Context, _ LockingStrength, _ string } +func (s *FileStore) SavePolicy(_ context.Context, _ LockingStrength, _ *Policy) error { + return status.Errorf(status.Internal, "SavePolicy is not implemented") +} + +func (s *FileStore) DeletePolicy(_ context.Context, _ LockingStrength, _ string) error { + return status.Errorf(status.Internal, "DeletePolicy is not implemented") +} + func (s *FileStore) GetAccountPostureChecks(_ context.Context, _ string) ([]*posture.Checks, error) { return nil, status.Errorf(status.Internal, "GetAccountPostureChecks is not implemented") } diff --git a/management/server/http/policies_handler.go b/management/server/http/policies_handler.go index 5b6e1121c..5bdc62e1a 100644 --- a/management/server/http/policies_handler.go +++ b/management/server/http/policies_handler.go @@ -130,6 +130,7 @@ func (h *Policies) savePolicy(w http.ResponseWriter, r *http.Request, accountID policy := server.Policy{ ID: policyID, + AccountID: accountID, Name: req.Name, Enabled: req.Enabled, Description: req.Description, diff --git a/management/server/peer.go b/management/server/peer.go index da9586734..c652ade3f 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -502,7 +502,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s return fmt.Errorf("failed to add peer to account: %w", err) } - err = transaction.IncrementNetworkSerial(ctx, accountID) + err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID) if err != nil { return fmt.Errorf("failed to increment network serial: %w", err) } diff --git a/management/server/policy.go b/management/server/policy.go index c10be5c0c..db2e06ebc 100644 --- a/management/server/policy.go +++ b/management/server/policy.go @@ -3,7 +3,7 @@ package server import ( "context" _ "embed" - "slices" + "fmt" "strconv" "strings" @@ -321,7 +321,7 @@ func (am *DefaultAccountManager) GetPolicy(ctx context.Context, accountID, polic } if !user.IsAdminOrServiceUser() || user.AccountID != accountID { - return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view policies") + return nil, status.Errorf(status.PermissionDenied, "only admin users are allowed to view policies") } return am.Store.GetPolicyByID(ctx, LockingStrengthShare, policyID, accountID) @@ -329,20 +329,48 @@ func (am *DefaultAccountManager) GetPolicy(ctx context.Context, accountID, polic // SavePolicy in the store func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *Policy, isUpdate bool) error { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - - account, err := am.Store.GetAccount(ctx, accountID) + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return err } - if err = am.savePolicy(account, policy, isUpdate); err != nil { + if !user.HasAdminPower() || user.AccountID != accountID { + return status.Errorf(status.PermissionDenied, "only admin users are allowed to update policies") + } + + groups, err := am.Store.GetAccountGroups(ctx, accountID) + if err != nil { return err } - account.Network.IncSerial() - if err = am.Store.SaveAccount(ctx, account); err != nil { + postureChecks, err := am.Store.GetAccountPostureChecks(ctx, accountID) + if err != nil { + return err + } + + for index, rule := range policy.Rules { + rule.Sources = getValidGroupIDs(groups, rule.Sources) + rule.Destinations = getValidGroupIDs(groups, rule.Destinations) + policy.Rules[index] = rule + } + + if policy.SourcePostureChecks != nil { + policy.SourcePostureChecks = getValidPostureCheckIDs(postureChecks, policy.SourcePostureChecks) + } + + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID) + if err != nil { + return fmt.Errorf("failed to increment network serial: %w", err) + } + + err = transaction.SavePolicy(ctx, LockingStrengthUpdate, policy) + if err != nil { + return fmt.Errorf("failed to save policy: %w", err) + } + return nil + }) + if err != nil { return err } @@ -352,6 +380,10 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user } am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta()) + account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) + if err != nil { + return fmt.Errorf("error getting account: %w", err) + } am.updateAccountPeers(ctx, account) return nil @@ -359,26 +391,42 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user // DeletePolicy from the store func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, policyID, userID string) error { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - - account, err := am.Store.GetAccount(ctx, accountID) + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return err } - policy, err := am.deletePolicy(account, policyID) + if !user.HasAdminPower() || user.AccountID != accountID { + return status.Errorf(status.PermissionDenied, "only admin users are allowed to delete policies") + } + + policy, err := am.Store.GetPolicyByID(ctx, LockingStrengthShare, policyID, accountID) if err != nil { return err } - account.Network.IncSerial() - if err = am.Store.SaveAccount(ctx, account); err != nil { + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID) + if err != nil { + return fmt.Errorf("failed to increment network serial: %w", err) + } + + err = transaction.DeletePolicy(ctx, LockingStrengthUpdate, policyID) + if err != nil { + return fmt.Errorf("failed to delete policy: %w", err) + } + return nil + }) + if err != nil { return err } - am.StoreEvent(ctx, userID, policy.ID, accountID, activity.PolicyRemoved, policy.EventMeta()) + am.StoreEvent(ctx, userID, policyID, accountID, activity.PolicyRemoved, policy.EventMeta()) + account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) + if err != nil { + return fmt.Errorf("error getting account: %w", err) + } am.updateAccountPeers(ctx, account) return nil @@ -392,7 +440,7 @@ func (am *DefaultAccountManager) ListPolicies(ctx context.Context, accountID, us } if !user.IsAdminOrServiceUser() || user.AccountID != accountID { - return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view policies") + return nil, status.Errorf(status.PermissionDenied, "only admin users are allowed to view policies") } return am.Store.GetAccountPolicies(ctx, accountID) @@ -415,36 +463,6 @@ func (am *DefaultAccountManager) deletePolicy(account *Account, policyID string) return policy, nil } -// savePolicy saves or updates a policy in the given account. -// If isUpdate is true, the function updates the existing policy; otherwise, it adds a new policy. -func (am *DefaultAccountManager) savePolicy(account *Account, policyToSave *Policy, isUpdate bool) error { - for index, rule := range policyToSave.Rules { - rule.Sources = filterValidGroupIDs(account, rule.Sources) - rule.Destinations = filterValidGroupIDs(account, rule.Destinations) - policyToSave.Rules[index] = rule - } - - if policyToSave.SourcePostureChecks != nil { - policyToSave.SourcePostureChecks = filterValidPostureChecks(account, policyToSave.SourcePostureChecks) - } - - if isUpdate { - policyIdx := slices.IndexFunc(account.Policies, func(policy *Policy) bool { return policy.ID == policyToSave.ID }) - if policyIdx < 0 { - return status.Errorf(status.NotFound, "couldn't find policy id %s", policyToSave.ID) - } - - // Update the existing policy - account.Policies[policyIdx] = policyToSave - return nil - } - - // Add the new policy to the account - account.Policies = append(account.Policies, policyToSave) - - return nil -} - func toProtocolFirewallRules(update []*FirewallRule) []*proto.FirewallRule { result := make([]*proto.FirewallRule, len(update)) for i := range update { @@ -558,28 +576,36 @@ func (a *Account) getPostureChecks(postureChecksID string) *posture.Checks { return nil } -// filterValidPostureChecks filters and returns the posture check IDs from the given list -// that are valid within the provided account. -func filterValidPostureChecks(account *Account, postureChecksIds []string) []string { - result := make([]string, 0, len(postureChecksIds)) +// getValidPostureCheckIDs filters and returns only the valid posture check IDs from the provided list. +func getValidPostureCheckIDs(postureChecks []*posture.Checks, postureChecksIds []string) []string { + validPostureCheckIDs := make(map[string]struct{}) + for _, check := range postureChecks { + validPostureCheckIDs[check.ID] = struct{}{} + } + + validIDs := make([]string, 0, len(postureChecksIds)) for _, id := range postureChecksIds { - for _, postureCheck := range account.PostureChecks { - if id == postureCheck.ID { - result = append(result, id) - continue - } + if _, exists := validPostureCheckIDs[id]; exists { + validIDs = append(validIDs, id) } } - return result + + return validIDs } -// filterValidGroupIDs filters a list of group IDs and returns only the ones present in the account's group map. -func filterValidGroupIDs(account *Account, groupIDs []string) []string { - result := make([]string, 0, len(groupIDs)) - for _, groupID := range groupIDs { - if _, exists := account.Groups[groupID]; exists { - result = append(result, groupID) +// getValidGroupIDs filters and returns only the valid group IDs from the provided list. +func getValidGroupIDs(groups []*nbgroup.Group, groupIDs []string) []string { + validGroupIDs := make(map[string]struct{}) + for _, group := range groups { + validGroupIDs[group.ID] = struct{}{} + } + + validIDs := make([]string, 0, len(groupIDs)) + for _, id := range groupIDs { + if _, exists := validGroupIDs[id]; exists { + validIDs = append(validIDs, id) } } - return result + + return validIDs } diff --git a/management/server/sql_store.go b/management/server/sql_store.go index f1533e850..6cf52836d 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -1007,8 +1007,9 @@ func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) erro return nil } -func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, accountId string) error { - result := s.db.WithContext(ctx).Model(&Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1")) +func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, lockStrength LockingStrength, accountId string) error { + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}). + Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1")) if result.Error != nil { return status.Errorf(status.Internal, "issue incrementing network serial count") } @@ -1106,6 +1107,18 @@ func (s *SqlStore) GetPolicyByID(ctx context.Context, lockStrength LockingStreng return getRecordByID[Policy](s.db.WithContext(ctx).Preload(clause.Associations), lockStrength, policyID, accountID) } +// SavePolicy saves a policy to the database. +func (s *SqlStore) SavePolicy(ctx context.Context, lockStrength LockingStrength, policy *Policy) error { + return s.db.WithContext(ctx).Session(&gorm.Session{FullSaveAssociations: true}). + Clauses(clause.Locking{Strength: string(lockStrength)}).Save(&policy).Error +} + +// DeletePolicy deletes a policy from the database. +func (s *SqlStore) DeletePolicy(ctx context.Context, lockStrength LockingStrength, policyID string) error { + return s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}). + Delete(&Policy{}, idQueryCondition, policyID).Error +} + // GetAccountPostureChecks retrieves posture checks for an account. func (s *SqlStore) GetAccountPostureChecks(ctx context.Context, accountID string) ([]*posture.Checks, error) { return getRecords[*posture.Checks](s.db.WithContext(ctx), accountID) diff --git a/management/server/store.go b/management/server/store.go index 62a0d72a4..4ac58f6ee 100644 --- a/management/server/store.go +++ b/management/server/store.go @@ -71,6 +71,8 @@ type Store interface { GetAccountPolicies(ctx context.Context, accountID string) ([]*Policy, error) GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error) + SavePolicy(ctx context.Context, lockStrength LockingStrength, policy *Policy) error + DeletePolicy(ctx context.Context, lockStrength LockingStrength, postureCheckID string) error GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) GetAccountPostureChecks(ctx context.Context, accountID string) ([]*posture.Checks, error) @@ -97,7 +99,7 @@ type Store interface { GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, nameServerGroupID string, accountID string) (*dns.NameServerGroup, error) GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountId string) ([]net.IP, error) - IncrementNetworkSerial(ctx context.Context, accountId string) error + IncrementNetworkSerial(ctx context.Context, lockStrength LockingStrength, accountId string) error GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountId string) (*Network, error) GetInstallationID() string