diff --git a/management/server/policy.go b/management/server/policy.go index c94ae65c3..4e3a227da 100644 --- a/management/server/policy.go +++ b/management/server/policy.go @@ -49,19 +49,12 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user var unchanged bool err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { - if err = validatePolicy(ctx, transaction, accountID, policy); err != nil { + existingPolicy, err := validatePolicy(ctx, transaction, accountID, policy) + if err != nil { return err } if isUpdate { - existingPolicy, getErr := transaction.GetPolicyByID(ctx, store.LockingStrengthNone, accountID, policy.ID) - if getErr != nil { - return getErr - } - - existingPolicy.Normalize() - policy.Normalize() - if policy.Equal(existingPolicy) { unchanged = true return nil @@ -78,7 +71,7 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user return err } } else { - updateAccountPeers, err = arePolicyChangesAffectPeers(ctx, transaction, accountID, policy, false) + updateAccountPeers, err = arePolicyChangesAffectPeers(ctx, transaction, policy) if err != nil { return err } @@ -126,7 +119,7 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po return err } - updateAccountPeers, err = arePolicyChangesAffectPeers(ctx, transaction, accountID, policy, false) + updateAccountPeers, err = arePolicyChangesAffectPeers(ctx, transaction, policy) if err != nil { return err } @@ -163,17 +156,8 @@ func (am *DefaultAccountManager) ListPolicies(ctx context.Context, accountID, us return am.Store.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID) } -// arePolicyChangesAffectPeers checks if changes to a policy will affect any associated peers. -func arePolicyChangesAffectPeers(ctx context.Context, transaction store.Store, accountID string, policy *types.Policy, isUpdate bool) (bool, error) { - if isUpdate { - existingPolicy, err := transaction.GetPolicyByID(ctx, store.LockingStrengthNone, accountID, policy.ID) - if err != nil { - return false, err - } - - return arePolicyChangesAffectPeersWithExisting(ctx, transaction, policy, existingPolicy) - } - +// arePolicyChangesAffectPeers checks if a policy (being created or deleted) will affect any associated peers. +func arePolicyChangesAffectPeers(ctx context.Context, transaction store.Store, policy *types.Policy) (bool, error) { for _, rule := range policy.Rules { if rule.SourceResource.Type != "" || rule.DestinationResource.Type != "" { return true, nil @@ -212,12 +196,15 @@ func arePolicyChangesAffectPeersWithExisting(ctx context.Context, transaction st return anyGroupHasPeersOrResources(ctx, transaction, policy.AccountID, policy.RuleGroups()) } -// validatePolicy validates the policy and its rules. -func validatePolicy(ctx context.Context, transaction store.Store, accountID string, policy *types.Policy) error { +// validatePolicy validates the policy and its rules. For updates it returns +// the existing policy loaded from the store so callers can avoid a second read. +func validatePolicy(ctx context.Context, transaction store.Store, accountID string, policy *types.Policy) (*types.Policy, error) { + var existingPolicy *types.Policy if policy.ID != "" { - existingPolicy, err := transaction.GetPolicyByID(ctx, store.LockingStrengthNone, accountID, policy.ID) + var err error + existingPolicy, err = transaction.GetPolicyByID(ctx, store.LockingStrengthNone, accountID, policy.ID) if err != nil { - return err + return nil, err } // TODO: Refactor to support multiple rules per policy @@ -228,7 +215,7 @@ func validatePolicy(ctx context.Context, transaction store.Store, accountID stri for _, rule := range policy.Rules { if rule.ID != "" && !existingRuleIDs[rule.ID] { - return status.Errorf(status.InvalidArgument, "invalid rule ID: %s", rule.ID) + return nil, status.Errorf(status.InvalidArgument, "invalid rule ID: %s", rule.ID) } } } else { @@ -238,12 +225,12 @@ func validatePolicy(ctx context.Context, transaction store.Store, accountID stri groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, policy.RuleGroups()) if err != nil { - return err + return nil, err } postureChecks, err := transaction.GetPostureChecksByIDs(ctx, store.LockingStrengthNone, accountID, policy.SourcePostureChecks) if err != nil { - return err + return nil, err } for i, rule := range policy.Rules { @@ -262,7 +249,7 @@ func validatePolicy(ctx context.Context, transaction store.Store, accountID stri policy.SourcePostureChecks = getValidPostureCheckIDs(postureChecks, policy.SourcePostureChecks) } - return nil + return existingPolicy, nil } // getValidPostureCheckIDs filters and returns only the valid posture check IDs from the provided list. diff --git a/management/server/types/policy.go b/management/server/types/policy.go index 101071ed0..d410aec8d 100644 --- a/management/server/types/policy.go +++ b/management/server/types/policy.go @@ -3,7 +3,6 @@ package types import ( "errors" "fmt" - "slices" "strconv" "strings" ) @@ -132,16 +131,6 @@ func (p *Policy) Equal(other *Policy) bool { return true } -func (p *Policy) Normalize() { - if p == nil { - return - } - slices.Sort(p.SourcePostureChecks) - for _, r := range p.Rules { - r.Normalize() - } -} - // EventMeta returns activity event meta related to this policy func (p *Policy) EventMeta() map[string]any { return map[string]any{"name": p.Name} diff --git a/management/server/types/policy_test.go b/management/server/types/policy_test.go index 7c70df7bc..b1d7aabc2 100644 --- a/management/server/types/policy_test.go +++ b/management/server/types/policy_test.go @@ -136,31 +136,6 @@ func TestPolicyEqual_RulesMismatchByID(t *testing.T) { assert.False(t, a.Equal(b)) } -func TestPolicyNormalize(t *testing.T) { - p := &Policy{ - SourcePostureChecks: []string{"pc3", "pc1", "pc2"}, - Rules: []*PolicyRule{ - { - ID: "r1", - Sources: []string{"g2", "g1"}, - Destinations: []string{"g4", "g3"}, - Ports: []string{"443", "80"}, - }, - }, - } - p.Normalize() - - assert.Equal(t, []string{"pc1", "pc2", "pc3"}, p.SourcePostureChecks) - assert.Equal(t, []string{"g1", "g2"}, p.Rules[0].Sources) - assert.Equal(t, []string{"g3", "g4"}, p.Rules[0].Destinations) - assert.Equal(t, []string{"443", "80"}, p.Rules[0].Ports) -} - -func TestPolicyNormalize_Nil(t *testing.T) { - var p *Policy - p.Normalize() -} - func TestPolicyEqual_FullScenario(t *testing.T) { a := &Policy{ ID: "pol1", diff --git a/management/server/types/policyrule.go b/management/server/types/policyrule.go index 58a1b7344..52c494a6a 100644 --- a/management/server/types/policyrule.go +++ b/management/server/types/policyrule.go @@ -2,7 +2,6 @@ package types import ( "slices" - "sort" "github.com/netbirdio/netbird/shared/management/proto" ) @@ -160,25 +159,6 @@ func (pm *PolicyRule) Equal(other *PolicyRule) bool { return true } -func (pm *PolicyRule) Normalize() { - if pm == nil { - return - } - slices.Sort(pm.Sources) - slices.Sort(pm.Destinations) - slices.Sort(pm.Ports) - sort.Slice(pm.PortRanges, func(i, j int) bool { - if pm.PortRanges[i].Start != pm.PortRanges[j].Start { - return pm.PortRanges[i].Start < pm.PortRanges[j].Start - } - return pm.PortRanges[i].End < pm.PortRanges[j].End - }) - for k, v := range pm.AuthorizedGroups { - slices.Sort(v) - pm.AuthorizedGroups[k] = v - } -} - func stringSlicesEqualUnordered(a, b []string) bool { if len(a) != len(b) { return false diff --git a/management/server/types/policyrule_test.go b/management/server/types/policyrule_test.go index 606dfe7c7..816e72abb 100644 --- a/management/server/types/policyrule_test.go +++ b/management/server/types/policyrule_test.go @@ -192,34 +192,3 @@ func TestPolicyRuleEqual_EmptySlices(t *testing.T) { assert.True(t, a.Equal(b)) } -func TestPolicyRuleNormalize(t *testing.T) { - rule := &PolicyRule{ - Sources: []string{"g3", "g1", "g2"}, - Destinations: []string{"g6", "g4", "g5"}, - Ports: []string{"443", "80", "22"}, - PortRanges: []RulePortRange{ - {Start: 8000, End: 9000}, - {Start: 80, End: 80}, - {Start: 80, End: 443}, - }, - AuthorizedGroups: map[string][]string{ - "g1": {"u3", "u1", "u2"}, - }, - } - rule.Normalize() - - assert.Equal(t, []string{"g1", "g2", "g3"}, rule.Sources) - assert.Equal(t, []string{"g4", "g5", "g6"}, rule.Destinations) - assert.Equal(t, []string{"22", "443", "80"}, rule.Ports) - assert.Equal(t, []RulePortRange{ - {Start: 80, End: 80}, - {Start: 80, End: 443}, - {Start: 8000, End: 9000}, - }, rule.PortRanges) - assert.Equal(t, []string{"u1", "u2", "u3"}, rule.AuthorizedGroups["g1"]) -} - -func TestPolicyRuleNormalize_Nil(t *testing.T) { - var rule *PolicyRule - rule.Normalize() -}