diff --git a/client/firewall/nftables/manager_linux_test.go b/client/firewall/nftables/manager_linux_test.go index 77f4f0306..33fdc4b3d 100644 --- a/client/firewall/nftables/manager_linux_test.go +++ b/client/firewall/nftables/manager_linux_test.go @@ -1,9 +1,11 @@ package nftables import ( + "bytes" "fmt" "net" "net/netip" + "os/exec" "testing" "time" @@ -225,3 +227,105 @@ func TestNFtablesCreatePerformance(t *testing.T) { }) } } + +func runIptablesSave(t *testing.T) (string, string) { + t.Helper() + var stdout, stderr bytes.Buffer + cmd := exec.Command("iptables-save") + cmd.Stdout = &stdout + cmd.Stderr = &stderr + + err := cmd.Run() + require.NoError(t, err, "iptables-save failed to run") + + return stdout.String(), stderr.String() +} + +func verifyIptablesOutput(t *testing.T, stdout, stderr string) { + t.Helper() + // Check for any incompatibility warnings + require.NotContains(t, + stderr, + "incompatible", + "iptables-save produced compatibility warning. Full stderr: %s", + stderr, + ) + + // Verify standard tables are present + expectedTables := []string{ + "*filter", + "*nat", + "*mangle", + } + + for _, table := range expectedTables { + require.Contains(t, + stdout, + table, + "iptables-save output missing expected table: %s\nFull stdout: %s", + table, + stdout, + ) + } +} + +func TestNftablesManagerCompatibilityWithIptables(t *testing.T) { + if check() != NFTABLES { + t.Skip("nftables not supported on this system") + } + + if _, err := exec.LookPath("iptables-save"); err != nil { + t.Skipf("iptables-save not available on this system: %v", err) + } + + // First ensure iptables-nft tables exist by running iptables-save + stdout, stderr := runIptablesSave(t) + verifyIptablesOutput(t, stdout, stderr) + + manager, err := Create(ifaceMock) + require.NoError(t, err, "failed to create manager") + require.NoError(t, manager.Init(nil)) + + t.Cleanup(func() { + err := manager.Reset(nil) + require.NoError(t, err, "failed to reset manager state") + + // Verify iptables output after reset + stdout, stderr := runIptablesSave(t) + verifyIptablesOutput(t, stdout, stderr) + }) + + ip := net.ParseIP("100.96.0.1") + _, err = manager.AddPeerFiltering( + ip, + fw.ProtocolTCP, + nil, + &fw.Port{Values: []int{80}}, + fw.RuleDirectionIN, + fw.ActionAccept, + "", + "test rule", + ) + require.NoError(t, err, "failed to add peer filtering rule") + + _, err = manager.AddRouteFiltering( + []netip.Prefix{netip.MustParsePrefix("192.168.2.0/24")}, + netip.MustParsePrefix("10.1.0.0/24"), + fw.ProtocolTCP, + nil, + &fw.Port{Values: []int{443}}, + fw.ActionAccept, + ) + require.NoError(t, err, "failed to add route filtering rule") + + pair := fw.RouterPair{ + Source: netip.MustParsePrefix("192.168.1.0/24"), + Destination: netip.MustParsePrefix("10.0.0.0/24"), + Masquerade: true, + } + err = manager.AddNatRule(pair) + require.NoError(t, err, "failed to add NAT rule") + + stdout, stderr = runIptablesSave(t) + verifyIptablesOutput(t, stdout, stderr) +} diff --git a/client/firewall/uspfilter/uspfilter.go b/client/firewall/uspfilter/uspfilter.go index af5dc6733..fb726395b 100644 --- a/client/firewall/uspfilter/uspfilter.go +++ b/client/firewall/uspfilter/uspfilter.go @@ -239,7 +239,7 @@ func (m *Manager) DeletePeerRule(rule firewall.Rule) error { // SetLegacyManagement doesn't need to be implemented for this manager func (m *Manager) SetLegacyManagement(isLegacy bool) error { if m.nativeFirewall == nil { - return errRouteNotSupported + return nil } return m.nativeFirewall.SetLegacyManagement(isLegacy) } diff --git a/client/iface/bind/control_android.go b/client/iface/bind/control_android.go new file mode 100644 index 000000000..b8a865e39 --- /dev/null +++ b/client/iface/bind/control_android.go @@ -0,0 +1,12 @@ +package bind + +import ( + wireguard "golang.zx2c4.com/wireguard/conn" + + nbnet "github.com/netbirdio/netbird/util/net" +) + +func init() { + // ControlFns is not thread safe and should only be modified during init. + *wireguard.ControlFns = append(*wireguard.ControlFns, nbnet.ControlProtectSocket) +} diff --git a/client/internal/routemanager/systemops/systemops_linux.go b/client/internal/routemanager/systemops/systemops_linux.go index 71a0f26ae..455e3407e 100644 --- a/client/internal/routemanager/systemops/systemops_linux.go +++ b/client/internal/routemanager/systemops/systemops_linux.go @@ -55,7 +55,7 @@ type ruleParams struct { // isLegacy determines whether to use the legacy routing setup func isLegacy() bool { - return os.Getenv("NB_USE_LEGACY_ROUTING") == "true" || nbnet.CustomRoutingDisabled() || os.Getenv(nbnet.EnvSkipSocketMark) == "true" + return os.Getenv("NB_USE_LEGACY_ROUTING") == "true" || nbnet.CustomRoutingDisabled() || nbnet.SkipSocketMark() } // setIsLegacy sets the legacy routing setup @@ -92,17 +92,6 @@ func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager return r.setupRefCounter(initAddresses, stateManager) } - if err = addRoutingTableName(); err != nil { - log.Errorf("Error adding routing table name: %v", err) - } - - originalValues, err := sysctl.Setup(r.wgInterface) - if err != nil { - log.Errorf("Error setting up sysctl: %v", err) - sysctlFailed = true - } - originalSysctl = originalValues - defer func() { if err != nil { if cleanErr := r.CleanupRouting(stateManager); cleanErr != nil { @@ -123,6 +112,17 @@ func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager } } + if err = addRoutingTableName(); err != nil { + log.Errorf("Error adding routing table name: %v", err) + } + + originalValues, err := sysctl.Setup(r.wgInterface) + if err != nil { + log.Errorf("Error setting up sysctl: %v", err) + sysctlFailed = true + } + originalSysctl = originalValues + return nil, nil, nil } @@ -450,7 +450,7 @@ func addRule(params ruleParams) error { rule.Invert = params.invert rule.SuppressPrefixlen = params.suppressPrefix - if err := netlink.RuleAdd(rule); err != nil && !errors.Is(err, syscall.EEXIST) && !errors.Is(err, syscall.EAFNOSUPPORT) { + if err := netlink.RuleAdd(rule); err != nil && !errors.Is(err, syscall.EEXIST) { return fmt.Errorf("add routing rule: %w", err) } @@ -467,7 +467,7 @@ func removeRule(params ruleParams) error { rule.Priority = params.priority rule.SuppressPrefixlen = params.suppressPrefix - if err := netlink.RuleDel(rule); err != nil && !errors.Is(err, syscall.ENOENT) && !errors.Is(err, syscall.EAFNOSUPPORT) { + if err := netlink.RuleDel(rule); err != nil && !errors.Is(err, syscall.ENOENT) { return fmt.Errorf("remove routing rule: %w", err) } diff --git a/go.mod b/go.mod index f3b633dfc..73dbceddc 100644 --- a/go.mod +++ b/go.mod @@ -240,7 +240,7 @@ replace github.com/kardianos/service => github.com/netbirdio/service v0.0.0-2024 replace github.com/getlantern/systray => github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949 -replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20241107152827-57d8513b5f73 +replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20241125150134-f9cdce5e32e9 replace github.com/cloudflare/circl => github.com/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6 diff --git a/go.sum b/go.sum index 59cbdd959..8b42ac45f 100644 --- a/go.sum +++ b/go.sum @@ -532,8 +532,8 @@ github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9ax github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d h1:bRq5TKgC7Iq20pDiuC54yXaWnAVeS5PdGpSokFTlR28= github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ= -github.com/netbirdio/wireguard-go v0.0.0-20241107152827-57d8513b5f73 h1:jayg97LH/jJlvpIHVxueTfa+tfQ+FY8fy2sIhCwkz0g= -github.com/netbirdio/wireguard-go v0.0.0-20241107152827-57d8513b5f73/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA= +github.com/netbirdio/wireguard-go v0.0.0-20241125150134-f9cdce5e32e9 h1:Pu/7EukijT09ynHUOzQYW7cC3M/BKU8O4qyN/TvTGoY= +github.com/netbirdio/wireguard-go v0.0.0-20241125150134-f9cdce5e32e9/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA= github.com/nicksnyder/go-i18n/v2 v2.4.0 h1:3IcvPOAvnCKwNm0TB0dLDTuawWEj+ax/RERNC+diLMM= github.com/nicksnyder/go-i18n/v2 v2.4.0/go.mod h1:nxYSZE9M0bf3Y70gPQjN9ha7XNHX7gMc814+6wVyEI4= github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A= diff --git a/management/server/account.go b/management/server/account.go index f5271c0a8..374233617 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -113,7 +113,7 @@ type AccountManager interface { GroupAddPeer(ctx context.Context, accountId, groupID, peerID string) error GroupDeletePeer(ctx context.Context, accountId, groupID, peerID string) error GetPolicy(ctx context.Context, accountID, policyID, userID string) (*Policy, error) - SavePolicy(ctx context.Context, accountID, userID string, policy *Policy, isUpdate bool) error + SavePolicy(ctx context.Context, accountID, userID string, policy *Policy) (*Policy, error) DeletePolicy(ctx context.Context, accountID, policyID, userID string) error ListPolicies(ctx context.Context, accountID, userID string) ([]*Policy, error) GetRoute(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error) @@ -139,7 +139,7 @@ type AccountManager interface { HasConnectedChannel(peerID string) bool GetExternalCacheManager() ExternalCacheManager GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) - SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error + SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) (*posture.Checks, error) DeletePostureChecks(ctx context.Context, accountID, postureChecksID, userID string) error ListPostureChecks(ctx context.Context, accountID, userID string) ([]*posture.Checks, error) GetIdpManager() idp.Manager diff --git a/management/server/account_test.go b/management/server/account_test.go index 97e0d45f0..dbabfd366 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -6,13 +6,17 @@ import ( b64 "encoding/base64" "encoding/json" "fmt" + "io" "net" + "os" "reflect" + "strconv" "sync" "testing" "time" "github.com/golang-jwt/jwt" + log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" @@ -1038,7 +1042,7 @@ func BenchmarkTest_GetAccountWithclaims(b *testing.B) { } b.Run("public without account ID", func(b *testing.B) { - //b.ResetTimer() + // b.ResetTimer() for i := 0; i < b.N; i++ { _, err := am.getAccountIDWithAuthorizationClaims(context.Background(), publicClaims) if err != nil { @@ -1048,7 +1052,7 @@ func BenchmarkTest_GetAccountWithclaims(b *testing.B) { }) b.Run("private without account ID", func(b *testing.B) { - //b.ResetTimer() + // b.ResetTimer() for i := 0; i < b.N; i++ { _, err := am.getAccountIDWithAuthorizationClaims(context.Background(), claims) if err != nil { @@ -1059,7 +1063,7 @@ func BenchmarkTest_GetAccountWithclaims(b *testing.B) { b.Run("private with account ID", func(b *testing.B) { claims.AccountId = id - //b.ResetTimer() + // b.ResetTimer() for i := 0; i < b.N; i++ { _, err := am.getAccountIDWithAuthorizationClaims(context.Background(), claims) if err != nil { @@ -1238,8 +1242,7 @@ func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) { return } - policy := Policy{ - ID: "policy", + _, err := manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ Enabled: true, Rules: []*PolicyRule{ { @@ -1250,8 +1253,7 @@ func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) { Action: PolicyTrafficActionAccept, }, }, - } - err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false) + }) require.NoError(t, err) updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) @@ -1320,19 +1322,6 @@ func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) { updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) - policy := Policy{ - Enabled: true, - Rules: []*PolicyRule{ - { - Enabled: true, - Sources: []string{"groupA"}, - Destinations: []string{"groupA"}, - Bidirectional: true, - Action: PolicyTrafficActionAccept, - }, - }, - } - wg := sync.WaitGroup{} wg.Add(1) go func() { @@ -1345,7 +1334,19 @@ func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) { } }() - if err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false); err != nil { + _, err := manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ + Enabled: true, + Rules: []*PolicyRule{ + { + Enabled: true, + Sources: []string{"groupA"}, + Destinations: []string{"groupA"}, + Bidirectional: true, + Action: PolicyTrafficActionAccept, + }, + }, + }) + if err != nil { t.Errorf("delete default rule: %v", err) return } @@ -1366,7 +1367,7 @@ func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) { return } - policy := Policy{ + _, err := manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ Enabled: true, Rules: []*PolicyRule{ { @@ -1377,9 +1378,8 @@ func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) { Action: PolicyTrafficActionAccept, }, }, - } - - if err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false); err != nil { + }) + if err != nil { t.Errorf("save policy: %v", err) return } @@ -1421,7 +1421,12 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) { require.NoError(t, err, "failed to save group") - policy := Policy{ + if err := manager.DeletePolicy(context.Background(), account.Id, account.Policies[0].ID, userID); err != nil { + t.Errorf("delete default rule: %v", err) + return + } + + policy, err := manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ Enabled: true, Rules: []*PolicyRule{ { @@ -1432,14 +1437,8 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) { Action: PolicyTrafficActionAccept, }, }, - } - - if err := manager.DeletePolicy(context.Background(), account.Id, account.Policies[0].ID, userID); err != nil { - t.Errorf("delete default rule: %v", err) - return - } - - if err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false); err != nil { + }) + if err != nil { t.Errorf("save policy: %v", err) return } @@ -2987,3 +2986,188 @@ func peerShouldReceiveUpdate(t *testing.T, updateMessage <-chan *UpdateMessage) t.Error("Timed out waiting for update message") } } + +func BenchmarkSyncAndMarkPeer(b *testing.B) { + benchCases := []struct { + name string + peers int + groups int + minMsPerOp float64 + maxMsPerOp float64 + }{ + {"Small", 50, 5, 1, 3}, + {"Medium", 500, 100, 7, 13}, + {"Large", 5000, 200, 65, 80}, + {"Small single", 50, 10, 1, 3}, + {"Medium single", 500, 10, 7, 13}, + {"Large 5", 5000, 15, 65, 80}, + } + + log.SetOutput(io.Discard) + defer log.SetOutput(os.Stderr) + for _, bc := range benchCases { + b.Run(bc.name, func(b *testing.B) { + manager, accountID, _, err := setupTestAccountManager(b, bc.peers, bc.groups) + if err != nil { + b.Fatalf("Failed to setup test account manager: %v", err) + } + ctx := context.Background() + account, err := manager.Store.GetAccount(ctx, accountID) + if err != nil { + b.Fatalf("Failed to get account: %v", err) + } + peerChannels := make(map[string]chan *UpdateMessage) + for peerID := range account.Peers { + peerChannels[peerID] = make(chan *UpdateMessage, channelBufferSize) + } + manager.peersUpdateManager.peerChannels = peerChannels + + b.ResetTimer() + start := time.Now() + for i := 0; i < b.N; i++ { + _, _, _, err := manager.SyncAndMarkPeer(context.Background(), account.Id, account.Peers["peer-1"].Key, nbpeer.PeerSystemMeta{Hostname: strconv.Itoa(i)}, net.IP{1, 1, 1, 1}) + assert.NoError(b, err) + } + + duration := time.Since(start) + msPerOp := float64(duration.Nanoseconds()) / float64(b.N) / 1e6 + b.ReportMetric(msPerOp, "ms/op") + + if msPerOp < bc.minMsPerOp { + b.Fatalf("Benchmark %s failed: too fast (%.2f ms/op, minimum %.2f ms/op)", bc.name, msPerOp, bc.minMsPerOp) + } + + if msPerOp > bc.maxMsPerOp { + b.Fatalf("Benchmark %s failed: too slow (%.2f ms/op, maximum %.2f ms/op)", bc.name, msPerOp, bc.maxMsPerOp) + } + }) + } +} + +func BenchmarkLoginPeer_ExistingPeer(b *testing.B) { + benchCases := []struct { + name string + peers int + groups int + minMsPerOp float64 + maxMsPerOp float64 + }{ + {"Small", 50, 5, 102, 110}, + {"Medium", 500, 100, 105, 140}, + {"Large", 5000, 200, 160, 200}, + {"Small single", 50, 10, 102, 110}, + {"Medium single", 500, 10, 105, 140}, + {"Large 5", 5000, 15, 160, 200}, + } + + log.SetOutput(io.Discard) + defer log.SetOutput(os.Stderr) + for _, bc := range benchCases { + b.Run(bc.name, func(b *testing.B) { + manager, accountID, _, err := setupTestAccountManager(b, bc.peers, bc.groups) + if err != nil { + b.Fatalf("Failed to setup test account manager: %v", err) + } + ctx := context.Background() + account, err := manager.Store.GetAccount(ctx, accountID) + if err != nil { + b.Fatalf("Failed to get account: %v", err) + } + peerChannels := make(map[string]chan *UpdateMessage) + for peerID := range account.Peers { + peerChannels[peerID] = make(chan *UpdateMessage, channelBufferSize) + } + manager.peersUpdateManager.peerChannels = peerChannels + + b.ResetTimer() + start := time.Now() + for i := 0; i < b.N; i++ { + _, _, _, err := manager.LoginPeer(context.Background(), PeerLogin{ + WireGuardPubKey: account.Peers["peer-1"].Key, + SSHKey: "someKey", + Meta: nbpeer.PeerSystemMeta{Hostname: strconv.Itoa(i)}, + UserID: "regular_user", + SetupKey: "", + ConnectionIP: net.IP{1, 1, 1, 1}, + }) + assert.NoError(b, err) + } + + duration := time.Since(start) + msPerOp := float64(duration.Nanoseconds()) / float64(b.N) / 1e6 + b.ReportMetric(msPerOp, "ms/op") + + if msPerOp < bc.minMsPerOp { + b.Fatalf("Benchmark %s failed: too fast (%.2f ms/op, minimum %.2f ms/op)", bc.name, msPerOp, bc.minMsPerOp) + } + + if msPerOp > bc.maxMsPerOp { + b.Fatalf("Benchmark %s failed: too slow (%.2f ms/op, maximum %.2f ms/op)", bc.name, msPerOp, bc.maxMsPerOp) + } + }) + } +} + +func BenchmarkLoginPeer_NewPeer(b *testing.B) { + benchCases := []struct { + name string + peers int + groups int + minMsPerOp float64 + maxMsPerOp float64 + }{ + {"Small", 50, 5, 107, 120}, + {"Medium", 500, 100, 105, 140}, + {"Large", 5000, 200, 180, 220}, + {"Small single", 50, 10, 107, 120}, + {"Medium single", 500, 10, 105, 140}, + {"Large 5", 5000, 15, 180, 220}, + } + + log.SetOutput(io.Discard) + defer log.SetOutput(os.Stderr) + for _, bc := range benchCases { + b.Run(bc.name, func(b *testing.B) { + manager, accountID, _, err := setupTestAccountManager(b, bc.peers, bc.groups) + if err != nil { + b.Fatalf("Failed to setup test account manager: %v", err) + } + ctx := context.Background() + account, err := manager.Store.GetAccount(ctx, accountID) + if err != nil { + b.Fatalf("Failed to get account: %v", err) + } + peerChannels := make(map[string]chan *UpdateMessage) + for peerID := range account.Peers { + peerChannels[peerID] = make(chan *UpdateMessage, channelBufferSize) + } + manager.peersUpdateManager.peerChannels = peerChannels + + b.ResetTimer() + start := time.Now() + for i := 0; i < b.N; i++ { + _, _, _, err := manager.LoginPeer(context.Background(), PeerLogin{ + WireGuardPubKey: "some-new-key" + strconv.Itoa(i), + SSHKey: "someKey", + Meta: nbpeer.PeerSystemMeta{Hostname: strconv.Itoa(i)}, + UserID: "regular_user", + SetupKey: "", + ConnectionIP: net.IP{1, 1, 1, 1}, + }) + assert.NoError(b, err) + } + + duration := time.Since(start) + msPerOp := float64(duration.Nanoseconds()) / float64(b.N) / 1e6 + b.ReportMetric(msPerOp, "ms/op") + + if msPerOp < bc.minMsPerOp { + b.Fatalf("Benchmark %s failed: too fast (%.2f ms/op, minimum %.2f ms/op)", bc.name, msPerOp, bc.minMsPerOp) + } + + if msPerOp > bc.maxMsPerOp { + b.Fatalf("Benchmark %s failed: too slow (%.2f ms/op, maximum %.2f ms/op)", bc.name, msPerOp, bc.maxMsPerOp) + } + }) + } +} diff --git a/management/server/dns.go b/management/server/dns.go index 4551be5ab..8df211b0b 100644 --- a/management/server/dns.go +++ b/management/server/dns.go @@ -3,6 +3,7 @@ package server import ( "context" "fmt" + "slices" "strconv" "sync" @@ -85,8 +86,12 @@ func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID s return nil, err } - if !user.IsAdminOrServiceUser() || user.AccountID != accountID { - return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view DNS settings") + if user.AccountID != accountID { + return nil, status.NewUserNotPartOfAccountError() + } + + if user.IsRegularUser() { + return nil, status.NewAdminPermissionError() } return am.Store.GetAccountDNSSettings(ctx, LockingStrengthShare, accountID) @@ -94,64 +99,137 @@ func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID s // SaveDNSSettings validates a user role and updates the account's DNS settings func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID string, userID string, dnsSettingsToSave *DNSSettings) error { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - - account, err := am.Store.GetAccount(ctx, accountID) - if err != nil { - return err - } - - user, err := account.FindUser(userID) - if err != nil { - return err - } - - if !user.HasAdminPower() { - return status.Errorf(status.PermissionDenied, "only users with admin power are allowed to update DNS settings") - } - if dnsSettingsToSave == nil { return status.Errorf(status.InvalidArgument, "the dns settings provided are nil") } - if len(dnsSettingsToSave.DisabledManagementGroups) != 0 { - err = validateGroups(dnsSettingsToSave.DisabledManagementGroups, account.Groups) - if err != nil { - return err - } - } - - oldSettings := account.DNSSettings.Copy() - account.DNSSettings = dnsSettingsToSave.Copy() - - addedGroups := difference(dnsSettingsToSave.DisabledManagementGroups, oldSettings.DisabledManagementGroups) - removedGroups := difference(oldSettings.DisabledManagementGroups, dnsSettingsToSave.DisabledManagementGroups) - - account.Network.IncSerial() - if err = am.Store.SaveAccount(ctx, account); err != nil { + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + if err != nil { return err } - for _, id := range addedGroups { - group := account.GetGroup(id) - meta := map[string]any{"group": group.Name, "group_id": group.ID} - am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupAddedToDisabledManagementGroups, meta) + if user.AccountID != accountID { + return status.NewUserNotPartOfAccountError() } - for _, id := range removedGroups { - group := account.GetGroup(id) - meta := map[string]any{"group": group.Name, "group_id": group.ID} - am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupRemovedFromDisabledManagementGroups, meta) + if !user.HasAdminPower() { + return status.NewAdminPermissionError() } - if anyGroupHasPeers(account, addedGroups) || anyGroupHasPeers(account, removedGroups) { + var updateAccountPeers bool + var eventsToStore []func() + + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + if err = validateDNSSettings(ctx, transaction, accountID, dnsSettingsToSave); err != nil { + return err + } + + oldSettings, err := transaction.GetAccountDNSSettings(ctx, LockingStrengthUpdate, accountID) + if err != nil { + return err + } + + addedGroups := difference(dnsSettingsToSave.DisabledManagementGroups, oldSettings.DisabledManagementGroups) + removedGroups := difference(oldSettings.DisabledManagementGroups, dnsSettingsToSave.DisabledManagementGroups) + + updateAccountPeers, err = areDNSSettingChangesAffectPeers(ctx, transaction, accountID, addedGroups, removedGroups) + if err != nil { + return err + } + + events := am.prepareDNSSettingsEvents(ctx, transaction, accountID, userID, addedGroups, removedGroups) + eventsToStore = append(eventsToStore, events...) + + if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + return err + } + + return transaction.SaveDNSSettings(ctx, LockingStrengthUpdate, accountID, dnsSettingsToSave) + }) + if err != nil { + return err + } + + for _, storeEvent := range eventsToStore { + storeEvent() + } + + if updateAccountPeers { am.updateAccountPeers(ctx, accountID) } return nil } +// prepareDNSSettingsEvents prepares a list of event functions to be stored. +func (am *DefaultAccountManager) prepareDNSSettingsEvents(ctx context.Context, transaction Store, accountID, userID string, addedGroups, removedGroups []string) []func() { + var eventsToStore []func() + + modifiedGroups := slices.Concat(addedGroups, removedGroups) + groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, modifiedGroups) + if err != nil { + log.WithContext(ctx).Debugf("failed to get groups for dns settings events: %v", err) + return nil + } + + for _, groupID := range addedGroups { + group, ok := groups[groupID] + if !ok { + log.WithContext(ctx).Debugf("skipped adding group: %s GroupAddedToDisabledManagementGroups activity", groupID) + continue + } + + eventsToStore = append(eventsToStore, func() { + meta := map[string]any{"group": group.Name, "group_id": group.ID} + am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupAddedToDisabledManagementGroups, meta) + }) + + } + + for _, groupID := range removedGroups { + group, ok := groups[groupID] + if !ok { + log.WithContext(ctx).Debugf("skipped adding group: %s GroupRemovedFromDisabledManagementGroups activity", groupID) + continue + } + + eventsToStore = append(eventsToStore, func() { + meta := map[string]any{"group": group.Name, "group_id": group.ID} + am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupRemovedFromDisabledManagementGroups, meta) + }) + } + + return eventsToStore +} + +// areDNSSettingChangesAffectPeers checks if the DNS settings changes affect any peers. +func areDNSSettingChangesAffectPeers(ctx context.Context, transaction Store, accountID string, addedGroups, removedGroups []string) (bool, error) { + hasPeers, err := anyGroupHasPeers(ctx, transaction, accountID, addedGroups) + if err != nil { + return false, err + } + + if hasPeers { + return true, nil + } + + return anyGroupHasPeers(ctx, transaction, accountID, removedGroups) +} + +// validateDNSSettings validates the DNS settings. +func validateDNSSettings(ctx context.Context, transaction Store, accountID string, settings *DNSSettings) error { + if len(settings.DisabledManagementGroups) == 0 { + return nil + } + + groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, settings.DisabledManagementGroups) + if err != nil { + return err + } + + return validateGroups(settings.DisabledManagementGroups, groups) +} + // toProtocolDNSConfig converts nbdns.Config to proto.DNSConfig using the cache func toProtocolDNSConfig(update nbdns.Config, cache *DNSConfigCache) *proto.DNSConfig { protoUpdate := &proto.DNSConfig{ diff --git a/management/server/group.go b/management/server/group.go index a36213f04..7b307cf1a 100644 --- a/management/server/group.go +++ b/management/server/group.go @@ -566,8 +566,7 @@ func areGroupChangesAffectPeers(ctx context.Context, transaction Store, accountI return false, nil } -// anyGroupHasPeers checks if any of the given groups in the account have peers. -func anyGroupHasPeers(account *Account, groupIDs []string) bool { +func (am *DefaultAccountManager) anyGroupHasPeers(account *Account, groupIDs []string) bool { for _, groupID := range groupIDs { if group, exists := account.Groups[groupID]; exists && group.HasPeers() { return true @@ -575,3 +574,19 @@ func anyGroupHasPeers(account *Account, groupIDs []string) bool { } return false } + +// anyGroupHasPeers checks if any of the given groups in the account have peers. +func anyGroupHasPeers(ctx context.Context, transaction Store, accountID string, groupIDs []string) (bool, error) { + groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, groupIDs) + if err != nil { + return false, err + } + + for _, group := range groups { + if group.HasPeers() { + return true, nil + } + } + + return false, nil +} diff --git a/management/server/group_test.go b/management/server/group_test.go index 59094a23e..ec017fc57 100644 --- a/management/server/group_test.go +++ b/management/server/group_test.go @@ -500,8 +500,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) { }) // adding a group to policy - err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ - ID: "policy", + _, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ Enabled: true, Rules: []*PolicyRule{ { @@ -512,7 +511,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) { Action: PolicyTrafficActionAccept, }, }, - }, false) + }) assert.NoError(t, err) // Saving a group linked to policy should update account peers and send peer update diff --git a/management/server/http/policies_handler.go b/management/server/http/policies_handler.go index 73f3803b5..eff9092d4 100644 --- a/management/server/http/policies_handler.go +++ b/management/server/http/policies_handler.go @@ -6,10 +6,8 @@ import ( "strconv" "github.com/gorilla/mux" - nbgroup "github.com/netbirdio/netbird/management/server/group" - "github.com/rs/xid" - "github.com/netbirdio/netbird/management/server" + nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/jwtclaims" @@ -122,21 +120,22 @@ func (h *Policies) savePolicy(w http.ResponseWriter, r *http.Request, accountID return } - isUpdate := policyID != "" - - if policyID == "" { - policyID = xid.New().String() - } - - policy := server.Policy{ + policy := &server.Policy{ ID: policyID, + AccountID: accountID, Name: req.Name, Enabled: req.Enabled, Description: req.Description, } for _, rule := range req.Rules { + var ruleID string + if rule.Id != nil { + ruleID = *rule.Id + } + pr := server.PolicyRule{ - ID: policyID, // TODO: when policy can contain multiple rules, need refactor + ID: ruleID, + PolicyID: policyID, Name: rule.Name, Destinations: rule.Destinations, Sources: rule.Sources, @@ -225,7 +224,8 @@ func (h *Policies) savePolicy(w http.ResponseWriter, r *http.Request, accountID policy.SourcePostureChecks = *req.SourcePostureChecks } - if err := h.accountManager.SavePolicy(r.Context(), accountID, userID, &policy, isUpdate); err != nil { + policy, err := h.accountManager.SavePolicy(r.Context(), accountID, userID, policy) + if err != nil { util.WriteError(r.Context(), err, w) return } @@ -236,7 +236,7 @@ func (h *Policies) savePolicy(w http.ResponseWriter, r *http.Request, accountID return } - resp := toPolicyResponse(allGroups, &policy) + resp := toPolicyResponse(allGroups, policy) if len(resp.Rules) == 0 { util.WriteError(r.Context(), status.Errorf(status.Internal, "no rules in the policy"), w) return diff --git a/management/server/http/policies_handler_test.go b/management/server/http/policies_handler_test.go index 228ebcbce..f8a897eb2 100644 --- a/management/server/http/policies_handler_test.go +++ b/management/server/http/policies_handler_test.go @@ -38,12 +38,12 @@ func initPoliciesTestData(policies ...*server.Policy) *Policies { } return policy, nil }, - SavePolicyFunc: func(_ context.Context, _, _ string, policy *server.Policy, _ bool) error { + SavePolicyFunc: func(_ context.Context, _, _ string, policy *server.Policy) (*server.Policy, error) { if !strings.HasPrefix(policy.ID, "id-") { policy.ID = "id-was-set" policy.Rules[0].ID = "id-was-set" } - return nil + return policy, nil }, GetAllGroupsFunc: func(ctx context.Context, accountID, userID string) ([]*nbgroup.Group, error) { return []*nbgroup.Group{{ID: "F"}, {ID: "G"}}, nil diff --git a/management/server/http/posture_checks_handler.go b/management/server/http/posture_checks_handler.go index 1d020e9bc..2c8204292 100644 --- a/management/server/http/posture_checks_handler.go +++ b/management/server/http/posture_checks_handler.go @@ -169,7 +169,8 @@ func (p *PostureChecksHandler) savePostureChecks(w http.ResponseWriter, r *http. return } - if err := p.accountManager.SavePostureChecks(r.Context(), accountID, userID, postureChecks); err != nil { + postureChecks, err = p.accountManager.SavePostureChecks(r.Context(), accountID, userID, postureChecks) + if err != nil { util.WriteError(r.Context(), err, w) return } diff --git a/management/server/http/posture_checks_handler_test.go b/management/server/http/posture_checks_handler_test.go index 02f0f0d83..f400cec81 100644 --- a/management/server/http/posture_checks_handler_test.go +++ b/management/server/http/posture_checks_handler_test.go @@ -40,15 +40,15 @@ func initPostureChecksTestData(postureChecks ...*posture.Checks) *PostureChecksH } return p, nil }, - SavePostureChecksFunc: func(_ context.Context, accountID, userID string, postureChecks *posture.Checks) error { + SavePostureChecksFunc: func(_ context.Context, accountID, userID string, postureChecks *posture.Checks) (*posture.Checks, error) { postureChecks.ID = "postureCheck" testPostureChecks[postureChecks.ID] = postureChecks if err := postureChecks.Validate(); err != nil { - return status.Errorf(status.InvalidArgument, err.Error()) //nolint + return nil, status.Errorf(status.InvalidArgument, err.Error()) //nolint } - return nil + return postureChecks, nil }, DeletePostureChecksFunc: func(_ context.Context, accountID, postureChecksID, userID string) error { _, ok := testPostureChecks[postureChecksID] diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index aa6a47b15..46a4fbc1f 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -49,7 +49,7 @@ type MockAccountManager struct { GroupDeletePeerFunc func(ctx context.Context, accountID, groupID, peerID string) error DeleteRuleFunc func(ctx context.Context, accountID, ruleID, userID string) error GetPolicyFunc func(ctx context.Context, accountID, policyID, userID string) (*server.Policy, error) - SavePolicyFunc func(ctx context.Context, accountID, userID string, policy *server.Policy, isUpdate bool) error + SavePolicyFunc func(ctx context.Context, accountID, userID string, policy *server.Policy) (*server.Policy, error) DeletePolicyFunc func(ctx context.Context, accountID, policyID, userID string) error ListPoliciesFunc func(ctx context.Context, accountID, userID string) ([]*server.Policy, error) GetUsersFromAccountFunc func(ctx context.Context, accountID, userID string) ([]*server.UserInfo, error) @@ -96,7 +96,7 @@ type MockAccountManager struct { HasConnectedChannelFunc func(peerID string) bool GetExternalCacheManagerFunc func() server.ExternalCacheManager GetPostureChecksFunc func(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) - SavePostureChecksFunc func(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error + SavePostureChecksFunc func(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) (*posture.Checks, error) DeletePostureChecksFunc func(ctx context.Context, accountID, postureChecksID, userID string) error ListPostureChecksFunc func(ctx context.Context, accountID, userID string) ([]*posture.Checks, error) GetIdpManagerFunc func() idp.Manager @@ -386,11 +386,11 @@ func (am *MockAccountManager) GetPolicy(ctx context.Context, accountID, policyID } // SavePolicy mock implementation of SavePolicy from server.AccountManager interface -func (am *MockAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *server.Policy, isUpdate bool) error { +func (am *MockAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *server.Policy) (*server.Policy, error) { if am.SavePolicyFunc != nil { - return am.SavePolicyFunc(ctx, accountID, userID, policy, isUpdate) + return am.SavePolicyFunc(ctx, accountID, userID, policy) } - return status.Errorf(codes.Unimplemented, "method SavePolicy is not implemented") + return nil, status.Errorf(codes.Unimplemented, "method SavePolicy is not implemented") } // DeletePolicy mock implementation of DeletePolicy from server.AccountManager interface @@ -730,11 +730,11 @@ func (am *MockAccountManager) GetPostureChecks(ctx context.Context, accountID, p } // SavePostureChecks mocks SavePostureChecks of the AccountManager interface -func (am *MockAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error { +func (am *MockAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) (*posture.Checks, error) { if am.SavePostureChecksFunc != nil { return am.SavePostureChecksFunc(ctx, accountID, userID, postureChecks) } - return status.Errorf(codes.Unimplemented, "method SavePostureChecks is not implemented") + return nil, status.Errorf(codes.Unimplemented, "method SavePostureChecks is not implemented") } // DeletePostureChecks mocks DeletePostureChecks of the AccountManager interface diff --git a/management/server/nameserver.go b/management/server/nameserver.go index 957008714..e7a5387a1 100644 --- a/management/server/nameserver.go +++ b/management/server/nameserver.go @@ -24,26 +24,34 @@ func (am *DefaultAccountManager) GetNameServerGroup(ctx context.Context, account return nil, err } - if !user.IsAdminOrServiceUser() || user.AccountID != accountID { - return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view name server groups") + if user.AccountID != accountID { + return nil, status.NewUserNotPartOfAccountError() } - return am.Store.GetNameServerGroupByID(ctx, LockingStrengthShare, nsGroupID, accountID) + if user.IsRegularUser() { + return nil, status.NewAdminPermissionError() + } + + return am.Store.GetNameServerGroupByID(ctx, LockingStrengthShare, accountID, nsGroupID) } // CreateNameServerGroup creates and saves a new nameserver group func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainEnabled bool) (*nbdns.NameServerGroup, error) { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() - account, err := am.Store.GetAccount(ctx, accountID) + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return nil, err } + if user.AccountID != accountID { + return nil, status.NewUserNotPartOfAccountError() + } + newNSGroup := &nbdns.NameServerGroup{ ID: xid.New().String(), + AccountID: accountID, Name: name, Description: description, NameServers: nameServerList, @@ -54,26 +62,33 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco SearchDomainsEnabled: searchDomainEnabled, } - err = validateNameServerGroup(false, newNSGroup, account) + var updateAccountPeers bool + + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + if err = validateNameServerGroup(ctx, transaction, accountID, newNSGroup); err != nil { + return err + } + + updateAccountPeers, err = anyGroupHasPeers(ctx, transaction, accountID, newNSGroup.Groups) + if err != nil { + return err + } + + if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + return err + } + + return transaction.SaveNameServerGroup(ctx, LockingStrengthUpdate, newNSGroup) + }) if err != nil { return nil, err } - if account.NameServerGroups == nil { - account.NameServerGroups = make(map[string]*nbdns.NameServerGroup) - } + am.StoreEvent(ctx, userID, newNSGroup.ID, accountID, activity.NameserverGroupCreated, newNSGroup.EventMeta()) - account.NameServerGroups[newNSGroup.ID] = newNSGroup - - account.Network.IncSerial() - if err = am.Store.SaveAccount(ctx, account); err != nil { - return nil, err - } - - if anyGroupHasPeers(account, newNSGroup.Groups) { + if updateAccountPeers { am.updateAccountPeers(ctx, accountID) } - am.StoreEvent(ctx, userID, newNSGroup.ID, accountID, activity.NameserverGroupCreated, newNSGroup.EventMeta()) return newNSGroup.Copy(), nil } @@ -87,58 +102,95 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun return status.Errorf(status.InvalidArgument, "nameserver group provided is nil") } - account, err := am.Store.GetAccount(ctx, accountID) + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return err } - err = validateNameServerGroup(true, nsGroupToSave, account) + if user.AccountID != accountID { + return status.NewUserNotPartOfAccountError() + } + + var updateAccountPeers bool + + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + oldNSGroup, err := transaction.GetNameServerGroupByID(ctx, LockingStrengthShare, accountID, nsGroupToSave.ID) + if err != nil { + return err + } + nsGroupToSave.AccountID = accountID + + if err = validateNameServerGroup(ctx, transaction, accountID, nsGroupToSave); err != nil { + return err + } + + updateAccountPeers, err = areNameServerGroupChangesAffectPeers(ctx, transaction, nsGroupToSave, oldNSGroup) + if err != nil { + return err + } + + if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + return err + } + + return transaction.SaveNameServerGroup(ctx, LockingStrengthUpdate, nsGroupToSave) + }) if err != nil { return err } - oldNSGroup := account.NameServerGroups[nsGroupToSave.ID] - account.NameServerGroups[nsGroupToSave.ID] = nsGroupToSave + am.StoreEvent(ctx, userID, nsGroupToSave.ID, accountID, activity.NameserverGroupUpdated, nsGroupToSave.EventMeta()) - account.Network.IncSerial() - if err = am.Store.SaveAccount(ctx, account); err != nil { - return err - } - - if areNameServerGroupChangesAffectPeers(account, nsGroupToSave, oldNSGroup) { + if updateAccountPeers { am.updateAccountPeers(ctx, accountID) } - am.StoreEvent(ctx, userID, nsGroupToSave.ID, accountID, activity.NameserverGroupUpdated, nsGroupToSave.EventMeta()) return nil } // DeleteNameServerGroup deletes nameserver group with nsGroupID func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, accountID, nsGroupID, userID string) error { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() - account, err := am.Store.GetAccount(ctx, accountID) + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return err } - nsGroup := account.NameServerGroups[nsGroupID] - if nsGroup == nil { - return status.Errorf(status.NotFound, "nameserver group %s wasn't found", nsGroupID) + if user.AccountID != accountID { + return status.NewUserNotPartOfAccountError() } - delete(account.NameServerGroups, nsGroupID) - account.Network.IncSerial() - if err = am.Store.SaveAccount(ctx, account); err != nil { + var nsGroup *nbdns.NameServerGroup + var updateAccountPeers bool + + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + nsGroup, err = transaction.GetNameServerGroupByID(ctx, LockingStrengthUpdate, accountID, nsGroupID) + if err != nil { + return err + } + + updateAccountPeers, err = anyGroupHasPeers(ctx, transaction, accountID, nsGroup.Groups) + if err != nil { + return err + } + + if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + return err + } + + return transaction.DeleteNameServerGroup(ctx, LockingStrengthUpdate, accountID, nsGroupID) + }) + if err != nil { return err } - if anyGroupHasPeers(account, nsGroup.Groups) { + am.StoreEvent(ctx, userID, nsGroup.ID, accountID, activity.NameserverGroupDeleted, nsGroup.EventMeta()) + + if updateAccountPeers { am.updateAccountPeers(ctx, accountID) } - am.StoreEvent(ctx, userID, nsGroup.ID, accountID, activity.NameserverGroupDeleted, nsGroup.EventMeta()) return nil } @@ -150,44 +202,62 @@ func (am *DefaultAccountManager) ListNameServerGroups(ctx context.Context, accou return nil, err } - if !user.IsAdminOrServiceUser() || user.AccountID != accountID { - return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view name server groups") + if user.AccountID != accountID { + return nil, status.NewUserNotPartOfAccountError() + } + + if user.IsRegularUser() { + return nil, status.NewAdminPermissionError() } return am.Store.GetAccountNameServerGroups(ctx, LockingStrengthShare, accountID) } -func validateNameServerGroup(existingGroup bool, nameserverGroup *nbdns.NameServerGroup, account *Account) error { - nsGroupID := "" - if existingGroup { - nsGroupID = nameserverGroup.ID - _, found := account.NameServerGroups[nsGroupID] - if !found { - return status.Errorf(status.NotFound, "nameserver group with ID %s was not found", nsGroupID) - } - } - +func validateNameServerGroup(ctx context.Context, transaction Store, accountID string, nameserverGroup *nbdns.NameServerGroup) error { err := validateDomainInput(nameserverGroup.Primary, nameserverGroup.Domains, nameserverGroup.SearchDomainsEnabled) if err != nil { return err } - err = validateNSGroupName(nameserverGroup.Name, nsGroupID, account.NameServerGroups) - if err != nil { - return err - } - err = validateNSList(nameserverGroup.NameServers) if err != nil { return err } - err = validateGroups(nameserverGroup.Groups, account.Groups) + nsServerGroups, err := transaction.GetAccountNameServerGroups(ctx, LockingStrengthShare, accountID) if err != nil { return err } - return nil + err = validateNSGroupName(nameserverGroup.Name, nameserverGroup.ID, nsServerGroups) + if err != nil { + return err + } + + groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, nameserverGroup.Groups) + if err != nil { + return err + } + + return validateGroups(nameserverGroup.Groups, groups) +} + +// areNameServerGroupChangesAffectPeers checks if the changes in the nameserver group affect the peers. +func areNameServerGroupChangesAffectPeers(ctx context.Context, transaction Store, newNSGroup, oldNSGroup *nbdns.NameServerGroup) (bool, error) { + if !newNSGroup.Enabled && !oldNSGroup.Enabled { + return false, nil + } + + hasPeers, err := anyGroupHasPeers(ctx, transaction, newNSGroup.AccountID, newNSGroup.Groups) + if err != nil { + return false, err + } + + if hasPeers { + return true, nil + } + + return anyGroupHasPeers(ctx, transaction, oldNSGroup.AccountID, oldNSGroup.Groups) } func validateDomainInput(primary bool, domains []string, searchDomainsEnabled bool) error { @@ -213,14 +283,14 @@ func validateDomainInput(primary bool, domains []string, searchDomainsEnabled bo return nil } -func validateNSGroupName(name, nsGroupID string, nsGroupMap map[string]*nbdns.NameServerGroup) error { +func validateNSGroupName(name, nsGroupID string, groups []*nbdns.NameServerGroup) error { if utf8.RuneCountInString(name) > nbdns.MaxGroupNameChar || name == "" { return status.Errorf(status.InvalidArgument, "nameserver group name should be between 1 and %d", nbdns.MaxGroupNameChar) } - for _, nsGroup := range nsGroupMap { + for _, nsGroup := range groups { if name == nsGroup.Name && nsGroup.ID != nsGroupID { - return status.Errorf(status.InvalidArgument, "a nameserver group with name %s already exist", name) + return status.Errorf(status.InvalidArgument, "nameserver group with name %s already exist", name) } } @@ -228,8 +298,8 @@ func validateNSGroupName(name, nsGroupID string, nsGroupMap map[string]*nbdns.Na } func validateNSList(list []nbdns.NameServer) error { - nsListLenght := len(list) - if nsListLenght == 0 || nsListLenght > 3 { + nsListLength := len(list) + if nsListLength == 0 || nsListLength > 3 { return status.Errorf(status.InvalidArgument, "the list of nameservers should be 1 or 3, got %d", len(list)) } return nil @@ -244,14 +314,7 @@ func validateGroups(list []string, groups map[string]*nbgroup.Group) error { if id == "" { return status.Errorf(status.InvalidArgument, "group ID should not be empty string") } - found := false - for groupID := range groups { - if id == groupID { - found = true - break - } - } - if !found { + if _, found := groups[id]; !found { return status.Errorf(status.InvalidArgument, "group id %s not found", id) } } @@ -277,11 +340,3 @@ func validateDomain(domain string) error { return nil } - -// areNameServerGroupChangesAffectPeers checks if the changes in the nameserver group affect the peers. -func areNameServerGroupChangesAffectPeers(account *Account, newNSGroup, oldNSGroup *nbdns.NameServerGroup) bool { - if !newNSGroup.Enabled && !oldNSGroup.Enabled { - return false - } - return anyGroupHasPeers(account, newNSGroup.Groups) || anyGroupHasPeers(account, oldNSGroup.Groups) -} diff --git a/management/server/peer.go b/management/server/peer.go index beb833dba..d45bb1a50 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -617,7 +617,11 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s return nil, nil, nil, err } - postureChecks := am.getPeerPostureChecks(account, newPeer) + postureChecks, err := am.getPeerPostureChecks(account, newPeer.ID) + if err != nil { + return nil, nil, nil, err + } + customZone := account.GetPeersCustomZone(ctx, am.dnsDomain) networkMap := account.GetPeerNetworkMap(ctx, newPeer.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics()) return newPeer, networkMap, postureChecks, nil @@ -702,7 +706,11 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac if err != nil { return nil, nil, nil, fmt.Errorf("failed to get validated peers: %w", err) } - postureChecks = am.getPeerPostureChecks(account, peer) + + postureChecks, err = am.getPeerPostureChecks(account, peer.ID) + if err != nil { + return nil, nil, nil, err + } customZone := account.GetPeersCustomZone(ctx, am.dnsDomain) return peer, account.GetPeerNetworkMap(ctx, peer.ID, customZone, validPeersMap, am.metrics.AccountManagerMetrics()), postureChecks, nil @@ -876,7 +884,11 @@ func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, is if err != nil { return nil, nil, nil, err } - postureChecks = am.getPeerPostureChecks(account, peer) + + postureChecks, err = am.getPeerPostureChecks(account, peer.ID) + if err != nil { + return nil, nil, nil, err + } customZone := account.GetPeersCustomZone(ctx, am.dnsDomain) return peer, account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics()), postureChecks, nil @@ -1030,7 +1042,12 @@ func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account defer wg.Done() defer func() { <-semaphore }() - postureChecks := am.getPeerPostureChecks(account, p) + postureChecks, err := am.getPeerPostureChecks(account, p.ID) + if err != nil { + log.WithContext(ctx).Errorf("failed to send out updates to peers, failed to get peer: %s posture checks: %v", p.ID, err) + return + } + remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics()) update := toSyncResponse(ctx, nil, p, nil, nil, remotePeerNetworkMap, am.GetDNSDomain(), postureChecks, dnsCache) am.peersUpdateManager.SendUpdate(ctx, p.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap}) diff --git a/management/server/peer_test.go b/management/server/peer_test.go index 4e2dcb2c3..e5eaa20d6 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -283,14 +283,12 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) { var ( group1 nbgroup.Group group2 nbgroup.Group - policy Policy ) group1.ID = xid.New().String() group2.ID = xid.New().String() group1.Name = "src" group2.Name = "dst" - policy.ID = xid.New().String() group1.Peers = append(group1.Peers, peer1.ID) group2.Peers = append(group2.Peers, peer2.ID) @@ -305,18 +303,20 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) { return } - policy.Name = "test" - policy.Enabled = true - policy.Rules = []*PolicyRule{ - { - Enabled: true, - Sources: []string{group1.ID}, - Destinations: []string{group2.ID}, - Bidirectional: true, - Action: PolicyTrafficActionAccept, + policy := &Policy{ + Name: "test", + Enabled: true, + Rules: []*PolicyRule{ + { + Enabled: true, + Sources: []string{group1.ID}, + Destinations: []string{group2.ID}, + Bidirectional: true, + Action: PolicyTrafficActionAccept, + }, }, } - err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, false) + policy, err = manager.SavePolicy(context.Background(), account.Id, userID, policy) if err != nil { t.Errorf("expecting rule to be added, got failure %v", err) return @@ -364,7 +364,7 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) { } policy.Enabled = false - err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, true) + _, err = manager.SavePolicy(context.Background(), account.Id, userID, policy) if err != nil { t.Errorf("expecting rule to be added, got failure %v", err) return @@ -833,19 +833,20 @@ func BenchmarkGetPeers(b *testing.B) { }) } } - func BenchmarkUpdateAccountPeers(b *testing.B) { benchCases := []struct { - name string - peers int - groups int + name string + peers int + groups int + minMsPerOp float64 + maxMsPerOp float64 }{ - {"Small", 50, 5}, - {"Medium", 500, 10}, - {"Large", 5000, 20}, - {"Small single", 50, 1}, - {"Medium single", 500, 1}, - {"Large 5", 5000, 5}, + {"Small", 50, 5, 90, 120}, + {"Medium", 500, 100, 110, 140}, + {"Large", 5000, 200, 800, 1300}, + {"Small single", 50, 10, 90, 120}, + {"Medium single", 500, 10, 110, 170}, + {"Large 5", 5000, 15, 1300, 1800}, } log.SetOutput(io.Discard) @@ -881,8 +882,16 @@ func BenchmarkUpdateAccountPeers(b *testing.B) { } duration := time.Since(start) - b.ReportMetric(float64(duration.Nanoseconds())/float64(b.N)/1e6, "ms/op") - b.ReportMetric(0, "ns/op") + msPerOp := float64(duration.Nanoseconds()) / float64(b.N) / 1e6 + b.ReportMetric(msPerOp, "ms/op") + + if msPerOp < bc.minMsPerOp { + b.Fatalf("Benchmark %s failed: too fast (%.2f ms/op, minimum %.2f ms/op)", bc.name, msPerOp, bc.minMsPerOp) + } + + if msPerOp > bc.maxMsPerOp { + b.Fatalf("Benchmark %s failed: too slow (%.2f ms/op, maximum %.2f ms/op)", bc.name, msPerOp, bc.maxMsPerOp) + } }) } } @@ -1445,8 +1454,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) { // Adding peer to group linked with policy should update account peers and send peer update t.Run("adding peer to group linked with policy", func(t *testing.T) { - err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ - ID: "policy", + _, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ Enabled: true, Rules: []*PolicyRule{ { @@ -1457,7 +1465,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) { Action: PolicyTrafficActionAccept, }, }, - }, false) + }) require.NoError(t, err) done := make(chan struct{}) diff --git a/management/server/policy.go b/management/server/policy.go index 8a5733f01..2d3abc3f1 100644 --- a/management/server/policy.go +++ b/management/server/policy.go @@ -3,13 +3,13 @@ package server import ( "context" _ "embed" - "slices" "strconv" "strings" + "github.com/netbirdio/netbird/management/proto" + "github.com/rs/xid" log "github.com/sirupsen/logrus" - "github.com/netbirdio/netbird/management/proto" "github.com/netbirdio/netbird/management/server/activity" nbgroup "github.com/netbirdio/netbird/management/server/group" nbpeer "github.com/netbirdio/netbird/management/server/peer" @@ -125,6 +125,7 @@ type PolicyRule struct { func (pm *PolicyRule) Copy() *PolicyRule { rule := &PolicyRule{ ID: pm.ID, + PolicyID: pm.PolicyID, Name: pm.Name, Description: pm.Description, Enabled: pm.Enabled, @@ -171,6 +172,7 @@ type Policy struct { func (p *Policy) Copy() *Policy { c := &Policy{ ID: p.ID, + AccountID: p.AccountID, Name: p.Name, Description: p.Description, Enabled: p.Enabled, @@ -343,157 +345,207 @@ func (am *DefaultAccountManager) GetPolicy(ctx context.Context, accountID, polic return nil, err } - if !user.IsAdminOrServiceUser() || user.AccountID != accountID { - return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view policies") + if user.AccountID != accountID { + return nil, status.NewUserNotPartOfAccountError() } - return am.Store.GetPolicyByID(ctx, LockingStrengthShare, policyID, accountID) + if user.IsRegularUser() { + return nil, status.NewAdminPermissionError() + } + + return am.Store.GetPolicyByID(ctx, LockingStrengthShare, accountID, policyID) } // SavePolicy in the store -func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *Policy, isUpdate bool) error { +func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *Policy) (*Policy, error) { unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() - account, err := am.Store.GetAccount(ctx, accountID) + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { - return err + return nil, err } - updateAccountPeers, err := am.savePolicy(account, policy, isUpdate) + if user.AccountID != accountID { + return nil, status.NewUserNotPartOfAccountError() + } + + if user.IsRegularUser() { + return nil, status.NewAdminPermissionError() + } + + var isUpdate = policy.ID != "" + var updateAccountPeers bool + var action = activity.PolicyAdded + + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + if err = validatePolicy(ctx, transaction, accountID, policy); err != nil { + return err + } + + updateAccountPeers, err = arePolicyChangesAffectPeers(ctx, transaction, accountID, policy, isUpdate) + if err != nil { + return err + } + + if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + return err + } + + saveFunc := transaction.CreatePolicy + if isUpdate { + action = activity.PolicyUpdated + saveFunc = transaction.SavePolicy + } + + return saveFunc(ctx, LockingStrengthUpdate, policy) + }) if err != nil { - return err + return nil, err } - account.Network.IncSerial() - if err = am.Store.SaveAccount(ctx, account); err != nil { - return err - } - - action := activity.PolicyAdded - if isUpdate { - action = activity.PolicyUpdated - } am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta()) if updateAccountPeers { am.updateAccountPeers(ctx, accountID) } + return policy, nil +} + +// DeletePolicy from the store +func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, policyID, userID string) error { + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + if err != nil { + return err + } + + if user.AccountID != accountID { + return status.NewUserNotPartOfAccountError() + } + + if user.IsRegularUser() { + return status.NewAdminPermissionError() + } + + var policy *Policy + var updateAccountPeers bool + + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + policy, err = transaction.GetPolicyByID(ctx, LockingStrengthUpdate, accountID, policyID) + if err != nil { + return err + } + + updateAccountPeers, err = arePolicyChangesAffectPeers(ctx, transaction, accountID, policy, false) + if err != nil { + return err + } + + if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + return err + } + + return transaction.DeletePolicy(ctx, LockingStrengthUpdate, accountID, policyID) + }) + if err != nil { + return err + } + + am.StoreEvent(ctx, userID, policyID, accountID, activity.PolicyRemoved, policy.EventMeta()) + + if updateAccountPeers { + am.updateAccountPeers(ctx, accountID) + } + return nil } -// DeletePolicy from the store -func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, policyID, userID string) error { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - - account, err := am.Store.GetAccount(ctx, accountID) - if err != nil { - return err - } - - policy, err := am.deletePolicy(account, policyID) - if err != nil { - return err - } - - account.Network.IncSerial() - if err = am.Store.SaveAccount(ctx, account); err != nil { - return err - } - - am.StoreEvent(ctx, userID, policy.ID, accountID, activity.PolicyRemoved, policy.EventMeta()) - - if anyGroupHasPeers(account, policy.ruleGroups()) { - am.updateAccountPeers(ctx, accountID) - } - - return nil -} - -// ListPolicies from the store +// ListPolicies from the store. func (am *DefaultAccountManager) ListPolicies(ctx context.Context, accountID, userID string) ([]*Policy, error) { user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return nil, err } - if !user.IsAdminOrServiceUser() || user.AccountID != accountID { - return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view policies") + if user.AccountID != accountID { + return nil, status.NewUserNotPartOfAccountError() + } + + if user.IsRegularUser() { + return nil, status.NewAdminPermissionError() } return am.Store.GetAccountPolicies(ctx, LockingStrengthShare, accountID) } -func (am *DefaultAccountManager) deletePolicy(account *Account, policyID string) (*Policy, error) { - policyIdx := -1 - for i, policy := range account.Policies { - if policy.ID == policyID { - policyIdx = i - break - } - } - if policyIdx < 0 { - return nil, status.Errorf(status.NotFound, "rule with ID %s doesn't exist", policyID) - } - - policy := account.Policies[policyIdx] - account.Policies = append(account.Policies[:policyIdx], account.Policies[policyIdx+1:]...) - return policy, nil -} - -// savePolicy saves or updates a policy in the given account. -// If isUpdate is true, the function updates the existing policy; otherwise, it adds a new policy. -func (am *DefaultAccountManager) savePolicy(account *Account, policyToSave *Policy, isUpdate bool) (bool, error) { - for index, rule := range policyToSave.Rules { - rule.Sources = filterValidGroupIDs(account, rule.Sources) - rule.Destinations = filterValidGroupIDs(account, rule.Destinations) - policyToSave.Rules[index] = rule - } - - if policyToSave.SourcePostureChecks != nil { - policyToSave.SourcePostureChecks = filterValidPostureChecks(account, policyToSave.SourcePostureChecks) - } - +// arePolicyChangesAffectPeers checks if changes to a policy will affect any associated peers. +func arePolicyChangesAffectPeers(ctx context.Context, transaction Store, accountID string, policy *Policy, isUpdate bool) (bool, error) { if isUpdate { - policyIdx := slices.IndexFunc(account.Policies, func(policy *Policy) bool { return policy.ID == policyToSave.ID }) - if policyIdx < 0 { - return false, status.Errorf(status.NotFound, "couldn't find policy id %s", policyToSave.ID) + existingPolicy, err := transaction.GetPolicyByID(ctx, LockingStrengthShare, accountID, policy.ID) + if err != nil { + return false, err } - oldPolicy := account.Policies[policyIdx] - // Update the existing policy - account.Policies[policyIdx] = policyToSave - - if !policyToSave.Enabled && !oldPolicy.Enabled { + if !policy.Enabled && !existingPolicy.Enabled { return false, nil } - updateAccountPeers := anyGroupHasPeers(account, oldPolicy.ruleGroups()) || anyGroupHasPeers(account, policyToSave.ruleGroups()) - return updateAccountPeers, nil - } + hasPeers, err := anyGroupHasPeers(ctx, transaction, policy.AccountID, existingPolicy.ruleGroups()) + if err != nil { + return false, err + } - // Add the new policy to the account - account.Policies = append(account.Policies, policyToSave) - - return anyGroupHasPeers(account, policyToSave.ruleGroups()), nil -} - -func toProtocolFirewallRules(rules []*FirewallRule) []*proto.FirewallRule { - result := make([]*proto.FirewallRule, len(rules)) - for i := range rules { - rule := rules[i] - - result[i] = &proto.FirewallRule{ - PeerIP: rule.PeerIP, - Direction: getProtoDirection(rule.Direction), - Action: getProtoAction(rule.Action), - Protocol: getProtoProtocol(rule.Protocol), - Port: rule.Port, + if hasPeers { + return true, nil } } - return result + + return anyGroupHasPeers(ctx, transaction, policy.AccountID, policy.ruleGroups()) +} + +// validatePolicy validates the policy and its rules. +func validatePolicy(ctx context.Context, transaction Store, accountID string, policy *Policy) error { + if policy.ID != "" { + _, err := transaction.GetPolicyByID(ctx, LockingStrengthShare, accountID, policy.ID) + if err != nil { + return err + } + } else { + policy.ID = xid.New().String() + policy.AccountID = accountID + } + + groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, policy.ruleGroups()) + if err != nil { + return err + } + + postureChecks, err := transaction.GetPostureChecksByIDs(ctx, LockingStrengthShare, accountID, policy.SourcePostureChecks) + if err != nil { + return err + } + + for i, rule := range policy.Rules { + ruleCopy := rule.Copy() + if ruleCopy.ID == "" { + ruleCopy.ID = policy.ID // TODO: when policy can contain multiple rules, need refactor + ruleCopy.PolicyID = policy.ID + } + + ruleCopy.Sources = getValidGroupIDs(groups, ruleCopy.Sources) + ruleCopy.Destinations = getValidGroupIDs(groups, ruleCopy.Destinations) + policy.Rules[i] = ruleCopy + } + + if policy.SourcePostureChecks != nil { + policy.SourcePostureChecks = getValidPostureCheckIDs(postureChecks, policy.SourcePostureChecks) + } + + return nil } // getAllPeersFromGroups for given peer ID and list of groups @@ -574,27 +626,42 @@ func (a *Account) getPostureChecks(postureChecksID string) *posture.Checks { return nil } -// filterValidPostureChecks filters and returns the posture check IDs from the given list -// that are valid within the provided account. -func filterValidPostureChecks(account *Account, postureChecksIds []string) []string { - result := make([]string, 0, len(postureChecksIds)) +// getValidPostureCheckIDs filters and returns only the valid posture check IDs from the provided list. +func getValidPostureCheckIDs(postureChecks map[string]*posture.Checks, postureChecksIds []string) []string { + validIDs := make([]string, 0, len(postureChecksIds)) for _, id := range postureChecksIds { - for _, postureCheck := range account.PostureChecks { - if id == postureCheck.ID { - result = append(result, id) - continue - } + if _, exists := postureChecks[id]; exists { + validIDs = append(validIDs, id) } } - return result + + return validIDs } -// filterValidGroupIDs filters a list of group IDs and returns only the ones present in the account's group map. -func filterValidGroupIDs(account *Account, groupIDs []string) []string { - result := make([]string, 0, len(groupIDs)) - for _, groupID := range groupIDs { - if _, exists := account.Groups[groupID]; exists { - result = append(result, groupID) +// getValidGroupIDs filters and returns only the valid group IDs from the provided list. +func getValidGroupIDs(groups map[string]*nbgroup.Group, groupIDs []string) []string { + validIDs := make([]string, 0, len(groupIDs)) + for _, id := range groupIDs { + if _, exists := groups[id]; exists { + validIDs = append(validIDs, id) + } + } + + return validIDs +} + +// toProtocolFirewallRules converts the firewall rules to the protocol firewall rules. +func toProtocolFirewallRules(rules []*FirewallRule) []*proto.FirewallRule { + result := make([]*proto.FirewallRule, len(rules)) + for i := range rules { + rule := rules[i] + + result[i] = &proto.FirewallRule{ + PeerIP: rule.PeerIP, + Direction: getProtoDirection(rule.Direction), + Action: getProtoAction(rule.Action), + Protocol: getProtoProtocol(rule.Protocol), + Port: rule.Port, } } return result diff --git a/management/server/policy_test.go b/management/server/policy_test.go index e7f0f9cd2..62d80f46e 100644 --- a/management/server/policy_test.go +++ b/management/server/policy_test.go @@ -7,7 +7,6 @@ import ( "testing" "time" - "github.com/rs/xid" "github.com/stretchr/testify/assert" "golang.org/x/exp/slices" @@ -859,14 +858,23 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) }) + var policyWithGroupRulesNoPeers *Policy + var policyWithDestinationPeersOnly *Policy + var policyWithSourceAndDestinationPeers *Policy + // Saving policy with rule groups with no peers should not update account's peers and not send peer update t.Run("saving policy with rule groups with no peers", func(t *testing.T) { - policy := Policy{ - ID: "policy-rule-groups-no-peers", - Enabled: true, + done := make(chan struct{}) + go func() { + peerShouldNotReceiveUpdate(t, updMsg) + close(done) + }() + + policyWithGroupRulesNoPeers, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ + AccountID: account.Id, + Enabled: true, Rules: []*PolicyRule{ { - ID: xid.New().String(), Enabled: true, Sources: []string{"groupB"}, Destinations: []string{"groupC"}, @@ -874,15 +882,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { Action: PolicyTrafficActionAccept, }, }, - } - - done := make(chan struct{}) - go func() { - peerShouldNotReceiveUpdate(t, updMsg) - close(done) - }() - - err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false) + }) assert.NoError(t, err) select { @@ -895,12 +895,17 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { // Saving policy with source group containing peers, but destination group without peers should // update account's peers and send peer update t.Run("saving policy where source has peers but destination does not", func(t *testing.T) { - policy := Policy{ - ID: "policy-source-has-peers-destination-none", - Enabled: true, + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg) + close(done) + }() + + _, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ + AccountID: account.Id, + Enabled: true, Rules: []*PolicyRule{ { - ID: xid.New().String(), Enabled: true, Sources: []string{"groupA"}, Destinations: []string{"groupB"}, @@ -909,15 +914,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { Action: PolicyTrafficActionAccept, }, }, - } - - done := make(chan struct{}) - go func() { - peerShouldReceiveUpdate(t, updMsg) - close(done) - }() - - err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false) + }) assert.NoError(t, err) select { @@ -930,13 +927,18 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { // Saving policy with destination group containing peers, but source group without peers should // update account's peers and send peer update t.Run("saving policy where destination has peers but source does not", func(t *testing.T) { - policy := Policy{ - ID: "policy-destination-has-peers-source-none", - Enabled: true, + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg) + close(done) + }() + + policyWithDestinationPeersOnly, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ + AccountID: account.Id, + Enabled: true, Rules: []*PolicyRule{ { - ID: xid.New().String(), - Enabled: false, + Enabled: true, Sources: []string{"groupC"}, Destinations: []string{"groupD"}, Bidirectional: true, @@ -944,15 +946,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { Action: PolicyTrafficActionAccept, }, }, - } - - done := make(chan struct{}) - go func() { - peerShouldReceiveUpdate(t, updMsg) - close(done) - }() - - err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false) + }) assert.NoError(t, err) select { @@ -965,12 +959,17 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { // Saving policy with destination and source groups containing peers should update account's peers // and send peer update t.Run("saving policy with source and destination groups with peers", func(t *testing.T) { - policy := Policy{ - ID: "policy-source-destination-peers", - Enabled: true, + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg) + close(done) + }() + + policyWithSourceAndDestinationPeers, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ + AccountID: account.Id, + Enabled: true, Rules: []*PolicyRule{ { - ID: xid.New().String(), Enabled: true, Sources: []string{"groupA"}, Destinations: []string{"groupD"}, @@ -978,15 +977,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { Action: PolicyTrafficActionAccept, }, }, - } - - done := make(chan struct{}) - go func() { - peerShouldReceiveUpdate(t, updMsg) - close(done) - }() - - err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false) + }) assert.NoError(t, err) select { @@ -999,28 +990,14 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { // Disabling policy with destination and source groups containing peers should update account's peers // and send peer update t.Run("disabling policy with source and destination groups with peers", func(t *testing.T) { - policy := Policy{ - ID: "policy-source-destination-peers", - Enabled: false, - Rules: []*PolicyRule{ - { - ID: xid.New().String(), - Enabled: true, - Sources: []string{"groupA"}, - Destinations: []string{"groupD"}, - Bidirectional: true, - Action: PolicyTrafficActionAccept, - }, - }, - } - done := make(chan struct{}) go func() { peerShouldReceiveUpdate(t, updMsg) close(done) }() - err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, true) + policyWithSourceAndDestinationPeers.Enabled = false + policyWithSourceAndDestinationPeers, err = manager.SavePolicy(context.Background(), account.Id, userID, policyWithSourceAndDestinationPeers) assert.NoError(t, err) select { @@ -1033,29 +1010,15 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { // Updating disabled policy with destination and source groups containing peers should not update account's peers // or send peer update t.Run("updating disabled policy with source and destination groups with peers", func(t *testing.T) { - policy := Policy{ - ID: "policy-source-destination-peers", - Description: "updated description", - Enabled: false, - Rules: []*PolicyRule{ - { - ID: xid.New().String(), - Enabled: true, - Sources: []string{"groupA"}, - Destinations: []string{"groupA"}, - Bidirectional: true, - Action: PolicyTrafficActionAccept, - }, - }, - } - done := make(chan struct{}) go func() { peerShouldNotReceiveUpdate(t, updMsg) close(done) }() - err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, true) + policyWithSourceAndDestinationPeers.Description = "updated description" + policyWithSourceAndDestinationPeers.Rules[0].Destinations = []string{"groupA"} + policyWithSourceAndDestinationPeers, err = manager.SavePolicy(context.Background(), account.Id, userID, policyWithSourceAndDestinationPeers) assert.NoError(t, err) select { @@ -1068,28 +1031,14 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { // Enabling policy with destination and source groups containing peers should update account's peers // and send peer update t.Run("enabling policy with source and destination groups with peers", func(t *testing.T) { - policy := Policy{ - ID: "policy-source-destination-peers", - Enabled: true, - Rules: []*PolicyRule{ - { - ID: xid.New().String(), - Enabled: true, - Sources: []string{"groupA"}, - Destinations: []string{"groupD"}, - Bidirectional: true, - Action: PolicyTrafficActionAccept, - }, - }, - } - done := make(chan struct{}) go func() { peerShouldReceiveUpdate(t, updMsg) close(done) }() - err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, true) + policyWithSourceAndDestinationPeers.Enabled = true + policyWithSourceAndDestinationPeers, err = manager.SavePolicy(context.Background(), account.Id, userID, policyWithSourceAndDestinationPeers) assert.NoError(t, err) select { @@ -1101,15 +1050,13 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { // Deleting policy should trigger account peers update and send peer update t.Run("deleting policy with source and destination groups with peers", func(t *testing.T) { - policyID := "policy-source-destination-peers" - done := make(chan struct{}) go func() { peerShouldReceiveUpdate(t, updMsg) close(done) }() - err := manager.DeletePolicy(context.Background(), account.Id, policyID, userID) + err := manager.DeletePolicy(context.Background(), account.Id, policyWithSourceAndDestinationPeers.ID, userID) assert.NoError(t, err) select { @@ -1123,14 +1070,13 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { // Deleting policy with destination group containing peers, but source group without peers should // update account's peers and send peer update t.Run("deleting policy where destination has peers but source does not", func(t *testing.T) { - policyID := "policy-destination-has-peers-source-none" done := make(chan struct{}) go func() { peerShouldReceiveUpdate(t, updMsg) close(done) }() - err := manager.DeletePolicy(context.Background(), account.Id, policyID, userID) + err := manager.DeletePolicy(context.Background(), account.Id, policyWithDestinationPeersOnly.ID, userID) assert.NoError(t, err) select { @@ -1142,14 +1088,13 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { // Deleting policy with no peers in groups should not update account's peers and not send peer update t.Run("deleting policy with no peers in groups", func(t *testing.T) { - policyID := "policy-rule-groups-no-peers" done := make(chan struct{}) go func() { peerShouldNotReceiveUpdate(t, updMsg) close(done) }() - err := manager.DeletePolicy(context.Background(), account.Id, policyID, userID) + err := manager.DeletePolicy(context.Background(), account.Id, policyWithGroupRulesNoPeers.ID, userID) assert.NoError(t, err) select { diff --git a/management/server/posture/checks.go b/management/server/posture/checks.go index f2739dddf..b2f308d76 100644 --- a/management/server/posture/checks.go +++ b/management/server/posture/checks.go @@ -7,8 +7,6 @@ import ( "regexp" "github.com/hashicorp/go-version" - "github.com/rs/xid" - "github.com/netbirdio/netbird/management/server/http/api" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/status" @@ -172,10 +170,6 @@ func NewChecksFromAPIPostureCheckUpdate(source api.PostureCheckUpdate, postureCh } func buildPostureCheck(postureChecksID string, name string, description string, checks api.Checks) (*Checks, error) { - if postureChecksID == "" { - postureChecksID = xid.New().String() - } - postureChecks := Checks{ ID: postureChecksID, Name: name, diff --git a/management/server/posture_checks.go b/management/server/posture_checks.go index 096cff3f5..0467efedb 100644 --- a/management/server/posture_checks.go +++ b/management/server/posture_checks.go @@ -2,237 +2,285 @@ package server import ( "context" + "errors" + "fmt" "slices" + "github.com/rs/xid" + "golang.org/x/exp/maps" + "github.com/netbirdio/netbird/management/server/activity" - nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/status" ) -const ( - errMsgPostureAdminOnly = "only users with admin power are allowed to view posture checks" -) - func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) { user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return nil, err } - if !user.HasAdminPower() || user.AccountID != accountID { - return nil, status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly) - } - - return am.Store.GetPostureChecksByID(ctx, LockingStrengthShare, postureChecksID, accountID) -} - -func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - - account, err := am.Store.GetAccount(ctx, accountID) - if err != nil { - return err - } - - user, err := account.FindUser(userID) - if err != nil { - return err + if user.AccountID != accountID { + return nil, status.NewUserNotPartOfAccountError() } if !user.HasAdminPower() { - return status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly) + return nil, status.NewAdminPermissionError() } - if err := postureChecks.Validate(); err != nil { - return status.Errorf(status.InvalidArgument, err.Error()) //nolint + return am.Store.GetPostureChecksByID(ctx, LockingStrengthShare, accountID, postureChecksID) +} + +// SavePostureChecks saves a posture check. +func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) (*posture.Checks, error) { + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + if err != nil { + return nil, err } - exists, uniqName := am.savePostureChecks(account, postureChecks) - - // we do not allow create new posture checks with non uniq name - if !exists && !uniqName { - return status.Errorf(status.PreconditionFailed, "Posture check name should be unique") + if user.AccountID != accountID { + return nil, status.NewUserNotPartOfAccountError() } - action := activity.PostureCheckCreated - if exists { - action = activity.PostureCheckUpdated - account.Network.IncSerial() + if !user.HasAdminPower() { + return nil, status.NewAdminPermissionError() } - if err = am.Store.SaveAccount(ctx, account); err != nil { - return err + var updateAccountPeers bool + var isUpdate = postureChecks.ID != "" + var action = activity.PostureCheckCreated + + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + if err = validatePostureChecks(ctx, transaction, accountID, postureChecks); err != nil { + return err + } + + if isUpdate { + updateAccountPeers, err = arePostureCheckChangesAffectPeers(ctx, transaction, accountID, postureChecks.ID) + if err != nil { + return err + } + + if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + return err + } + + action = activity.PostureCheckUpdated + } + + postureChecks.AccountID = accountID + return transaction.SavePostureChecks(ctx, LockingStrengthUpdate, postureChecks) + }) + if err != nil { + return nil, err } am.StoreEvent(ctx, userID, postureChecks.ID, accountID, action, postureChecks.EventMeta()) - if arePostureCheckChangesAffectingPeers(account, postureChecks.ID, exists) { + if updateAccountPeers { am.updateAccountPeers(ctx, accountID) } - return nil + return postureChecks, nil } +// DeletePostureChecks deletes a posture check by ID. func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accountID, postureChecksID, userID string) error { unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() - account, err := am.Store.GetAccount(ctx, accountID) + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return err } - user, err := account.FindUser(userID) - if err != nil { - return err + if user.AccountID != accountID { + return status.NewUserNotPartOfAccountError() } if !user.HasAdminPower() { - return status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly) + return status.NewAdminPermissionError() } - postureChecks, err := am.deletePostureChecks(account, postureChecksID) + var postureChecks *posture.Checks + + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + postureChecks, err = transaction.GetPostureChecksByID(ctx, LockingStrengthShare, accountID, postureChecksID) + if err != nil { + return err + } + + if err = isPostureCheckLinkedToPolicy(ctx, transaction, postureChecksID, accountID); err != nil { + return err + } + + if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + return err + } + + return transaction.DeletePostureChecks(ctx, LockingStrengthUpdate, accountID, postureChecksID) + }) if err != nil { return err } - if err = am.Store.SaveAccount(ctx, account); err != nil { - return err - } - am.StoreEvent(ctx, userID, postureChecks.ID, accountID, activity.PostureCheckDeleted, postureChecks.EventMeta()) return nil } +// ListPostureChecks returns a list of posture checks. func (am *DefaultAccountManager) ListPostureChecks(ctx context.Context, accountID, userID string) ([]*posture.Checks, error) { user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return nil, err } - if !user.HasAdminPower() || user.AccountID != accountID { - return nil, status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly) + if user.AccountID != accountID { + return nil, status.NewUserNotPartOfAccountError() + } + + if !user.HasAdminPower() { + return nil, status.NewAdminPermissionError() } return am.Store.GetAccountPostureChecks(ctx, LockingStrengthShare, accountID) } -func (am *DefaultAccountManager) savePostureChecks(account *Account, postureChecks *posture.Checks) (exists, uniqName bool) { - uniqName = true - for i, p := range account.PostureChecks { - if !exists && p.ID == postureChecks.ID { - account.PostureChecks[i] = postureChecks - exists = true - } - if p.Name == postureChecks.Name { - uniqName = false - } - } - if !exists { - account.PostureChecks = append(account.PostureChecks, postureChecks) - } - return -} - -func (am *DefaultAccountManager) deletePostureChecks(account *Account, postureChecksID string) (*posture.Checks, error) { - postureChecksIdx := -1 - for i, postureChecks := range account.PostureChecks { - if postureChecks.ID == postureChecksID { - postureChecksIdx = i - break - } - } - if postureChecksIdx < 0 { - return nil, status.Errorf(status.NotFound, "posture checks with ID %s doesn't exist", postureChecksID) - } - - // Check if posture check is linked to any policy - if isLinked, linkedPolicy := isPostureCheckLinkedToPolicy(account, postureChecksID); isLinked { - return nil, status.Errorf(status.PreconditionFailed, "posture checks have been linked to policy: %s", linkedPolicy.Name) - } - - postureChecks := account.PostureChecks[postureChecksIdx] - account.PostureChecks = append(account.PostureChecks[:postureChecksIdx], account.PostureChecks[postureChecksIdx+1:]...) - - return postureChecks, nil -} - // getPeerPostureChecks returns the posture checks applied for a given peer. -func (am *DefaultAccountManager) getPeerPostureChecks(account *Account, peer *nbpeer.Peer) []*posture.Checks { - peerPostureChecks := make(map[string]posture.Checks) +func (am *DefaultAccountManager) getPeerPostureChecks(account *Account, peerID string) ([]*posture.Checks, error) { + peerPostureChecks := make(map[string]*posture.Checks) if len(account.PostureChecks) == 0 { - return nil + return nil, nil } for _, policy := range account.Policies { - if !policy.Enabled { + if !policy.Enabled || len(policy.SourcePostureChecks) == 0 { continue } - if isPeerInPolicySourceGroups(peer.ID, account, policy) { - addPolicyPostureChecks(account, policy, peerPostureChecks) + if err := addPolicyPostureChecks(account, peerID, policy, peerPostureChecks); err != nil { + return nil, err } } - postureChecksList := make([]*posture.Checks, 0, len(peerPostureChecks)) - for _, check := range peerPostureChecks { - checkCopy := check - postureChecksList = append(postureChecksList, &checkCopy) + return maps.Values(peerPostureChecks), nil +} + +// arePostureCheckChangesAffectPeers checks if the changes in posture checks are affecting peers. +func arePostureCheckChangesAffectPeers(ctx context.Context, transaction Store, accountID, postureCheckID string) (bool, error) { + policies, err := transaction.GetAccountPolicies(ctx, LockingStrengthShare, accountID) + if err != nil { + return false, err } - return postureChecksList + for _, policy := range policies { + if slices.Contains(policy.SourcePostureChecks, postureCheckID) { + hasPeers, err := anyGroupHasPeers(ctx, transaction, accountID, policy.ruleGroups()) + if err != nil { + return false, err + } + + if hasPeers { + return true, nil + } + } + } + + return false, nil +} + +// validatePostureChecks validates the posture checks. +func validatePostureChecks(ctx context.Context, transaction Store, accountID string, postureChecks *posture.Checks) error { + if err := postureChecks.Validate(); err != nil { + return status.Errorf(status.InvalidArgument, err.Error()) //nolint + } + + // If the posture check already has an ID, verify its existence in the store. + if postureChecks.ID != "" { + if _, err := transaction.GetPostureChecksByID(ctx, LockingStrengthShare, accountID, postureChecks.ID); err != nil { + return err + } + return nil + } + + // For new posture checks, ensure no duplicates by name. + checks, err := transaction.GetAccountPostureChecks(ctx, LockingStrengthShare, accountID) + if err != nil { + return err + } + + for _, check := range checks { + if check.Name == postureChecks.Name && check.ID != postureChecks.ID { + return status.Errorf(status.InvalidArgument, "posture checks with name %s already exists", postureChecks.Name) + } + } + + postureChecks.ID = xid.New().String() + + return nil +} + +// addPolicyPostureChecks adds posture checks from a policy to the peer posture checks map if the peer is in the policy's source groups. +func addPolicyPostureChecks(account *Account, peerID string, policy *Policy, peerPostureChecks map[string]*posture.Checks) error { + isInGroup, err := isPeerInPolicySourceGroups(account, peerID, policy) + if err != nil { + return err + } + + if !isInGroup { + return nil + } + + for _, sourcePostureCheckID := range policy.SourcePostureChecks { + postureCheck := account.getPostureChecks(sourcePostureCheckID) + if postureCheck == nil { + return errors.New("failed to add policy posture checks: posture checks not found") + } + peerPostureChecks[sourcePostureCheckID] = postureCheck + } + + return nil } // isPeerInPolicySourceGroups checks if a peer is present in any of the policy rule source groups. -func isPeerInPolicySourceGroups(peerID string, account *Account, policy *Policy) bool { +func isPeerInPolicySourceGroups(account *Account, peerID string, policy *Policy) (bool, error) { for _, rule := range policy.Rules { if !rule.Enabled { continue } for _, sourceGroup := range rule.Sources { - group, ok := account.Groups[sourceGroup] - if ok && slices.Contains(group.Peers, peerID) { - return true + group := account.GetGroup(sourceGroup) + if group == nil { + return false, fmt.Errorf("failed to check peer in policy source group: group not found") + } + + if slices.Contains(group.Peers, peerID) { + return true, nil } } } - return false -} - -func addPolicyPostureChecks(account *Account, policy *Policy, peerPostureChecks map[string]posture.Checks) { - for _, sourcePostureCheckID := range policy.SourcePostureChecks { - for _, postureCheck := range account.PostureChecks { - if postureCheck.ID == sourcePostureCheckID { - peerPostureChecks[sourcePostureCheckID] = *postureCheck - } - } - } -} - -func isPostureCheckLinkedToPolicy(account *Account, postureChecksID string) (bool, *Policy) { - for _, policy := range account.Policies { - if slices.Contains(policy.SourcePostureChecks, postureChecksID) { - return true, policy - } - } return false, nil } -// arePostureCheckChangesAffectingPeers checks if the changes in posture checks are affecting peers. -func arePostureCheckChangesAffectingPeers(account *Account, postureCheckID string, exists bool) bool { - if !exists { - return false +// isPostureCheckLinkedToPolicy checks whether the posture check is linked to any account policy. +func isPostureCheckLinkedToPolicy(ctx context.Context, transaction Store, postureChecksID, accountID string) error { + policies, err := transaction.GetAccountPolicies(ctx, LockingStrengthShare, accountID) + if err != nil { + return err } - isLinked, linkedPolicy := isPostureCheckLinkedToPolicy(account, postureCheckID) - if !isLinked { - return false + for _, policy := range policies { + if slices.Contains(policy.SourcePostureChecks, postureChecksID) { + return status.Errorf(status.PreconditionFailed, "posture checks have been linked to policy: %s", policy.Name) + } } - return anyGroupHasPeers(account, linkedPolicy.ruleGroups()) + + return nil } diff --git a/management/server/posture_checks_test.go b/management/server/posture_checks_test.go index c63538b9d..93e5741cf 100644 --- a/management/server/posture_checks_test.go +++ b/management/server/posture_checks_test.go @@ -5,8 +5,8 @@ import ( "testing" "time" - "github.com/rs/xid" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/netbirdio/netbird/management/server/group" @@ -16,7 +16,6 @@ import ( const ( adminUserID = "adminUserID" regularUserID = "regularUserID" - postureCheckID = "existing-id" postureCheckName = "Existing check" ) @@ -33,7 +32,7 @@ func TestDefaultAccountManager_PostureCheck(t *testing.T) { t.Run("Generic posture check flow", func(t *testing.T) { // regular users can not create checks - err := am.SavePostureChecks(context.Background(), account.Id, regularUserID, &posture.Checks{}) + _, err = am.SavePostureChecks(context.Background(), account.Id, regularUserID, &posture.Checks{}) assert.Error(t, err) // regular users cannot list check @@ -41,8 +40,7 @@ func TestDefaultAccountManager_PostureCheck(t *testing.T) { assert.Error(t, err) // should be possible to create posture check with uniq name - err = am.SavePostureChecks(context.Background(), account.Id, adminUserID, &posture.Checks{ - ID: postureCheckID, + postureCheck, err := am.SavePostureChecks(context.Background(), account.Id, adminUserID, &posture.Checks{ Name: postureCheckName, Checks: posture.ChecksDefinition{ NBVersionCheck: &posture.NBVersionCheck{ @@ -58,8 +56,7 @@ func TestDefaultAccountManager_PostureCheck(t *testing.T) { assert.Len(t, checks, 1) // should not be possible to create posture check with non uniq name - err = am.SavePostureChecks(context.Background(), account.Id, adminUserID, &posture.Checks{ - ID: "new-id", + _, err = am.SavePostureChecks(context.Background(), account.Id, adminUserID, &posture.Checks{ Name: postureCheckName, Checks: posture.ChecksDefinition{ GeoLocationCheck: &posture.GeoLocationCheck{ @@ -74,23 +71,20 @@ func TestDefaultAccountManager_PostureCheck(t *testing.T) { assert.Error(t, err) // admins can update posture checks - err = am.SavePostureChecks(context.Background(), account.Id, adminUserID, &posture.Checks{ - ID: postureCheckID, - Name: postureCheckName, - Checks: posture.ChecksDefinition{ - NBVersionCheck: &posture.NBVersionCheck{ - MinVersion: "0.27.0", - }, + postureCheck.Checks = posture.ChecksDefinition{ + NBVersionCheck: &posture.NBVersionCheck{ + MinVersion: "0.27.0", }, - }) + } + _, err = am.SavePostureChecks(context.Background(), account.Id, adminUserID, postureCheck) assert.NoError(t, err) // users should not be able to delete posture checks - err = am.DeletePostureChecks(context.Background(), account.Id, postureCheckID, regularUserID) + err = am.DeletePostureChecks(context.Background(), account.Id, postureCheck.ID, regularUserID) assert.Error(t, err) // admin should be able to delete posture checks - err = am.DeletePostureChecks(context.Background(), account.Id, postureCheckID, adminUserID) + err = am.DeletePostureChecks(context.Background(), account.Id, postureCheck.ID, adminUserID) assert.NoError(t, err) checks, err = am.ListPostureChecks(context.Background(), account.Id, adminUserID) assert.NoError(t, err) @@ -150,9 +144,22 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) }) - postureCheck := posture.Checks{ - ID: "postureCheck", - Name: "postureCheck", + postureCheckA := &posture.Checks{ + Name: "postureCheckA", + AccountID: account.Id, + Checks: posture.ChecksDefinition{ + ProcessCheck: &posture.ProcessCheck{ + Processes: []posture.Process{ + {LinuxPath: "/usr/bin/netbird", MacPath: "/usr/local/bin/netbird"}, + }, + }, + }, + } + postureCheckA, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckA) + require.NoError(t, err) + + postureCheckB := &posture.Checks{ + Name: "postureCheckB", AccountID: account.Id, Checks: posture.ChecksDefinition{ NBVersionCheck: &posture.NBVersionCheck{ @@ -169,7 +176,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { close(done) }() - err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck) + postureCheckB, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB) assert.NoError(t, err) select { @@ -187,12 +194,12 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { close(done) }() - postureCheck.Checks = posture.ChecksDefinition{ + postureCheckB.Checks = posture.ChecksDefinition{ NBVersionCheck: &posture.NBVersionCheck{ MinVersion: "0.29.0", }, } - err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck) + _, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB) assert.NoError(t, err) select { @@ -202,12 +209,10 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { } }) - policy := Policy{ - ID: "policyA", + policy := &Policy{ Enabled: true, Rules: []*PolicyRule{ { - ID: xid.New().String(), Enabled: true, Sources: []string{"groupA"}, Destinations: []string{"groupA"}, @@ -215,7 +220,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { Action: PolicyTrafficActionAccept, }, }, - SourcePostureChecks: []string{postureCheck.ID}, + SourcePostureChecks: []string{postureCheckB.ID}, } // Linking posture check to policy should trigger update account peers and send peer update @@ -226,7 +231,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { close(done) }() - err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false) + policy, err = manager.SavePolicy(context.Background(), account.Id, userID, policy) assert.NoError(t, err) select { @@ -238,7 +243,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { // Updating linked posture checks should update account peers and send peer update t.Run("updating linked to posture check with peers", func(t *testing.T) { - postureCheck.Checks = posture.ChecksDefinition{ + postureCheckB.Checks = posture.ChecksDefinition{ NBVersionCheck: &posture.NBVersionCheck{ MinVersion: "0.29.0", }, @@ -255,7 +260,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { close(done) }() - err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck) + _, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB) assert.NoError(t, err) select { @@ -274,8 +279,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { }() policy.SourcePostureChecks = []string{} - - err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, true) + _, err := manager.SavePolicy(context.Background(), account.Id, userID, policy) assert.NoError(t, err) select { @@ -293,7 +297,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { close(done) }() - err := manager.DeletePostureChecks(context.Background(), account.Id, "postureCheck", userID) + err := manager.DeletePostureChecks(context.Background(), account.Id, postureCheckA.ID, userID) assert.NoError(t, err) select { @@ -303,17 +307,15 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { } }) - err = manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck) + _, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB) assert.NoError(t, err) // Updating linked posture check to policy with no peers should not trigger account peers update and not send peer update t.Run("updating linked posture check to policy with no peers", func(t *testing.T) { - policy = Policy{ - ID: "policyB", + _, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ Enabled: true, Rules: []*PolicyRule{ { - ID: xid.New().String(), Enabled: true, Sources: []string{"groupB"}, Destinations: []string{"groupC"}, @@ -321,9 +323,8 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { Action: PolicyTrafficActionAccept, }, }, - SourcePostureChecks: []string{postureCheck.ID}, - } - err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, false) + SourcePostureChecks: []string{postureCheckB.ID}, + }) assert.NoError(t, err) done := make(chan struct{}) @@ -332,12 +333,12 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { close(done) }() - postureCheck.Checks = posture.ChecksDefinition{ + postureCheckB.Checks = posture.ChecksDefinition{ NBVersionCheck: &posture.NBVersionCheck{ MinVersion: "0.29.0", }, } - err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck) + _, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB) assert.NoError(t, err) select { @@ -354,12 +355,11 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { t.Cleanup(func() { manager.peersUpdateManager.CloseChannel(context.Background(), peer2.ID) }) - policy = Policy{ - ID: "policyB", + + _, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ Enabled: true, Rules: []*PolicyRule{ { - ID: xid.New().String(), Enabled: true, Sources: []string{"groupB"}, Destinations: []string{"groupA"}, @@ -367,10 +367,8 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { Action: PolicyTrafficActionAccept, }, }, - SourcePostureChecks: []string{postureCheck.ID}, - } - - err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, true) + SourcePostureChecks: []string{postureCheckB.ID}, + }) assert.NoError(t, err) done := make(chan struct{}) @@ -379,12 +377,12 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { close(done) }() - postureCheck.Checks = posture.ChecksDefinition{ + postureCheckB.Checks = posture.ChecksDefinition{ NBVersionCheck: &posture.NBVersionCheck{ MinVersion: "0.29.0", }, } - err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck) + _, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB) assert.NoError(t, err) select { @@ -397,8 +395,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { // Updating linked client posture check to policy where source has peers but destination does not, // should trigger account peers update and send peer update t.Run("updating linked posture check to policy where source has peers but destination does not", func(t *testing.T) { - policy = Policy{ - ID: "policyB", + _, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ Enabled: true, Rules: []*PolicyRule{ { @@ -409,9 +406,8 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { Action: PolicyTrafficActionAccept, }, }, - SourcePostureChecks: []string{postureCheck.ID}, - } - err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, true) + SourcePostureChecks: []string{postureCheckB.ID}, + }) assert.NoError(t, err) done := make(chan struct{}) @@ -420,7 +416,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { close(done) }() - postureCheck.Checks = posture.ChecksDefinition{ + postureCheckB.Checks = posture.ChecksDefinition{ ProcessCheck: &posture.ProcessCheck{ Processes: []posture.Process{ { @@ -429,7 +425,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { }, }, } - err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck) + _, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB) assert.NoError(t, err) select { @@ -440,80 +436,120 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { }) } -func TestArePostureCheckChangesAffectingPeers(t *testing.T) { - account := &Account{ - Policies: []*Policy{ - { - ID: "policyA", - Rules: []*PolicyRule{ - { - Enabled: true, - Sources: []string{"groupA"}, - Destinations: []string{"groupA"}, - }, - }, - SourcePostureChecks: []string{"checkA"}, - }, - }, - Groups: map[string]*group.Group{ - "groupA": { - ID: "groupA", - Peers: []string{"peer1"}, - }, - "groupB": { - ID: "groupB", - Peers: []string{}, - }, - }, - PostureChecks: []*posture.Checks{ - { - ID: "checkA", - }, - { - ID: "checkB", - }, - }, +func TestArePostureCheckChangesAffectPeers(t *testing.T) { + manager, err := createManager(t) + require.NoError(t, err, "failed to create account manager") + + account, err := initTestPostureChecksAccount(manager) + require.NoError(t, err, "failed to init testing account") + + groupA := &group.Group{ + ID: "groupA", + AccountID: account.Id, + Peers: []string{"peer1"}, } + groupB := &group.Group{ + ID: "groupB", + AccountID: account.Id, + Peers: []string{}, + } + err = manager.Store.SaveGroups(context.Background(), LockingStrengthUpdate, []*group.Group{groupA, groupB}) + require.NoError(t, err, "failed to save groups") + + postureCheckA := &posture.Checks{ + Name: "checkA", + AccountID: account.Id, + Checks: posture.ChecksDefinition{ + NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.33.1"}, + }, + } + postureCheckA, err = manager.SavePostureChecks(context.Background(), account.Id, adminUserID, postureCheckA) + require.NoError(t, err, "failed to save postureCheckA") + + postureCheckB := &posture.Checks{ + Name: "checkB", + AccountID: account.Id, + Checks: posture.ChecksDefinition{ + NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.33.1"}, + }, + } + postureCheckB, err = manager.SavePostureChecks(context.Background(), account.Id, adminUserID, postureCheckB) + require.NoError(t, err, "failed to save postureCheckB") + + policy := &Policy{ + AccountID: account.Id, + Rules: []*PolicyRule{ + { + Enabled: true, + Sources: []string{"groupA"}, + Destinations: []string{"groupA"}, + }, + }, + SourcePostureChecks: []string{postureCheckA.ID}, + } + + policy, err = manager.SavePolicy(context.Background(), account.Id, adminUserID, policy) + require.NoError(t, err, "failed to save policy") + t.Run("posture check exists and is linked to policy with peers", func(t *testing.T) { - result := arePostureCheckChangesAffectingPeers(account, "checkA", true) + result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID) + require.NoError(t, err) assert.True(t, result) }) t.Run("posture check exists but is not linked to any policy", func(t *testing.T) { - result := arePostureCheckChangesAffectingPeers(account, "checkB", true) + result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckB.ID) + require.NoError(t, err) assert.False(t, result) }) t.Run("posture check does not exist", func(t *testing.T) { - result := arePostureCheckChangesAffectingPeers(account, "unknown", false) + result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, "unknown") + require.NoError(t, err) assert.False(t, result) }) t.Run("posture check is linked to policy with no peers in source groups", func(t *testing.T) { - account.Policies[0].Rules[0].Sources = []string{"groupB"} - account.Policies[0].Rules[0].Destinations = []string{"groupA"} - result := arePostureCheckChangesAffectingPeers(account, "checkA", true) + policy.Rules[0].Sources = []string{"groupB"} + policy.Rules[0].Destinations = []string{"groupA"} + _, err = manager.SavePolicy(context.Background(), account.Id, adminUserID, policy) + require.NoError(t, err, "failed to update policy") + + result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID) + require.NoError(t, err) assert.True(t, result) }) t.Run("posture check is linked to policy with no peers in destination groups", func(t *testing.T) { - account.Policies[0].Rules[0].Sources = []string{"groupA"} - account.Policies[0].Rules[0].Destinations = []string{"groupB"} - result := arePostureCheckChangesAffectingPeers(account, "checkA", true) + policy.Rules[0].Sources = []string{"groupA"} + policy.Rules[0].Destinations = []string{"groupB"} + _, err = manager.SavePolicy(context.Background(), account.Id, adminUserID, policy) + require.NoError(t, err, "failed to update policy") + + result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID) + require.NoError(t, err) assert.True(t, result) }) - t.Run("posture check is linked to policy with non-existent group", func(t *testing.T) { - account.Policies[0].Rules[0].Sources = []string{"nonExistentGroup"} - account.Policies[0].Rules[0].Destinations = []string{"nonExistentGroup"} - result := arePostureCheckChangesAffectingPeers(account, "checkA", true) + t.Run("posture check is linked to policy but no peers in groups", func(t *testing.T) { + groupA.Peers = []string{} + err = manager.Store.SaveGroup(context.Background(), LockingStrengthUpdate, groupA) + require.NoError(t, err, "failed to save groups") + + result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID) + require.NoError(t, err) assert.False(t, result) }) - t.Run("posture check is linked to policy but no peers in groups", func(t *testing.T) { - account.Groups["groupA"].Peers = []string{} - result := arePostureCheckChangesAffectingPeers(account, "checkA", true) + t.Run("posture check is linked to policy with non-existent group", func(t *testing.T) { + policy.Rules[0].Sources = []string{"nonExistentGroup"} + policy.Rules[0].Destinations = []string{"nonExistentGroup"} + _, err = manager.SavePolicy(context.Background(), account.Id, adminUserID, policy) + require.NoError(t, err, "failed to update policy") + + result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID) + require.NoError(t, err) assert.False(t, result) }) } diff --git a/management/server/route.go b/management/server/route.go index dcf2cb0d3..ecb562645 100644 --- a/management/server/route.go +++ b/management/server/route.go @@ -237,7 +237,7 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri return nil, err } - if isRouteChangeAffectPeers(account, &newRoute) { + if am.isRouteChangeAffectPeers(account, &newRoute) { am.updateAccountPeers(ctx, accountID) } @@ -323,7 +323,7 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI return err } - if isRouteChangeAffectPeers(account, oldRoute) || isRouteChangeAffectPeers(account, routeToSave) { + if am.isRouteChangeAffectPeers(account, oldRoute) || am.isRouteChangeAffectPeers(account, routeToSave) { am.updateAccountPeers(ctx, accountID) } @@ -355,7 +355,7 @@ func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID stri am.StoreEvent(ctx, userID, string(routy.ID), accountID, activity.RouteRemoved, routy.EventMeta()) - if isRouteChangeAffectPeers(account, routy) { + if am.isRouteChangeAffectPeers(account, routy) { am.updateAccountPeers(ctx, accountID) } @@ -651,6 +651,6 @@ func getProtoPortInfo(rule *RouteFirewallRule) *proto.PortInfo { // isRouteChangeAffectPeers checks if a given route affects peers by determining // if it has a routing peer, distribution, or peer groups that include peers -func isRouteChangeAffectPeers(account *Account, route *route.Route) bool { - return anyGroupHasPeers(account, route.Groups) || anyGroupHasPeers(account, route.PeerGroups) || route.Peer != "" +func (am *DefaultAccountManager) isRouteChangeAffectPeers(account *Account, route *route.Route) bool { + return am.anyGroupHasPeers(account, route.Groups) || am.anyGroupHasPeers(account, route.PeerGroups) || route.Peer != "" } diff --git a/management/server/route_test.go b/management/server/route_test.go index 5c848f68c..108f791e0 100644 --- a/management/server/route_test.go +++ b/management/server/route_test.go @@ -1214,12 +1214,11 @@ func TestGetNetworkMap_RouteSync(t *testing.T) { defaultRule := rules[0] newPolicy := defaultRule.Copy() - newPolicy.ID = xid.New().String() newPolicy.Name = "peer1 only" newPolicy.Rules[0].Sources = []string{newGroup.ID} newPolicy.Rules[0].Destinations = []string{newGroup.ID} - err = am.SavePolicy(context.Background(), account.Id, userID, newPolicy, false) + _, err = am.SavePolicy(context.Background(), account.Id, userID, newPolicy) require.NoError(t, err) err = am.DeletePolicy(context.Background(), account.Id, defaultRule.ID, userID) diff --git a/management/server/setupkey_test.go b/management/server/setupkey_test.go index 7c8200706..614547c60 100644 --- a/management/server/setupkey_test.go +++ b/management/server/setupkey_test.go @@ -406,8 +406,7 @@ func TestSetupKeyAccountPeersUpdate(t *testing.T) { }) assert.NoError(t, err) - policy := Policy{ - ID: "policy", + policy := &Policy{ Enabled: true, Rules: []*PolicyRule{ { @@ -419,7 +418,7 @@ func TestSetupKeyAccountPeersUpdate(t *testing.T) { }, }, } - err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, false) + _, err = manager.SavePolicy(context.Background(), account.Id, userID, policy) require.NoError(t, err) updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) diff --git a/management/server/sql_store.go b/management/server/sql_store.go index a15caee0e..aecebf11b 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -1225,9 +1225,10 @@ func (s *SqlStore) GetAccountDNSSettings(ctx context.Context, lockStrength Locki First(&accountDNSSettings, idQueryCondition, accountID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { - return nil, status.Errorf(status.NotFound, "dns settings not found") + return nil, status.NewAccountNotFoundError(accountID) } - return nil, status.Errorf(status.Internal, "failed to get dns settings from store: %v", result.Error) + log.WithContext(ctx).Errorf("failed to get dns settings from store: %v", result.Error) + return nil, status.Errorf(status.Internal, "failed to get dns settings from store") } return &accountDNSSettings.DNSSettings, nil } @@ -1310,8 +1311,8 @@ func (s *SqlStore) GetGroupsByIDs(ctx context.Context, lockStrength LockingStren var groups []*nbgroup.Group result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&groups, accountAndIDsQueryCondition, accountID, groupIDs) if result.Error != nil { - log.WithContext(ctx).Errorf("failed to get groups by ID's from the store: %s", result.Error) - return nil, status.Errorf(status.Internal, "failed to get groups by ID's from the store") + log.WithContext(ctx).Errorf("failed to get groups by ID's from store: %s", result.Error) + return nil, status.Errorf(status.Internal, "failed to get groups by ID's from store") } groupsMap := make(map[string]*nbgroup.Group) @@ -1362,22 +1363,139 @@ func (s *SqlStore) DeleteGroups(ctx context.Context, strength LockingStrength, a // GetAccountPolicies retrieves policies for an account. func (s *SqlStore) GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) { - return getRecords[*Policy](s.db.Preload(clause.Associations), lockStrength, accountID) + var policies []*Policy + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + Preload(clause.Associations).Find(&policies, accountIDCondition, accountID) + if err := result.Error; err != nil { + log.WithContext(ctx).Errorf("failed to get policies from the store: %s", result.Error) + return nil, status.Errorf(status.Internal, "failed to get policies from store") + } + + return policies, nil } // GetPolicyByID retrieves a policy by its ID and account ID. -func (s *SqlStore) GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error) { - return getRecordByID[Policy](s.db.Preload(clause.Associations), lockStrength, policyID, accountID) +func (s *SqlStore) GetPolicyByID(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) (*Policy, error) { + var policy *Policy + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Preload(clause.Associations). + First(&policy, accountAndIDQueryCondition, accountID, policyID) + if err := result.Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, status.NewPolicyNotFoundError(policyID) + } + log.WithContext(ctx).Errorf("failed to get policy from store: %s", err) + return nil, status.Errorf(status.Internal, "failed to get policy from store") + } + + return policy, nil +} + +func (s *SqlStore) CreatePolicy(ctx context.Context, lockStrength LockingStrength, policy *Policy) error { + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Create(policy) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to create policy in store: %s", result.Error) + return status.Errorf(status.Internal, "failed to create policy in store") + } + + return nil +} + +// SavePolicy saves a policy to the database. +func (s *SqlStore) SavePolicy(ctx context.Context, lockStrength LockingStrength, policy *Policy) error { + result := s.db.Session(&gorm.Session{FullSaveAssociations: true}). + Clauses(clause.Locking{Strength: string(lockStrength)}).Save(policy) + if err := result.Error; err != nil { + log.WithContext(ctx).Errorf("failed to save policy to the store: %s", err) + return status.Errorf(status.Internal, "failed to save policy to store") + } + return nil +} + +func (s *SqlStore) DeletePolicy(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) error { + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + Delete(&Policy{}, accountAndIDQueryCondition, accountID, policyID) + if err := result.Error; err != nil { + log.WithContext(ctx).Errorf("failed to delete policy from store: %s", err) + return status.Errorf(status.Internal, "failed to delete policy from store") + } + + if result.RowsAffected == 0 { + return status.NewPolicyNotFoundError(policyID) + } + + return nil } // GetAccountPostureChecks retrieves posture checks for an account. func (s *SqlStore) GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error) { - return getRecords[*posture.Checks](s.db, lockStrength, accountID) + var postureChecks []*posture.Checks + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&postureChecks, accountIDCondition, accountID) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to get posture checks from store: %s", result.Error) + return nil, status.Errorf(status.Internal, "failed to get posture checks from store") + } + + return postureChecks, nil } // GetPostureChecksByID retrieves posture checks by their ID and account ID. -func (s *SqlStore) GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, postureCheckID string, accountID string) (*posture.Checks, error) { - return getRecordByID[posture.Checks](s.db, lockStrength, postureCheckID, accountID) +func (s *SqlStore) GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, accountID, postureChecksID string) (*posture.Checks, error) { + var postureCheck *posture.Checks + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + First(&postureCheck, accountAndIDQueryCondition, accountID, postureChecksID) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, status.NewPostureChecksNotFoundError(postureChecksID) + } + log.WithContext(ctx).Errorf("failed to get posture check from store: %s", result.Error) + return nil, status.Errorf(status.Internal, "failed to get posture check from store") + } + + return postureCheck, nil +} + +// GetPostureChecksByIDs retrieves posture checks by their IDs and account ID. +func (s *SqlStore) GetPostureChecksByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, postureChecksIDs []string) (map[string]*posture.Checks, error) { + var postureChecks []*posture.Checks + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&postureChecks, accountAndIDsQueryCondition, accountID, postureChecksIDs) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to get posture checks by ID's from store: %s", result.Error) + return nil, status.Errorf(status.Internal, "failed to get posture checks by ID's from store") + } + + postureChecksMap := make(map[string]*posture.Checks) + for _, postureCheck := range postureChecks { + postureChecksMap[postureCheck.ID] = postureCheck + } + + return postureChecksMap, nil +} + +// SavePostureChecks saves a posture checks to the database. +func (s *SqlStore) SavePostureChecks(ctx context.Context, lockStrength LockingStrength, postureCheck *posture.Checks) error { + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(postureCheck) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to save posture checks to store: %s", result.Error) + return status.Errorf(status.Internal, "failed to save posture checks to store") + } + + return nil +} + +// DeletePostureChecks deletes a posture checks from the database. +func (s *SqlStore) DeletePostureChecks(ctx context.Context, lockStrength LockingStrength, accountID, postureChecksID string) error { + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + Delete(&posture.Checks{}, accountAndIDQueryCondition, accountID, postureChecksID) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to delete posture checks from store: %s", result.Error) + return status.Errorf(status.Internal, "failed to delete posture checks from store") + } + + if result.RowsAffected == 0 { + return status.NewPostureChecksNotFoundError(postureChecksID) + } + + return nil } // GetAccountRoutes retrieves network routes for an account. @@ -1447,12 +1565,55 @@ func (s *SqlStore) DeleteSetupKey(ctx context.Context, lockStrength LockingStren // GetAccountNameServerGroups retrieves name server groups for an account. func (s *SqlStore) GetAccountNameServerGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbdns.NameServerGroup, error) { - return getRecords[*nbdns.NameServerGroup](s.db, lockStrength, accountID) + var nsGroups []*nbdns.NameServerGroup + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&nsGroups, accountIDCondition, accountID) + if err := result.Error; err != nil { + log.WithContext(ctx).Errorf("failed to get name server groups from the store: %s", err) + return nil, status.Errorf(status.Internal, "failed to get name server groups from store") + } + + return nsGroups, nil } // GetNameServerGroupByID retrieves a name server group by its ID and account ID. -func (s *SqlStore) GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, nsGroupID string, accountID string) (*nbdns.NameServerGroup, error) { - return getRecordByID[nbdns.NameServerGroup](s.db, lockStrength, nsGroupID, accountID) +func (s *SqlStore) GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, nsGroupID string) (*nbdns.NameServerGroup, error) { + var nsGroup *nbdns.NameServerGroup + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}). + First(&nsGroup, accountAndIDQueryCondition, accountID, nsGroupID) + if err := result.Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, status.NewNameServerGroupNotFoundError(nsGroupID) + } + log.WithContext(ctx).Errorf("failed to get name server group from the store: %s", err) + return nil, status.Errorf(status.Internal, "failed to get name server group from store") + } + + return nsGroup, nil +} + +// SaveNameServerGroup saves a name server group to the database. +func (s *SqlStore) SaveNameServerGroup(ctx context.Context, lockStrength LockingStrength, nameServerGroup *nbdns.NameServerGroup) error { + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(nameServerGroup) + if err := result.Error; err != nil { + log.WithContext(ctx).Errorf("failed to save name server group to the store: %s", err) + return status.Errorf(status.Internal, "failed to save name server group to store") + } + return nil +} + +// DeleteNameServerGroup deletes a name server group from the database. +func (s *SqlStore) DeleteNameServerGroup(ctx context.Context, lockStrength LockingStrength, accountID, nsGroupID string) error { + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Delete(&nbdns.NameServerGroup{}, accountAndIDQueryCondition, accountID, nsGroupID) + if err := result.Error; err != nil { + log.WithContext(ctx).Errorf("failed to delete name server group from the store: %s", err) + return status.Errorf(status.Internal, "failed to delete name server group from store") + } + + if result.RowsAffected == 0 { + return status.NewNameServerGroupNotFoundError(nsGroupID) + } + + return nil } // getRecords retrieves records from the database based on the account ID. @@ -1487,3 +1648,19 @@ func getRecordByID[T any](db *gorm.DB, lockStrength LockingStrength, recordID, a } return &record, nil } + +// SaveDNSSettings saves the DNS settings to the store. +func (s *SqlStore) SaveDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *DNSSettings) error { + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}). + Where(idQueryCondition, accountID).Updates(&AccountDNSSettings{DNSSettings: *settings}) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to save dns settings to store: %v", result.Error) + return status.Errorf(status.Internal, "failed to save dns settings to store") + } + + if result.RowsAffected == 0 { + return status.NewAccountNotFoundError(accountID) + } + + return nil +} diff --git a/management/server/sql_store_test.go b/management/server/sql_store_test.go index 114da1ee6..6064b019f 100644 --- a/management/server/sql_store_test.go +++ b/management/server/sql_store_test.go @@ -16,6 +16,7 @@ import ( "github.com/google/uuid" nbdns "github.com/netbirdio/netbird/dns" nbgroup "github.com/netbirdio/netbird/management/server/group" + "github.com/netbirdio/netbird/management/server/posture" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -1564,3 +1565,489 @@ func TestSqlStore_GetPeersByIDs(t *testing.T) { }) } } + +func TestSqlStore_GetPostureChecksByID(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + tests := []struct { + name string + postureChecksID string + expectError bool + }{ + { + name: "retrieve existing posture checks", + postureChecksID: "csplshq7qv948l48f7t0", + expectError: false, + }, + { + name: "retrieve non-existing posture checks", + postureChecksID: "non-existing", + expectError: true, + }, + { + name: "retrieve with empty posture checks ID", + postureChecksID: "", + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + postureChecks, err := store.GetPostureChecksByID(context.Background(), LockingStrengthShare, accountID, tt.postureChecksID) + if tt.expectError { + require.Error(t, err) + sErr, ok := status.FromError(err) + require.True(t, ok) + require.Equal(t, sErr.Type(), status.NotFound) + require.Nil(t, postureChecks) + } else { + require.NoError(t, err) + require.NotNil(t, postureChecks) + require.Equal(t, tt.postureChecksID, postureChecks.ID) + } + }) + } +} + +func TestSqlStore_GetPostureChecksByIDs(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + tests := []struct { + name string + postureCheckIDs []string + expectedCount int + }{ + { + name: "retrieve existing posture checks by existing IDs", + postureCheckIDs: []string{"csplshq7qv948l48f7t0", "cspnllq7qv95uq1r4k90"}, + expectedCount: 2, + }, + { + name: "empty posture check IDs list", + postureCheckIDs: []string{}, + expectedCount: 0, + }, + { + name: "non-existing posture check IDs", + postureCheckIDs: []string{"nonexistent1", "nonexistent2"}, + expectedCount: 0, + }, + { + name: "mixed existing and non-existing posture check IDs", + postureCheckIDs: []string{"cspnllq7qv95uq1r4k90", "nonexistent"}, + expectedCount: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + groups, err := store.GetPostureChecksByIDs(context.Background(), LockingStrengthShare, accountID, tt.postureCheckIDs) + require.NoError(t, err) + require.Len(t, groups, tt.expectedCount) + }) + } +} + +func TestSqlStore_SavePostureChecks(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + postureChecks := &posture.Checks{ + ID: "posture-checks-id", + AccountID: accountID, + Checks: posture.ChecksDefinition{ + NBVersionCheck: &posture.NBVersionCheck{ + MinVersion: "0.31.0", + }, + OSVersionCheck: &posture.OSVersionCheck{ + Ios: &posture.MinVersionCheck{ + MinVersion: "13.0.1", + }, + Linux: &posture.MinKernelVersionCheck{ + MinKernelVersion: "5.3.3-dev", + }, + }, + GeoLocationCheck: &posture.GeoLocationCheck{ + Locations: []posture.Location{ + { + CountryCode: "DE", + CityName: "Berlin", + }, + }, + Action: posture.CheckActionAllow, + }, + }, + } + err = store.SavePostureChecks(context.Background(), LockingStrengthUpdate, postureChecks) + require.NoError(t, err) + + savePostureChecks, err := store.GetPostureChecksByID(context.Background(), LockingStrengthShare, accountID, "posture-checks-id") + require.NoError(t, err) + require.Equal(t, savePostureChecks, postureChecks) +} + +func TestSqlStore_DeletePostureChecks(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + tests := []struct { + name string + postureChecksID string + expectError bool + }{ + { + name: "delete existing posture checks", + postureChecksID: "csplshq7qv948l48f7t0", + expectError: false, + }, + { + name: "delete non-existing posture checks", + postureChecksID: "non-existing-posture-checks-id", + expectError: true, + }, + { + name: "delete with empty posture checks ID", + postureChecksID: "", + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err = store.DeletePostureChecks(context.Background(), LockingStrengthUpdate, accountID, tt.postureChecksID) + if tt.expectError { + require.Error(t, err) + sErr, ok := status.FromError(err) + require.True(t, ok) + require.Equal(t, sErr.Type(), status.NotFound) + } else { + require.NoError(t, err) + group, err := store.GetPostureChecksByID(context.Background(), LockingStrengthShare, accountID, tt.postureChecksID) + require.Error(t, err) + require.Nil(t, group) + } + }) + } +} + +func TestSqlStore_GetPolicyByID(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + tests := []struct { + name string + policyID string + expectError bool + }{ + { + name: "retrieve existing policy", + policyID: "cs1tnh0hhcjnqoiuebf0", + expectError: false, + }, + { + name: "retrieve non-existing policy checks", + policyID: "non-existing", + expectError: true, + }, + { + name: "retrieve with empty policy ID", + policyID: "", + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + policy, err := store.GetPolicyByID(context.Background(), LockingStrengthShare, accountID, tt.policyID) + if tt.expectError { + require.Error(t, err) + sErr, ok := status.FromError(err) + require.True(t, ok) + require.Equal(t, sErr.Type(), status.NotFound) + require.Nil(t, policy) + } else { + require.NoError(t, err) + require.NotNil(t, policy) + require.Equal(t, tt.policyID, policy.ID) + } + }) + } +} + +func TestSqlStore_CreatePolicy(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + policy := &Policy{ + ID: "policy-id", + AccountID: accountID, + Enabled: true, + Rules: []*PolicyRule{ + { + Enabled: true, + Sources: []string{"groupA"}, + Destinations: []string{"groupC"}, + Bidirectional: true, + Action: PolicyTrafficActionAccept, + }, + }, + } + err = store.CreatePolicy(context.Background(), LockingStrengthUpdate, policy) + require.NoError(t, err) + + savePolicy, err := store.GetPolicyByID(context.Background(), LockingStrengthShare, accountID, policy.ID) + require.NoError(t, err) + require.Equal(t, savePolicy, policy) + +} + +func TestSqlStore_SavePolicy(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + policyID := "cs1tnh0hhcjnqoiuebf0" + + policy, err := store.GetPolicyByID(context.Background(), LockingStrengthShare, accountID, policyID) + require.NoError(t, err) + + policy.Enabled = false + policy.Description = "policy" + policy.Rules[0].Sources = []string{"group"} + policy.Rules[0].Ports = []string{"80", "443"} + err = store.SavePolicy(context.Background(), LockingStrengthUpdate, policy) + require.NoError(t, err) + + savePolicy, err := store.GetPolicyByID(context.Background(), LockingStrengthShare, accountID, policy.ID) + require.NoError(t, err) + require.Equal(t, savePolicy, policy) +} + +func TestSqlStore_DeletePolicy(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + policyID := "cs1tnh0hhcjnqoiuebf0" + + err = store.DeletePolicy(context.Background(), LockingStrengthShare, accountID, policyID) + require.NoError(t, err) + + policy, err := store.GetPolicyByID(context.Background(), LockingStrengthShare, accountID, policyID) + require.Error(t, err) + require.Nil(t, policy) +} + +func TestSqlStore_GetDNSSettings(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + tests := []struct { + name string + accountID string + expectError bool + }{ + { + name: "retrieve existing account dns settings", + accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b", + expectError: false, + }, + { + name: "retrieve non-existing account dns settings", + accountID: "non-existing", + expectError: true, + }, + { + name: "retrieve dns settings with empty account ID", + accountID: "", + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + dnsSettings, err := store.GetAccountDNSSettings(context.Background(), LockingStrengthShare, tt.accountID) + if tt.expectError { + require.Error(t, err) + sErr, ok := status.FromError(err) + require.True(t, ok) + require.Equal(t, sErr.Type(), status.NotFound) + require.Nil(t, dnsSettings) + } else { + require.NoError(t, err) + require.NotNil(t, dnsSettings) + } + }) + } +} + +func TestSqlStore_SaveDNSSettings(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + dnsSettings, err := store.GetAccountDNSSettings(context.Background(), LockingStrengthShare, accountID) + require.NoError(t, err) + + dnsSettings.DisabledManagementGroups = []string{"groupA", "groupB"} + err = store.SaveDNSSettings(context.Background(), LockingStrengthUpdate, accountID, dnsSettings) + require.NoError(t, err) + + saveDNSSettings, err := store.GetAccountDNSSettings(context.Background(), LockingStrengthShare, accountID) + require.NoError(t, err) + require.Equal(t, saveDNSSettings, dnsSettings) +} + +func TestSqlStore_GetAccountNameServerGroups(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + tests := []struct { + name string + accountID string + expectedCount int + }{ + { + name: "retrieve name server groups by existing account ID", + accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b", + expectedCount: 1, + }, + { + name: "non-existing account ID", + accountID: "nonexistent", + expectedCount: 0, + }, + { + name: "empty account ID", + accountID: "", + expectedCount: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + peers, err := store.GetAccountNameServerGroups(context.Background(), LockingStrengthShare, tt.accountID) + require.NoError(t, err) + require.Len(t, peers, tt.expectedCount) + }) + } + +} + +func TestSqlStore_GetNameServerByID(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + tests := []struct { + name string + nsGroupID string + expectError bool + }{ + { + name: "retrieve existing nameserver group", + nsGroupID: "csqdelq7qv97ncu7d9t0", + expectError: false, + }, + { + name: "retrieve non-existing nameserver group", + nsGroupID: "non-existing", + expectError: true, + }, + { + name: "retrieve with empty nameserver group ID", + nsGroupID: "", + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + nsGroup, err := store.GetNameServerGroupByID(context.Background(), LockingStrengthShare, accountID, tt.nsGroupID) + if tt.expectError { + require.Error(t, err) + sErr, ok := status.FromError(err) + require.True(t, ok) + require.Equal(t, sErr.Type(), status.NotFound) + require.Nil(t, nsGroup) + } else { + require.NoError(t, err) + require.NotNil(t, nsGroup) + require.Equal(t, tt.nsGroupID, nsGroup.ID) + } + }) + } +} + +func TestSqlStore_SaveNameServerGroup(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + nsGroup := &nbdns.NameServerGroup{ + ID: "ns-group-id", + AccountID: accountID, + Name: "NS Group", + NameServers: []nbdns.NameServer{ + { + IP: netip.MustParseAddr("8.8.8.8"), + NSType: 1, + Port: 53, + }, + }, + Groups: []string{"groupA"}, + Primary: true, + Enabled: true, + SearchDomainsEnabled: false, + } + + err = store.SaveNameServerGroup(context.Background(), LockingStrengthUpdate, nsGroup) + require.NoError(t, err) + + saveNSGroup, err := store.GetNameServerGroupByID(context.Background(), LockingStrengthShare, accountID, nsGroup.ID) + require.NoError(t, err) + require.Equal(t, saveNSGroup, nsGroup) +} + +func TestSqlStore_DeleteNameServerGroup(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + nsGroupID := "csqdelq7qv97ncu7d9t0" + + err = store.DeleteNameServerGroup(context.Background(), LockingStrengthShare, accountID, nsGroupID) + require.NoError(t, err) + + nsGroup, err := store.GetNameServerGroupByID(context.Background(), LockingStrengthShare, accountID, nsGroupID) + require.Error(t, err) + require.Nil(t, nsGroup) +} diff --git a/management/server/status/error.go b/management/server/status/error.go index 8b6d0077b..59f436f5b 100644 --- a/management/server/status/error.go +++ b/management/server/status/error.go @@ -139,3 +139,18 @@ func NewGetAccountError(err error) error { func NewGroupNotFoundError(groupID string) error { return Errorf(NotFound, "group: %s not found", groupID) } + +// NewPostureChecksNotFoundError creates a new Error with NotFound type for a missing posture checks +func NewPostureChecksNotFoundError(postureChecksID string) error { + return Errorf(NotFound, "posture checks: %s not found", postureChecksID) +} + +// NewPolicyNotFoundError creates a new Error with NotFound type for a missing policy +func NewPolicyNotFoundError(policyID string) error { + return Errorf(NotFound, "policy: %s not found", policyID) +} + +// NewNameServerGroupNotFoundError creates a new Error with NotFound type for a missing name server group +func NewNameServerGroupNotFoundError(nsGroupID string) error { + return Errorf(NotFound, "nameserver group: %s not found", nsGroupID) +} diff --git a/management/server/store.go b/management/server/store.go index 937270527..59d8e86fa 100644 --- a/management/server/store.go +++ b/management/server/store.go @@ -59,6 +59,7 @@ type Store interface { SaveAccount(ctx context.Context, account *Account) error DeleteAccount(ctx context.Context, account *Account) error UpdateAccountDomainAttributes(ctx context.Context, accountID string, domain string, category string, isPrimaryDomain bool) error + SaveDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *DNSSettings) error GetUserByTokenID(ctx context.Context, tokenID string) (*User, error) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error) @@ -80,11 +81,17 @@ type Store interface { DeleteGroups(ctx context.Context, strength LockingStrength, accountID string, groupIDs []string) error GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) - GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error) + GetPolicyByID(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) (*Policy, error) + CreatePolicy(ctx context.Context, lockStrength LockingStrength, policy *Policy) error + SavePolicy(ctx context.Context, lockStrength LockingStrength, policy *Policy) error + DeletePolicy(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) error GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error) - GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, postureCheckID string, accountID string) (*posture.Checks, error) + GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, accountID, postureCheckID string) (*posture.Checks, error) + GetPostureChecksByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, postureChecksIDs []string) (map[string]*posture.Checks, error) + SavePostureChecks(ctx context.Context, lockStrength LockingStrength, postureCheck *posture.Checks) error + DeletePostureChecks(ctx context.Context, lockStrength LockingStrength, accountID, postureChecksID string) error GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string) ([]string, error) AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error @@ -110,6 +117,8 @@ type Store interface { GetAccountNameServerGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*dns.NameServerGroup, error) GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, nameServerGroupID string, accountID string) (*dns.NameServerGroup, error) + SaveNameServerGroup(ctx context.Context, lockStrength LockingStrength, nameServerGroup *dns.NameServerGroup) error + DeleteNameServerGroup(ctx context.Context, lockStrength LockingStrength, accountID, nameServerGroupID string) error GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountId string) ([]net.IP, error) IncrementNetworkSerial(ctx context.Context, lockStrength LockingStrength, accountId string) error diff --git a/management/server/testdata/extended-store.sql b/management/server/testdata/extended-store.sql index b522741e7..455111439 100644 --- a/management/server/testdata/extended-store.sql +++ b/management/server/testdata/extended-store.sql @@ -34,4 +34,7 @@ INSERT INTO personal_access_tokens VALUES('9dj38s35-63fb-11ec-90d6-0242ac120003' INSERT INTO "groups" VALUES('cfefqs706sqkneg59g4g','bf1c8084-ba50-4ce7-9439-34653001fc3b','All','api','[]',0,''); INSERT INTO "groups" VALUES('cfefqs706sqkneg59g3g','bf1c8084-ba50-4ce7-9439-34653001fc3b','AwesomeGroup1','api','[]',0,''); INSERT INTO "groups" VALUES('cfefqs706sqkneg59g2g','bf1c8084-ba50-4ce7-9439-34653001fc3b','AwesomeGroup2','api','[]',0,''); +INSERT INTO posture_checks VALUES('csplshq7qv948l48f7t0','NetBird Version > 0.32.0','','bf1c8084-ba50-4ce7-9439-34653001fc3b','{"NBVersionCheck":{"MinVersion":"0.31.0"}}'); +INSERT INTO posture_checks VALUES('cspnllq7qv95uq1r4k90','Allow Berlin and Deny local network 172.16.1.0/24','','bf1c8084-ba50-4ce7-9439-34653001fc3b','{"GeoLocationCheck":{"Locations":[{"CountryCode":"DE","CityName":"Berlin"}],"Action":"allow"},"PeerNetworkRangeCheck":{"Action":"deny","Ranges":["172.16.1.0/24"]}}'); +INSERT INTO name_server_groups VALUES('csqdelq7qv97ncu7d9t0','bf1c8084-ba50-4ce7-9439-34653001fc3b','Google DNS','Google DNS Servers','[{"IP":"8.8.8.8","NSType":1,"Port":53},{"IP":"8.8.4.4","NSType":1,"Port":53}]','["cfefqs706sqkneg59g2g"]',1,'[]',1,0); INSERT INTO installations VALUES(1,''); diff --git a/management/server/user_test.go b/management/server/user_test.go index d4f560a54..498017afa 100644 --- a/management/server/user_test.go +++ b/management/server/user_test.go @@ -1279,8 +1279,7 @@ func TestUserAccountPeersUpdate(t *testing.T) { }) require.NoError(t, err) - policy := Policy{ - ID: "policy", + policy := &Policy{ Enabled: true, Rules: []*PolicyRule{ { @@ -1292,7 +1291,7 @@ func TestUserAccountPeersUpdate(t *testing.T) { }, }, } - err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, false) + _, err = manager.SavePolicy(context.Background(), account.Id, userID, policy) require.NoError(t, err) updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) diff --git a/util/net/conn.go b/util/net/conn.go new file mode 100644 index 000000000..26693f841 --- /dev/null +++ b/util/net/conn.go @@ -0,0 +1,31 @@ +//go:build !ios + +package net + +import ( + "net" + + log "github.com/sirupsen/logrus" +) + +// Conn wraps a net.Conn to override the Close method +type Conn struct { + net.Conn + ID ConnectionID +} + +// Close overrides the net.Conn Close method to execute all registered hooks after closing the connection +func (c *Conn) Close() error { + err := c.Conn.Close() + + dialerCloseHooksMutex.RLock() + defer dialerCloseHooksMutex.RUnlock() + + for _, hook := range dialerCloseHooks { + if err := hook(c.ID, &c.Conn); err != nil { + log.Errorf("Error executing dialer close hook: %v", err) + } + } + + return err +} diff --git a/util/net/dial.go b/util/net/dial.go new file mode 100644 index 000000000..595311492 --- /dev/null +++ b/util/net/dial.go @@ -0,0 +1,58 @@ +//go:build !ios + +package net + +import ( + "fmt" + "net" + + log "github.com/sirupsen/logrus" +) + +func DialUDP(network string, laddr, raddr *net.UDPAddr) (*net.UDPConn, error) { + if CustomRoutingDisabled() { + return net.DialUDP(network, laddr, raddr) + } + + dialer := NewDialer() + dialer.LocalAddr = laddr + + conn, err := dialer.Dial(network, raddr.String()) + if err != nil { + return nil, fmt.Errorf("dialing UDP %s: %w", raddr.String(), err) + } + + udpConn, ok := conn.(*Conn).Conn.(*net.UDPConn) + if !ok { + if err := conn.Close(); err != nil { + log.Errorf("Failed to close connection: %v", err) + } + return nil, fmt.Errorf("expected UDP connection, got different type: %T", conn) + } + + return udpConn, nil +} + +func DialTCP(network string, laddr, raddr *net.TCPAddr) (*net.TCPConn, error) { + if CustomRoutingDisabled() { + return net.DialTCP(network, laddr, raddr) + } + + dialer := NewDialer() + dialer.LocalAddr = laddr + + conn, err := dialer.Dial(network, raddr.String()) + if err != nil { + return nil, fmt.Errorf("dialing TCP %s: %w", raddr.String(), err) + } + + tcpConn, ok := conn.(*Conn).Conn.(*net.TCPConn) + if !ok { + if err := conn.Close(); err != nil { + log.Errorf("Failed to close connection: %v", err) + } + return nil, fmt.Errorf("expected TCP connection, got different type: %T", conn) + } + + return tcpConn, nil +} diff --git a/util/net/dialer_ios.go b/util/net/dial_ios.go similarity index 100% rename from util/net/dialer_ios.go rename to util/net/dial_ios.go diff --git a/util/net/dialer_android.go b/util/net/dialer_android.go deleted file mode 100644 index 4cbded536..000000000 --- a/util/net/dialer_android.go +++ /dev/null @@ -1,25 +0,0 @@ -package net - -import ( - "syscall" - - log "github.com/sirupsen/logrus" -) - -func (d *Dialer) init() { - d.Dialer.Control = func(_, _ string, c syscall.RawConn) error { - err := c.Control(func(fd uintptr) { - androidProtectSocketLock.Lock() - f := androidProtectSocket - androidProtectSocketLock.Unlock() - if f == nil { - return - } - ok := f(int32(fd)) - if !ok { - log.Errorf("failed to protect socket: %d", fd) - } - }) - return err - } -} diff --git a/util/net/dialer_nonios.go b/util/net/dialer_dial.go similarity index 63% rename from util/net/dialer_nonios.go rename to util/net/dialer_dial.go index 34004a368..1659b6220 100644 --- a/util/net/dialer_nonios.go +++ b/util/net/dialer_dial.go @@ -81,28 +81,6 @@ func (d *Dialer) Dial(network, address string) (net.Conn, error) { return d.DialContext(context.Background(), network, address) } -// Conn wraps a net.Conn to override the Close method -type Conn struct { - net.Conn - ID ConnectionID -} - -// Close overrides the net.Conn Close method to execute all registered hooks after closing the connection -func (c *Conn) Close() error { - err := c.Conn.Close() - - dialerCloseHooksMutex.RLock() - defer dialerCloseHooksMutex.RUnlock() - - for _, hook := range dialerCloseHooks { - if err := hook(c.ID, &c.Conn); err != nil { - log.Errorf("Error executing dialer close hook: %v", err) - } - } - - return err -} - func callDialerHooks(ctx context.Context, connID ConnectionID, address string, resolver *net.Resolver) error { host, _, err := net.SplitHostPort(address) if err != nil { @@ -127,51 +105,3 @@ func callDialerHooks(ctx context.Context, connID ConnectionID, address string, r return result.ErrorOrNil() } - -func DialUDP(network string, laddr, raddr *net.UDPAddr) (*net.UDPConn, error) { - if CustomRoutingDisabled() { - return net.DialUDP(network, laddr, raddr) - } - - dialer := NewDialer() - dialer.LocalAddr = laddr - - conn, err := dialer.Dial(network, raddr.String()) - if err != nil { - return nil, fmt.Errorf("dialing UDP %s: %w", raddr.String(), err) - } - - udpConn, ok := conn.(*Conn).Conn.(*net.UDPConn) - if !ok { - if err := conn.Close(); err != nil { - log.Errorf("Failed to close connection: %v", err) - } - return nil, fmt.Errorf("expected UDP connection, got different type: %T", conn) - } - - return udpConn, nil -} - -func DialTCP(network string, laddr, raddr *net.TCPAddr) (*net.TCPConn, error) { - if CustomRoutingDisabled() { - return net.DialTCP(network, laddr, raddr) - } - - dialer := NewDialer() - dialer.LocalAddr = laddr - - conn, err := dialer.Dial(network, raddr.String()) - if err != nil { - return nil, fmt.Errorf("dialing TCP %s: %w", raddr.String(), err) - } - - tcpConn, ok := conn.(*Conn).Conn.(*net.TCPConn) - if !ok { - if err := conn.Close(); err != nil { - log.Errorf("Failed to close connection: %v", err) - } - return nil, fmt.Errorf("expected TCP connection, got different type: %T", conn) - } - - return tcpConn, nil -} diff --git a/util/net/dialer_init_android.go b/util/net/dialer_init_android.go new file mode 100644 index 000000000..63b903348 --- /dev/null +++ b/util/net/dialer_init_android.go @@ -0,0 +1,5 @@ +package net + +func (d *Dialer) init() { + d.Dialer.Control = ControlProtectSocket +} diff --git a/util/net/dialer_linux.go b/util/net/dialer_init_linux.go similarity index 88% rename from util/net/dialer_linux.go rename to util/net/dialer_init_linux.go index aed5c59a3..d801e6080 100644 --- a/util/net/dialer_linux.go +++ b/util/net/dialer_init_linux.go @@ -7,6 +7,6 @@ import "syscall" // init configures the net.Dialer Control function to set the fwmark on the socket func (d *Dialer) init() { d.Dialer.Control = func(_, _ string, c syscall.RawConn) error { - return SetRawSocketMark(c) + return setRawSocketMark(c) } } diff --git a/util/net/dialer_nonlinux.go b/util/net/dialer_init_nonlinux.go similarity index 58% rename from util/net/dialer_nonlinux.go rename to util/net/dialer_init_nonlinux.go index c838441bd..8c57ebbaa 100644 --- a/util/net/dialer_nonlinux.go +++ b/util/net/dialer_init_nonlinux.go @@ -3,4 +3,5 @@ package net func (d *Dialer) init() { + // implemented on Linux and Android only } diff --git a/util/net/env.go b/util/net/env.go new file mode 100644 index 000000000..099da39b7 --- /dev/null +++ b/util/net/env.go @@ -0,0 +1,29 @@ +package net + +import ( + "os" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/iface/netstack" +) + +const ( + envDisableCustomRouting = "NB_DISABLE_CUSTOM_ROUTING" + envSkipSocketMark = "NB_SKIP_SOCKET_MARK" +) + +func CustomRoutingDisabled() bool { + if netstack.IsEnabled() { + return true + } + return os.Getenv(envDisableCustomRouting) == "true" +} + +func SkipSocketMark() bool { + if skipSocketMark := os.Getenv(envSkipSocketMark); skipSocketMark == "true" { + log.Infof("%s is set to true, skipping SO_MARK", envSkipSocketMark) + return true + } + return false +} diff --git a/util/net/listen.go b/util/net/listen.go new file mode 100644 index 000000000..3ae8a9435 --- /dev/null +++ b/util/net/listen.go @@ -0,0 +1,37 @@ +//go:build !ios + +package net + +import ( + "context" + "fmt" + "net" + "sync" + + "github.com/pion/transport/v3" + log "github.com/sirupsen/logrus" +) + +// ListenUDP listens on the network address and returns a transport.UDPConn +// which includes support for write and close hooks. +func ListenUDP(network string, laddr *net.UDPAddr) (transport.UDPConn, error) { + if CustomRoutingDisabled() { + return net.ListenUDP(network, laddr) + } + + conn, err := NewListener().ListenPacket(context.Background(), network, laddr.String()) + if err != nil { + return nil, fmt.Errorf("listen UDP: %w", err) + } + + packetConn := conn.(*PacketConn) + udpConn, ok := packetConn.PacketConn.(*net.UDPConn) + if !ok { + if err := packetConn.Close(); err != nil { + log.Errorf("Failed to close connection: %v", err) + } + return nil, fmt.Errorf("expected UDPConn, got different type: %T", udpConn) + } + + return &UDPConn{UDPConn: udpConn, ID: packetConn.ID, seenAddrs: &sync.Map{}}, nil +} diff --git a/util/net/listener_ios.go b/util/net/listen_ios.go similarity index 100% rename from util/net/listener_ios.go rename to util/net/listen_ios.go diff --git a/util/net/listener_android.go b/util/net/listener_android.go deleted file mode 100644 index d4167ad53..000000000 --- a/util/net/listener_android.go +++ /dev/null @@ -1,26 +0,0 @@ -package net - -import ( - "syscall" - - log "github.com/sirupsen/logrus" -) - -// init configures the net.ListenerConfig Control function to set the fwmark on the socket -func (l *ListenerConfig) init() { - l.ListenConfig.Control = func(_, _ string, c syscall.RawConn) error { - err := c.Control(func(fd uintptr) { - androidProtectSocketLock.Lock() - f := androidProtectSocket - androidProtectSocketLock.Unlock() - if f == nil { - return - } - ok := f(int32(fd)) - if !ok { - log.Errorf("failed to protect listener socket: %d", fd) - } - }) - return err - } -} diff --git a/util/net/listener_init_android.go b/util/net/listener_init_android.go new file mode 100644 index 000000000..f7bfa1dab --- /dev/null +++ b/util/net/listener_init_android.go @@ -0,0 +1,6 @@ +package net + +// init configures the net.ListenerConfig Control function to set the fwmark on the socket +func (l *ListenerConfig) init() { + l.ListenConfig.Control = ControlProtectSocket +} diff --git a/util/net/listener_linux.go b/util/net/listener_init_linux.go similarity index 89% rename from util/net/listener_linux.go rename to util/net/listener_init_linux.go index 8d332160a..e32d5d894 100644 --- a/util/net/listener_linux.go +++ b/util/net/listener_init_linux.go @@ -9,6 +9,6 @@ import ( // init configures the net.ListenerConfig Control function to set the fwmark on the socket func (l *ListenerConfig) init() { l.ListenConfig.Control = func(_, _ string, c syscall.RawConn) error { - return SetRawSocketMark(c) + return setRawSocketMark(c) } } diff --git a/util/net/listener_nonlinux.go b/util/net/listener_init_nonlinux.go similarity index 61% rename from util/net/listener_nonlinux.go rename to util/net/listener_init_nonlinux.go index 14a6be49d..80f6f7f1a 100644 --- a/util/net/listener_nonlinux.go +++ b/util/net/listener_init_nonlinux.go @@ -3,4 +3,5 @@ package net func (l *ListenerConfig) init() { + // implemented on Linux and Android only } diff --git a/util/net/listener_nonios.go b/util/net/listener_listen.go similarity index 84% rename from util/net/listener_nonios.go rename to util/net/listener_listen.go index ae4be3494..efffba40e 100644 --- a/util/net/listener_nonios.go +++ b/util/net/listener_listen.go @@ -8,7 +8,6 @@ import ( "net" "sync" - "github.com/pion/transport/v3" log "github.com/sirupsen/logrus" ) @@ -146,27 +145,3 @@ func closeConn(id ConnectionID, conn net.PacketConn) error { return err } - -// ListenUDP listens on the network address and returns a transport.UDPConn -// which includes support for write and close hooks. -func ListenUDP(network string, laddr *net.UDPAddr) (transport.UDPConn, error) { - if CustomRoutingDisabled() { - return net.ListenUDP(network, laddr) - } - - conn, err := NewListener().ListenPacket(context.Background(), network, laddr.String()) - if err != nil { - return nil, fmt.Errorf("listen UDP: %w", err) - } - - packetConn := conn.(*PacketConn) - udpConn, ok := packetConn.PacketConn.(*net.UDPConn) - if !ok { - if err := packetConn.Close(); err != nil { - log.Errorf("Failed to close connection: %v", err) - } - return nil, fmt.Errorf("expected UDPConn, got different type: %T", udpConn) - } - - return &UDPConn{UDPConn: udpConn, ID: packetConn.ID, seenAddrs: &sync.Map{}}, nil -} diff --git a/util/net/net.go b/util/net/net.go index 5448eb85a..403aa87e7 100644 --- a/util/net/net.go +++ b/util/net/net.go @@ -2,9 +2,6 @@ package net import ( "net" - "os" - - "github.com/netbirdio/netbird/client/iface/netstack" "github.com/google/uuid" ) @@ -16,8 +13,6 @@ const ( PreroutingFwmarkRedirected = 0x1BD01 PreroutingFwmarkMasquerade = 0x1BD11 PreroutingFwmarkMasqueradeReturn = 0x1BD12 - - envDisableCustomRouting = "NB_DISABLE_CUSTOM_ROUTING" ) // ConnectionID provides a globally unique identifier for network connections. @@ -31,10 +26,3 @@ type RemoveHookFunc func(connID ConnectionID) error func GenerateConnID() ConnectionID { return ConnectionID(uuid.NewString()) } - -func CustomRoutingDisabled() bool { - if netstack.IsEnabled() { - return true - } - return os.Getenv(envDisableCustomRouting) == "true" -} diff --git a/util/net/net_linux.go b/util/net/net_linux.go index 98f49af8d..fc486ebd4 100644 --- a/util/net/net_linux.go +++ b/util/net/net_linux.go @@ -4,29 +4,42 @@ package net import ( "fmt" - "os" "syscall" log "github.com/sirupsen/logrus" ) -const EnvSkipSocketMark = "NB_SKIP_SOCKET_MARK" - // SetSocketMark sets the SO_MARK option on the given socket connection func SetSocketMark(conn syscall.Conn) error { + if isSocketMarkDisabled() { + return nil + } + sysconn, err := conn.SyscallConn() if err != nil { return fmt.Errorf("get raw conn: %w", err) } - return SetRawSocketMark(sysconn) + return setRawSocketMark(sysconn) } -func SetRawSocketMark(conn syscall.RawConn) error { +// SetSocketOpt sets the SO_MARK option on the given file descriptor +func SetSocketOpt(fd int) error { + if isSocketMarkDisabled() { + return nil + } + + return setSocketOptInt(fd) +} + +func setRawSocketMark(conn syscall.RawConn) error { var setErr error err := conn.Control(func(fd uintptr) { - setErr = SetSocketOpt(int(fd)) + if isSocketMarkDisabled() { + return + } + setErr = setSocketOptInt(int(fd)) }) if err != nil { return fmt.Errorf("control: %w", err) @@ -39,17 +52,18 @@ func SetRawSocketMark(conn syscall.RawConn) error { return nil } -func SetSocketOpt(fd int) error { - if CustomRoutingDisabled() { - log.Infof("Custom routing is disabled, skipping SO_MARK") - return nil - } - - // Check for the new environment variable - if skipSocketMark := os.Getenv(EnvSkipSocketMark); skipSocketMark == "true" { - log.Info("NB_SKIP_SOCKET_MARK is set to true, skipping SO_MARK") - return nil - } - +func setSocketOptInt(fd int) error { return syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_MARK, NetbirdFwmark) } + +func isSocketMarkDisabled() bool { + if CustomRoutingDisabled() { + log.Infof("Custom routing is disabled, skipping SO_MARK") + return true + } + + if SkipSocketMark() { + return true + } + return false +} diff --git a/util/net/protectsocket_android.go b/util/net/protectsocket_android.go index 64fb45aa4..febed8a1e 100644 --- a/util/net/protectsocket_android.go +++ b/util/net/protectsocket_android.go @@ -1,14 +1,42 @@ package net -import "sync" +import ( + "fmt" + "sync" + "syscall" +) var ( androidProtectSocketLock sync.Mutex androidProtectSocket func(fd int32) bool ) -func SetAndroidProtectSocketFn(f func(fd int32) bool) { +func SetAndroidProtectSocketFn(fn func(fd int32) bool) { androidProtectSocketLock.Lock() - androidProtectSocket = f + androidProtectSocket = fn androidProtectSocketLock.Unlock() } + +// ControlProtectSocket is a Control function that sets the fwmark on the socket +func ControlProtectSocket(_, _ string, c syscall.RawConn) error { + var aErr error + err := c.Control(func(fd uintptr) { + androidProtectSocketLock.Lock() + defer androidProtectSocketLock.Unlock() + + if androidProtectSocket == nil { + aErr = fmt.Errorf("socket protection function not set") + return + } + + if !androidProtectSocket(int32(fd)) { + aErr = fmt.Errorf("failed to protect socket via Android") + } + }) + + if err != nil { + return err + } + + return aErr +}