diff --git a/management/server/policy.go b/management/server/policy.go index 3e84c3d10..48297ca11 100644 --- a/management/server/policy.go +++ b/management/server/policy.go @@ -5,6 +5,7 @@ import ( _ "embed" "github.com/rs/xid" + "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/operations" @@ -46,25 +47,40 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user var isUpdate = policy.ID != "" var updateAccountPeers bool var action = activity.PolicyAdded + var unchanged bool err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { - if err = validatePolicy(ctx, transaction, accountID, policy); err != nil { - return err - } - - updateAccountPeers, err = arePolicyChangesAffectPeers(ctx, transaction, accountID, policy, isUpdate) + existingPolicy, err := validatePolicy(ctx, transaction, accountID, policy) if err != nil { return err } - saveFunc := transaction.CreatePolicy if isUpdate { - action = activity.PolicyUpdated - saveFunc = transaction.SavePolicy - } + if policy.Equal(existingPolicy) { + logrus.WithContext(ctx).Tracef("policy update skipped because equal to stored one - policy id %s", policy.ID) + unchanged = true + return nil + } - if err = saveFunc(ctx, policy); err != nil { - return err + action = activity.PolicyUpdated + + updateAccountPeers, err = arePolicyChangesAffectPeersWithExisting(ctx, transaction, policy, existingPolicy) + if err != nil { + return err + } + + if err = transaction.SavePolicy(ctx, policy); err != nil { + return err + } + } else { + updateAccountPeers, err = arePolicyChangesAffectPeers(ctx, transaction, policy) + if err != nil { + return err + } + + if err = transaction.CreatePolicy(ctx, policy); err != nil { + return err + } } return transaction.IncrementNetworkSerial(ctx, accountID) @@ -73,6 +89,10 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user return nil, err } + if unchanged { + return policy, nil + } + am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta()) if updateAccountPeers { @@ -101,7 +121,7 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po return err } - updateAccountPeers, err = arePolicyChangesAffectPeers(ctx, transaction, accountID, policy, false) + updateAccountPeers, err = arePolicyChangesAffectPeers(ctx, transaction, policy) if err != nil { return err } @@ -138,34 +158,37 @@ func (am *DefaultAccountManager) ListPolicies(ctx context.Context, accountID, us return am.Store.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID) } -// arePolicyChangesAffectPeers checks if changes to a policy will affect any associated peers. -func arePolicyChangesAffectPeers(ctx context.Context, transaction store.Store, accountID string, policy *types.Policy, isUpdate bool) (bool, error) { - if isUpdate { - existingPolicy, err := transaction.GetPolicyByID(ctx, store.LockingStrengthNone, accountID, policy.ID) - if err != nil { - return false, err - } - - if !policy.Enabled && !existingPolicy.Enabled { - return false, nil - } - - for _, rule := range existingPolicy.Rules { - if rule.SourceResource.Type != "" || rule.DestinationResource.Type != "" { - return true, nil - } - } - - hasPeers, err := anyGroupHasPeersOrResources(ctx, transaction, policy.AccountID, existingPolicy.RuleGroups()) - if err != nil { - return false, err - } - - if hasPeers { +// arePolicyChangesAffectPeers checks if a policy (being created or deleted) will affect any associated peers. +func arePolicyChangesAffectPeers(ctx context.Context, transaction store.Store, policy *types.Policy) (bool, error) { + for _, rule := range policy.Rules { + if rule.SourceResource.Type != "" || rule.DestinationResource.Type != "" { return true, nil } } + return anyGroupHasPeersOrResources(ctx, transaction, policy.AccountID, policy.RuleGroups()) +} + +func arePolicyChangesAffectPeersWithExisting(ctx context.Context, transaction store.Store, policy *types.Policy, existingPolicy *types.Policy) (bool, error) { + if !policy.Enabled && !existingPolicy.Enabled { + return false, nil + } + + for _, rule := range existingPolicy.Rules { + if rule.SourceResource.Type != "" || rule.DestinationResource.Type != "" { + return true, nil + } + } + + hasPeers, err := anyGroupHasPeersOrResources(ctx, transaction, policy.AccountID, existingPolicy.RuleGroups()) + if err != nil { + return false, err + } + + if hasPeers { + return true, nil + } + for _, rule := range policy.Rules { if rule.SourceResource.Type != "" || rule.DestinationResource.Type != "" { return true, nil @@ -175,12 +198,15 @@ func arePolicyChangesAffectPeers(ctx context.Context, transaction store.Store, a return anyGroupHasPeersOrResources(ctx, transaction, policy.AccountID, policy.RuleGroups()) } -// validatePolicy validates the policy and its rules. -func validatePolicy(ctx context.Context, transaction store.Store, accountID string, policy *types.Policy) error { +// validatePolicy validates the policy and its rules. For updates it returns +// the existing policy loaded from the store so callers can avoid a second read. +func validatePolicy(ctx context.Context, transaction store.Store, accountID string, policy *types.Policy) (*types.Policy, error) { + var existingPolicy *types.Policy if policy.ID != "" { - existingPolicy, err := transaction.GetPolicyByID(ctx, store.LockingStrengthNone, accountID, policy.ID) + var err error + existingPolicy, err = transaction.GetPolicyByID(ctx, store.LockingStrengthNone, accountID, policy.ID) if err != nil { - return err + return nil, err } // TODO: Refactor to support multiple rules per policy @@ -191,7 +217,7 @@ func validatePolicy(ctx context.Context, transaction store.Store, accountID stri for _, rule := range policy.Rules { if rule.ID != "" && !existingRuleIDs[rule.ID] { - return status.Errorf(status.InvalidArgument, "invalid rule ID: %s", rule.ID) + return nil, status.Errorf(status.InvalidArgument, "invalid rule ID: %s", rule.ID) } } } else { @@ -201,12 +227,12 @@ func validatePolicy(ctx context.Context, transaction store.Store, accountID stri groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, policy.RuleGroups()) if err != nil { - return err + return nil, err } postureChecks, err := transaction.GetPostureChecksByIDs(ctx, store.LockingStrengthNone, accountID, policy.SourcePostureChecks) if err != nil { - return err + return nil, err } for i, rule := range policy.Rules { @@ -225,7 +251,7 @@ func validatePolicy(ctx context.Context, transaction store.Store, accountID stri policy.SourcePostureChecks = getValidPostureCheckIDs(postureChecks, policy.SourcePostureChecks) } - return nil + return existingPolicy, nil } // getValidPostureCheckIDs filters and returns only the valid posture check IDs from the provided list. diff --git a/management/server/types/policy.go b/management/server/types/policy.go index d4e1a8816..d410aec8d 100644 --- a/management/server/types/policy.go +++ b/management/server/types/policy.go @@ -93,6 +93,44 @@ func (p *Policy) Copy() *Policy { return c } +func (p *Policy) Equal(other *Policy) bool { + if p == nil || other == nil { + return p == other + } + + if p.ID != other.ID || + p.AccountID != other.AccountID || + p.Name != other.Name || + p.Description != other.Description || + p.Enabled != other.Enabled { + return false + } + + if !stringSlicesEqualUnordered(p.SourcePostureChecks, other.SourcePostureChecks) { + return false + } + + if len(p.Rules) != len(other.Rules) { + return false + } + + otherRules := make(map[string]*PolicyRule, len(other.Rules)) + for _, r := range other.Rules { + otherRules[r.ID] = r + } + for _, r := range p.Rules { + otherRule, ok := otherRules[r.ID] + if !ok { + return false + } + if !r.Equal(otherRule) { + return false + } + } + + return true +} + // EventMeta returns activity event meta related to this policy func (p *Policy) EventMeta() map[string]any { return map[string]any{"name": p.Name} diff --git a/management/server/types/policy_test.go b/management/server/types/policy_test.go new file mode 100644 index 000000000..b1d7aabc2 --- /dev/null +++ b/management/server/types/policy_test.go @@ -0,0 +1,193 @@ +package types + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestPolicyEqual_SameRulesDifferentOrder(t *testing.T) { + a := &Policy{ + ID: "pol1", + AccountID: "acc1", + Name: "test", + Enabled: true, + Rules: []*PolicyRule{ + {ID: "r1", PolicyID: "pol1", Ports: []string{"80"}}, + {ID: "r2", PolicyID: "pol1", Ports: []string{"443"}}, + }, + } + b := &Policy{ + ID: "pol1", + AccountID: "acc1", + Name: "test", + Enabled: true, + Rules: []*PolicyRule{ + {ID: "r2", PolicyID: "pol1", Ports: []string{"443"}}, + {ID: "r1", PolicyID: "pol1", Ports: []string{"80"}}, + }, + } + assert.True(t, a.Equal(b)) +} + +func TestPolicyEqual_DifferentRules(t *testing.T) { + a := &Policy{ + ID: "pol1", + Enabled: true, + Rules: []*PolicyRule{ + {ID: "r1", PolicyID: "pol1", Ports: []string{"80"}}, + }, + } + b := &Policy{ + ID: "pol1", + Enabled: true, + Rules: []*PolicyRule{ + {ID: "r1", PolicyID: "pol1", Ports: []string{"443"}}, + }, + } + assert.False(t, a.Equal(b)) +} + +func TestPolicyEqual_DifferentRuleCount(t *testing.T) { + a := &Policy{ + ID: "pol1", + Rules: []*PolicyRule{ + {ID: "r1", PolicyID: "pol1"}, + }, + } + b := &Policy{ + ID: "pol1", + Rules: []*PolicyRule{ + {ID: "r1", PolicyID: "pol1"}, + {ID: "r2", PolicyID: "pol1"}, + }, + } + assert.False(t, a.Equal(b)) +} + +func TestPolicyEqual_PostureChecksDifferentOrder(t *testing.T) { + a := &Policy{ + ID: "pol1", + SourcePostureChecks: []string{"pc3", "pc1", "pc2"}, + } + b := &Policy{ + ID: "pol1", + SourcePostureChecks: []string{"pc1", "pc2", "pc3"}, + } + assert.True(t, a.Equal(b)) +} + +func TestPolicyEqual_DifferentPostureChecks(t *testing.T) { + a := &Policy{ + ID: "pol1", + SourcePostureChecks: []string{"pc1", "pc2"}, + } + b := &Policy{ + ID: "pol1", + SourcePostureChecks: []string{"pc1", "pc3"}, + } + assert.False(t, a.Equal(b)) +} + +func TestPolicyEqual_DifferentScalarFields(t *testing.T) { + base := Policy{ + ID: "pol1", + AccountID: "acc1", + Name: "test", + Description: "desc", + Enabled: true, + } + + other := base + other.Name = "changed" + assert.False(t, base.Equal(&other)) + + other = base + other.Enabled = false + assert.False(t, base.Equal(&other)) + + other = base + other.Description = "changed" + assert.False(t, base.Equal(&other)) +} + +func TestPolicyEqual_NilCases(t *testing.T) { + var a *Policy + var b *Policy + assert.True(t, a.Equal(b)) + + a = &Policy{ID: "pol1"} + assert.False(t, a.Equal(nil)) +} + +func TestPolicyEqual_RulesMismatchByID(t *testing.T) { + a := &Policy{ + ID: "pol1", + Rules: []*PolicyRule{ + {ID: "r1", PolicyID: "pol1"}, + }, + } + b := &Policy{ + ID: "pol1", + Rules: []*PolicyRule{ + {ID: "r2", PolicyID: "pol1"}, + }, + } + assert.False(t, a.Equal(b)) +} + +func TestPolicyEqual_FullScenario(t *testing.T) { + a := &Policy{ + ID: "pol1", + AccountID: "acc1", + Name: "Web Access", + Description: "Allow web access", + Enabled: true, + SourcePostureChecks: []string{"pc2", "pc1"}, + Rules: []*PolicyRule{ + { + ID: "r1", + PolicyID: "pol1", + Name: "HTTP", + Enabled: true, + Action: PolicyTrafficActionAccept, + Protocol: PolicyRuleProtocolTCP, + Bidirectional: true, + Sources: []string{"g2", "g1"}, + Destinations: []string{"g4", "g3"}, + Ports: []string{"443", "80", "8080"}, + PortRanges: []RulePortRange{ + {Start: 8000, End: 9000}, + {Start: 80, End: 80}, + }, + }, + }, + } + b := &Policy{ + ID: "pol1", + AccountID: "acc1", + Name: "Web Access", + Description: "Allow web access", + Enabled: true, + SourcePostureChecks: []string{"pc1", "pc2"}, + Rules: []*PolicyRule{ + { + ID: "r1", + PolicyID: "pol1", + Name: "HTTP", + Enabled: true, + Action: PolicyTrafficActionAccept, + Protocol: PolicyRuleProtocolTCP, + Bidirectional: true, + Sources: []string{"g1", "g2"}, + Destinations: []string{"g3", "g4"}, + Ports: []string{"80", "8080", "443"}, + PortRanges: []RulePortRange{ + {Start: 80, End: 80}, + {Start: 8000, End: 9000}, + }, + }, + }, + } + assert.True(t, a.Equal(b)) +} diff --git a/management/server/types/policyrule.go b/management/server/types/policyrule.go index bb75dd555..52c494a6a 100644 --- a/management/server/types/policyrule.go +++ b/management/server/types/policyrule.go @@ -1,6 +1,8 @@ package types import ( + "slices" + "github.com/netbirdio/netbird/shared/management/proto" ) @@ -118,3 +120,106 @@ func (pm *PolicyRule) Copy() *PolicyRule { } return rule } + +func (pm *PolicyRule) Equal(other *PolicyRule) bool { + if pm == nil || other == nil { + return pm == other + } + + if pm.ID != other.ID || + pm.PolicyID != other.PolicyID || + pm.Name != other.Name || + pm.Description != other.Description || + pm.Enabled != other.Enabled || + pm.Action != other.Action || + pm.Bidirectional != other.Bidirectional || + pm.Protocol != other.Protocol || + pm.SourceResource != other.SourceResource || + pm.DestinationResource != other.DestinationResource || + pm.AuthorizedUser != other.AuthorizedUser { + return false + } + + if !stringSlicesEqualUnordered(pm.Sources, other.Sources) { + return false + } + if !stringSlicesEqualUnordered(pm.Destinations, other.Destinations) { + return false + } + if !stringSlicesEqualUnordered(pm.Ports, other.Ports) { + return false + } + if !portRangeSlicesEqualUnordered(pm.PortRanges, other.PortRanges) { + return false + } + if !authorizedGroupsEqual(pm.AuthorizedGroups, other.AuthorizedGroups) { + return false + } + + return true +} + +func stringSlicesEqualUnordered(a, b []string) bool { + if len(a) != len(b) { + return false + } + if len(a) == 0 { + return true + } + sorted1 := make([]string, len(a)) + sorted2 := make([]string, len(b)) + copy(sorted1, a) + copy(sorted2, b) + slices.Sort(sorted1) + slices.Sort(sorted2) + return slices.Equal(sorted1, sorted2) +} + +func portRangeSlicesEqualUnordered(a, b []RulePortRange) bool { + if len(a) != len(b) { + return false + } + if len(a) == 0 { + return true + } + cmp := func(x, y RulePortRange) int { + if x.Start != y.Start { + if x.Start < y.Start { + return -1 + } + return 1 + } + if x.End != y.End { + if x.End < y.End { + return -1 + } + return 1 + } + return 0 + } + sorted1 := make([]RulePortRange, len(a)) + sorted2 := make([]RulePortRange, len(b)) + copy(sorted1, a) + copy(sorted2, b) + slices.SortFunc(sorted1, cmp) + slices.SortFunc(sorted2, cmp) + return slices.EqualFunc(sorted1, sorted2, func(x, y RulePortRange) bool { + return x.Start == y.Start && x.End == y.End + }) +} + +func authorizedGroupsEqual(a, b map[string][]string) bool { + if len(a) != len(b) { + return false + } + for k, va := range a { + vb, ok := b[k] + if !ok { + return false + } + if !stringSlicesEqualUnordered(va, vb) { + return false + } + } + return true +} diff --git a/management/server/types/policyrule_test.go b/management/server/types/policyrule_test.go new file mode 100644 index 000000000..816e72abb --- /dev/null +++ b/management/server/types/policyrule_test.go @@ -0,0 +1,194 @@ +package types + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestPolicyRuleEqual_SamePortsDifferentOrder(t *testing.T) { + a := &PolicyRule{ + ID: "rule1", + PolicyID: "pol1", + Ports: []string{"443", "80", "22"}, + } + b := &PolicyRule{ + ID: "rule1", + PolicyID: "pol1", + Ports: []string{"22", "443", "80"}, + } + assert.True(t, a.Equal(b)) +} + +func TestPolicyRuleEqual_DifferentPorts(t *testing.T) { + a := &PolicyRule{ + ID: "rule1", + PolicyID: "pol1", + Ports: []string{"443", "80"}, + } + b := &PolicyRule{ + ID: "rule1", + PolicyID: "pol1", + Ports: []string{"443", "22"}, + } + assert.False(t, a.Equal(b)) +} + +func TestPolicyRuleEqual_SourcesDestinationsDifferentOrder(t *testing.T) { + a := &PolicyRule{ + ID: "rule1", + PolicyID: "pol1", + Sources: []string{"g1", "g2", "g3"}, + Destinations: []string{"g4", "g5"}, + } + b := &PolicyRule{ + ID: "rule1", + PolicyID: "pol1", + Sources: []string{"g3", "g1", "g2"}, + Destinations: []string{"g5", "g4"}, + } + assert.True(t, a.Equal(b)) +} + +func TestPolicyRuleEqual_DifferentSources(t *testing.T) { + a := &PolicyRule{ + ID: "rule1", + PolicyID: "pol1", + Sources: []string{"g1", "g2"}, + } + b := &PolicyRule{ + ID: "rule1", + PolicyID: "pol1", + Sources: []string{"g1", "g3"}, + } + assert.False(t, a.Equal(b)) +} + +func TestPolicyRuleEqual_PortRangesDifferentOrder(t *testing.T) { + a := &PolicyRule{ + ID: "rule1", + PolicyID: "pol1", + PortRanges: []RulePortRange{ + {Start: 8000, End: 9000}, + {Start: 80, End: 80}, + }, + } + b := &PolicyRule{ + ID: "rule1", + PolicyID: "pol1", + PortRanges: []RulePortRange{ + {Start: 80, End: 80}, + {Start: 8000, End: 9000}, + }, + } + assert.True(t, a.Equal(b)) +} + +func TestPolicyRuleEqual_DifferentPortRanges(t *testing.T) { + a := &PolicyRule{ + ID: "rule1", + PolicyID: "pol1", + PortRanges: []RulePortRange{ + {Start: 80, End: 80}, + }, + } + b := &PolicyRule{ + ID: "rule1", + PolicyID: "pol1", + PortRanges: []RulePortRange{ + {Start: 80, End: 443}, + }, + } + assert.False(t, a.Equal(b)) +} + +func TestPolicyRuleEqual_AuthorizedGroupsDifferentValueOrder(t *testing.T) { + a := &PolicyRule{ + ID: "rule1", + PolicyID: "pol1", + AuthorizedGroups: map[string][]string{ + "g1": {"u1", "u2", "u3"}, + }, + } + b := &PolicyRule{ + ID: "rule1", + PolicyID: "pol1", + AuthorizedGroups: map[string][]string{ + "g1": {"u3", "u1", "u2"}, + }, + } + assert.True(t, a.Equal(b)) +} + +func TestPolicyRuleEqual_DifferentAuthorizedGroups(t *testing.T) { + a := &PolicyRule{ + ID: "rule1", + PolicyID: "pol1", + AuthorizedGroups: map[string][]string{ + "g1": {"u1"}, + }, + } + b := &PolicyRule{ + ID: "rule1", + PolicyID: "pol1", + AuthorizedGroups: map[string][]string{ + "g2": {"u1"}, + }, + } + assert.False(t, a.Equal(b)) +} + +func TestPolicyRuleEqual_DifferentScalarFields(t *testing.T) { + base := PolicyRule{ + ID: "rule1", + PolicyID: "pol1", + Name: "test", + Description: "desc", + Enabled: true, + Action: PolicyTrafficActionAccept, + Bidirectional: true, + Protocol: PolicyRuleProtocolTCP, + } + + other := base + other.Name = "changed" + assert.False(t, base.Equal(&other)) + + other = base + other.Enabled = false + assert.False(t, base.Equal(&other)) + + other = base + other.Action = PolicyTrafficActionDrop + assert.False(t, base.Equal(&other)) + + other = base + other.Protocol = PolicyRuleProtocolUDP + assert.False(t, base.Equal(&other)) +} + +func TestPolicyRuleEqual_NilCases(t *testing.T) { + var a *PolicyRule + var b *PolicyRule + assert.True(t, a.Equal(b)) + + a = &PolicyRule{ID: "rule1"} + assert.False(t, a.Equal(nil)) +} + +func TestPolicyRuleEqual_EmptySlices(t *testing.T) { + a := &PolicyRule{ + ID: "rule1", + PolicyID: "pol1", + Ports: []string{}, + Sources: nil, + } + b := &PolicyRule{ + ID: "rule1", + PolicyID: "pol1", + Ports: nil, + Sources: []string{}, + } + assert.True(t, a.Equal(b)) +} +