Compare commits

...

4 Commits

Author SHA1 Message Date
Zoltan Papp
1165058fad [client] fix port collision in TestUpload (#5950)
* [debug] fix port collision in TestUpload

TestUpload hardcoded :8080, so it failed deterministically when anything
was already on that port and collided across concurrent test runs.
Bind a :0 listener in the test to get a kernel-assigned free port, and
add Server.Serve so tests can hand the listener in without reaching
into unexported state.

* [debug] drop test-only Server.Serve, use SERVER_ADDRESS env

The previous commit added a Server.Serve method on the upload-server,
used only by TestUpload. That left production with an unused function.
Reserve an ephemeral loopback port in the test, release it, and pass
the address through SERVER_ADDRESS (which the server already reads).
A small wait helper ensures the server is accepting connections before
the upload runs, so the close/rebind gap does not cause a false failure.
2026-04-21 19:07:20 +02:00
Zoltan Papp
703353d354 [flow] fix goroutine leak in TestReceive_ProtocolErrorStreamReconnect (#5951)
The Receive goroutine could outlive the test and call t.Logf after
teardown, panicking with "Log in goroutine after ... has completed".
Register a cleanup that waits for the goroutine to exit; ordering is
LIFO so it runs after client.Close, which is what unblocks Receive.
2026-04-21 19:06:47 +02:00
Zoltan Papp
2fb50aef6b [client] allow UDP packet loss in TestICEBind_HandlesConcurrentMixedTraffic (#5953)
The test writes 500 packets per family and asserted exact-count
delivery within a 5s window, even though its own comment says "Some
packet loss is acceptable for UDP". On FreeBSD/QEMU runners the writer
loops cannot always finish all 500 before the 5s deadline closes the
readers (we have seen 411/500 in CI).

The real assertion of this test is the routing check — IPv4 peer only
gets v4- packets, IPv6 peer only gets v6- packets — which remains
strict. Replace the exact-count assertions with a >=80% delivery
threshold so runner speed variance no longer causes false failures.
2026-04-21 19:05:58 +02:00
Vlad
eb3aa96257 [management] check policy for changes before actual db update (#5405) 2026-04-21 18:37:04 +02:00
8 changed files with 652 additions and 47 deletions

View File

@@ -239,8 +239,12 @@ func TestICEBind_HandlesConcurrentMixedTraffic(t *testing.T) {
ipv6Count++ ipv6Count++
} }
assert.Equal(t, packetsPerFamily, ipv4Count) // Allow some UDP packet loss under load (e.g. FreeBSD/QEMU runners). The
assert.Equal(t, packetsPerFamily, ipv6Count) // routing-correctness checks above are the real assertions; the counts
// are a sanity bound to catch a totally silent path.
minDelivered := packetsPerFamily * 80 / 100
assert.GreaterOrEqual(t, ipv4Count, minDelivered, "IPv4 delivery below threshold")
assert.GreaterOrEqual(t, ipv6Count, minDelivered, "IPv6 delivery below threshold")
} }
func TestICEBind_DetectsAddressFamilyFromConnection(t *testing.T) { func TestICEBind_DetectsAddressFamilyFromConnection(t *testing.T) {

View File

@@ -3,10 +3,12 @@ package debug
import ( import (
"context" "context"
"errors" "errors"
"net"
"net/http" "net/http"
"os" "os"
"path/filepath" "path/filepath"
"testing" "testing"
"time"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@@ -19,8 +21,10 @@ func TestUpload(t *testing.T) {
t.Skip("Skipping upload test on docker ci") t.Skip("Skipping upload test on docker ci")
} }
testDir := t.TempDir() testDir := t.TempDir()
testURL := "http://localhost:8080" addr := reserveLoopbackPort(t)
testURL := "http://" + addr
t.Setenv("SERVER_URL", testURL) t.Setenv("SERVER_URL", testURL)
t.Setenv("SERVER_ADDRESS", addr)
t.Setenv("STORE_DIR", testDir) t.Setenv("STORE_DIR", testDir)
srv := server.NewServer() srv := server.NewServer()
go func() { go func() {
@@ -33,6 +37,7 @@ func TestUpload(t *testing.T) {
t.Errorf("Failed to stop server: %v", err) t.Errorf("Failed to stop server: %v", err)
} }
}) })
waitForServer(t, addr)
file := filepath.Join(t.TempDir(), "tmpfile") file := filepath.Join(t.TempDir(), "tmpfile")
fileContent := []byte("test file content") fileContent := []byte("test file content")
@@ -47,3 +52,30 @@ func TestUpload(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, fileContent, createdFileContent) require.Equal(t, fileContent, createdFileContent)
} }
// reserveLoopbackPort binds an ephemeral port on loopback to learn a free
// address, then releases it so the server under test can rebind. The close/
// rebind window is racy in theory; on loopback with a kernel-assigned port
// it's essentially never contended in practice.
func reserveLoopbackPort(t *testing.T) string {
t.Helper()
l, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
addr := l.Addr().String()
require.NoError(t, l.Close())
return addr
}
func waitForServer(t *testing.T, addr string) {
t.Helper()
deadline := time.Now().Add(5 * time.Second)
for time.Now().Before(deadline) {
c, err := net.DialTimeout("tcp", addr, 100*time.Millisecond)
if err == nil {
_ = c.Close()
return
}
time.Sleep(20 * time.Millisecond)
}
t.Fatalf("server did not start listening on %s in time", addr)
}

View File

@@ -457,6 +457,18 @@ func TestReceive_ProtocolErrorStreamReconnect(t *testing.T) {
client, err := flow.NewClient("http://"+server.addr, "test-payload", "test-signature", 1*time.Second) client, err := flow.NewClient("http://"+server.addr, "test-payload", "test-signature", 1*time.Second)
require.NoError(t, err) require.NoError(t, err)
// Cleanups run LIFO: the goroutine-drain registered here runs after Close below,
// which is when Receive has actually returned. Without this, the Receive goroutine
// can outlive the test and call t.Logf after teardown, panicking.
receiveDone := make(chan struct{})
t.Cleanup(func() {
select {
case <-receiveDone:
case <-time.After(2 * time.Second):
t.Error("Receive goroutine did not exit after Close")
}
})
t.Cleanup(func() { t.Cleanup(func() {
err := client.Close() err := client.Close()
assert.NoError(t, err, "failed to close flow") assert.NoError(t, err, "failed to close flow")
@@ -468,6 +480,7 @@ func TestReceive_ProtocolErrorStreamReconnect(t *testing.T) {
receivedAfterReconnect := make(chan struct{}) receivedAfterReconnect := make(chan struct{})
go func() { go func() {
defer close(receiveDone)
err := client.Receive(ctx, 1*time.Second, func(msg *proto.FlowEventAck) error { err := client.Receive(ctx, 1*time.Second, func(msg *proto.FlowEventAck) error {
if msg.IsInitiator || len(msg.EventId) == 0 { if msg.IsInitiator || len(msg.EventId) == 0 {
return nil return nil

View File

@@ -5,6 +5,7 @@ import (
_ "embed" _ "embed"
"github.com/rs/xid" "github.com/rs/xid"
"github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/permissions/operations"
@@ -46,25 +47,40 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
var isUpdate = policy.ID != "" var isUpdate = policy.ID != ""
var updateAccountPeers bool var updateAccountPeers bool
var action = activity.PolicyAdded var action = activity.PolicyAdded
var unchanged bool
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if err = validatePolicy(ctx, transaction, accountID, policy); err != nil { existingPolicy, err := validatePolicy(ctx, transaction, accountID, policy)
return err
}
updateAccountPeers, err = arePolicyChangesAffectPeers(ctx, transaction, accountID, policy, isUpdate)
if err != nil { if err != nil {
return err return err
} }
saveFunc := transaction.CreatePolicy
if isUpdate { if isUpdate {
action = activity.PolicyUpdated if policy.Equal(existingPolicy) {
saveFunc = transaction.SavePolicy logrus.WithContext(ctx).Tracef("policy update skipped because equal to stored one - policy id %s", policy.ID)
} unchanged = true
return nil
}
if err = saveFunc(ctx, policy); err != nil { action = activity.PolicyUpdated
return err
updateAccountPeers, err = arePolicyChangesAffectPeersWithExisting(ctx, transaction, policy, existingPolicy)
if err != nil {
return err
}
if err = transaction.SavePolicy(ctx, policy); err != nil {
return err
}
} else {
updateAccountPeers, err = arePolicyChangesAffectPeers(ctx, transaction, policy)
if err != nil {
return err
}
if err = transaction.CreatePolicy(ctx, policy); err != nil {
return err
}
} }
return transaction.IncrementNetworkSerial(ctx, accountID) return transaction.IncrementNetworkSerial(ctx, accountID)
@@ -73,6 +89,10 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
return nil, err return nil, err
} }
if unchanged {
return policy, nil
}
am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta()) am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta())
if updateAccountPeers { if updateAccountPeers {
@@ -101,7 +121,7 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po
return err return err
} }
updateAccountPeers, err = arePolicyChangesAffectPeers(ctx, transaction, accountID, policy, false) updateAccountPeers, err = arePolicyChangesAffectPeers(ctx, transaction, policy)
if err != nil { if err != nil {
return err return err
} }
@@ -138,34 +158,37 @@ func (am *DefaultAccountManager) ListPolicies(ctx context.Context, accountID, us
return am.Store.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID) return am.Store.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
} }
// arePolicyChangesAffectPeers checks if changes to a policy will affect any associated peers. // arePolicyChangesAffectPeers checks if a policy (being created or deleted) will affect any associated peers.
func arePolicyChangesAffectPeers(ctx context.Context, transaction store.Store, accountID string, policy *types.Policy, isUpdate bool) (bool, error) { func arePolicyChangesAffectPeers(ctx context.Context, transaction store.Store, policy *types.Policy) (bool, error) {
if isUpdate { for _, rule := range policy.Rules {
existingPolicy, err := transaction.GetPolicyByID(ctx, store.LockingStrengthNone, accountID, policy.ID) if rule.SourceResource.Type != "" || rule.DestinationResource.Type != "" {
if err != nil {
return false, err
}
if !policy.Enabled && !existingPolicy.Enabled {
return false, nil
}
for _, rule := range existingPolicy.Rules {
if rule.SourceResource.Type != "" || rule.DestinationResource.Type != "" {
return true, nil
}
}
hasPeers, err := anyGroupHasPeersOrResources(ctx, transaction, policy.AccountID, existingPolicy.RuleGroups())
if err != nil {
return false, err
}
if hasPeers {
return true, nil return true, nil
} }
} }
return anyGroupHasPeersOrResources(ctx, transaction, policy.AccountID, policy.RuleGroups())
}
func arePolicyChangesAffectPeersWithExisting(ctx context.Context, transaction store.Store, policy *types.Policy, existingPolicy *types.Policy) (bool, error) {
if !policy.Enabled && !existingPolicy.Enabled {
return false, nil
}
for _, rule := range existingPolicy.Rules {
if rule.SourceResource.Type != "" || rule.DestinationResource.Type != "" {
return true, nil
}
}
hasPeers, err := anyGroupHasPeersOrResources(ctx, transaction, policy.AccountID, existingPolicy.RuleGroups())
if err != nil {
return false, err
}
if hasPeers {
return true, nil
}
for _, rule := range policy.Rules { for _, rule := range policy.Rules {
if rule.SourceResource.Type != "" || rule.DestinationResource.Type != "" { if rule.SourceResource.Type != "" || rule.DestinationResource.Type != "" {
return true, nil return true, nil
@@ -175,12 +198,15 @@ func arePolicyChangesAffectPeers(ctx context.Context, transaction store.Store, a
return anyGroupHasPeersOrResources(ctx, transaction, policy.AccountID, policy.RuleGroups()) return anyGroupHasPeersOrResources(ctx, transaction, policy.AccountID, policy.RuleGroups())
} }
// validatePolicy validates the policy and its rules. // validatePolicy validates the policy and its rules. For updates it returns
func validatePolicy(ctx context.Context, transaction store.Store, accountID string, policy *types.Policy) error { // the existing policy loaded from the store so callers can avoid a second read.
func validatePolicy(ctx context.Context, transaction store.Store, accountID string, policy *types.Policy) (*types.Policy, error) {
var existingPolicy *types.Policy
if policy.ID != "" { if policy.ID != "" {
existingPolicy, err := transaction.GetPolicyByID(ctx, store.LockingStrengthNone, accountID, policy.ID) var err error
existingPolicy, err = transaction.GetPolicyByID(ctx, store.LockingStrengthNone, accountID, policy.ID)
if err != nil { if err != nil {
return err return nil, err
} }
// TODO: Refactor to support multiple rules per policy // TODO: Refactor to support multiple rules per policy
@@ -191,7 +217,7 @@ func validatePolicy(ctx context.Context, transaction store.Store, accountID stri
for _, rule := range policy.Rules { for _, rule := range policy.Rules {
if rule.ID != "" && !existingRuleIDs[rule.ID] { if rule.ID != "" && !existingRuleIDs[rule.ID] {
return status.Errorf(status.InvalidArgument, "invalid rule ID: %s", rule.ID) return nil, status.Errorf(status.InvalidArgument, "invalid rule ID: %s", rule.ID)
} }
} }
} else { } else {
@@ -201,12 +227,12 @@ func validatePolicy(ctx context.Context, transaction store.Store, accountID stri
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, policy.RuleGroups()) groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, policy.RuleGroups())
if err != nil { if err != nil {
return err return nil, err
} }
postureChecks, err := transaction.GetPostureChecksByIDs(ctx, store.LockingStrengthNone, accountID, policy.SourcePostureChecks) postureChecks, err := transaction.GetPostureChecksByIDs(ctx, store.LockingStrengthNone, accountID, policy.SourcePostureChecks)
if err != nil { if err != nil {
return err return nil, err
} }
for i, rule := range policy.Rules { for i, rule := range policy.Rules {
@@ -225,7 +251,7 @@ func validatePolicy(ctx context.Context, transaction store.Store, accountID stri
policy.SourcePostureChecks = getValidPostureCheckIDs(postureChecks, policy.SourcePostureChecks) policy.SourcePostureChecks = getValidPostureCheckIDs(postureChecks, policy.SourcePostureChecks)
} }
return nil return existingPolicy, nil
} }
// getValidPostureCheckIDs filters and returns only the valid posture check IDs from the provided list. // getValidPostureCheckIDs filters and returns only the valid posture check IDs from the provided list.

View File

@@ -93,6 +93,44 @@ func (p *Policy) Copy() *Policy {
return c return c
} }
func (p *Policy) Equal(other *Policy) bool {
if p == nil || other == nil {
return p == other
}
if p.ID != other.ID ||
p.AccountID != other.AccountID ||
p.Name != other.Name ||
p.Description != other.Description ||
p.Enabled != other.Enabled {
return false
}
if !stringSlicesEqualUnordered(p.SourcePostureChecks, other.SourcePostureChecks) {
return false
}
if len(p.Rules) != len(other.Rules) {
return false
}
otherRules := make(map[string]*PolicyRule, len(other.Rules))
for _, r := range other.Rules {
otherRules[r.ID] = r
}
for _, r := range p.Rules {
otherRule, ok := otherRules[r.ID]
if !ok {
return false
}
if !r.Equal(otherRule) {
return false
}
}
return true
}
// EventMeta returns activity event meta related to this policy // EventMeta returns activity event meta related to this policy
func (p *Policy) EventMeta() map[string]any { func (p *Policy) EventMeta() map[string]any {
return map[string]any{"name": p.Name} return map[string]any{"name": p.Name}

View File

@@ -0,0 +1,193 @@
package types
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestPolicyEqual_SameRulesDifferentOrder(t *testing.T) {
a := &Policy{
ID: "pol1",
AccountID: "acc1",
Name: "test",
Enabled: true,
Rules: []*PolicyRule{
{ID: "r1", PolicyID: "pol1", Ports: []string{"80"}},
{ID: "r2", PolicyID: "pol1", Ports: []string{"443"}},
},
}
b := &Policy{
ID: "pol1",
AccountID: "acc1",
Name: "test",
Enabled: true,
Rules: []*PolicyRule{
{ID: "r2", PolicyID: "pol1", Ports: []string{"443"}},
{ID: "r1", PolicyID: "pol1", Ports: []string{"80"}},
},
}
assert.True(t, a.Equal(b))
}
func TestPolicyEqual_DifferentRules(t *testing.T) {
a := &Policy{
ID: "pol1",
Enabled: true,
Rules: []*PolicyRule{
{ID: "r1", PolicyID: "pol1", Ports: []string{"80"}},
},
}
b := &Policy{
ID: "pol1",
Enabled: true,
Rules: []*PolicyRule{
{ID: "r1", PolicyID: "pol1", Ports: []string{"443"}},
},
}
assert.False(t, a.Equal(b))
}
func TestPolicyEqual_DifferentRuleCount(t *testing.T) {
a := &Policy{
ID: "pol1",
Rules: []*PolicyRule{
{ID: "r1", PolicyID: "pol1"},
},
}
b := &Policy{
ID: "pol1",
Rules: []*PolicyRule{
{ID: "r1", PolicyID: "pol1"},
{ID: "r2", PolicyID: "pol1"},
},
}
assert.False(t, a.Equal(b))
}
func TestPolicyEqual_PostureChecksDifferentOrder(t *testing.T) {
a := &Policy{
ID: "pol1",
SourcePostureChecks: []string{"pc3", "pc1", "pc2"},
}
b := &Policy{
ID: "pol1",
SourcePostureChecks: []string{"pc1", "pc2", "pc3"},
}
assert.True(t, a.Equal(b))
}
func TestPolicyEqual_DifferentPostureChecks(t *testing.T) {
a := &Policy{
ID: "pol1",
SourcePostureChecks: []string{"pc1", "pc2"},
}
b := &Policy{
ID: "pol1",
SourcePostureChecks: []string{"pc1", "pc3"},
}
assert.False(t, a.Equal(b))
}
func TestPolicyEqual_DifferentScalarFields(t *testing.T) {
base := Policy{
ID: "pol1",
AccountID: "acc1",
Name: "test",
Description: "desc",
Enabled: true,
}
other := base
other.Name = "changed"
assert.False(t, base.Equal(&other))
other = base
other.Enabled = false
assert.False(t, base.Equal(&other))
other = base
other.Description = "changed"
assert.False(t, base.Equal(&other))
}
func TestPolicyEqual_NilCases(t *testing.T) {
var a *Policy
var b *Policy
assert.True(t, a.Equal(b))
a = &Policy{ID: "pol1"}
assert.False(t, a.Equal(nil))
}
func TestPolicyEqual_RulesMismatchByID(t *testing.T) {
a := &Policy{
ID: "pol1",
Rules: []*PolicyRule{
{ID: "r1", PolicyID: "pol1"},
},
}
b := &Policy{
ID: "pol1",
Rules: []*PolicyRule{
{ID: "r2", PolicyID: "pol1"},
},
}
assert.False(t, a.Equal(b))
}
func TestPolicyEqual_FullScenario(t *testing.T) {
a := &Policy{
ID: "pol1",
AccountID: "acc1",
Name: "Web Access",
Description: "Allow web access",
Enabled: true,
SourcePostureChecks: []string{"pc2", "pc1"},
Rules: []*PolicyRule{
{
ID: "r1",
PolicyID: "pol1",
Name: "HTTP",
Enabled: true,
Action: PolicyTrafficActionAccept,
Protocol: PolicyRuleProtocolTCP,
Bidirectional: true,
Sources: []string{"g2", "g1"},
Destinations: []string{"g4", "g3"},
Ports: []string{"443", "80", "8080"},
PortRanges: []RulePortRange{
{Start: 8000, End: 9000},
{Start: 80, End: 80},
},
},
},
}
b := &Policy{
ID: "pol1",
AccountID: "acc1",
Name: "Web Access",
Description: "Allow web access",
Enabled: true,
SourcePostureChecks: []string{"pc1", "pc2"},
Rules: []*PolicyRule{
{
ID: "r1",
PolicyID: "pol1",
Name: "HTTP",
Enabled: true,
Action: PolicyTrafficActionAccept,
Protocol: PolicyRuleProtocolTCP,
Bidirectional: true,
Sources: []string{"g1", "g2"},
Destinations: []string{"g3", "g4"},
Ports: []string{"80", "8080", "443"},
PortRanges: []RulePortRange{
{Start: 80, End: 80},
{Start: 8000, End: 9000},
},
},
},
}
assert.True(t, a.Equal(b))
}

View File

@@ -1,6 +1,8 @@
package types package types
import ( import (
"slices"
"github.com/netbirdio/netbird/shared/management/proto" "github.com/netbirdio/netbird/shared/management/proto"
) )
@@ -118,3 +120,106 @@ func (pm *PolicyRule) Copy() *PolicyRule {
} }
return rule return rule
} }
func (pm *PolicyRule) Equal(other *PolicyRule) bool {
if pm == nil || other == nil {
return pm == other
}
if pm.ID != other.ID ||
pm.PolicyID != other.PolicyID ||
pm.Name != other.Name ||
pm.Description != other.Description ||
pm.Enabled != other.Enabled ||
pm.Action != other.Action ||
pm.Bidirectional != other.Bidirectional ||
pm.Protocol != other.Protocol ||
pm.SourceResource != other.SourceResource ||
pm.DestinationResource != other.DestinationResource ||
pm.AuthorizedUser != other.AuthorizedUser {
return false
}
if !stringSlicesEqualUnordered(pm.Sources, other.Sources) {
return false
}
if !stringSlicesEqualUnordered(pm.Destinations, other.Destinations) {
return false
}
if !stringSlicesEqualUnordered(pm.Ports, other.Ports) {
return false
}
if !portRangeSlicesEqualUnordered(pm.PortRanges, other.PortRanges) {
return false
}
if !authorizedGroupsEqual(pm.AuthorizedGroups, other.AuthorizedGroups) {
return false
}
return true
}
func stringSlicesEqualUnordered(a, b []string) bool {
if len(a) != len(b) {
return false
}
if len(a) == 0 {
return true
}
sorted1 := make([]string, len(a))
sorted2 := make([]string, len(b))
copy(sorted1, a)
copy(sorted2, b)
slices.Sort(sorted1)
slices.Sort(sorted2)
return slices.Equal(sorted1, sorted2)
}
func portRangeSlicesEqualUnordered(a, b []RulePortRange) bool {
if len(a) != len(b) {
return false
}
if len(a) == 0 {
return true
}
cmp := func(x, y RulePortRange) int {
if x.Start != y.Start {
if x.Start < y.Start {
return -1
}
return 1
}
if x.End != y.End {
if x.End < y.End {
return -1
}
return 1
}
return 0
}
sorted1 := make([]RulePortRange, len(a))
sorted2 := make([]RulePortRange, len(b))
copy(sorted1, a)
copy(sorted2, b)
slices.SortFunc(sorted1, cmp)
slices.SortFunc(sorted2, cmp)
return slices.EqualFunc(sorted1, sorted2, func(x, y RulePortRange) bool {
return x.Start == y.Start && x.End == y.End
})
}
func authorizedGroupsEqual(a, b map[string][]string) bool {
if len(a) != len(b) {
return false
}
for k, va := range a {
vb, ok := b[k]
if !ok {
return false
}
if !stringSlicesEqualUnordered(va, vb) {
return false
}
}
return true
}

View File

@@ -0,0 +1,194 @@
package types
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestPolicyRuleEqual_SamePortsDifferentOrder(t *testing.T) {
a := &PolicyRule{
ID: "rule1",
PolicyID: "pol1",
Ports: []string{"443", "80", "22"},
}
b := &PolicyRule{
ID: "rule1",
PolicyID: "pol1",
Ports: []string{"22", "443", "80"},
}
assert.True(t, a.Equal(b))
}
func TestPolicyRuleEqual_DifferentPorts(t *testing.T) {
a := &PolicyRule{
ID: "rule1",
PolicyID: "pol1",
Ports: []string{"443", "80"},
}
b := &PolicyRule{
ID: "rule1",
PolicyID: "pol1",
Ports: []string{"443", "22"},
}
assert.False(t, a.Equal(b))
}
func TestPolicyRuleEqual_SourcesDestinationsDifferentOrder(t *testing.T) {
a := &PolicyRule{
ID: "rule1",
PolicyID: "pol1",
Sources: []string{"g1", "g2", "g3"},
Destinations: []string{"g4", "g5"},
}
b := &PolicyRule{
ID: "rule1",
PolicyID: "pol1",
Sources: []string{"g3", "g1", "g2"},
Destinations: []string{"g5", "g4"},
}
assert.True(t, a.Equal(b))
}
func TestPolicyRuleEqual_DifferentSources(t *testing.T) {
a := &PolicyRule{
ID: "rule1",
PolicyID: "pol1",
Sources: []string{"g1", "g2"},
}
b := &PolicyRule{
ID: "rule1",
PolicyID: "pol1",
Sources: []string{"g1", "g3"},
}
assert.False(t, a.Equal(b))
}
func TestPolicyRuleEqual_PortRangesDifferentOrder(t *testing.T) {
a := &PolicyRule{
ID: "rule1",
PolicyID: "pol1",
PortRanges: []RulePortRange{
{Start: 8000, End: 9000},
{Start: 80, End: 80},
},
}
b := &PolicyRule{
ID: "rule1",
PolicyID: "pol1",
PortRanges: []RulePortRange{
{Start: 80, End: 80},
{Start: 8000, End: 9000},
},
}
assert.True(t, a.Equal(b))
}
func TestPolicyRuleEqual_DifferentPortRanges(t *testing.T) {
a := &PolicyRule{
ID: "rule1",
PolicyID: "pol1",
PortRanges: []RulePortRange{
{Start: 80, End: 80},
},
}
b := &PolicyRule{
ID: "rule1",
PolicyID: "pol1",
PortRanges: []RulePortRange{
{Start: 80, End: 443},
},
}
assert.False(t, a.Equal(b))
}
func TestPolicyRuleEqual_AuthorizedGroupsDifferentValueOrder(t *testing.T) {
a := &PolicyRule{
ID: "rule1",
PolicyID: "pol1",
AuthorizedGroups: map[string][]string{
"g1": {"u1", "u2", "u3"},
},
}
b := &PolicyRule{
ID: "rule1",
PolicyID: "pol1",
AuthorizedGroups: map[string][]string{
"g1": {"u3", "u1", "u2"},
},
}
assert.True(t, a.Equal(b))
}
func TestPolicyRuleEqual_DifferentAuthorizedGroups(t *testing.T) {
a := &PolicyRule{
ID: "rule1",
PolicyID: "pol1",
AuthorizedGroups: map[string][]string{
"g1": {"u1"},
},
}
b := &PolicyRule{
ID: "rule1",
PolicyID: "pol1",
AuthorizedGroups: map[string][]string{
"g2": {"u1"},
},
}
assert.False(t, a.Equal(b))
}
func TestPolicyRuleEqual_DifferentScalarFields(t *testing.T) {
base := PolicyRule{
ID: "rule1",
PolicyID: "pol1",
Name: "test",
Description: "desc",
Enabled: true,
Action: PolicyTrafficActionAccept,
Bidirectional: true,
Protocol: PolicyRuleProtocolTCP,
}
other := base
other.Name = "changed"
assert.False(t, base.Equal(&other))
other = base
other.Enabled = false
assert.False(t, base.Equal(&other))
other = base
other.Action = PolicyTrafficActionDrop
assert.False(t, base.Equal(&other))
other = base
other.Protocol = PolicyRuleProtocolUDP
assert.False(t, base.Equal(&other))
}
func TestPolicyRuleEqual_NilCases(t *testing.T) {
var a *PolicyRule
var b *PolicyRule
assert.True(t, a.Equal(b))
a = &PolicyRule{ID: "rule1"}
assert.False(t, a.Equal(nil))
}
func TestPolicyRuleEqual_EmptySlices(t *testing.T) {
a := &PolicyRule{
ID: "rule1",
PolicyID: "pol1",
Ports: []string{},
Sources: nil,
}
b := &PolicyRule{
ID: "rule1",
PolicyID: "pol1",
Ports: nil,
Sources: []string{},
}
assert.True(t, a.Equal(b))
}