mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-21 17:56:39 +00:00
Compare commits
4 Commits
fix-crowds
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1165058fad | ||
|
|
703353d354 | ||
|
|
2fb50aef6b | ||
|
|
eb3aa96257 |
@@ -239,8 +239,12 @@ func TestICEBind_HandlesConcurrentMixedTraffic(t *testing.T) {
|
|||||||
ipv6Count++
|
ipv6Count++
|
||||||
}
|
}
|
||||||
|
|
||||||
assert.Equal(t, packetsPerFamily, ipv4Count)
|
// Allow some UDP packet loss under load (e.g. FreeBSD/QEMU runners). The
|
||||||
assert.Equal(t, packetsPerFamily, ipv6Count)
|
// routing-correctness checks above are the real assertions; the counts
|
||||||
|
// are a sanity bound to catch a totally silent path.
|
||||||
|
minDelivered := packetsPerFamily * 80 / 100
|
||||||
|
assert.GreaterOrEqual(t, ipv4Count, minDelivered, "IPv4 delivery below threshold")
|
||||||
|
assert.GreaterOrEqual(t, ipv6Count, minDelivered, "IPv6 delivery below threshold")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestICEBind_DetectsAddressFamilyFromConnection(t *testing.T) {
|
func TestICEBind_DetectsAddressFamilyFromConnection(t *testing.T) {
|
||||||
|
|||||||
@@ -3,10 +3,12 @@ package debug
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
@@ -19,8 +21,10 @@ func TestUpload(t *testing.T) {
|
|||||||
t.Skip("Skipping upload test on docker ci")
|
t.Skip("Skipping upload test on docker ci")
|
||||||
}
|
}
|
||||||
testDir := t.TempDir()
|
testDir := t.TempDir()
|
||||||
testURL := "http://localhost:8080"
|
addr := reserveLoopbackPort(t)
|
||||||
|
testURL := "http://" + addr
|
||||||
t.Setenv("SERVER_URL", testURL)
|
t.Setenv("SERVER_URL", testURL)
|
||||||
|
t.Setenv("SERVER_ADDRESS", addr)
|
||||||
t.Setenv("STORE_DIR", testDir)
|
t.Setenv("STORE_DIR", testDir)
|
||||||
srv := server.NewServer()
|
srv := server.NewServer()
|
||||||
go func() {
|
go func() {
|
||||||
@@ -33,6 +37,7 @@ func TestUpload(t *testing.T) {
|
|||||||
t.Errorf("Failed to stop server: %v", err)
|
t.Errorf("Failed to stop server: %v", err)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
waitForServer(t, addr)
|
||||||
|
|
||||||
file := filepath.Join(t.TempDir(), "tmpfile")
|
file := filepath.Join(t.TempDir(), "tmpfile")
|
||||||
fileContent := []byte("test file content")
|
fileContent := []byte("test file content")
|
||||||
@@ -47,3 +52,30 @@ func TestUpload(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, fileContent, createdFileContent)
|
require.Equal(t, fileContent, createdFileContent)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// reserveLoopbackPort binds an ephemeral port on loopback to learn a free
|
||||||
|
// address, then releases it so the server under test can rebind. The close/
|
||||||
|
// rebind window is racy in theory; on loopback with a kernel-assigned port
|
||||||
|
// it's essentially never contended in practice.
|
||||||
|
func reserveLoopbackPort(t *testing.T) string {
|
||||||
|
t.Helper()
|
||||||
|
l, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
require.NoError(t, err)
|
||||||
|
addr := l.Addr().String()
|
||||||
|
require.NoError(t, l.Close())
|
||||||
|
return addr
|
||||||
|
}
|
||||||
|
|
||||||
|
func waitForServer(t *testing.T, addr string) {
|
||||||
|
t.Helper()
|
||||||
|
deadline := time.Now().Add(5 * time.Second)
|
||||||
|
for time.Now().Before(deadline) {
|
||||||
|
c, err := net.DialTimeout("tcp", addr, 100*time.Millisecond)
|
||||||
|
if err == nil {
|
||||||
|
_ = c.Close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
time.Sleep(20 * time.Millisecond)
|
||||||
|
}
|
||||||
|
t.Fatalf("server did not start listening on %s in time", addr)
|
||||||
|
}
|
||||||
|
|||||||
@@ -457,6 +457,18 @@ func TestReceive_ProtocolErrorStreamReconnect(t *testing.T) {
|
|||||||
|
|
||||||
client, err := flow.NewClient("http://"+server.addr, "test-payload", "test-signature", 1*time.Second)
|
client, err := flow.NewClient("http://"+server.addr, "test-payload", "test-signature", 1*time.Second)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Cleanups run LIFO: the goroutine-drain registered here runs after Close below,
|
||||||
|
// which is when Receive has actually returned. Without this, the Receive goroutine
|
||||||
|
// can outlive the test and call t.Logf after teardown, panicking.
|
||||||
|
receiveDone := make(chan struct{})
|
||||||
|
t.Cleanup(func() {
|
||||||
|
select {
|
||||||
|
case <-receiveDone:
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Error("Receive goroutine did not exit after Close")
|
||||||
|
}
|
||||||
|
})
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
err := client.Close()
|
err := client.Close()
|
||||||
assert.NoError(t, err, "failed to close flow")
|
assert.NoError(t, err, "failed to close flow")
|
||||||
@@ -468,6 +480,7 @@ func TestReceive_ProtocolErrorStreamReconnect(t *testing.T) {
|
|||||||
receivedAfterReconnect := make(chan struct{})
|
receivedAfterReconnect := make(chan struct{})
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
|
defer close(receiveDone)
|
||||||
err := client.Receive(ctx, 1*time.Second, func(msg *proto.FlowEventAck) error {
|
err := client.Receive(ctx, 1*time.Second, func(msg *proto.FlowEventAck) error {
|
||||||
if msg.IsInitiator || len(msg.EventId) == 0 {
|
if msg.IsInitiator || len(msg.EventId) == 0 {
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
_ "embed"
|
_ "embed"
|
||||||
|
|
||||||
"github.com/rs/xid"
|
"github.com/rs/xid"
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||||
@@ -46,25 +47,40 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
|
|||||||
var isUpdate = policy.ID != ""
|
var isUpdate = policy.ID != ""
|
||||||
var updateAccountPeers bool
|
var updateAccountPeers bool
|
||||||
var action = activity.PolicyAdded
|
var action = activity.PolicyAdded
|
||||||
|
var unchanged bool
|
||||||
|
|
||||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||||
if err = validatePolicy(ctx, transaction, accountID, policy); err != nil {
|
existingPolicy, err := validatePolicy(ctx, transaction, accountID, policy)
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
updateAccountPeers, err = arePolicyChangesAffectPeers(ctx, transaction, accountID, policy, isUpdate)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
saveFunc := transaction.CreatePolicy
|
|
||||||
if isUpdate {
|
if isUpdate {
|
||||||
action = activity.PolicyUpdated
|
if policy.Equal(existingPolicy) {
|
||||||
saveFunc = transaction.SavePolicy
|
logrus.WithContext(ctx).Tracef("policy update skipped because equal to stored one - policy id %s", policy.ID)
|
||||||
}
|
unchanged = true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
if err = saveFunc(ctx, policy); err != nil {
|
action = activity.PolicyUpdated
|
||||||
return err
|
|
||||||
|
updateAccountPeers, err = arePolicyChangesAffectPeersWithExisting(ctx, transaction, policy, existingPolicy)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = transaction.SavePolicy(ctx, policy); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
updateAccountPeers, err = arePolicyChangesAffectPeers(ctx, transaction, policy)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = transaction.CreatePolicy(ctx, policy); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return transaction.IncrementNetworkSerial(ctx, accountID)
|
return transaction.IncrementNetworkSerial(ctx, accountID)
|
||||||
@@ -73,6 +89,10 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if unchanged {
|
||||||
|
return policy, nil
|
||||||
|
}
|
||||||
|
|
||||||
am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta())
|
am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta())
|
||||||
|
|
||||||
if updateAccountPeers {
|
if updateAccountPeers {
|
||||||
@@ -101,7 +121,7 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
updateAccountPeers, err = arePolicyChangesAffectPeers(ctx, transaction, accountID, policy, false)
|
updateAccountPeers, err = arePolicyChangesAffectPeers(ctx, transaction, policy)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -138,34 +158,37 @@ func (am *DefaultAccountManager) ListPolicies(ctx context.Context, accountID, us
|
|||||||
return am.Store.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
|
return am.Store.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// arePolicyChangesAffectPeers checks if changes to a policy will affect any associated peers.
|
// arePolicyChangesAffectPeers checks if a policy (being created or deleted) will affect any associated peers.
|
||||||
func arePolicyChangesAffectPeers(ctx context.Context, transaction store.Store, accountID string, policy *types.Policy, isUpdate bool) (bool, error) {
|
func arePolicyChangesAffectPeers(ctx context.Context, transaction store.Store, policy *types.Policy) (bool, error) {
|
||||||
if isUpdate {
|
for _, rule := range policy.Rules {
|
||||||
existingPolicy, err := transaction.GetPolicyByID(ctx, store.LockingStrengthNone, accountID, policy.ID)
|
if rule.SourceResource.Type != "" || rule.DestinationResource.Type != "" {
|
||||||
if err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if !policy.Enabled && !existingPolicy.Enabled {
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, rule := range existingPolicy.Rules {
|
|
||||||
if rule.SourceResource.Type != "" || rule.DestinationResource.Type != "" {
|
|
||||||
return true, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
hasPeers, err := anyGroupHasPeersOrResources(ctx, transaction, policy.AccountID, existingPolicy.RuleGroups())
|
|
||||||
if err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if hasPeers {
|
|
||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return anyGroupHasPeersOrResources(ctx, transaction, policy.AccountID, policy.RuleGroups())
|
||||||
|
}
|
||||||
|
|
||||||
|
func arePolicyChangesAffectPeersWithExisting(ctx context.Context, transaction store.Store, policy *types.Policy, existingPolicy *types.Policy) (bool, error) {
|
||||||
|
if !policy.Enabled && !existingPolicy.Enabled {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, rule := range existingPolicy.Rules {
|
||||||
|
if rule.SourceResource.Type != "" || rule.DestinationResource.Type != "" {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
hasPeers, err := anyGroupHasPeersOrResources(ctx, transaction, policy.AccountID, existingPolicy.RuleGroups())
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if hasPeers {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
for _, rule := range policy.Rules {
|
for _, rule := range policy.Rules {
|
||||||
if rule.SourceResource.Type != "" || rule.DestinationResource.Type != "" {
|
if rule.SourceResource.Type != "" || rule.DestinationResource.Type != "" {
|
||||||
return true, nil
|
return true, nil
|
||||||
@@ -175,12 +198,15 @@ func arePolicyChangesAffectPeers(ctx context.Context, transaction store.Store, a
|
|||||||
return anyGroupHasPeersOrResources(ctx, transaction, policy.AccountID, policy.RuleGroups())
|
return anyGroupHasPeersOrResources(ctx, transaction, policy.AccountID, policy.RuleGroups())
|
||||||
}
|
}
|
||||||
|
|
||||||
// validatePolicy validates the policy and its rules.
|
// validatePolicy validates the policy and its rules. For updates it returns
|
||||||
func validatePolicy(ctx context.Context, transaction store.Store, accountID string, policy *types.Policy) error {
|
// the existing policy loaded from the store so callers can avoid a second read.
|
||||||
|
func validatePolicy(ctx context.Context, transaction store.Store, accountID string, policy *types.Policy) (*types.Policy, error) {
|
||||||
|
var existingPolicy *types.Policy
|
||||||
if policy.ID != "" {
|
if policy.ID != "" {
|
||||||
existingPolicy, err := transaction.GetPolicyByID(ctx, store.LockingStrengthNone, accountID, policy.ID)
|
var err error
|
||||||
|
existingPolicy, err = transaction.GetPolicyByID(ctx, store.LockingStrengthNone, accountID, policy.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: Refactor to support multiple rules per policy
|
// TODO: Refactor to support multiple rules per policy
|
||||||
@@ -191,7 +217,7 @@ func validatePolicy(ctx context.Context, transaction store.Store, accountID stri
|
|||||||
|
|
||||||
for _, rule := range policy.Rules {
|
for _, rule := range policy.Rules {
|
||||||
if rule.ID != "" && !existingRuleIDs[rule.ID] {
|
if rule.ID != "" && !existingRuleIDs[rule.ID] {
|
||||||
return status.Errorf(status.InvalidArgument, "invalid rule ID: %s", rule.ID)
|
return nil, status.Errorf(status.InvalidArgument, "invalid rule ID: %s", rule.ID)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@@ -201,12 +227,12 @@ func validatePolicy(ctx context.Context, transaction store.Store, accountID stri
|
|||||||
|
|
||||||
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, policy.RuleGroups())
|
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, policy.RuleGroups())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
postureChecks, err := transaction.GetPostureChecksByIDs(ctx, store.LockingStrengthNone, accountID, policy.SourcePostureChecks)
|
postureChecks, err := transaction.GetPostureChecksByIDs(ctx, store.LockingStrengthNone, accountID, policy.SourcePostureChecks)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, rule := range policy.Rules {
|
for i, rule := range policy.Rules {
|
||||||
@@ -225,7 +251,7 @@ func validatePolicy(ctx context.Context, transaction store.Store, accountID stri
|
|||||||
policy.SourcePostureChecks = getValidPostureCheckIDs(postureChecks, policy.SourcePostureChecks)
|
policy.SourcePostureChecks = getValidPostureCheckIDs(postureChecks, policy.SourcePostureChecks)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return existingPolicy, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// getValidPostureCheckIDs filters and returns only the valid posture check IDs from the provided list.
|
// getValidPostureCheckIDs filters and returns only the valid posture check IDs from the provided list.
|
||||||
|
|||||||
@@ -93,6 +93,44 @@ func (p *Policy) Copy() *Policy {
|
|||||||
return c
|
return c
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *Policy) Equal(other *Policy) bool {
|
||||||
|
if p == nil || other == nil {
|
||||||
|
return p == other
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.ID != other.ID ||
|
||||||
|
p.AccountID != other.AccountID ||
|
||||||
|
p.Name != other.Name ||
|
||||||
|
p.Description != other.Description ||
|
||||||
|
p.Enabled != other.Enabled {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if !stringSlicesEqualUnordered(p.SourcePostureChecks, other.SourcePostureChecks) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(p.Rules) != len(other.Rules) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
otherRules := make(map[string]*PolicyRule, len(other.Rules))
|
||||||
|
for _, r := range other.Rules {
|
||||||
|
otherRules[r.ID] = r
|
||||||
|
}
|
||||||
|
for _, r := range p.Rules {
|
||||||
|
otherRule, ok := otherRules[r.ID]
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if !r.Equal(otherRule) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
// EventMeta returns activity event meta related to this policy
|
// EventMeta returns activity event meta related to this policy
|
||||||
func (p *Policy) EventMeta() map[string]any {
|
func (p *Policy) EventMeta() map[string]any {
|
||||||
return map[string]any{"name": p.Name}
|
return map[string]any{"name": p.Name}
|
||||||
|
|||||||
193
management/server/types/policy_test.go
Normal file
193
management/server/types/policy_test.go
Normal file
@@ -0,0 +1,193 @@
|
|||||||
|
package types
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestPolicyEqual_SameRulesDifferentOrder(t *testing.T) {
|
||||||
|
a := &Policy{
|
||||||
|
ID: "pol1",
|
||||||
|
AccountID: "acc1",
|
||||||
|
Name: "test",
|
||||||
|
Enabled: true,
|
||||||
|
Rules: []*PolicyRule{
|
||||||
|
{ID: "r1", PolicyID: "pol1", Ports: []string{"80"}},
|
||||||
|
{ID: "r2", PolicyID: "pol1", Ports: []string{"443"}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
b := &Policy{
|
||||||
|
ID: "pol1",
|
||||||
|
AccountID: "acc1",
|
||||||
|
Name: "test",
|
||||||
|
Enabled: true,
|
||||||
|
Rules: []*PolicyRule{
|
||||||
|
{ID: "r2", PolicyID: "pol1", Ports: []string{"443"}},
|
||||||
|
{ID: "r1", PolicyID: "pol1", Ports: []string{"80"}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
assert.True(t, a.Equal(b))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPolicyEqual_DifferentRules(t *testing.T) {
|
||||||
|
a := &Policy{
|
||||||
|
ID: "pol1",
|
||||||
|
Enabled: true,
|
||||||
|
Rules: []*PolicyRule{
|
||||||
|
{ID: "r1", PolicyID: "pol1", Ports: []string{"80"}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
b := &Policy{
|
||||||
|
ID: "pol1",
|
||||||
|
Enabled: true,
|
||||||
|
Rules: []*PolicyRule{
|
||||||
|
{ID: "r1", PolicyID: "pol1", Ports: []string{"443"}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
assert.False(t, a.Equal(b))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPolicyEqual_DifferentRuleCount(t *testing.T) {
|
||||||
|
a := &Policy{
|
||||||
|
ID: "pol1",
|
||||||
|
Rules: []*PolicyRule{
|
||||||
|
{ID: "r1", PolicyID: "pol1"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
b := &Policy{
|
||||||
|
ID: "pol1",
|
||||||
|
Rules: []*PolicyRule{
|
||||||
|
{ID: "r1", PolicyID: "pol1"},
|
||||||
|
{ID: "r2", PolicyID: "pol1"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
assert.False(t, a.Equal(b))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPolicyEqual_PostureChecksDifferentOrder(t *testing.T) {
|
||||||
|
a := &Policy{
|
||||||
|
ID: "pol1",
|
||||||
|
SourcePostureChecks: []string{"pc3", "pc1", "pc2"},
|
||||||
|
}
|
||||||
|
b := &Policy{
|
||||||
|
ID: "pol1",
|
||||||
|
SourcePostureChecks: []string{"pc1", "pc2", "pc3"},
|
||||||
|
}
|
||||||
|
assert.True(t, a.Equal(b))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPolicyEqual_DifferentPostureChecks(t *testing.T) {
|
||||||
|
a := &Policy{
|
||||||
|
ID: "pol1",
|
||||||
|
SourcePostureChecks: []string{"pc1", "pc2"},
|
||||||
|
}
|
||||||
|
b := &Policy{
|
||||||
|
ID: "pol1",
|
||||||
|
SourcePostureChecks: []string{"pc1", "pc3"},
|
||||||
|
}
|
||||||
|
assert.False(t, a.Equal(b))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPolicyEqual_DifferentScalarFields(t *testing.T) {
|
||||||
|
base := Policy{
|
||||||
|
ID: "pol1",
|
||||||
|
AccountID: "acc1",
|
||||||
|
Name: "test",
|
||||||
|
Description: "desc",
|
||||||
|
Enabled: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
other := base
|
||||||
|
other.Name = "changed"
|
||||||
|
assert.False(t, base.Equal(&other))
|
||||||
|
|
||||||
|
other = base
|
||||||
|
other.Enabled = false
|
||||||
|
assert.False(t, base.Equal(&other))
|
||||||
|
|
||||||
|
other = base
|
||||||
|
other.Description = "changed"
|
||||||
|
assert.False(t, base.Equal(&other))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPolicyEqual_NilCases(t *testing.T) {
|
||||||
|
var a *Policy
|
||||||
|
var b *Policy
|
||||||
|
assert.True(t, a.Equal(b))
|
||||||
|
|
||||||
|
a = &Policy{ID: "pol1"}
|
||||||
|
assert.False(t, a.Equal(nil))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPolicyEqual_RulesMismatchByID(t *testing.T) {
|
||||||
|
a := &Policy{
|
||||||
|
ID: "pol1",
|
||||||
|
Rules: []*PolicyRule{
|
||||||
|
{ID: "r1", PolicyID: "pol1"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
b := &Policy{
|
||||||
|
ID: "pol1",
|
||||||
|
Rules: []*PolicyRule{
|
||||||
|
{ID: "r2", PolicyID: "pol1"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
assert.False(t, a.Equal(b))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPolicyEqual_FullScenario(t *testing.T) {
|
||||||
|
a := &Policy{
|
||||||
|
ID: "pol1",
|
||||||
|
AccountID: "acc1",
|
||||||
|
Name: "Web Access",
|
||||||
|
Description: "Allow web access",
|
||||||
|
Enabled: true,
|
||||||
|
SourcePostureChecks: []string{"pc2", "pc1"},
|
||||||
|
Rules: []*PolicyRule{
|
||||||
|
{
|
||||||
|
ID: "r1",
|
||||||
|
PolicyID: "pol1",
|
||||||
|
Name: "HTTP",
|
||||||
|
Enabled: true,
|
||||||
|
Action: PolicyTrafficActionAccept,
|
||||||
|
Protocol: PolicyRuleProtocolTCP,
|
||||||
|
Bidirectional: true,
|
||||||
|
Sources: []string{"g2", "g1"},
|
||||||
|
Destinations: []string{"g4", "g3"},
|
||||||
|
Ports: []string{"443", "80", "8080"},
|
||||||
|
PortRanges: []RulePortRange{
|
||||||
|
{Start: 8000, End: 9000},
|
||||||
|
{Start: 80, End: 80},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
b := &Policy{
|
||||||
|
ID: "pol1",
|
||||||
|
AccountID: "acc1",
|
||||||
|
Name: "Web Access",
|
||||||
|
Description: "Allow web access",
|
||||||
|
Enabled: true,
|
||||||
|
SourcePostureChecks: []string{"pc1", "pc2"},
|
||||||
|
Rules: []*PolicyRule{
|
||||||
|
{
|
||||||
|
ID: "r1",
|
||||||
|
PolicyID: "pol1",
|
||||||
|
Name: "HTTP",
|
||||||
|
Enabled: true,
|
||||||
|
Action: PolicyTrafficActionAccept,
|
||||||
|
Protocol: PolicyRuleProtocolTCP,
|
||||||
|
Bidirectional: true,
|
||||||
|
Sources: []string{"g1", "g2"},
|
||||||
|
Destinations: []string{"g3", "g4"},
|
||||||
|
Ports: []string{"80", "8080", "443"},
|
||||||
|
PortRanges: []RulePortRange{
|
||||||
|
{Start: 80, End: 80},
|
||||||
|
{Start: 8000, End: 9000},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
assert.True(t, a.Equal(b))
|
||||||
|
}
|
||||||
@@ -1,6 +1,8 @@
|
|||||||
package types
|
package types
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"slices"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/shared/management/proto"
|
"github.com/netbirdio/netbird/shared/management/proto"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -118,3 +120,106 @@ func (pm *PolicyRule) Copy() *PolicyRule {
|
|||||||
}
|
}
|
||||||
return rule
|
return rule
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (pm *PolicyRule) Equal(other *PolicyRule) bool {
|
||||||
|
if pm == nil || other == nil {
|
||||||
|
return pm == other
|
||||||
|
}
|
||||||
|
|
||||||
|
if pm.ID != other.ID ||
|
||||||
|
pm.PolicyID != other.PolicyID ||
|
||||||
|
pm.Name != other.Name ||
|
||||||
|
pm.Description != other.Description ||
|
||||||
|
pm.Enabled != other.Enabled ||
|
||||||
|
pm.Action != other.Action ||
|
||||||
|
pm.Bidirectional != other.Bidirectional ||
|
||||||
|
pm.Protocol != other.Protocol ||
|
||||||
|
pm.SourceResource != other.SourceResource ||
|
||||||
|
pm.DestinationResource != other.DestinationResource ||
|
||||||
|
pm.AuthorizedUser != other.AuthorizedUser {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if !stringSlicesEqualUnordered(pm.Sources, other.Sources) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if !stringSlicesEqualUnordered(pm.Destinations, other.Destinations) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if !stringSlicesEqualUnordered(pm.Ports, other.Ports) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if !portRangeSlicesEqualUnordered(pm.PortRanges, other.PortRanges) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if !authorizedGroupsEqual(pm.AuthorizedGroups, other.AuthorizedGroups) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func stringSlicesEqualUnordered(a, b []string) bool {
|
||||||
|
if len(a) != len(b) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if len(a) == 0 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
sorted1 := make([]string, len(a))
|
||||||
|
sorted2 := make([]string, len(b))
|
||||||
|
copy(sorted1, a)
|
||||||
|
copy(sorted2, b)
|
||||||
|
slices.Sort(sorted1)
|
||||||
|
slices.Sort(sorted2)
|
||||||
|
return slices.Equal(sorted1, sorted2)
|
||||||
|
}
|
||||||
|
|
||||||
|
func portRangeSlicesEqualUnordered(a, b []RulePortRange) bool {
|
||||||
|
if len(a) != len(b) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if len(a) == 0 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
cmp := func(x, y RulePortRange) int {
|
||||||
|
if x.Start != y.Start {
|
||||||
|
if x.Start < y.Start {
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
if x.End != y.End {
|
||||||
|
if x.End < y.End {
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
sorted1 := make([]RulePortRange, len(a))
|
||||||
|
sorted2 := make([]RulePortRange, len(b))
|
||||||
|
copy(sorted1, a)
|
||||||
|
copy(sorted2, b)
|
||||||
|
slices.SortFunc(sorted1, cmp)
|
||||||
|
slices.SortFunc(sorted2, cmp)
|
||||||
|
return slices.EqualFunc(sorted1, sorted2, func(x, y RulePortRange) bool {
|
||||||
|
return x.Start == y.Start && x.End == y.End
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func authorizedGroupsEqual(a, b map[string][]string) bool {
|
||||||
|
if len(a) != len(b) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for k, va := range a {
|
||||||
|
vb, ok := b[k]
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if !stringSlicesEqualUnordered(va, vb) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|||||||
194
management/server/types/policyrule_test.go
Normal file
194
management/server/types/policyrule_test.go
Normal file
@@ -0,0 +1,194 @@
|
|||||||
|
package types
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestPolicyRuleEqual_SamePortsDifferentOrder(t *testing.T) {
|
||||||
|
a := &PolicyRule{
|
||||||
|
ID: "rule1",
|
||||||
|
PolicyID: "pol1",
|
||||||
|
Ports: []string{"443", "80", "22"},
|
||||||
|
}
|
||||||
|
b := &PolicyRule{
|
||||||
|
ID: "rule1",
|
||||||
|
PolicyID: "pol1",
|
||||||
|
Ports: []string{"22", "443", "80"},
|
||||||
|
}
|
||||||
|
assert.True(t, a.Equal(b))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPolicyRuleEqual_DifferentPorts(t *testing.T) {
|
||||||
|
a := &PolicyRule{
|
||||||
|
ID: "rule1",
|
||||||
|
PolicyID: "pol1",
|
||||||
|
Ports: []string{"443", "80"},
|
||||||
|
}
|
||||||
|
b := &PolicyRule{
|
||||||
|
ID: "rule1",
|
||||||
|
PolicyID: "pol1",
|
||||||
|
Ports: []string{"443", "22"},
|
||||||
|
}
|
||||||
|
assert.False(t, a.Equal(b))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPolicyRuleEqual_SourcesDestinationsDifferentOrder(t *testing.T) {
|
||||||
|
a := &PolicyRule{
|
||||||
|
ID: "rule1",
|
||||||
|
PolicyID: "pol1",
|
||||||
|
Sources: []string{"g1", "g2", "g3"},
|
||||||
|
Destinations: []string{"g4", "g5"},
|
||||||
|
}
|
||||||
|
b := &PolicyRule{
|
||||||
|
ID: "rule1",
|
||||||
|
PolicyID: "pol1",
|
||||||
|
Sources: []string{"g3", "g1", "g2"},
|
||||||
|
Destinations: []string{"g5", "g4"},
|
||||||
|
}
|
||||||
|
assert.True(t, a.Equal(b))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPolicyRuleEqual_DifferentSources(t *testing.T) {
|
||||||
|
a := &PolicyRule{
|
||||||
|
ID: "rule1",
|
||||||
|
PolicyID: "pol1",
|
||||||
|
Sources: []string{"g1", "g2"},
|
||||||
|
}
|
||||||
|
b := &PolicyRule{
|
||||||
|
ID: "rule1",
|
||||||
|
PolicyID: "pol1",
|
||||||
|
Sources: []string{"g1", "g3"},
|
||||||
|
}
|
||||||
|
assert.False(t, a.Equal(b))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPolicyRuleEqual_PortRangesDifferentOrder(t *testing.T) {
|
||||||
|
a := &PolicyRule{
|
||||||
|
ID: "rule1",
|
||||||
|
PolicyID: "pol1",
|
||||||
|
PortRanges: []RulePortRange{
|
||||||
|
{Start: 8000, End: 9000},
|
||||||
|
{Start: 80, End: 80},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
b := &PolicyRule{
|
||||||
|
ID: "rule1",
|
||||||
|
PolicyID: "pol1",
|
||||||
|
PortRanges: []RulePortRange{
|
||||||
|
{Start: 80, End: 80},
|
||||||
|
{Start: 8000, End: 9000},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
assert.True(t, a.Equal(b))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPolicyRuleEqual_DifferentPortRanges(t *testing.T) {
|
||||||
|
a := &PolicyRule{
|
||||||
|
ID: "rule1",
|
||||||
|
PolicyID: "pol1",
|
||||||
|
PortRanges: []RulePortRange{
|
||||||
|
{Start: 80, End: 80},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
b := &PolicyRule{
|
||||||
|
ID: "rule1",
|
||||||
|
PolicyID: "pol1",
|
||||||
|
PortRanges: []RulePortRange{
|
||||||
|
{Start: 80, End: 443},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
assert.False(t, a.Equal(b))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPolicyRuleEqual_AuthorizedGroupsDifferentValueOrder(t *testing.T) {
|
||||||
|
a := &PolicyRule{
|
||||||
|
ID: "rule1",
|
||||||
|
PolicyID: "pol1",
|
||||||
|
AuthorizedGroups: map[string][]string{
|
||||||
|
"g1": {"u1", "u2", "u3"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
b := &PolicyRule{
|
||||||
|
ID: "rule1",
|
||||||
|
PolicyID: "pol1",
|
||||||
|
AuthorizedGroups: map[string][]string{
|
||||||
|
"g1": {"u3", "u1", "u2"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
assert.True(t, a.Equal(b))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPolicyRuleEqual_DifferentAuthorizedGroups(t *testing.T) {
|
||||||
|
a := &PolicyRule{
|
||||||
|
ID: "rule1",
|
||||||
|
PolicyID: "pol1",
|
||||||
|
AuthorizedGroups: map[string][]string{
|
||||||
|
"g1": {"u1"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
b := &PolicyRule{
|
||||||
|
ID: "rule1",
|
||||||
|
PolicyID: "pol1",
|
||||||
|
AuthorizedGroups: map[string][]string{
|
||||||
|
"g2": {"u1"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
assert.False(t, a.Equal(b))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPolicyRuleEqual_DifferentScalarFields(t *testing.T) {
|
||||||
|
base := PolicyRule{
|
||||||
|
ID: "rule1",
|
||||||
|
PolicyID: "pol1",
|
||||||
|
Name: "test",
|
||||||
|
Description: "desc",
|
||||||
|
Enabled: true,
|
||||||
|
Action: PolicyTrafficActionAccept,
|
||||||
|
Bidirectional: true,
|
||||||
|
Protocol: PolicyRuleProtocolTCP,
|
||||||
|
}
|
||||||
|
|
||||||
|
other := base
|
||||||
|
other.Name = "changed"
|
||||||
|
assert.False(t, base.Equal(&other))
|
||||||
|
|
||||||
|
other = base
|
||||||
|
other.Enabled = false
|
||||||
|
assert.False(t, base.Equal(&other))
|
||||||
|
|
||||||
|
other = base
|
||||||
|
other.Action = PolicyTrafficActionDrop
|
||||||
|
assert.False(t, base.Equal(&other))
|
||||||
|
|
||||||
|
other = base
|
||||||
|
other.Protocol = PolicyRuleProtocolUDP
|
||||||
|
assert.False(t, base.Equal(&other))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPolicyRuleEqual_NilCases(t *testing.T) {
|
||||||
|
var a *PolicyRule
|
||||||
|
var b *PolicyRule
|
||||||
|
assert.True(t, a.Equal(b))
|
||||||
|
|
||||||
|
a = &PolicyRule{ID: "rule1"}
|
||||||
|
assert.False(t, a.Equal(nil))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPolicyRuleEqual_EmptySlices(t *testing.T) {
|
||||||
|
a := &PolicyRule{
|
||||||
|
ID: "rule1",
|
||||||
|
PolicyID: "pol1",
|
||||||
|
Ports: []string{},
|
||||||
|
Sources: nil,
|
||||||
|
}
|
||||||
|
b := &PolicyRule{
|
||||||
|
ID: "rule1",
|
||||||
|
PolicyID: "pol1",
|
||||||
|
Ports: nil,
|
||||||
|
Sources: []string{},
|
||||||
|
}
|
||||||
|
assert.True(t, a.Equal(b))
|
||||||
|
}
|
||||||
|
|
||||||
Reference in New Issue
Block a user