Compare commits

...

19 Commits

Author SHA1 Message Date
Maycon Santos
75a02c3e38 re-add wmi 2024-11-29 00:34:43 +01:00
Maycon Santos
052525093a add missing static info 2024-11-29 00:09:54 +01:00
Maycon Santos
b4b9aedf5a Add static system info to other OSs 2024-11-29 00:04:36 +01:00
Maycon Santos
07f0f9fdbd Add static system info 2024-11-28 23:46:29 +01:00
Pascal Fischer
c6641be94b [tests] Enable benchmark tests on github actions (#2961) 2024-11-28 19:22:01 +01:00
Pascal Fischer
89cf8a55e2 [management] Add performance test for login and sync calls (#2960) 2024-11-28 14:59:53 +01:00
Pascal Fischer
00c3b67182 [management] refactor to use account object instead of separate db calls for peer update (#2957) 2024-11-28 11:13:01 +01:00
Zoltan Papp
9203690033 [client] Code cleaning in net pkg and fix exit node feature on Android(#2932)
Code cleaning around the util/net package. The goal was to write a more understandable source code but modify nothing on the logic.
Protect the WireGuard UDP listeners with marks.
The implementation can support the VPN permission revocation events in thread safe way. It will be important if we start to support the running time route and DNS update features.

- uniformize the file name convention: [struct_name] _ [functions] _ [os].go
- code cleaning in net_linux.go
- move env variables to env.go file
2024-11-26 23:34:27 +01:00
Bethuel Mmbaga
9683da54b0 [management] Refactor nameserver groups to use store methods (#2888) 2024-11-26 17:39:04 +01:00
Bethuel Mmbaga
0e48a772ff [management] Refactor DNS settings to use store methods (#2883) 2024-11-26 13:43:05 +01:00
Bethuel Mmbaga
f118d81d32 [management] Refactor policy to use store methods (#2878) 2024-11-26 10:46:05 +01:00
Bethuel Mmbaga
ca12bc6953 [management] Refactor posture check to use store methods (#2874) 2024-11-25 16:26:24 +01:00
Viktor Liu
9810386937 [client] Allow routing to fallback to exclusion routes if rules are not supported (#2909) 2024-11-25 15:19:56 +01:00
Viktor Liu
f1625b32bd [client] Set up sysctl and routing table name only if routing rules are available (#2933) 2024-11-25 15:12:16 +01:00
Viktor Liu
0ecd5f2118 [client] Test nftables for incompatible iptables rules (#2948) 2024-11-25 15:11:56 +01:00
Viktor Liu
940d0c48c6 [client] Don't return error in userspace mode without firewall (#2924) 2024-11-25 15:11:31 +01:00
Maycon Santos
56cecf849e Import time package (#2940) 2024-11-22 20:40:30 +01:00
Maycon Santos
05c4aa7c2c [misc] Renew slack link (#2938) 2024-11-22 18:50:47 +01:00
Zoltan Papp
2a5cb16494 [relay] Refactor initial Relay connection (#2800)
Can support firewalls with restricted WS rules

allow to run engine without Relay servers
keep up to date Relay address changes
2024-11-22 18:12:34 +01:00
82 changed files with 2746 additions and 1153 deletions

View File

@@ -52,6 +52,47 @@ jobs:
- name: Test
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 10m -p 1 ./...
benchmark:
strategy:
fail-fast: false
matrix:
arch: [ '386','amd64' ]
store: [ 'sqlite', 'postgres' ]
runs-on: ubuntu-22.04
steps:
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
- name: Cache Go modules
uses: actions/cache@v4
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-go-
- name: Checkout code
uses: actions/checkout@v4
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
- name: Install 32-bit libpcap
if: matrix.arch == '386'
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
- name: Install modules
run: go mod tidy
- name: check git status
run: git --no-pager diff --exit-code
- name: Test
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -run=^$ -bench=. -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 10m -p 1 ./...
test_client_on_docker:
runs-on: ubuntu-20.04
steps:

View File

@@ -17,7 +17,7 @@
<img src="https://img.shields.io/badge/license-BSD--3-blue" />
</a>
<br>
<a href="https://join.slack.com/t/netbirdio/shared_invite/zt-2p5zwhm4g-8fHollzrQa5y4PZF5AEpvQ">
<a href="https://join.slack.com/t/netbirdio/shared_invite/zt-2utg2ncdz-W7LEB6toRBLE1Jca37dYpg">
<img src="https://img.shields.io/badge/slack-@netbird-red.svg?logo=slack"/>
</a>
<br>
@@ -34,7 +34,7 @@
<br/>
See <a href="https://netbird.io/docs/">Documentation</a>
<br/>
Join our <a href="https://join.slack.com/t/netbirdio/shared_invite/zt-2p5zwhm4g-8fHollzrQa5y4PZF5AEpvQ">Slack channel</a>
Join our <a href="https://join.slack.com/t/netbirdio/shared_invite/zt-2utg2ncdz-W7LEB6toRBLE1Jca37dYpg">Slack channel</a>
<br/>
</strong>

View File

@@ -106,7 +106,7 @@ func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsRead
// todo do not throw error in case of cancelled context
ctx = internal.CtxInitState(ctx)
connectClient := internal.NewConnectClient(ctx, cfg, c.recorder)
connectClient := internal.NewConnectClient(ctx, cfg, c.recorder, &system.StaticInfo{})
return connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, dns.items, dnsReadyListener)
}
@@ -132,7 +132,7 @@ func (c *Client) RunWithoutLogin(dns *DNSList, dnsReadyListener DnsReadyListener
// todo do not throw error in case of cancelled context
ctx = internal.CtxInitState(ctx)
connectClient := internal.NewConnectClient(ctx, cfg, c.recorder)
connectClient := internal.NewConnectClient(ctx, cfg, c.recorder, &system.StaticInfo{})
return connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, dns.items, dnsReadyListener)
}

View File

@@ -131,7 +131,7 @@ func (a *Auth) loginWithSetupKeyAndSaveConfig(setupKey string, deviceName string
ctxWithValues := context.WithValue(a.ctx, system.DeviceNameCtxKey, deviceName)
err := a.withBackOff(a.ctx, func() error {
backoffErr := internal.Login(ctxWithValues, a.config, setupKey, "")
backoffErr := internal.Login(ctxWithValues, a.config, setupKey, "", &system.StaticInfo{})
if s, ok := gstatus.FromError(backoffErr); ok && (s.Code() == codes.PermissionDenied) {
// we got an answer from management, exit backoff earlier
return backoff.Permanent(backoffErr)
@@ -162,7 +162,7 @@ func (a *Auth) login(urlOpener URLOpener) error {
// check if we need to generate JWT token
err := a.withBackOff(a.ctx, func() (err error) {
needsLogin, err = internal.IsLoginRequired(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config.SSHKey)
needsLogin, err = internal.IsLoginRequired(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config.SSHKey, &system.StaticInfo{})
return
})
if err != nil {
@@ -179,7 +179,7 @@ func (a *Auth) login(urlOpener URLOpener) error {
}
err = a.withBackOff(a.ctx, func() error {
err := internal.Login(a.ctx, a.config, "", jwtToken)
err := internal.Login(a.ctx, a.config, "", jwtToken, &system.StaticInfo{})
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
return nil
}

View File

@@ -137,9 +137,11 @@ var loginCmd = &cobra.Command{
func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *internal.Config, setupKey string) error {
needsLogin := false
staticInfoChan := system.GetStaticInfoInBackground(ctx)
staticInfo := <-staticInfoChan
err := WithBackOff(func() error {
err := internal.Login(ctx, config, "", "")
err := internal.Login(ctx, config, "", "", staticInfo)
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
needsLogin = true
return nil
@@ -162,7 +164,7 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *internal.C
var lastError error
err = WithBackOff(func() error {
err := internal.Login(ctx, config, setupKey, jwtToken)
err := internal.Login(ctx, config, setupKey, jwtToken, staticInfo)
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
lastError = err
return nil

View File

@@ -152,6 +152,8 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
return err
}
staticInfoChan := system.GetStaticInfoInBackground(ctx)
config, err := internal.UpdateOrCreateConfig(ic)
if err != nil {
return fmt.Errorf("get config file: %v", err)
@@ -171,7 +173,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
r := peer.NewRecorder(config.ManagementURL.String())
r.GetFullStatus()
connectClient := internal.NewConnectClient(ctx, config, r)
connectClient := internal.NewConnectClient(ctx, config, r, <-staticInfoChan)
return connectClient.Run()
}

View File

@@ -1,9 +1,11 @@
package nftables
import (
"bytes"
"fmt"
"net"
"net/netip"
"os/exec"
"testing"
"time"
@@ -225,3 +227,105 @@ func TestNFtablesCreatePerformance(t *testing.T) {
})
}
}
func runIptablesSave(t *testing.T) (string, string) {
t.Helper()
var stdout, stderr bytes.Buffer
cmd := exec.Command("iptables-save")
cmd.Stdout = &stdout
cmd.Stderr = &stderr
err := cmd.Run()
require.NoError(t, err, "iptables-save failed to run")
return stdout.String(), stderr.String()
}
func verifyIptablesOutput(t *testing.T, stdout, stderr string) {
t.Helper()
// Check for any incompatibility warnings
require.NotContains(t,
stderr,
"incompatible",
"iptables-save produced compatibility warning. Full stderr: %s",
stderr,
)
// Verify standard tables are present
expectedTables := []string{
"*filter",
"*nat",
"*mangle",
}
for _, table := range expectedTables {
require.Contains(t,
stdout,
table,
"iptables-save output missing expected table: %s\nFull stdout: %s",
table,
stdout,
)
}
}
func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
if check() != NFTABLES {
t.Skip("nftables not supported on this system")
}
if _, err := exec.LookPath("iptables-save"); err != nil {
t.Skipf("iptables-save not available on this system: %v", err)
}
// First ensure iptables-nft tables exist by running iptables-save
stdout, stderr := runIptablesSave(t)
verifyIptablesOutput(t, stdout, stderr)
manager, err := Create(ifaceMock)
require.NoError(t, err, "failed to create manager")
require.NoError(t, manager.Init(nil))
t.Cleanup(func() {
err := manager.Reset(nil)
require.NoError(t, err, "failed to reset manager state")
// Verify iptables output after reset
stdout, stderr := runIptablesSave(t)
verifyIptablesOutput(t, stdout, stderr)
})
ip := net.ParseIP("100.96.0.1")
_, err = manager.AddPeerFiltering(
ip,
fw.ProtocolTCP,
nil,
&fw.Port{Values: []int{80}},
fw.RuleDirectionIN,
fw.ActionAccept,
"",
"test rule",
)
require.NoError(t, err, "failed to add peer filtering rule")
_, err = manager.AddRouteFiltering(
[]netip.Prefix{netip.MustParsePrefix("192.168.2.0/24")},
netip.MustParsePrefix("10.1.0.0/24"),
fw.ProtocolTCP,
nil,
&fw.Port{Values: []int{443}},
fw.ActionAccept,
)
require.NoError(t, err, "failed to add route filtering rule")
pair := fw.RouterPair{
Source: netip.MustParsePrefix("192.168.1.0/24"),
Destination: netip.MustParsePrefix("10.0.0.0/24"),
Masquerade: true,
}
err = manager.AddNatRule(pair)
require.NoError(t, err, "failed to add NAT rule")
stdout, stderr = runIptablesSave(t)
verifyIptablesOutput(t, stdout, stderr)
}

View File

@@ -239,7 +239,7 @@ func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
// SetLegacyManagement doesn't need to be implemented for this manager
func (m *Manager) SetLegacyManagement(isLegacy bool) error {
if m.nativeFirewall == nil {
return errRouteNotSupported
return nil
}
return m.nativeFirewall.SetLegacyManagement(isLegacy)
}

View File

@@ -0,0 +1,12 @@
package bind
import (
wireguard "golang.zx2c4.com/wireguard/conn"
nbnet "github.com/netbirdio/netbird/util/net"
)
func init() {
// ControlFns is not thread safe and should only be modified during init.
*wireguard.ControlFns = append(*wireguard.ControlFns, nbnet.ControlProtectSocket)
}

View File

@@ -40,19 +40,21 @@ type ConnectClient struct {
statusRecorder *peer.Status
engine *Engine
engineMutex sync.Mutex
staticInfo *system.StaticInfo
}
func NewConnectClient(
ctx context.Context,
config *Config,
statusRecorder *peer.Status,
staticInfo *system.StaticInfo,
) *ConnectClient {
return &ConnectClient{
ctx: ctx,
config: config,
statusRecorder: statusRecorder,
engineMutex: sync.Mutex{},
staticInfo: staticInfo,
}
}
@@ -179,7 +181,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHold
}()
// connect (just a connection, no stream yet) and login to Management Service to get an initial global Wiretrustee config
loginResp, err := loginToManagement(engineCtx, mgmClient, publicSSHKey)
loginResp, err := loginToManagement(engineCtx, mgmClient, publicSSHKey, c.staticInfo)
if err != nil {
log.Debug(err)
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied) {
@@ -232,6 +234,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHold
relayURLs, token := parseRelayInfo(loginResp)
relayManager := relayClient.NewManager(engineCtx, relayURLs, myPrivateKey.PublicKey().String())
c.statusRecorder.SetRelayMgr(relayManager)
if len(relayURLs) > 0 {
if token != nil {
if err := relayManager.UpdateToken(token); err != nil {
@@ -242,9 +245,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHold
log.Infof("connecting to the Relay service(s): %s", strings.Join(relayURLs, ", "))
if err = relayManager.Serve(); err != nil {
log.Error(err)
return wrapErr(err)
}
c.statusRecorder.SetRelayMgr(relayManager)
}
peerConfig := loginResp.GetPeerConfig()
@@ -258,7 +259,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHold
checks := loginResp.GetChecks()
c.engineMutex.Lock()
c.engine = NewEngineWithProbes(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, probes, checks)
c.engine = NewEngineWithProbes(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, probes, checks, c.staticInfo)
c.engineMutex.Unlock()
@@ -425,14 +426,14 @@ func connectToSignal(ctx context.Context, wtConfig *mgmProto.WiretrusteeConfig,
}
// loginToManagement creates Management Services client, establishes a connection, logs-in and gets a global Wiretrustee config (signal, turn, stun hosts, etc)
func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte) (*mgmProto.LoginResponse, error) {
func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte, staticInfo *system.StaticInfo) (*mgmProto.LoginResponse, error) {
serverPublicKey, err := client.GetServerPublicKey()
if err != nil {
return nil, gstatus.Errorf(codes.FailedPrecondition, "failed while getting Management Service public key: %s", err)
}
sysInfo := system.GetInfo(ctx)
sysInfo := system.GetInfo(ctx, staticInfo)
loginResp, err := client.Login(*serverPublicKey, sysInfo, pubSSHKey)
if err != nil {
return nil, err

View File

@@ -172,6 +172,8 @@ type Engine struct {
relayManager *relayClient.Manager
stateManager *statemanager.Manager
srWatcher *guard.SRWatcher
staticInfo *system.StaticInfo
}
// Peer is an instance of the Connection Peer
@@ -180,8 +182,8 @@ type Peer struct {
WgAllowedIps string
}
// NewEngine creates a new Connection Engine
func NewEngine(
// newEngine creates a new Connection Engine
func newEngine(
clientCtx context.Context,
clientCancel context.CancelFunc,
signalClient signal.Client,
@@ -203,6 +205,7 @@ func NewEngine(
statusRecorder,
nil,
checks,
nil,
)
}
@@ -218,6 +221,7 @@ func NewEngineWithProbes(
statusRecorder *peer.Status,
probes *ProbeHolder,
checks []*mgmProto.Checks,
staticInfo *system.StaticInfo,
) *Engine {
engine := &Engine{
clientCtx: clientCtx,
@@ -237,6 +241,7 @@ func NewEngineWithProbes(
statusRecorder: statusRecorder,
probes: probes,
checks: checks,
staticInfo: staticInfo,
}
if path := statemanager.GetDefaultStatePath(); path != "" {
engine.stateManager = statemanager.New(path)
@@ -538,6 +543,7 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
relayMsg := wCfg.GetRelay()
if relayMsg != nil {
// when we receive token we expect valid address list too
c := &auth.Token{
Payload: relayMsg.GetTokenPayload(),
Signature: relayMsg.GetTokenSignature(),
@@ -546,9 +552,16 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
log.Errorf("failed to update relay token: %v", err)
return fmt.Errorf("update relay token: %w", err)
}
e.relayManager.UpdateServerURLs(relayMsg.Urls)
// Just in case the agent started with an MGM server where the relay was disabled but was later enabled.
// We can ignore all errors because the guard will manage the reconnection retries.
_ = e.relayManager.Serve()
} else {
e.relayManager.UpdateServerURLs(nil)
}
// todo update relay address in the relay manager
// todo update signal
}
@@ -574,10 +587,10 @@ func (e *Engine) updateChecksIfNew(checks []*mgmProto.Checks) error {
}
e.checks = checks
info, err := system.GetInfoWithChecks(e.ctx, checks)
info, err := system.GetInfoWithChecks(e.ctx, checks, e.staticInfo)
if err != nil {
log.Warnf("failed to get system info with checks: %v", err)
info = system.GetInfo(e.ctx)
info = system.GetInfo(e.ctx, e.staticInfo)
}
if err := e.mgmClient.SyncMeta(info); err != nil {
@@ -677,10 +690,10 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
// E.g. when a new peer has been registered and we are allowed to connect to it.
func (e *Engine) receiveManagementEvents() {
go func() {
info, err := system.GetInfoWithChecks(e.ctx, e.checks)
info, err := system.GetInfoWithChecks(e.ctx, e.checks, e.staticInfo)
if err != nil {
log.Warnf("failed to get system info with checks: %v", err)
info = system.GetInfo(e.ctx)
info = system.GetInfo(e.ctx, e.staticInfo)
}
// err = e.mgmClient.Sync(info, e.handleSync)
@@ -1184,7 +1197,7 @@ func (e *Engine) close() {
}
func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, error) {
info := system.GetInfo(e.ctx)
info := system.GetInfo(e.ctx, e.staticInfo)
netMap, err := e.mgmClient.GetNetworkMap(info)
if err != nil {
return nil, nil, err

View File

@@ -84,7 +84,7 @@ func TestEngine_SSH(t *testing.T) {
defer cancel()
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String())
engine := NewEngine(
engine := newEngine(
ctx, cancel,
&signal.MockClient{},
&mgmt.MockClient{},
@@ -229,7 +229,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
defer cancel()
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String())
engine := NewEngine(
engine := newEngine(
ctx, cancel,
&signal.MockClient{},
&mgmt.MockClient{},
@@ -434,7 +434,7 @@ func TestEngine_Sync(t *testing.T) {
return nil
}
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String())
engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{SyncFunc: syncFunc}, relayMgr, &EngineConfig{
engine := newEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{SyncFunc: syncFunc}, relayMgr, &EngineConfig{
WgIfaceName: "utun103",
WgAddr: "100.64.0.1/24",
WgPrivateKey: key,
@@ -594,7 +594,7 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
wgAddr := fmt.Sprintf("100.66.%d.1/24", n)
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String())
engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, relayMgr, &EngineConfig{
engine := newEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, relayMgr, &EngineConfig{
WgIfaceName: wgIfaceName,
WgAddr: wgAddr,
WgPrivateKey: key,
@@ -774,7 +774,7 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
wgAddr := fmt.Sprintf("100.66.%d.1/24", n)
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String())
engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, relayMgr, &EngineConfig{
engine := newEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, relayMgr, &EngineConfig{
WgIfaceName: wgIfaceName,
WgAddr: wgAddr,
WgPrivateKey: key,
@@ -1118,7 +1118,7 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin
return nil, err
}
info := system.GetInfo(ctx)
info := system.GetInfo(ctx, nil)
resp, err := mgmtClient.Register(*publicKey, setupKey, "", info, nil)
if err != nil {
return nil, err
@@ -1140,7 +1140,7 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin
}
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String())
e, err := NewEngine(ctx, cancel, signalClient, mgmtClient, relayMgr, conf, MobileDependency{}, peer.NewRecorder("https://mgm"), nil), nil
e, err := newEngine(ctx, cancel, signalClient, mgmtClient, relayMgr, conf, MobileDependency{}, peer.NewRecorder("https://mgm"), nil), nil
e.ctx = ctx
return e, err
}

View File

@@ -17,7 +17,7 @@ import (
)
// IsLoginRequired check that the server is support SSO or not
func IsLoginRequired(ctx context.Context, privateKey string, mgmURL *url.URL, sshKey string) (bool, error) {
func IsLoginRequired(ctx context.Context, privateKey string, mgmURL *url.URL, sshKey string, staticInfo *system.StaticInfo) (bool, error) {
mgmClient, err := getMgmClient(ctx, privateKey, mgmURL)
if err != nil {
return false, err
@@ -38,7 +38,7 @@ func IsLoginRequired(ctx context.Context, privateKey string, mgmURL *url.URL, ss
return false, err
}
_, err = doMgmLogin(ctx, mgmClient, pubSSHKey)
_, err = doMgmLogin(ctx, mgmClient, pubSSHKey, staticInfo)
if isLoginNeeded(err) {
return true, nil
}
@@ -46,7 +46,7 @@ func IsLoginRequired(ctx context.Context, privateKey string, mgmURL *url.URL, ss
}
// Login or register the client
func Login(ctx context.Context, config *Config, setupKey string, jwtToken string) error {
func Login(ctx context.Context, config *Config, setupKey string, jwtToken string, staticInfo *system.StaticInfo) error {
mgmClient, err := getMgmClient(ctx, config.PrivateKey, config.ManagementURL)
if err != nil {
return err
@@ -67,10 +67,10 @@ func Login(ctx context.Context, config *Config, setupKey string, jwtToken string
return err
}
serverKey, err := doMgmLogin(ctx, mgmClient, pubSSHKey)
serverKey, err := doMgmLogin(ctx, mgmClient, pubSSHKey, staticInfo)
if serverKey != nil && isRegistrationNeeded(err) {
log.Debugf("peer registration required")
_, err = registerPeer(ctx, *serverKey, mgmClient, setupKey, jwtToken, pubSSHKey)
_, err = registerPeer(ctx, *serverKey, mgmClient, setupKey, jwtToken, pubSSHKey, staticInfo)
return err
}
@@ -99,28 +99,28 @@ func getMgmClient(ctx context.Context, privateKey string, mgmURL *url.URL) (*mgm
return mgmClient, err
}
func doMgmLogin(ctx context.Context, mgmClient *mgm.GrpcClient, pubSSHKey []byte) (*wgtypes.Key, error) {
func doMgmLogin(ctx context.Context, mgmClient *mgm.GrpcClient, pubSSHKey []byte, staticInfo *system.StaticInfo) (*wgtypes.Key, error) {
serverKey, err := mgmClient.GetServerPublicKey()
if err != nil {
log.Errorf("failed while getting Management Service public key: %v", err)
return nil, err
}
sysInfo := system.GetInfo(ctx)
sysInfo := system.GetInfo(ctx, staticInfo)
_, err = mgmClient.Login(*serverKey, sysInfo, pubSSHKey)
return serverKey, err
}
// registerPeer checks whether setupKey was provided via cmd line and if not then it prompts user to enter a key.
// Otherwise tries to register with the provided setupKey via command line.
func registerPeer(ctx context.Context, serverPublicKey wgtypes.Key, client *mgm.GrpcClient, setupKey string, jwtToken string, pubSSHKey []byte) (*mgmProto.LoginResponse, error) {
func registerPeer(ctx context.Context, serverPublicKey wgtypes.Key, client *mgm.GrpcClient, setupKey string, jwtToken string, pubSSHKey []byte, staticInfo *system.StaticInfo) (*mgmProto.LoginResponse, error) {
validSetupKey, err := uuid.Parse(setupKey)
if err != nil && jwtToken == "" {
return nil, status.Errorf(codes.InvalidArgument, "invalid setup-key or no sso information provided, err: %v", err)
}
log.Debugf("sending peer registration request to Management Service")
info := system.GetInfo(ctx)
info := system.GetInfo(ctx, staticInfo)
loginResp, err := client.Register(serverPublicKey, validSetupKey.String(), jwtToken, info, pubSSHKey)
if err != nil {
log.Errorf("failed registering peer %v,%s", err, validSetupKey.String())

View File

@@ -676,25 +676,23 @@ func (d *Status) GetRelayStates() []relay.ProbeResult {
// extend the list of stun, turn servers with relay address
relayStates := slices.Clone(d.relayStates)
var relayState relay.ProbeResult
// if the server connection is not established then we will use the general address
// in case of connection we will use the instance specific address
instanceAddr, err := d.relayMgr.RelayInstanceAddress()
if err != nil {
// TODO add their status
if errors.Is(err, relayClient.ErrRelayClientNotConnected) {
for _, r := range d.relayMgr.ServerURLs() {
relayStates = append(relayStates, relay.ProbeResult{
URI: r,
})
}
return relayStates
for _, r := range d.relayMgr.ServerURLs() {
relayStates = append(relayStates, relay.ProbeResult{
URI: r,
Err: err,
})
}
relayState.Err = err
return relayStates
}
relayState.URI = instanceAddr
relayState := relay.ProbeResult{
URI: instanceAddr,
}
return append(relayStates, relayState)
}

View File

@@ -46,8 +46,6 @@ type WorkerICE struct {
hasRelayOnLocally bool
conn WorkerICECallbacks
selectedPriority ConnPriority
agent *ice.Agent
muxAgent sync.Mutex
@@ -95,10 +93,8 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
var preferredCandidateTypes []ice.CandidateType
if w.hasRelayOnLocally && remoteOfferAnswer.RelaySrvAddress != "" {
w.selectedPriority = connPriorityICEP2P
preferredCandidateTypes = icemaker.CandidateTypesP2P()
} else {
w.selectedPriority = connPriorityICETurn
preferredCandidateTypes = icemaker.CandidateTypes()
}
@@ -159,7 +155,7 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
RelayedOnLocal: isRelayCandidate(pair.Local),
}
w.log.Debugf("on ICE conn read to use ready")
go w.conn.OnConnReady(w.selectedPriority, ci)
go w.conn.OnConnReady(selectedPriority(pair), ci)
}
// OnRemoteCandidate Handles ICE connection Candidate provided by the remote peer.
@@ -394,3 +390,11 @@ func isRelayed(pair *ice.CandidatePair) bool {
}
return false
}
func selectedPriority(pair *ice.CandidatePair) ConnPriority {
if isRelayed(pair) {
return connPriorityICETurn
} else {
return connPriorityICEP2P
}
}

View File

@@ -55,7 +55,7 @@ type ruleParams struct {
// isLegacy determines whether to use the legacy routing setup
func isLegacy() bool {
return os.Getenv("NB_USE_LEGACY_ROUTING") == "true" || nbnet.CustomRoutingDisabled() || os.Getenv(nbnet.EnvSkipSocketMark) == "true"
return os.Getenv("NB_USE_LEGACY_ROUTING") == "true" || nbnet.CustomRoutingDisabled() || nbnet.SkipSocketMark()
}
// setIsLegacy sets the legacy routing setup
@@ -92,17 +92,6 @@ func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager
return r.setupRefCounter(initAddresses, stateManager)
}
if err = addRoutingTableName(); err != nil {
log.Errorf("Error adding routing table name: %v", err)
}
originalValues, err := sysctl.Setup(r.wgInterface)
if err != nil {
log.Errorf("Error setting up sysctl: %v", err)
sysctlFailed = true
}
originalSysctl = originalValues
defer func() {
if err != nil {
if cleanErr := r.CleanupRouting(stateManager); cleanErr != nil {
@@ -123,6 +112,17 @@ func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager
}
}
if err = addRoutingTableName(); err != nil {
log.Errorf("Error adding routing table name: %v", err)
}
originalValues, err := sysctl.Setup(r.wgInterface)
if err != nil {
log.Errorf("Error setting up sysctl: %v", err)
sysctlFailed = true
}
originalSysctl = originalValues
return nil, nil, nil
}
@@ -450,7 +450,7 @@ func addRule(params ruleParams) error {
rule.Invert = params.invert
rule.SuppressPrefixlen = params.suppressPrefix
if err := netlink.RuleAdd(rule); err != nil && !errors.Is(err, syscall.EEXIST) && !errors.Is(err, syscall.EAFNOSUPPORT) {
if err := netlink.RuleAdd(rule); err != nil && !errors.Is(err, syscall.EEXIST) {
return fmt.Errorf("add routing rule: %w", err)
}
@@ -467,7 +467,7 @@ func removeRule(params ruleParams) error {
rule.Priority = params.priority
rule.SuppressPrefixlen = params.suppressPrefix
if err := netlink.RuleDel(rule); err != nil && !errors.Is(err, syscall.ENOENT) && !errors.Is(err, syscall.EAFNOSUPPORT) {
if err := netlink.RuleDel(rule); err != nil && !errors.Is(err, syscall.ENOENT) {
return fmt.Errorf("remove routing rule: %w", err)
}

View File

@@ -123,7 +123,7 @@ func (c *Client) Run(fd int32, interfaceName string) error {
c.onHostDnsFn = func([]string) {}
cfg.WgIface = interfaceName
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder, &system.StaticInfo{})
return c.connectClient.RunOniOS(fd, c.networkChangeListener, c.dnsManager)
}
@@ -204,7 +204,7 @@ func (c *Client) IsLoginRequired() bool {
ConfigPath: c.cfgFile,
})
needsLogin, _ := internal.IsLoginRequired(ctx, cfg.PrivateKey, cfg.ManagementURL, cfg.SSHKey)
needsLogin, _ := internal.IsLoginRequired(ctx, cfg.PrivateKey, cfg.ManagementURL, cfg.SSHKey, &system.StaticInfo{})
return needsLogin
}
@@ -244,7 +244,7 @@ func (c *Client) LoginForMobile() string {
return
}
jwtToken := tokenInfo.GetTokenToUse()
_ = internal.Login(ctx, cfg, "", jwtToken)
_ = internal.Login(ctx, cfg, "", jwtToken, &system.StaticInfo{})
c.loginComplete = true
}()

View File

@@ -104,7 +104,7 @@ func (a *Auth) LoginWithSetupKeyAndSaveConfig(setupKey string, deviceName string
ctxWithValues := context.WithValue(a.ctx, system.DeviceNameCtxKey, deviceName)
err := a.withBackOff(a.ctx, func() error {
backoffErr := internal.Login(ctxWithValues, a.config, setupKey, "")
backoffErr := internal.Login(ctxWithValues, a.config, setupKey, "", &system.StaticInfo{})
if s, ok := gstatus.FromError(backoffErr); ok && (s.Code() == codes.PermissionDenied) {
// we got an answer from management, exit backoff earlier
return backoff.Permanent(backoffErr)
@@ -123,7 +123,7 @@ func (a *Auth) Login() error {
// check if we need to generate JWT token
err := a.withBackOff(a.ctx, func() (err error) {
needsLogin, err = internal.IsLoginRequired(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config.SSHKey)
needsLogin, err = internal.IsLoginRequired(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config.SSHKey, &system.StaticInfo{})
return
})
if err != nil {
@@ -136,7 +136,7 @@ func (a *Auth) Login() error {
}
err = a.withBackOff(a.ctx, func() error {
err := internal.Login(a.ctx, a.config, "", jwtToken)
err := internal.Login(a.ctx, a.config, "", jwtToken, &system.StaticInfo{})
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
return nil
}

View File

@@ -68,6 +68,7 @@ type Server struct {
relayProbe *internal.Probe
wgProbe *internal.Probe
lastProbe time.Time
staticInfo *system.StaticInfo
}
type oauthAuthFlow struct {
@@ -79,6 +80,8 @@ type oauthAuthFlow struct {
// New server instance constructor.
func New(ctx context.Context, configPath, logFile string) *Server {
staticInfoChan := system.GetStaticInfoInBackground(ctx)
staticInfo := <-staticInfoChan
return &Server{
rootCtx: ctx,
latestConfigInput: internal.ConfigInput{
@@ -89,6 +92,7 @@ func New(ctx context.Context, configPath, logFile string) *Server {
signalProbe: internal.NewProbe(),
relayProbe: internal.NewProbe(),
wgProbe: internal.NewProbe(),
staticInfo: staticInfo,
}
}
@@ -195,7 +199,7 @@ func (s *Server) connectWithRetryRuns(ctx context.Context, config *internal.Conf
runOperation := func() error {
log.Tracef("running client connection")
s.connectClient = internal.NewConnectClient(ctx, config, statusRecorder)
s.connectClient = internal.NewConnectClient(ctx, config, statusRecorder, s.staticInfo)
probes := internal.ProbeHolder{
MgmProbe: s.mgmProbe,
@@ -272,7 +276,7 @@ func parseEnvDuration(envVar string, defaultDuration time.Duration) time.Duratio
// loginAttempt attempts to login using the provided information. it returns a status in case something fails
func (s *Server) loginAttempt(ctx context.Context, setupKey, jwtToken string) (internal.StatusType, error) {
var status internal.StatusType
err := internal.Login(ctx, s.config, setupKey, jwtToken)
err := internal.Login(ctx, s.config, setupKey, jwtToken, s.staticInfo)
if err != nil {
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
log.Warnf("failed login: %v", err)

View File

@@ -61,6 +61,14 @@ type Info struct {
Files []File // for posture checks
}
// StaticInfo is an object that contains machine information that does not change
type StaticInfo struct {
SystemSerialNumber string
SystemProductName string
SystemManufacturer string
Environment Environment
}
// extractUserAgent extracts Netbird's agent (client) name and version from the outgoing context
func extractUserAgent(ctx context.Context) string {
md, hasMeta := metadata.FromOutgoingContext(ctx)
@@ -142,7 +150,7 @@ func isDuplicated(addresses []NetworkAddress, addr NetworkAddress) bool {
}
// GetInfoWithChecks retrieves and parses the system information with applied checks.
func GetInfoWithChecks(ctx context.Context, checks []*proto.Checks) (*Info, error) {
func GetInfoWithChecks(ctx context.Context, checks []*proto.Checks, staticInfo *StaticInfo) (*Info, error) {
processCheckPaths := make([]string, 0)
for _, check := range checks {
processCheckPaths = append(processCheckPaths, check.GetFiles()...)
@@ -153,8 +161,17 @@ func GetInfoWithChecks(ctx context.Context, checks []*proto.Checks) (*Info, erro
return nil, err
}
info := GetInfo(ctx)
info := GetInfo(ctx, staticInfo)
info.Files = files
return info, nil
}
// GetStaticInfoInBackground retrieves and parses the system information in the background
func GetStaticInfoInBackground(ctx context.Context) <-chan *StaticInfo {
ch := make(chan *StaticInfo)
go func() {
ch <- getStaticInfo(ctx)
}()
return ch
}

View File

@@ -16,7 +16,7 @@ import (
)
// GetInfo retrieves and parses the system information
func GetInfo(ctx context.Context) *Info {
func GetInfo(ctx context.Context, _ *StaticInfo) *Info {
kernel := "android"
osInfo := uname()
if len(osInfo) == 2 {
@@ -44,6 +44,10 @@ func GetInfo(ctx context.Context) *Info {
return gio
}
func getStaticInfo(ctx context.Context) *StaticInfo {
return nil
}
// checkFileAndProcess checks if the file path exists and if a process is running at that path.
func checkFileAndProcess(paths []string) ([]File, error) {
return []File{}, nil

View File

@@ -21,7 +21,7 @@ import (
)
// GetInfo retrieves and parses the system information
func GetInfo(ctx context.Context) *Info {
func GetInfo(ctx context.Context, staticInfo *StaticInfo) *Info {
utsname := unix.Utsname{}
err := unix.Uname(&utsname)
if err != nil {
@@ -41,26 +41,22 @@ func GetInfo(ctx context.Context) *Info {
log.Warnf("failed to discover network addresses: %s", err)
}
serialNum, prodName, manufacturer := sysInfo()
env := Environment{
Cloud: detect_cloud.Detect(ctx),
Platform: detect_platform.Detect(ctx),
gio := &Info{
Kernel: sysName,
OSVersion: strings.TrimSpace(string(swVersion)),
Platform: machine,
OS: sysName,
GoOS: runtime.GOOS,
CPUs: runtime.NumCPU(),
KernelVersion: release,
NetworkAddresses: addrs,
}
gio := &Info{
Kernel: sysName,
OSVersion: strings.TrimSpace(string(swVersion)),
Platform: machine,
OS: sysName,
GoOS: runtime.GOOS,
CPUs: runtime.NumCPU(),
KernelVersion: release,
NetworkAddresses: addrs,
SystemSerialNumber: serialNum,
SystemProductName: prodName,
SystemManufacturer: manufacturer,
Environment: env,
if staticInfo != nil {
gio.SystemSerialNumber = staticInfo.SystemSerialNumber
gio.SystemProductName = staticInfo.SystemProductName
gio.SystemManufacturer = staticInfo.SystemManufacturer
gio.Environment = staticInfo.Environment
}
systemHostname, _ := os.Hostname()
@@ -71,6 +67,21 @@ func GetInfo(ctx context.Context) *Info {
return gio
}
func getStaticInfo(ctx context.Context) *StaticInfo {
serialNum, prodName, manufacturer := sysInfo()
env := Environment{
Cloud: detect_cloud.Detect(ctx),
Platform: detect_platform.Detect(ctx),
}
return &StaticInfo{
SystemSerialNumber: serialNum,
SystemProductName: prodName,
SystemManufacturer: manufacturer,
Environment: env,
}
}
func sysInfo() (serialNumber string, productName string, manufacturer string) {
out, _ := exec.Command("/usr/sbin/ioreg", "-l").Output() // err ignored for brevity
for _, l := range strings.Split(string(out), "\n") {

View File

@@ -19,7 +19,7 @@ import (
)
// GetInfo retrieves and parses the system information
func GetInfo(ctx context.Context) *Info {
func GetInfo(ctx context.Context, staticInfo *StaticInfo) *Info {
out := _getInfo()
for strings.Contains(out, "broken pipe") {
out = _getInfo()
@@ -29,16 +29,11 @@ func GetInfo(ctx context.Context) *Info {
osStr = strings.ReplaceAll(osStr, "\r\n", "")
osInfo := strings.Split(osStr, " ")
env := Environment{
Cloud: detect_cloud.Detect(ctx),
Platform: detect_platform.Detect(ctx),
}
osName, osVersion := readOsReleaseFile()
systemHostname, _ := os.Hostname()
return &Info{
info := &Info{
GoOS: runtime.GOOS,
Kernel: osInfo[0],
Platform: runtime.GOARCH,
@@ -49,7 +44,25 @@ func GetInfo(ctx context.Context) *Info {
WiretrusteeVersion: version.NetbirdVersion(),
UIVersion: extractUserAgent(ctx),
KernelVersion: osInfo[1],
Environment: env,
}
if staticInfo != nil {
info.SystemSerialNumber = staticInfo.SystemSerialNumber
info.SystemProductName = staticInfo.SystemProductName
info.SystemManufacturer = staticInfo.SystemManufacturer
info.Environment = staticInfo.Environment
}
return info
}
func getStaticInfo(ctx context.Context) *StaticInfo {
env := Environment{
Cloud: detect_cloud.Detect(ctx),
Platform: detect_platform.Detect(ctx),
}
return &StaticInfo{
Environment: env,
}
}

View File

@@ -11,7 +11,7 @@ import (
)
// GetInfo retrieves and parses the system information
func GetInfo(ctx context.Context) *Info {
func GetInfo(ctx context.Context, _ *StaticInfo) *Info {
// Convert fixed-size byte arrays to Go strings
sysName := extractOsName(ctx, "sysName")
@@ -25,6 +25,10 @@ func GetInfo(ctx context.Context) *Info {
return gio
}
func getStaticInfo(ctx context.Context) *StaticInfo {
return nil
}
// checkFileAndProcess checks if the file path exists and if a process is running at that path.
func checkFileAndProcess(paths []string) ([]File, error) {
return []File{}, nil

View File

@@ -42,7 +42,7 @@ func (s SysInfoWrapper) GetSysInfo() SysInfo {
}
// GetInfo retrieves and parses the system information
func GetInfo(ctx context.Context) *Info {
func GetInfo(ctx context.Context, staticInfo *StaticInfo) *Info {
info := _getInfo()
for strings.Contains(info, "broken pipe") {
info = _getInfo()
@@ -65,14 +65,6 @@ func GetInfo(ctx context.Context) *Info {
log.Warnf("failed to discover network addresses: %s", err)
}
si := SysInfoWrapper{}
serialNum, prodName, manufacturer := sysInfo(si.GetSysInfo())
env := Environment{
Cloud: detect_cloud.Detect(ctx),
Platform: detect_platform.Detect(ctx),
}
gio := &Info{
Kernel: osInfo[0],
Platform: osInfo[2],
@@ -85,13 +77,32 @@ func GetInfo(ctx context.Context) *Info {
UIVersion: extractUserAgent(ctx),
KernelVersion: osInfo[1],
NetworkAddresses: addrs,
}
if staticInfo != nil {
gio.SystemSerialNumber = staticInfo.SystemSerialNumber
gio.SystemProductName = staticInfo.SystemProductName
gio.SystemManufacturer = staticInfo.SystemManufacturer
gio.Environment = staticInfo.Environment
}
return gio
}
func getStaticInfo(ctx context.Context) *StaticInfo {
si := SysInfoWrapper{}
serialNum, prodName, manufacturer := sysInfo(si.GetSysInfo())
env := Environment{
Cloud: detect_cloud.Detect(ctx),
Platform: detect_platform.Detect(ctx),
}
return &StaticInfo{
SystemSerialNumber: serialNum,
SystemProductName: prodName,
SystemManufacturer: manufacturer,
Environment: env,
}
return gio
}
func _getInfo() string {

View File

@@ -9,7 +9,7 @@ import (
)
func Test_LocalWTVersion(t *testing.T) {
got := GetInfo(context.TODO())
got := GetInfo(context.TODO(), nil)
want := "development"
assert.Equal(t, want, got.WiretrusteeVersion)
}
@@ -21,7 +21,7 @@ func Test_UIVersion(t *testing.T) {
"user-agent": {want},
})
got := GetInfo(ctx)
got := GetInfo(ctx, nil)
assert.Equal(t, want, got.UIVersion)
}
@@ -30,7 +30,7 @@ func Test_CustomHostname(t *testing.T) {
ctx := context.WithValue(context.Background(), DeviceNameCtxKey, "custom-host")
want := "custom-host"
got := GetInfo(ctx)
got := GetInfo(ctx, nil)
assert.Equal(t, want, got.Hostname)
}

View File

@@ -33,7 +33,7 @@ type Win32_BIOS struct {
}
// GetInfo retrieves and parses the system information
func GetInfo(ctx context.Context) *Info {
func GetInfo(ctx context.Context, staticInfo *StaticInfo) *Info {
osName, osVersion := getOSNameAndVersion()
buildVersion := getBuildVersion()
@@ -42,39 +42,22 @@ func GetInfo(ctx context.Context) *Info {
log.Warnf("failed to discover network addresses: %s", err)
}
serialNum, err := sysNumber()
if err != nil {
log.Warnf("failed to get system serial number: %s", err)
}
prodName, err := sysProductName()
if err != nil {
log.Warnf("failed to get system product name: %s", err)
}
manufacturer, err := sysManufacturer()
if err != nil {
log.Warnf("failed to get system manufacturer: %s", err)
}
env := Environment{
Cloud: detect_cloud.Detect(ctx),
Platform: detect_platform.Detect(ctx),
}
gio := &Info{
Kernel: "windows",
OSVersion: osVersion,
Platform: "unknown",
OS: osName,
GoOS: runtime.GOOS,
CPUs: runtime.NumCPU(),
KernelVersion: buildVersion,
NetworkAddresses: addrs,
SystemSerialNumber: serialNum,
SystemProductName: prodName,
SystemManufacturer: manufacturer,
Environment: env,
Kernel: "windows",
OSVersion: osVersion,
Platform: "unknown",
OS: osName,
GoOS: runtime.GOOS,
CPUs: runtime.NumCPU(),
KernelVersion: buildVersion,
NetworkAddresses: addrs,
}
if staticInfo != nil {
gio.SystemSerialNumber = staticInfo.SystemSerialNumber
gio.SystemProductName = staticInfo.SystemProductName
gio.SystemManufacturer = staticInfo.SystemManufacturer
gio.Environment = staticInfo.Environment
}
systemHostname, _ := os.Hostname()
@@ -85,6 +68,41 @@ func GetInfo(ctx context.Context) *Info {
return gio
}
func getStaticInfo(ctx context.Context) *StaticInfo {
serialNum, prodName, manufacturer := sysInfo()
env := Environment{
Cloud: detect_cloud.Detect(ctx),
Platform: detect_platform.Detect(ctx),
}
return &StaticInfo{
SystemSerialNumber: serialNum,
SystemProductName: prodName,
SystemManufacturer: manufacturer,
Environment: env,
}
}
func sysInfo() (serialNumber string, productName string, manufacturer string) {
var err error
serialNumber, err = sysNumber()
if err != nil {
log.Warnf("failed to get system serial number: %s", err)
}
productName, err = sysProductName()
if err != nil {
log.Warnf("failed to get system product name: %s", err)
}
manufacturer, err = sysManufacturer()
if err != nil {
log.Warnf("failed to get system manufacturer: %s", err)
}
return serialNumber, productName, manufacturer
}
func getOSNameAndVersion() (string, string) {
var dst []Win32_OperatingSystem
query := wmi.CreateQuery(&dst, "")

2
go.mod
View File

@@ -236,7 +236,7 @@ replace github.com/kardianos/service => github.com/netbirdio/service v0.0.0-2024
replace github.com/getlantern/systray => github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949
replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20241107152827-57d8513b5f73
replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20241125150134-f9cdce5e32e9
replace github.com/cloudflare/circl => github.com/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6

4
go.sum
View File

@@ -527,8 +527,8 @@ github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9ax
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d h1:bRq5TKgC7Iq20pDiuC54yXaWnAVeS5PdGpSokFTlR28=
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ=
github.com/netbirdio/wireguard-go v0.0.0-20241107152827-57d8513b5f73 h1:jayg97LH/jJlvpIHVxueTfa+tfQ+FY8fy2sIhCwkz0g=
github.com/netbirdio/wireguard-go v0.0.0-20241107152827-57d8513b5f73/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA=
github.com/netbirdio/wireguard-go v0.0.0-20241125150134-f9cdce5e32e9 h1:Pu/7EukijT09ynHUOzQYW7cC3M/BKU8O4qyN/TvTGoY=
github.com/netbirdio/wireguard-go v0.0.0-20241125150134-f9cdce5e32e9/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA=
github.com/nicksnyder/go-i18n/v2 v2.4.0 h1:3IcvPOAvnCKwNm0TB0dLDTuawWEj+ax/RERNC+diLMM=
github.com/nicksnyder/go-i18n/v2 v2.4.0/go.mod h1:nxYSZE9M0bf3Y70gPQjN9ha7XNHX7gMc814+6wVyEI4=
github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A=

View File

@@ -174,7 +174,7 @@ func TestClient_LoginUnregistered_ShouldThrow_401(t *testing.T) {
if err != nil {
t.Fatal(err)
}
sysInfo := system.GetInfo(context.TODO())
sysInfo := system.GetInfo(context.TODO(), nil)
_, err = client.Login(*key, sysInfo, nil)
if err == nil {
t.Error("expecting err on unregistered login, got nil")
@@ -202,7 +202,7 @@ func TestClient_LoginRegistered(t *testing.T) {
if err != nil {
t.Error(err)
}
info := system.GetInfo(context.TODO())
info := system.GetInfo(context.TODO(), nil)
resp, err := client.Register(*key, ValidKey, "", info, nil)
if err != nil {
t.Error(err)
@@ -232,7 +232,7 @@ func TestClient_Sync(t *testing.T) {
t.Error(err)
}
info := system.GetInfo(context.TODO())
info := system.GetInfo(context.TODO(), nil)
_, err = client.Register(*serverKey, ValidKey, "", info, nil)
if err != nil {
t.Error(err)
@@ -248,7 +248,7 @@ func TestClient_Sync(t *testing.T) {
t.Fatal(err)
}
info = system.GetInfo(context.TODO())
info = system.GetInfo(context.TODO(), nil)
_, err = remoteClient.Register(*serverKey, ValidKey, "", info, nil)
if err != nil {
t.Fatal(err)
@@ -346,7 +346,7 @@ func Test_SystemMetaDataFromClient(t *testing.T) {
}, nil
}
info := system.GetInfo(context.TODO())
info := system.GetInfo(context.TODO(), nil)
_, err = testClient.Register(*key, ValidKey, "", info, nil)
if err != nil {
t.Errorf("error while trying to register client: %v", err)

View File

@@ -113,7 +113,7 @@ type AccountManager interface {
GroupAddPeer(ctx context.Context, accountId, groupID, peerID string) error
GroupDeletePeer(ctx context.Context, accountId, groupID, peerID string) error
GetPolicy(ctx context.Context, accountID, policyID, userID string) (*Policy, error)
SavePolicy(ctx context.Context, accountID, userID string, policy *Policy, isUpdate bool) error
SavePolicy(ctx context.Context, accountID, userID string, policy *Policy) (*Policy, error)
DeletePolicy(ctx context.Context, accountID, policyID, userID string) error
ListPolicies(ctx context.Context, accountID, userID string) ([]*Policy, error)
GetRoute(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error)
@@ -139,7 +139,7 @@ type AccountManager interface {
HasConnectedChannel(peerID string) bool
GetExternalCacheManager() ExternalCacheManager
GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error)
SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error
SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) (*posture.Checks, error)
DeletePostureChecks(ctx context.Context, accountID, postureChecksID, userID string) error
ListPostureChecks(ctx context.Context, accountID, userID string) ([]*posture.Checks, error)
GetIdpManager() idp.Manager

View File

@@ -6,13 +6,17 @@ import (
b64 "encoding/base64"
"encoding/json"
"fmt"
"io"
"net"
"os"
"reflect"
"strconv"
"sync"
"testing"
"time"
"github.com/golang-jwt/jwt"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
@@ -1038,7 +1042,7 @@ func BenchmarkTest_GetAccountWithclaims(b *testing.B) {
}
b.Run("public without account ID", func(b *testing.B) {
//b.ResetTimer()
// b.ResetTimer()
for i := 0; i < b.N; i++ {
_, err := am.getAccountIDWithAuthorizationClaims(context.Background(), publicClaims)
if err != nil {
@@ -1048,7 +1052,7 @@ func BenchmarkTest_GetAccountWithclaims(b *testing.B) {
})
b.Run("private without account ID", func(b *testing.B) {
//b.ResetTimer()
// b.ResetTimer()
for i := 0; i < b.N; i++ {
_, err := am.getAccountIDWithAuthorizationClaims(context.Background(), claims)
if err != nil {
@@ -1059,7 +1063,7 @@ func BenchmarkTest_GetAccountWithclaims(b *testing.B) {
b.Run("private with account ID", func(b *testing.B) {
claims.AccountId = id
//b.ResetTimer()
// b.ResetTimer()
for i := 0; i < b.N; i++ {
_, err := am.getAccountIDWithAuthorizationClaims(context.Background(), claims)
if err != nil {
@@ -1238,8 +1242,7 @@ func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) {
return
}
policy := Policy{
ID: "policy",
_, err := manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
Enabled: true,
Rules: []*PolicyRule{
{
@@ -1250,8 +1253,7 @@ func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) {
Action: PolicyTrafficActionAccept,
},
},
}
err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false)
})
require.NoError(t, err)
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
@@ -1320,19 +1322,6 @@ func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) {
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
policy := Policy{
Enabled: true,
Rules: []*PolicyRule{
{
Enabled: true,
Sources: []string{"groupA"},
Destinations: []string{"groupA"},
Bidirectional: true,
Action: PolicyTrafficActionAccept,
},
},
}
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
@@ -1345,7 +1334,19 @@ func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) {
}
}()
if err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false); err != nil {
_, err := manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
Enabled: true,
Rules: []*PolicyRule{
{
Enabled: true,
Sources: []string{"groupA"},
Destinations: []string{"groupA"},
Bidirectional: true,
Action: PolicyTrafficActionAccept,
},
},
})
if err != nil {
t.Errorf("delete default rule: %v", err)
return
}
@@ -1366,7 +1367,7 @@ func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) {
return
}
policy := Policy{
_, err := manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
Enabled: true,
Rules: []*PolicyRule{
{
@@ -1377,9 +1378,8 @@ func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) {
Action: PolicyTrafficActionAccept,
},
},
}
if err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false); err != nil {
})
if err != nil {
t.Errorf("save policy: %v", err)
return
}
@@ -1421,7 +1421,12 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
require.NoError(t, err, "failed to save group")
policy := Policy{
if err := manager.DeletePolicy(context.Background(), account.Id, account.Policies[0].ID, userID); err != nil {
t.Errorf("delete default rule: %v", err)
return
}
policy, err := manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
Enabled: true,
Rules: []*PolicyRule{
{
@@ -1432,14 +1437,8 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
Action: PolicyTrafficActionAccept,
},
},
}
if err := manager.DeletePolicy(context.Background(), account.Id, account.Policies[0].ID, userID); err != nil {
t.Errorf("delete default rule: %v", err)
return
}
if err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false); err != nil {
})
if err != nil {
t.Errorf("save policy: %v", err)
return
}
@@ -2987,3 +2986,218 @@ func peerShouldReceiveUpdate(t *testing.T, updateMessage <-chan *UpdateMessage)
t.Error("Timed out waiting for update message")
}
}
func BenchmarkSyncAndMarkPeer(b *testing.B) {
benchCases := []struct {
name string
peers int
groups int
// We need different expectations for CI/CD and local runs because of the different performance characteristics
minMsPerOpLocal float64
maxMsPerOpLocal float64
minMsPerOpCICD float64
maxMsPerOpCICD float64
}{
{"Small", 50, 5, 1, 3, 4, 10},
{"Medium", 500, 100, 7, 13, 10, 60},
{"Large", 5000, 200, 65, 80, 60, 170},
{"Small single", 50, 10, 1, 3, 4, 60},
{"Medium single", 500, 10, 7, 13, 10, 26},
{"Large 5", 5000, 15, 65, 80, 60, 170},
}
log.SetOutput(io.Discard)
defer log.SetOutput(os.Stderr)
for _, bc := range benchCases {
b.Run(bc.name, func(b *testing.B) {
manager, accountID, _, err := setupTestAccountManager(b, bc.peers, bc.groups)
if err != nil {
b.Fatalf("Failed to setup test account manager: %v", err)
}
ctx := context.Background()
account, err := manager.Store.GetAccount(ctx, accountID)
if err != nil {
b.Fatalf("Failed to get account: %v", err)
}
peerChannels := make(map[string]chan *UpdateMessage)
for peerID := range account.Peers {
peerChannels[peerID] = make(chan *UpdateMessage, channelBufferSize)
}
manager.peersUpdateManager.peerChannels = peerChannels
b.ResetTimer()
start := time.Now()
for i := 0; i < b.N; i++ {
_, _, _, err := manager.SyncAndMarkPeer(context.Background(), account.Id, account.Peers["peer-1"].Key, nbpeer.PeerSystemMeta{Hostname: strconv.Itoa(i)}, net.IP{1, 1, 1, 1})
assert.NoError(b, err)
}
duration := time.Since(start)
msPerOp := float64(duration.Nanoseconds()) / float64(b.N) / 1e6
b.ReportMetric(msPerOp, "ms/op")
minExpected := bc.minMsPerOpLocal
maxExpected := bc.maxMsPerOpLocal
if os.Getenv("CI") == "true" {
minExpected = bc.minMsPerOpCICD
maxExpected = bc.maxMsPerOpCICD
}
if msPerOp < minExpected {
b.Fatalf("Benchmark %s failed: too fast (%.2f ms/op, minimum %.2f ms/op)", bc.name, msPerOp, minExpected)
}
if msPerOp > maxExpected {
b.Fatalf("Benchmark %s failed: too slow (%.2f ms/op, maximum %.2f ms/op)", bc.name, msPerOp, maxExpected)
}
})
}
}
func BenchmarkLoginPeer_ExistingPeer(b *testing.B) {
benchCases := []struct {
name string
peers int
groups int
// We need different expectations for CI/CD and local runs because of the different performance characteristics
minMsPerOpLocal float64
maxMsPerOpLocal float64
minMsPerOpCICD float64
maxMsPerOpCICD float64
}{
{"Small", 50, 5, 102, 110, 102, 120},
{"Medium", 500, 100, 105, 140, 105, 170},
{"Large", 5000, 200, 160, 200, 160, 270},
{"Small single", 50, 10, 102, 110, 102, 120},
{"Medium single", 500, 10, 105, 140, 105, 170},
{"Large 5", 5000, 15, 160, 200, 160, 270},
}
log.SetOutput(io.Discard)
defer log.SetOutput(os.Stderr)
for _, bc := range benchCases {
b.Run(bc.name, func(b *testing.B) {
manager, accountID, _, err := setupTestAccountManager(b, bc.peers, bc.groups)
if err != nil {
b.Fatalf("Failed to setup test account manager: %v", err)
}
ctx := context.Background()
account, err := manager.Store.GetAccount(ctx, accountID)
if err != nil {
b.Fatalf("Failed to get account: %v", err)
}
peerChannels := make(map[string]chan *UpdateMessage)
for peerID := range account.Peers {
peerChannels[peerID] = make(chan *UpdateMessage, channelBufferSize)
}
manager.peersUpdateManager.peerChannels = peerChannels
b.ResetTimer()
start := time.Now()
for i := 0; i < b.N; i++ {
_, _, _, err := manager.LoginPeer(context.Background(), PeerLogin{
WireGuardPubKey: account.Peers["peer-1"].Key,
SSHKey: "someKey",
Meta: nbpeer.PeerSystemMeta{Hostname: strconv.Itoa(i)},
UserID: "regular_user",
SetupKey: "",
ConnectionIP: net.IP{1, 1, 1, 1},
})
assert.NoError(b, err)
}
duration := time.Since(start)
msPerOp := float64(duration.Nanoseconds()) / float64(b.N) / 1e6
b.ReportMetric(msPerOp, "ms/op")
minExpected := bc.minMsPerOpLocal
maxExpected := bc.maxMsPerOpLocal
if os.Getenv("CI") == "true" {
minExpected = bc.minMsPerOpCICD
maxExpected = bc.maxMsPerOpCICD
}
if msPerOp < minExpected {
b.Fatalf("Benchmark %s failed: too fast (%.2f ms/op, minimum %.2f ms/op)", bc.name, msPerOp, minExpected)
}
if msPerOp > maxExpected {
b.Fatalf("Benchmark %s failed: too slow (%.2f ms/op, maximum %.2f ms/op)", bc.name, msPerOp, maxExpected)
}
})
}
}
func BenchmarkLoginPeer_NewPeer(b *testing.B) {
benchCases := []struct {
name string
peers int
groups int
// We need different expectations for CI/CD and local runs because of the different performance characteristics
minMsPerOpLocal float64
maxMsPerOpLocal float64
minMsPerOpCICD float64
maxMsPerOpCICD float64
}{
{"Small", 50, 5, 107, 120, 107, 140},
{"Medium", 500, 100, 105, 140, 105, 170},
{"Large", 5000, 200, 180, 220, 180, 320},
{"Small single", 50, 10, 107, 120, 105, 140},
{"Medium single", 500, 10, 105, 140, 105, 170},
{"Large 5", 5000, 15, 180, 220, 180, 320},
}
log.SetOutput(io.Discard)
defer log.SetOutput(os.Stderr)
for _, bc := range benchCases {
b.Run(bc.name, func(b *testing.B) {
manager, accountID, _, err := setupTestAccountManager(b, bc.peers, bc.groups)
if err != nil {
b.Fatalf("Failed to setup test account manager: %v", err)
}
ctx := context.Background()
account, err := manager.Store.GetAccount(ctx, accountID)
if err != nil {
b.Fatalf("Failed to get account: %v", err)
}
peerChannels := make(map[string]chan *UpdateMessage)
for peerID := range account.Peers {
peerChannels[peerID] = make(chan *UpdateMessage, channelBufferSize)
}
manager.peersUpdateManager.peerChannels = peerChannels
b.ResetTimer()
start := time.Now()
for i := 0; i < b.N; i++ {
_, _, _, err := manager.LoginPeer(context.Background(), PeerLogin{
WireGuardPubKey: "some-new-key" + strconv.Itoa(i),
SSHKey: "someKey",
Meta: nbpeer.PeerSystemMeta{Hostname: strconv.Itoa(i)},
UserID: "regular_user",
SetupKey: "",
ConnectionIP: net.IP{1, 1, 1, 1},
})
assert.NoError(b, err)
}
duration := time.Since(start)
msPerOp := float64(duration.Nanoseconds()) / float64(b.N) / 1e6
b.ReportMetric(msPerOp, "ms/op")
minExpected := bc.minMsPerOpLocal
maxExpected := bc.maxMsPerOpLocal
if os.Getenv("CI") == "true" {
minExpected = bc.minMsPerOpCICD
maxExpected = bc.maxMsPerOpCICD
}
if msPerOp < minExpected {
b.Fatalf("Benchmark %s failed: too fast (%.2f ms/op, minimum %.2f ms/op)", bc.name, msPerOp, minExpected)
}
if msPerOp > maxExpected {
b.Fatalf("Benchmark %s failed: too slow (%.2f ms/op, maximum %.2f ms/op)", bc.name, msPerOp, maxExpected)
}
})
}
}

View File

@@ -3,6 +3,7 @@ package server
import (
"context"
"fmt"
"slices"
"strconv"
"sync"
@@ -85,8 +86,12 @@ func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID s
return nil, err
}
if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view DNS settings")
if user.AccountID != accountID {
return nil, status.NewUserNotPartOfAccountError()
}
if user.IsRegularUser() {
return nil, status.NewAdminPermissionError()
}
return am.Store.GetAccountDNSSettings(ctx, LockingStrengthShare, accountID)
@@ -94,64 +99,137 @@ func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID s
// SaveDNSSettings validates a user role and updates the account's DNS settings
func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID string, userID string, dnsSettingsToSave *DNSSettings) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return err
}
user, err := account.FindUser(userID)
if err != nil {
return err
}
if !user.HasAdminPower() {
return status.Errorf(status.PermissionDenied, "only users with admin power are allowed to update DNS settings")
}
if dnsSettingsToSave == nil {
return status.Errorf(status.InvalidArgument, "the dns settings provided are nil")
}
if len(dnsSettingsToSave.DisabledManagementGroups) != 0 {
err = validateGroups(dnsSettingsToSave.DisabledManagementGroups, account.Groups)
if err != nil {
return err
}
}
oldSettings := account.DNSSettings.Copy()
account.DNSSettings = dnsSettingsToSave.Copy()
addedGroups := difference(dnsSettingsToSave.DisabledManagementGroups, oldSettings.DisabledManagementGroups)
removedGroups := difference(oldSettings.DisabledManagementGroups, dnsSettingsToSave.DisabledManagementGroups)
account.Network.IncSerial()
if err = am.Store.SaveAccount(ctx, account); err != nil {
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return err
}
for _, id := range addedGroups {
group := account.GetGroup(id)
meta := map[string]any{"group": group.Name, "group_id": group.ID}
am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupAddedToDisabledManagementGroups, meta)
if user.AccountID != accountID {
return status.NewUserNotPartOfAccountError()
}
for _, id := range removedGroups {
group := account.GetGroup(id)
meta := map[string]any{"group": group.Name, "group_id": group.ID}
am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupRemovedFromDisabledManagementGroups, meta)
if !user.HasAdminPower() {
return status.NewAdminPermissionError()
}
if anyGroupHasPeers(account, addedGroups) || anyGroupHasPeers(account, removedGroups) {
var updateAccountPeers bool
var eventsToStore []func()
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
if err = validateDNSSettings(ctx, transaction, accountID, dnsSettingsToSave); err != nil {
return err
}
oldSettings, err := transaction.GetAccountDNSSettings(ctx, LockingStrengthUpdate, accountID)
if err != nil {
return err
}
addedGroups := difference(dnsSettingsToSave.DisabledManagementGroups, oldSettings.DisabledManagementGroups)
removedGroups := difference(oldSettings.DisabledManagementGroups, dnsSettingsToSave.DisabledManagementGroups)
updateAccountPeers, err = areDNSSettingChangesAffectPeers(ctx, transaction, accountID, addedGroups, removedGroups)
if err != nil {
return err
}
events := am.prepareDNSSettingsEvents(ctx, transaction, accountID, userID, addedGroups, removedGroups)
eventsToStore = append(eventsToStore, events...)
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return err
}
return transaction.SaveDNSSettings(ctx, LockingStrengthUpdate, accountID, dnsSettingsToSave)
})
if err != nil {
return err
}
for _, storeEvent := range eventsToStore {
storeEvent()
}
if updateAccountPeers {
am.updateAccountPeers(ctx, accountID)
}
return nil
}
// prepareDNSSettingsEvents prepares a list of event functions to be stored.
func (am *DefaultAccountManager) prepareDNSSettingsEvents(ctx context.Context, transaction Store, accountID, userID string, addedGroups, removedGroups []string) []func() {
var eventsToStore []func()
modifiedGroups := slices.Concat(addedGroups, removedGroups)
groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, modifiedGroups)
if err != nil {
log.WithContext(ctx).Debugf("failed to get groups for dns settings events: %v", err)
return nil
}
for _, groupID := range addedGroups {
group, ok := groups[groupID]
if !ok {
log.WithContext(ctx).Debugf("skipped adding group: %s GroupAddedToDisabledManagementGroups activity", groupID)
continue
}
eventsToStore = append(eventsToStore, func() {
meta := map[string]any{"group": group.Name, "group_id": group.ID}
am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupAddedToDisabledManagementGroups, meta)
})
}
for _, groupID := range removedGroups {
group, ok := groups[groupID]
if !ok {
log.WithContext(ctx).Debugf("skipped adding group: %s GroupRemovedFromDisabledManagementGroups activity", groupID)
continue
}
eventsToStore = append(eventsToStore, func() {
meta := map[string]any{"group": group.Name, "group_id": group.ID}
am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupRemovedFromDisabledManagementGroups, meta)
})
}
return eventsToStore
}
// areDNSSettingChangesAffectPeers checks if the DNS settings changes affect any peers.
func areDNSSettingChangesAffectPeers(ctx context.Context, transaction Store, accountID string, addedGroups, removedGroups []string) (bool, error) {
hasPeers, err := anyGroupHasPeers(ctx, transaction, accountID, addedGroups)
if err != nil {
return false, err
}
if hasPeers {
return true, nil
}
return anyGroupHasPeers(ctx, transaction, accountID, removedGroups)
}
// validateDNSSettings validates the DNS settings.
func validateDNSSettings(ctx context.Context, transaction Store, accountID string, settings *DNSSettings) error {
if len(settings.DisabledManagementGroups) == 0 {
return nil
}
groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, settings.DisabledManagementGroups)
if err != nil {
return err
}
return validateGroups(settings.DisabledManagementGroups, groups)
}
// toProtocolDNSConfig converts nbdns.Config to proto.DNSConfig using the cache
func toProtocolDNSConfig(update nbdns.Config, cache *DNSConfigCache) *proto.DNSConfig {
protoUpdate := &proto.DNSConfig{

View File

@@ -566,8 +566,7 @@ func areGroupChangesAffectPeers(ctx context.Context, transaction Store, accountI
return false, nil
}
// anyGroupHasPeers checks if any of the given groups in the account have peers.
func anyGroupHasPeers(account *Account, groupIDs []string) bool {
func (am *DefaultAccountManager) anyGroupHasPeers(account *Account, groupIDs []string) bool {
for _, groupID := range groupIDs {
if group, exists := account.Groups[groupID]; exists && group.HasPeers() {
return true
@@ -575,3 +574,19 @@ func anyGroupHasPeers(account *Account, groupIDs []string) bool {
}
return false
}
// anyGroupHasPeers checks if any of the given groups in the account have peers.
func anyGroupHasPeers(ctx context.Context, transaction Store, accountID string, groupIDs []string) (bool, error) {
groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, groupIDs)
if err != nil {
return false, err
}
for _, group := range groups {
if group.HasPeers() {
return true, nil
}
}
return false, nil
}

View File

@@ -500,8 +500,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
})
// adding a group to policy
err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
ID: "policy",
_, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
Enabled: true,
Rules: []*PolicyRule{
{
@@ -512,7 +511,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
Action: PolicyTrafficActionAccept,
},
},
}, false)
})
assert.NoError(t, err)
// Saving a group linked to policy should update account peers and send peer update

View File

@@ -6,10 +6,8 @@ import (
"strconv"
"github.com/gorilla/mux"
nbgroup "github.com/netbirdio/netbird/management/server/group"
"github.com/rs/xid"
"github.com/netbirdio/netbird/management/server"
nbgroup "github.com/netbirdio/netbird/management/server/group"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/jwtclaims"
@@ -122,21 +120,22 @@ func (h *Policies) savePolicy(w http.ResponseWriter, r *http.Request, accountID
return
}
isUpdate := policyID != ""
if policyID == "" {
policyID = xid.New().String()
}
policy := server.Policy{
policy := &server.Policy{
ID: policyID,
AccountID: accountID,
Name: req.Name,
Enabled: req.Enabled,
Description: req.Description,
}
for _, rule := range req.Rules {
var ruleID string
if rule.Id != nil {
ruleID = *rule.Id
}
pr := server.PolicyRule{
ID: policyID, // TODO: when policy can contain multiple rules, need refactor
ID: ruleID,
PolicyID: policyID,
Name: rule.Name,
Destinations: rule.Destinations,
Sources: rule.Sources,
@@ -225,7 +224,8 @@ func (h *Policies) savePolicy(w http.ResponseWriter, r *http.Request, accountID
policy.SourcePostureChecks = *req.SourcePostureChecks
}
if err := h.accountManager.SavePolicy(r.Context(), accountID, userID, &policy, isUpdate); err != nil {
policy, err := h.accountManager.SavePolicy(r.Context(), accountID, userID, policy)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
@@ -236,7 +236,7 @@ func (h *Policies) savePolicy(w http.ResponseWriter, r *http.Request, accountID
return
}
resp := toPolicyResponse(allGroups, &policy)
resp := toPolicyResponse(allGroups, policy)
if len(resp.Rules) == 0 {
util.WriteError(r.Context(), status.Errorf(status.Internal, "no rules in the policy"), w)
return

View File

@@ -38,12 +38,12 @@ func initPoliciesTestData(policies ...*server.Policy) *Policies {
}
return policy, nil
},
SavePolicyFunc: func(_ context.Context, _, _ string, policy *server.Policy, _ bool) error {
SavePolicyFunc: func(_ context.Context, _, _ string, policy *server.Policy) (*server.Policy, error) {
if !strings.HasPrefix(policy.ID, "id-") {
policy.ID = "id-was-set"
policy.Rules[0].ID = "id-was-set"
}
return nil
return policy, nil
},
GetAllGroupsFunc: func(ctx context.Context, accountID, userID string) ([]*nbgroup.Group, error) {
return []*nbgroup.Group{{ID: "F"}, {ID: "G"}}, nil

View File

@@ -169,7 +169,8 @@ func (p *PostureChecksHandler) savePostureChecks(w http.ResponseWriter, r *http.
return
}
if err := p.accountManager.SavePostureChecks(r.Context(), accountID, userID, postureChecks); err != nil {
postureChecks, err = p.accountManager.SavePostureChecks(r.Context(), accountID, userID, postureChecks)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}

View File

@@ -40,15 +40,15 @@ func initPostureChecksTestData(postureChecks ...*posture.Checks) *PostureChecksH
}
return p, nil
},
SavePostureChecksFunc: func(_ context.Context, accountID, userID string, postureChecks *posture.Checks) error {
SavePostureChecksFunc: func(_ context.Context, accountID, userID string, postureChecks *posture.Checks) (*posture.Checks, error) {
postureChecks.ID = "postureCheck"
testPostureChecks[postureChecks.ID] = postureChecks
if err := postureChecks.Validate(); err != nil {
return status.Errorf(status.InvalidArgument, err.Error()) //nolint
return nil, status.Errorf(status.InvalidArgument, err.Error()) //nolint
}
return nil
return postureChecks, nil
},
DeletePostureChecksFunc: func(_ context.Context, accountID, postureChecksID, userID string) error {
_, ok := testPostureChecks[postureChecksID]

View File

@@ -49,7 +49,7 @@ type MockAccountManager struct {
GroupDeletePeerFunc func(ctx context.Context, accountID, groupID, peerID string) error
DeleteRuleFunc func(ctx context.Context, accountID, ruleID, userID string) error
GetPolicyFunc func(ctx context.Context, accountID, policyID, userID string) (*server.Policy, error)
SavePolicyFunc func(ctx context.Context, accountID, userID string, policy *server.Policy, isUpdate bool) error
SavePolicyFunc func(ctx context.Context, accountID, userID string, policy *server.Policy) (*server.Policy, error)
DeletePolicyFunc func(ctx context.Context, accountID, policyID, userID string) error
ListPoliciesFunc func(ctx context.Context, accountID, userID string) ([]*server.Policy, error)
GetUsersFromAccountFunc func(ctx context.Context, accountID, userID string) ([]*server.UserInfo, error)
@@ -96,7 +96,7 @@ type MockAccountManager struct {
HasConnectedChannelFunc func(peerID string) bool
GetExternalCacheManagerFunc func() server.ExternalCacheManager
GetPostureChecksFunc func(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error)
SavePostureChecksFunc func(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error
SavePostureChecksFunc func(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) (*posture.Checks, error)
DeletePostureChecksFunc func(ctx context.Context, accountID, postureChecksID, userID string) error
ListPostureChecksFunc func(ctx context.Context, accountID, userID string) ([]*posture.Checks, error)
GetIdpManagerFunc func() idp.Manager
@@ -386,11 +386,11 @@ func (am *MockAccountManager) GetPolicy(ctx context.Context, accountID, policyID
}
// SavePolicy mock implementation of SavePolicy from server.AccountManager interface
func (am *MockAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *server.Policy, isUpdate bool) error {
func (am *MockAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *server.Policy) (*server.Policy, error) {
if am.SavePolicyFunc != nil {
return am.SavePolicyFunc(ctx, accountID, userID, policy, isUpdate)
return am.SavePolicyFunc(ctx, accountID, userID, policy)
}
return status.Errorf(codes.Unimplemented, "method SavePolicy is not implemented")
return nil, status.Errorf(codes.Unimplemented, "method SavePolicy is not implemented")
}
// DeletePolicy mock implementation of DeletePolicy from server.AccountManager interface
@@ -730,11 +730,11 @@ func (am *MockAccountManager) GetPostureChecks(ctx context.Context, accountID, p
}
// SavePostureChecks mocks SavePostureChecks of the AccountManager interface
func (am *MockAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error {
func (am *MockAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) (*posture.Checks, error) {
if am.SavePostureChecksFunc != nil {
return am.SavePostureChecksFunc(ctx, accountID, userID, postureChecks)
}
return status.Errorf(codes.Unimplemented, "method SavePostureChecks is not implemented")
return nil, status.Errorf(codes.Unimplemented, "method SavePostureChecks is not implemented")
}
// DeletePostureChecks mocks DeletePostureChecks of the AccountManager interface

View File

@@ -24,26 +24,34 @@ func (am *DefaultAccountManager) GetNameServerGroup(ctx context.Context, account
return nil, err
}
if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view name server groups")
if user.AccountID != accountID {
return nil, status.NewUserNotPartOfAccountError()
}
return am.Store.GetNameServerGroupByID(ctx, LockingStrengthShare, nsGroupID, accountID)
if user.IsRegularUser() {
return nil, status.NewAdminPermissionError()
}
return am.Store.GetNameServerGroupByID(ctx, LockingStrengthShare, accountID, nsGroupID)
}
// CreateNameServerGroup creates and saves a new nameserver group
func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainEnabled bool) (*nbdns.NameServerGroup, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return nil, err
}
if user.AccountID != accountID {
return nil, status.NewUserNotPartOfAccountError()
}
newNSGroup := &nbdns.NameServerGroup{
ID: xid.New().String(),
AccountID: accountID,
Name: name,
Description: description,
NameServers: nameServerList,
@@ -54,26 +62,33 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco
SearchDomainsEnabled: searchDomainEnabled,
}
err = validateNameServerGroup(false, newNSGroup, account)
var updateAccountPeers bool
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
if err = validateNameServerGroup(ctx, transaction, accountID, newNSGroup); err != nil {
return err
}
updateAccountPeers, err = anyGroupHasPeers(ctx, transaction, accountID, newNSGroup.Groups)
if err != nil {
return err
}
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return err
}
return transaction.SaveNameServerGroup(ctx, LockingStrengthUpdate, newNSGroup)
})
if err != nil {
return nil, err
}
if account.NameServerGroups == nil {
account.NameServerGroups = make(map[string]*nbdns.NameServerGroup)
}
am.StoreEvent(ctx, userID, newNSGroup.ID, accountID, activity.NameserverGroupCreated, newNSGroup.EventMeta())
account.NameServerGroups[newNSGroup.ID] = newNSGroup
account.Network.IncSerial()
if err = am.Store.SaveAccount(ctx, account); err != nil {
return nil, err
}
if anyGroupHasPeers(account, newNSGroup.Groups) {
if updateAccountPeers {
am.updateAccountPeers(ctx, accountID)
}
am.StoreEvent(ctx, userID, newNSGroup.ID, accountID, activity.NameserverGroupCreated, newNSGroup.EventMeta())
return newNSGroup.Copy(), nil
}
@@ -87,58 +102,95 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun
return status.Errorf(status.InvalidArgument, "nameserver group provided is nil")
}
account, err := am.Store.GetAccount(ctx, accountID)
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return err
}
err = validateNameServerGroup(true, nsGroupToSave, account)
if user.AccountID != accountID {
return status.NewUserNotPartOfAccountError()
}
var updateAccountPeers bool
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
oldNSGroup, err := transaction.GetNameServerGroupByID(ctx, LockingStrengthShare, accountID, nsGroupToSave.ID)
if err != nil {
return err
}
nsGroupToSave.AccountID = accountID
if err = validateNameServerGroup(ctx, transaction, accountID, nsGroupToSave); err != nil {
return err
}
updateAccountPeers, err = areNameServerGroupChangesAffectPeers(ctx, transaction, nsGroupToSave, oldNSGroup)
if err != nil {
return err
}
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return err
}
return transaction.SaveNameServerGroup(ctx, LockingStrengthUpdate, nsGroupToSave)
})
if err != nil {
return err
}
oldNSGroup := account.NameServerGroups[nsGroupToSave.ID]
account.NameServerGroups[nsGroupToSave.ID] = nsGroupToSave
am.StoreEvent(ctx, userID, nsGroupToSave.ID, accountID, activity.NameserverGroupUpdated, nsGroupToSave.EventMeta())
account.Network.IncSerial()
if err = am.Store.SaveAccount(ctx, account); err != nil {
return err
}
if areNameServerGroupChangesAffectPeers(account, nsGroupToSave, oldNSGroup) {
if updateAccountPeers {
am.updateAccountPeers(ctx, accountID)
}
am.StoreEvent(ctx, userID, nsGroupToSave.ID, accountID, activity.NameserverGroupUpdated, nsGroupToSave.EventMeta())
return nil
}
// DeleteNameServerGroup deletes nameserver group with nsGroupID
func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, accountID, nsGroupID, userID string) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return err
}
nsGroup := account.NameServerGroups[nsGroupID]
if nsGroup == nil {
return status.Errorf(status.NotFound, "nameserver group %s wasn't found", nsGroupID)
if user.AccountID != accountID {
return status.NewUserNotPartOfAccountError()
}
delete(account.NameServerGroups, nsGroupID)
account.Network.IncSerial()
if err = am.Store.SaveAccount(ctx, account); err != nil {
var nsGroup *nbdns.NameServerGroup
var updateAccountPeers bool
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
nsGroup, err = transaction.GetNameServerGroupByID(ctx, LockingStrengthUpdate, accountID, nsGroupID)
if err != nil {
return err
}
updateAccountPeers, err = anyGroupHasPeers(ctx, transaction, accountID, nsGroup.Groups)
if err != nil {
return err
}
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return err
}
return transaction.DeleteNameServerGroup(ctx, LockingStrengthUpdate, accountID, nsGroupID)
})
if err != nil {
return err
}
if anyGroupHasPeers(account, nsGroup.Groups) {
am.StoreEvent(ctx, userID, nsGroup.ID, accountID, activity.NameserverGroupDeleted, nsGroup.EventMeta())
if updateAccountPeers {
am.updateAccountPeers(ctx, accountID)
}
am.StoreEvent(ctx, userID, nsGroup.ID, accountID, activity.NameserverGroupDeleted, nsGroup.EventMeta())
return nil
}
@@ -150,44 +202,62 @@ func (am *DefaultAccountManager) ListNameServerGroups(ctx context.Context, accou
return nil, err
}
if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view name server groups")
if user.AccountID != accountID {
return nil, status.NewUserNotPartOfAccountError()
}
if user.IsRegularUser() {
return nil, status.NewAdminPermissionError()
}
return am.Store.GetAccountNameServerGroups(ctx, LockingStrengthShare, accountID)
}
func validateNameServerGroup(existingGroup bool, nameserverGroup *nbdns.NameServerGroup, account *Account) error {
nsGroupID := ""
if existingGroup {
nsGroupID = nameserverGroup.ID
_, found := account.NameServerGroups[nsGroupID]
if !found {
return status.Errorf(status.NotFound, "nameserver group with ID %s was not found", nsGroupID)
}
}
func validateNameServerGroup(ctx context.Context, transaction Store, accountID string, nameserverGroup *nbdns.NameServerGroup) error {
err := validateDomainInput(nameserverGroup.Primary, nameserverGroup.Domains, nameserverGroup.SearchDomainsEnabled)
if err != nil {
return err
}
err = validateNSGroupName(nameserverGroup.Name, nsGroupID, account.NameServerGroups)
if err != nil {
return err
}
err = validateNSList(nameserverGroup.NameServers)
if err != nil {
return err
}
err = validateGroups(nameserverGroup.Groups, account.Groups)
nsServerGroups, err := transaction.GetAccountNameServerGroups(ctx, LockingStrengthShare, accountID)
if err != nil {
return err
}
return nil
err = validateNSGroupName(nameserverGroup.Name, nameserverGroup.ID, nsServerGroups)
if err != nil {
return err
}
groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, nameserverGroup.Groups)
if err != nil {
return err
}
return validateGroups(nameserverGroup.Groups, groups)
}
// areNameServerGroupChangesAffectPeers checks if the changes in the nameserver group affect the peers.
func areNameServerGroupChangesAffectPeers(ctx context.Context, transaction Store, newNSGroup, oldNSGroup *nbdns.NameServerGroup) (bool, error) {
if !newNSGroup.Enabled && !oldNSGroup.Enabled {
return false, nil
}
hasPeers, err := anyGroupHasPeers(ctx, transaction, newNSGroup.AccountID, newNSGroup.Groups)
if err != nil {
return false, err
}
if hasPeers {
return true, nil
}
return anyGroupHasPeers(ctx, transaction, oldNSGroup.AccountID, oldNSGroup.Groups)
}
func validateDomainInput(primary bool, domains []string, searchDomainsEnabled bool) error {
@@ -213,14 +283,14 @@ func validateDomainInput(primary bool, domains []string, searchDomainsEnabled bo
return nil
}
func validateNSGroupName(name, nsGroupID string, nsGroupMap map[string]*nbdns.NameServerGroup) error {
func validateNSGroupName(name, nsGroupID string, groups []*nbdns.NameServerGroup) error {
if utf8.RuneCountInString(name) > nbdns.MaxGroupNameChar || name == "" {
return status.Errorf(status.InvalidArgument, "nameserver group name should be between 1 and %d", nbdns.MaxGroupNameChar)
}
for _, nsGroup := range nsGroupMap {
for _, nsGroup := range groups {
if name == nsGroup.Name && nsGroup.ID != nsGroupID {
return status.Errorf(status.InvalidArgument, "a nameserver group with name %s already exist", name)
return status.Errorf(status.InvalidArgument, "nameserver group with name %s already exist", name)
}
}
@@ -228,8 +298,8 @@ func validateNSGroupName(name, nsGroupID string, nsGroupMap map[string]*nbdns.Na
}
func validateNSList(list []nbdns.NameServer) error {
nsListLenght := len(list)
if nsListLenght == 0 || nsListLenght > 3 {
nsListLength := len(list)
if nsListLength == 0 || nsListLength > 3 {
return status.Errorf(status.InvalidArgument, "the list of nameservers should be 1 or 3, got %d", len(list))
}
return nil
@@ -244,14 +314,7 @@ func validateGroups(list []string, groups map[string]*nbgroup.Group) error {
if id == "" {
return status.Errorf(status.InvalidArgument, "group ID should not be empty string")
}
found := false
for groupID := range groups {
if id == groupID {
found = true
break
}
}
if !found {
if _, found := groups[id]; !found {
return status.Errorf(status.InvalidArgument, "group id %s not found", id)
}
}
@@ -277,11 +340,3 @@ func validateDomain(domain string) error {
return nil
}
// areNameServerGroupChangesAffectPeers checks if the changes in the nameserver group affect the peers.
func areNameServerGroupChangesAffectPeers(account *Account, newNSGroup, oldNSGroup *nbdns.NameServerGroup) bool {
if !newNSGroup.Enabled && !oldNSGroup.Enabled {
return false
}
return anyGroupHasPeers(account, newNSGroup.Groups) || anyGroupHasPeers(account, oldNSGroup.Groups)
}

View File

@@ -617,7 +617,11 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
return nil, nil, nil, err
}
postureChecks := am.getPeerPostureChecks(account, newPeer)
postureChecks, err := am.getPeerPostureChecks(account, newPeer.ID)
if err != nil {
return nil, nil, nil, err
}
customZone := account.GetPeersCustomZone(ctx, am.dnsDomain)
networkMap := account.GetPeerNetworkMap(ctx, newPeer.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics())
return newPeer, networkMap, postureChecks, nil
@@ -702,7 +706,11 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac
if err != nil {
return nil, nil, nil, fmt.Errorf("failed to get validated peers: %w", err)
}
postureChecks = am.getPeerPostureChecks(account, peer)
postureChecks, err = am.getPeerPostureChecks(account, peer.ID)
if err != nil {
return nil, nil, nil, err
}
customZone := account.GetPeersCustomZone(ctx, am.dnsDomain)
return peer, account.GetPeerNetworkMap(ctx, peer.ID, customZone, validPeersMap, am.metrics.AccountManagerMetrics()), postureChecks, nil
@@ -876,7 +884,11 @@ func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, is
if err != nil {
return nil, nil, nil, err
}
postureChecks = am.getPeerPostureChecks(account, peer)
postureChecks, err = am.getPeerPostureChecks(account, peer.ID)
if err != nil {
return nil, nil, nil, err
}
customZone := account.GetPeersCustomZone(ctx, am.dnsDomain)
return peer, account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics()), postureChecks, nil
@@ -1030,7 +1042,12 @@ func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account
defer wg.Done()
defer func() { <-semaphore }()
postureChecks := am.getPeerPostureChecks(account, p)
postureChecks, err := am.getPeerPostureChecks(account, p.ID)
if err != nil {
log.WithContext(ctx).Errorf("failed to send out updates to peers, failed to get peer: %s posture checks: %v", p.ID, err)
return
}
remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics())
update := toSyncResponse(ctx, nil, p, nil, nil, remotePeerNetworkMap, am.GetDNSDomain(), postureChecks, dnsCache)
am.peersUpdateManager.SendUpdate(ctx, p.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap})

View File

@@ -283,14 +283,12 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) {
var (
group1 nbgroup.Group
group2 nbgroup.Group
policy Policy
)
group1.ID = xid.New().String()
group2.ID = xid.New().String()
group1.Name = "src"
group2.Name = "dst"
policy.ID = xid.New().String()
group1.Peers = append(group1.Peers, peer1.ID)
group2.Peers = append(group2.Peers, peer2.ID)
@@ -305,18 +303,20 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) {
return
}
policy.Name = "test"
policy.Enabled = true
policy.Rules = []*PolicyRule{
{
Enabled: true,
Sources: []string{group1.ID},
Destinations: []string{group2.ID},
Bidirectional: true,
Action: PolicyTrafficActionAccept,
policy := &Policy{
Name: "test",
Enabled: true,
Rules: []*PolicyRule{
{
Enabled: true,
Sources: []string{group1.ID},
Destinations: []string{group2.ID},
Bidirectional: true,
Action: PolicyTrafficActionAccept,
},
},
}
err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, false)
policy, err = manager.SavePolicy(context.Background(), account.Id, userID, policy)
if err != nil {
t.Errorf("expecting rule to be added, got failure %v", err)
return
@@ -364,7 +364,7 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) {
}
policy.Enabled = false
err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, true)
_, err = manager.SavePolicy(context.Background(), account.Id, userID, policy)
if err != nil {
t.Errorf("expecting rule to be added, got failure %v", err)
return
@@ -833,19 +833,23 @@ func BenchmarkGetPeers(b *testing.B) {
})
}
}
func BenchmarkUpdateAccountPeers(b *testing.B) {
benchCases := []struct {
name string
peers int
groups int
// We need different expectations for CI/CD and local runs because of the different performance characteristics
minMsPerOpLocal float64
maxMsPerOpLocal float64
minMsPerOpCICD float64
maxMsPerOpCICD float64
}{
{"Small", 50, 5},
{"Medium", 500, 10},
{"Large", 5000, 20},
{"Small single", 50, 1},
{"Medium single", 500, 1},
{"Large 5", 5000, 5},
{"Small", 50, 5, 90, 120, 90, 120},
{"Medium", 500, 100, 110, 140, 120, 200},
{"Large", 5000, 200, 800, 1300, 2500, 3600},
{"Small single", 50, 10, 90, 120, 90, 120},
{"Medium single", 500, 10, 110, 170, 120, 200},
{"Large 5", 5000, 15, 1300, 1800, 5000, 6000},
}
log.SetOutput(io.Discard)
@@ -881,8 +885,23 @@ func BenchmarkUpdateAccountPeers(b *testing.B) {
}
duration := time.Since(start)
b.ReportMetric(float64(duration.Nanoseconds())/float64(b.N)/1e6, "ms/op")
b.ReportMetric(0, "ns/op")
msPerOp := float64(duration.Nanoseconds()) / float64(b.N) / 1e6
b.ReportMetric(msPerOp, "ms/op")
minExpected := bc.minMsPerOpLocal
maxExpected := bc.maxMsPerOpLocal
if os.Getenv("CI") == "true" {
minExpected = bc.minMsPerOpCICD
maxExpected = bc.maxMsPerOpCICD
}
if msPerOp < minExpected {
b.Fatalf("Benchmark %s failed: too fast (%.2f ms/op, minimum %.2f ms/op)", bc.name, msPerOp, minExpected)
}
if msPerOp > maxExpected {
b.Fatalf("Benchmark %s failed: too slow (%.2f ms/op, maximum %.2f ms/op)", bc.name, msPerOp, maxExpected)
}
})
}
}
@@ -1445,8 +1464,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
// Adding peer to group linked with policy should update account peers and send peer update
t.Run("adding peer to group linked with policy", func(t *testing.T) {
err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
ID: "policy",
_, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
Enabled: true,
Rules: []*PolicyRule{
{
@@ -1457,7 +1475,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
Action: PolicyTrafficActionAccept,
},
},
}, false)
})
require.NoError(t, err)
done := make(chan struct{})

View File

@@ -3,13 +3,13 @@ package server
import (
"context"
_ "embed"
"slices"
"strconv"
"strings"
"github.com/netbirdio/netbird/management/proto"
"github.com/rs/xid"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/management/server/activity"
nbgroup "github.com/netbirdio/netbird/management/server/group"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
@@ -125,6 +125,7 @@ type PolicyRule struct {
func (pm *PolicyRule) Copy() *PolicyRule {
rule := &PolicyRule{
ID: pm.ID,
PolicyID: pm.PolicyID,
Name: pm.Name,
Description: pm.Description,
Enabled: pm.Enabled,
@@ -171,6 +172,7 @@ type Policy struct {
func (p *Policy) Copy() *Policy {
c := &Policy{
ID: p.ID,
AccountID: p.AccountID,
Name: p.Name,
Description: p.Description,
Enabled: p.Enabled,
@@ -343,157 +345,207 @@ func (am *DefaultAccountManager) GetPolicy(ctx context.Context, accountID, polic
return nil, err
}
if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view policies")
if user.AccountID != accountID {
return nil, status.NewUserNotPartOfAccountError()
}
return am.Store.GetPolicyByID(ctx, LockingStrengthShare, policyID, accountID)
if user.IsRegularUser() {
return nil, status.NewAdminPermissionError()
}
return am.Store.GetPolicyByID(ctx, LockingStrengthShare, accountID, policyID)
}
// SavePolicy in the store
func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *Policy, isUpdate bool) error {
func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *Policy) (*Policy, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return err
return nil, err
}
updateAccountPeers, err := am.savePolicy(account, policy, isUpdate)
if user.AccountID != accountID {
return nil, status.NewUserNotPartOfAccountError()
}
if user.IsRegularUser() {
return nil, status.NewAdminPermissionError()
}
var isUpdate = policy.ID != ""
var updateAccountPeers bool
var action = activity.PolicyAdded
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
if err = validatePolicy(ctx, transaction, accountID, policy); err != nil {
return err
}
updateAccountPeers, err = arePolicyChangesAffectPeers(ctx, transaction, accountID, policy, isUpdate)
if err != nil {
return err
}
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return err
}
saveFunc := transaction.CreatePolicy
if isUpdate {
action = activity.PolicyUpdated
saveFunc = transaction.SavePolicy
}
return saveFunc(ctx, LockingStrengthUpdate, policy)
})
if err != nil {
return err
return nil, err
}
account.Network.IncSerial()
if err = am.Store.SaveAccount(ctx, account); err != nil {
return err
}
action := activity.PolicyAdded
if isUpdate {
action = activity.PolicyUpdated
}
am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta())
if updateAccountPeers {
am.updateAccountPeers(ctx, accountID)
}
return policy, nil
}
// DeletePolicy from the store
func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, policyID, userID string) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return err
}
if user.AccountID != accountID {
return status.NewUserNotPartOfAccountError()
}
if user.IsRegularUser() {
return status.NewAdminPermissionError()
}
var policy *Policy
var updateAccountPeers bool
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
policy, err = transaction.GetPolicyByID(ctx, LockingStrengthUpdate, accountID, policyID)
if err != nil {
return err
}
updateAccountPeers, err = arePolicyChangesAffectPeers(ctx, transaction, accountID, policy, false)
if err != nil {
return err
}
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return err
}
return transaction.DeletePolicy(ctx, LockingStrengthUpdate, accountID, policyID)
})
if err != nil {
return err
}
am.StoreEvent(ctx, userID, policyID, accountID, activity.PolicyRemoved, policy.EventMeta())
if updateAccountPeers {
am.updateAccountPeers(ctx, accountID)
}
return nil
}
// DeletePolicy from the store
func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, policyID, userID string) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return err
}
policy, err := am.deletePolicy(account, policyID)
if err != nil {
return err
}
account.Network.IncSerial()
if err = am.Store.SaveAccount(ctx, account); err != nil {
return err
}
am.StoreEvent(ctx, userID, policy.ID, accountID, activity.PolicyRemoved, policy.EventMeta())
if anyGroupHasPeers(account, policy.ruleGroups()) {
am.updateAccountPeers(ctx, accountID)
}
return nil
}
// ListPolicies from the store
// ListPolicies from the store.
func (am *DefaultAccountManager) ListPolicies(ctx context.Context, accountID, userID string) ([]*Policy, error) {
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return nil, err
}
if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view policies")
if user.AccountID != accountID {
return nil, status.NewUserNotPartOfAccountError()
}
if user.IsRegularUser() {
return nil, status.NewAdminPermissionError()
}
return am.Store.GetAccountPolicies(ctx, LockingStrengthShare, accountID)
}
func (am *DefaultAccountManager) deletePolicy(account *Account, policyID string) (*Policy, error) {
policyIdx := -1
for i, policy := range account.Policies {
if policy.ID == policyID {
policyIdx = i
break
}
}
if policyIdx < 0 {
return nil, status.Errorf(status.NotFound, "rule with ID %s doesn't exist", policyID)
}
policy := account.Policies[policyIdx]
account.Policies = append(account.Policies[:policyIdx], account.Policies[policyIdx+1:]...)
return policy, nil
}
// savePolicy saves or updates a policy in the given account.
// If isUpdate is true, the function updates the existing policy; otherwise, it adds a new policy.
func (am *DefaultAccountManager) savePolicy(account *Account, policyToSave *Policy, isUpdate bool) (bool, error) {
for index, rule := range policyToSave.Rules {
rule.Sources = filterValidGroupIDs(account, rule.Sources)
rule.Destinations = filterValidGroupIDs(account, rule.Destinations)
policyToSave.Rules[index] = rule
}
if policyToSave.SourcePostureChecks != nil {
policyToSave.SourcePostureChecks = filterValidPostureChecks(account, policyToSave.SourcePostureChecks)
}
// arePolicyChangesAffectPeers checks if changes to a policy will affect any associated peers.
func arePolicyChangesAffectPeers(ctx context.Context, transaction Store, accountID string, policy *Policy, isUpdate bool) (bool, error) {
if isUpdate {
policyIdx := slices.IndexFunc(account.Policies, func(policy *Policy) bool { return policy.ID == policyToSave.ID })
if policyIdx < 0 {
return false, status.Errorf(status.NotFound, "couldn't find policy id %s", policyToSave.ID)
existingPolicy, err := transaction.GetPolicyByID(ctx, LockingStrengthShare, accountID, policy.ID)
if err != nil {
return false, err
}
oldPolicy := account.Policies[policyIdx]
// Update the existing policy
account.Policies[policyIdx] = policyToSave
if !policyToSave.Enabled && !oldPolicy.Enabled {
if !policy.Enabled && !existingPolicy.Enabled {
return false, nil
}
updateAccountPeers := anyGroupHasPeers(account, oldPolicy.ruleGroups()) || anyGroupHasPeers(account, policyToSave.ruleGroups())
return updateAccountPeers, nil
}
hasPeers, err := anyGroupHasPeers(ctx, transaction, policy.AccountID, existingPolicy.ruleGroups())
if err != nil {
return false, err
}
// Add the new policy to the account
account.Policies = append(account.Policies, policyToSave)
return anyGroupHasPeers(account, policyToSave.ruleGroups()), nil
}
func toProtocolFirewallRules(rules []*FirewallRule) []*proto.FirewallRule {
result := make([]*proto.FirewallRule, len(rules))
for i := range rules {
rule := rules[i]
result[i] = &proto.FirewallRule{
PeerIP: rule.PeerIP,
Direction: getProtoDirection(rule.Direction),
Action: getProtoAction(rule.Action),
Protocol: getProtoProtocol(rule.Protocol),
Port: rule.Port,
if hasPeers {
return true, nil
}
}
return result
return anyGroupHasPeers(ctx, transaction, policy.AccountID, policy.ruleGroups())
}
// validatePolicy validates the policy and its rules.
func validatePolicy(ctx context.Context, transaction Store, accountID string, policy *Policy) error {
if policy.ID != "" {
_, err := transaction.GetPolicyByID(ctx, LockingStrengthShare, accountID, policy.ID)
if err != nil {
return err
}
} else {
policy.ID = xid.New().String()
policy.AccountID = accountID
}
groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, policy.ruleGroups())
if err != nil {
return err
}
postureChecks, err := transaction.GetPostureChecksByIDs(ctx, LockingStrengthShare, accountID, policy.SourcePostureChecks)
if err != nil {
return err
}
for i, rule := range policy.Rules {
ruleCopy := rule.Copy()
if ruleCopy.ID == "" {
ruleCopy.ID = policy.ID // TODO: when policy can contain multiple rules, need refactor
ruleCopy.PolicyID = policy.ID
}
ruleCopy.Sources = getValidGroupIDs(groups, ruleCopy.Sources)
ruleCopy.Destinations = getValidGroupIDs(groups, ruleCopy.Destinations)
policy.Rules[i] = ruleCopy
}
if policy.SourcePostureChecks != nil {
policy.SourcePostureChecks = getValidPostureCheckIDs(postureChecks, policy.SourcePostureChecks)
}
return nil
}
// getAllPeersFromGroups for given peer ID and list of groups
@@ -574,27 +626,42 @@ func (a *Account) getPostureChecks(postureChecksID string) *posture.Checks {
return nil
}
// filterValidPostureChecks filters and returns the posture check IDs from the given list
// that are valid within the provided account.
func filterValidPostureChecks(account *Account, postureChecksIds []string) []string {
result := make([]string, 0, len(postureChecksIds))
// getValidPostureCheckIDs filters and returns only the valid posture check IDs from the provided list.
func getValidPostureCheckIDs(postureChecks map[string]*posture.Checks, postureChecksIds []string) []string {
validIDs := make([]string, 0, len(postureChecksIds))
for _, id := range postureChecksIds {
for _, postureCheck := range account.PostureChecks {
if id == postureCheck.ID {
result = append(result, id)
continue
}
if _, exists := postureChecks[id]; exists {
validIDs = append(validIDs, id)
}
}
return result
return validIDs
}
// filterValidGroupIDs filters a list of group IDs and returns only the ones present in the account's group map.
func filterValidGroupIDs(account *Account, groupIDs []string) []string {
result := make([]string, 0, len(groupIDs))
for _, groupID := range groupIDs {
if _, exists := account.Groups[groupID]; exists {
result = append(result, groupID)
// getValidGroupIDs filters and returns only the valid group IDs from the provided list.
func getValidGroupIDs(groups map[string]*nbgroup.Group, groupIDs []string) []string {
validIDs := make([]string, 0, len(groupIDs))
for _, id := range groupIDs {
if _, exists := groups[id]; exists {
validIDs = append(validIDs, id)
}
}
return validIDs
}
// toProtocolFirewallRules converts the firewall rules to the protocol firewall rules.
func toProtocolFirewallRules(rules []*FirewallRule) []*proto.FirewallRule {
result := make([]*proto.FirewallRule, len(rules))
for i := range rules {
rule := rules[i]
result[i] = &proto.FirewallRule{
PeerIP: rule.PeerIP,
Direction: getProtoDirection(rule.Direction),
Action: getProtoAction(rule.Action),
Protocol: getProtoProtocol(rule.Protocol),
Port: rule.Port,
}
}
return result

View File

@@ -7,7 +7,6 @@ import (
"testing"
"time"
"github.com/rs/xid"
"github.com/stretchr/testify/assert"
"golang.org/x/exp/slices"
@@ -859,14 +858,23 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
})
var policyWithGroupRulesNoPeers *Policy
var policyWithDestinationPeersOnly *Policy
var policyWithSourceAndDestinationPeers *Policy
// Saving policy with rule groups with no peers should not update account's peers and not send peer update
t.Run("saving policy with rule groups with no peers", func(t *testing.T) {
policy := Policy{
ID: "policy-rule-groups-no-peers",
Enabled: true,
done := make(chan struct{})
go func() {
peerShouldNotReceiveUpdate(t, updMsg)
close(done)
}()
policyWithGroupRulesNoPeers, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
AccountID: account.Id,
Enabled: true,
Rules: []*PolicyRule{
{
ID: xid.New().String(),
Enabled: true,
Sources: []string{"groupB"},
Destinations: []string{"groupC"},
@@ -874,15 +882,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
Action: PolicyTrafficActionAccept,
},
},
}
done := make(chan struct{})
go func() {
peerShouldNotReceiveUpdate(t, updMsg)
close(done)
}()
err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false)
})
assert.NoError(t, err)
select {
@@ -895,12 +895,17 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
// Saving policy with source group containing peers, but destination group without peers should
// update account's peers and send peer update
t.Run("saving policy where source has peers but destination does not", func(t *testing.T) {
policy := Policy{
ID: "policy-source-has-peers-destination-none",
Enabled: true,
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, updMsg)
close(done)
}()
_, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
AccountID: account.Id,
Enabled: true,
Rules: []*PolicyRule{
{
ID: xid.New().String(),
Enabled: true,
Sources: []string{"groupA"},
Destinations: []string{"groupB"},
@@ -909,15 +914,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
Action: PolicyTrafficActionAccept,
},
},
}
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, updMsg)
close(done)
}()
err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false)
})
assert.NoError(t, err)
select {
@@ -930,13 +927,18 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
// Saving policy with destination group containing peers, but source group without peers should
// update account's peers and send peer update
t.Run("saving policy where destination has peers but source does not", func(t *testing.T) {
policy := Policy{
ID: "policy-destination-has-peers-source-none",
Enabled: true,
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, updMsg)
close(done)
}()
policyWithDestinationPeersOnly, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
AccountID: account.Id,
Enabled: true,
Rules: []*PolicyRule{
{
ID: xid.New().String(),
Enabled: false,
Enabled: true,
Sources: []string{"groupC"},
Destinations: []string{"groupD"},
Bidirectional: true,
@@ -944,15 +946,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
Action: PolicyTrafficActionAccept,
},
},
}
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, updMsg)
close(done)
}()
err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false)
})
assert.NoError(t, err)
select {
@@ -965,12 +959,17 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
// Saving policy with destination and source groups containing peers should update account's peers
// and send peer update
t.Run("saving policy with source and destination groups with peers", func(t *testing.T) {
policy := Policy{
ID: "policy-source-destination-peers",
Enabled: true,
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, updMsg)
close(done)
}()
policyWithSourceAndDestinationPeers, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
AccountID: account.Id,
Enabled: true,
Rules: []*PolicyRule{
{
ID: xid.New().String(),
Enabled: true,
Sources: []string{"groupA"},
Destinations: []string{"groupD"},
@@ -978,15 +977,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
Action: PolicyTrafficActionAccept,
},
},
}
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, updMsg)
close(done)
}()
err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false)
})
assert.NoError(t, err)
select {
@@ -999,28 +990,14 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
// Disabling policy with destination and source groups containing peers should update account's peers
// and send peer update
t.Run("disabling policy with source and destination groups with peers", func(t *testing.T) {
policy := Policy{
ID: "policy-source-destination-peers",
Enabled: false,
Rules: []*PolicyRule{
{
ID: xid.New().String(),
Enabled: true,
Sources: []string{"groupA"},
Destinations: []string{"groupD"},
Bidirectional: true,
Action: PolicyTrafficActionAccept,
},
},
}
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, updMsg)
close(done)
}()
err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, true)
policyWithSourceAndDestinationPeers.Enabled = false
policyWithSourceAndDestinationPeers, err = manager.SavePolicy(context.Background(), account.Id, userID, policyWithSourceAndDestinationPeers)
assert.NoError(t, err)
select {
@@ -1033,29 +1010,15 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
// Updating disabled policy with destination and source groups containing peers should not update account's peers
// or send peer update
t.Run("updating disabled policy with source and destination groups with peers", func(t *testing.T) {
policy := Policy{
ID: "policy-source-destination-peers",
Description: "updated description",
Enabled: false,
Rules: []*PolicyRule{
{
ID: xid.New().String(),
Enabled: true,
Sources: []string{"groupA"},
Destinations: []string{"groupA"},
Bidirectional: true,
Action: PolicyTrafficActionAccept,
},
},
}
done := make(chan struct{})
go func() {
peerShouldNotReceiveUpdate(t, updMsg)
close(done)
}()
err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, true)
policyWithSourceAndDestinationPeers.Description = "updated description"
policyWithSourceAndDestinationPeers.Rules[0].Destinations = []string{"groupA"}
policyWithSourceAndDestinationPeers, err = manager.SavePolicy(context.Background(), account.Id, userID, policyWithSourceAndDestinationPeers)
assert.NoError(t, err)
select {
@@ -1068,28 +1031,14 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
// Enabling policy with destination and source groups containing peers should update account's peers
// and send peer update
t.Run("enabling policy with source and destination groups with peers", func(t *testing.T) {
policy := Policy{
ID: "policy-source-destination-peers",
Enabled: true,
Rules: []*PolicyRule{
{
ID: xid.New().String(),
Enabled: true,
Sources: []string{"groupA"},
Destinations: []string{"groupD"},
Bidirectional: true,
Action: PolicyTrafficActionAccept,
},
},
}
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, updMsg)
close(done)
}()
err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, true)
policyWithSourceAndDestinationPeers.Enabled = true
policyWithSourceAndDestinationPeers, err = manager.SavePolicy(context.Background(), account.Id, userID, policyWithSourceAndDestinationPeers)
assert.NoError(t, err)
select {
@@ -1101,15 +1050,13 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
// Deleting policy should trigger account peers update and send peer update
t.Run("deleting policy with source and destination groups with peers", func(t *testing.T) {
policyID := "policy-source-destination-peers"
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, updMsg)
close(done)
}()
err := manager.DeletePolicy(context.Background(), account.Id, policyID, userID)
err := manager.DeletePolicy(context.Background(), account.Id, policyWithSourceAndDestinationPeers.ID, userID)
assert.NoError(t, err)
select {
@@ -1123,14 +1070,13 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
// Deleting policy with destination group containing peers, but source group without peers should
// update account's peers and send peer update
t.Run("deleting policy where destination has peers but source does not", func(t *testing.T) {
policyID := "policy-destination-has-peers-source-none"
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, updMsg)
close(done)
}()
err := manager.DeletePolicy(context.Background(), account.Id, policyID, userID)
err := manager.DeletePolicy(context.Background(), account.Id, policyWithDestinationPeersOnly.ID, userID)
assert.NoError(t, err)
select {
@@ -1142,14 +1088,13 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
// Deleting policy with no peers in groups should not update account's peers and not send peer update
t.Run("deleting policy with no peers in groups", func(t *testing.T) {
policyID := "policy-rule-groups-no-peers"
done := make(chan struct{})
go func() {
peerShouldNotReceiveUpdate(t, updMsg)
close(done)
}()
err := manager.DeletePolicy(context.Background(), account.Id, policyID, userID)
err := manager.DeletePolicy(context.Background(), account.Id, policyWithGroupRulesNoPeers.ID, userID)
assert.NoError(t, err)
select {

View File

@@ -7,8 +7,6 @@ import (
"regexp"
"github.com/hashicorp/go-version"
"github.com/rs/xid"
"github.com/netbirdio/netbird/management/server/http/api"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/status"
@@ -172,10 +170,6 @@ func NewChecksFromAPIPostureCheckUpdate(source api.PostureCheckUpdate, postureCh
}
func buildPostureCheck(postureChecksID string, name string, description string, checks api.Checks) (*Checks, error) {
if postureChecksID == "" {
postureChecksID = xid.New().String()
}
postureChecks := Checks{
ID: postureChecksID,
Name: name,

View File

@@ -2,237 +2,285 @@ package server
import (
"context"
"errors"
"fmt"
"slices"
"github.com/rs/xid"
"golang.org/x/exp/maps"
"github.com/netbirdio/netbird/management/server/activity"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/status"
)
const (
errMsgPostureAdminOnly = "only users with admin power are allowed to view posture checks"
)
func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) {
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return nil, err
}
if !user.HasAdminPower() || user.AccountID != accountID {
return nil, status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly)
}
return am.Store.GetPostureChecksByID(ctx, LockingStrengthShare, postureChecksID, accountID)
}
func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return err
}
user, err := account.FindUser(userID)
if err != nil {
return err
if user.AccountID != accountID {
return nil, status.NewUserNotPartOfAccountError()
}
if !user.HasAdminPower() {
return status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly)
return nil, status.NewAdminPermissionError()
}
if err := postureChecks.Validate(); err != nil {
return status.Errorf(status.InvalidArgument, err.Error()) //nolint
return am.Store.GetPostureChecksByID(ctx, LockingStrengthShare, accountID, postureChecksID)
}
// SavePostureChecks saves a posture check.
func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) (*posture.Checks, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return nil, err
}
exists, uniqName := am.savePostureChecks(account, postureChecks)
// we do not allow create new posture checks with non uniq name
if !exists && !uniqName {
return status.Errorf(status.PreconditionFailed, "Posture check name should be unique")
if user.AccountID != accountID {
return nil, status.NewUserNotPartOfAccountError()
}
action := activity.PostureCheckCreated
if exists {
action = activity.PostureCheckUpdated
account.Network.IncSerial()
if !user.HasAdminPower() {
return nil, status.NewAdminPermissionError()
}
if err = am.Store.SaveAccount(ctx, account); err != nil {
return err
var updateAccountPeers bool
var isUpdate = postureChecks.ID != ""
var action = activity.PostureCheckCreated
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
if err = validatePostureChecks(ctx, transaction, accountID, postureChecks); err != nil {
return err
}
if isUpdate {
updateAccountPeers, err = arePostureCheckChangesAffectPeers(ctx, transaction, accountID, postureChecks.ID)
if err != nil {
return err
}
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return err
}
action = activity.PostureCheckUpdated
}
postureChecks.AccountID = accountID
return transaction.SavePostureChecks(ctx, LockingStrengthUpdate, postureChecks)
})
if err != nil {
return nil, err
}
am.StoreEvent(ctx, userID, postureChecks.ID, accountID, action, postureChecks.EventMeta())
if arePostureCheckChangesAffectingPeers(account, postureChecks.ID, exists) {
if updateAccountPeers {
am.updateAccountPeers(ctx, accountID)
}
return nil
return postureChecks, nil
}
// DeletePostureChecks deletes a posture check by ID.
func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accountID, postureChecksID, userID string) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return err
}
user, err := account.FindUser(userID)
if err != nil {
return err
if user.AccountID != accountID {
return status.NewUserNotPartOfAccountError()
}
if !user.HasAdminPower() {
return status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly)
return status.NewAdminPermissionError()
}
postureChecks, err := am.deletePostureChecks(account, postureChecksID)
var postureChecks *posture.Checks
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
postureChecks, err = transaction.GetPostureChecksByID(ctx, LockingStrengthShare, accountID, postureChecksID)
if err != nil {
return err
}
if err = isPostureCheckLinkedToPolicy(ctx, transaction, postureChecksID, accountID); err != nil {
return err
}
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return err
}
return transaction.DeletePostureChecks(ctx, LockingStrengthUpdate, accountID, postureChecksID)
})
if err != nil {
return err
}
if err = am.Store.SaveAccount(ctx, account); err != nil {
return err
}
am.StoreEvent(ctx, userID, postureChecks.ID, accountID, activity.PostureCheckDeleted, postureChecks.EventMeta())
return nil
}
// ListPostureChecks returns a list of posture checks.
func (am *DefaultAccountManager) ListPostureChecks(ctx context.Context, accountID, userID string) ([]*posture.Checks, error) {
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return nil, err
}
if !user.HasAdminPower() || user.AccountID != accountID {
return nil, status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly)
if user.AccountID != accountID {
return nil, status.NewUserNotPartOfAccountError()
}
if !user.HasAdminPower() {
return nil, status.NewAdminPermissionError()
}
return am.Store.GetAccountPostureChecks(ctx, LockingStrengthShare, accountID)
}
func (am *DefaultAccountManager) savePostureChecks(account *Account, postureChecks *posture.Checks) (exists, uniqName bool) {
uniqName = true
for i, p := range account.PostureChecks {
if !exists && p.ID == postureChecks.ID {
account.PostureChecks[i] = postureChecks
exists = true
}
if p.Name == postureChecks.Name {
uniqName = false
}
}
if !exists {
account.PostureChecks = append(account.PostureChecks, postureChecks)
}
return
}
func (am *DefaultAccountManager) deletePostureChecks(account *Account, postureChecksID string) (*posture.Checks, error) {
postureChecksIdx := -1
for i, postureChecks := range account.PostureChecks {
if postureChecks.ID == postureChecksID {
postureChecksIdx = i
break
}
}
if postureChecksIdx < 0 {
return nil, status.Errorf(status.NotFound, "posture checks with ID %s doesn't exist", postureChecksID)
}
// Check if posture check is linked to any policy
if isLinked, linkedPolicy := isPostureCheckLinkedToPolicy(account, postureChecksID); isLinked {
return nil, status.Errorf(status.PreconditionFailed, "posture checks have been linked to policy: %s", linkedPolicy.Name)
}
postureChecks := account.PostureChecks[postureChecksIdx]
account.PostureChecks = append(account.PostureChecks[:postureChecksIdx], account.PostureChecks[postureChecksIdx+1:]...)
return postureChecks, nil
}
// getPeerPostureChecks returns the posture checks applied for a given peer.
func (am *DefaultAccountManager) getPeerPostureChecks(account *Account, peer *nbpeer.Peer) []*posture.Checks {
peerPostureChecks := make(map[string]posture.Checks)
func (am *DefaultAccountManager) getPeerPostureChecks(account *Account, peerID string) ([]*posture.Checks, error) {
peerPostureChecks := make(map[string]*posture.Checks)
if len(account.PostureChecks) == 0 {
return nil
return nil, nil
}
for _, policy := range account.Policies {
if !policy.Enabled {
if !policy.Enabled || len(policy.SourcePostureChecks) == 0 {
continue
}
if isPeerInPolicySourceGroups(peer.ID, account, policy) {
addPolicyPostureChecks(account, policy, peerPostureChecks)
if err := addPolicyPostureChecks(account, peerID, policy, peerPostureChecks); err != nil {
return nil, err
}
}
postureChecksList := make([]*posture.Checks, 0, len(peerPostureChecks))
for _, check := range peerPostureChecks {
checkCopy := check
postureChecksList = append(postureChecksList, &checkCopy)
return maps.Values(peerPostureChecks), nil
}
// arePostureCheckChangesAffectPeers checks if the changes in posture checks are affecting peers.
func arePostureCheckChangesAffectPeers(ctx context.Context, transaction Store, accountID, postureCheckID string) (bool, error) {
policies, err := transaction.GetAccountPolicies(ctx, LockingStrengthShare, accountID)
if err != nil {
return false, err
}
return postureChecksList
for _, policy := range policies {
if slices.Contains(policy.SourcePostureChecks, postureCheckID) {
hasPeers, err := anyGroupHasPeers(ctx, transaction, accountID, policy.ruleGroups())
if err != nil {
return false, err
}
if hasPeers {
return true, nil
}
}
}
return false, nil
}
// validatePostureChecks validates the posture checks.
func validatePostureChecks(ctx context.Context, transaction Store, accountID string, postureChecks *posture.Checks) error {
if err := postureChecks.Validate(); err != nil {
return status.Errorf(status.InvalidArgument, err.Error()) //nolint
}
// If the posture check already has an ID, verify its existence in the store.
if postureChecks.ID != "" {
if _, err := transaction.GetPostureChecksByID(ctx, LockingStrengthShare, accountID, postureChecks.ID); err != nil {
return err
}
return nil
}
// For new posture checks, ensure no duplicates by name.
checks, err := transaction.GetAccountPostureChecks(ctx, LockingStrengthShare, accountID)
if err != nil {
return err
}
for _, check := range checks {
if check.Name == postureChecks.Name && check.ID != postureChecks.ID {
return status.Errorf(status.InvalidArgument, "posture checks with name %s already exists", postureChecks.Name)
}
}
postureChecks.ID = xid.New().String()
return nil
}
// addPolicyPostureChecks adds posture checks from a policy to the peer posture checks map if the peer is in the policy's source groups.
func addPolicyPostureChecks(account *Account, peerID string, policy *Policy, peerPostureChecks map[string]*posture.Checks) error {
isInGroup, err := isPeerInPolicySourceGroups(account, peerID, policy)
if err != nil {
return err
}
if !isInGroup {
return nil
}
for _, sourcePostureCheckID := range policy.SourcePostureChecks {
postureCheck := account.getPostureChecks(sourcePostureCheckID)
if postureCheck == nil {
return errors.New("failed to add policy posture checks: posture checks not found")
}
peerPostureChecks[sourcePostureCheckID] = postureCheck
}
return nil
}
// isPeerInPolicySourceGroups checks if a peer is present in any of the policy rule source groups.
func isPeerInPolicySourceGroups(peerID string, account *Account, policy *Policy) bool {
func isPeerInPolicySourceGroups(account *Account, peerID string, policy *Policy) (bool, error) {
for _, rule := range policy.Rules {
if !rule.Enabled {
continue
}
for _, sourceGroup := range rule.Sources {
group, ok := account.Groups[sourceGroup]
if ok && slices.Contains(group.Peers, peerID) {
return true
group := account.GetGroup(sourceGroup)
if group == nil {
return false, fmt.Errorf("failed to check peer in policy source group: group not found")
}
if slices.Contains(group.Peers, peerID) {
return true, nil
}
}
}
return false
}
func addPolicyPostureChecks(account *Account, policy *Policy, peerPostureChecks map[string]posture.Checks) {
for _, sourcePostureCheckID := range policy.SourcePostureChecks {
for _, postureCheck := range account.PostureChecks {
if postureCheck.ID == sourcePostureCheckID {
peerPostureChecks[sourcePostureCheckID] = *postureCheck
}
}
}
}
func isPostureCheckLinkedToPolicy(account *Account, postureChecksID string) (bool, *Policy) {
for _, policy := range account.Policies {
if slices.Contains(policy.SourcePostureChecks, postureChecksID) {
return true, policy
}
}
return false, nil
}
// arePostureCheckChangesAffectingPeers checks if the changes in posture checks are affecting peers.
func arePostureCheckChangesAffectingPeers(account *Account, postureCheckID string, exists bool) bool {
if !exists {
return false
// isPostureCheckLinkedToPolicy checks whether the posture check is linked to any account policy.
func isPostureCheckLinkedToPolicy(ctx context.Context, transaction Store, postureChecksID, accountID string) error {
policies, err := transaction.GetAccountPolicies(ctx, LockingStrengthShare, accountID)
if err != nil {
return err
}
isLinked, linkedPolicy := isPostureCheckLinkedToPolicy(account, postureCheckID)
if !isLinked {
return false
for _, policy := range policies {
if slices.Contains(policy.SourcePostureChecks, postureChecksID) {
return status.Errorf(status.PreconditionFailed, "posture checks have been linked to policy: %s", policy.Name)
}
}
return anyGroupHasPeers(account, linkedPolicy.ruleGroups())
return nil
}

View File

@@ -5,8 +5,8 @@ import (
"testing"
"time"
"github.com/rs/xid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/server/group"
@@ -16,7 +16,6 @@ import (
const (
adminUserID = "adminUserID"
regularUserID = "regularUserID"
postureCheckID = "existing-id"
postureCheckName = "Existing check"
)
@@ -33,7 +32,7 @@ func TestDefaultAccountManager_PostureCheck(t *testing.T) {
t.Run("Generic posture check flow", func(t *testing.T) {
// regular users can not create checks
err := am.SavePostureChecks(context.Background(), account.Id, regularUserID, &posture.Checks{})
_, err = am.SavePostureChecks(context.Background(), account.Id, regularUserID, &posture.Checks{})
assert.Error(t, err)
// regular users cannot list check
@@ -41,8 +40,7 @@ func TestDefaultAccountManager_PostureCheck(t *testing.T) {
assert.Error(t, err)
// should be possible to create posture check with uniq name
err = am.SavePostureChecks(context.Background(), account.Id, adminUserID, &posture.Checks{
ID: postureCheckID,
postureCheck, err := am.SavePostureChecks(context.Background(), account.Id, adminUserID, &posture.Checks{
Name: postureCheckName,
Checks: posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{
@@ -58,8 +56,7 @@ func TestDefaultAccountManager_PostureCheck(t *testing.T) {
assert.Len(t, checks, 1)
// should not be possible to create posture check with non uniq name
err = am.SavePostureChecks(context.Background(), account.Id, adminUserID, &posture.Checks{
ID: "new-id",
_, err = am.SavePostureChecks(context.Background(), account.Id, adminUserID, &posture.Checks{
Name: postureCheckName,
Checks: posture.ChecksDefinition{
GeoLocationCheck: &posture.GeoLocationCheck{
@@ -74,23 +71,20 @@ func TestDefaultAccountManager_PostureCheck(t *testing.T) {
assert.Error(t, err)
// admins can update posture checks
err = am.SavePostureChecks(context.Background(), account.Id, adminUserID, &posture.Checks{
ID: postureCheckID,
Name: postureCheckName,
Checks: posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{
MinVersion: "0.27.0",
},
postureCheck.Checks = posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{
MinVersion: "0.27.0",
},
})
}
_, err = am.SavePostureChecks(context.Background(), account.Id, adminUserID, postureCheck)
assert.NoError(t, err)
// users should not be able to delete posture checks
err = am.DeletePostureChecks(context.Background(), account.Id, postureCheckID, regularUserID)
err = am.DeletePostureChecks(context.Background(), account.Id, postureCheck.ID, regularUserID)
assert.Error(t, err)
// admin should be able to delete posture checks
err = am.DeletePostureChecks(context.Background(), account.Id, postureCheckID, adminUserID)
err = am.DeletePostureChecks(context.Background(), account.Id, postureCheck.ID, adminUserID)
assert.NoError(t, err)
checks, err = am.ListPostureChecks(context.Background(), account.Id, adminUserID)
assert.NoError(t, err)
@@ -150,9 +144,22 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
})
postureCheck := posture.Checks{
ID: "postureCheck",
Name: "postureCheck",
postureCheckA := &posture.Checks{
Name: "postureCheckA",
AccountID: account.Id,
Checks: posture.ChecksDefinition{
ProcessCheck: &posture.ProcessCheck{
Processes: []posture.Process{
{LinuxPath: "/usr/bin/netbird", MacPath: "/usr/local/bin/netbird"},
},
},
},
}
postureCheckA, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckA)
require.NoError(t, err)
postureCheckB := &posture.Checks{
Name: "postureCheckB",
AccountID: account.Id,
Checks: posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{
@@ -169,7 +176,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
close(done)
}()
err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck)
postureCheckB, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB)
assert.NoError(t, err)
select {
@@ -187,12 +194,12 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
close(done)
}()
postureCheck.Checks = posture.ChecksDefinition{
postureCheckB.Checks = posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{
MinVersion: "0.29.0",
},
}
err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck)
_, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB)
assert.NoError(t, err)
select {
@@ -202,12 +209,10 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
}
})
policy := Policy{
ID: "policyA",
policy := &Policy{
Enabled: true,
Rules: []*PolicyRule{
{
ID: xid.New().String(),
Enabled: true,
Sources: []string{"groupA"},
Destinations: []string{"groupA"},
@@ -215,7 +220,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
Action: PolicyTrafficActionAccept,
},
},
SourcePostureChecks: []string{postureCheck.ID},
SourcePostureChecks: []string{postureCheckB.ID},
}
// Linking posture check to policy should trigger update account peers and send peer update
@@ -226,7 +231,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
close(done)
}()
err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false)
policy, err = manager.SavePolicy(context.Background(), account.Id, userID, policy)
assert.NoError(t, err)
select {
@@ -238,7 +243,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
// Updating linked posture checks should update account peers and send peer update
t.Run("updating linked to posture check with peers", func(t *testing.T) {
postureCheck.Checks = posture.ChecksDefinition{
postureCheckB.Checks = posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{
MinVersion: "0.29.0",
},
@@ -255,7 +260,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
close(done)
}()
err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck)
_, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB)
assert.NoError(t, err)
select {
@@ -274,8 +279,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
}()
policy.SourcePostureChecks = []string{}
err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, true)
_, err := manager.SavePolicy(context.Background(), account.Id, userID, policy)
assert.NoError(t, err)
select {
@@ -293,7 +297,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
close(done)
}()
err := manager.DeletePostureChecks(context.Background(), account.Id, "postureCheck", userID)
err := manager.DeletePostureChecks(context.Background(), account.Id, postureCheckA.ID, userID)
assert.NoError(t, err)
select {
@@ -303,17 +307,15 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
}
})
err = manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck)
_, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB)
assert.NoError(t, err)
// Updating linked posture check to policy with no peers should not trigger account peers update and not send peer update
t.Run("updating linked posture check to policy with no peers", func(t *testing.T) {
policy = Policy{
ID: "policyB",
_, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
Enabled: true,
Rules: []*PolicyRule{
{
ID: xid.New().String(),
Enabled: true,
Sources: []string{"groupB"},
Destinations: []string{"groupC"},
@@ -321,9 +323,8 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
Action: PolicyTrafficActionAccept,
},
},
SourcePostureChecks: []string{postureCheck.ID},
}
err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, false)
SourcePostureChecks: []string{postureCheckB.ID},
})
assert.NoError(t, err)
done := make(chan struct{})
@@ -332,12 +333,12 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
close(done)
}()
postureCheck.Checks = posture.ChecksDefinition{
postureCheckB.Checks = posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{
MinVersion: "0.29.0",
},
}
err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck)
_, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB)
assert.NoError(t, err)
select {
@@ -354,12 +355,11 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
t.Cleanup(func() {
manager.peersUpdateManager.CloseChannel(context.Background(), peer2.ID)
})
policy = Policy{
ID: "policyB",
_, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
Enabled: true,
Rules: []*PolicyRule{
{
ID: xid.New().String(),
Enabled: true,
Sources: []string{"groupB"},
Destinations: []string{"groupA"},
@@ -367,10 +367,8 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
Action: PolicyTrafficActionAccept,
},
},
SourcePostureChecks: []string{postureCheck.ID},
}
err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, true)
SourcePostureChecks: []string{postureCheckB.ID},
})
assert.NoError(t, err)
done := make(chan struct{})
@@ -379,12 +377,12 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
close(done)
}()
postureCheck.Checks = posture.ChecksDefinition{
postureCheckB.Checks = posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{
MinVersion: "0.29.0",
},
}
err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck)
_, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB)
assert.NoError(t, err)
select {
@@ -397,8 +395,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
// Updating linked client posture check to policy where source has peers but destination does not,
// should trigger account peers update and send peer update
t.Run("updating linked posture check to policy where source has peers but destination does not", func(t *testing.T) {
policy = Policy{
ID: "policyB",
_, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
Enabled: true,
Rules: []*PolicyRule{
{
@@ -409,9 +406,8 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
Action: PolicyTrafficActionAccept,
},
},
SourcePostureChecks: []string{postureCheck.ID},
}
err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, true)
SourcePostureChecks: []string{postureCheckB.ID},
})
assert.NoError(t, err)
done := make(chan struct{})
@@ -420,7 +416,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
close(done)
}()
postureCheck.Checks = posture.ChecksDefinition{
postureCheckB.Checks = posture.ChecksDefinition{
ProcessCheck: &posture.ProcessCheck{
Processes: []posture.Process{
{
@@ -429,7 +425,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
},
},
}
err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck)
_, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB)
assert.NoError(t, err)
select {
@@ -440,80 +436,120 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
})
}
func TestArePostureCheckChangesAffectingPeers(t *testing.T) {
account := &Account{
Policies: []*Policy{
{
ID: "policyA",
Rules: []*PolicyRule{
{
Enabled: true,
Sources: []string{"groupA"},
Destinations: []string{"groupA"},
},
},
SourcePostureChecks: []string{"checkA"},
},
},
Groups: map[string]*group.Group{
"groupA": {
ID: "groupA",
Peers: []string{"peer1"},
},
"groupB": {
ID: "groupB",
Peers: []string{},
},
},
PostureChecks: []*posture.Checks{
{
ID: "checkA",
},
{
ID: "checkB",
},
},
func TestArePostureCheckChangesAffectPeers(t *testing.T) {
manager, err := createManager(t)
require.NoError(t, err, "failed to create account manager")
account, err := initTestPostureChecksAccount(manager)
require.NoError(t, err, "failed to init testing account")
groupA := &group.Group{
ID: "groupA",
AccountID: account.Id,
Peers: []string{"peer1"},
}
groupB := &group.Group{
ID: "groupB",
AccountID: account.Id,
Peers: []string{},
}
err = manager.Store.SaveGroups(context.Background(), LockingStrengthUpdate, []*group.Group{groupA, groupB})
require.NoError(t, err, "failed to save groups")
postureCheckA := &posture.Checks{
Name: "checkA",
AccountID: account.Id,
Checks: posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.33.1"},
},
}
postureCheckA, err = manager.SavePostureChecks(context.Background(), account.Id, adminUserID, postureCheckA)
require.NoError(t, err, "failed to save postureCheckA")
postureCheckB := &posture.Checks{
Name: "checkB",
AccountID: account.Id,
Checks: posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.33.1"},
},
}
postureCheckB, err = manager.SavePostureChecks(context.Background(), account.Id, adminUserID, postureCheckB)
require.NoError(t, err, "failed to save postureCheckB")
policy := &Policy{
AccountID: account.Id,
Rules: []*PolicyRule{
{
Enabled: true,
Sources: []string{"groupA"},
Destinations: []string{"groupA"},
},
},
SourcePostureChecks: []string{postureCheckA.ID},
}
policy, err = manager.SavePolicy(context.Background(), account.Id, adminUserID, policy)
require.NoError(t, err, "failed to save policy")
t.Run("posture check exists and is linked to policy with peers", func(t *testing.T) {
result := arePostureCheckChangesAffectingPeers(account, "checkA", true)
result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID)
require.NoError(t, err)
assert.True(t, result)
})
t.Run("posture check exists but is not linked to any policy", func(t *testing.T) {
result := arePostureCheckChangesAffectingPeers(account, "checkB", true)
result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckB.ID)
require.NoError(t, err)
assert.False(t, result)
})
t.Run("posture check does not exist", func(t *testing.T) {
result := arePostureCheckChangesAffectingPeers(account, "unknown", false)
result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, "unknown")
require.NoError(t, err)
assert.False(t, result)
})
t.Run("posture check is linked to policy with no peers in source groups", func(t *testing.T) {
account.Policies[0].Rules[0].Sources = []string{"groupB"}
account.Policies[0].Rules[0].Destinations = []string{"groupA"}
result := arePostureCheckChangesAffectingPeers(account, "checkA", true)
policy.Rules[0].Sources = []string{"groupB"}
policy.Rules[0].Destinations = []string{"groupA"}
_, err = manager.SavePolicy(context.Background(), account.Id, adminUserID, policy)
require.NoError(t, err, "failed to update policy")
result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID)
require.NoError(t, err)
assert.True(t, result)
})
t.Run("posture check is linked to policy with no peers in destination groups", func(t *testing.T) {
account.Policies[0].Rules[0].Sources = []string{"groupA"}
account.Policies[0].Rules[0].Destinations = []string{"groupB"}
result := arePostureCheckChangesAffectingPeers(account, "checkA", true)
policy.Rules[0].Sources = []string{"groupA"}
policy.Rules[0].Destinations = []string{"groupB"}
_, err = manager.SavePolicy(context.Background(), account.Id, adminUserID, policy)
require.NoError(t, err, "failed to update policy")
result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID)
require.NoError(t, err)
assert.True(t, result)
})
t.Run("posture check is linked to policy with non-existent group", func(t *testing.T) {
account.Policies[0].Rules[0].Sources = []string{"nonExistentGroup"}
account.Policies[0].Rules[0].Destinations = []string{"nonExistentGroup"}
result := arePostureCheckChangesAffectingPeers(account, "checkA", true)
t.Run("posture check is linked to policy but no peers in groups", func(t *testing.T) {
groupA.Peers = []string{}
err = manager.Store.SaveGroup(context.Background(), LockingStrengthUpdate, groupA)
require.NoError(t, err, "failed to save groups")
result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID)
require.NoError(t, err)
assert.False(t, result)
})
t.Run("posture check is linked to policy but no peers in groups", func(t *testing.T) {
account.Groups["groupA"].Peers = []string{}
result := arePostureCheckChangesAffectingPeers(account, "checkA", true)
t.Run("posture check is linked to policy with non-existent group", func(t *testing.T) {
policy.Rules[0].Sources = []string{"nonExistentGroup"}
policy.Rules[0].Destinations = []string{"nonExistentGroup"}
_, err = manager.SavePolicy(context.Background(), account.Id, adminUserID, policy)
require.NoError(t, err, "failed to update policy")
result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID)
require.NoError(t, err)
assert.False(t, result)
})
}

View File

@@ -237,7 +237,7 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
return nil, err
}
if isRouteChangeAffectPeers(account, &newRoute) {
if am.isRouteChangeAffectPeers(account, &newRoute) {
am.updateAccountPeers(ctx, accountID)
}
@@ -323,7 +323,7 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
return err
}
if isRouteChangeAffectPeers(account, oldRoute) || isRouteChangeAffectPeers(account, routeToSave) {
if am.isRouteChangeAffectPeers(account, oldRoute) || am.isRouteChangeAffectPeers(account, routeToSave) {
am.updateAccountPeers(ctx, accountID)
}
@@ -355,7 +355,7 @@ func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID stri
am.StoreEvent(ctx, userID, string(routy.ID), accountID, activity.RouteRemoved, routy.EventMeta())
if isRouteChangeAffectPeers(account, routy) {
if am.isRouteChangeAffectPeers(account, routy) {
am.updateAccountPeers(ctx, accountID)
}
@@ -651,6 +651,6 @@ func getProtoPortInfo(rule *RouteFirewallRule) *proto.PortInfo {
// isRouteChangeAffectPeers checks if a given route affects peers by determining
// if it has a routing peer, distribution, or peer groups that include peers
func isRouteChangeAffectPeers(account *Account, route *route.Route) bool {
return anyGroupHasPeers(account, route.Groups) || anyGroupHasPeers(account, route.PeerGroups) || route.Peer != ""
func (am *DefaultAccountManager) isRouteChangeAffectPeers(account *Account, route *route.Route) bool {
return am.anyGroupHasPeers(account, route.Groups) || am.anyGroupHasPeers(account, route.PeerGroups) || route.Peer != ""
}

View File

@@ -1214,12 +1214,11 @@ func TestGetNetworkMap_RouteSync(t *testing.T) {
defaultRule := rules[0]
newPolicy := defaultRule.Copy()
newPolicy.ID = xid.New().String()
newPolicy.Name = "peer1 only"
newPolicy.Rules[0].Sources = []string{newGroup.ID}
newPolicy.Rules[0].Destinations = []string{newGroup.ID}
err = am.SavePolicy(context.Background(), account.Id, userID, newPolicy, false)
_, err = am.SavePolicy(context.Background(), account.Id, userID, newPolicy)
require.NoError(t, err)
err = am.DeletePolicy(context.Background(), account.Id, defaultRule.ID, userID)

View File

@@ -406,8 +406,7 @@ func TestSetupKeyAccountPeersUpdate(t *testing.T) {
})
assert.NoError(t, err)
policy := Policy{
ID: "policy",
policy := &Policy{
Enabled: true,
Rules: []*PolicyRule{
{
@@ -419,7 +418,7 @@ func TestSetupKeyAccountPeersUpdate(t *testing.T) {
},
},
}
err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, false)
_, err = manager.SavePolicy(context.Background(), account.Id, userID, policy)
require.NoError(t, err)
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)

View File

@@ -1162,9 +1162,10 @@ func (s *SqlStore) GetAccountDNSSettings(ctx context.Context, lockStrength Locki
First(&accountDNSSettings, idQueryCondition, accountID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "dns settings not found")
return nil, status.NewAccountNotFoundError(accountID)
}
return nil, status.Errorf(status.Internal, "failed to get dns settings from store: %v", result.Error)
log.WithContext(ctx).Errorf("failed to get dns settings from store: %v", result.Error)
return nil, status.Errorf(status.Internal, "failed to get dns settings from store")
}
return &accountDNSSettings.DNSSettings, nil
}
@@ -1243,8 +1244,8 @@ func (s *SqlStore) GetGroupsByIDs(ctx context.Context, lockStrength LockingStren
var groups []*nbgroup.Group
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&groups, accountAndIDsQueryCondition, accountID, groupIDs)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to get groups by ID's from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "failed to get groups by ID's from the store")
log.WithContext(ctx).Errorf("failed to get groups by ID's from store: %s", result.Error)
return nil, status.Errorf(status.Internal, "failed to get groups by ID's from store")
}
groupsMap := make(map[string]*nbgroup.Group)
@@ -1295,22 +1296,139 @@ func (s *SqlStore) DeleteGroups(ctx context.Context, strength LockingStrength, a
// GetAccountPolicies retrieves policies for an account.
func (s *SqlStore) GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) {
return getRecords[*Policy](s.db.Preload(clause.Associations), lockStrength, accountID)
var policies []*Policy
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
Preload(clause.Associations).Find(&policies, accountIDCondition, accountID)
if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to get policies from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "failed to get policies from store")
}
return policies, nil
}
// GetPolicyByID retrieves a policy by its ID and account ID.
func (s *SqlStore) GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error) {
return getRecordByID[Policy](s.db.Preload(clause.Associations), lockStrength, policyID, accountID)
func (s *SqlStore) GetPolicyByID(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) (*Policy, error) {
var policy *Policy
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Preload(clause.Associations).
First(&policy, accountAndIDQueryCondition, accountID, policyID)
if err := result.Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, status.NewPolicyNotFoundError(policyID)
}
log.WithContext(ctx).Errorf("failed to get policy from store: %s", err)
return nil, status.Errorf(status.Internal, "failed to get policy from store")
}
return policy, nil
}
func (s *SqlStore) CreatePolicy(ctx context.Context, lockStrength LockingStrength, policy *Policy) error {
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Create(policy)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to create policy in store: %s", result.Error)
return status.Errorf(status.Internal, "failed to create policy in store")
}
return nil
}
// SavePolicy saves a policy to the database.
func (s *SqlStore) SavePolicy(ctx context.Context, lockStrength LockingStrength, policy *Policy) error {
result := s.db.Session(&gorm.Session{FullSaveAssociations: true}).
Clauses(clause.Locking{Strength: string(lockStrength)}).Save(policy)
if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to save policy to the store: %s", err)
return status.Errorf(status.Internal, "failed to save policy to store")
}
return nil
}
func (s *SqlStore) DeletePolicy(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) error {
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
Delete(&Policy{}, accountAndIDQueryCondition, accountID, policyID)
if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to delete policy from store: %s", err)
return status.Errorf(status.Internal, "failed to delete policy from store")
}
if result.RowsAffected == 0 {
return status.NewPolicyNotFoundError(policyID)
}
return nil
}
// GetAccountPostureChecks retrieves posture checks for an account.
func (s *SqlStore) GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error) {
return getRecords[*posture.Checks](s.db, lockStrength, accountID)
var postureChecks []*posture.Checks
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&postureChecks, accountIDCondition, accountID)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to get posture checks from store: %s", result.Error)
return nil, status.Errorf(status.Internal, "failed to get posture checks from store")
}
return postureChecks, nil
}
// GetPostureChecksByID retrieves posture checks by their ID and account ID.
func (s *SqlStore) GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, postureCheckID string, accountID string) (*posture.Checks, error) {
return getRecordByID[posture.Checks](s.db, lockStrength, postureCheckID, accountID)
func (s *SqlStore) GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, accountID, postureChecksID string) (*posture.Checks, error) {
var postureCheck *posture.Checks
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
First(&postureCheck, accountAndIDQueryCondition, accountID, postureChecksID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.NewPostureChecksNotFoundError(postureChecksID)
}
log.WithContext(ctx).Errorf("failed to get posture check from store: %s", result.Error)
return nil, status.Errorf(status.Internal, "failed to get posture check from store")
}
return postureCheck, nil
}
// GetPostureChecksByIDs retrieves posture checks by their IDs and account ID.
func (s *SqlStore) GetPostureChecksByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, postureChecksIDs []string) (map[string]*posture.Checks, error) {
var postureChecks []*posture.Checks
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&postureChecks, accountAndIDsQueryCondition, accountID, postureChecksIDs)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to get posture checks by ID's from store: %s", result.Error)
return nil, status.Errorf(status.Internal, "failed to get posture checks by ID's from store")
}
postureChecksMap := make(map[string]*posture.Checks)
for _, postureCheck := range postureChecks {
postureChecksMap[postureCheck.ID] = postureCheck
}
return postureChecksMap, nil
}
// SavePostureChecks saves a posture checks to the database.
func (s *SqlStore) SavePostureChecks(ctx context.Context, lockStrength LockingStrength, postureCheck *posture.Checks) error {
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(postureCheck)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to save posture checks to store: %s", result.Error)
return status.Errorf(status.Internal, "failed to save posture checks to store")
}
return nil
}
// DeletePostureChecks deletes a posture checks from the database.
func (s *SqlStore) DeletePostureChecks(ctx context.Context, lockStrength LockingStrength, accountID, postureChecksID string) error {
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
Delete(&posture.Checks{}, accountAndIDQueryCondition, accountID, postureChecksID)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to delete posture checks from store: %s", result.Error)
return status.Errorf(status.Internal, "failed to delete posture checks from store")
}
if result.RowsAffected == 0 {
return status.NewPostureChecksNotFoundError(postureChecksID)
}
return nil
}
// GetAccountRoutes retrieves network routes for an account.
@@ -1380,12 +1498,55 @@ func (s *SqlStore) DeleteSetupKey(ctx context.Context, lockStrength LockingStren
// GetAccountNameServerGroups retrieves name server groups for an account.
func (s *SqlStore) GetAccountNameServerGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbdns.NameServerGroup, error) {
return getRecords[*nbdns.NameServerGroup](s.db, lockStrength, accountID)
var nsGroups []*nbdns.NameServerGroup
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&nsGroups, accountIDCondition, accountID)
if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to get name server groups from the store: %s", err)
return nil, status.Errorf(status.Internal, "failed to get name server groups from store")
}
return nsGroups, nil
}
// GetNameServerGroupByID retrieves a name server group by its ID and account ID.
func (s *SqlStore) GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, nsGroupID string, accountID string) (*nbdns.NameServerGroup, error) {
return getRecordByID[nbdns.NameServerGroup](s.db, lockStrength, nsGroupID, accountID)
func (s *SqlStore) GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, nsGroupID string) (*nbdns.NameServerGroup, error) {
var nsGroup *nbdns.NameServerGroup
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
First(&nsGroup, accountAndIDQueryCondition, accountID, nsGroupID)
if err := result.Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, status.NewNameServerGroupNotFoundError(nsGroupID)
}
log.WithContext(ctx).Errorf("failed to get name server group from the store: %s", err)
return nil, status.Errorf(status.Internal, "failed to get name server group from store")
}
return nsGroup, nil
}
// SaveNameServerGroup saves a name server group to the database.
func (s *SqlStore) SaveNameServerGroup(ctx context.Context, lockStrength LockingStrength, nameServerGroup *nbdns.NameServerGroup) error {
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(nameServerGroup)
if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to save name server group to the store: %s", err)
return status.Errorf(status.Internal, "failed to save name server group to store")
}
return nil
}
// DeleteNameServerGroup deletes a name server group from the database.
func (s *SqlStore) DeleteNameServerGroup(ctx context.Context, lockStrength LockingStrength, accountID, nsGroupID string) error {
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Delete(&nbdns.NameServerGroup{}, accountAndIDQueryCondition, accountID, nsGroupID)
if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to delete name server group from the store: %s", err)
return status.Errorf(status.Internal, "failed to delete name server group from store")
}
if result.RowsAffected == 0 {
return status.NewNameServerGroupNotFoundError(nsGroupID)
}
return nil
}
// getRecords retrieves records from the database based on the account ID.
@@ -1420,3 +1581,19 @@ func getRecordByID[T any](db *gorm.DB, lockStrength LockingStrength, recordID, a
}
return &record, nil
}
// SaveDNSSettings saves the DNS settings to the store.
func (s *SqlStore) SaveDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *DNSSettings) error {
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).
Where(idQueryCondition, accountID).Updates(&AccountDNSSettings{DNSSettings: *settings})
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to save dns settings to store: %v", result.Error)
return status.Errorf(status.Internal, "failed to save dns settings to store")
}
if result.RowsAffected == 0 {
return status.NewAccountNotFoundError(accountID)
}
return nil
}

View File

@@ -16,6 +16,7 @@ import (
"github.com/google/uuid"
nbdns "github.com/netbirdio/netbird/dns"
nbgroup "github.com/netbirdio/netbird/management/server/group"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -1564,3 +1565,489 @@ func TestSqlStore_GetPeersByIDs(t *testing.T) {
})
}
}
func TestSqlStore_GetPostureChecksByID(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
tests := []struct {
name string
postureChecksID string
expectError bool
}{
{
name: "retrieve existing posture checks",
postureChecksID: "csplshq7qv948l48f7t0",
expectError: false,
},
{
name: "retrieve non-existing posture checks",
postureChecksID: "non-existing",
expectError: true,
},
{
name: "retrieve with empty posture checks ID",
postureChecksID: "",
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
postureChecks, err := store.GetPostureChecksByID(context.Background(), LockingStrengthShare, accountID, tt.postureChecksID)
if tt.expectError {
require.Error(t, err)
sErr, ok := status.FromError(err)
require.True(t, ok)
require.Equal(t, sErr.Type(), status.NotFound)
require.Nil(t, postureChecks)
} else {
require.NoError(t, err)
require.NotNil(t, postureChecks)
require.Equal(t, tt.postureChecksID, postureChecks.ID)
}
})
}
}
func TestSqlStore_GetPostureChecksByIDs(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
tests := []struct {
name string
postureCheckIDs []string
expectedCount int
}{
{
name: "retrieve existing posture checks by existing IDs",
postureCheckIDs: []string{"csplshq7qv948l48f7t0", "cspnllq7qv95uq1r4k90"},
expectedCount: 2,
},
{
name: "empty posture check IDs list",
postureCheckIDs: []string{},
expectedCount: 0,
},
{
name: "non-existing posture check IDs",
postureCheckIDs: []string{"nonexistent1", "nonexistent2"},
expectedCount: 0,
},
{
name: "mixed existing and non-existing posture check IDs",
postureCheckIDs: []string{"cspnllq7qv95uq1r4k90", "nonexistent"},
expectedCount: 1,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
groups, err := store.GetPostureChecksByIDs(context.Background(), LockingStrengthShare, accountID, tt.postureCheckIDs)
require.NoError(t, err)
require.Len(t, groups, tt.expectedCount)
})
}
}
func TestSqlStore_SavePostureChecks(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
postureChecks := &posture.Checks{
ID: "posture-checks-id",
AccountID: accountID,
Checks: posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{
MinVersion: "0.31.0",
},
OSVersionCheck: &posture.OSVersionCheck{
Ios: &posture.MinVersionCheck{
MinVersion: "13.0.1",
},
Linux: &posture.MinKernelVersionCheck{
MinKernelVersion: "5.3.3-dev",
},
},
GeoLocationCheck: &posture.GeoLocationCheck{
Locations: []posture.Location{
{
CountryCode: "DE",
CityName: "Berlin",
},
},
Action: posture.CheckActionAllow,
},
},
}
err = store.SavePostureChecks(context.Background(), LockingStrengthUpdate, postureChecks)
require.NoError(t, err)
savePostureChecks, err := store.GetPostureChecksByID(context.Background(), LockingStrengthShare, accountID, "posture-checks-id")
require.NoError(t, err)
require.Equal(t, savePostureChecks, postureChecks)
}
func TestSqlStore_DeletePostureChecks(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
tests := []struct {
name string
postureChecksID string
expectError bool
}{
{
name: "delete existing posture checks",
postureChecksID: "csplshq7qv948l48f7t0",
expectError: false,
},
{
name: "delete non-existing posture checks",
postureChecksID: "non-existing-posture-checks-id",
expectError: true,
},
{
name: "delete with empty posture checks ID",
postureChecksID: "",
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err = store.DeletePostureChecks(context.Background(), LockingStrengthUpdate, accountID, tt.postureChecksID)
if tt.expectError {
require.Error(t, err)
sErr, ok := status.FromError(err)
require.True(t, ok)
require.Equal(t, sErr.Type(), status.NotFound)
} else {
require.NoError(t, err)
group, err := store.GetPostureChecksByID(context.Background(), LockingStrengthShare, accountID, tt.postureChecksID)
require.Error(t, err)
require.Nil(t, group)
}
})
}
}
func TestSqlStore_GetPolicyByID(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
tests := []struct {
name string
policyID string
expectError bool
}{
{
name: "retrieve existing policy",
policyID: "cs1tnh0hhcjnqoiuebf0",
expectError: false,
},
{
name: "retrieve non-existing policy checks",
policyID: "non-existing",
expectError: true,
},
{
name: "retrieve with empty policy ID",
policyID: "",
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
policy, err := store.GetPolicyByID(context.Background(), LockingStrengthShare, accountID, tt.policyID)
if tt.expectError {
require.Error(t, err)
sErr, ok := status.FromError(err)
require.True(t, ok)
require.Equal(t, sErr.Type(), status.NotFound)
require.Nil(t, policy)
} else {
require.NoError(t, err)
require.NotNil(t, policy)
require.Equal(t, tt.policyID, policy.ID)
}
})
}
}
func TestSqlStore_CreatePolicy(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
policy := &Policy{
ID: "policy-id",
AccountID: accountID,
Enabled: true,
Rules: []*PolicyRule{
{
Enabled: true,
Sources: []string{"groupA"},
Destinations: []string{"groupC"},
Bidirectional: true,
Action: PolicyTrafficActionAccept,
},
},
}
err = store.CreatePolicy(context.Background(), LockingStrengthUpdate, policy)
require.NoError(t, err)
savePolicy, err := store.GetPolicyByID(context.Background(), LockingStrengthShare, accountID, policy.ID)
require.NoError(t, err)
require.Equal(t, savePolicy, policy)
}
func TestSqlStore_SavePolicy(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
policyID := "cs1tnh0hhcjnqoiuebf0"
policy, err := store.GetPolicyByID(context.Background(), LockingStrengthShare, accountID, policyID)
require.NoError(t, err)
policy.Enabled = false
policy.Description = "policy"
policy.Rules[0].Sources = []string{"group"}
policy.Rules[0].Ports = []string{"80", "443"}
err = store.SavePolicy(context.Background(), LockingStrengthUpdate, policy)
require.NoError(t, err)
savePolicy, err := store.GetPolicyByID(context.Background(), LockingStrengthShare, accountID, policy.ID)
require.NoError(t, err)
require.Equal(t, savePolicy, policy)
}
func TestSqlStore_DeletePolicy(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
policyID := "cs1tnh0hhcjnqoiuebf0"
err = store.DeletePolicy(context.Background(), LockingStrengthShare, accountID, policyID)
require.NoError(t, err)
policy, err := store.GetPolicyByID(context.Background(), LockingStrengthShare, accountID, policyID)
require.Error(t, err)
require.Nil(t, policy)
}
func TestSqlStore_GetDNSSettings(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
tests := []struct {
name string
accountID string
expectError bool
}{
{
name: "retrieve existing account dns settings",
accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b",
expectError: false,
},
{
name: "retrieve non-existing account dns settings",
accountID: "non-existing",
expectError: true,
},
{
name: "retrieve dns settings with empty account ID",
accountID: "",
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
dnsSettings, err := store.GetAccountDNSSettings(context.Background(), LockingStrengthShare, tt.accountID)
if tt.expectError {
require.Error(t, err)
sErr, ok := status.FromError(err)
require.True(t, ok)
require.Equal(t, sErr.Type(), status.NotFound)
require.Nil(t, dnsSettings)
} else {
require.NoError(t, err)
require.NotNil(t, dnsSettings)
}
})
}
}
func TestSqlStore_SaveDNSSettings(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
dnsSettings, err := store.GetAccountDNSSettings(context.Background(), LockingStrengthShare, accountID)
require.NoError(t, err)
dnsSettings.DisabledManagementGroups = []string{"groupA", "groupB"}
err = store.SaveDNSSettings(context.Background(), LockingStrengthUpdate, accountID, dnsSettings)
require.NoError(t, err)
saveDNSSettings, err := store.GetAccountDNSSettings(context.Background(), LockingStrengthShare, accountID)
require.NoError(t, err)
require.Equal(t, saveDNSSettings, dnsSettings)
}
func TestSqlStore_GetAccountNameServerGroups(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
tests := []struct {
name string
accountID string
expectedCount int
}{
{
name: "retrieve name server groups by existing account ID",
accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b",
expectedCount: 1,
},
{
name: "non-existing account ID",
accountID: "nonexistent",
expectedCount: 0,
},
{
name: "empty account ID",
accountID: "",
expectedCount: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
peers, err := store.GetAccountNameServerGroups(context.Background(), LockingStrengthShare, tt.accountID)
require.NoError(t, err)
require.Len(t, peers, tt.expectedCount)
})
}
}
func TestSqlStore_GetNameServerByID(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
tests := []struct {
name string
nsGroupID string
expectError bool
}{
{
name: "retrieve existing nameserver group",
nsGroupID: "csqdelq7qv97ncu7d9t0",
expectError: false,
},
{
name: "retrieve non-existing nameserver group",
nsGroupID: "non-existing",
expectError: true,
},
{
name: "retrieve with empty nameserver group ID",
nsGroupID: "",
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
nsGroup, err := store.GetNameServerGroupByID(context.Background(), LockingStrengthShare, accountID, tt.nsGroupID)
if tt.expectError {
require.Error(t, err)
sErr, ok := status.FromError(err)
require.True(t, ok)
require.Equal(t, sErr.Type(), status.NotFound)
require.Nil(t, nsGroup)
} else {
require.NoError(t, err)
require.NotNil(t, nsGroup)
require.Equal(t, tt.nsGroupID, nsGroup.ID)
}
})
}
}
func TestSqlStore_SaveNameServerGroup(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
nsGroup := &nbdns.NameServerGroup{
ID: "ns-group-id",
AccountID: accountID,
Name: "NS Group",
NameServers: []nbdns.NameServer{
{
IP: netip.MustParseAddr("8.8.8.8"),
NSType: 1,
Port: 53,
},
},
Groups: []string{"groupA"},
Primary: true,
Enabled: true,
SearchDomainsEnabled: false,
}
err = store.SaveNameServerGroup(context.Background(), LockingStrengthUpdate, nsGroup)
require.NoError(t, err)
saveNSGroup, err := store.GetNameServerGroupByID(context.Background(), LockingStrengthShare, accountID, nsGroup.ID)
require.NoError(t, err)
require.Equal(t, saveNSGroup, nsGroup)
}
func TestSqlStore_DeleteNameServerGroup(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
nsGroupID := "csqdelq7qv97ncu7d9t0"
err = store.DeleteNameServerGroup(context.Background(), LockingStrengthShare, accountID, nsGroupID)
require.NoError(t, err)
nsGroup, err := store.GetNameServerGroupByID(context.Background(), LockingStrengthShare, accountID, nsGroupID)
require.Error(t, err)
require.Nil(t, nsGroup)
}

View File

@@ -139,3 +139,18 @@ func NewGetAccountError(err error) error {
func NewGroupNotFoundError(groupID string) error {
return Errorf(NotFound, "group: %s not found", groupID)
}
// NewPostureChecksNotFoundError creates a new Error with NotFound type for a missing posture checks
func NewPostureChecksNotFoundError(postureChecksID string) error {
return Errorf(NotFound, "posture checks: %s not found", postureChecksID)
}
// NewPolicyNotFoundError creates a new Error with NotFound type for a missing policy
func NewPolicyNotFoundError(policyID string) error {
return Errorf(NotFound, "policy: %s not found", policyID)
}
// NewNameServerGroupNotFoundError creates a new Error with NotFound type for a missing name server group
func NewNameServerGroupNotFoundError(nsGroupID string) error {
return Errorf(NotFound, "nameserver group: %s not found", nsGroupID)
}

View File

@@ -59,6 +59,7 @@ type Store interface {
SaveAccount(ctx context.Context, account *Account) error
DeleteAccount(ctx context.Context, account *Account) error
UpdateAccountDomainAttributes(ctx context.Context, accountID string, domain string, category string, isPrimaryDomain bool) error
SaveDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *DNSSettings) error
GetUserByTokenID(ctx context.Context, tokenID string) (*User, error)
GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error)
@@ -80,11 +81,17 @@ type Store interface {
DeleteGroups(ctx context.Context, strength LockingStrength, accountID string, groupIDs []string) error
GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error)
GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error)
GetPolicyByID(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) (*Policy, error)
CreatePolicy(ctx context.Context, lockStrength LockingStrength, policy *Policy) error
SavePolicy(ctx context.Context, lockStrength LockingStrength, policy *Policy) error
DeletePolicy(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) error
GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error)
GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, postureCheckID string, accountID string) (*posture.Checks, error)
GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, accountID, postureCheckID string) (*posture.Checks, error)
GetPostureChecksByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, postureChecksIDs []string) (map[string]*posture.Checks, error)
SavePostureChecks(ctx context.Context, lockStrength LockingStrength, postureCheck *posture.Checks) error
DeletePostureChecks(ctx context.Context, lockStrength LockingStrength, accountID, postureChecksID string) error
GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string) ([]string, error)
AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error
@@ -110,6 +117,8 @@ type Store interface {
GetAccountNameServerGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*dns.NameServerGroup, error)
GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, nameServerGroupID string, accountID string) (*dns.NameServerGroup, error)
SaveNameServerGroup(ctx context.Context, lockStrength LockingStrength, nameServerGroup *dns.NameServerGroup) error
DeleteNameServerGroup(ctx context.Context, lockStrength LockingStrength, accountID, nameServerGroupID string) error
GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountId string) ([]net.IP, error)
IncrementNetworkSerial(ctx context.Context, lockStrength LockingStrength, accountId string) error

View File

@@ -34,4 +34,7 @@ INSERT INTO personal_access_tokens VALUES('9dj38s35-63fb-11ec-90d6-0242ac120003'
INSERT INTO "groups" VALUES('cfefqs706sqkneg59g4g','bf1c8084-ba50-4ce7-9439-34653001fc3b','All','api','[]',0,'');
INSERT INTO "groups" VALUES('cfefqs706sqkneg59g3g','bf1c8084-ba50-4ce7-9439-34653001fc3b','AwesomeGroup1','api','[]',0,'');
INSERT INTO "groups" VALUES('cfefqs706sqkneg59g2g','bf1c8084-ba50-4ce7-9439-34653001fc3b','AwesomeGroup2','api','[]',0,'');
INSERT INTO posture_checks VALUES('csplshq7qv948l48f7t0','NetBird Version > 0.32.0','','bf1c8084-ba50-4ce7-9439-34653001fc3b','{"NBVersionCheck":{"MinVersion":"0.31.0"}}');
INSERT INTO posture_checks VALUES('cspnllq7qv95uq1r4k90','Allow Berlin and Deny local network 172.16.1.0/24','','bf1c8084-ba50-4ce7-9439-34653001fc3b','{"GeoLocationCheck":{"Locations":[{"CountryCode":"DE","CityName":"Berlin"}],"Action":"allow"},"PeerNetworkRangeCheck":{"Action":"deny","Ranges":["172.16.1.0/24"]}}');
INSERT INTO name_server_groups VALUES('csqdelq7qv97ncu7d9t0','bf1c8084-ba50-4ce7-9439-34653001fc3b','Google DNS','Google DNS Servers','[{"IP":"8.8.8.8","NSType":1,"Port":53},{"IP":"8.8.4.4","NSType":1,"Port":53}]','["cfefqs706sqkneg59g2g"]',1,'[]',1,0);
INSERT INTO installations VALUES(1,'');

View File

@@ -1279,8 +1279,7 @@ func TestUserAccountPeersUpdate(t *testing.T) {
})
require.NoError(t, err)
policy := Policy{
ID: "policy",
policy := &Policy{
Enabled: true,
Rules: []*PolicyRule{
{
@@ -1292,7 +1291,7 @@ func TestUserAccountPeersUpdate(t *testing.T) {
},
},
}
err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, false)
_, err = manager.SavePolicy(context.Background(), account.Id, userID, policy)
require.NoError(t, err)
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)

View File

@@ -140,7 +140,7 @@ type Client struct {
instanceURL *RelayAddr
muInstanceURL sync.Mutex
onDisconnectListener func()
onDisconnectListener func(string)
onConnectedListener func()
listenerMutex sync.Mutex
}
@@ -233,7 +233,7 @@ func (c *Client) ServerInstanceURL() (string, error) {
}
// SetOnDisconnectListener sets a function that will be called when the connection to the relay server is closed.
func (c *Client) SetOnDisconnectListener(fn func()) {
func (c *Client) SetOnDisconnectListener(fn func(string)) {
c.listenerMutex.Lock()
defer c.listenerMutex.Unlock()
c.onDisconnectListener = fn
@@ -554,7 +554,7 @@ func (c *Client) notifyDisconnected() {
if c.onDisconnectListener == nil {
return
}
go c.onDisconnectListener()
go c.onDisconnectListener(c.connectionURL)
}
func (c *Client) notifyConnected() {

View File

@@ -551,7 +551,7 @@ func TestCloseByServer(t *testing.T) {
}
disconnected := make(chan struct{})
relayClient.SetOnDisconnectListener(func() {
relayClient.SetOnDisconnectListener(func(_ string) {
log.Infof("client disconnected")
close(disconnected)
})

View File

@@ -4,65 +4,120 @@ import (
"context"
"time"
"github.com/cenkalti/backoff/v4"
log "github.com/sirupsen/logrus"
)
var (
reconnectingTimeout = 5 * time.Second
reconnectingTimeout = 60 * time.Second
)
// Guard manage the reconnection tries to the Relay server in case of disconnection event.
type Guard struct {
ctx context.Context
relayClient *Client
// OnNewRelayClient is a channel that is used to notify the relay client about a new relay client instance.
OnNewRelayClient chan *Client
serverPicker *ServerPicker
}
// NewGuard creates a new guard for the relay client.
func NewGuard(context context.Context, relayClient *Client) *Guard {
func NewGuard(sp *ServerPicker) *Guard {
g := &Guard{
ctx: context,
relayClient: relayClient,
OnNewRelayClient: make(chan *Client, 1),
serverPicker: sp,
}
return g
}
// OnDisconnected is called when the relay client is disconnected from the relay server. It will trigger the reconnection
// StartReconnectTrys is called when the relay client is disconnected from the relay server.
// It attempts to reconnect to the relay server. The function first tries a quick reconnect
// to the same server that was used before, if the server URL is still valid. If the quick
// reconnect fails, it starts a ticker to periodically attempt server picking until it
// succeeds or the context is done.
//
// Parameters:
// - ctx: The context to control the lifecycle of the reconnection attempts.
// - relayClient: The relay client instance that was disconnected.
// todo prevent multiple reconnection instances. In the current usage it should not happen, but it is better to prevent
func (g *Guard) OnDisconnected() {
if g.quickReconnect() {
func (g *Guard) StartReconnectTrys(ctx context.Context, relayClient *Client) {
if relayClient == nil {
goto RETRY
}
if g.isServerURLStillValid(relayClient) && g.quickReconnect(ctx, relayClient) {
return
}
ticker := time.NewTicker(reconnectingTimeout)
RETRY:
ticker := exponentTicker(ctx)
defer ticker.Stop()
for {
select {
case <-ticker.C:
err := g.relayClient.Connect()
if err != nil {
log.Errorf("failed to reconnect to relay server: %s", err)
if err := g.retry(ctx); err != nil {
log.Errorf("failed to pick new Relay server: %s", err)
continue
}
return
case <-g.ctx.Done():
case <-ctx.Done():
return
}
}
}
func (g *Guard) quickReconnect() bool {
ctx, cancel := context.WithTimeout(g.ctx, 1500*time.Millisecond)
func (g *Guard) retry(ctx context.Context) error {
log.Infof("try to pick up a new Relay server")
relayClient, err := g.serverPicker.PickServer(ctx)
if err != nil {
return err
}
// prevent to work with a deprecated Relay client instance
g.drainRelayClientChan()
g.OnNewRelayClient <- relayClient
return nil
}
func (g *Guard) quickReconnect(parentCtx context.Context, rc *Client) bool {
ctx, cancel := context.WithTimeout(parentCtx, 1500*time.Millisecond)
defer cancel()
<-ctx.Done()
if g.ctx.Err() != nil {
if parentCtx.Err() != nil {
return false
}
log.Infof("try to reconnect to Relay server: %s", rc.connectionURL)
if err := g.relayClient.Connect(); err != nil {
if err := rc.Connect(); err != nil {
log.Errorf("failed to reconnect to relay server: %s", err)
return false
}
return true
}
func (g *Guard) drainRelayClientChan() {
select {
case <-g.OnNewRelayClient:
default:
}
}
func (g *Guard) isServerURLStillValid(rc *Client) bool {
for _, url := range g.serverPicker.ServerURLs.Load().([]string) {
if url == rc.connectionURL {
return true
}
}
return false
}
func exponentTicker(ctx context.Context) *backoff.Ticker {
bo := backoff.WithContext(&backoff.ExponentialBackOff{
InitialInterval: 2 * time.Second,
Multiplier: 2,
MaxInterval: reconnectingTimeout,
Clock: backoff.SystemClock,
}, ctx)
return backoff.NewTicker(bo)
}

View File

@@ -57,12 +57,15 @@ type ManagerService interface {
// relay servers will be closed if there is no active connection. Periodically the manager will check if there is any
// unused relay connection and close it.
type Manager struct {
ctx context.Context
serverURLs []string
peerID string
tokenStore *relayAuth.TokenStore
ctx context.Context
peerID string
running bool
tokenStore *relayAuth.TokenStore
serverPicker *ServerPicker
relayClient *Client
relayClient *Client
// the guard logic can overwrite the relayClient variable, this mutex protect the usage of the variable
relayClientMu sync.Mutex
reconnectGuard *Guard
relayClients map[string]*RelayTrack
@@ -76,48 +79,54 @@ type Manager struct {
// NewManager creates a new manager instance.
// The serverURL address can be empty. In this case, the manager will not serve.
func NewManager(ctx context.Context, serverURLs []string, peerID string) *Manager {
return &Manager{
ctx: ctx,
serverURLs: serverURLs,
peerID: peerID,
tokenStore: &relayAuth.TokenStore{},
tokenStore := &relayAuth.TokenStore{}
m := &Manager{
ctx: ctx,
peerID: peerID,
tokenStore: tokenStore,
serverPicker: &ServerPicker{
TokenStore: tokenStore,
PeerID: peerID,
},
relayClients: make(map[string]*RelayTrack),
onDisconnectedListeners: make(map[string]*list.List),
}
m.serverPicker.ServerURLs.Store(serverURLs)
m.reconnectGuard = NewGuard(m.serverPicker)
return m
}
// Serve starts the manager. It will establish a connection to the relay server and start the relay cleanup loop for
// the unused relay connections. The manager will automatically reconnect to the relay server in case of disconnection.
// Serve starts the manager, attempting to establish a connection with the relay server.
// If the connection fails, it will keep trying to reconnect in the background.
// Additionally, it starts a cleanup loop to remove unused relay connections.
// The manager will automatically reconnect to the relay server in case of disconnection.
func (m *Manager) Serve() error {
if m.relayClient != nil {
if m.running {
return fmt.Errorf("manager already serving")
}
log.Debugf("starting relay client manager with %v relay servers", m.serverURLs)
m.running = true
log.Debugf("starting relay client manager with %v relay servers", m.serverPicker.ServerURLs.Load())
sp := ServerPicker{
TokenStore: m.tokenStore,
PeerID: m.peerID,
}
client, err := sp.PickServer(m.ctx, m.serverURLs)
client, err := m.serverPicker.PickServer(m.ctx)
if err != nil {
return err
go m.reconnectGuard.StartReconnectTrys(m.ctx, nil)
} else {
m.storeClient(client)
}
m.relayClient = client
m.reconnectGuard = NewGuard(m.ctx, m.relayClient)
m.relayClient.SetOnConnectedListener(m.onServerConnected)
m.relayClient.SetOnDisconnectListener(func() {
m.onServerDisconnected(client.connectionURL)
})
m.startCleanupLoop()
return nil
go m.listenGuardEvent(m.ctx)
go m.startCleanupLoop()
return err
}
// OpenConn opens a connection to the given peer key. If the peer is on the same relay server, the connection will be
// established via the relay server. If the peer is on a different relay server, the manager will establish a new
// connection to the relay server. It returns back with a net.Conn what represent the remote peer connection.
func (m *Manager) OpenConn(serverAddress, peerKey string) (net.Conn, error) {
m.relayClientMu.Lock()
defer m.relayClientMu.Unlock()
if m.relayClient == nil {
return nil, ErrRelayClientNotConnected
}
@@ -146,6 +155,9 @@ func (m *Manager) OpenConn(serverAddress, peerKey string) (net.Conn, error) {
// Ready returns true if the home Relay client is connected to the relay server.
func (m *Manager) Ready() bool {
m.relayClientMu.Lock()
defer m.relayClientMu.Unlock()
if m.relayClient == nil {
return false
}
@@ -159,6 +171,13 @@ func (m *Manager) SetOnReconnectedListener(f func()) {
// AddCloseListener adds a listener to the given server instance address. The listener will be called if the connection
// closed.
func (m *Manager) AddCloseListener(serverAddress string, onClosedListener OnServerCloseListener) error {
m.relayClientMu.Lock()
defer m.relayClientMu.Unlock()
if m.relayClient == nil {
return ErrRelayClientNotConnected
}
foreign, err := m.isForeignServer(serverAddress)
if err != nil {
return err
@@ -177,6 +196,9 @@ func (m *Manager) AddCloseListener(serverAddress string, onClosedListener OnServ
// RelayInstanceAddress returns the address of the permanent relay server. It could change if the network connection is
// lost. This address will be sent to the target peer to choose the common relay server for the communication.
func (m *Manager) RelayInstanceAddress() (string, error) {
m.relayClientMu.Lock()
defer m.relayClientMu.Unlock()
if m.relayClient == nil {
return "", ErrRelayClientNotConnected
}
@@ -185,13 +207,18 @@ func (m *Manager) RelayInstanceAddress() (string, error) {
// ServerURLs returns the addresses of the relay servers.
func (m *Manager) ServerURLs() []string {
return m.serverURLs
return m.serverPicker.ServerURLs.Load().([]string)
}
// HasRelayAddress returns true if the manager is serving. With this method can check if the peer can communicate with
// Relay service.
func (m *Manager) HasRelayAddress() bool {
return len(m.serverURLs) > 0
return len(m.serverPicker.ServerURLs.Load().([]string)) > 0
}
func (m *Manager) UpdateServerURLs(serverURLs []string) {
log.Infof("update relay server URLs: %v", serverURLs)
m.serverPicker.ServerURLs.Store(serverURLs)
}
// UpdateToken updates the token in the token store.
@@ -245,9 +272,7 @@ func (m *Manager) openConnVia(serverAddress, peerKey string) (net.Conn, error) {
return nil, err
}
// if connection closed then delete the relay client from the list
relayClient.SetOnDisconnectListener(func() {
m.onServerDisconnected(serverAddress)
})
relayClient.SetOnDisconnectListener(m.onServerDisconnected)
rt.relayClient = relayClient
rt.Unlock()
@@ -265,14 +290,37 @@ func (m *Manager) onServerConnected() {
go m.onReconnectedListenerFn()
}
// onServerDisconnected start to reconnection for home server only
func (m *Manager) onServerDisconnected(serverAddress string) {
m.relayClientMu.Lock()
if serverAddress == m.relayClient.connectionURL {
go m.reconnectGuard.OnDisconnected()
go m.reconnectGuard.StartReconnectTrys(m.ctx, m.relayClient)
}
m.relayClientMu.Unlock()
m.notifyOnDisconnectListeners(serverAddress)
}
func (m *Manager) listenGuardEvent(ctx context.Context) {
for {
select {
case rc := <-m.reconnectGuard.OnNewRelayClient:
m.storeClient(rc)
case <-ctx.Done():
return
}
}
}
func (m *Manager) storeClient(client *Client) {
m.relayClientMu.Lock()
defer m.relayClientMu.Unlock()
m.relayClient = client
m.relayClient.SetOnConnectedListener(m.onServerConnected)
m.relayClient.SetOnDisconnectListener(m.onServerDisconnected)
}
func (m *Manager) isForeignServer(address string) (bool, error) {
rAddr, err := m.relayClient.ServerInstanceURL()
if err != nil {
@@ -282,22 +330,16 @@ func (m *Manager) isForeignServer(address string) (bool, error) {
}
func (m *Manager) startCleanupLoop() {
if m.ctx.Err() != nil {
return
}
ticker := time.NewTicker(relayCleanupInterval)
go func() {
defer ticker.Stop()
for {
select {
case <-m.ctx.Done():
return
case <-ticker.C:
m.cleanUpUnusedRelays()
}
defer ticker.Stop()
for {
select {
case <-m.ctx.Done():
return
case <-ticker.C:
m.cleanUpUnusedRelays()
}
}()
}
}
func (m *Manager) cleanUpUnusedRelays() {

View File

@@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
"sync/atomic"
"time"
log "github.com/sirupsen/logrus"
@@ -12,10 +13,13 @@ import (
)
const (
connectionTimeout = 30 * time.Second
maxConcurrentServers = 7
)
var (
connectionTimeout = 30 * time.Second
)
type connResult struct {
RelayClient *Client
Url string
@@ -24,20 +28,22 @@ type connResult struct {
type ServerPicker struct {
TokenStore *auth.TokenStore
ServerURLs atomic.Value
PeerID string
}
func (sp *ServerPicker) PickServer(parentCtx context.Context, urls []string) (*Client, error) {
func (sp *ServerPicker) PickServer(parentCtx context.Context) (*Client, error) {
ctx, cancel := context.WithTimeout(parentCtx, connectionTimeout)
defer cancel()
totalServers := len(urls)
totalServers := len(sp.ServerURLs.Load().([]string))
connResultChan := make(chan connResult, totalServers)
successChan := make(chan connResult, 1)
concurrentLimiter := make(chan struct{}, maxConcurrentServers)
for _, url := range urls {
log.Debugf("pick server from list: %v", sp.ServerURLs.Load().([]string))
for _, url := range sp.ServerURLs.Load().([]string) {
// todo check if we have a successful connection so we do not need to connect to other servers
concurrentLimiter <- struct{}{}
go func(url string) {
@@ -78,7 +84,7 @@ func (sp *ServerPicker) processConnResults(resultChan chan connResult, successCh
for numOfResults := 0; numOfResults < cap(resultChan); numOfResults++ {
cr := <-resultChan
if cr.Err != nil {
log.Debugf("failed to connect to Relay server: %s: %v", cr.Url, cr.Err)
log.Tracef("failed to connect to Relay server: %s: %v", cr.Url, cr.Err)
continue
}
log.Infof("connected to Relay server: %s", cr.Url)

View File

@@ -4,19 +4,23 @@ import (
"context"
"errors"
"testing"
"time"
)
func TestServerPicker_UnavailableServers(t *testing.T) {
connectionTimeout = 5 * time.Second
sp := ServerPicker{
TokenStore: nil,
PeerID: "test",
}
sp.ServerURLs.Store([]string{"rel://dummy1", "rel://dummy2"})
ctx, cancel := context.WithTimeout(context.Background(), connectionTimeout+1)
defer cancel()
go func() {
_, err := sp.PickServer(ctx, []string{"rel://dummy1", "rel://dummy2"})
_, err := sp.PickServer(ctx)
if err == nil {
t.Error(err)
}

31
util/net/conn.go Normal file
View File

@@ -0,0 +1,31 @@
//go:build !ios
package net
import (
"net"
log "github.com/sirupsen/logrus"
)
// Conn wraps a net.Conn to override the Close method
type Conn struct {
net.Conn
ID ConnectionID
}
// Close overrides the net.Conn Close method to execute all registered hooks after closing the connection
func (c *Conn) Close() error {
err := c.Conn.Close()
dialerCloseHooksMutex.RLock()
defer dialerCloseHooksMutex.RUnlock()
for _, hook := range dialerCloseHooks {
if err := hook(c.ID, &c.Conn); err != nil {
log.Errorf("Error executing dialer close hook: %v", err)
}
}
return err
}

58
util/net/dial.go Normal file
View File

@@ -0,0 +1,58 @@
//go:build !ios
package net
import (
"fmt"
"net"
log "github.com/sirupsen/logrus"
)
func DialUDP(network string, laddr, raddr *net.UDPAddr) (*net.UDPConn, error) {
if CustomRoutingDisabled() {
return net.DialUDP(network, laddr, raddr)
}
dialer := NewDialer()
dialer.LocalAddr = laddr
conn, err := dialer.Dial(network, raddr.String())
if err != nil {
return nil, fmt.Errorf("dialing UDP %s: %w", raddr.String(), err)
}
udpConn, ok := conn.(*Conn).Conn.(*net.UDPConn)
if !ok {
if err := conn.Close(); err != nil {
log.Errorf("Failed to close connection: %v", err)
}
return nil, fmt.Errorf("expected UDP connection, got different type: %T", conn)
}
return udpConn, nil
}
func DialTCP(network string, laddr, raddr *net.TCPAddr) (*net.TCPConn, error) {
if CustomRoutingDisabled() {
return net.DialTCP(network, laddr, raddr)
}
dialer := NewDialer()
dialer.LocalAddr = laddr
conn, err := dialer.Dial(network, raddr.String())
if err != nil {
return nil, fmt.Errorf("dialing TCP %s: %w", raddr.String(), err)
}
tcpConn, ok := conn.(*Conn).Conn.(*net.TCPConn)
if !ok {
if err := conn.Close(); err != nil {
log.Errorf("Failed to close connection: %v", err)
}
return nil, fmt.Errorf("expected TCP connection, got different type: %T", conn)
}
return tcpConn, nil
}

View File

@@ -1,25 +0,0 @@
package net
import (
"syscall"
log "github.com/sirupsen/logrus"
)
func (d *Dialer) init() {
d.Dialer.Control = func(_, _ string, c syscall.RawConn) error {
err := c.Control(func(fd uintptr) {
androidProtectSocketLock.Lock()
f := androidProtectSocket
androidProtectSocketLock.Unlock()
if f == nil {
return
}
ok := f(int32(fd))
if !ok {
log.Errorf("failed to protect socket: %d", fd)
}
})
return err
}
}

View File

@@ -81,28 +81,6 @@ func (d *Dialer) Dial(network, address string) (net.Conn, error) {
return d.DialContext(context.Background(), network, address)
}
// Conn wraps a net.Conn to override the Close method
type Conn struct {
net.Conn
ID ConnectionID
}
// Close overrides the net.Conn Close method to execute all registered hooks after closing the connection
func (c *Conn) Close() error {
err := c.Conn.Close()
dialerCloseHooksMutex.RLock()
defer dialerCloseHooksMutex.RUnlock()
for _, hook := range dialerCloseHooks {
if err := hook(c.ID, &c.Conn); err != nil {
log.Errorf("Error executing dialer close hook: %v", err)
}
}
return err
}
func callDialerHooks(ctx context.Context, connID ConnectionID, address string, resolver *net.Resolver) error {
host, _, err := net.SplitHostPort(address)
if err != nil {
@@ -127,51 +105,3 @@ func callDialerHooks(ctx context.Context, connID ConnectionID, address string, r
return result.ErrorOrNil()
}
func DialUDP(network string, laddr, raddr *net.UDPAddr) (*net.UDPConn, error) {
if CustomRoutingDisabled() {
return net.DialUDP(network, laddr, raddr)
}
dialer := NewDialer()
dialer.LocalAddr = laddr
conn, err := dialer.Dial(network, raddr.String())
if err != nil {
return nil, fmt.Errorf("dialing UDP %s: %w", raddr.String(), err)
}
udpConn, ok := conn.(*Conn).Conn.(*net.UDPConn)
if !ok {
if err := conn.Close(); err != nil {
log.Errorf("Failed to close connection: %v", err)
}
return nil, fmt.Errorf("expected UDP connection, got different type: %T", conn)
}
return udpConn, nil
}
func DialTCP(network string, laddr, raddr *net.TCPAddr) (*net.TCPConn, error) {
if CustomRoutingDisabled() {
return net.DialTCP(network, laddr, raddr)
}
dialer := NewDialer()
dialer.LocalAddr = laddr
conn, err := dialer.Dial(network, raddr.String())
if err != nil {
return nil, fmt.Errorf("dialing TCP %s: %w", raddr.String(), err)
}
tcpConn, ok := conn.(*Conn).Conn.(*net.TCPConn)
if !ok {
if err := conn.Close(); err != nil {
log.Errorf("Failed to close connection: %v", err)
}
return nil, fmt.Errorf("expected TCP connection, got different type: %T", conn)
}
return tcpConn, nil
}

View File

@@ -0,0 +1,5 @@
package net
func (d *Dialer) init() {
d.Dialer.Control = ControlProtectSocket
}

View File

@@ -7,6 +7,6 @@ import "syscall"
// init configures the net.Dialer Control function to set the fwmark on the socket
func (d *Dialer) init() {
d.Dialer.Control = func(_, _ string, c syscall.RawConn) error {
return SetRawSocketMark(c)
return setRawSocketMark(c)
}
}

View File

@@ -3,4 +3,5 @@
package net
func (d *Dialer) init() {
// implemented on Linux and Android only
}

29
util/net/env.go Normal file
View File

@@ -0,0 +1,29 @@
package net
import (
"os"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/netstack"
)
const (
envDisableCustomRouting = "NB_DISABLE_CUSTOM_ROUTING"
envSkipSocketMark = "NB_SKIP_SOCKET_MARK"
)
func CustomRoutingDisabled() bool {
if netstack.IsEnabled() {
return true
}
return os.Getenv(envDisableCustomRouting) == "true"
}
func SkipSocketMark() bool {
if skipSocketMark := os.Getenv(envSkipSocketMark); skipSocketMark == "true" {
log.Infof("%s is set to true, skipping SO_MARK", envSkipSocketMark)
return true
}
return false
}

37
util/net/listen.go Normal file
View File

@@ -0,0 +1,37 @@
//go:build !ios
package net
import (
"context"
"fmt"
"net"
"sync"
"github.com/pion/transport/v3"
log "github.com/sirupsen/logrus"
)
// ListenUDP listens on the network address and returns a transport.UDPConn
// which includes support for write and close hooks.
func ListenUDP(network string, laddr *net.UDPAddr) (transport.UDPConn, error) {
if CustomRoutingDisabled() {
return net.ListenUDP(network, laddr)
}
conn, err := NewListener().ListenPacket(context.Background(), network, laddr.String())
if err != nil {
return nil, fmt.Errorf("listen UDP: %w", err)
}
packetConn := conn.(*PacketConn)
udpConn, ok := packetConn.PacketConn.(*net.UDPConn)
if !ok {
if err := packetConn.Close(); err != nil {
log.Errorf("Failed to close connection: %v", err)
}
return nil, fmt.Errorf("expected UDPConn, got different type: %T", udpConn)
}
return &UDPConn{UDPConn: udpConn, ID: packetConn.ID, seenAddrs: &sync.Map{}}, nil
}

View File

@@ -1,26 +0,0 @@
package net
import (
"syscall"
log "github.com/sirupsen/logrus"
)
// init configures the net.ListenerConfig Control function to set the fwmark on the socket
func (l *ListenerConfig) init() {
l.ListenConfig.Control = func(_, _ string, c syscall.RawConn) error {
err := c.Control(func(fd uintptr) {
androidProtectSocketLock.Lock()
f := androidProtectSocket
androidProtectSocketLock.Unlock()
if f == nil {
return
}
ok := f(int32(fd))
if !ok {
log.Errorf("failed to protect listener socket: %d", fd)
}
})
return err
}
}

View File

@@ -0,0 +1,6 @@
package net
// init configures the net.ListenerConfig Control function to set the fwmark on the socket
func (l *ListenerConfig) init() {
l.ListenConfig.Control = ControlProtectSocket
}

View File

@@ -9,6 +9,6 @@ import (
// init configures the net.ListenerConfig Control function to set the fwmark on the socket
func (l *ListenerConfig) init() {
l.ListenConfig.Control = func(_, _ string, c syscall.RawConn) error {
return SetRawSocketMark(c)
return setRawSocketMark(c)
}
}

View File

@@ -3,4 +3,5 @@
package net
func (l *ListenerConfig) init() {
// implemented on Linux and Android only
}

View File

@@ -8,7 +8,6 @@ import (
"net"
"sync"
"github.com/pion/transport/v3"
log "github.com/sirupsen/logrus"
)
@@ -146,27 +145,3 @@ func closeConn(id ConnectionID, conn net.PacketConn) error {
return err
}
// ListenUDP listens on the network address and returns a transport.UDPConn
// which includes support for write and close hooks.
func ListenUDP(network string, laddr *net.UDPAddr) (transport.UDPConn, error) {
if CustomRoutingDisabled() {
return net.ListenUDP(network, laddr)
}
conn, err := NewListener().ListenPacket(context.Background(), network, laddr.String())
if err != nil {
return nil, fmt.Errorf("listen UDP: %w", err)
}
packetConn := conn.(*PacketConn)
udpConn, ok := packetConn.PacketConn.(*net.UDPConn)
if !ok {
if err := packetConn.Close(); err != nil {
log.Errorf("Failed to close connection: %v", err)
}
return nil, fmt.Errorf("expected UDPConn, got different type: %T", udpConn)
}
return &UDPConn{UDPConn: udpConn, ID: packetConn.ID, seenAddrs: &sync.Map{}}, nil
}

View File

@@ -2,9 +2,6 @@ package net
import (
"net"
"os"
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/google/uuid"
)
@@ -16,8 +13,6 @@ const (
PreroutingFwmarkRedirected = 0x1BD01
PreroutingFwmarkMasquerade = 0x1BD11
PreroutingFwmarkMasqueradeReturn = 0x1BD12
envDisableCustomRouting = "NB_DISABLE_CUSTOM_ROUTING"
)
// ConnectionID provides a globally unique identifier for network connections.
@@ -31,10 +26,3 @@ type RemoveHookFunc func(connID ConnectionID) error
func GenerateConnID() ConnectionID {
return ConnectionID(uuid.NewString())
}
func CustomRoutingDisabled() bool {
if netstack.IsEnabled() {
return true
}
return os.Getenv(envDisableCustomRouting) == "true"
}

View File

@@ -4,29 +4,42 @@ package net
import (
"fmt"
"os"
"syscall"
log "github.com/sirupsen/logrus"
)
const EnvSkipSocketMark = "NB_SKIP_SOCKET_MARK"
// SetSocketMark sets the SO_MARK option on the given socket connection
func SetSocketMark(conn syscall.Conn) error {
if isSocketMarkDisabled() {
return nil
}
sysconn, err := conn.SyscallConn()
if err != nil {
return fmt.Errorf("get raw conn: %w", err)
}
return SetRawSocketMark(sysconn)
return setRawSocketMark(sysconn)
}
func SetRawSocketMark(conn syscall.RawConn) error {
// SetSocketOpt sets the SO_MARK option on the given file descriptor
func SetSocketOpt(fd int) error {
if isSocketMarkDisabled() {
return nil
}
return setSocketOptInt(fd)
}
func setRawSocketMark(conn syscall.RawConn) error {
var setErr error
err := conn.Control(func(fd uintptr) {
setErr = SetSocketOpt(int(fd))
if isSocketMarkDisabled() {
return
}
setErr = setSocketOptInt(int(fd))
})
if err != nil {
return fmt.Errorf("control: %w", err)
@@ -39,17 +52,18 @@ func SetRawSocketMark(conn syscall.RawConn) error {
return nil
}
func SetSocketOpt(fd int) error {
if CustomRoutingDisabled() {
log.Infof("Custom routing is disabled, skipping SO_MARK")
return nil
}
// Check for the new environment variable
if skipSocketMark := os.Getenv(EnvSkipSocketMark); skipSocketMark == "true" {
log.Info("NB_SKIP_SOCKET_MARK is set to true, skipping SO_MARK")
return nil
}
func setSocketOptInt(fd int) error {
return syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_MARK, NetbirdFwmark)
}
func isSocketMarkDisabled() bool {
if CustomRoutingDisabled() {
log.Infof("Custom routing is disabled, skipping SO_MARK")
return true
}
if SkipSocketMark() {
return true
}
return false
}

View File

@@ -1,14 +1,42 @@
package net
import "sync"
import (
"fmt"
"sync"
"syscall"
)
var (
androidProtectSocketLock sync.Mutex
androidProtectSocket func(fd int32) bool
)
func SetAndroidProtectSocketFn(f func(fd int32) bool) {
func SetAndroidProtectSocketFn(fn func(fd int32) bool) {
androidProtectSocketLock.Lock()
androidProtectSocket = f
androidProtectSocket = fn
androidProtectSocketLock.Unlock()
}
// ControlProtectSocket is a Control function that sets the fwmark on the socket
func ControlProtectSocket(_, _ string, c syscall.RawConn) error {
var aErr error
err := c.Control(func(fd uintptr) {
androidProtectSocketLock.Lock()
defer androidProtectSocketLock.Unlock()
if androidProtectSocket == nil {
aErr = fmt.Errorf("socket protection function not set")
return
}
if !androidProtectSocket(int32(fd)) {
aErr = fmt.Errorf("failed to protect socket via Android")
}
})
if err != nil {
return err
}
return aErr
}