mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-16 23:36:39 +00:00
Compare commits
41 Commits
account-re
...
add-static
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
75a02c3e38 | ||
|
|
052525093a | ||
|
|
b4b9aedf5a | ||
|
|
07f0f9fdbd | ||
|
|
c6641be94b | ||
|
|
89cf8a55e2 | ||
|
|
00c3b67182 | ||
|
|
9203690033 | ||
|
|
9683da54b0 | ||
|
|
0e48a772ff | ||
|
|
f118d81d32 | ||
|
|
ca12bc6953 | ||
|
|
9810386937 | ||
|
|
f1625b32bd | ||
|
|
0ecd5f2118 | ||
|
|
940d0c48c6 | ||
|
|
56cecf849e | ||
|
|
05c4aa7c2c | ||
|
|
2a5cb16494 | ||
|
|
9db1932664 | ||
|
|
1bbabf70b0 | ||
|
|
aa575d6f44 | ||
|
|
f66bbcc54c | ||
|
|
5dd6a08ea6 | ||
|
|
eb5d0569ae | ||
|
|
52ea2e84e9 | ||
|
|
78fab877c0 | ||
|
|
65a94f695f | ||
|
|
ec543f89fb | ||
|
|
a7d5c52203 | ||
|
|
582bb58714 | ||
|
|
121dfda915 | ||
|
|
a1c5287b7c | ||
|
|
12f442439a | ||
|
|
d9b691b8a5 | ||
|
|
4aee3c9e33 | ||
|
|
44e799c687 | ||
|
|
be78efbd42 | ||
|
|
6886691213 | ||
|
|
b48afd92fd | ||
|
|
39329e12a1 |
41
.github/workflows/golang-test-linux.yml
vendored
41
.github/workflows/golang-test-linux.yml
vendored
@@ -52,6 +52,47 @@ jobs:
|
||||
- name: Test
|
||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 10m -p 1 ./...
|
||||
|
||||
benchmark:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
arch: [ '386','amd64' ]
|
||||
store: [ 'sqlite', 'postgres' ]
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: "1.23.x"
|
||||
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/go/pkg/mod
|
||||
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-go-
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install dependencies
|
||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
|
||||
|
||||
- name: Install 32-bit libpcap
|
||||
if: matrix.arch == '386'
|
||||
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
|
||||
|
||||
- name: Install modules
|
||||
run: go mod tidy
|
||||
|
||||
- name: check git status
|
||||
run: git --no-pager diff --exit-code
|
||||
|
||||
- name: Test
|
||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -run=^$ -bench=. -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 10m -p 1 ./...
|
||||
|
||||
test_client_on_docker:
|
||||
runs-on: ubuntu-20.04
|
||||
steps:
|
||||
|
||||
2
.github/workflows/release.yml
vendored
2
.github/workflows/release.yml
vendored
@@ -9,7 +9,7 @@ on:
|
||||
pull_request:
|
||||
|
||||
env:
|
||||
SIGN_PIPE_VER: "v0.0.16"
|
||||
SIGN_PIPE_VER: "v0.0.17"
|
||||
GORELEASER_VER: "v2.3.2"
|
||||
PRODUCT_NAME: "NetBird"
|
||||
COPYRIGHT: "Wiretrustee UG (haftungsbeschreankt)"
|
||||
|
||||
@@ -17,8 +17,12 @@
|
||||
<img src="https://img.shields.io/badge/license-BSD--3-blue" />
|
||||
</a>
|
||||
<br>
|
||||
<a href="https://join.slack.com/t/netbirdio/shared_invite/zt-2p5zwhm4g-8fHollzrQa5y4PZF5AEpvQ">
|
||||
<a href="https://join.slack.com/t/netbirdio/shared_invite/zt-2utg2ncdz-W7LEB6toRBLE1Jca37dYpg">
|
||||
<img src="https://img.shields.io/badge/slack-@netbird-red.svg?logo=slack"/>
|
||||
</a>
|
||||
<br>
|
||||
<a href="https://gurubase.io/g/netbird">
|
||||
<img src="https://img.shields.io/badge/Gurubase-Ask%20NetBird%20Guru-006BFF"/>
|
||||
</a>
|
||||
</p>
|
||||
</div>
|
||||
@@ -30,7 +34,7 @@
|
||||
<br/>
|
||||
See <a href="https://netbird.io/docs/">Documentation</a>
|
||||
<br/>
|
||||
Join our <a href="https://join.slack.com/t/netbirdio/shared_invite/zt-2p5zwhm4g-8fHollzrQa5y4PZF5AEpvQ">Slack channel</a>
|
||||
Join our <a href="https://join.slack.com/t/netbirdio/shared_invite/zt-2utg2ncdz-W7LEB6toRBLE1Jca37dYpg">Slack channel</a>
|
||||
<br/>
|
||||
|
||||
</strong>
|
||||
|
||||
@@ -106,7 +106,7 @@ func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsRead
|
||||
|
||||
// todo do not throw error in case of cancelled context
|
||||
ctx = internal.CtxInitState(ctx)
|
||||
connectClient := internal.NewConnectClient(ctx, cfg, c.recorder)
|
||||
connectClient := internal.NewConnectClient(ctx, cfg, c.recorder, &system.StaticInfo{})
|
||||
return connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, dns.items, dnsReadyListener)
|
||||
}
|
||||
|
||||
@@ -132,7 +132,7 @@ func (c *Client) RunWithoutLogin(dns *DNSList, dnsReadyListener DnsReadyListener
|
||||
|
||||
// todo do not throw error in case of cancelled context
|
||||
ctx = internal.CtxInitState(ctx)
|
||||
connectClient := internal.NewConnectClient(ctx, cfg, c.recorder)
|
||||
connectClient := internal.NewConnectClient(ctx, cfg, c.recorder, &system.StaticInfo{})
|
||||
return connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, dns.items, dnsReadyListener)
|
||||
}
|
||||
|
||||
|
||||
@@ -131,7 +131,7 @@ func (a *Auth) loginWithSetupKeyAndSaveConfig(setupKey string, deviceName string
|
||||
ctxWithValues := context.WithValue(a.ctx, system.DeviceNameCtxKey, deviceName)
|
||||
|
||||
err := a.withBackOff(a.ctx, func() error {
|
||||
backoffErr := internal.Login(ctxWithValues, a.config, setupKey, "")
|
||||
backoffErr := internal.Login(ctxWithValues, a.config, setupKey, "", &system.StaticInfo{})
|
||||
if s, ok := gstatus.FromError(backoffErr); ok && (s.Code() == codes.PermissionDenied) {
|
||||
// we got an answer from management, exit backoff earlier
|
||||
return backoff.Permanent(backoffErr)
|
||||
@@ -162,7 +162,7 @@ func (a *Auth) login(urlOpener URLOpener) error {
|
||||
|
||||
// check if we need to generate JWT token
|
||||
err := a.withBackOff(a.ctx, func() (err error) {
|
||||
needsLogin, err = internal.IsLoginRequired(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config.SSHKey)
|
||||
needsLogin, err = internal.IsLoginRequired(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config.SSHKey, &system.StaticInfo{})
|
||||
return
|
||||
})
|
||||
if err != nil {
|
||||
@@ -179,7 +179,7 @@ func (a *Auth) login(urlOpener URLOpener) error {
|
||||
}
|
||||
|
||||
err = a.withBackOff(a.ctx, func() error {
|
||||
err := internal.Login(a.ctx, a.config, "", jwtToken)
|
||||
err := internal.Login(a.ctx, a.config, "", jwtToken, &system.StaticInfo{})
|
||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -137,9 +137,11 @@ var loginCmd = &cobra.Command{
|
||||
|
||||
func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *internal.Config, setupKey string) error {
|
||||
needsLogin := false
|
||||
staticInfoChan := system.GetStaticInfoInBackground(ctx)
|
||||
staticInfo := <-staticInfoChan
|
||||
|
||||
err := WithBackOff(func() error {
|
||||
err := internal.Login(ctx, config, "", "")
|
||||
err := internal.Login(ctx, config, "", "", staticInfo)
|
||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
|
||||
needsLogin = true
|
||||
return nil
|
||||
@@ -162,7 +164,7 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *internal.C
|
||||
var lastError error
|
||||
|
||||
err = WithBackOff(func() error {
|
||||
err := internal.Login(ctx, config, setupKey, jwtToken)
|
||||
err := internal.Login(ctx, config, setupKey, jwtToken, staticInfo)
|
||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
|
||||
lastError = err
|
||||
return nil
|
||||
|
||||
@@ -152,6 +152,8 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
|
||||
return err
|
||||
}
|
||||
|
||||
staticInfoChan := system.GetStaticInfoInBackground(ctx)
|
||||
|
||||
config, err := internal.UpdateOrCreateConfig(ic)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get config file: %v", err)
|
||||
@@ -171,7 +173,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
|
||||
r := peer.NewRecorder(config.ManagementURL.String())
|
||||
r.GetFullStatus()
|
||||
|
||||
connectClient := internal.NewConnectClient(ctx, config, r)
|
||||
connectClient := internal.NewConnectClient(ctx, config, r, <-staticInfoChan)
|
||||
return connectClient.Run()
|
||||
}
|
||||
|
||||
|
||||
@@ -83,9 +83,11 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
||||
}
|
||||
|
||||
// persist early to ensure cleanup of chains
|
||||
if err := stateManager.PersistState(context.Background()); err != nil {
|
||||
log.Errorf("failed to persist state: %v", err)
|
||||
}
|
||||
go func() {
|
||||
if err := stateManager.PersistState(context.Background()); err != nil {
|
||||
log.Errorf("failed to persist state: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -99,9 +99,11 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
||||
}
|
||||
|
||||
// persist early
|
||||
if err := stateManager.PersistState(context.Background()); err != nil {
|
||||
log.Errorf("failed to persist state: %v", err)
|
||||
}
|
||||
go func() {
|
||||
if err := stateManager.PersistState(context.Background()); err != nil {
|
||||
log.Errorf("failed to persist state: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -197,7 +199,7 @@ func (m *Manager) AllowNetbird() error {
|
||||
|
||||
var chain *nftables.Chain
|
||||
for _, c := range chains {
|
||||
if c.Table.Name == tableNameFilter && c.Name == chainNameForward {
|
||||
if c.Table.Name == tableNameFilter && c.Name == chainNameInput {
|
||||
chain = c
|
||||
break
|
||||
}
|
||||
@@ -274,7 +276,7 @@ func (m *Manager) resetNetbirdInputRules() error {
|
||||
|
||||
func (m *Manager) deleteNetbirdInputRules(chains []*nftables.Chain) {
|
||||
for _, c := range chains {
|
||||
if c.Table.Name == "filter" && c.Name == "INPUT" {
|
||||
if c.Table.Name == tableNameFilter && c.Name == chainNameInput {
|
||||
rules, err := m.rConn.GetRules(c.Table, c)
|
||||
if err != nil {
|
||||
log.Errorf("get rules for chain %q: %v", c.Name, err)
|
||||
@@ -349,7 +351,9 @@ func (m *Manager) applyAllowNetbirdRules(chain *nftables.Chain) {
|
||||
Register: 1,
|
||||
Data: ifname(m.wgIface.Name()),
|
||||
},
|
||||
&expr.Verdict{},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictAccept,
|
||||
},
|
||||
},
|
||||
UserData: []byte(allowNetbirdInputRuleID),
|
||||
}
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
package nftables
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os/exec"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -225,3 +227,105 @@ func TestNFtablesCreatePerformance(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func runIptablesSave(t *testing.T) (string, string) {
|
||||
t.Helper()
|
||||
var stdout, stderr bytes.Buffer
|
||||
cmd := exec.Command("iptables-save")
|
||||
cmd.Stdout = &stdout
|
||||
cmd.Stderr = &stderr
|
||||
|
||||
err := cmd.Run()
|
||||
require.NoError(t, err, "iptables-save failed to run")
|
||||
|
||||
return stdout.String(), stderr.String()
|
||||
}
|
||||
|
||||
func verifyIptablesOutput(t *testing.T, stdout, stderr string) {
|
||||
t.Helper()
|
||||
// Check for any incompatibility warnings
|
||||
require.NotContains(t,
|
||||
stderr,
|
||||
"incompatible",
|
||||
"iptables-save produced compatibility warning. Full stderr: %s",
|
||||
stderr,
|
||||
)
|
||||
|
||||
// Verify standard tables are present
|
||||
expectedTables := []string{
|
||||
"*filter",
|
||||
"*nat",
|
||||
"*mangle",
|
||||
}
|
||||
|
||||
for _, table := range expectedTables {
|
||||
require.Contains(t,
|
||||
stdout,
|
||||
table,
|
||||
"iptables-save output missing expected table: %s\nFull stdout: %s",
|
||||
table,
|
||||
stdout,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
|
||||
if check() != NFTABLES {
|
||||
t.Skip("nftables not supported on this system")
|
||||
}
|
||||
|
||||
if _, err := exec.LookPath("iptables-save"); err != nil {
|
||||
t.Skipf("iptables-save not available on this system: %v", err)
|
||||
}
|
||||
|
||||
// First ensure iptables-nft tables exist by running iptables-save
|
||||
stdout, stderr := runIptablesSave(t)
|
||||
verifyIptablesOutput(t, stdout, stderr)
|
||||
|
||||
manager, err := Create(ifaceMock)
|
||||
require.NoError(t, err, "failed to create manager")
|
||||
require.NoError(t, manager.Init(nil))
|
||||
|
||||
t.Cleanup(func() {
|
||||
err := manager.Reset(nil)
|
||||
require.NoError(t, err, "failed to reset manager state")
|
||||
|
||||
// Verify iptables output after reset
|
||||
stdout, stderr := runIptablesSave(t)
|
||||
verifyIptablesOutput(t, stdout, stderr)
|
||||
})
|
||||
|
||||
ip := net.ParseIP("100.96.0.1")
|
||||
_, err = manager.AddPeerFiltering(
|
||||
ip,
|
||||
fw.ProtocolTCP,
|
||||
nil,
|
||||
&fw.Port{Values: []int{80}},
|
||||
fw.RuleDirectionIN,
|
||||
fw.ActionAccept,
|
||||
"",
|
||||
"test rule",
|
||||
)
|
||||
require.NoError(t, err, "failed to add peer filtering rule")
|
||||
|
||||
_, err = manager.AddRouteFiltering(
|
||||
[]netip.Prefix{netip.MustParsePrefix("192.168.2.0/24")},
|
||||
netip.MustParsePrefix("10.1.0.0/24"),
|
||||
fw.ProtocolTCP,
|
||||
nil,
|
||||
&fw.Port{Values: []int{443}},
|
||||
fw.ActionAccept,
|
||||
)
|
||||
require.NoError(t, err, "failed to add route filtering rule")
|
||||
|
||||
pair := fw.RouterPair{
|
||||
Source: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
Destination: netip.MustParsePrefix("10.0.0.0/24"),
|
||||
Masquerade: true,
|
||||
}
|
||||
err = manager.AddNatRule(pair)
|
||||
require.NoError(t, err, "failed to add NAT rule")
|
||||
|
||||
stdout, stderr = runIptablesSave(t)
|
||||
verifyIptablesOutput(t, stdout, stderr)
|
||||
}
|
||||
|
||||
@@ -239,7 +239,7 @@ func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
|
||||
// SetLegacyManagement doesn't need to be implemented for this manager
|
||||
func (m *Manager) SetLegacyManagement(isLegacy bool) error {
|
||||
if m.nativeFirewall == nil {
|
||||
return errRouteNotSupported
|
||||
return nil
|
||||
}
|
||||
return m.nativeFirewall.SetLegacyManagement(isLegacy)
|
||||
}
|
||||
|
||||
12
client/iface/bind/control_android.go
Normal file
12
client/iface/bind/control_android.go
Normal file
@@ -0,0 +1,12 @@
|
||||
package bind
|
||||
|
||||
import (
|
||||
wireguard "golang.zx2c4.com/wireguard/conn"
|
||||
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
func init() {
|
||||
// ControlFns is not thread safe and should only be modified during init.
|
||||
*wireguard.ControlFns = append(*wireguard.ControlFns, nbnet.ControlProtectSocket)
|
||||
}
|
||||
@@ -164,7 +164,7 @@ func UpdateOrCreateConfig(input ConfigInput) (*Config, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = util.WriteJsonWithRestrictedPermission(input.ConfigPath, cfg)
|
||||
err = util.WriteJsonWithRestrictedPermission(context.Background(), input.ConfigPath, cfg)
|
||||
return cfg, err
|
||||
}
|
||||
|
||||
@@ -185,7 +185,7 @@ func CreateInMemoryConfig(input ConfigInput) (*Config, error) {
|
||||
|
||||
// WriteOutConfig write put the prepared config to the given path
|
||||
func WriteOutConfig(path string, config *Config) error {
|
||||
return util.WriteJson(path, config)
|
||||
return util.WriteJson(context.Background(), path, config)
|
||||
}
|
||||
|
||||
// createNewConfig creates a new config generating a new Wireguard key and saving to file
|
||||
@@ -215,7 +215,7 @@ func update(input ConfigInput) (*Config, error) {
|
||||
}
|
||||
|
||||
if updated {
|
||||
if err := util.WriteJson(input.ConfigPath, config); err != nil {
|
||||
if err := util.WriteJson(context.Background(), input.ConfigPath, config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
@@ -40,19 +40,21 @@ type ConnectClient struct {
|
||||
statusRecorder *peer.Status
|
||||
engine *Engine
|
||||
engineMutex sync.Mutex
|
||||
staticInfo *system.StaticInfo
|
||||
}
|
||||
|
||||
func NewConnectClient(
|
||||
ctx context.Context,
|
||||
config *Config,
|
||||
statusRecorder *peer.Status,
|
||||
|
||||
staticInfo *system.StaticInfo,
|
||||
) *ConnectClient {
|
||||
return &ConnectClient{
|
||||
ctx: ctx,
|
||||
config: config,
|
||||
statusRecorder: statusRecorder,
|
||||
engineMutex: sync.Mutex{},
|
||||
staticInfo: staticInfo,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -157,7 +159,8 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHold
|
||||
|
||||
engineCtx, cancel := context.WithCancel(c.ctx)
|
||||
defer func() {
|
||||
c.statusRecorder.MarkManagementDisconnected(state.err)
|
||||
_, err := state.Status()
|
||||
c.statusRecorder.MarkManagementDisconnected(err)
|
||||
c.statusRecorder.CleanLocalPeerState()
|
||||
cancel()
|
||||
}()
|
||||
@@ -178,7 +181,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHold
|
||||
}()
|
||||
|
||||
// connect (just a connection, no stream yet) and login to Management Service to get an initial global Wiretrustee config
|
||||
loginResp, err := loginToManagement(engineCtx, mgmClient, publicSSHKey)
|
||||
loginResp, err := loginToManagement(engineCtx, mgmClient, publicSSHKey, c.staticInfo)
|
||||
if err != nil {
|
||||
log.Debug(err)
|
||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied) {
|
||||
@@ -231,6 +234,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHold
|
||||
|
||||
relayURLs, token := parseRelayInfo(loginResp)
|
||||
relayManager := relayClient.NewManager(engineCtx, relayURLs, myPrivateKey.PublicKey().String())
|
||||
c.statusRecorder.SetRelayMgr(relayManager)
|
||||
if len(relayURLs) > 0 {
|
||||
if token != nil {
|
||||
if err := relayManager.UpdateToken(token); err != nil {
|
||||
@@ -241,9 +245,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHold
|
||||
log.Infof("connecting to the Relay service(s): %s", strings.Join(relayURLs, ", "))
|
||||
if err = relayManager.Serve(); err != nil {
|
||||
log.Error(err)
|
||||
return wrapErr(err)
|
||||
}
|
||||
c.statusRecorder.SetRelayMgr(relayManager)
|
||||
}
|
||||
|
||||
peerConfig := loginResp.GetPeerConfig()
|
||||
@@ -257,7 +259,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHold
|
||||
checks := loginResp.GetChecks()
|
||||
|
||||
c.engineMutex.Lock()
|
||||
c.engine = NewEngineWithProbes(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, probes, checks)
|
||||
c.engine = NewEngineWithProbes(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, probes, checks, c.staticInfo)
|
||||
|
||||
c.engineMutex.Unlock()
|
||||
|
||||
@@ -424,14 +426,14 @@ func connectToSignal(ctx context.Context, wtConfig *mgmProto.WiretrusteeConfig,
|
||||
}
|
||||
|
||||
// loginToManagement creates Management Services client, establishes a connection, logs-in and gets a global Wiretrustee config (signal, turn, stun hosts, etc)
|
||||
func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte) (*mgmProto.LoginResponse, error) {
|
||||
func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte, staticInfo *system.StaticInfo) (*mgmProto.LoginResponse, error) {
|
||||
|
||||
serverPublicKey, err := client.GetServerPublicKey()
|
||||
if err != nil {
|
||||
return nil, gstatus.Errorf(codes.FailedPrecondition, "failed while getting Management Service public key: %s", err)
|
||||
}
|
||||
|
||||
sysInfo := system.GetInfo(ctx)
|
||||
sysInfo := system.GetInfo(ctx, staticInfo)
|
||||
loginResp, err := client.Login(*serverPublicKey, sysInfo, pubSSHKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/mitchellh/hashstructure/v2"
|
||||
@@ -323,12 +322,12 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
|
||||
log.Error(err)
|
||||
}
|
||||
|
||||
// persist dns state right away
|
||||
ctx, cancel := context.WithTimeout(s.ctx, 3*time.Second)
|
||||
defer cancel()
|
||||
if err := s.stateManager.PersistState(ctx); err != nil {
|
||||
log.Errorf("Failed to persist dns state: %v", err)
|
||||
}
|
||||
go func() {
|
||||
// persist dns state right away
|
||||
if err := s.stateManager.PersistState(s.ctx); err != nil {
|
||||
log.Errorf("Failed to persist dns state: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
if s.searchDomainNotifier != nil {
|
||||
s.searchDomainNotifier.onNewSearchDomains(s.SearchDomains())
|
||||
@@ -533,12 +532,11 @@ func (s *DefaultServer) upstreamCallbacks(
|
||||
l.Errorf("Failed to apply nameserver deactivation on the host: %v", err)
|
||||
}
|
||||
|
||||
// persist dns state right away
|
||||
ctx, cancel := context.WithTimeout(s.ctx, 3*time.Second)
|
||||
defer cancel()
|
||||
if err := s.stateManager.PersistState(ctx); err != nil {
|
||||
l.Errorf("Failed to persist dns state: %v", err)
|
||||
}
|
||||
go func() {
|
||||
if err := s.stateManager.PersistState(s.ctx); err != nil {
|
||||
l.Errorf("Failed to persist dns state: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
if runtime.GOOS == "android" && nsGroup.Primary && len(s.hostsDNSHolder.get()) > 0 {
|
||||
s.addHostRootZone()
|
||||
|
||||
@@ -782,7 +782,7 @@ func TestDNSPermanent_matchOnly(t *testing.T) {
|
||||
Port: 53,
|
||||
},
|
||||
},
|
||||
Domains: []string{"customdomain.com"},
|
||||
Domains: []string{"google.com"},
|
||||
Primary: false,
|
||||
},
|
||||
},
|
||||
@@ -804,7 +804,7 @@ func TestDNSPermanent_matchOnly(t *testing.T) {
|
||||
if ips[0] != zoneRecords[0].RData {
|
||||
t.Fatalf("invalid zone record: %v", err)
|
||||
}
|
||||
_, err = resolver.LookupHost(context.Background(), "customdomain.com")
|
||||
_, err = resolver.LookupHost(context.Background(), "google.com")
|
||||
if err != nil {
|
||||
t.Errorf("failed to resolve: %s", err)
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"reflect"
|
||||
"runtime"
|
||||
"slices"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
@@ -38,7 +39,6 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
|
||||
|
||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
@@ -171,7 +171,9 @@ type Engine struct {
|
||||
|
||||
relayManager *relayClient.Manager
|
||||
stateManager *statemanager.Manager
|
||||
srWatcher *guard.SRWatcher
|
||||
srWatcher *guard.SRWatcher
|
||||
|
||||
staticInfo *system.StaticInfo
|
||||
}
|
||||
|
||||
// Peer is an instance of the Connection Peer
|
||||
@@ -180,8 +182,8 @@ type Peer struct {
|
||||
WgAllowedIps string
|
||||
}
|
||||
|
||||
// NewEngine creates a new Connection Engine
|
||||
func NewEngine(
|
||||
// newEngine creates a new Connection Engine
|
||||
func newEngine(
|
||||
clientCtx context.Context,
|
||||
clientCancel context.CancelFunc,
|
||||
signalClient signal.Client,
|
||||
@@ -203,6 +205,7 @@ func NewEngine(
|
||||
statusRecorder,
|
||||
nil,
|
||||
checks,
|
||||
nil,
|
||||
)
|
||||
}
|
||||
|
||||
@@ -218,6 +221,7 @@ func NewEngineWithProbes(
|
||||
statusRecorder *peer.Status,
|
||||
probes *ProbeHolder,
|
||||
checks []*mgmProto.Checks,
|
||||
staticInfo *system.StaticInfo,
|
||||
) *Engine {
|
||||
engine := &Engine{
|
||||
clientCtx: clientCtx,
|
||||
@@ -237,6 +241,7 @@ func NewEngineWithProbes(
|
||||
statusRecorder: statusRecorder,
|
||||
probes: probes,
|
||||
checks: checks,
|
||||
staticInfo: staticInfo,
|
||||
}
|
||||
if path := statemanager.GetDefaultStatePath(); path != "" {
|
||||
engine.stateManager = statemanager.New(path)
|
||||
@@ -297,7 +302,7 @@ func (e *Engine) Stop() error {
|
||||
if err := e.stateManager.Stop(ctx); err != nil {
|
||||
return fmt.Errorf("failed to stop state manager: %w", err)
|
||||
}
|
||||
if err := e.stateManager.PersistState(ctx); err != nil {
|
||||
if err := e.stateManager.PersistState(context.Background()); err != nil {
|
||||
log.Errorf("failed to persist state: %v", err)
|
||||
}
|
||||
|
||||
@@ -538,6 +543,7 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
||||
|
||||
relayMsg := wCfg.GetRelay()
|
||||
if relayMsg != nil {
|
||||
// when we receive token we expect valid address list too
|
||||
c := &auth.Token{
|
||||
Payload: relayMsg.GetTokenPayload(),
|
||||
Signature: relayMsg.GetTokenSignature(),
|
||||
@@ -546,9 +552,16 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
||||
log.Errorf("failed to update relay token: %v", err)
|
||||
return fmt.Errorf("update relay token: %w", err)
|
||||
}
|
||||
|
||||
e.relayManager.UpdateServerURLs(relayMsg.Urls)
|
||||
|
||||
// Just in case the agent started with an MGM server where the relay was disabled but was later enabled.
|
||||
// We can ignore all errors because the guard will manage the reconnection retries.
|
||||
_ = e.relayManager.Serve()
|
||||
} else {
|
||||
e.relayManager.UpdateServerURLs(nil)
|
||||
}
|
||||
|
||||
// todo update relay address in the relay manager
|
||||
// todo update signal
|
||||
}
|
||||
|
||||
@@ -574,10 +587,10 @@ func (e *Engine) updateChecksIfNew(checks []*mgmProto.Checks) error {
|
||||
}
|
||||
e.checks = checks
|
||||
|
||||
info, err := system.GetInfoWithChecks(e.ctx, checks)
|
||||
info, err := system.GetInfoWithChecks(e.ctx, checks, e.staticInfo)
|
||||
if err != nil {
|
||||
log.Warnf("failed to get system info with checks: %v", err)
|
||||
info = system.GetInfo(e.ctx)
|
||||
info = system.GetInfo(e.ctx, e.staticInfo)
|
||||
}
|
||||
|
||||
if err := e.mgmClient.SyncMeta(info); err != nil {
|
||||
@@ -641,6 +654,10 @@ func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
|
||||
}
|
||||
|
||||
func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
|
||||
if e.wgInterface == nil {
|
||||
return errors.New("wireguard interface is not initialized")
|
||||
}
|
||||
|
||||
if e.wgInterface.Address().String() != conf.Address {
|
||||
oldAddr := e.wgInterface.Address().String()
|
||||
log.Debugf("updating peer address from %s to %s", oldAddr, conf.Address)
|
||||
@@ -673,10 +690,10 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
|
||||
// E.g. when a new peer has been registered and we are allowed to connect to it.
|
||||
func (e *Engine) receiveManagementEvents() {
|
||||
go func() {
|
||||
info, err := system.GetInfoWithChecks(e.ctx, e.checks)
|
||||
info, err := system.GetInfoWithChecks(e.ctx, e.checks, e.staticInfo)
|
||||
if err != nil {
|
||||
log.Warnf("failed to get system info with checks: %v", err)
|
||||
info = system.GetInfo(e.ctx)
|
||||
info = system.GetInfo(e.ctx, e.staticInfo)
|
||||
}
|
||||
|
||||
// err = e.mgmClient.Sync(info, e.handleSync)
|
||||
@@ -1180,7 +1197,7 @@ func (e *Engine) close() {
|
||||
}
|
||||
|
||||
func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, error) {
|
||||
info := system.GetInfo(e.ctx)
|
||||
info := system.GetInfo(e.ctx, e.staticInfo)
|
||||
netMap, err := e.mgmClient.GetNetworkMap(info)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
@@ -1481,6 +1498,17 @@ func (e *Engine) stopDNSServer() {
|
||||
|
||||
// isChecksEqual checks if two slices of checks are equal.
|
||||
func isChecksEqual(checks []*mgmProto.Checks, oChecks []*mgmProto.Checks) bool {
|
||||
for _, check := range checks {
|
||||
sort.Slice(check.Files, func(i, j int) bool {
|
||||
return check.Files[i] < check.Files[j]
|
||||
})
|
||||
}
|
||||
for _, oCheck := range oChecks {
|
||||
sort.Slice(oCheck.Files, func(i, j int) bool {
|
||||
return oCheck.Files[i] < oCheck.Files[j]
|
||||
})
|
||||
}
|
||||
|
||||
return slices.EqualFunc(checks, oChecks, func(checks, oChecks *mgmProto.Checks) bool {
|
||||
return slices.Equal(checks.Files, oChecks.Files)
|
||||
})
|
||||
|
||||
@@ -84,7 +84,7 @@ func TestEngine_SSH(t *testing.T) {
|
||||
defer cancel()
|
||||
|
||||
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String())
|
||||
engine := NewEngine(
|
||||
engine := newEngine(
|
||||
ctx, cancel,
|
||||
&signal.MockClient{},
|
||||
&mgmt.MockClient{},
|
||||
@@ -229,7 +229,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
|
||||
defer cancel()
|
||||
|
||||
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String())
|
||||
engine := NewEngine(
|
||||
engine := newEngine(
|
||||
ctx, cancel,
|
||||
&signal.MockClient{},
|
||||
&mgmt.MockClient{},
|
||||
@@ -434,7 +434,7 @@ func TestEngine_Sync(t *testing.T) {
|
||||
return nil
|
||||
}
|
||||
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String())
|
||||
engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{SyncFunc: syncFunc}, relayMgr, &EngineConfig{
|
||||
engine := newEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{SyncFunc: syncFunc}, relayMgr, &EngineConfig{
|
||||
WgIfaceName: "utun103",
|
||||
WgAddr: "100.64.0.1/24",
|
||||
WgPrivateKey: key,
|
||||
@@ -594,7 +594,7 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
|
||||
wgAddr := fmt.Sprintf("100.66.%d.1/24", n)
|
||||
|
||||
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String())
|
||||
engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, relayMgr, &EngineConfig{
|
||||
engine := newEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, relayMgr, &EngineConfig{
|
||||
WgIfaceName: wgIfaceName,
|
||||
WgAddr: wgAddr,
|
||||
WgPrivateKey: key,
|
||||
@@ -774,7 +774,7 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
|
||||
wgAddr := fmt.Sprintf("100.66.%d.1/24", n)
|
||||
|
||||
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String())
|
||||
engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, relayMgr, &EngineConfig{
|
||||
engine := newEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, relayMgr, &EngineConfig{
|
||||
WgIfaceName: wgIfaceName,
|
||||
WgAddr: wgAddr,
|
||||
WgPrivateKey: key,
|
||||
@@ -1006,6 +1006,99 @@ func Test_ParseNATExternalIPMappings(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func Test_CheckFilesEqual(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
inputChecks1 []*mgmtProto.Checks
|
||||
inputChecks2 []*mgmtProto.Checks
|
||||
expectedBool bool
|
||||
}{
|
||||
{
|
||||
name: "Equal Files In Equal Order Should Return True",
|
||||
inputChecks1: []*mgmtProto.Checks{
|
||||
{
|
||||
Files: []string{
|
||||
"testfile1",
|
||||
"testfile2",
|
||||
},
|
||||
},
|
||||
},
|
||||
inputChecks2: []*mgmtProto.Checks{
|
||||
{
|
||||
Files: []string{
|
||||
"testfile1",
|
||||
"testfile2",
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedBool: true,
|
||||
},
|
||||
{
|
||||
name: "Equal Files In Reverse Order Should Return True",
|
||||
inputChecks1: []*mgmtProto.Checks{
|
||||
{
|
||||
Files: []string{
|
||||
"testfile1",
|
||||
"testfile2",
|
||||
},
|
||||
},
|
||||
},
|
||||
inputChecks2: []*mgmtProto.Checks{
|
||||
{
|
||||
Files: []string{
|
||||
"testfile2",
|
||||
"testfile1",
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedBool: true,
|
||||
},
|
||||
{
|
||||
name: "Unequal Files Should Return False",
|
||||
inputChecks1: []*mgmtProto.Checks{
|
||||
{
|
||||
Files: []string{
|
||||
"testfile1",
|
||||
"testfile2",
|
||||
},
|
||||
},
|
||||
},
|
||||
inputChecks2: []*mgmtProto.Checks{
|
||||
{
|
||||
Files: []string{
|
||||
"testfile1",
|
||||
"testfile3",
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedBool: false,
|
||||
},
|
||||
{
|
||||
name: "Compared With Empty Should Return False",
|
||||
inputChecks1: []*mgmtProto.Checks{
|
||||
{
|
||||
Files: []string{
|
||||
"testfile1",
|
||||
"testfile2",
|
||||
},
|
||||
},
|
||||
},
|
||||
inputChecks2: []*mgmtProto.Checks{
|
||||
{
|
||||
Files: []string{},
|
||||
},
|
||||
},
|
||||
expectedBool: false,
|
||||
},
|
||||
}
|
||||
for _, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
result := isChecksEqual(testCase.inputChecks1, testCase.inputChecks2)
|
||||
assert.Equal(t, testCase.expectedBool, result, "result should match expected bool")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey string, i int, mgmtAddr string, signalAddr string) (*Engine, error) {
|
||||
key, err := wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
@@ -1025,7 +1118,7 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin
|
||||
return nil, err
|
||||
}
|
||||
|
||||
info := system.GetInfo(ctx)
|
||||
info := system.GetInfo(ctx, nil)
|
||||
resp, err := mgmtClient.Register(*publicKey, setupKey, "", info, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -1047,7 +1140,7 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin
|
||||
}
|
||||
|
||||
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String())
|
||||
e, err := NewEngine(ctx, cancel, signalClient, mgmtClient, relayMgr, conf, MobileDependency{}, peer.NewRecorder("https://mgm"), nil), nil
|
||||
e, err := newEngine(ctx, cancel, signalClient, mgmtClient, relayMgr, conf, MobileDependency{}, peer.NewRecorder("https://mgm"), nil), nil
|
||||
e.ctx = ctx
|
||||
return e, err
|
||||
}
|
||||
|
||||
@@ -17,7 +17,7 @@ import (
|
||||
)
|
||||
|
||||
// IsLoginRequired check that the server is support SSO or not
|
||||
func IsLoginRequired(ctx context.Context, privateKey string, mgmURL *url.URL, sshKey string) (bool, error) {
|
||||
func IsLoginRequired(ctx context.Context, privateKey string, mgmURL *url.URL, sshKey string, staticInfo *system.StaticInfo) (bool, error) {
|
||||
mgmClient, err := getMgmClient(ctx, privateKey, mgmURL)
|
||||
if err != nil {
|
||||
return false, err
|
||||
@@ -38,7 +38,7 @@ func IsLoginRequired(ctx context.Context, privateKey string, mgmURL *url.URL, ss
|
||||
return false, err
|
||||
}
|
||||
|
||||
_, err = doMgmLogin(ctx, mgmClient, pubSSHKey)
|
||||
_, err = doMgmLogin(ctx, mgmClient, pubSSHKey, staticInfo)
|
||||
if isLoginNeeded(err) {
|
||||
return true, nil
|
||||
}
|
||||
@@ -46,7 +46,7 @@ func IsLoginRequired(ctx context.Context, privateKey string, mgmURL *url.URL, ss
|
||||
}
|
||||
|
||||
// Login or register the client
|
||||
func Login(ctx context.Context, config *Config, setupKey string, jwtToken string) error {
|
||||
func Login(ctx context.Context, config *Config, setupKey string, jwtToken string, staticInfo *system.StaticInfo) error {
|
||||
mgmClient, err := getMgmClient(ctx, config.PrivateKey, config.ManagementURL)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -67,10 +67,10 @@ func Login(ctx context.Context, config *Config, setupKey string, jwtToken string
|
||||
return err
|
||||
}
|
||||
|
||||
serverKey, err := doMgmLogin(ctx, mgmClient, pubSSHKey)
|
||||
serverKey, err := doMgmLogin(ctx, mgmClient, pubSSHKey, staticInfo)
|
||||
if serverKey != nil && isRegistrationNeeded(err) {
|
||||
log.Debugf("peer registration required")
|
||||
_, err = registerPeer(ctx, *serverKey, mgmClient, setupKey, jwtToken, pubSSHKey)
|
||||
_, err = registerPeer(ctx, *serverKey, mgmClient, setupKey, jwtToken, pubSSHKey, staticInfo)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -99,28 +99,28 @@ func getMgmClient(ctx context.Context, privateKey string, mgmURL *url.URL) (*mgm
|
||||
return mgmClient, err
|
||||
}
|
||||
|
||||
func doMgmLogin(ctx context.Context, mgmClient *mgm.GrpcClient, pubSSHKey []byte) (*wgtypes.Key, error) {
|
||||
func doMgmLogin(ctx context.Context, mgmClient *mgm.GrpcClient, pubSSHKey []byte, staticInfo *system.StaticInfo) (*wgtypes.Key, error) {
|
||||
serverKey, err := mgmClient.GetServerPublicKey()
|
||||
if err != nil {
|
||||
log.Errorf("failed while getting Management Service public key: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sysInfo := system.GetInfo(ctx)
|
||||
sysInfo := system.GetInfo(ctx, staticInfo)
|
||||
_, err = mgmClient.Login(*serverKey, sysInfo, pubSSHKey)
|
||||
return serverKey, err
|
||||
}
|
||||
|
||||
// registerPeer checks whether setupKey was provided via cmd line and if not then it prompts user to enter a key.
|
||||
// Otherwise tries to register with the provided setupKey via command line.
|
||||
func registerPeer(ctx context.Context, serverPublicKey wgtypes.Key, client *mgm.GrpcClient, setupKey string, jwtToken string, pubSSHKey []byte) (*mgmProto.LoginResponse, error) {
|
||||
func registerPeer(ctx context.Context, serverPublicKey wgtypes.Key, client *mgm.GrpcClient, setupKey string, jwtToken string, pubSSHKey []byte, staticInfo *system.StaticInfo) (*mgmProto.LoginResponse, error) {
|
||||
validSetupKey, err := uuid.Parse(setupKey)
|
||||
if err != nil && jwtToken == "" {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid setup-key or no sso information provided, err: %v", err)
|
||||
}
|
||||
|
||||
log.Debugf("sending peer registration request to Management Service")
|
||||
info := system.GetInfo(ctx)
|
||||
info := system.GetInfo(ctx, staticInfo)
|
||||
loginResp, err := client.Register(serverPublicKey, validSetupKey.String(), jwtToken, info, pubSSHKey)
|
||||
if err != nil {
|
||||
log.Errorf("failed registering peer %v,%s", err, validSetupKey.String())
|
||||
|
||||
@@ -676,25 +676,23 @@ func (d *Status) GetRelayStates() []relay.ProbeResult {
|
||||
// extend the list of stun, turn servers with relay address
|
||||
relayStates := slices.Clone(d.relayStates)
|
||||
|
||||
var relayState relay.ProbeResult
|
||||
|
||||
// if the server connection is not established then we will use the general address
|
||||
// in case of connection we will use the instance specific address
|
||||
instanceAddr, err := d.relayMgr.RelayInstanceAddress()
|
||||
if err != nil {
|
||||
// TODO add their status
|
||||
if errors.Is(err, relayClient.ErrRelayClientNotConnected) {
|
||||
for _, r := range d.relayMgr.ServerURLs() {
|
||||
relayStates = append(relayStates, relay.ProbeResult{
|
||||
URI: r,
|
||||
})
|
||||
}
|
||||
return relayStates
|
||||
for _, r := range d.relayMgr.ServerURLs() {
|
||||
relayStates = append(relayStates, relay.ProbeResult{
|
||||
URI: r,
|
||||
Err: err,
|
||||
})
|
||||
}
|
||||
relayState.Err = err
|
||||
return relayStates
|
||||
}
|
||||
|
||||
relayState.URI = instanceAddr
|
||||
relayState := relay.ProbeResult{
|
||||
URI: instanceAddr,
|
||||
}
|
||||
return append(relayStates, relayState)
|
||||
}
|
||||
|
||||
|
||||
@@ -46,8 +46,6 @@ type WorkerICE struct {
|
||||
hasRelayOnLocally bool
|
||||
conn WorkerICECallbacks
|
||||
|
||||
selectedPriority ConnPriority
|
||||
|
||||
agent *ice.Agent
|
||||
muxAgent sync.Mutex
|
||||
|
||||
@@ -95,10 +93,8 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
|
||||
|
||||
var preferredCandidateTypes []ice.CandidateType
|
||||
if w.hasRelayOnLocally && remoteOfferAnswer.RelaySrvAddress != "" {
|
||||
w.selectedPriority = connPriorityICEP2P
|
||||
preferredCandidateTypes = icemaker.CandidateTypesP2P()
|
||||
} else {
|
||||
w.selectedPriority = connPriorityICETurn
|
||||
preferredCandidateTypes = icemaker.CandidateTypes()
|
||||
}
|
||||
|
||||
@@ -159,7 +155,7 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
|
||||
RelayedOnLocal: isRelayCandidate(pair.Local),
|
||||
}
|
||||
w.log.Debugf("on ICE conn read to use ready")
|
||||
go w.conn.OnConnReady(w.selectedPriority, ci)
|
||||
go w.conn.OnConnReady(selectedPriority(pair), ci)
|
||||
}
|
||||
|
||||
// OnRemoteCandidate Handles ICE connection Candidate provided by the remote peer.
|
||||
@@ -394,3 +390,11 @@ func isRelayed(pair *ice.CandidatePair) bool {
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func selectedPriority(pair *ice.CandidatePair) ConnPriority {
|
||||
if isRelayed(pair) {
|
||||
return connPriorityICETurn
|
||||
} else {
|
||||
return connPriorityICEP2P
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package routemanager
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -227,6 +228,64 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
|
||||
currentRoute: "route1",
|
||||
expectedRouteID: "route1",
|
||||
},
|
||||
{
|
||||
name: "relayed routes with latency 0 should maintain previous choice",
|
||||
statuses: map[route.ID]routerPeerStatus{
|
||||
"route1": {
|
||||
connected: true,
|
||||
relayed: true,
|
||||
latency: 0 * time.Millisecond,
|
||||
},
|
||||
"route2": {
|
||||
connected: true,
|
||||
relayed: true,
|
||||
latency: 0 * time.Millisecond,
|
||||
},
|
||||
},
|
||||
existingRoutes: map[route.ID]*route.Route{
|
||||
"route1": {
|
||||
ID: "route1",
|
||||
Metric: route.MaxMetric,
|
||||
Peer: "peer1",
|
||||
},
|
||||
"route2": {
|
||||
ID: "route2",
|
||||
Metric: route.MaxMetric,
|
||||
Peer: "peer2",
|
||||
},
|
||||
},
|
||||
currentRoute: "route1",
|
||||
expectedRouteID: "route1",
|
||||
},
|
||||
{
|
||||
name: "p2p routes with latency 0 should maintain previous choice",
|
||||
statuses: map[route.ID]routerPeerStatus{
|
||||
"route1": {
|
||||
connected: true,
|
||||
relayed: false,
|
||||
latency: 0 * time.Millisecond,
|
||||
},
|
||||
"route2": {
|
||||
connected: true,
|
||||
relayed: false,
|
||||
latency: 0 * time.Millisecond,
|
||||
},
|
||||
},
|
||||
existingRoutes: map[route.ID]*route.Route{
|
||||
"route1": {
|
||||
ID: "route1",
|
||||
Metric: route.MaxMetric,
|
||||
Peer: "peer1",
|
||||
},
|
||||
"route2": {
|
||||
ID: "route2",
|
||||
Metric: route.MaxMetric,
|
||||
Peer: "peer2",
|
||||
},
|
||||
},
|
||||
currentRoute: "route1",
|
||||
expectedRouteID: "route1",
|
||||
},
|
||||
{
|
||||
name: "current route with bad score should be changed to route with better score",
|
||||
statuses: map[route.ID]routerPeerStatus{
|
||||
@@ -287,6 +346,45 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
// fill the test data with random routes
|
||||
for _, tc := range testCases {
|
||||
for i := 0; i < 50; i++ {
|
||||
dummyRoute := &route.Route{
|
||||
ID: route.ID(fmt.Sprintf("dummy_p1_%d", i)),
|
||||
Metric: route.MinMetric,
|
||||
Peer: fmt.Sprintf("dummy_p1_%d", i),
|
||||
}
|
||||
tc.existingRoutes[dummyRoute.ID] = dummyRoute
|
||||
}
|
||||
for i := 0; i < 50; i++ {
|
||||
dummyRoute := &route.Route{
|
||||
ID: route.ID(fmt.Sprintf("dummy_p2_%d", i)),
|
||||
Metric: route.MinMetric,
|
||||
Peer: fmt.Sprintf("dummy_p1_%d", i),
|
||||
}
|
||||
tc.existingRoutes[dummyRoute.ID] = dummyRoute
|
||||
}
|
||||
|
||||
for i := 0; i < 50; i++ {
|
||||
id := route.ID(fmt.Sprintf("dummy_p1_%d", i))
|
||||
dummyStatus := routerPeerStatus{
|
||||
connected: false,
|
||||
relayed: true,
|
||||
latency: 0,
|
||||
}
|
||||
tc.statuses[id] = dummyStatus
|
||||
}
|
||||
for i := 0; i < 50; i++ {
|
||||
id := route.ID(fmt.Sprintf("dummy_p2_%d", i))
|
||||
dummyStatus := routerPeerStatus{
|
||||
connected: false,
|
||||
relayed: true,
|
||||
latency: 0,
|
||||
}
|
||||
tc.statuses[id] = dummyStatus
|
||||
}
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
currentRoute := &route.Route{
|
||||
|
||||
@@ -47,10 +47,9 @@ type RemoveFunc[Key, O any] func(key Key, out O) error
|
||||
type Counter[Key comparable, I, O any] struct {
|
||||
// refCountMap keeps track of the reference Ref for keys
|
||||
refCountMap map[Key]Ref[O]
|
||||
refCountMu sync.Mutex
|
||||
mu sync.Mutex
|
||||
// idMap keeps track of the keys associated with an ID for removal
|
||||
idMap map[string][]Key
|
||||
idMu sync.Mutex
|
||||
add AddFunc[Key, I, O]
|
||||
remove RemoveFunc[Key, O]
|
||||
}
|
||||
@@ -75,10 +74,8 @@ func New[Key comparable, I, O any](add AddFunc[Key, I, O], remove RemoveFunc[Key
|
||||
func (rm *Counter[Key, I, O]) LoadData(
|
||||
existingCounter *Counter[Key, I, O],
|
||||
) {
|
||||
rm.refCountMu.Lock()
|
||||
defer rm.refCountMu.Unlock()
|
||||
rm.idMu.Lock()
|
||||
defer rm.idMu.Unlock()
|
||||
rm.mu.Lock()
|
||||
defer rm.mu.Unlock()
|
||||
|
||||
rm.refCountMap = existingCounter.refCountMap
|
||||
rm.idMap = existingCounter.idMap
|
||||
@@ -87,8 +84,8 @@ func (rm *Counter[Key, I, O]) LoadData(
|
||||
// Get retrieves the current reference count and associated data for a key.
|
||||
// If the key doesn't exist, it returns a zero value Ref and false.
|
||||
func (rm *Counter[Key, I, O]) Get(key Key) (Ref[O], bool) {
|
||||
rm.refCountMu.Lock()
|
||||
defer rm.refCountMu.Unlock()
|
||||
rm.mu.Lock()
|
||||
defer rm.mu.Unlock()
|
||||
|
||||
ref, ok := rm.refCountMap[key]
|
||||
return ref, ok
|
||||
@@ -97,9 +94,13 @@ func (rm *Counter[Key, I, O]) Get(key Key) (Ref[O], bool) {
|
||||
// Increment increments the reference count for the given key.
|
||||
// If this is the first reference to the key, the AddFunc is called.
|
||||
func (rm *Counter[Key, I, O]) Increment(key Key, in I) (Ref[O], error) {
|
||||
rm.refCountMu.Lock()
|
||||
defer rm.refCountMu.Unlock()
|
||||
rm.mu.Lock()
|
||||
defer rm.mu.Unlock()
|
||||
|
||||
return rm.increment(key, in)
|
||||
}
|
||||
|
||||
func (rm *Counter[Key, I, O]) increment(key Key, in I) (Ref[O], error) {
|
||||
ref := rm.refCountMap[key]
|
||||
logCallerF("Increasing ref count [%d -> %d] for key %v with In [%v] Out [%v]", ref.Count, ref.Count+1, key, in, ref.Out)
|
||||
|
||||
@@ -126,10 +127,10 @@ func (rm *Counter[Key, I, O]) Increment(key Key, in I) (Ref[O], error) {
|
||||
// IncrementWithID increments the reference count for the given key and groups it under the given ID.
|
||||
// If this is the first reference to the key, the AddFunc is called.
|
||||
func (rm *Counter[Key, I, O]) IncrementWithID(id string, key Key, in I) (Ref[O], error) {
|
||||
rm.idMu.Lock()
|
||||
defer rm.idMu.Unlock()
|
||||
rm.mu.Lock()
|
||||
defer rm.mu.Unlock()
|
||||
|
||||
ref, err := rm.Increment(key, in)
|
||||
ref, err := rm.increment(key, in)
|
||||
if err != nil {
|
||||
return ref, fmt.Errorf("with ID: %w", err)
|
||||
}
|
||||
@@ -141,9 +142,12 @@ func (rm *Counter[Key, I, O]) IncrementWithID(id string, key Key, in I) (Ref[O],
|
||||
// Decrement decrements the reference count for the given key.
|
||||
// If the reference count reaches 0, the RemoveFunc is called.
|
||||
func (rm *Counter[Key, I, O]) Decrement(key Key) (Ref[O], error) {
|
||||
rm.refCountMu.Lock()
|
||||
defer rm.refCountMu.Unlock()
|
||||
rm.mu.Lock()
|
||||
defer rm.mu.Unlock()
|
||||
return rm.decrement(key)
|
||||
}
|
||||
|
||||
func (rm *Counter[Key, I, O]) decrement(key Key) (Ref[O], error) {
|
||||
ref, ok := rm.refCountMap[key]
|
||||
if !ok {
|
||||
logCallerF("No reference found for key %v", key)
|
||||
@@ -168,12 +172,12 @@ func (rm *Counter[Key, I, O]) Decrement(key Key) (Ref[O], error) {
|
||||
// DecrementWithID decrements the reference count for all keys associated with the given ID.
|
||||
// If the reference count reaches 0, the RemoveFunc is called.
|
||||
func (rm *Counter[Key, I, O]) DecrementWithID(id string) error {
|
||||
rm.idMu.Lock()
|
||||
defer rm.idMu.Unlock()
|
||||
rm.mu.Lock()
|
||||
defer rm.mu.Unlock()
|
||||
|
||||
var merr *multierror.Error
|
||||
for _, key := range rm.idMap[id] {
|
||||
if _, err := rm.Decrement(key); err != nil {
|
||||
if _, err := rm.decrement(key); err != nil {
|
||||
merr = multierror.Append(merr, err)
|
||||
}
|
||||
}
|
||||
@@ -184,10 +188,8 @@ func (rm *Counter[Key, I, O]) DecrementWithID(id string) error {
|
||||
|
||||
// Flush removes all references and calls RemoveFunc for each key.
|
||||
func (rm *Counter[Key, I, O]) Flush() error {
|
||||
rm.refCountMu.Lock()
|
||||
defer rm.refCountMu.Unlock()
|
||||
rm.idMu.Lock()
|
||||
defer rm.idMu.Unlock()
|
||||
rm.mu.Lock()
|
||||
defer rm.mu.Unlock()
|
||||
|
||||
var merr *multierror.Error
|
||||
for key := range rm.refCountMap {
|
||||
@@ -206,10 +208,8 @@ func (rm *Counter[Key, I, O]) Flush() error {
|
||||
|
||||
// Clear removes all references without calling RemoveFunc.
|
||||
func (rm *Counter[Key, I, O]) Clear() {
|
||||
rm.refCountMu.Lock()
|
||||
defer rm.refCountMu.Unlock()
|
||||
rm.idMu.Lock()
|
||||
defer rm.idMu.Unlock()
|
||||
rm.mu.Lock()
|
||||
defer rm.mu.Unlock()
|
||||
|
||||
clear(rm.refCountMap)
|
||||
clear(rm.idMap)
|
||||
@@ -217,10 +217,8 @@ func (rm *Counter[Key, I, O]) Clear() {
|
||||
|
||||
// MarshalJSON implements the json.Marshaler interface for Counter.
|
||||
func (rm *Counter[Key, I, O]) MarshalJSON() ([]byte, error) {
|
||||
rm.refCountMu.Lock()
|
||||
defer rm.refCountMu.Unlock()
|
||||
rm.idMu.Lock()
|
||||
defer rm.idMu.Unlock()
|
||||
rm.mu.Lock()
|
||||
defer rm.mu.Unlock()
|
||||
|
||||
return json.Marshal(struct {
|
||||
RefCountMap map[Key]Ref[O] `json:"refCountMap"`
|
||||
|
||||
@@ -2,31 +2,28 @@ package systemops
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"sync"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||
)
|
||||
|
||||
type ShutdownState struct {
|
||||
Counter *ExclusionCounter `json:"counter,omitempty"`
|
||||
mu sync.RWMutex
|
||||
}
|
||||
type ShutdownState ExclusionCounter
|
||||
|
||||
func (s *ShutdownState) Name() string {
|
||||
return "route_state"
|
||||
}
|
||||
|
||||
func (s *ShutdownState) Cleanup() error {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
if s.Counter == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
sysops := NewSysOps(nil, nil)
|
||||
sysops.refCounter = refcounter.New[netip.Prefix, struct{}, Nexthop](nil, sysops.removeFromRouteTable)
|
||||
sysops.refCounter.LoadData(s.Counter)
|
||||
sysops.refCounter.LoadData((*ExclusionCounter)(s))
|
||||
|
||||
return sysops.refCounter.Flush()
|
||||
}
|
||||
|
||||
func (s *ShutdownState) MarshalJSON() ([]byte, error) {
|
||||
return (*ExclusionCounter)(s).MarshalJSON()
|
||||
}
|
||||
|
||||
func (s *ShutdownState) UnmarshalJSON(data []byte) error {
|
||||
return (*ExclusionCounter)(s).UnmarshalJSON(data)
|
||||
}
|
||||
|
||||
@@ -57,30 +57,19 @@ func (r *SysOps) setupRefCounter(initAddresses []net.IP, stateManager *statemana
|
||||
return nexthop, refcounter.ErrIgnore
|
||||
}
|
||||
|
||||
r.updateState(stateManager)
|
||||
|
||||
return nexthop, err
|
||||
},
|
||||
func(prefix netip.Prefix, nexthop Nexthop) error {
|
||||
// remove from state even if we have trouble removing it from the route table
|
||||
// it could be already gone
|
||||
r.updateState(stateManager)
|
||||
|
||||
return r.removeFromRouteTable(prefix, nexthop)
|
||||
},
|
||||
r.removeFromRouteTable,
|
||||
)
|
||||
|
||||
r.refCounter = refCounter
|
||||
|
||||
return r.setupHooks(initAddresses)
|
||||
return r.setupHooks(initAddresses, stateManager)
|
||||
}
|
||||
|
||||
// updateState updates state on every change so it will be persisted regularly
|
||||
func (r *SysOps) updateState(stateManager *statemanager.Manager) {
|
||||
state := getState(stateManager)
|
||||
|
||||
state.Counter = r.refCounter
|
||||
|
||||
if err := stateManager.UpdateState(state); err != nil {
|
||||
if err := stateManager.UpdateState((*ShutdownState)(r.refCounter)); err != nil {
|
||||
log.Errorf("failed to update state: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -336,7 +325,7 @@ func (r *SysOps) genericRemoveVPNRoute(prefix netip.Prefix, intf *net.Interface)
|
||||
return r.removeFromRouteTable(prefix, nextHop)
|
||||
}
|
||||
|
||||
func (r *SysOps) setupHooks(initAddresses []net.IP) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
|
||||
func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
|
||||
beforeHook := func(connID nbnet.ConnectionID, ip net.IP) error {
|
||||
prefix, err := util.GetPrefixFromIP(ip)
|
||||
if err != nil {
|
||||
@@ -347,6 +336,8 @@ func (r *SysOps) setupHooks(initAddresses []net.IP) (nbnet.AddHookFunc, nbnet.Re
|
||||
return fmt.Errorf("adding route reference: %v", err)
|
||||
}
|
||||
|
||||
r.updateState(stateManager)
|
||||
|
||||
return nil
|
||||
}
|
||||
afterHook := func(connID nbnet.ConnectionID) error {
|
||||
@@ -354,6 +345,8 @@ func (r *SysOps) setupHooks(initAddresses []net.IP) (nbnet.AddHookFunc, nbnet.Re
|
||||
return fmt.Errorf("remove route reference: %w", err)
|
||||
}
|
||||
|
||||
r.updateState(stateManager)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -532,14 +525,3 @@ func isVpnRoute(addr netip.Addr, vpnRoutes []netip.Prefix, localRoutes []netip.P
|
||||
// Return true if the longest matching prefix is from vpnRoutes
|
||||
return isVpn, longestPrefix
|
||||
}
|
||||
|
||||
func getState(stateManager *statemanager.Manager) *ShutdownState {
|
||||
var shutdownState *ShutdownState
|
||||
if state := stateManager.GetState(shutdownState); state != nil {
|
||||
shutdownState = state.(*ShutdownState)
|
||||
} else {
|
||||
shutdownState = &ShutdownState{}
|
||||
}
|
||||
|
||||
return shutdownState
|
||||
}
|
||||
|
||||
@@ -55,7 +55,7 @@ type ruleParams struct {
|
||||
|
||||
// isLegacy determines whether to use the legacy routing setup
|
||||
func isLegacy() bool {
|
||||
return os.Getenv("NB_USE_LEGACY_ROUTING") == "true" || nbnet.CustomRoutingDisabled()
|
||||
return os.Getenv("NB_USE_LEGACY_ROUTING") == "true" || nbnet.CustomRoutingDisabled() || nbnet.SkipSocketMark()
|
||||
}
|
||||
|
||||
// setIsLegacy sets the legacy routing setup
|
||||
@@ -92,17 +92,6 @@ func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager
|
||||
return r.setupRefCounter(initAddresses, stateManager)
|
||||
}
|
||||
|
||||
if err = addRoutingTableName(); err != nil {
|
||||
log.Errorf("Error adding routing table name: %v", err)
|
||||
}
|
||||
|
||||
originalValues, err := sysctl.Setup(r.wgInterface)
|
||||
if err != nil {
|
||||
log.Errorf("Error setting up sysctl: %v", err)
|
||||
sysctlFailed = true
|
||||
}
|
||||
originalSysctl = originalValues
|
||||
|
||||
defer func() {
|
||||
if err != nil {
|
||||
if cleanErr := r.CleanupRouting(stateManager); cleanErr != nil {
|
||||
@@ -123,6 +112,17 @@ func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager
|
||||
}
|
||||
}
|
||||
|
||||
if err = addRoutingTableName(); err != nil {
|
||||
log.Errorf("Error adding routing table name: %v", err)
|
||||
}
|
||||
|
||||
originalValues, err := sysctl.Setup(r.wgInterface)
|
||||
if err != nil {
|
||||
log.Errorf("Error setting up sysctl: %v", err)
|
||||
sysctlFailed = true
|
||||
}
|
||||
originalSysctl = originalValues
|
||||
|
||||
return nil, nil, nil
|
||||
}
|
||||
|
||||
@@ -450,7 +450,7 @@ func addRule(params ruleParams) error {
|
||||
rule.Invert = params.invert
|
||||
rule.SuppressPrefixlen = params.suppressPrefix
|
||||
|
||||
if err := netlink.RuleAdd(rule); err != nil && !errors.Is(err, syscall.EEXIST) && !errors.Is(err, syscall.EAFNOSUPPORT) {
|
||||
if err := netlink.RuleAdd(rule); err != nil && !errors.Is(err, syscall.EEXIST) {
|
||||
return fmt.Errorf("add routing rule: %w", err)
|
||||
}
|
||||
|
||||
@@ -467,7 +467,7 @@ func removeRule(params ruleParams) error {
|
||||
rule.Priority = params.priority
|
||||
rule.SuppressPrefixlen = params.suppressPrefix
|
||||
|
||||
if err := netlink.RuleDel(rule); err != nil && !errors.Is(err, syscall.ENOENT) && !errors.Is(err, syscall.EAFNOSUPPORT) {
|
||||
if err := netlink.RuleDel(rule); err != nil && !errors.Is(err, syscall.ENOENT) {
|
||||
return fmt.Errorf("remove routing rule: %w", err)
|
||||
}
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
// State interface defines the methods that all state types must implement
|
||||
@@ -73,15 +74,15 @@ func (m *Manager) Stop(ctx context.Context) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if m.cancel != nil {
|
||||
m.cancel()
|
||||
if m.cancel == nil {
|
||||
return nil
|
||||
}
|
||||
m.cancel()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-m.done:
|
||||
return nil
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-m.done:
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -178,25 +179,18 @@ func (m *Manager) PersistState(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(ctx, 3*time.Second)
|
||||
bs, err := marshalWithPanicRecovery(m.states)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal states: %w", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
done := make(chan error, 1)
|
||||
|
||||
start := time.Now()
|
||||
go func() {
|
||||
data, err := json.MarshalIndent(m.states, "", " ")
|
||||
if err != nil {
|
||||
done <- fmt.Errorf("marshal states: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
// nolint:gosec
|
||||
if err := os.WriteFile(m.filePath, data, 0640); err != nil {
|
||||
done <- fmt.Errorf("write state file: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
done <- nil
|
||||
done <- util.WriteBytesWithRestrictedPermission(ctx, m.filePath, bs)
|
||||
}()
|
||||
|
||||
select {
|
||||
@@ -208,7 +202,7 @@ func (m *Manager) PersistState(ctx context.Context) error {
|
||||
}
|
||||
}
|
||||
|
||||
log.Debugf("persisted shutdown states: %v", maps.Keys(m.dirty))
|
||||
log.Debugf("persisted shutdown states: %v, took %v", maps.Keys(m.dirty), time.Since(start))
|
||||
|
||||
clear(m.dirty)
|
||||
|
||||
@@ -296,3 +290,19 @@ func (m *Manager) PerformCleanup() error {
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func marshalWithPanicRecovery(v any) ([]byte, error) {
|
||||
var bs []byte
|
||||
var err error
|
||||
|
||||
func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = fmt.Errorf("panic during marshal: %v", r)
|
||||
}
|
||||
}()
|
||||
bs, err = json.Marshal(v)
|
||||
}()
|
||||
|
||||
return bs, err
|
||||
}
|
||||
|
||||
@@ -4,32 +4,20 @@ import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// GetDefaultStatePath returns the path to the state file based on the operating system
|
||||
// It returns an empty string if the path cannot be determined. It also creates the directory if it does not exist.
|
||||
// It returns an empty string if the path cannot be determined.
|
||||
func GetDefaultStatePath() string {
|
||||
var path string
|
||||
|
||||
switch runtime.GOOS {
|
||||
case "windows":
|
||||
path = filepath.Join(os.Getenv("PROGRAMDATA"), "Netbird", "state.json")
|
||||
return filepath.Join(os.Getenv("PROGRAMDATA"), "Netbird", "state.json")
|
||||
case "darwin", "linux":
|
||||
path = "/var/lib/netbird/state.json"
|
||||
return "/var/lib/netbird/state.json"
|
||||
case "freebsd", "openbsd", "netbsd", "dragonfly":
|
||||
path = "/var/db/netbird/state.json"
|
||||
// ios/android don't need state
|
||||
default:
|
||||
return ""
|
||||
return "/var/db/netbird/state.json"
|
||||
}
|
||||
|
||||
dir := filepath.Dir(path)
|
||||
if err := os.MkdirAll(dir, 0755); err != nil {
|
||||
log.Errorf("Error creating directory %s: %v. Continuing without state support.", dir, err)
|
||||
return ""
|
||||
}
|
||||
return ""
|
||||
|
||||
return path
|
||||
}
|
||||
|
||||
@@ -123,7 +123,7 @@ func (c *Client) Run(fd int32, interfaceName string) error {
|
||||
c.onHostDnsFn = func([]string) {}
|
||||
cfg.WgIface = interfaceName
|
||||
|
||||
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
|
||||
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder, &system.StaticInfo{})
|
||||
return c.connectClient.RunOniOS(fd, c.networkChangeListener, c.dnsManager)
|
||||
}
|
||||
|
||||
@@ -204,7 +204,7 @@ func (c *Client) IsLoginRequired() bool {
|
||||
ConfigPath: c.cfgFile,
|
||||
})
|
||||
|
||||
needsLogin, _ := internal.IsLoginRequired(ctx, cfg.PrivateKey, cfg.ManagementURL, cfg.SSHKey)
|
||||
needsLogin, _ := internal.IsLoginRequired(ctx, cfg.PrivateKey, cfg.ManagementURL, cfg.SSHKey, &system.StaticInfo{})
|
||||
return needsLogin
|
||||
}
|
||||
|
||||
@@ -244,7 +244,7 @@ func (c *Client) LoginForMobile() string {
|
||||
return
|
||||
}
|
||||
jwtToken := tokenInfo.GetTokenToUse()
|
||||
_ = internal.Login(ctx, cfg, "", jwtToken)
|
||||
_ = internal.Login(ctx, cfg, "", jwtToken, &system.StaticInfo{})
|
||||
c.loginComplete = true
|
||||
}()
|
||||
|
||||
|
||||
@@ -104,7 +104,7 @@ func (a *Auth) LoginWithSetupKeyAndSaveConfig(setupKey string, deviceName string
|
||||
ctxWithValues := context.WithValue(a.ctx, system.DeviceNameCtxKey, deviceName)
|
||||
|
||||
err := a.withBackOff(a.ctx, func() error {
|
||||
backoffErr := internal.Login(ctxWithValues, a.config, setupKey, "")
|
||||
backoffErr := internal.Login(ctxWithValues, a.config, setupKey, "", &system.StaticInfo{})
|
||||
if s, ok := gstatus.FromError(backoffErr); ok && (s.Code() == codes.PermissionDenied) {
|
||||
// we got an answer from management, exit backoff earlier
|
||||
return backoff.Permanent(backoffErr)
|
||||
@@ -123,7 +123,7 @@ func (a *Auth) Login() error {
|
||||
|
||||
// check if we need to generate JWT token
|
||||
err := a.withBackOff(a.ctx, func() (err error) {
|
||||
needsLogin, err = internal.IsLoginRequired(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config.SSHKey)
|
||||
needsLogin, err = internal.IsLoginRequired(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config.SSHKey, &system.StaticInfo{})
|
||||
return
|
||||
})
|
||||
if err != nil {
|
||||
@@ -136,7 +136,7 @@ func (a *Auth) Login() error {
|
||||
}
|
||||
|
||||
err = a.withBackOff(a.ctx, func() error {
|
||||
err := internal.Login(a.ctx, a.config, "", jwtToken)
|
||||
err := internal.Login(a.ctx, a.config, "", jwtToken, &system.StaticInfo{})
|
||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -68,6 +68,7 @@ type Server struct {
|
||||
relayProbe *internal.Probe
|
||||
wgProbe *internal.Probe
|
||||
lastProbe time.Time
|
||||
staticInfo *system.StaticInfo
|
||||
}
|
||||
|
||||
type oauthAuthFlow struct {
|
||||
@@ -79,6 +80,8 @@ type oauthAuthFlow struct {
|
||||
|
||||
// New server instance constructor.
|
||||
func New(ctx context.Context, configPath, logFile string) *Server {
|
||||
staticInfoChan := system.GetStaticInfoInBackground(ctx)
|
||||
staticInfo := <-staticInfoChan
|
||||
return &Server{
|
||||
rootCtx: ctx,
|
||||
latestConfigInput: internal.ConfigInput{
|
||||
@@ -89,6 +92,7 @@ func New(ctx context.Context, configPath, logFile string) *Server {
|
||||
signalProbe: internal.NewProbe(),
|
||||
relayProbe: internal.NewProbe(),
|
||||
wgProbe: internal.NewProbe(),
|
||||
staticInfo: staticInfo,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -195,7 +199,7 @@ func (s *Server) connectWithRetryRuns(ctx context.Context, config *internal.Conf
|
||||
|
||||
runOperation := func() error {
|
||||
log.Tracef("running client connection")
|
||||
s.connectClient = internal.NewConnectClient(ctx, config, statusRecorder)
|
||||
s.connectClient = internal.NewConnectClient(ctx, config, statusRecorder, s.staticInfo)
|
||||
|
||||
probes := internal.ProbeHolder{
|
||||
MgmProbe: s.mgmProbe,
|
||||
@@ -272,7 +276,7 @@ func parseEnvDuration(envVar string, defaultDuration time.Duration) time.Duratio
|
||||
// loginAttempt attempts to login using the provided information. it returns a status in case something fails
|
||||
func (s *Server) loginAttempt(ctx context.Context, setupKey, jwtToken string) (internal.StatusType, error) {
|
||||
var status internal.StatusType
|
||||
err := internal.Login(ctx, s.config, setupKey, jwtToken)
|
||||
err := internal.Login(ctx, s.config, setupKey, jwtToken, s.staticInfo)
|
||||
if err != nil {
|
||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
|
||||
log.Warnf("failed login: %v", err)
|
||||
|
||||
@@ -61,6 +61,14 @@ type Info struct {
|
||||
Files []File // for posture checks
|
||||
}
|
||||
|
||||
// StaticInfo is an object that contains machine information that does not change
|
||||
type StaticInfo struct {
|
||||
SystemSerialNumber string
|
||||
SystemProductName string
|
||||
SystemManufacturer string
|
||||
Environment Environment
|
||||
}
|
||||
|
||||
// extractUserAgent extracts Netbird's agent (client) name and version from the outgoing context
|
||||
func extractUserAgent(ctx context.Context) string {
|
||||
md, hasMeta := metadata.FromOutgoingContext(ctx)
|
||||
@@ -142,7 +150,7 @@ func isDuplicated(addresses []NetworkAddress, addr NetworkAddress) bool {
|
||||
}
|
||||
|
||||
// GetInfoWithChecks retrieves and parses the system information with applied checks.
|
||||
func GetInfoWithChecks(ctx context.Context, checks []*proto.Checks) (*Info, error) {
|
||||
func GetInfoWithChecks(ctx context.Context, checks []*proto.Checks, staticInfo *StaticInfo) (*Info, error) {
|
||||
processCheckPaths := make([]string, 0)
|
||||
for _, check := range checks {
|
||||
processCheckPaths = append(processCheckPaths, check.GetFiles()...)
|
||||
@@ -153,8 +161,17 @@ func GetInfoWithChecks(ctx context.Context, checks []*proto.Checks) (*Info, erro
|
||||
return nil, err
|
||||
}
|
||||
|
||||
info := GetInfo(ctx)
|
||||
info := GetInfo(ctx, staticInfo)
|
||||
info.Files = files
|
||||
|
||||
return info, nil
|
||||
}
|
||||
|
||||
// GetStaticInfoInBackground retrieves and parses the system information in the background
|
||||
func GetStaticInfoInBackground(ctx context.Context) <-chan *StaticInfo {
|
||||
ch := make(chan *StaticInfo)
|
||||
go func() {
|
||||
ch <- getStaticInfo(ctx)
|
||||
}()
|
||||
return ch
|
||||
}
|
||||
|
||||
@@ -16,7 +16,7 @@ import (
|
||||
)
|
||||
|
||||
// GetInfo retrieves and parses the system information
|
||||
func GetInfo(ctx context.Context) *Info {
|
||||
func GetInfo(ctx context.Context, _ *StaticInfo) *Info {
|
||||
kernel := "android"
|
||||
osInfo := uname()
|
||||
if len(osInfo) == 2 {
|
||||
@@ -44,6 +44,10 @@ func GetInfo(ctx context.Context) *Info {
|
||||
return gio
|
||||
}
|
||||
|
||||
func getStaticInfo(ctx context.Context) *StaticInfo {
|
||||
return nil
|
||||
}
|
||||
|
||||
// checkFileAndProcess checks if the file path exists and if a process is running at that path.
|
||||
func checkFileAndProcess(paths []string) ([]File, error) {
|
||||
return []File{}, nil
|
||||
|
||||
@@ -21,7 +21,7 @@ import (
|
||||
)
|
||||
|
||||
// GetInfo retrieves and parses the system information
|
||||
func GetInfo(ctx context.Context) *Info {
|
||||
func GetInfo(ctx context.Context, staticInfo *StaticInfo) *Info {
|
||||
utsname := unix.Utsname{}
|
||||
err := unix.Uname(&utsname)
|
||||
if err != nil {
|
||||
@@ -41,26 +41,22 @@ func GetInfo(ctx context.Context) *Info {
|
||||
log.Warnf("failed to discover network addresses: %s", err)
|
||||
}
|
||||
|
||||
serialNum, prodName, manufacturer := sysInfo()
|
||||
|
||||
env := Environment{
|
||||
Cloud: detect_cloud.Detect(ctx),
|
||||
Platform: detect_platform.Detect(ctx),
|
||||
gio := &Info{
|
||||
Kernel: sysName,
|
||||
OSVersion: strings.TrimSpace(string(swVersion)),
|
||||
Platform: machine,
|
||||
OS: sysName,
|
||||
GoOS: runtime.GOOS,
|
||||
CPUs: runtime.NumCPU(),
|
||||
KernelVersion: release,
|
||||
NetworkAddresses: addrs,
|
||||
}
|
||||
|
||||
gio := &Info{
|
||||
Kernel: sysName,
|
||||
OSVersion: strings.TrimSpace(string(swVersion)),
|
||||
Platform: machine,
|
||||
OS: sysName,
|
||||
GoOS: runtime.GOOS,
|
||||
CPUs: runtime.NumCPU(),
|
||||
KernelVersion: release,
|
||||
NetworkAddresses: addrs,
|
||||
SystemSerialNumber: serialNum,
|
||||
SystemProductName: prodName,
|
||||
SystemManufacturer: manufacturer,
|
||||
Environment: env,
|
||||
if staticInfo != nil {
|
||||
gio.SystemSerialNumber = staticInfo.SystemSerialNumber
|
||||
gio.SystemProductName = staticInfo.SystemProductName
|
||||
gio.SystemManufacturer = staticInfo.SystemManufacturer
|
||||
gio.Environment = staticInfo.Environment
|
||||
}
|
||||
|
||||
systemHostname, _ := os.Hostname()
|
||||
@@ -71,6 +67,21 @@ func GetInfo(ctx context.Context) *Info {
|
||||
return gio
|
||||
}
|
||||
|
||||
func getStaticInfo(ctx context.Context) *StaticInfo {
|
||||
serialNum, prodName, manufacturer := sysInfo()
|
||||
env := Environment{
|
||||
Cloud: detect_cloud.Detect(ctx),
|
||||
Platform: detect_platform.Detect(ctx),
|
||||
}
|
||||
|
||||
return &StaticInfo{
|
||||
SystemSerialNumber: serialNum,
|
||||
SystemProductName: prodName,
|
||||
SystemManufacturer: manufacturer,
|
||||
Environment: env,
|
||||
}
|
||||
}
|
||||
|
||||
func sysInfo() (serialNumber string, productName string, manufacturer string) {
|
||||
out, _ := exec.Command("/usr/sbin/ioreg", "-l").Output() // err ignored for brevity
|
||||
for _, l := range strings.Split(string(out), "\n") {
|
||||
|
||||
@@ -19,7 +19,7 @@ import (
|
||||
)
|
||||
|
||||
// GetInfo retrieves and parses the system information
|
||||
func GetInfo(ctx context.Context) *Info {
|
||||
func GetInfo(ctx context.Context, staticInfo *StaticInfo) *Info {
|
||||
out := _getInfo()
|
||||
for strings.Contains(out, "broken pipe") {
|
||||
out = _getInfo()
|
||||
@@ -29,16 +29,11 @@ func GetInfo(ctx context.Context) *Info {
|
||||
osStr = strings.ReplaceAll(osStr, "\r\n", "")
|
||||
osInfo := strings.Split(osStr, " ")
|
||||
|
||||
env := Environment{
|
||||
Cloud: detect_cloud.Detect(ctx),
|
||||
Platform: detect_platform.Detect(ctx),
|
||||
}
|
||||
|
||||
osName, osVersion := readOsReleaseFile()
|
||||
|
||||
systemHostname, _ := os.Hostname()
|
||||
|
||||
return &Info{
|
||||
info := &Info{
|
||||
GoOS: runtime.GOOS,
|
||||
Kernel: osInfo[0],
|
||||
Platform: runtime.GOARCH,
|
||||
@@ -49,7 +44,25 @@ func GetInfo(ctx context.Context) *Info {
|
||||
WiretrusteeVersion: version.NetbirdVersion(),
|
||||
UIVersion: extractUserAgent(ctx),
|
||||
KernelVersion: osInfo[1],
|
||||
Environment: env,
|
||||
}
|
||||
if staticInfo != nil {
|
||||
info.SystemSerialNumber = staticInfo.SystemSerialNumber
|
||||
info.SystemProductName = staticInfo.SystemProductName
|
||||
info.SystemManufacturer = staticInfo.SystemManufacturer
|
||||
info.Environment = staticInfo.Environment
|
||||
}
|
||||
|
||||
return info
|
||||
}
|
||||
|
||||
func getStaticInfo(ctx context.Context) *StaticInfo {
|
||||
env := Environment{
|
||||
Cloud: detect_cloud.Detect(ctx),
|
||||
Platform: detect_platform.Detect(ctx),
|
||||
}
|
||||
|
||||
return &StaticInfo{
|
||||
Environment: env,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ import (
|
||||
)
|
||||
|
||||
// GetInfo retrieves and parses the system information
|
||||
func GetInfo(ctx context.Context) *Info {
|
||||
func GetInfo(ctx context.Context, _ *StaticInfo) *Info {
|
||||
|
||||
// Convert fixed-size byte arrays to Go strings
|
||||
sysName := extractOsName(ctx, "sysName")
|
||||
@@ -25,6 +25,10 @@ func GetInfo(ctx context.Context) *Info {
|
||||
return gio
|
||||
}
|
||||
|
||||
func getStaticInfo(ctx context.Context) *StaticInfo {
|
||||
return nil
|
||||
}
|
||||
|
||||
// checkFileAndProcess checks if the file path exists and if a process is running at that path.
|
||||
func checkFileAndProcess(paths []string) ([]File, error) {
|
||||
return []File{}, nil
|
||||
|
||||
@@ -42,7 +42,7 @@ func (s SysInfoWrapper) GetSysInfo() SysInfo {
|
||||
}
|
||||
|
||||
// GetInfo retrieves and parses the system information
|
||||
func GetInfo(ctx context.Context) *Info {
|
||||
func GetInfo(ctx context.Context, staticInfo *StaticInfo) *Info {
|
||||
info := _getInfo()
|
||||
for strings.Contains(info, "broken pipe") {
|
||||
info = _getInfo()
|
||||
@@ -65,14 +65,6 @@ func GetInfo(ctx context.Context) *Info {
|
||||
log.Warnf("failed to discover network addresses: %s", err)
|
||||
}
|
||||
|
||||
si := SysInfoWrapper{}
|
||||
serialNum, prodName, manufacturer := sysInfo(si.GetSysInfo())
|
||||
|
||||
env := Environment{
|
||||
Cloud: detect_cloud.Detect(ctx),
|
||||
Platform: detect_platform.Detect(ctx),
|
||||
}
|
||||
|
||||
gio := &Info{
|
||||
Kernel: osInfo[0],
|
||||
Platform: osInfo[2],
|
||||
@@ -85,13 +77,32 @@ func GetInfo(ctx context.Context) *Info {
|
||||
UIVersion: extractUserAgent(ctx),
|
||||
KernelVersion: osInfo[1],
|
||||
NetworkAddresses: addrs,
|
||||
}
|
||||
|
||||
if staticInfo != nil {
|
||||
gio.SystemSerialNumber = staticInfo.SystemSerialNumber
|
||||
gio.SystemProductName = staticInfo.SystemProductName
|
||||
gio.SystemManufacturer = staticInfo.SystemManufacturer
|
||||
gio.Environment = staticInfo.Environment
|
||||
}
|
||||
|
||||
return gio
|
||||
}
|
||||
|
||||
func getStaticInfo(ctx context.Context) *StaticInfo {
|
||||
si := SysInfoWrapper{}
|
||||
serialNum, prodName, manufacturer := sysInfo(si.GetSysInfo())
|
||||
env := Environment{
|
||||
Cloud: detect_cloud.Detect(ctx),
|
||||
Platform: detect_platform.Detect(ctx),
|
||||
}
|
||||
|
||||
return &StaticInfo{
|
||||
SystemSerialNumber: serialNum,
|
||||
SystemProductName: prodName,
|
||||
SystemManufacturer: manufacturer,
|
||||
Environment: env,
|
||||
}
|
||||
|
||||
return gio
|
||||
}
|
||||
|
||||
func _getInfo() string {
|
||||
|
||||
@@ -9,7 +9,7 @@ import (
|
||||
)
|
||||
|
||||
func Test_LocalWTVersion(t *testing.T) {
|
||||
got := GetInfo(context.TODO())
|
||||
got := GetInfo(context.TODO(), nil)
|
||||
want := "development"
|
||||
assert.Equal(t, want, got.WiretrusteeVersion)
|
||||
}
|
||||
@@ -21,7 +21,7 @@ func Test_UIVersion(t *testing.T) {
|
||||
"user-agent": {want},
|
||||
})
|
||||
|
||||
got := GetInfo(ctx)
|
||||
got := GetInfo(ctx, nil)
|
||||
assert.Equal(t, want, got.UIVersion)
|
||||
}
|
||||
|
||||
@@ -30,7 +30,7 @@ func Test_CustomHostname(t *testing.T) {
|
||||
ctx := context.WithValue(context.Background(), DeviceNameCtxKey, "custom-host")
|
||||
want := "custom-host"
|
||||
|
||||
got := GetInfo(ctx)
|
||||
got := GetInfo(ctx, nil)
|
||||
assert.Equal(t, want, got.Hostname)
|
||||
}
|
||||
|
||||
|
||||
@@ -33,7 +33,7 @@ type Win32_BIOS struct {
|
||||
}
|
||||
|
||||
// GetInfo retrieves and parses the system information
|
||||
func GetInfo(ctx context.Context) *Info {
|
||||
func GetInfo(ctx context.Context, staticInfo *StaticInfo) *Info {
|
||||
osName, osVersion := getOSNameAndVersion()
|
||||
buildVersion := getBuildVersion()
|
||||
|
||||
@@ -42,39 +42,22 @@ func GetInfo(ctx context.Context) *Info {
|
||||
log.Warnf("failed to discover network addresses: %s", err)
|
||||
}
|
||||
|
||||
serialNum, err := sysNumber()
|
||||
if err != nil {
|
||||
log.Warnf("failed to get system serial number: %s", err)
|
||||
}
|
||||
|
||||
prodName, err := sysProductName()
|
||||
if err != nil {
|
||||
log.Warnf("failed to get system product name: %s", err)
|
||||
}
|
||||
|
||||
manufacturer, err := sysManufacturer()
|
||||
if err != nil {
|
||||
log.Warnf("failed to get system manufacturer: %s", err)
|
||||
}
|
||||
|
||||
env := Environment{
|
||||
Cloud: detect_cloud.Detect(ctx),
|
||||
Platform: detect_platform.Detect(ctx),
|
||||
}
|
||||
|
||||
gio := &Info{
|
||||
Kernel: "windows",
|
||||
OSVersion: osVersion,
|
||||
Platform: "unknown",
|
||||
OS: osName,
|
||||
GoOS: runtime.GOOS,
|
||||
CPUs: runtime.NumCPU(),
|
||||
KernelVersion: buildVersion,
|
||||
NetworkAddresses: addrs,
|
||||
SystemSerialNumber: serialNum,
|
||||
SystemProductName: prodName,
|
||||
SystemManufacturer: manufacturer,
|
||||
Environment: env,
|
||||
Kernel: "windows",
|
||||
OSVersion: osVersion,
|
||||
Platform: "unknown",
|
||||
OS: osName,
|
||||
GoOS: runtime.GOOS,
|
||||
CPUs: runtime.NumCPU(),
|
||||
KernelVersion: buildVersion,
|
||||
NetworkAddresses: addrs,
|
||||
}
|
||||
|
||||
if staticInfo != nil {
|
||||
gio.SystemSerialNumber = staticInfo.SystemSerialNumber
|
||||
gio.SystemProductName = staticInfo.SystemProductName
|
||||
gio.SystemManufacturer = staticInfo.SystemManufacturer
|
||||
gio.Environment = staticInfo.Environment
|
||||
}
|
||||
|
||||
systemHostname, _ := os.Hostname()
|
||||
@@ -85,6 +68,41 @@ func GetInfo(ctx context.Context) *Info {
|
||||
return gio
|
||||
}
|
||||
|
||||
func getStaticInfo(ctx context.Context) *StaticInfo {
|
||||
serialNum, prodName, manufacturer := sysInfo()
|
||||
env := Environment{
|
||||
Cloud: detect_cloud.Detect(ctx),
|
||||
Platform: detect_platform.Detect(ctx),
|
||||
}
|
||||
|
||||
return &StaticInfo{
|
||||
SystemSerialNumber: serialNum,
|
||||
SystemProductName: prodName,
|
||||
SystemManufacturer: manufacturer,
|
||||
Environment: env,
|
||||
}
|
||||
}
|
||||
|
||||
func sysInfo() (serialNumber string, productName string, manufacturer string) {
|
||||
var err error
|
||||
serialNumber, err = sysNumber()
|
||||
if err != nil {
|
||||
log.Warnf("failed to get system serial number: %s", err)
|
||||
}
|
||||
|
||||
productName, err = sysProductName()
|
||||
if err != nil {
|
||||
log.Warnf("failed to get system product name: %s", err)
|
||||
}
|
||||
|
||||
manufacturer, err = sysManufacturer()
|
||||
if err != nil {
|
||||
log.Warnf("failed to get system manufacturer: %s", err)
|
||||
}
|
||||
|
||||
return serialNumber, productName, manufacturer
|
||||
}
|
||||
|
||||
func getOSNameAndVersion() (string, string) {
|
||||
var dst []Win32_OperatingSystem
|
||||
query := wmi.CreateQuery(&dst, "")
|
||||
|
||||
2
go.mod
2
go.mod
@@ -236,7 +236,7 @@ replace github.com/kardianos/service => github.com/netbirdio/service v0.0.0-2024
|
||||
|
||||
replace github.com/getlantern/systray => github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949
|
||||
|
||||
replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20241107152827-57d8513b5f73
|
||||
replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20241125150134-f9cdce5e32e9
|
||||
|
||||
replace github.com/cloudflare/circl => github.com/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6
|
||||
|
||||
|
||||
4
go.sum
4
go.sum
@@ -527,8 +527,8 @@ github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9ax
|
||||
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
|
||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d h1:bRq5TKgC7Iq20pDiuC54yXaWnAVeS5PdGpSokFTlR28=
|
||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ=
|
||||
github.com/netbirdio/wireguard-go v0.0.0-20241107152827-57d8513b5f73 h1:jayg97LH/jJlvpIHVxueTfa+tfQ+FY8fy2sIhCwkz0g=
|
||||
github.com/netbirdio/wireguard-go v0.0.0-20241107152827-57d8513b5f73/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA=
|
||||
github.com/netbirdio/wireguard-go v0.0.0-20241125150134-f9cdce5e32e9 h1:Pu/7EukijT09ynHUOzQYW7cC3M/BKU8O4qyN/TvTGoY=
|
||||
github.com/netbirdio/wireguard-go v0.0.0-20241125150134-f9cdce5e32e9/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA=
|
||||
github.com/nicksnyder/go-i18n/v2 v2.4.0 h1:3IcvPOAvnCKwNm0TB0dLDTuawWEj+ax/RERNC+diLMM=
|
||||
github.com/nicksnyder/go-i18n/v2 v2.4.0/go.mod h1:nxYSZE9M0bf3Y70gPQjN9ha7XNHX7gMc814+6wVyEI4=
|
||||
github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A=
|
||||
|
||||
@@ -174,7 +174,7 @@ func TestClient_LoginUnregistered_ShouldThrow_401(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
sysInfo := system.GetInfo(context.TODO())
|
||||
sysInfo := system.GetInfo(context.TODO(), nil)
|
||||
_, err = client.Login(*key, sysInfo, nil)
|
||||
if err == nil {
|
||||
t.Error("expecting err on unregistered login, got nil")
|
||||
@@ -202,7 +202,7 @@ func TestClient_LoginRegistered(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
info := system.GetInfo(context.TODO())
|
||||
info := system.GetInfo(context.TODO(), nil)
|
||||
resp, err := client.Register(*key, ValidKey, "", info, nil)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
@@ -232,7 +232,7 @@ func TestClient_Sync(t *testing.T) {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
info := system.GetInfo(context.TODO())
|
||||
info := system.GetInfo(context.TODO(), nil)
|
||||
_, err = client.Register(*serverKey, ValidKey, "", info, nil)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
@@ -248,7 +248,7 @@ func TestClient_Sync(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
info = system.GetInfo(context.TODO())
|
||||
info = system.GetInfo(context.TODO(), nil)
|
||||
_, err = remoteClient.Register(*serverKey, ValidKey, "", info, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -346,7 +346,7 @@ func Test_SystemMetaDataFromClient(t *testing.T) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
info := system.GetInfo(context.TODO())
|
||||
info := system.GetInfo(context.TODO(), nil)
|
||||
_, err = testClient.Register(*key, ValidKey, "", info, nil)
|
||||
if err != nil {
|
||||
t.Errorf("error while trying to register client: %v", err)
|
||||
|
||||
@@ -110,11 +110,10 @@ type AccountManager interface {
|
||||
SaveGroups(ctx context.Context, accountID, userID string, newGroups []*nbgroup.Group) error
|
||||
DeleteGroup(ctx context.Context, accountId, userId, groupID string) error
|
||||
DeleteGroups(ctx context.Context, accountId, userId string, groupIDs []string) error
|
||||
ListGroups(ctx context.Context, accountId string) ([]*nbgroup.Group, error)
|
||||
GroupAddPeer(ctx context.Context, accountId, groupID, peerID string) error
|
||||
GroupDeletePeer(ctx context.Context, accountId, groupID, peerID string) error
|
||||
GetPolicy(ctx context.Context, accountID, policyID, userID string) (*Policy, error)
|
||||
SavePolicy(ctx context.Context, accountID, userID string, policy *Policy, isUpdate bool) error
|
||||
SavePolicy(ctx context.Context, accountID, userID string, policy *Policy) (*Policy, error)
|
||||
DeletePolicy(ctx context.Context, accountID, policyID, userID string) error
|
||||
ListPolicies(ctx context.Context, accountID, userID string) ([]*Policy, error)
|
||||
GetRoute(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error)
|
||||
@@ -140,7 +139,7 @@ type AccountManager interface {
|
||||
HasConnectedChannel(peerID string) bool
|
||||
GetExternalCacheManager() ExternalCacheManager
|
||||
GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error)
|
||||
SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error
|
||||
SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) (*posture.Checks, error)
|
||||
DeletePostureChecks(ctx context.Context, accountID, postureChecksID, userID string) error
|
||||
ListPostureChecks(ctx context.Context, accountID, userID string) ([]*posture.Checks, error)
|
||||
GetIdpManager() idp.Manager
|
||||
@@ -966,7 +965,9 @@ func (am *DefaultAccountManager) getJWTGroupsChanges(user *User, groups []*nbgro
|
||||
}
|
||||
|
||||
// UserGroupsAddToPeers adds groups to all peers of user
|
||||
func (a *Account) UserGroupsAddToPeers(userID string, groups ...string) {
|
||||
func (a *Account) UserGroupsAddToPeers(userID string, groups ...string) map[string][]string {
|
||||
groupUpdates := make(map[string][]string)
|
||||
|
||||
userPeers := make(map[string]struct{})
|
||||
for pid, peer := range a.Peers {
|
||||
if peer.UserID == userID {
|
||||
@@ -980,6 +981,8 @@ func (a *Account) UserGroupsAddToPeers(userID string, groups ...string) {
|
||||
continue
|
||||
}
|
||||
|
||||
oldPeers := group.Peers
|
||||
|
||||
groupPeers := make(map[string]struct{})
|
||||
for _, pid := range group.Peers {
|
||||
groupPeers[pid] = struct{}{}
|
||||
@@ -993,16 +996,25 @@ func (a *Account) UserGroupsAddToPeers(userID string, groups ...string) {
|
||||
for pid := range groupPeers {
|
||||
group.Peers = append(group.Peers, pid)
|
||||
}
|
||||
|
||||
groupUpdates[gid] = difference(group.Peers, oldPeers)
|
||||
}
|
||||
|
||||
return groupUpdates
|
||||
}
|
||||
|
||||
// UserGroupsRemoveFromPeers removes groups from all peers of user
|
||||
func (a *Account) UserGroupsRemoveFromPeers(userID string, groups ...string) {
|
||||
func (a *Account) UserGroupsRemoveFromPeers(userID string, groups ...string) map[string][]string {
|
||||
groupUpdates := make(map[string][]string)
|
||||
|
||||
for _, gid := range groups {
|
||||
group, ok := a.Groups[gid]
|
||||
if !ok || group.Name == "All" {
|
||||
continue
|
||||
}
|
||||
|
||||
oldPeers := group.Peers
|
||||
|
||||
update := make([]string, 0, len(group.Peers))
|
||||
for _, pid := range group.Peers {
|
||||
peer, ok := a.Peers[pid]
|
||||
@@ -1014,7 +1026,10 @@ func (a *Account) UserGroupsRemoveFromPeers(userID string, groups ...string) {
|
||||
}
|
||||
}
|
||||
group.Peers = update
|
||||
groupUpdates[gid] = difference(oldPeers, group.Peers)
|
||||
}
|
||||
|
||||
return groupUpdates
|
||||
}
|
||||
|
||||
// BuildManager creates a new DefaultAccountManager with a provided Store
|
||||
@@ -1176,6 +1191,11 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = am.handleGroupsPropagationSettings(ctx, oldSettings, newSettings, userID, accountID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("groups propagation failed: %w", err)
|
||||
}
|
||||
|
||||
updatedAccount := account.UpdateSettings(newSettings)
|
||||
|
||||
err = am.Store.SaveAccount(ctx, account)
|
||||
@@ -1186,21 +1206,39 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
|
||||
return updatedAccount, nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) handleInactivityExpirationSettings(ctx context.Context, account *Account, oldSettings, newSettings *Settings, userID, accountID string) error {
|
||||
if oldSettings.PeerInactivityExpirationEnabled != newSettings.PeerInactivityExpirationEnabled {
|
||||
event := activity.AccountPeerInactivityExpirationEnabled
|
||||
if !newSettings.PeerInactivityExpirationEnabled {
|
||||
event = activity.AccountPeerInactivityExpirationDisabled
|
||||
am.peerInactivityExpiry.Cancel(ctx, []string{accountID})
|
||||
func (am *DefaultAccountManager) handleGroupsPropagationSettings(ctx context.Context, oldSettings, newSettings *Settings, userID, accountID string) error {
|
||||
if oldSettings.GroupsPropagationEnabled != newSettings.GroupsPropagationEnabled {
|
||||
if newSettings.GroupsPropagationEnabled {
|
||||
am.StoreEvent(ctx, userID, accountID, accountID, activity.UserGroupPropagationEnabled, nil)
|
||||
// Todo: retroactively add user groups to all peers
|
||||
} else {
|
||||
am.checkAndSchedulePeerInactivityExpiration(ctx, account)
|
||||
am.StoreEvent(ctx, userID, accountID, accountID, activity.UserGroupPropagationDisabled, nil)
|
||||
}
|
||||
am.StoreEvent(ctx, userID, accountID, accountID, event, nil)
|
||||
}
|
||||
|
||||
if oldSettings.PeerInactivityExpiration != newSettings.PeerInactivityExpiration {
|
||||
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountPeerInactivityExpirationDurationUpdated, nil)
|
||||
am.checkAndSchedulePeerInactivityExpiration(ctx, account)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) handleInactivityExpirationSettings(ctx context.Context, account *Account, oldSettings, newSettings *Settings, userID, accountID string) error {
|
||||
|
||||
if newSettings.PeerInactivityExpirationEnabled {
|
||||
if oldSettings.PeerInactivityExpiration != newSettings.PeerInactivityExpiration {
|
||||
oldSettings.PeerInactivityExpiration = newSettings.PeerInactivityExpiration
|
||||
|
||||
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountPeerInactivityExpirationDurationUpdated, nil)
|
||||
am.checkAndSchedulePeerInactivityExpiration(ctx, account)
|
||||
}
|
||||
} else {
|
||||
if oldSettings.PeerInactivityExpirationEnabled != newSettings.PeerInactivityExpirationEnabled {
|
||||
event := activity.AccountPeerInactivityExpirationEnabled
|
||||
if !newSettings.PeerInactivityExpirationEnabled {
|
||||
event = activity.AccountPeerInactivityExpirationDisabled
|
||||
am.peerInactivityExpiry.Cancel(ctx, []string{accountID})
|
||||
} else {
|
||||
am.checkAndSchedulePeerInactivityExpiration(ctx, account)
|
||||
}
|
||||
am.StoreEvent(ctx, userID, accountID, accountID, event, nil)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -1435,7 +1473,7 @@ func isNil(i idp.Manager) bool {
|
||||
// addAccountIDToIDPAppMeta update user's app metadata in idp manager
|
||||
func (am *DefaultAccountManager) addAccountIDToIDPAppMeta(ctx context.Context, userID string, accountID string) error {
|
||||
if !isNil(am.idpManager) {
|
||||
accountUsers, err := am.Store.GetAccountUsers(ctx, accountID)
|
||||
accountUsers, err := am.Store.GetAccountUsers(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -2083,7 +2121,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
|
||||
return fmt.Errorf("error saving groups: %w", err)
|
||||
}
|
||||
|
||||
if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
|
||||
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
|
||||
return fmt.Errorf("error incrementing network serial: %w", err)
|
||||
}
|
||||
}
|
||||
@@ -2101,7 +2139,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
|
||||
}
|
||||
|
||||
for _, g := range addNewGroups {
|
||||
group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, g, accountID)
|
||||
group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, g)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, accountID)
|
||||
} else {
|
||||
@@ -2114,7 +2152,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
|
||||
}
|
||||
|
||||
for _, g := range removeOldGroups {
|
||||
group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, g, accountID)
|
||||
group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, g)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, accountID)
|
||||
} else {
|
||||
@@ -2127,14 +2165,19 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
|
||||
}
|
||||
|
||||
if settings.GroupsPropagationEnabled {
|
||||
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||
removedGroupAffectsPeers, err := areGroupChangesAffectPeers(ctx, am.Store, accountID, removeOldGroups)
|
||||
if err != nil {
|
||||
return status.NewGetAccountError(err)
|
||||
return err
|
||||
}
|
||||
|
||||
if areGroupChangesAffectPeers(account, addNewGroups) || areGroupChangesAffectPeers(account, removeOldGroups) {
|
||||
newGroupsAffectsPeers, err := areGroupChangesAffectPeers(ctx, am.Store, accountID, addNewGroups)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if removedGroupAffectsPeers || newGroupsAffectsPeers {
|
||||
log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId)
|
||||
am.updateAccountPeers(ctx, account)
|
||||
am.updateAccountPeers(ctx, accountID)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2319,7 +2362,7 @@ func (am *DefaultAccountManager) OnPeerDisconnected(ctx context.Context, account
|
||||
|
||||
err = am.MarkPeerConnected(ctx, peerPubKey, false, nil, account)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("failed marking peer as connected %s %v", peerPubKey, err)
|
||||
log.WithContext(ctx).Warnf("failed marking peer as disconnected %s %v", peerPubKey, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -2335,6 +2378,9 @@ func (am *DefaultAccountManager) SyncPeerMeta(ctx context.Context, peerPubKey st
|
||||
unlock := am.Store.AcquireReadLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
unlockPeer := am.Store.AcquireWriteLockByUID(ctx, peerPubKey)
|
||||
defer unlockPeer()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -2398,12 +2444,7 @@ func (am *DefaultAccountManager) CheckUserAccessByJWTGroups(ctx context.Context,
|
||||
|
||||
func (am *DefaultAccountManager) onPeersInvalidated(ctx context.Context, accountID string) {
|
||||
log.WithContext(ctx).Debugf("validated peers has been invalidated for account %s", accountID)
|
||||
updatedAccount, err := am.Store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get account %s: %v", accountID, err)
|
||||
return
|
||||
}
|
||||
am.updateAccountPeers(ctx, updatedAccount)
|
||||
am.updateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) {
|
||||
|
||||
@@ -6,13 +6,17 @@ import (
|
||||
b64 "encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
@@ -1038,7 +1042,7 @@ func BenchmarkTest_GetAccountWithclaims(b *testing.B) {
|
||||
}
|
||||
|
||||
b.Run("public without account ID", func(b *testing.B) {
|
||||
//b.ResetTimer()
|
||||
// b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, err := am.getAccountIDWithAuthorizationClaims(context.Background(), publicClaims)
|
||||
if err != nil {
|
||||
@@ -1048,7 +1052,7 @@ func BenchmarkTest_GetAccountWithclaims(b *testing.B) {
|
||||
})
|
||||
|
||||
b.Run("private without account ID", func(b *testing.B) {
|
||||
//b.ResetTimer()
|
||||
// b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, err := am.getAccountIDWithAuthorizationClaims(context.Background(), claims)
|
||||
if err != nil {
|
||||
@@ -1059,7 +1063,7 @@ func BenchmarkTest_GetAccountWithclaims(b *testing.B) {
|
||||
|
||||
b.Run("private with account ID", func(b *testing.B) {
|
||||
claims.AccountId = id
|
||||
//b.ResetTimer()
|
||||
// b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, err := am.getAccountIDWithAuthorizationClaims(context.Background(), claims)
|
||||
if err != nil {
|
||||
@@ -1238,8 +1242,7 @@ func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
policy := Policy{
|
||||
ID: "policy",
|
||||
_, err := manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
@@ -1250,8 +1253,7 @@ func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) {
|
||||
Action: PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
}
|
||||
err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false)
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
@@ -1320,19 +1322,6 @@ func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) {
|
||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
||||
|
||||
policy := Policy{
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
Enabled: true,
|
||||
Sources: []string{"groupA"},
|
||||
Destinations: []string{"groupA"},
|
||||
Bidirectional: true,
|
||||
Action: PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
@@ -1345,7 +1334,19 @@ func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) {
|
||||
}
|
||||
}()
|
||||
|
||||
if err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false); err != nil {
|
||||
_, err := manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
Enabled: true,
|
||||
Sources: []string{"groupA"},
|
||||
Destinations: []string{"groupA"},
|
||||
Bidirectional: true,
|
||||
Action: PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Errorf("delete default rule: %v", err)
|
||||
return
|
||||
}
|
||||
@@ -1366,7 +1367,7 @@ func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
policy := Policy{
|
||||
_, err := manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
@@ -1377,9 +1378,8 @@ func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) {
|
||||
Action: PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false); err != nil {
|
||||
})
|
||||
if err != nil {
|
||||
t.Errorf("save policy: %v", err)
|
||||
return
|
||||
}
|
||||
@@ -1413,13 +1413,20 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
|
||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
||||
|
||||
group := group.Group{
|
||||
err := manager.SaveGroup(context.Background(), account.Id, userID, &group.Group{
|
||||
ID: "groupA",
|
||||
Name: "GroupA",
|
||||
Peers: []string{peer1.ID, peer2.ID, peer3.ID},
|
||||
})
|
||||
|
||||
require.NoError(t, err, "failed to save group")
|
||||
|
||||
if err := manager.DeletePolicy(context.Background(), account.Id, account.Policies[0].ID, userID); err != nil {
|
||||
t.Errorf("delete default rule: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
policy := Policy{
|
||||
policy, err := manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
@@ -1430,14 +1437,8 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
|
||||
Action: PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if err := manager.DeletePolicy(context.Background(), account.Id, account.Policies[0].ID, userID); err != nil {
|
||||
t.Errorf("delete default rule: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false); err != nil {
|
||||
})
|
||||
if err != nil {
|
||||
t.Errorf("save policy: %v", err)
|
||||
return
|
||||
}
|
||||
@@ -1460,7 +1461,7 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
if err := manager.DeleteGroup(context.Background(), account.Id, "", group.ID); err != nil {
|
||||
if err := manager.DeleteGroup(context.Background(), account.Id, userID, "groupA"); err != nil {
|
||||
t.Errorf("delete group: %v", err)
|
||||
return
|
||||
}
|
||||
@@ -2714,7 +2715,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
|
||||
assert.NoError(t, err, "unable to get user")
|
||||
assert.Len(t, user.AutoGroups, 0)
|
||||
|
||||
group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "group1", "accountID")
|
||||
group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "accountID", "group1")
|
||||
assert.NoError(t, err, "unable to get group")
|
||||
assert.Equal(t, group1.Issued, group.GroupIssuedAPI, "group should be api issued")
|
||||
})
|
||||
@@ -2734,7 +2735,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
|
||||
assert.NoError(t, err, "unable to get user")
|
||||
assert.Len(t, user.AutoGroups, 1)
|
||||
|
||||
group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "group1", "accountID")
|
||||
group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "accountID", "group1")
|
||||
assert.NoError(t, err, "unable to get group")
|
||||
assert.Equal(t, group1.Issued, group.GroupIssuedAPI, "group should be api issued")
|
||||
})
|
||||
@@ -2985,3 +2986,218 @@ func peerShouldReceiveUpdate(t *testing.T, updateMessage <-chan *UpdateMessage)
|
||||
t.Error("Timed out waiting for update message")
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkSyncAndMarkPeer(b *testing.B) {
|
||||
benchCases := []struct {
|
||||
name string
|
||||
peers int
|
||||
groups int
|
||||
// We need different expectations for CI/CD and local runs because of the different performance characteristics
|
||||
minMsPerOpLocal float64
|
||||
maxMsPerOpLocal float64
|
||||
minMsPerOpCICD float64
|
||||
maxMsPerOpCICD float64
|
||||
}{
|
||||
{"Small", 50, 5, 1, 3, 4, 10},
|
||||
{"Medium", 500, 100, 7, 13, 10, 60},
|
||||
{"Large", 5000, 200, 65, 80, 60, 170},
|
||||
{"Small single", 50, 10, 1, 3, 4, 60},
|
||||
{"Medium single", 500, 10, 7, 13, 10, 26},
|
||||
{"Large 5", 5000, 15, 65, 80, 60, 170},
|
||||
}
|
||||
|
||||
log.SetOutput(io.Discard)
|
||||
defer log.SetOutput(os.Stderr)
|
||||
for _, bc := range benchCases {
|
||||
b.Run(bc.name, func(b *testing.B) {
|
||||
manager, accountID, _, err := setupTestAccountManager(b, bc.peers, bc.groups)
|
||||
if err != nil {
|
||||
b.Fatalf("Failed to setup test account manager: %v", err)
|
||||
}
|
||||
ctx := context.Background()
|
||||
account, err := manager.Store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
b.Fatalf("Failed to get account: %v", err)
|
||||
}
|
||||
peerChannels := make(map[string]chan *UpdateMessage)
|
||||
for peerID := range account.Peers {
|
||||
peerChannels[peerID] = make(chan *UpdateMessage, channelBufferSize)
|
||||
}
|
||||
manager.peersUpdateManager.peerChannels = peerChannels
|
||||
|
||||
b.ResetTimer()
|
||||
start := time.Now()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _, _, err := manager.SyncAndMarkPeer(context.Background(), account.Id, account.Peers["peer-1"].Key, nbpeer.PeerSystemMeta{Hostname: strconv.Itoa(i)}, net.IP{1, 1, 1, 1})
|
||||
assert.NoError(b, err)
|
||||
}
|
||||
|
||||
duration := time.Since(start)
|
||||
msPerOp := float64(duration.Nanoseconds()) / float64(b.N) / 1e6
|
||||
b.ReportMetric(msPerOp, "ms/op")
|
||||
|
||||
minExpected := bc.minMsPerOpLocal
|
||||
maxExpected := bc.maxMsPerOpLocal
|
||||
if os.Getenv("CI") == "true" {
|
||||
minExpected = bc.minMsPerOpCICD
|
||||
maxExpected = bc.maxMsPerOpCICD
|
||||
}
|
||||
|
||||
if msPerOp < minExpected {
|
||||
b.Fatalf("Benchmark %s failed: too fast (%.2f ms/op, minimum %.2f ms/op)", bc.name, msPerOp, minExpected)
|
||||
}
|
||||
|
||||
if msPerOp > maxExpected {
|
||||
b.Fatalf("Benchmark %s failed: too slow (%.2f ms/op, maximum %.2f ms/op)", bc.name, msPerOp, maxExpected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkLoginPeer_ExistingPeer(b *testing.B) {
|
||||
benchCases := []struct {
|
||||
name string
|
||||
peers int
|
||||
groups int
|
||||
// We need different expectations for CI/CD and local runs because of the different performance characteristics
|
||||
minMsPerOpLocal float64
|
||||
maxMsPerOpLocal float64
|
||||
minMsPerOpCICD float64
|
||||
maxMsPerOpCICD float64
|
||||
}{
|
||||
{"Small", 50, 5, 102, 110, 102, 120},
|
||||
{"Medium", 500, 100, 105, 140, 105, 170},
|
||||
{"Large", 5000, 200, 160, 200, 160, 270},
|
||||
{"Small single", 50, 10, 102, 110, 102, 120},
|
||||
{"Medium single", 500, 10, 105, 140, 105, 170},
|
||||
{"Large 5", 5000, 15, 160, 200, 160, 270},
|
||||
}
|
||||
|
||||
log.SetOutput(io.Discard)
|
||||
defer log.SetOutput(os.Stderr)
|
||||
for _, bc := range benchCases {
|
||||
b.Run(bc.name, func(b *testing.B) {
|
||||
manager, accountID, _, err := setupTestAccountManager(b, bc.peers, bc.groups)
|
||||
if err != nil {
|
||||
b.Fatalf("Failed to setup test account manager: %v", err)
|
||||
}
|
||||
ctx := context.Background()
|
||||
account, err := manager.Store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
b.Fatalf("Failed to get account: %v", err)
|
||||
}
|
||||
peerChannels := make(map[string]chan *UpdateMessage)
|
||||
for peerID := range account.Peers {
|
||||
peerChannels[peerID] = make(chan *UpdateMessage, channelBufferSize)
|
||||
}
|
||||
manager.peersUpdateManager.peerChannels = peerChannels
|
||||
|
||||
b.ResetTimer()
|
||||
start := time.Now()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _, _, err := manager.LoginPeer(context.Background(), PeerLogin{
|
||||
WireGuardPubKey: account.Peers["peer-1"].Key,
|
||||
SSHKey: "someKey",
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: strconv.Itoa(i)},
|
||||
UserID: "regular_user",
|
||||
SetupKey: "",
|
||||
ConnectionIP: net.IP{1, 1, 1, 1},
|
||||
})
|
||||
assert.NoError(b, err)
|
||||
}
|
||||
|
||||
duration := time.Since(start)
|
||||
msPerOp := float64(duration.Nanoseconds()) / float64(b.N) / 1e6
|
||||
b.ReportMetric(msPerOp, "ms/op")
|
||||
|
||||
minExpected := bc.minMsPerOpLocal
|
||||
maxExpected := bc.maxMsPerOpLocal
|
||||
if os.Getenv("CI") == "true" {
|
||||
minExpected = bc.minMsPerOpCICD
|
||||
maxExpected = bc.maxMsPerOpCICD
|
||||
}
|
||||
|
||||
if msPerOp < minExpected {
|
||||
b.Fatalf("Benchmark %s failed: too fast (%.2f ms/op, minimum %.2f ms/op)", bc.name, msPerOp, minExpected)
|
||||
}
|
||||
|
||||
if msPerOp > maxExpected {
|
||||
b.Fatalf("Benchmark %s failed: too slow (%.2f ms/op, maximum %.2f ms/op)", bc.name, msPerOp, maxExpected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkLoginPeer_NewPeer(b *testing.B) {
|
||||
benchCases := []struct {
|
||||
name string
|
||||
peers int
|
||||
groups int
|
||||
// We need different expectations for CI/CD and local runs because of the different performance characteristics
|
||||
minMsPerOpLocal float64
|
||||
maxMsPerOpLocal float64
|
||||
minMsPerOpCICD float64
|
||||
maxMsPerOpCICD float64
|
||||
}{
|
||||
{"Small", 50, 5, 107, 120, 107, 140},
|
||||
{"Medium", 500, 100, 105, 140, 105, 170},
|
||||
{"Large", 5000, 200, 180, 220, 180, 320},
|
||||
{"Small single", 50, 10, 107, 120, 105, 140},
|
||||
{"Medium single", 500, 10, 105, 140, 105, 170},
|
||||
{"Large 5", 5000, 15, 180, 220, 180, 320},
|
||||
}
|
||||
|
||||
log.SetOutput(io.Discard)
|
||||
defer log.SetOutput(os.Stderr)
|
||||
for _, bc := range benchCases {
|
||||
b.Run(bc.name, func(b *testing.B) {
|
||||
manager, accountID, _, err := setupTestAccountManager(b, bc.peers, bc.groups)
|
||||
if err != nil {
|
||||
b.Fatalf("Failed to setup test account manager: %v", err)
|
||||
}
|
||||
ctx := context.Background()
|
||||
account, err := manager.Store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
b.Fatalf("Failed to get account: %v", err)
|
||||
}
|
||||
peerChannels := make(map[string]chan *UpdateMessage)
|
||||
for peerID := range account.Peers {
|
||||
peerChannels[peerID] = make(chan *UpdateMessage, channelBufferSize)
|
||||
}
|
||||
manager.peersUpdateManager.peerChannels = peerChannels
|
||||
|
||||
b.ResetTimer()
|
||||
start := time.Now()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _, _, err := manager.LoginPeer(context.Background(), PeerLogin{
|
||||
WireGuardPubKey: "some-new-key" + strconv.Itoa(i),
|
||||
SSHKey: "someKey",
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: strconv.Itoa(i)},
|
||||
UserID: "regular_user",
|
||||
SetupKey: "",
|
||||
ConnectionIP: net.IP{1, 1, 1, 1},
|
||||
})
|
||||
assert.NoError(b, err)
|
||||
}
|
||||
|
||||
duration := time.Since(start)
|
||||
msPerOp := float64(duration.Nanoseconds()) / float64(b.N) / 1e6
|
||||
b.ReportMetric(msPerOp, "ms/op")
|
||||
|
||||
minExpected := bc.minMsPerOpLocal
|
||||
maxExpected := bc.maxMsPerOpLocal
|
||||
if os.Getenv("CI") == "true" {
|
||||
minExpected = bc.minMsPerOpCICD
|
||||
maxExpected = bc.maxMsPerOpCICD
|
||||
}
|
||||
|
||||
if msPerOp < minExpected {
|
||||
b.Fatalf("Benchmark %s failed: too fast (%.2f ms/op, minimum %.2f ms/op)", bc.name, msPerOp, minExpected)
|
||||
}
|
||||
|
||||
if msPerOp > maxExpected {
|
||||
b.Fatalf("Benchmark %s failed: too slow (%.2f ms/op, maximum %.2f ms/op)", bc.name, msPerOp, maxExpected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -148,6 +148,9 @@ const (
|
||||
AccountPeerInactivityExpirationDurationUpdated Activity = 67
|
||||
|
||||
SetupKeyDeleted Activity = 68
|
||||
|
||||
UserGroupPropagationEnabled Activity = 69
|
||||
UserGroupPropagationDisabled Activity = 70
|
||||
)
|
||||
|
||||
var activityMap = map[Activity]Code{
|
||||
@@ -222,6 +225,9 @@ var activityMap = map[Activity]Code{
|
||||
AccountPeerInactivityExpirationDisabled: {"Account peer inactivity expiration disabled", "account.peer.inactivity.expiration.disable"},
|
||||
AccountPeerInactivityExpirationDurationUpdated: {"Account peer inactivity expiration duration updated", "account.peer.inactivity.expiration.update"},
|
||||
SetupKeyDeleted: {"Setup key deleted", "setupkey.delete"},
|
||||
|
||||
UserGroupPropagationEnabled: {"User group propagation enabled", "account.setting.group.propagation.enable"},
|
||||
UserGroupPropagationDisabled: {"User group propagation disabled", "account.setting.group.propagation.disable"},
|
||||
}
|
||||
|
||||
// StringCode returns a string code of the activity
|
||||
|
||||
@@ -3,6 +3,7 @@ package server
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"slices"
|
||||
"strconv"
|
||||
"sync"
|
||||
|
||||
@@ -85,8 +86,12 @@ func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID s
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
|
||||
return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view DNS settings")
|
||||
if user.AccountID != accountID {
|
||||
return nil, status.NewUserNotPartOfAccountError()
|
||||
}
|
||||
|
||||
if user.IsRegularUser() {
|
||||
return nil, status.NewAdminPermissionError()
|
||||
}
|
||||
|
||||
return am.Store.GetAccountDNSSettings(ctx, LockingStrengthShare, accountID)
|
||||
@@ -94,64 +99,137 @@ func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID s
|
||||
|
||||
// SaveDNSSettings validates a user role and updates the account's DNS settings
|
||||
func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID string, userID string, dnsSettingsToSave *DNSSettings) error {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
user, err := account.FindUser(userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !user.HasAdminPower() {
|
||||
return status.Errorf(status.PermissionDenied, "only users with admin power are allowed to update DNS settings")
|
||||
}
|
||||
|
||||
if dnsSettingsToSave == nil {
|
||||
return status.Errorf(status.InvalidArgument, "the dns settings provided are nil")
|
||||
}
|
||||
|
||||
if len(dnsSettingsToSave.DisabledManagementGroups) != 0 {
|
||||
err = validateGroups(dnsSettingsToSave.DisabledManagementGroups, account.Groups)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
oldSettings := account.DNSSettings.Copy()
|
||||
account.DNSSettings = dnsSettingsToSave.Copy()
|
||||
|
||||
addedGroups := difference(dnsSettingsToSave.DisabledManagementGroups, oldSettings.DisabledManagementGroups)
|
||||
removedGroups := difference(oldSettings.DisabledManagementGroups, dnsSettingsToSave.DisabledManagementGroups)
|
||||
|
||||
account.Network.IncSerial()
|
||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, id := range addedGroups {
|
||||
group := account.GetGroup(id)
|
||||
meta := map[string]any{"group": group.Name, "group_id": group.ID}
|
||||
am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupAddedToDisabledManagementGroups, meta)
|
||||
if user.AccountID != accountID {
|
||||
return status.NewUserNotPartOfAccountError()
|
||||
}
|
||||
|
||||
for _, id := range removedGroups {
|
||||
group := account.GetGroup(id)
|
||||
meta := map[string]any{"group": group.Name, "group_id": group.ID}
|
||||
am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupRemovedFromDisabledManagementGroups, meta)
|
||||
if !user.HasAdminPower() {
|
||||
return status.NewAdminPermissionError()
|
||||
}
|
||||
|
||||
if anyGroupHasPeers(account, addedGroups) || anyGroupHasPeers(account, removedGroups) {
|
||||
am.updateAccountPeers(ctx, account)
|
||||
var updateAccountPeers bool
|
||||
var eventsToStore []func()
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||
if err = validateDNSSettings(ctx, transaction, accountID, dnsSettingsToSave); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
oldSettings, err := transaction.GetAccountDNSSettings(ctx, LockingStrengthUpdate, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
addedGroups := difference(dnsSettingsToSave.DisabledManagementGroups, oldSettings.DisabledManagementGroups)
|
||||
removedGroups := difference(oldSettings.DisabledManagementGroups, dnsSettingsToSave.DisabledManagementGroups)
|
||||
|
||||
updateAccountPeers, err = areDNSSettingChangesAffectPeers(ctx, transaction, accountID, addedGroups, removedGroups)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
events := am.prepareDNSSettingsEvents(ctx, transaction, accountID, userID, addedGroups, removedGroups)
|
||||
eventsToStore = append(eventsToStore, events...)
|
||||
|
||||
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return transaction.SaveDNSSettings(ctx, LockingStrengthUpdate, accountID, dnsSettingsToSave)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, storeEvent := range eventsToStore {
|
||||
storeEvent()
|
||||
}
|
||||
|
||||
if updateAccountPeers {
|
||||
am.updateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// prepareDNSSettingsEvents prepares a list of event functions to be stored.
|
||||
func (am *DefaultAccountManager) prepareDNSSettingsEvents(ctx context.Context, transaction Store, accountID, userID string, addedGroups, removedGroups []string) []func() {
|
||||
var eventsToStore []func()
|
||||
|
||||
modifiedGroups := slices.Concat(addedGroups, removedGroups)
|
||||
groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, modifiedGroups)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("failed to get groups for dns settings events: %v", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, groupID := range addedGroups {
|
||||
group, ok := groups[groupID]
|
||||
if !ok {
|
||||
log.WithContext(ctx).Debugf("skipped adding group: %s GroupAddedToDisabledManagementGroups activity", groupID)
|
||||
continue
|
||||
}
|
||||
|
||||
eventsToStore = append(eventsToStore, func() {
|
||||
meta := map[string]any{"group": group.Name, "group_id": group.ID}
|
||||
am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupAddedToDisabledManagementGroups, meta)
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
for _, groupID := range removedGroups {
|
||||
group, ok := groups[groupID]
|
||||
if !ok {
|
||||
log.WithContext(ctx).Debugf("skipped adding group: %s GroupRemovedFromDisabledManagementGroups activity", groupID)
|
||||
continue
|
||||
}
|
||||
|
||||
eventsToStore = append(eventsToStore, func() {
|
||||
meta := map[string]any{"group": group.Name, "group_id": group.ID}
|
||||
am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupRemovedFromDisabledManagementGroups, meta)
|
||||
})
|
||||
}
|
||||
|
||||
return eventsToStore
|
||||
}
|
||||
|
||||
// areDNSSettingChangesAffectPeers checks if the DNS settings changes affect any peers.
|
||||
func areDNSSettingChangesAffectPeers(ctx context.Context, transaction Store, accountID string, addedGroups, removedGroups []string) (bool, error) {
|
||||
hasPeers, err := anyGroupHasPeers(ctx, transaction, accountID, addedGroups)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if hasPeers {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
return anyGroupHasPeers(ctx, transaction, accountID, removedGroups)
|
||||
}
|
||||
|
||||
// validateDNSSettings validates the DNS settings.
|
||||
func validateDNSSettings(ctx context.Context, transaction Store, accountID string, settings *DNSSettings) error {
|
||||
if len(settings.DisabledManagementGroups) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, settings.DisabledManagementGroups)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return validateGroups(settings.DisabledManagementGroups, groups)
|
||||
}
|
||||
|
||||
// toProtocolDNSConfig converts nbdns.Config to proto.DNSConfig using the cache
|
||||
func toProtocolDNSConfig(update nbdns.Config, cache *DNSConfigCache) *proto.DNSConfig {
|
||||
protoUpdate := &proto.DNSConfig{
|
||||
|
||||
@@ -223,7 +223,7 @@ func restore(ctx context.Context, file string) (*FileStore, error) {
|
||||
// It is recommended to call it with locking FileStore.mux
|
||||
func (s *FileStore) persist(ctx context.Context, file string) error {
|
||||
start := time.Now()
|
||||
err := util.WriteJson(file, s)
|
||||
err := util.WriteJson(context.Background(), file, s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -6,11 +6,12 @@ import (
|
||||
"fmt"
|
||||
"slices"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
"github.com/rs/xid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
@@ -27,18 +28,17 @@ func (e *GroupLinkError) Error() string {
|
||||
|
||||
// CheckGroupPermissions validates if a user has the necessary permissions to view groups
|
||||
func (am *DefaultAccountManager) CheckGroupPermissions(ctx context.Context, accountID, userID string) error {
|
||||
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if (!user.IsAdminOrServiceUser() && settings.RegularUsersViewBlocked) || user.AccountID != accountID {
|
||||
return status.Errorf(status.PermissionDenied, "groups are blocked for users")
|
||||
if user.AccountID != accountID {
|
||||
return status.NewUserNotPartOfAccountError()
|
||||
}
|
||||
|
||||
if user.IsRegularUser() {
|
||||
return status.NewAdminPermissionError()
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -49,8 +49,7 @@ func (am *DefaultAccountManager) GetGroup(ctx context.Context, accountID, groupI
|
||||
if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return am.Store.GetGroupByID(ctx, LockingStrengthShare, groupID, accountID)
|
||||
return am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID)
|
||||
}
|
||||
|
||||
// GetAllGroups returns all groups in an account
|
||||
@@ -58,13 +57,12 @@ func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID, us
|
||||
if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID)
|
||||
}
|
||||
|
||||
// GetGroupByName filters all groups in an account by name and returns the one with the most peers
|
||||
func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*nbgroup.Group, error) {
|
||||
return am.Store.GetGroupByName(ctx, LockingStrengthShare, groupName, accountID)
|
||||
return am.Store.GetGroupByName(ctx, LockingStrengthShare, accountID, groupName)
|
||||
}
|
||||
|
||||
// SaveGroup object of the peers
|
||||
@@ -77,79 +75,74 @@ func (am *DefaultAccountManager) SaveGroup(ctx context.Context, accountID, userI
|
||||
// SaveGroups adds new groups to the account.
|
||||
// Note: This function does not acquire the global lock.
|
||||
// It is the caller's responsibility to ensure proper locking is in place before invoking this method.
|
||||
func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, userID string, newGroups []*nbgroup.Group) error {
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, userID string, groups []*nbgroup.Group) error {
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if user.AccountID != accountID {
|
||||
return status.NewUserNotPartOfAccountError()
|
||||
}
|
||||
|
||||
if user.IsRegularUser() {
|
||||
return status.NewAdminPermissionError()
|
||||
}
|
||||
|
||||
var eventsToStore []func()
|
||||
var groupsToSave []*nbgroup.Group
|
||||
var updateAccountPeers bool
|
||||
|
||||
for _, newGroup := range newGroups {
|
||||
if newGroup.ID == "" && newGroup.Issued != nbgroup.GroupIssuedAPI {
|
||||
return status.Errorf(status.InvalidArgument, "%s group without ID set", newGroup.Issued)
|
||||
}
|
||||
|
||||
if newGroup.ID == "" && newGroup.Issued == nbgroup.GroupIssuedAPI {
|
||||
existingGroup, err := account.FindGroupByName(newGroup.Name)
|
||||
if err != nil {
|
||||
s, ok := status.FromError(err)
|
||||
if !ok || s.ErrorType != status.NotFound {
|
||||
return err
|
||||
}
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||
groupIDs := make([]string, 0, len(groups))
|
||||
for _, newGroup := range groups {
|
||||
if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Avoid duplicate groups only for the API issued groups.
|
||||
// Integration or JWT groups can be duplicated as they are coming from the IdP that we don't have control of.
|
||||
if existingGroup != nil {
|
||||
return status.Errorf(status.AlreadyExists, "group with name %s already exists", newGroup.Name)
|
||||
}
|
||||
newGroup.AccountID = accountID
|
||||
groupsToSave = append(groupsToSave, newGroup)
|
||||
groupIDs = append(groupIDs, newGroup.ID)
|
||||
|
||||
newGroup.ID = xid.New().String()
|
||||
events := am.prepareGroupEvents(ctx, transaction, accountID, userID, newGroup)
|
||||
eventsToStore = append(eventsToStore, events...)
|
||||
}
|
||||
|
||||
for _, peerID := range newGroup.Peers {
|
||||
if account.Peers[peerID] == nil {
|
||||
return status.Errorf(status.InvalidArgument, "peer with ID \"%s\" not found", peerID)
|
||||
}
|
||||
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, groupIDs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
oldGroup := account.Groups[newGroup.ID]
|
||||
account.Groups[newGroup.ID] = newGroup
|
||||
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
events := am.prepareGroupEvents(ctx, userID, accountID, newGroup, oldGroup, account)
|
||||
eventsToStore = append(eventsToStore, events...)
|
||||
}
|
||||
|
||||
newGroupIDs := make([]string, 0, len(newGroups))
|
||||
for _, newGroup := range newGroups {
|
||||
newGroupIDs = append(newGroupIDs, newGroup.ID)
|
||||
}
|
||||
|
||||
account.Network.IncSerial()
|
||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
||||
return transaction.SaveGroups(ctx, LockingStrengthUpdate, groupsToSave)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if areGroupChangesAffectPeers(account, newGroupIDs) {
|
||||
am.updateAccountPeers(ctx, account)
|
||||
}
|
||||
|
||||
for _, storeEvent := range eventsToStore {
|
||||
storeEvent()
|
||||
}
|
||||
|
||||
if updateAccountPeers {
|
||||
am.updateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// prepareGroupEvents prepares a list of event functions to be stored.
|
||||
func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, userID string, accountID string, newGroup, oldGroup *nbgroup.Group, account *Account) []func() {
|
||||
func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, transaction Store, accountID, userID string, newGroup *nbgroup.Group) []func() {
|
||||
var eventsToStore []func()
|
||||
|
||||
addedPeers := make([]string, 0)
|
||||
removedPeers := make([]string, 0)
|
||||
|
||||
if oldGroup != nil {
|
||||
oldGroup, err := transaction.GetGroupByID(ctx, LockingStrengthShare, accountID, newGroup.ID)
|
||||
if err == nil && oldGroup != nil {
|
||||
addedPeers = difference(newGroup.Peers, oldGroup.Peers)
|
||||
removedPeers = difference(oldGroup.Peers, newGroup.Peers)
|
||||
} else {
|
||||
@@ -159,35 +152,42 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, userID
|
||||
})
|
||||
}
|
||||
|
||||
for _, p := range addedPeers {
|
||||
peer := account.Peers[p]
|
||||
if peer == nil {
|
||||
log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", p, accountID)
|
||||
modifiedPeers := slices.Concat(addedPeers, removedPeers)
|
||||
peers, err := transaction.GetPeersByIDs(ctx, LockingStrengthShare, accountID, modifiedPeers)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("failed to get peers for group events: %v", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, peerID := range addedPeers {
|
||||
peer, ok := peers[peerID]
|
||||
if !ok {
|
||||
log.WithContext(ctx).Debugf("skipped adding peer: %s GroupAddedToPeer activity: peer not found in store", peerID)
|
||||
continue
|
||||
}
|
||||
peerCopy := peer // copy to avoid closure issues
|
||||
|
||||
eventsToStore = append(eventsToStore, func() {
|
||||
am.StoreEvent(ctx, userID, peerCopy.ID, accountID, activity.GroupAddedToPeer,
|
||||
map[string]any{
|
||||
"group": newGroup.Name, "group_id": newGroup.ID, "peer_ip": peerCopy.IP.String(),
|
||||
"peer_fqdn": peerCopy.FQDN(am.GetDNSDomain()),
|
||||
})
|
||||
meta := map[string]any{
|
||||
"group": newGroup.Name, "group_id": newGroup.ID,
|
||||
"peer_ip": peer.IP.String(), "peer_fqdn": peer.FQDN(am.GetDNSDomain()),
|
||||
}
|
||||
am.StoreEvent(ctx, userID, peer.ID, accountID, activity.GroupAddedToPeer, meta)
|
||||
})
|
||||
}
|
||||
|
||||
for _, p := range removedPeers {
|
||||
peer := account.Peers[p]
|
||||
if peer == nil {
|
||||
log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", p, accountID)
|
||||
for _, peerID := range removedPeers {
|
||||
peer, ok := peers[peerID]
|
||||
if !ok {
|
||||
log.WithContext(ctx).Debugf("skipped adding peer: %s GroupRemovedFromPeer activity: peer not found in store", peerID)
|
||||
continue
|
||||
}
|
||||
peerCopy := peer // copy to avoid closure issues
|
||||
|
||||
eventsToStore = append(eventsToStore, func() {
|
||||
am.StoreEvent(ctx, userID, peerCopy.ID, accountID, activity.GroupRemovedFromPeer,
|
||||
map[string]any{
|
||||
"group": newGroup.Name, "group_id": newGroup.ID, "peer_ip": peerCopy.IP.String(),
|
||||
"peer_fqdn": peerCopy.FQDN(am.GetDNSDomain()),
|
||||
})
|
||||
meta := map[string]any{
|
||||
"group": newGroup.Name, "group_id": newGroup.ID,
|
||||
"peer_ip": peer.IP.String(), "peer_fqdn": peer.FQDN(am.GetDNSDomain()),
|
||||
}
|
||||
am.StoreEvent(ctx, userID, peer.ID, accountID, activity.GroupRemovedFromPeer, meta)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -210,42 +210,10 @@ func difference(a, b []string) []string {
|
||||
}
|
||||
|
||||
// DeleteGroup object of the peers.
|
||||
func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountId, userId, groupID string) error {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountId)
|
||||
func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountID, userID, groupID string) error {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
group, ok := account.Groups[groupID]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
allGroup, err := account.GetGroupAll()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if allGroup.ID == groupID {
|
||||
return status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed")
|
||||
}
|
||||
|
||||
if err = validateDeleteGroup(account, group, userId); err != nil {
|
||||
return err
|
||||
}
|
||||
delete(account.Groups, groupID)
|
||||
|
||||
account.Network.IncSerial()
|
||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
am.StoreEvent(ctx, userId, groupID, accountId, activity.GroupDeleted, group.EventMeta())
|
||||
|
||||
return nil
|
||||
return am.DeleteGroups(ctx, accountID, userID, []string{groupID})
|
||||
}
|
||||
|
||||
// DeleteGroups deletes groups from an account.
|
||||
@@ -254,93 +222,94 @@ func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountId, use
|
||||
//
|
||||
// If an error occurs while deleting a group, the function skips it and continues deleting other groups.
|
||||
// Errors are collected and returned at the end.
|
||||
func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountId, userId string, groupIDs []string) error {
|
||||
account, err := am.Store.GetAccount(ctx, accountId)
|
||||
func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, userID string, groupIDs []string) error {
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if user.AccountID != accountID {
|
||||
return status.NewUserNotPartOfAccountError()
|
||||
}
|
||||
|
||||
if user.IsRegularUser() {
|
||||
return status.NewAdminPermissionError()
|
||||
}
|
||||
|
||||
var allErrors error
|
||||
var groupIDsToDelete []string
|
||||
var deletedGroups []*nbgroup.Group
|
||||
|
||||
deletedGroups := make([]*nbgroup.Group, 0, len(groupIDs))
|
||||
for _, groupID := range groupIDs {
|
||||
group, ok := account.Groups[groupID]
|
||||
if !ok {
|
||||
continue
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||
for _, groupID := range groupIDs {
|
||||
group, err := transaction.GetGroupByID(ctx, LockingStrengthUpdate, accountID, groupID)
|
||||
if err != nil {
|
||||
allErrors = errors.Join(allErrors, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if err := validateDeleteGroup(ctx, transaction, group, userID); err != nil {
|
||||
allErrors = errors.Join(allErrors, err)
|
||||
continue
|
||||
}
|
||||
|
||||
groupIDsToDelete = append(groupIDsToDelete, groupID)
|
||||
deletedGroups = append(deletedGroups, group)
|
||||
}
|
||||
|
||||
if err := validateDeleteGroup(account, group, userId); err != nil {
|
||||
allErrors = errors.Join(allErrors, fmt.Errorf("failed to delete group %s: %w", groupID, err))
|
||||
continue
|
||||
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
delete(account.Groups, groupID)
|
||||
deletedGroups = append(deletedGroups, group)
|
||||
}
|
||||
|
||||
account.Network.IncSerial()
|
||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
||||
return transaction.DeleteGroups(ctx, LockingStrengthUpdate, accountID, groupIDsToDelete)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, g := range deletedGroups {
|
||||
am.StoreEvent(ctx, userId, g.ID, accountId, activity.GroupDeleted, g.EventMeta())
|
||||
for _, group := range deletedGroups {
|
||||
am.StoreEvent(ctx, userID, group.ID, accountID, activity.GroupDeleted, group.EventMeta())
|
||||
}
|
||||
|
||||
return allErrors
|
||||
}
|
||||
|
||||
// ListGroups objects of the peers
|
||||
func (am *DefaultAccountManager) ListGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
groups := make([]*nbgroup.Group, 0, len(account.Groups))
|
||||
for _, item := range account.Groups {
|
||||
groups = append(groups, item)
|
||||
}
|
||||
|
||||
return groups, nil
|
||||
}
|
||||
|
||||
// GroupAddPeer appends peer to the group
|
||||
func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, groupID, peerID string) error {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
var group *nbgroup.Group
|
||||
var updateAccountPeers bool
|
||||
var err error
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||
group, err = transaction.GetGroupByID(context.Background(), LockingStrengthUpdate, accountID, groupID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if updated := group.AddPeer(peerID); !updated {
|
||||
return nil
|
||||
}
|
||||
|
||||
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{groupID})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return transaction.SaveGroup(ctx, LockingStrengthUpdate, group)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
group, ok := account.Groups[groupID]
|
||||
if !ok {
|
||||
return status.Errorf(status.NotFound, "group with ID %s not found", groupID)
|
||||
}
|
||||
|
||||
add := true
|
||||
for _, itemID := range group.Peers {
|
||||
if itemID == peerID {
|
||||
add = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if add {
|
||||
group.Peers = append(group.Peers, peerID)
|
||||
}
|
||||
|
||||
account.Network.IncSerial()
|
||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if areGroupChangesAffectPeers(account, []string{group.ID}) {
|
||||
am.updateAccountPeers(ctx, account)
|
||||
if updateAccountPeers {
|
||||
am.updateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -351,90 +320,162 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID,
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
var group *nbgroup.Group
|
||||
var updateAccountPeers bool
|
||||
var err error
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||
group, err = transaction.GetGroupByID(context.Background(), LockingStrengthUpdate, accountID, groupID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if updated := group.RemovePeer(peerID); !updated {
|
||||
return nil
|
||||
}
|
||||
|
||||
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{groupID})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return transaction.SaveGroup(ctx, LockingStrengthUpdate, group)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
group, ok := account.Groups[groupID]
|
||||
if !ok {
|
||||
return status.Errorf(status.NotFound, "group with ID %s not found", groupID)
|
||||
}
|
||||
|
||||
account.Network.IncSerial()
|
||||
for i, itemID := range group.Peers {
|
||||
if itemID == peerID {
|
||||
group.Peers = append(group.Peers[:i], group.Peers[i+1:]...)
|
||||
if err := am.Store.SaveAccount(ctx, account); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if areGroupChangesAffectPeers(account, []string{group.ID}) {
|
||||
am.updateAccountPeers(ctx, account)
|
||||
if updateAccountPeers {
|
||||
am.updateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateDeleteGroup(account *Account, group *nbgroup.Group, userID string) error {
|
||||
// validateNewGroup validates the new group for existence and required fields.
|
||||
func validateNewGroup(ctx context.Context, transaction Store, accountID string, newGroup *nbgroup.Group) error {
|
||||
if newGroup.ID == "" && newGroup.Issued != nbgroup.GroupIssuedAPI {
|
||||
return status.Errorf(status.InvalidArgument, "%s group without ID set", newGroup.Issued)
|
||||
}
|
||||
|
||||
if newGroup.ID == "" && newGroup.Issued == nbgroup.GroupIssuedAPI {
|
||||
existingGroup, err := transaction.GetGroupByName(ctx, LockingStrengthShare, accountID, newGroup.Name)
|
||||
if err != nil {
|
||||
if s, ok := status.FromError(err); !ok || s.Type() != status.NotFound {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Prevent duplicate groups for API-issued groups.
|
||||
// Integration or JWT groups can be duplicated as they are coming from the IdP that we don't have control of.
|
||||
if existingGroup != nil {
|
||||
return status.Errorf(status.AlreadyExists, "group with name %s already exists", newGroup.Name)
|
||||
}
|
||||
|
||||
newGroup.ID = xid.New().String()
|
||||
}
|
||||
|
||||
for _, peerID := range newGroup.Peers {
|
||||
_, err := transaction.GetPeerByID(ctx, LockingStrengthShare, accountID, peerID)
|
||||
if err != nil {
|
||||
return status.Errorf(status.InvalidArgument, "peer with ID \"%s\" not found", peerID)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateDeleteGroup(ctx context.Context, transaction Store, group *nbgroup.Group, userID string) error {
|
||||
// disable a deleting integration group if the initiator is not an admin service user
|
||||
if group.Issued == nbgroup.GroupIssuedIntegration {
|
||||
executingUser := account.Users[userID]
|
||||
if executingUser == nil {
|
||||
return status.Errorf(status.NotFound, "user not found")
|
||||
executingUser, err := transaction.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if executingUser.Role != UserRoleAdmin || !executingUser.IsServiceUser {
|
||||
return status.Errorf(status.PermissionDenied, "only service users with admin power can delete integration group")
|
||||
}
|
||||
}
|
||||
|
||||
if isLinked, linkedRoute := isGroupLinkedToRoute(account.Routes, group.ID); isLinked {
|
||||
if group.IsGroupAll() {
|
||||
return status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed")
|
||||
}
|
||||
|
||||
if isLinked, linkedRoute := isGroupLinkedToRoute(ctx, transaction, group.AccountID, group.ID); isLinked {
|
||||
return &GroupLinkError{"route", string(linkedRoute.NetID)}
|
||||
}
|
||||
|
||||
if isLinked, linkedDns := isGroupLinkedToDns(account.NameServerGroups, group.ID); isLinked {
|
||||
if isLinked, linkedDns := isGroupLinkedToDns(ctx, transaction, group.AccountID, group.ID); isLinked {
|
||||
return &GroupLinkError{"name server groups", linkedDns.Name}
|
||||
}
|
||||
|
||||
if isLinked, linkedPolicy := isGroupLinkedToPolicy(account.Policies, group.ID); isLinked {
|
||||
if isLinked, linkedPolicy := isGroupLinkedToPolicy(ctx, transaction, group.AccountID, group.ID); isLinked {
|
||||
return &GroupLinkError{"policy", linkedPolicy.Name}
|
||||
}
|
||||
|
||||
if isLinked, linkedSetupKey := isGroupLinkedToSetupKey(account.SetupKeys, group.ID); isLinked {
|
||||
if isLinked, linkedSetupKey := isGroupLinkedToSetupKey(ctx, transaction, group.AccountID, group.ID); isLinked {
|
||||
return &GroupLinkError{"setup key", linkedSetupKey.Name}
|
||||
}
|
||||
|
||||
if isLinked, linkedUser := isGroupLinkedToUser(account.Users, group.ID); isLinked {
|
||||
if isLinked, linkedUser := isGroupLinkedToUser(ctx, transaction, group.AccountID, group.ID); isLinked {
|
||||
return &GroupLinkError{"user", linkedUser.Id}
|
||||
}
|
||||
|
||||
if slices.Contains(account.DNSSettings.DisabledManagementGroups, group.ID) {
|
||||
return checkGroupLinkedToSettings(ctx, transaction, group)
|
||||
}
|
||||
|
||||
// checkGroupLinkedToSettings verifies if a group is linked to any settings in the account.
|
||||
func checkGroupLinkedToSettings(ctx context.Context, transaction Store, group *nbgroup.Group) error {
|
||||
dnsSettings, err := transaction.GetAccountDNSSettings(ctx, LockingStrengthShare, group.AccountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if slices.Contains(dnsSettings.DisabledManagementGroups, group.ID) {
|
||||
return &GroupLinkError{"disabled DNS management groups", group.Name}
|
||||
}
|
||||
|
||||
if account.Settings.Extra != nil {
|
||||
if slices.Contains(account.Settings.Extra.IntegratedValidatorGroups, group.ID) {
|
||||
return &GroupLinkError{"integrated validator", group.Name}
|
||||
}
|
||||
settings, err := transaction.GetAccountSettings(ctx, LockingStrengthShare, group.AccountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if settings.Extra != nil && slices.Contains(settings.Extra.IntegratedValidatorGroups, group.ID) {
|
||||
return &GroupLinkError{"integrated validator", group.Name}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// isGroupLinkedToRoute checks if a group is linked to any route in the account.
|
||||
func isGroupLinkedToRoute(routes map[route.ID]*route.Route, groupID string) (bool, *route.Route) {
|
||||
func isGroupLinkedToRoute(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *route.Route) {
|
||||
routes, err := transaction.GetAccountRoutes(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("error retrieving routes while checking group linkage: %v", err)
|
||||
return false, nil
|
||||
}
|
||||
|
||||
for _, r := range routes {
|
||||
if slices.Contains(r.Groups, groupID) || slices.Contains(r.PeerGroups, groupID) {
|
||||
return true, r
|
||||
}
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// isGroupLinkedToPolicy checks if a group is linked to any policy in the account.
|
||||
func isGroupLinkedToPolicy(policies []*Policy, groupID string) (bool, *Policy) {
|
||||
func isGroupLinkedToPolicy(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *Policy) {
|
||||
policies, err := transaction.GetAccountPolicies(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("error retrieving policies while checking group linkage: %v", err)
|
||||
return false, nil
|
||||
}
|
||||
|
||||
for _, policy := range policies {
|
||||
for _, rule := range policy.Rules {
|
||||
if slices.Contains(rule.Sources, groupID) || slices.Contains(rule.Destinations, groupID) {
|
||||
@@ -446,7 +487,13 @@ func isGroupLinkedToPolicy(policies []*Policy, groupID string) (bool, *Policy) {
|
||||
}
|
||||
|
||||
// isGroupLinkedToDns checks if a group is linked to any nameserver group in the account.
|
||||
func isGroupLinkedToDns(nameServerGroups map[string]*nbdns.NameServerGroup, groupID string) (bool, *nbdns.NameServerGroup) {
|
||||
func isGroupLinkedToDns(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *nbdns.NameServerGroup) {
|
||||
nameServerGroups, err := transaction.GetAccountNameServerGroups(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("error retrieving name server groups while checking group linkage: %v", err)
|
||||
return false, nil
|
||||
}
|
||||
|
||||
for _, dns := range nameServerGroups {
|
||||
for _, g := range dns.Groups {
|
||||
if g == groupID {
|
||||
@@ -454,11 +501,18 @@ func isGroupLinkedToDns(nameServerGroups map[string]*nbdns.NameServerGroup, grou
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// isGroupLinkedToSetupKey checks if a group is linked to any setup key in the account.
|
||||
func isGroupLinkedToSetupKey(setupKeys map[string]*SetupKey, groupID string) (bool, *SetupKey) {
|
||||
func isGroupLinkedToSetupKey(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *SetupKey) {
|
||||
setupKeys, err := transaction.GetAccountSetupKeys(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("error retrieving setup keys while checking group linkage: %v", err)
|
||||
return false, nil
|
||||
}
|
||||
|
||||
for _, setupKey := range setupKeys {
|
||||
if slices.Contains(setupKey.AutoGroups, groupID) {
|
||||
return true, setupKey
|
||||
@@ -468,7 +522,13 @@ func isGroupLinkedToSetupKey(setupKeys map[string]*SetupKey, groupID string) (bo
|
||||
}
|
||||
|
||||
// isGroupLinkedToUser checks if a group is linked to any user in the account.
|
||||
func isGroupLinkedToUser(users map[string]*User, groupID string) (bool, *User) {
|
||||
func isGroupLinkedToUser(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *User) {
|
||||
users, err := transaction.GetAccountUsers(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("error retrieving users while checking group linkage: %v", err)
|
||||
return false, nil
|
||||
}
|
||||
|
||||
for _, user := range users {
|
||||
if slices.Contains(user.AutoGroups, groupID) {
|
||||
return true, user
|
||||
@@ -477,8 +537,36 @@ func isGroupLinkedToUser(users map[string]*User, groupID string) (bool, *User) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// anyGroupHasPeers checks if any of the given groups in the account have peers.
|
||||
func anyGroupHasPeers(account *Account, groupIDs []string) bool {
|
||||
// areGroupChangesAffectPeers checks if any changes to the specified groups will affect peers.
|
||||
func areGroupChangesAffectPeers(ctx context.Context, transaction Store, accountID string, groupIDs []string) (bool, error) {
|
||||
if len(groupIDs) == 0 {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
dnsSettings, err := transaction.GetAccountDNSSettings(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
for _, groupID := range groupIDs {
|
||||
if slices.Contains(dnsSettings.DisabledManagementGroups, groupID) {
|
||||
return true, nil
|
||||
}
|
||||
if linked, _ := isGroupLinkedToDns(ctx, transaction, accountID, groupID); linked {
|
||||
return true, nil
|
||||
}
|
||||
if linked, _ := isGroupLinkedToPolicy(ctx, transaction, accountID, groupID); linked {
|
||||
return true, nil
|
||||
}
|
||||
if linked, _ := isGroupLinkedToRoute(ctx, transaction, accountID, groupID); linked {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) anyGroupHasPeers(account *Account, groupIDs []string) bool {
|
||||
for _, groupID := range groupIDs {
|
||||
if group, exists := account.Groups[groupID]; exists && group.HasPeers() {
|
||||
return true
|
||||
@@ -487,21 +575,18 @@ func anyGroupHasPeers(account *Account, groupIDs []string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func areGroupChangesAffectPeers(account *Account, groupIDs []string) bool {
|
||||
for _, groupID := range groupIDs {
|
||||
if slices.Contains(account.DNSSettings.DisabledManagementGroups, groupID) {
|
||||
return true
|
||||
}
|
||||
if linked, _ := isGroupLinkedToDns(account.NameServerGroups, groupID); linked {
|
||||
return true
|
||||
}
|
||||
if linked, _ := isGroupLinkedToPolicy(account.Policies, groupID); linked {
|
||||
return true
|
||||
}
|
||||
if linked, _ := isGroupLinkedToRoute(account.Routes, groupID); linked {
|
||||
return true
|
||||
// anyGroupHasPeers checks if any of the given groups in the account have peers.
|
||||
func anyGroupHasPeers(ctx context.Context, transaction Store, accountID string, groupIDs []string) (bool, error) {
|
||||
groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, groupIDs)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
for _, group := range groups {
|
||||
if group.HasPeers() {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
return false, nil
|
||||
}
|
||||
|
||||
@@ -54,3 +54,30 @@ func (g *Group) HasPeers() bool {
|
||||
func (g *Group) IsGroupAll() bool {
|
||||
return g.Name == "All"
|
||||
}
|
||||
|
||||
// AddPeer adds peerID to Peers if not present, returning true if added.
|
||||
func (g *Group) AddPeer(peerID string) bool {
|
||||
if peerID == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, itemID := range g.Peers {
|
||||
if itemID == peerID {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
g.Peers = append(g.Peers, peerID)
|
||||
return true
|
||||
}
|
||||
|
||||
// RemovePeer removes peerID from Peers if present, returning true if removed.
|
||||
func (g *Group) RemovePeer(peerID string) bool {
|
||||
for i, itemID := range g.Peers {
|
||||
if itemID == peerID {
|
||||
g.Peers = append(g.Peers[:i], g.Peers[i+1:]...)
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
90
management/server/group/group_test.go
Normal file
90
management/server/group/group_test.go
Normal file
@@ -0,0 +1,90 @@
|
||||
package group
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestAddPeer(t *testing.T) {
|
||||
t.Run("add new peer to empty slice", func(t *testing.T) {
|
||||
group := &Group{Peers: []string{}}
|
||||
peerID := "peer1"
|
||||
assert.True(t, group.AddPeer(peerID))
|
||||
assert.Contains(t, group.Peers, peerID)
|
||||
})
|
||||
|
||||
t.Run("add new peer to nil slice", func(t *testing.T) {
|
||||
group := &Group{Peers: nil}
|
||||
peerID := "peer1"
|
||||
assert.True(t, group.AddPeer(peerID))
|
||||
assert.Contains(t, group.Peers, peerID)
|
||||
})
|
||||
|
||||
t.Run("add new peer to non-empty slice", func(t *testing.T) {
|
||||
group := &Group{Peers: []string{"peer1", "peer2"}}
|
||||
peerID := "peer3"
|
||||
assert.True(t, group.AddPeer(peerID))
|
||||
assert.Contains(t, group.Peers, peerID)
|
||||
})
|
||||
|
||||
t.Run("add duplicate peer", func(t *testing.T) {
|
||||
group := &Group{Peers: []string{"peer1", "peer2"}}
|
||||
peerID := "peer1"
|
||||
assert.False(t, group.AddPeer(peerID))
|
||||
assert.Equal(t, 2, len(group.Peers))
|
||||
})
|
||||
|
||||
t.Run("add empty peer", func(t *testing.T) {
|
||||
group := &Group{Peers: []string{"peer1", "peer2"}}
|
||||
peerID := ""
|
||||
assert.False(t, group.AddPeer(peerID))
|
||||
assert.Equal(t, 2, len(group.Peers))
|
||||
})
|
||||
}
|
||||
|
||||
func TestRemovePeer(t *testing.T) {
|
||||
t.Run("remove existing peer from slice", func(t *testing.T) {
|
||||
group := &Group{Peers: []string{"peer1", "peer2", "peer3"}}
|
||||
peerID := "peer2"
|
||||
assert.True(t, group.RemovePeer(peerID))
|
||||
assert.NotContains(t, group.Peers, peerID)
|
||||
assert.Equal(t, 2, len(group.Peers))
|
||||
})
|
||||
|
||||
t.Run("remove peer from empty slice", func(t *testing.T) {
|
||||
group := &Group{Peers: []string{}}
|
||||
peerID := "peer1"
|
||||
assert.False(t, group.RemovePeer(peerID))
|
||||
assert.Equal(t, 0, len(group.Peers))
|
||||
})
|
||||
|
||||
t.Run("remove peer from nil slice", func(t *testing.T) {
|
||||
group := &Group{Peers: nil}
|
||||
peerID := "peer1"
|
||||
assert.False(t, group.RemovePeer(peerID))
|
||||
assert.Nil(t, group.Peers)
|
||||
})
|
||||
|
||||
t.Run("remove non-existent peer", func(t *testing.T) {
|
||||
group := &Group{Peers: []string{"peer1", "peer2"}}
|
||||
peerID := "peer3"
|
||||
assert.False(t, group.RemovePeer(peerID))
|
||||
assert.Equal(t, 2, len(group.Peers))
|
||||
})
|
||||
|
||||
t.Run("remove peer from single-item slice", func(t *testing.T) {
|
||||
group := &Group{Peers: []string{"peer1"}}
|
||||
peerID := "peer1"
|
||||
assert.True(t, group.RemovePeer(peerID))
|
||||
assert.Equal(t, 0, len(group.Peers))
|
||||
assert.NotContains(t, group.Peers, peerID)
|
||||
})
|
||||
|
||||
t.Run("remove empty peer", func(t *testing.T) {
|
||||
group := &Group{Peers: []string{"peer1", "peer2"}}
|
||||
peerID := ""
|
||||
assert.False(t, group.RemovePeer(peerID))
|
||||
assert.Equal(t, 2, len(group.Peers))
|
||||
})
|
||||
}
|
||||
@@ -208,7 +208,7 @@ func TestDefaultAccountManager_DeleteGroups(t *testing.T) {
|
||||
{
|
||||
name: "delete non-existent group",
|
||||
groupIDs: []string{"non-existent-group"},
|
||||
expectedDeleted: []string{"non-existent-group"},
|
||||
expectedReasons: []string{"group: non-existent-group not found"},
|
||||
},
|
||||
{
|
||||
name: "delete multiple groups with mixed results",
|
||||
@@ -500,8 +500,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
})
|
||||
|
||||
// adding a group to policy
|
||||
err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
|
||||
ID: "policy",
|
||||
_, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
@@ -512,7 +511,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
Action: PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
}, false)
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Saving a group linked to policy should update account peers and send peer update
|
||||
|
||||
@@ -439,17 +439,13 @@ components:
|
||||
example: 5
|
||||
required:
|
||||
- accessible_peers_count
|
||||
SetupKey:
|
||||
SetupKeyBase:
|
||||
type: object
|
||||
properties:
|
||||
id:
|
||||
description: Setup Key ID
|
||||
type: string
|
||||
example: 2531583362
|
||||
key:
|
||||
description: Setup Key value
|
||||
type: string
|
||||
example: A616097E-FCF0-48FA-9354-CA4A61142761
|
||||
name:
|
||||
description: Setup key name identifier
|
||||
type: string
|
||||
@@ -518,22 +514,31 @@ components:
|
||||
- updated_at
|
||||
- usage_limit
|
||||
- ephemeral
|
||||
SetupKeyClear:
|
||||
allOf:
|
||||
- $ref: '#/components/schemas/SetupKeyBase'
|
||||
- type: object
|
||||
properties:
|
||||
key:
|
||||
description: Setup Key as plain text
|
||||
type: string
|
||||
example: A616097E-FCF0-48FA-9354-CA4A61142761
|
||||
required:
|
||||
- key
|
||||
SetupKey:
|
||||
allOf:
|
||||
- $ref: '#/components/schemas/SetupKeyBase'
|
||||
- type: object
|
||||
properties:
|
||||
key:
|
||||
description: Setup Key as secret
|
||||
type: string
|
||||
example: A6160****
|
||||
required:
|
||||
- key
|
||||
SetupKeyRequest:
|
||||
type: object
|
||||
properties:
|
||||
name:
|
||||
description: Setup Key name
|
||||
type: string
|
||||
example: Default key
|
||||
type:
|
||||
description: Setup key type, one-off for single time usage and reusable
|
||||
type: string
|
||||
example: reusable
|
||||
expires_in:
|
||||
description: Expiration time in seconds, 0 will mean the key never expires
|
||||
type: integer
|
||||
minimum: 0
|
||||
example: 86400
|
||||
revoked:
|
||||
description: Setup key revocation status
|
||||
type: boolean
|
||||
@@ -544,21 +549,9 @@ components:
|
||||
items:
|
||||
type: string
|
||||
example: "ch8i4ug6lnn4g9hqv7m0"
|
||||
usage_limit:
|
||||
description: A number of times this key can be used. The value of 0 indicates the unlimited usage.
|
||||
type: integer
|
||||
example: 0
|
||||
ephemeral:
|
||||
description: Indicate that the peer will be ephemeral or not
|
||||
type: boolean
|
||||
example: true
|
||||
required:
|
||||
- name
|
||||
- type
|
||||
- expires_in
|
||||
- revoked
|
||||
- auto_groups
|
||||
- usage_limit
|
||||
CreateSetupKeyRequest:
|
||||
type: object
|
||||
properties:
|
||||
@@ -1943,7 +1936,7 @@ paths:
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/SetupKey'
|
||||
$ref: '#/components/schemas/SetupKeyClear'
|
||||
'400':
|
||||
"$ref": "#/components/responses/bad_request"
|
||||
'401':
|
||||
|
||||
@@ -1062,7 +1062,94 @@ type SetupKey struct {
|
||||
// Id Setup Key ID
|
||||
Id string `json:"id"`
|
||||
|
||||
// Key Setup Key value
|
||||
// Key Setup Key as secret
|
||||
Key string `json:"key"`
|
||||
|
||||
// LastUsed Setup key last usage date
|
||||
LastUsed time.Time `json:"last_used"`
|
||||
|
||||
// Name Setup key name identifier
|
||||
Name string `json:"name"`
|
||||
|
||||
// Revoked Setup key revocation status
|
||||
Revoked bool `json:"revoked"`
|
||||
|
||||
// State Setup key status, "valid", "overused","expired" or "revoked"
|
||||
State string `json:"state"`
|
||||
|
||||
// Type Setup key type, one-off for single time usage and reusable
|
||||
Type string `json:"type"`
|
||||
|
||||
// UpdatedAt Setup key last update date
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
|
||||
// UsageLimit A number of times this key can be used. The value of 0 indicates the unlimited usage.
|
||||
UsageLimit int `json:"usage_limit"`
|
||||
|
||||
// UsedTimes Usage count of setup key
|
||||
UsedTimes int `json:"used_times"`
|
||||
|
||||
// Valid Setup key validity status
|
||||
Valid bool `json:"valid"`
|
||||
}
|
||||
|
||||
// SetupKeyBase defines model for SetupKeyBase.
|
||||
type SetupKeyBase struct {
|
||||
// AutoGroups List of group IDs to auto-assign to peers registered with this key
|
||||
AutoGroups []string `json:"auto_groups"`
|
||||
|
||||
// Ephemeral Indicate that the peer will be ephemeral or not
|
||||
Ephemeral bool `json:"ephemeral"`
|
||||
|
||||
// Expires Setup Key expiration date
|
||||
Expires time.Time `json:"expires"`
|
||||
|
||||
// Id Setup Key ID
|
||||
Id string `json:"id"`
|
||||
|
||||
// LastUsed Setup key last usage date
|
||||
LastUsed time.Time `json:"last_used"`
|
||||
|
||||
// Name Setup key name identifier
|
||||
Name string `json:"name"`
|
||||
|
||||
// Revoked Setup key revocation status
|
||||
Revoked bool `json:"revoked"`
|
||||
|
||||
// State Setup key status, "valid", "overused","expired" or "revoked"
|
||||
State string `json:"state"`
|
||||
|
||||
// Type Setup key type, one-off for single time usage and reusable
|
||||
Type string `json:"type"`
|
||||
|
||||
// UpdatedAt Setup key last update date
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
|
||||
// UsageLimit A number of times this key can be used. The value of 0 indicates the unlimited usage.
|
||||
UsageLimit int `json:"usage_limit"`
|
||||
|
||||
// UsedTimes Usage count of setup key
|
||||
UsedTimes int `json:"used_times"`
|
||||
|
||||
// Valid Setup key validity status
|
||||
Valid bool `json:"valid"`
|
||||
}
|
||||
|
||||
// SetupKeyClear defines model for SetupKeyClear.
|
||||
type SetupKeyClear struct {
|
||||
// AutoGroups List of group IDs to auto-assign to peers registered with this key
|
||||
AutoGroups []string `json:"auto_groups"`
|
||||
|
||||
// Ephemeral Indicate that the peer will be ephemeral or not
|
||||
Ephemeral bool `json:"ephemeral"`
|
||||
|
||||
// Expires Setup Key expiration date
|
||||
Expires time.Time `json:"expires"`
|
||||
|
||||
// Id Setup Key ID
|
||||
Id string `json:"id"`
|
||||
|
||||
// Key Setup Key as plain text
|
||||
Key string `json:"key"`
|
||||
|
||||
// LastUsed Setup key last usage date
|
||||
@@ -1098,23 +1185,8 @@ type SetupKeyRequest struct {
|
||||
// AutoGroups List of group IDs to auto-assign to peers registered with this key
|
||||
AutoGroups []string `json:"auto_groups"`
|
||||
|
||||
// Ephemeral Indicate that the peer will be ephemeral or not
|
||||
Ephemeral *bool `json:"ephemeral,omitempty"`
|
||||
|
||||
// ExpiresIn Expiration time in seconds, 0 will mean the key never expires
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
|
||||
// Name Setup Key name
|
||||
Name string `json:"name"`
|
||||
|
||||
// Revoked Setup key revocation status
|
||||
Revoked bool `json:"revoked"`
|
||||
|
||||
// Type Setup key type, one-off for single time usage and reusable
|
||||
Type string `json:"type"`
|
||||
|
||||
// UsageLimit A number of times this key can be used. The value of 0 indicates the unlimited usage.
|
||||
UsageLimit int `json:"usage_limit"`
|
||||
}
|
||||
|
||||
// User defines model for User.
|
||||
|
||||
@@ -184,14 +184,26 @@ func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
dnsDomain := h.accountManager.GetDNSDomain()
|
||||
|
||||
respBody := make([]*api.PeerBatch, 0, len(account.Peers))
|
||||
for _, peer := range account.Peers {
|
||||
peers, err := h.accountManager.GetPeers(r.Context(), accountID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
groupsMap := map[string]*nbgroup.Group{}
|
||||
groups, _ := h.accountManager.GetAllGroups(r.Context(), accountID, userID)
|
||||
for _, group := range groups {
|
||||
groupsMap[group.ID] = group
|
||||
}
|
||||
|
||||
respBody := make([]*api.PeerBatch, 0, len(peers))
|
||||
for _, peer := range peers {
|
||||
peerToReturn, err := h.checkPeerStatus(peer)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
groupMinimumInfo := toGroupsInfo(account.Groups, peer.ID)
|
||||
groupMinimumInfo := toGroupsInfo(groupsMap, peer.ID)
|
||||
|
||||
respBody = append(respBody, toPeerListItemResponse(peerToReturn, groupMinimumInfo, dnsDomain, 0))
|
||||
}
|
||||
@@ -304,7 +316,7 @@ func peerToAccessiblePeer(peer *nbpeer.Peer, dnsDomain string) api.AccessiblePee
|
||||
}
|
||||
|
||||
func toGroupsInfo(groups map[string]*nbgroup.Group, peerID string) []api.GroupMinimum {
|
||||
var groupsInfo []api.GroupMinimum
|
||||
groupsInfo := []api.GroupMinimum{}
|
||||
groupsChecked := make(map[string]struct{})
|
||||
for _, group := range groups {
|
||||
_, ok := groupsChecked[group.ID]
|
||||
|
||||
@@ -6,10 +6,8 @@ import (
|
||||
"strconv"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||
"github.com/rs/xid"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/http/util"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
@@ -122,21 +120,22 @@ func (h *Policies) savePolicy(w http.ResponseWriter, r *http.Request, accountID
|
||||
return
|
||||
}
|
||||
|
||||
isUpdate := policyID != ""
|
||||
|
||||
if policyID == "" {
|
||||
policyID = xid.New().String()
|
||||
}
|
||||
|
||||
policy := server.Policy{
|
||||
policy := &server.Policy{
|
||||
ID: policyID,
|
||||
AccountID: accountID,
|
||||
Name: req.Name,
|
||||
Enabled: req.Enabled,
|
||||
Description: req.Description,
|
||||
}
|
||||
for _, rule := range req.Rules {
|
||||
var ruleID string
|
||||
if rule.Id != nil {
|
||||
ruleID = *rule.Id
|
||||
}
|
||||
|
||||
pr := server.PolicyRule{
|
||||
ID: policyID, // TODO: when policy can contain multiple rules, need refactor
|
||||
ID: ruleID,
|
||||
PolicyID: policyID,
|
||||
Name: rule.Name,
|
||||
Destinations: rule.Destinations,
|
||||
Sources: rule.Sources,
|
||||
@@ -225,7 +224,8 @@ func (h *Policies) savePolicy(w http.ResponseWriter, r *http.Request, accountID
|
||||
policy.SourcePostureChecks = *req.SourcePostureChecks
|
||||
}
|
||||
|
||||
if err := h.accountManager.SavePolicy(r.Context(), accountID, userID, &policy, isUpdate); err != nil {
|
||||
policy, err := h.accountManager.SavePolicy(r.Context(), accountID, userID, policy)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
@@ -236,7 +236,7 @@ func (h *Policies) savePolicy(w http.ResponseWriter, r *http.Request, accountID
|
||||
return
|
||||
}
|
||||
|
||||
resp := toPolicyResponse(allGroups, &policy)
|
||||
resp := toPolicyResponse(allGroups, policy)
|
||||
if len(resp.Rules) == 0 {
|
||||
util.WriteError(r.Context(), status.Errorf(status.Internal, "no rules in the policy"), w)
|
||||
return
|
||||
|
||||
@@ -38,12 +38,12 @@ func initPoliciesTestData(policies ...*server.Policy) *Policies {
|
||||
}
|
||||
return policy, nil
|
||||
},
|
||||
SavePolicyFunc: func(_ context.Context, _, _ string, policy *server.Policy, _ bool) error {
|
||||
SavePolicyFunc: func(_ context.Context, _, _ string, policy *server.Policy) (*server.Policy, error) {
|
||||
if !strings.HasPrefix(policy.ID, "id-") {
|
||||
policy.ID = "id-was-set"
|
||||
policy.Rules[0].ID = "id-was-set"
|
||||
}
|
||||
return nil
|
||||
return policy, nil
|
||||
},
|
||||
GetAllGroupsFunc: func(ctx context.Context, accountID, userID string) ([]*nbgroup.Group, error) {
|
||||
return []*nbgroup.Group{{ID: "F"}, {ID: "G"}}, nil
|
||||
|
||||
@@ -169,7 +169,8 @@ func (p *PostureChecksHandler) savePostureChecks(w http.ResponseWriter, r *http.
|
||||
return
|
||||
}
|
||||
|
||||
if err := p.accountManager.SavePostureChecks(r.Context(), accountID, userID, postureChecks); err != nil {
|
||||
postureChecks, err = p.accountManager.SavePostureChecks(r.Context(), accountID, userID, postureChecks)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -40,15 +40,15 @@ func initPostureChecksTestData(postureChecks ...*posture.Checks) *PostureChecksH
|
||||
}
|
||||
return p, nil
|
||||
},
|
||||
SavePostureChecksFunc: func(_ context.Context, accountID, userID string, postureChecks *posture.Checks) error {
|
||||
SavePostureChecksFunc: func(_ context.Context, accountID, userID string, postureChecks *posture.Checks) (*posture.Checks, error) {
|
||||
postureChecks.ID = "postureCheck"
|
||||
testPostureChecks[postureChecks.ID] = postureChecks
|
||||
|
||||
if err := postureChecks.Validate(); err != nil {
|
||||
return status.Errorf(status.InvalidArgument, err.Error()) //nolint
|
||||
return nil, status.Errorf(status.InvalidArgument, err.Error()) //nolint
|
||||
}
|
||||
|
||||
return nil
|
||||
return postureChecks, nil
|
||||
},
|
||||
DeletePostureChecksFunc: func(_ context.Context, accountID, postureChecksID, userID string) error {
|
||||
_, ok := testPostureChecks[postureChecksID]
|
||||
|
||||
@@ -137,11 +137,6 @@ func (h *SetupKeysHandler) UpdateSetupKey(w http.ResponseWriter, r *http.Request
|
||||
return
|
||||
}
|
||||
|
||||
if req.Name == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "setup key name field is invalid: %s", req.Name), w)
|
||||
return
|
||||
}
|
||||
|
||||
if req.AutoGroups == nil {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "setup key AutoGroups field is invalid"), w)
|
||||
return
|
||||
@@ -150,7 +145,6 @@ func (h *SetupKeysHandler) UpdateSetupKey(w http.ResponseWriter, r *http.Request
|
||||
newKey := &server.SetupKey{}
|
||||
newKey.AutoGroups = req.AutoGroups
|
||||
newKey.Revoked = req.Revoked
|
||||
newKey.Name = req.Name
|
||||
newKey.Id = keyID
|
||||
|
||||
newKey, err = h.accountManager.SaveSetupKey(r.Context(), accountID, newKey, userID)
|
||||
|
||||
@@ -52,25 +52,22 @@ func (am *DefaultAccountManager) UpdateIntegratedValidatorGroups(ctx context.Con
|
||||
return am.Store.SaveAccount(ctx, a)
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error) {
|
||||
if len(groups) == 0 {
|
||||
func (am *DefaultAccountManager) GroupValidation(ctx context.Context, accountID string, groupIDs []string) (bool, error) {
|
||||
if len(groupIDs) == 0 {
|
||||
return true, nil
|
||||
}
|
||||
accountsGroups, err := am.ListGroups(ctx, accountId)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
for _, group := range groups {
|
||||
var found bool
|
||||
for _, accountGroup := range accountsGroups {
|
||||
if accountGroup.ID == group {
|
||||
found = true
|
||||
break
|
||||
|
||||
err := am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||
for _, groupID := range groupIDs {
|
||||
_, err := transaction.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
return false, nil
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return true, nil
|
||||
|
||||
@@ -45,12 +45,11 @@ type MockAccountManager struct {
|
||||
SaveGroupsFunc func(ctx context.Context, accountID, userID string, groups []*group.Group) error
|
||||
DeleteGroupFunc func(ctx context.Context, accountID, userId, groupID string) error
|
||||
DeleteGroupsFunc func(ctx context.Context, accountId, userId string, groupIDs []string) error
|
||||
ListGroupsFunc func(ctx context.Context, accountID string) ([]*group.Group, error)
|
||||
GroupAddPeerFunc func(ctx context.Context, accountID, groupID, peerID string) error
|
||||
GroupDeletePeerFunc func(ctx context.Context, accountID, groupID, peerID string) error
|
||||
DeleteRuleFunc func(ctx context.Context, accountID, ruleID, userID string) error
|
||||
GetPolicyFunc func(ctx context.Context, accountID, policyID, userID string) (*server.Policy, error)
|
||||
SavePolicyFunc func(ctx context.Context, accountID, userID string, policy *server.Policy, isUpdate bool) error
|
||||
SavePolicyFunc func(ctx context.Context, accountID, userID string, policy *server.Policy) (*server.Policy, error)
|
||||
DeletePolicyFunc func(ctx context.Context, accountID, policyID, userID string) error
|
||||
ListPoliciesFunc func(ctx context.Context, accountID, userID string) ([]*server.Policy, error)
|
||||
GetUsersFromAccountFunc func(ctx context.Context, accountID, userID string) ([]*server.UserInfo, error)
|
||||
@@ -97,7 +96,7 @@ type MockAccountManager struct {
|
||||
HasConnectedChannelFunc func(peerID string) bool
|
||||
GetExternalCacheManagerFunc func() server.ExternalCacheManager
|
||||
GetPostureChecksFunc func(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error)
|
||||
SavePostureChecksFunc func(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error
|
||||
SavePostureChecksFunc func(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) (*posture.Checks, error)
|
||||
DeletePostureChecksFunc func(ctx context.Context, accountID, postureChecksID, userID string) error
|
||||
ListPostureChecksFunc func(ctx context.Context, accountID, userID string) ([]*posture.Checks, error)
|
||||
GetIdpManagerFunc func() idp.Manager
|
||||
@@ -354,14 +353,6 @@ func (am *MockAccountManager) DeleteGroups(ctx context.Context, accountId, userI
|
||||
return status.Errorf(codes.Unimplemented, "method DeleteGroups is not implemented")
|
||||
}
|
||||
|
||||
// ListGroups mock implementation of ListGroups from server.AccountManager interface
|
||||
func (am *MockAccountManager) ListGroups(ctx context.Context, accountID string) ([]*group.Group, error) {
|
||||
if am.ListGroupsFunc != nil {
|
||||
return am.ListGroupsFunc(ctx, accountID)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method ListGroups is not implemented")
|
||||
}
|
||||
|
||||
// GroupAddPeer mock implementation of GroupAddPeer from server.AccountManager interface
|
||||
func (am *MockAccountManager) GroupAddPeer(ctx context.Context, accountID, groupID, peerID string) error {
|
||||
if am.GroupAddPeerFunc != nil {
|
||||
@@ -395,11 +386,11 @@ func (am *MockAccountManager) GetPolicy(ctx context.Context, accountID, policyID
|
||||
}
|
||||
|
||||
// SavePolicy mock implementation of SavePolicy from server.AccountManager interface
|
||||
func (am *MockAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *server.Policy, isUpdate bool) error {
|
||||
func (am *MockAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *server.Policy) (*server.Policy, error) {
|
||||
if am.SavePolicyFunc != nil {
|
||||
return am.SavePolicyFunc(ctx, accountID, userID, policy, isUpdate)
|
||||
return am.SavePolicyFunc(ctx, accountID, userID, policy)
|
||||
}
|
||||
return status.Errorf(codes.Unimplemented, "method SavePolicy is not implemented")
|
||||
return nil, status.Errorf(codes.Unimplemented, "method SavePolicy is not implemented")
|
||||
}
|
||||
|
||||
// DeletePolicy mock implementation of DeletePolicy from server.AccountManager interface
|
||||
@@ -739,11 +730,11 @@ func (am *MockAccountManager) GetPostureChecks(ctx context.Context, accountID, p
|
||||
}
|
||||
|
||||
// SavePostureChecks mocks SavePostureChecks of the AccountManager interface
|
||||
func (am *MockAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error {
|
||||
func (am *MockAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) (*posture.Checks, error) {
|
||||
if am.SavePostureChecksFunc != nil {
|
||||
return am.SavePostureChecksFunc(ctx, accountID, userID, postureChecks)
|
||||
}
|
||||
return status.Errorf(codes.Unimplemented, "method SavePostureChecks is not implemented")
|
||||
return nil, status.Errorf(codes.Unimplemented, "method SavePostureChecks is not implemented")
|
||||
}
|
||||
|
||||
// DeletePostureChecks mocks DeletePostureChecks of the AccountManager interface
|
||||
|
||||
@@ -24,26 +24,34 @@ func (am *DefaultAccountManager) GetNameServerGroup(ctx context.Context, account
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
|
||||
return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view name server groups")
|
||||
if user.AccountID != accountID {
|
||||
return nil, status.NewUserNotPartOfAccountError()
|
||||
}
|
||||
|
||||
return am.Store.GetNameServerGroupByID(ctx, LockingStrengthShare, nsGroupID, accountID)
|
||||
if user.IsRegularUser() {
|
||||
return nil, status.NewAdminPermissionError()
|
||||
}
|
||||
|
||||
return am.Store.GetNameServerGroupByID(ctx, LockingStrengthShare, accountID, nsGroupID)
|
||||
}
|
||||
|
||||
// CreateNameServerGroup creates and saves a new nameserver group
|
||||
func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainEnabled bool) (*nbdns.NameServerGroup, error) {
|
||||
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if user.AccountID != accountID {
|
||||
return nil, status.NewUserNotPartOfAccountError()
|
||||
}
|
||||
|
||||
newNSGroup := &nbdns.NameServerGroup{
|
||||
ID: xid.New().String(),
|
||||
AccountID: accountID,
|
||||
Name: name,
|
||||
Description: description,
|
||||
NameServers: nameServerList,
|
||||
@@ -54,27 +62,34 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco
|
||||
SearchDomainsEnabled: searchDomainEnabled,
|
||||
}
|
||||
|
||||
err = validateNameServerGroup(false, newNSGroup, account)
|
||||
var updateAccountPeers bool
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||
if err = validateNameServerGroup(ctx, transaction, accountID, newNSGroup); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
updateAccountPeers, err = anyGroupHasPeers(ctx, transaction, accountID, newNSGroup.Groups)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return transaction.SaveNameServerGroup(ctx, LockingStrengthUpdate, newNSGroup)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if account.NameServerGroups == nil {
|
||||
account.NameServerGroups = make(map[string]*nbdns.NameServerGroup)
|
||||
}
|
||||
|
||||
account.NameServerGroups[newNSGroup.ID] = newNSGroup
|
||||
|
||||
account.Network.IncSerial()
|
||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if anyGroupHasPeers(account, newNSGroup.Groups) {
|
||||
am.updateAccountPeers(ctx, account)
|
||||
}
|
||||
am.StoreEvent(ctx, userID, newNSGroup.ID, accountID, activity.NameserverGroupCreated, newNSGroup.EventMeta())
|
||||
|
||||
if updateAccountPeers {
|
||||
am.updateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
return newNSGroup.Copy(), nil
|
||||
}
|
||||
|
||||
@@ -87,59 +102,96 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun
|
||||
return status.Errorf(status.InvalidArgument, "nameserver group provided is nil")
|
||||
}
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = validateNameServerGroup(true, nsGroupToSave, account)
|
||||
if user.AccountID != accountID {
|
||||
return status.NewUserNotPartOfAccountError()
|
||||
}
|
||||
|
||||
var updateAccountPeers bool
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||
oldNSGroup, err := transaction.GetNameServerGroupByID(ctx, LockingStrengthShare, accountID, nsGroupToSave.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
nsGroupToSave.AccountID = accountID
|
||||
|
||||
if err = validateNameServerGroup(ctx, transaction, accountID, nsGroupToSave); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
updateAccountPeers, err = areNameServerGroupChangesAffectPeers(ctx, transaction, nsGroupToSave, oldNSGroup)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return transaction.SaveNameServerGroup(ctx, LockingStrengthUpdate, nsGroupToSave)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
oldNSGroup := account.NameServerGroups[nsGroupToSave.ID]
|
||||
account.NameServerGroups[nsGroupToSave.ID] = nsGroupToSave
|
||||
|
||||
account.Network.IncSerial()
|
||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if areNameServerGroupChangesAffectPeers(account, nsGroupToSave, oldNSGroup) {
|
||||
am.updateAccountPeers(ctx, account)
|
||||
}
|
||||
am.StoreEvent(ctx, userID, nsGroupToSave.ID, accountID, activity.NameserverGroupUpdated, nsGroupToSave.EventMeta())
|
||||
|
||||
if updateAccountPeers {
|
||||
am.updateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteNameServerGroup deletes nameserver group with nsGroupID
|
||||
func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, accountID, nsGroupID, userID string) error {
|
||||
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
nsGroup := account.NameServerGroups[nsGroupID]
|
||||
if nsGroup == nil {
|
||||
return status.Errorf(status.NotFound, "nameserver group %s wasn't found", nsGroupID)
|
||||
if user.AccountID != accountID {
|
||||
return status.NewUserNotPartOfAccountError()
|
||||
}
|
||||
delete(account.NameServerGroups, nsGroupID)
|
||||
|
||||
account.Network.IncSerial()
|
||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
||||
var nsGroup *nbdns.NameServerGroup
|
||||
var updateAccountPeers bool
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||
nsGroup, err = transaction.GetNameServerGroupByID(ctx, LockingStrengthUpdate, accountID, nsGroupID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
updateAccountPeers, err = anyGroupHasPeers(ctx, transaction, accountID, nsGroup.Groups)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return transaction.DeleteNameServerGroup(ctx, LockingStrengthUpdate, accountID, nsGroupID)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if anyGroupHasPeers(account, nsGroup.Groups) {
|
||||
am.updateAccountPeers(ctx, account)
|
||||
}
|
||||
am.StoreEvent(ctx, userID, nsGroup.ID, accountID, activity.NameserverGroupDeleted, nsGroup.EventMeta())
|
||||
|
||||
if updateAccountPeers {
|
||||
am.updateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -150,44 +202,62 @@ func (am *DefaultAccountManager) ListNameServerGroups(ctx context.Context, accou
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
|
||||
return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view name server groups")
|
||||
if user.AccountID != accountID {
|
||||
return nil, status.NewUserNotPartOfAccountError()
|
||||
}
|
||||
|
||||
if user.IsRegularUser() {
|
||||
return nil, status.NewAdminPermissionError()
|
||||
}
|
||||
|
||||
return am.Store.GetAccountNameServerGroups(ctx, LockingStrengthShare, accountID)
|
||||
}
|
||||
|
||||
func validateNameServerGroup(existingGroup bool, nameserverGroup *nbdns.NameServerGroup, account *Account) error {
|
||||
nsGroupID := ""
|
||||
if existingGroup {
|
||||
nsGroupID = nameserverGroup.ID
|
||||
_, found := account.NameServerGroups[nsGroupID]
|
||||
if !found {
|
||||
return status.Errorf(status.NotFound, "nameserver group with ID %s was not found", nsGroupID)
|
||||
}
|
||||
}
|
||||
|
||||
func validateNameServerGroup(ctx context.Context, transaction Store, accountID string, nameserverGroup *nbdns.NameServerGroup) error {
|
||||
err := validateDomainInput(nameserverGroup.Primary, nameserverGroup.Domains, nameserverGroup.SearchDomainsEnabled)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = validateNSGroupName(nameserverGroup.Name, nsGroupID, account.NameServerGroups)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = validateNSList(nameserverGroup.NameServers)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = validateGroups(nameserverGroup.Groups, account.Groups)
|
||||
nsServerGroups, err := transaction.GetAccountNameServerGroups(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
err = validateNSGroupName(nameserverGroup.Name, nameserverGroup.ID, nsServerGroups)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, nameserverGroup.Groups)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return validateGroups(nameserverGroup.Groups, groups)
|
||||
}
|
||||
|
||||
// areNameServerGroupChangesAffectPeers checks if the changes in the nameserver group affect the peers.
|
||||
func areNameServerGroupChangesAffectPeers(ctx context.Context, transaction Store, newNSGroup, oldNSGroup *nbdns.NameServerGroup) (bool, error) {
|
||||
if !newNSGroup.Enabled && !oldNSGroup.Enabled {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
hasPeers, err := anyGroupHasPeers(ctx, transaction, newNSGroup.AccountID, newNSGroup.Groups)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if hasPeers {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
return anyGroupHasPeers(ctx, transaction, oldNSGroup.AccountID, oldNSGroup.Groups)
|
||||
}
|
||||
|
||||
func validateDomainInput(primary bool, domains []string, searchDomainsEnabled bool) error {
|
||||
@@ -213,14 +283,14 @@ func validateDomainInput(primary bool, domains []string, searchDomainsEnabled bo
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateNSGroupName(name, nsGroupID string, nsGroupMap map[string]*nbdns.NameServerGroup) error {
|
||||
func validateNSGroupName(name, nsGroupID string, groups []*nbdns.NameServerGroup) error {
|
||||
if utf8.RuneCountInString(name) > nbdns.MaxGroupNameChar || name == "" {
|
||||
return status.Errorf(status.InvalidArgument, "nameserver group name should be between 1 and %d", nbdns.MaxGroupNameChar)
|
||||
}
|
||||
|
||||
for _, nsGroup := range nsGroupMap {
|
||||
for _, nsGroup := range groups {
|
||||
if name == nsGroup.Name && nsGroup.ID != nsGroupID {
|
||||
return status.Errorf(status.InvalidArgument, "a nameserver group with name %s already exist", name)
|
||||
return status.Errorf(status.InvalidArgument, "nameserver group with name %s already exist", name)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -228,8 +298,8 @@ func validateNSGroupName(name, nsGroupID string, nsGroupMap map[string]*nbdns.Na
|
||||
}
|
||||
|
||||
func validateNSList(list []nbdns.NameServer) error {
|
||||
nsListLenght := len(list)
|
||||
if nsListLenght == 0 || nsListLenght > 3 {
|
||||
nsListLength := len(list)
|
||||
if nsListLength == 0 || nsListLength > 3 {
|
||||
return status.Errorf(status.InvalidArgument, "the list of nameservers should be 1 or 3, got %d", len(list))
|
||||
}
|
||||
return nil
|
||||
@@ -244,14 +314,7 @@ func validateGroups(list []string, groups map[string]*nbgroup.Group) error {
|
||||
if id == "" {
|
||||
return status.Errorf(status.InvalidArgument, "group ID should not be empty string")
|
||||
}
|
||||
found := false
|
||||
for groupID := range groups {
|
||||
if id == groupID {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
if _, found := groups[id]; !found {
|
||||
return status.Errorf(status.InvalidArgument, "group id %s not found", id)
|
||||
}
|
||||
}
|
||||
@@ -277,11 +340,3 @@ func validateDomain(domain string) error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// areNameServerGroupChangesAffectPeers checks if the changes in the nameserver group affect the peers.
|
||||
func areNameServerGroupChangesAffectPeers(account *Account, newNSGroup, oldNSGroup *nbdns.NameServerGroup) bool {
|
||||
if !newNSGroup.Enabled && !oldNSGroup.Enabled {
|
||||
return false
|
||||
}
|
||||
return anyGroupHasPeers(account, newNSGroup.Groups) || anyGroupHasPeers(account, oldNSGroup.Groups)
|
||||
}
|
||||
|
||||
@@ -133,7 +133,7 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK
|
||||
if expired {
|
||||
// we need to update other peers because when peer login expires all other peers are notified to disconnect from
|
||||
// the expired one. Here we notify them that connection is now allowed again.
|
||||
am.updateAccountPeers(ctx, account)
|
||||
am.updateAccountPeers(ctx, account.Id)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -168,6 +168,8 @@ func (am *DefaultAccountManager) updatePeerStatusAndLocation(ctx context.Context
|
||||
|
||||
account.UpdatePeer(peer)
|
||||
|
||||
log.WithContext(ctx).Tracef("saving peer status for peer %s is connected: %t", peer.ID, connected)
|
||||
|
||||
err := am.Store.SavePeerStatus(account.Id, peer.ID, *newStatus)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to save peer status: %w", err)
|
||||
@@ -269,7 +271,7 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
|
||||
}
|
||||
|
||||
if peerLabelUpdated || requiresPeerUpdates {
|
||||
am.updateAccountPeers(ctx, account)
|
||||
am.updateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
return peer, nil
|
||||
@@ -333,7 +335,10 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
|
||||
return err
|
||||
}
|
||||
|
||||
updateAccountPeers := isPeerInActiveGroup(account, peerID)
|
||||
updateAccountPeers, err := am.isPeerInActiveGroup(ctx, account, peerID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = am.deletePeers(ctx, account, []string{peerID}, userID)
|
||||
if err != nil {
|
||||
@@ -346,7 +351,7 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
|
||||
}
|
||||
|
||||
if updateAccountPeers {
|
||||
am.updateAccountPeers(ctx, account)
|
||||
am.updateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -553,7 +558,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
||||
return fmt.Errorf("failed to add peer to account: %w", err)
|
||||
}
|
||||
|
||||
err = transaction.IncrementNetworkSerial(ctx, accountID)
|
||||
err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to increment network serial: %w", err)
|
||||
}
|
||||
@@ -596,10 +601,15 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
||||
if err != nil {
|
||||
return nil, nil, nil, fmt.Errorf("error getting all group ID: %w", err)
|
||||
}
|
||||
|
||||
groupsToAdd = append(groupsToAdd, allGroup.ID)
|
||||
if areGroupChangesAffectPeers(account, groupsToAdd) {
|
||||
am.updateAccountPeers(ctx, account)
|
||||
|
||||
newGroupsAffectsPeers, err := areGroupChangesAffectPeers(ctx, am.Store, accountID, groupsToAdd)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
if newGroupsAffectsPeers {
|
||||
am.updateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
approvedPeersMap, err := am.GetValidatedPeers(account)
|
||||
@@ -607,7 +617,11 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
postureChecks := am.getPeerPostureChecks(account, newPeer)
|
||||
postureChecks, err := am.getPeerPostureChecks(account, newPeer.ID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
customZone := account.GetPeersCustomZone(ctx, am.dnsDomain)
|
||||
networkMap := account.GetPeerNetworkMap(ctx, newPeer.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics())
|
||||
return newPeer, networkMap, postureChecks, nil
|
||||
@@ -657,13 +671,16 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac
|
||||
|
||||
updated := peer.UpdateMetaIfNew(sync.Meta)
|
||||
if updated {
|
||||
am.metrics.AccountManagerMetrics().CountPeerMetUpdate()
|
||||
account.Peers[peer.ID] = peer
|
||||
log.WithContext(ctx).Tracef("peer %s metadata updated", peer.ID)
|
||||
err = am.Store.SavePeer(ctx, account.Id, peer)
|
||||
if err != nil {
|
||||
return nil, nil, nil, fmt.Errorf("failed to save peer: %w", err)
|
||||
}
|
||||
|
||||
if sync.UpdateAccountPeers {
|
||||
am.updateAccountPeers(ctx, account)
|
||||
am.updateAccountPeers(ctx, account.Id)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -682,14 +699,18 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac
|
||||
}
|
||||
|
||||
if isStatusChanged {
|
||||
am.updateAccountPeers(ctx, account)
|
||||
am.updateAccountPeers(ctx, account.Id)
|
||||
}
|
||||
|
||||
validPeersMap, err := am.GetValidatedPeers(account)
|
||||
if err != nil {
|
||||
return nil, nil, nil, fmt.Errorf("failed to get validated peers: %w", err)
|
||||
}
|
||||
postureChecks = am.getPeerPostureChecks(account, peer)
|
||||
|
||||
postureChecks, err = am.getPeerPostureChecks(account, peer.ID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
customZone := account.GetPeersCustomZone(ctx, am.dnsDomain)
|
||||
return peer, account.GetPeerNetworkMap(ctx, peer.ID, customZone, validPeersMap, am.metrics.AccountManagerMetrics()), postureChecks, nil
|
||||
@@ -789,6 +810,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin)
|
||||
|
||||
updated := peer.UpdateMetaIfNew(login.Meta)
|
||||
if updated {
|
||||
am.metrics.AccountManagerMetrics().CountPeerMetUpdate()
|
||||
shouldStorePeer = true
|
||||
}
|
||||
|
||||
@@ -813,7 +835,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin)
|
||||
}
|
||||
|
||||
if updateRemotePeers || isStatusChanged {
|
||||
am.updateAccountPeers(ctx, account)
|
||||
am.updateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
return am.getValidatedPeerWithMap(ctx, isRequiresApproval, account, peer)
|
||||
@@ -862,7 +884,11 @@ func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, is
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
postureChecks = am.getPeerPostureChecks(account, peer)
|
||||
|
||||
postureChecks, err = am.getPeerPostureChecks(account, peer.ID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
customZone := account.GetPeersCustomZone(ctx, am.dnsDomain)
|
||||
return peer, account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics()), postureChecks, nil
|
||||
@@ -976,7 +1002,13 @@ func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID,
|
||||
|
||||
// updateAccountPeers updates all peers that belong to an account.
|
||||
// Should be called when changes have to be synced to peers.
|
||||
func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account *Account) {
|
||||
func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, accountID string) {
|
||||
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to send out updates to peers: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
defer func() {
|
||||
if am.metrics != nil {
|
||||
@@ -1010,7 +1042,12 @@ func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account
|
||||
defer wg.Done()
|
||||
defer func() { <-semaphore }()
|
||||
|
||||
postureChecks := am.getPeerPostureChecks(account, p)
|
||||
postureChecks, err := am.getPeerPostureChecks(account, p.ID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to send out updates to peers, failed to get peer: %s posture checks: %v", p.ID, err)
|
||||
return
|
||||
}
|
||||
|
||||
remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics())
|
||||
update := toSyncResponse(ctx, nil, p, nil, nil, remotePeerNetworkMap, am.GetDNSDomain(), postureChecks, dnsCache)
|
||||
am.peersUpdateManager.SendUpdate(ctx, p.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap})
|
||||
@@ -1030,12 +1067,12 @@ func ConvertSliceToMap(existingLabels []string) map[string]struct{} {
|
||||
|
||||
// IsPeerInActiveGroup checks if the given peer is part of a group that is used
|
||||
// in an active DNS, route, or ACL configuration.
|
||||
func isPeerInActiveGroup(account *Account, peerID string) bool {
|
||||
func (am *DefaultAccountManager) isPeerInActiveGroup(ctx context.Context, account *Account, peerID string) (bool, error) {
|
||||
peerGroupIDs := make([]string, 0)
|
||||
for _, group := range account.Groups {
|
||||
if slices.Contains(group.Peers, peerID) {
|
||||
peerGroupIDs = append(peerGroupIDs, group.ID)
|
||||
}
|
||||
}
|
||||
return areGroupChangesAffectPeers(account, peerGroupIDs)
|
||||
return areGroupChangesAffectPeers(ctx, am.Store, account.Id, peerGroupIDs)
|
||||
}
|
||||
|
||||
@@ -283,14 +283,12 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) {
|
||||
var (
|
||||
group1 nbgroup.Group
|
||||
group2 nbgroup.Group
|
||||
policy Policy
|
||||
)
|
||||
|
||||
group1.ID = xid.New().String()
|
||||
group2.ID = xid.New().String()
|
||||
group1.Name = "src"
|
||||
group2.Name = "dst"
|
||||
policy.ID = xid.New().String()
|
||||
group1.Peers = append(group1.Peers, peer1.ID)
|
||||
group2.Peers = append(group2.Peers, peer2.ID)
|
||||
|
||||
@@ -305,18 +303,20 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
policy.Name = "test"
|
||||
policy.Enabled = true
|
||||
policy.Rules = []*PolicyRule{
|
||||
{
|
||||
Enabled: true,
|
||||
Sources: []string{group1.ID},
|
||||
Destinations: []string{group2.ID},
|
||||
Bidirectional: true,
|
||||
Action: PolicyTrafficActionAccept,
|
||||
policy := &Policy{
|
||||
Name: "test",
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
Enabled: true,
|
||||
Sources: []string{group1.ID},
|
||||
Destinations: []string{group2.ID},
|
||||
Bidirectional: true,
|
||||
Action: PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
}
|
||||
err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, false)
|
||||
policy, err = manager.SavePolicy(context.Background(), account.Id, userID, policy)
|
||||
if err != nil {
|
||||
t.Errorf("expecting rule to be added, got failure %v", err)
|
||||
return
|
||||
@@ -364,7 +364,7 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) {
|
||||
}
|
||||
|
||||
policy.Enabled = false
|
||||
err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, true)
|
||||
_, err = manager.SavePolicy(context.Background(), account.Id, userID, policy)
|
||||
if err != nil {
|
||||
t.Errorf("expecting rule to be added, got failure %v", err)
|
||||
return
|
||||
@@ -833,19 +833,23 @@ func BenchmarkGetPeers(b *testing.B) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkUpdateAccountPeers(b *testing.B) {
|
||||
benchCases := []struct {
|
||||
name string
|
||||
peers int
|
||||
groups int
|
||||
// We need different expectations for CI/CD and local runs because of the different performance characteristics
|
||||
minMsPerOpLocal float64
|
||||
maxMsPerOpLocal float64
|
||||
minMsPerOpCICD float64
|
||||
maxMsPerOpCICD float64
|
||||
}{
|
||||
{"Small", 50, 5},
|
||||
{"Medium", 500, 10},
|
||||
{"Large", 5000, 20},
|
||||
{"Small single", 50, 1},
|
||||
{"Medium single", 500, 1},
|
||||
{"Large 5", 5000, 5},
|
||||
{"Small", 50, 5, 90, 120, 90, 120},
|
||||
{"Medium", 500, 100, 110, 140, 120, 200},
|
||||
{"Large", 5000, 200, 800, 1300, 2500, 3600},
|
||||
{"Small single", 50, 10, 90, 120, 90, 120},
|
||||
{"Medium single", 500, 10, 110, 170, 120, 200},
|
||||
{"Large 5", 5000, 15, 1300, 1800, 5000, 6000},
|
||||
}
|
||||
|
||||
log.SetOutput(io.Discard)
|
||||
@@ -877,12 +881,27 @@ func BenchmarkUpdateAccountPeers(b *testing.B) {
|
||||
start := time.Now()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
manager.updateAccountPeers(ctx, account)
|
||||
manager.updateAccountPeers(ctx, account.Id)
|
||||
}
|
||||
|
||||
duration := time.Since(start)
|
||||
b.ReportMetric(float64(duration.Nanoseconds())/float64(b.N)/1e6, "ms/op")
|
||||
b.ReportMetric(0, "ns/op")
|
||||
msPerOp := float64(duration.Nanoseconds()) / float64(b.N) / 1e6
|
||||
b.ReportMetric(msPerOp, "ms/op")
|
||||
|
||||
minExpected := bc.minMsPerOpLocal
|
||||
maxExpected := bc.maxMsPerOpLocal
|
||||
if os.Getenv("CI") == "true" {
|
||||
minExpected = bc.minMsPerOpCICD
|
||||
maxExpected = bc.maxMsPerOpCICD
|
||||
}
|
||||
|
||||
if msPerOp < minExpected {
|
||||
b.Fatalf("Benchmark %s failed: too fast (%.2f ms/op, minimum %.2f ms/op)", bc.name, msPerOp, minExpected)
|
||||
}
|
||||
|
||||
if msPerOp > maxExpected {
|
||||
b.Fatalf("Benchmark %s failed: too slow (%.2f ms/op, maximum %.2f ms/op)", bc.name, msPerOp, maxExpected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1445,8 +1464,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
// Adding peer to group linked with policy should update account peers and send peer update
|
||||
t.Run("adding peer to group linked with policy", func(t *testing.T) {
|
||||
err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
|
||||
ID: "policy",
|
||||
_, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
@@ -1457,7 +1475,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
|
||||
Action: PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
}, false)
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
done := make(chan struct{})
|
||||
|
||||
@@ -3,13 +3,13 @@ package server
|
||||
import (
|
||||
"context"
|
||||
_ "embed"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/netbirdio/netbird/management/proto"
|
||||
"github.com/rs/xid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/proto"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
@@ -125,6 +125,7 @@ type PolicyRule struct {
|
||||
func (pm *PolicyRule) Copy() *PolicyRule {
|
||||
rule := &PolicyRule{
|
||||
ID: pm.ID,
|
||||
PolicyID: pm.PolicyID,
|
||||
Name: pm.Name,
|
||||
Description: pm.Description,
|
||||
Enabled: pm.Enabled,
|
||||
@@ -171,6 +172,7 @@ type Policy struct {
|
||||
func (p *Policy) Copy() *Policy {
|
||||
c := &Policy{
|
||||
ID: p.ID,
|
||||
AccountID: p.AccountID,
|
||||
Name: p.Name,
|
||||
Description: p.Description,
|
||||
Enabled: p.Enabled,
|
||||
@@ -343,44 +345,72 @@ func (am *DefaultAccountManager) GetPolicy(ctx context.Context, accountID, polic
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
|
||||
return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view policies")
|
||||
if user.AccountID != accountID {
|
||||
return nil, status.NewUserNotPartOfAccountError()
|
||||
}
|
||||
|
||||
return am.Store.GetPolicyByID(ctx, LockingStrengthShare, policyID, accountID)
|
||||
if user.IsRegularUser() {
|
||||
return nil, status.NewAdminPermissionError()
|
||||
}
|
||||
|
||||
return am.Store.GetPolicyByID(ctx, LockingStrengthShare, accountID, policyID)
|
||||
}
|
||||
|
||||
// SavePolicy in the store
|
||||
func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *Policy, isUpdate bool) error {
|
||||
func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *Policy) (*Policy, error) {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
updateAccountPeers, err := am.savePolicy(account, policy, isUpdate)
|
||||
if user.AccountID != accountID {
|
||||
return nil, status.NewUserNotPartOfAccountError()
|
||||
}
|
||||
|
||||
if user.IsRegularUser() {
|
||||
return nil, status.NewAdminPermissionError()
|
||||
}
|
||||
|
||||
var isUpdate = policy.ID != ""
|
||||
var updateAccountPeers bool
|
||||
var action = activity.PolicyAdded
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||
if err = validatePolicy(ctx, transaction, accountID, policy); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
updateAccountPeers, err = arePolicyChangesAffectPeers(ctx, transaction, accountID, policy, isUpdate)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
saveFunc := transaction.CreatePolicy
|
||||
if isUpdate {
|
||||
action = activity.PolicyUpdated
|
||||
saveFunc = transaction.SavePolicy
|
||||
}
|
||||
|
||||
return saveFunc(ctx, LockingStrengthUpdate, policy)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
account.Network.IncSerial()
|
||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
action := activity.PolicyAdded
|
||||
if isUpdate {
|
||||
action = activity.PolicyUpdated
|
||||
}
|
||||
am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta())
|
||||
|
||||
if updateAccountPeers {
|
||||
am.updateAccountPeers(ctx, account)
|
||||
am.updateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
return nil
|
||||
return policy, nil
|
||||
}
|
||||
|
||||
// DeletePolicy from the store
|
||||
@@ -388,112 +418,134 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
policy, err := am.deletePolicy(account, policyID)
|
||||
if user.AccountID != accountID {
|
||||
return status.NewUserNotPartOfAccountError()
|
||||
}
|
||||
|
||||
if user.IsRegularUser() {
|
||||
return status.NewAdminPermissionError()
|
||||
}
|
||||
|
||||
var policy *Policy
|
||||
var updateAccountPeers bool
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||
policy, err = transaction.GetPolicyByID(ctx, LockingStrengthUpdate, accountID, policyID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
updateAccountPeers, err = arePolicyChangesAffectPeers(ctx, transaction, accountID, policy, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return transaction.DeletePolicy(ctx, LockingStrengthUpdate, accountID, policyID)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
account.Network.IncSerial()
|
||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
||||
return err
|
||||
}
|
||||
am.StoreEvent(ctx, userID, policyID, accountID, activity.PolicyRemoved, policy.EventMeta())
|
||||
|
||||
am.StoreEvent(ctx, userID, policy.ID, accountID, activity.PolicyRemoved, policy.EventMeta())
|
||||
|
||||
if anyGroupHasPeers(account, policy.ruleGroups()) {
|
||||
am.updateAccountPeers(ctx, account)
|
||||
if updateAccountPeers {
|
||||
am.updateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListPolicies from the store
|
||||
// ListPolicies from the store.
|
||||
func (am *DefaultAccountManager) ListPolicies(ctx context.Context, accountID, userID string) ([]*Policy, error) {
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
|
||||
return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view policies")
|
||||
if user.AccountID != accountID {
|
||||
return nil, status.NewUserNotPartOfAccountError()
|
||||
}
|
||||
|
||||
if user.IsRegularUser() {
|
||||
return nil, status.NewAdminPermissionError()
|
||||
}
|
||||
|
||||
return am.Store.GetAccountPolicies(ctx, LockingStrengthShare, accountID)
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) deletePolicy(account *Account, policyID string) (*Policy, error) {
|
||||
policyIdx := -1
|
||||
for i, policy := range account.Policies {
|
||||
if policy.ID == policyID {
|
||||
policyIdx = i
|
||||
break
|
||||
}
|
||||
}
|
||||
if policyIdx < 0 {
|
||||
return nil, status.Errorf(status.NotFound, "rule with ID %s doesn't exist", policyID)
|
||||
}
|
||||
|
||||
policy := account.Policies[policyIdx]
|
||||
account.Policies = append(account.Policies[:policyIdx], account.Policies[policyIdx+1:]...)
|
||||
return policy, nil
|
||||
}
|
||||
|
||||
// savePolicy saves or updates a policy in the given account.
|
||||
// If isUpdate is true, the function updates the existing policy; otherwise, it adds a new policy.
|
||||
func (am *DefaultAccountManager) savePolicy(account *Account, policyToSave *Policy, isUpdate bool) (bool, error) {
|
||||
for index, rule := range policyToSave.Rules {
|
||||
rule.Sources = filterValidGroupIDs(account, rule.Sources)
|
||||
rule.Destinations = filterValidGroupIDs(account, rule.Destinations)
|
||||
policyToSave.Rules[index] = rule
|
||||
}
|
||||
|
||||
if policyToSave.SourcePostureChecks != nil {
|
||||
policyToSave.SourcePostureChecks = filterValidPostureChecks(account, policyToSave.SourcePostureChecks)
|
||||
}
|
||||
|
||||
// arePolicyChangesAffectPeers checks if changes to a policy will affect any associated peers.
|
||||
func arePolicyChangesAffectPeers(ctx context.Context, transaction Store, accountID string, policy *Policy, isUpdate bool) (bool, error) {
|
||||
if isUpdate {
|
||||
policyIdx := slices.IndexFunc(account.Policies, func(policy *Policy) bool { return policy.ID == policyToSave.ID })
|
||||
if policyIdx < 0 {
|
||||
return false, status.Errorf(status.NotFound, "couldn't find policy id %s", policyToSave.ID)
|
||||
existingPolicy, err := transaction.GetPolicyByID(ctx, LockingStrengthShare, accountID, policy.ID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
oldPolicy := account.Policies[policyIdx]
|
||||
// Update the existing policy
|
||||
account.Policies[policyIdx] = policyToSave
|
||||
|
||||
if !policyToSave.Enabled && !oldPolicy.Enabled {
|
||||
if !policy.Enabled && !existingPolicy.Enabled {
|
||||
return false, nil
|
||||
}
|
||||
updateAccountPeers := anyGroupHasPeers(account, oldPolicy.ruleGroups()) || anyGroupHasPeers(account, policyToSave.ruleGroups())
|
||||
|
||||
return updateAccountPeers, nil
|
||||
}
|
||||
hasPeers, err := anyGroupHasPeers(ctx, transaction, policy.AccountID, existingPolicy.ruleGroups())
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
// Add the new policy to the account
|
||||
account.Policies = append(account.Policies, policyToSave)
|
||||
|
||||
return anyGroupHasPeers(account, policyToSave.ruleGroups()), nil
|
||||
}
|
||||
|
||||
func toProtocolFirewallRules(rules []*FirewallRule) []*proto.FirewallRule {
|
||||
result := make([]*proto.FirewallRule, len(rules))
|
||||
for i := range rules {
|
||||
rule := rules[i]
|
||||
|
||||
result[i] = &proto.FirewallRule{
|
||||
PeerIP: rule.PeerIP,
|
||||
Direction: getProtoDirection(rule.Direction),
|
||||
Action: getProtoAction(rule.Action),
|
||||
Protocol: getProtoProtocol(rule.Protocol),
|
||||
Port: rule.Port,
|
||||
if hasPeers {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
return result
|
||||
|
||||
return anyGroupHasPeers(ctx, transaction, policy.AccountID, policy.ruleGroups())
|
||||
}
|
||||
|
||||
// validatePolicy validates the policy and its rules.
|
||||
func validatePolicy(ctx context.Context, transaction Store, accountID string, policy *Policy) error {
|
||||
if policy.ID != "" {
|
||||
_, err := transaction.GetPolicyByID(ctx, LockingStrengthShare, accountID, policy.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
policy.ID = xid.New().String()
|
||||
policy.AccountID = accountID
|
||||
}
|
||||
|
||||
groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, policy.ruleGroups())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
postureChecks, err := transaction.GetPostureChecksByIDs(ctx, LockingStrengthShare, accountID, policy.SourcePostureChecks)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for i, rule := range policy.Rules {
|
||||
ruleCopy := rule.Copy()
|
||||
if ruleCopy.ID == "" {
|
||||
ruleCopy.ID = policy.ID // TODO: when policy can contain multiple rules, need refactor
|
||||
ruleCopy.PolicyID = policy.ID
|
||||
}
|
||||
|
||||
ruleCopy.Sources = getValidGroupIDs(groups, ruleCopy.Sources)
|
||||
ruleCopy.Destinations = getValidGroupIDs(groups, ruleCopy.Destinations)
|
||||
policy.Rules[i] = ruleCopy
|
||||
}
|
||||
|
||||
if policy.SourcePostureChecks != nil {
|
||||
policy.SourcePostureChecks = getValidPostureCheckIDs(postureChecks, policy.SourcePostureChecks)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// getAllPeersFromGroups for given peer ID and list of groups
|
||||
@@ -574,27 +626,42 @@ func (a *Account) getPostureChecks(postureChecksID string) *posture.Checks {
|
||||
return nil
|
||||
}
|
||||
|
||||
// filterValidPostureChecks filters and returns the posture check IDs from the given list
|
||||
// that are valid within the provided account.
|
||||
func filterValidPostureChecks(account *Account, postureChecksIds []string) []string {
|
||||
result := make([]string, 0, len(postureChecksIds))
|
||||
// getValidPostureCheckIDs filters and returns only the valid posture check IDs from the provided list.
|
||||
func getValidPostureCheckIDs(postureChecks map[string]*posture.Checks, postureChecksIds []string) []string {
|
||||
validIDs := make([]string, 0, len(postureChecksIds))
|
||||
for _, id := range postureChecksIds {
|
||||
for _, postureCheck := range account.PostureChecks {
|
||||
if id == postureCheck.ID {
|
||||
result = append(result, id)
|
||||
continue
|
||||
}
|
||||
if _, exists := postureChecks[id]; exists {
|
||||
validIDs = append(validIDs, id)
|
||||
}
|
||||
}
|
||||
return result
|
||||
|
||||
return validIDs
|
||||
}
|
||||
|
||||
// filterValidGroupIDs filters a list of group IDs and returns only the ones present in the account's group map.
|
||||
func filterValidGroupIDs(account *Account, groupIDs []string) []string {
|
||||
result := make([]string, 0, len(groupIDs))
|
||||
for _, groupID := range groupIDs {
|
||||
if _, exists := account.Groups[groupID]; exists {
|
||||
result = append(result, groupID)
|
||||
// getValidGroupIDs filters and returns only the valid group IDs from the provided list.
|
||||
func getValidGroupIDs(groups map[string]*nbgroup.Group, groupIDs []string) []string {
|
||||
validIDs := make([]string, 0, len(groupIDs))
|
||||
for _, id := range groupIDs {
|
||||
if _, exists := groups[id]; exists {
|
||||
validIDs = append(validIDs, id)
|
||||
}
|
||||
}
|
||||
|
||||
return validIDs
|
||||
}
|
||||
|
||||
// toProtocolFirewallRules converts the firewall rules to the protocol firewall rules.
|
||||
func toProtocolFirewallRules(rules []*FirewallRule) []*proto.FirewallRule {
|
||||
result := make([]*proto.FirewallRule, len(rules))
|
||||
for i := range rules {
|
||||
rule := rules[i]
|
||||
|
||||
result[i] = &proto.FirewallRule{
|
||||
PeerIP: rule.PeerIP,
|
||||
Direction: getProtoDirection(rule.Direction),
|
||||
Action: getProtoAction(rule.Action),
|
||||
Protocol: getProtoProtocol(rule.Protocol),
|
||||
Port: rule.Port,
|
||||
}
|
||||
}
|
||||
return result
|
||||
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/rs/xid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"golang.org/x/exp/slices"
|
||||
|
||||
@@ -859,14 +858,23 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
||||
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
||||
})
|
||||
|
||||
var policyWithGroupRulesNoPeers *Policy
|
||||
var policyWithDestinationPeersOnly *Policy
|
||||
var policyWithSourceAndDestinationPeers *Policy
|
||||
|
||||
// Saving policy with rule groups with no peers should not update account's peers and not send peer update
|
||||
t.Run("saving policy with rule groups with no peers", func(t *testing.T) {
|
||||
policy := Policy{
|
||||
ID: "policy-rule-groups-no-peers",
|
||||
Enabled: true,
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
policyWithGroupRulesNoPeers, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
|
||||
AccountID: account.Id,
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
ID: xid.New().String(),
|
||||
Enabled: true,
|
||||
Sources: []string{"groupB"},
|
||||
Destinations: []string{"groupC"},
|
||||
@@ -874,15 +882,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
||||
Action: PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false)
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
@@ -895,12 +895,17 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
||||
// Saving policy with source group containing peers, but destination group without peers should
|
||||
// update account's peers and send peer update
|
||||
t.Run("saving policy where source has peers but destination does not", func(t *testing.T) {
|
||||
policy := Policy{
|
||||
ID: "policy-source-has-peers-destination-none",
|
||||
Enabled: true,
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
_, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
|
||||
AccountID: account.Id,
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
ID: xid.New().String(),
|
||||
Enabled: true,
|
||||
Sources: []string{"groupA"},
|
||||
Destinations: []string{"groupB"},
|
||||
@@ -909,15 +914,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
||||
Action: PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false)
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
@@ -930,13 +927,18 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
||||
// Saving policy with destination group containing peers, but source group without peers should
|
||||
// update account's peers and send peer update
|
||||
t.Run("saving policy where destination has peers but source does not", func(t *testing.T) {
|
||||
policy := Policy{
|
||||
ID: "policy-destination-has-peers-source-none",
|
||||
Enabled: true,
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
policyWithDestinationPeersOnly, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
|
||||
AccountID: account.Id,
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
ID: xid.New().String(),
|
||||
Enabled: false,
|
||||
Enabled: true,
|
||||
Sources: []string{"groupC"},
|
||||
Destinations: []string{"groupD"},
|
||||
Bidirectional: true,
|
||||
@@ -944,15 +946,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
||||
Action: PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false)
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
@@ -965,12 +959,17 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
||||
// Saving policy with destination and source groups containing peers should update account's peers
|
||||
// and send peer update
|
||||
t.Run("saving policy with source and destination groups with peers", func(t *testing.T) {
|
||||
policy := Policy{
|
||||
ID: "policy-source-destination-peers",
|
||||
Enabled: true,
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
policyWithSourceAndDestinationPeers, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
|
||||
AccountID: account.Id,
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
ID: xid.New().String(),
|
||||
Enabled: true,
|
||||
Sources: []string{"groupA"},
|
||||
Destinations: []string{"groupD"},
|
||||
@@ -978,15 +977,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
||||
Action: PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false)
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
@@ -999,28 +990,14 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
||||
// Disabling policy with destination and source groups containing peers should update account's peers
|
||||
// and send peer update
|
||||
t.Run("disabling policy with source and destination groups with peers", func(t *testing.T) {
|
||||
policy := Policy{
|
||||
ID: "policy-source-destination-peers",
|
||||
Enabled: false,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
ID: xid.New().String(),
|
||||
Enabled: true,
|
||||
Sources: []string{"groupA"},
|
||||
Destinations: []string{"groupD"},
|
||||
Bidirectional: true,
|
||||
Action: PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, true)
|
||||
policyWithSourceAndDestinationPeers.Enabled = false
|
||||
policyWithSourceAndDestinationPeers, err = manager.SavePolicy(context.Background(), account.Id, userID, policyWithSourceAndDestinationPeers)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
@@ -1033,29 +1010,15 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
||||
// Updating disabled policy with destination and source groups containing peers should not update account's peers
|
||||
// or send peer update
|
||||
t.Run("updating disabled policy with source and destination groups with peers", func(t *testing.T) {
|
||||
policy := Policy{
|
||||
ID: "policy-source-destination-peers",
|
||||
Description: "updated description",
|
||||
Enabled: false,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
ID: xid.New().String(),
|
||||
Enabled: true,
|
||||
Sources: []string{"groupA"},
|
||||
Destinations: []string{"groupA"},
|
||||
Bidirectional: true,
|
||||
Action: PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, true)
|
||||
policyWithSourceAndDestinationPeers.Description = "updated description"
|
||||
policyWithSourceAndDestinationPeers.Rules[0].Destinations = []string{"groupA"}
|
||||
policyWithSourceAndDestinationPeers, err = manager.SavePolicy(context.Background(), account.Id, userID, policyWithSourceAndDestinationPeers)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
@@ -1068,28 +1031,14 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
||||
// Enabling policy with destination and source groups containing peers should update account's peers
|
||||
// and send peer update
|
||||
t.Run("enabling policy with source and destination groups with peers", func(t *testing.T) {
|
||||
policy := Policy{
|
||||
ID: "policy-source-destination-peers",
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
ID: xid.New().String(),
|
||||
Enabled: true,
|
||||
Sources: []string{"groupA"},
|
||||
Destinations: []string{"groupD"},
|
||||
Bidirectional: true,
|
||||
Action: PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, true)
|
||||
policyWithSourceAndDestinationPeers.Enabled = true
|
||||
policyWithSourceAndDestinationPeers, err = manager.SavePolicy(context.Background(), account.Id, userID, policyWithSourceAndDestinationPeers)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
@@ -1101,15 +1050,13 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
// Deleting policy should trigger account peers update and send peer update
|
||||
t.Run("deleting policy with source and destination groups with peers", func(t *testing.T) {
|
||||
policyID := "policy-source-destination-peers"
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err := manager.DeletePolicy(context.Background(), account.Id, policyID, userID)
|
||||
err := manager.DeletePolicy(context.Background(), account.Id, policyWithSourceAndDestinationPeers.ID, userID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
@@ -1123,14 +1070,13 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
||||
// Deleting policy with destination group containing peers, but source group without peers should
|
||||
// update account's peers and send peer update
|
||||
t.Run("deleting policy where destination has peers but source does not", func(t *testing.T) {
|
||||
policyID := "policy-destination-has-peers-source-none"
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err := manager.DeletePolicy(context.Background(), account.Id, policyID, userID)
|
||||
err := manager.DeletePolicy(context.Background(), account.Id, policyWithDestinationPeersOnly.ID, userID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
@@ -1142,14 +1088,13 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
// Deleting policy with no peers in groups should not update account's peers and not send peer update
|
||||
t.Run("deleting policy with no peers in groups", func(t *testing.T) {
|
||||
policyID := "policy-rule-groups-no-peers"
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err := manager.DeletePolicy(context.Background(), account.Id, policyID, userID)
|
||||
err := manager.DeletePolicy(context.Background(), account.Id, policyWithGroupRulesNoPeers.ID, userID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
|
||||
@@ -7,8 +7,6 @@ import (
|
||||
"regexp"
|
||||
|
||||
"github.com/hashicorp/go-version"
|
||||
"github.com/rs/xid"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
@@ -172,10 +170,6 @@ func NewChecksFromAPIPostureCheckUpdate(source api.PostureCheckUpdate, postureCh
|
||||
}
|
||||
|
||||
func buildPostureCheck(postureChecksID string, name string, description string, checks api.Checks) (*Checks, error) {
|
||||
if postureChecksID == "" {
|
||||
postureChecksID = xid.New().String()
|
||||
}
|
||||
|
||||
postureChecks := Checks{
|
||||
ID: postureChecksID,
|
||||
Name: name,
|
||||
|
||||
@@ -2,237 +2,285 @@ package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"slices"
|
||||
|
||||
"github.com/rs/xid"
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
)
|
||||
|
||||
const (
|
||||
errMsgPostureAdminOnly = "only users with admin power are allowed to view posture checks"
|
||||
)
|
||||
|
||||
func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) {
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !user.HasAdminPower() || user.AccountID != accountID {
|
||||
return nil, status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly)
|
||||
}
|
||||
|
||||
return am.Store.GetPostureChecksByID(ctx, LockingStrengthShare, postureChecksID, accountID)
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
user, err := account.FindUser(userID)
|
||||
if err != nil {
|
||||
return err
|
||||
if user.AccountID != accountID {
|
||||
return nil, status.NewUserNotPartOfAccountError()
|
||||
}
|
||||
|
||||
if !user.HasAdminPower() {
|
||||
return status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly)
|
||||
return nil, status.NewAdminPermissionError()
|
||||
}
|
||||
|
||||
if err := postureChecks.Validate(); err != nil {
|
||||
return status.Errorf(status.InvalidArgument, err.Error()) //nolint
|
||||
return am.Store.GetPostureChecksByID(ctx, LockingStrengthShare, accountID, postureChecksID)
|
||||
}
|
||||
|
||||
// SavePostureChecks saves a posture check.
|
||||
func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) (*posture.Checks, error) {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
exists, uniqName := am.savePostureChecks(account, postureChecks)
|
||||
|
||||
// we do not allow create new posture checks with non uniq name
|
||||
if !exists && !uniqName {
|
||||
return status.Errorf(status.PreconditionFailed, "Posture check name should be unique")
|
||||
if user.AccountID != accountID {
|
||||
return nil, status.NewUserNotPartOfAccountError()
|
||||
}
|
||||
|
||||
action := activity.PostureCheckCreated
|
||||
if exists {
|
||||
action = activity.PostureCheckUpdated
|
||||
account.Network.IncSerial()
|
||||
if !user.HasAdminPower() {
|
||||
return nil, status.NewAdminPermissionError()
|
||||
}
|
||||
|
||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
||||
return err
|
||||
var updateAccountPeers bool
|
||||
var isUpdate = postureChecks.ID != ""
|
||||
var action = activity.PostureCheckCreated
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||
if err = validatePostureChecks(ctx, transaction, accountID, postureChecks); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if isUpdate {
|
||||
updateAccountPeers, err = arePostureCheckChangesAffectPeers(ctx, transaction, accountID, postureChecks.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
action = activity.PostureCheckUpdated
|
||||
}
|
||||
|
||||
postureChecks.AccountID = accountID
|
||||
return transaction.SavePostureChecks(ctx, LockingStrengthUpdate, postureChecks)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
am.StoreEvent(ctx, userID, postureChecks.ID, accountID, action, postureChecks.EventMeta())
|
||||
|
||||
if arePostureCheckChangesAffectingPeers(account, postureChecks.ID, exists) {
|
||||
am.updateAccountPeers(ctx, account)
|
||||
if updateAccountPeers {
|
||||
am.updateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
return nil
|
||||
return postureChecks, nil
|
||||
}
|
||||
|
||||
// DeletePostureChecks deletes a posture check by ID.
|
||||
func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accountID, postureChecksID, userID string) error {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
user, err := account.FindUser(userID)
|
||||
if err != nil {
|
||||
return err
|
||||
if user.AccountID != accountID {
|
||||
return status.NewUserNotPartOfAccountError()
|
||||
}
|
||||
|
||||
if !user.HasAdminPower() {
|
||||
return status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly)
|
||||
return status.NewAdminPermissionError()
|
||||
}
|
||||
|
||||
postureChecks, err := am.deletePostureChecks(account, postureChecksID)
|
||||
var postureChecks *posture.Checks
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||
postureChecks, err = transaction.GetPostureChecksByID(ctx, LockingStrengthShare, accountID, postureChecksID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = isPostureCheckLinkedToPolicy(ctx, transaction, postureChecksID, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return transaction.DeletePostureChecks(ctx, LockingStrengthUpdate, accountID, postureChecksID)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
am.StoreEvent(ctx, userID, postureChecks.ID, accountID, activity.PostureCheckDeleted, postureChecks.EventMeta())
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListPostureChecks returns a list of posture checks.
|
||||
func (am *DefaultAccountManager) ListPostureChecks(ctx context.Context, accountID, userID string) ([]*posture.Checks, error) {
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !user.HasAdminPower() || user.AccountID != accountID {
|
||||
return nil, status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly)
|
||||
if user.AccountID != accountID {
|
||||
return nil, status.NewUserNotPartOfAccountError()
|
||||
}
|
||||
|
||||
if !user.HasAdminPower() {
|
||||
return nil, status.NewAdminPermissionError()
|
||||
}
|
||||
|
||||
return am.Store.GetAccountPostureChecks(ctx, LockingStrengthShare, accountID)
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) savePostureChecks(account *Account, postureChecks *posture.Checks) (exists, uniqName bool) {
|
||||
uniqName = true
|
||||
for i, p := range account.PostureChecks {
|
||||
if !exists && p.ID == postureChecks.ID {
|
||||
account.PostureChecks[i] = postureChecks
|
||||
exists = true
|
||||
}
|
||||
if p.Name == postureChecks.Name {
|
||||
uniqName = false
|
||||
}
|
||||
}
|
||||
if !exists {
|
||||
account.PostureChecks = append(account.PostureChecks, postureChecks)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) deletePostureChecks(account *Account, postureChecksID string) (*posture.Checks, error) {
|
||||
postureChecksIdx := -1
|
||||
for i, postureChecks := range account.PostureChecks {
|
||||
if postureChecks.ID == postureChecksID {
|
||||
postureChecksIdx = i
|
||||
break
|
||||
}
|
||||
}
|
||||
if postureChecksIdx < 0 {
|
||||
return nil, status.Errorf(status.NotFound, "posture checks with ID %s doesn't exist", postureChecksID)
|
||||
}
|
||||
|
||||
// Check if posture check is linked to any policy
|
||||
if isLinked, linkedPolicy := isPostureCheckLinkedToPolicy(account, postureChecksID); isLinked {
|
||||
return nil, status.Errorf(status.PreconditionFailed, "posture checks have been linked to policy: %s", linkedPolicy.Name)
|
||||
}
|
||||
|
||||
postureChecks := account.PostureChecks[postureChecksIdx]
|
||||
account.PostureChecks = append(account.PostureChecks[:postureChecksIdx], account.PostureChecks[postureChecksIdx+1:]...)
|
||||
|
||||
return postureChecks, nil
|
||||
}
|
||||
|
||||
// getPeerPostureChecks returns the posture checks applied for a given peer.
|
||||
func (am *DefaultAccountManager) getPeerPostureChecks(account *Account, peer *nbpeer.Peer) []*posture.Checks {
|
||||
peerPostureChecks := make(map[string]posture.Checks)
|
||||
func (am *DefaultAccountManager) getPeerPostureChecks(account *Account, peerID string) ([]*posture.Checks, error) {
|
||||
peerPostureChecks := make(map[string]*posture.Checks)
|
||||
|
||||
if len(account.PostureChecks) == 0 {
|
||||
return nil
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
for _, policy := range account.Policies {
|
||||
if !policy.Enabled {
|
||||
if !policy.Enabled || len(policy.SourcePostureChecks) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
if isPeerInPolicySourceGroups(peer.ID, account, policy) {
|
||||
addPolicyPostureChecks(account, policy, peerPostureChecks)
|
||||
if err := addPolicyPostureChecks(account, peerID, policy, peerPostureChecks); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
postureChecksList := make([]*posture.Checks, 0, len(peerPostureChecks))
|
||||
for _, check := range peerPostureChecks {
|
||||
checkCopy := check
|
||||
postureChecksList = append(postureChecksList, &checkCopy)
|
||||
return maps.Values(peerPostureChecks), nil
|
||||
}
|
||||
|
||||
// arePostureCheckChangesAffectPeers checks if the changes in posture checks are affecting peers.
|
||||
func arePostureCheckChangesAffectPeers(ctx context.Context, transaction Store, accountID, postureCheckID string) (bool, error) {
|
||||
policies, err := transaction.GetAccountPolicies(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return postureChecksList
|
||||
for _, policy := range policies {
|
||||
if slices.Contains(policy.SourcePostureChecks, postureCheckID) {
|
||||
hasPeers, err := anyGroupHasPeers(ctx, transaction, accountID, policy.ruleGroups())
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if hasPeers {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// validatePostureChecks validates the posture checks.
|
||||
func validatePostureChecks(ctx context.Context, transaction Store, accountID string, postureChecks *posture.Checks) error {
|
||||
if err := postureChecks.Validate(); err != nil {
|
||||
return status.Errorf(status.InvalidArgument, err.Error()) //nolint
|
||||
}
|
||||
|
||||
// If the posture check already has an ID, verify its existence in the store.
|
||||
if postureChecks.ID != "" {
|
||||
if _, err := transaction.GetPostureChecksByID(ctx, LockingStrengthShare, accountID, postureChecks.ID); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// For new posture checks, ensure no duplicates by name.
|
||||
checks, err := transaction.GetAccountPostureChecks(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, check := range checks {
|
||||
if check.Name == postureChecks.Name && check.ID != postureChecks.ID {
|
||||
return status.Errorf(status.InvalidArgument, "posture checks with name %s already exists", postureChecks.Name)
|
||||
}
|
||||
}
|
||||
|
||||
postureChecks.ID = xid.New().String()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// addPolicyPostureChecks adds posture checks from a policy to the peer posture checks map if the peer is in the policy's source groups.
|
||||
func addPolicyPostureChecks(account *Account, peerID string, policy *Policy, peerPostureChecks map[string]*posture.Checks) error {
|
||||
isInGroup, err := isPeerInPolicySourceGroups(account, peerID, policy)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !isInGroup {
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, sourcePostureCheckID := range policy.SourcePostureChecks {
|
||||
postureCheck := account.getPostureChecks(sourcePostureCheckID)
|
||||
if postureCheck == nil {
|
||||
return errors.New("failed to add policy posture checks: posture checks not found")
|
||||
}
|
||||
peerPostureChecks[sourcePostureCheckID] = postureCheck
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// isPeerInPolicySourceGroups checks if a peer is present in any of the policy rule source groups.
|
||||
func isPeerInPolicySourceGroups(peerID string, account *Account, policy *Policy) bool {
|
||||
func isPeerInPolicySourceGroups(account *Account, peerID string, policy *Policy) (bool, error) {
|
||||
for _, rule := range policy.Rules {
|
||||
if !rule.Enabled {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, sourceGroup := range rule.Sources {
|
||||
group, ok := account.Groups[sourceGroup]
|
||||
if ok && slices.Contains(group.Peers, peerID) {
|
||||
return true
|
||||
group := account.GetGroup(sourceGroup)
|
||||
if group == nil {
|
||||
return false, fmt.Errorf("failed to check peer in policy source group: group not found")
|
||||
}
|
||||
|
||||
if slices.Contains(group.Peers, peerID) {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func addPolicyPostureChecks(account *Account, policy *Policy, peerPostureChecks map[string]posture.Checks) {
|
||||
for _, sourcePostureCheckID := range policy.SourcePostureChecks {
|
||||
for _, postureCheck := range account.PostureChecks {
|
||||
if postureCheck.ID == sourcePostureCheckID {
|
||||
peerPostureChecks[sourcePostureCheckID] = *postureCheck
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func isPostureCheckLinkedToPolicy(account *Account, postureChecksID string) (bool, *Policy) {
|
||||
for _, policy := range account.Policies {
|
||||
if slices.Contains(policy.SourcePostureChecks, postureChecksID) {
|
||||
return true, policy
|
||||
}
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// arePostureCheckChangesAffectingPeers checks if the changes in posture checks are affecting peers.
|
||||
func arePostureCheckChangesAffectingPeers(account *Account, postureCheckID string, exists bool) bool {
|
||||
if !exists {
|
||||
return false
|
||||
// isPostureCheckLinkedToPolicy checks whether the posture check is linked to any account policy.
|
||||
func isPostureCheckLinkedToPolicy(ctx context.Context, transaction Store, postureChecksID, accountID string) error {
|
||||
policies, err := transaction.GetAccountPolicies(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
isLinked, linkedPolicy := isPostureCheckLinkedToPolicy(account, postureCheckID)
|
||||
if !isLinked {
|
||||
return false
|
||||
for _, policy := range policies {
|
||||
if slices.Contains(policy.SourcePostureChecks, postureChecksID) {
|
||||
return status.Errorf(status.PreconditionFailed, "posture checks have been linked to policy: %s", policy.Name)
|
||||
}
|
||||
}
|
||||
return anyGroupHasPeers(account, linkedPolicy.ruleGroups())
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -5,8 +5,8 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/rs/xid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/group"
|
||||
|
||||
@@ -16,7 +16,6 @@ import (
|
||||
const (
|
||||
adminUserID = "adminUserID"
|
||||
regularUserID = "regularUserID"
|
||||
postureCheckID = "existing-id"
|
||||
postureCheckName = "Existing check"
|
||||
)
|
||||
|
||||
@@ -33,7 +32,7 @@ func TestDefaultAccountManager_PostureCheck(t *testing.T) {
|
||||
|
||||
t.Run("Generic posture check flow", func(t *testing.T) {
|
||||
// regular users can not create checks
|
||||
err := am.SavePostureChecks(context.Background(), account.Id, regularUserID, &posture.Checks{})
|
||||
_, err = am.SavePostureChecks(context.Background(), account.Id, regularUserID, &posture.Checks{})
|
||||
assert.Error(t, err)
|
||||
|
||||
// regular users cannot list check
|
||||
@@ -41,8 +40,7 @@ func TestDefaultAccountManager_PostureCheck(t *testing.T) {
|
||||
assert.Error(t, err)
|
||||
|
||||
// should be possible to create posture check with uniq name
|
||||
err = am.SavePostureChecks(context.Background(), account.Id, adminUserID, &posture.Checks{
|
||||
ID: postureCheckID,
|
||||
postureCheck, err := am.SavePostureChecks(context.Background(), account.Id, adminUserID, &posture.Checks{
|
||||
Name: postureCheckName,
|
||||
Checks: posture.ChecksDefinition{
|
||||
NBVersionCheck: &posture.NBVersionCheck{
|
||||
@@ -58,8 +56,7 @@ func TestDefaultAccountManager_PostureCheck(t *testing.T) {
|
||||
assert.Len(t, checks, 1)
|
||||
|
||||
// should not be possible to create posture check with non uniq name
|
||||
err = am.SavePostureChecks(context.Background(), account.Id, adminUserID, &posture.Checks{
|
||||
ID: "new-id",
|
||||
_, err = am.SavePostureChecks(context.Background(), account.Id, adminUserID, &posture.Checks{
|
||||
Name: postureCheckName,
|
||||
Checks: posture.ChecksDefinition{
|
||||
GeoLocationCheck: &posture.GeoLocationCheck{
|
||||
@@ -74,23 +71,20 @@ func TestDefaultAccountManager_PostureCheck(t *testing.T) {
|
||||
assert.Error(t, err)
|
||||
|
||||
// admins can update posture checks
|
||||
err = am.SavePostureChecks(context.Background(), account.Id, adminUserID, &posture.Checks{
|
||||
ID: postureCheckID,
|
||||
Name: postureCheckName,
|
||||
Checks: posture.ChecksDefinition{
|
||||
NBVersionCheck: &posture.NBVersionCheck{
|
||||
MinVersion: "0.27.0",
|
||||
},
|
||||
postureCheck.Checks = posture.ChecksDefinition{
|
||||
NBVersionCheck: &posture.NBVersionCheck{
|
||||
MinVersion: "0.27.0",
|
||||
},
|
||||
})
|
||||
}
|
||||
_, err = am.SavePostureChecks(context.Background(), account.Id, adminUserID, postureCheck)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// users should not be able to delete posture checks
|
||||
err = am.DeletePostureChecks(context.Background(), account.Id, postureCheckID, regularUserID)
|
||||
err = am.DeletePostureChecks(context.Background(), account.Id, postureCheck.ID, regularUserID)
|
||||
assert.Error(t, err)
|
||||
|
||||
// admin should be able to delete posture checks
|
||||
err = am.DeletePostureChecks(context.Background(), account.Id, postureCheckID, adminUserID)
|
||||
err = am.DeletePostureChecks(context.Background(), account.Id, postureCheck.ID, adminUserID)
|
||||
assert.NoError(t, err)
|
||||
checks, err = am.ListPostureChecks(context.Background(), account.Id, adminUserID)
|
||||
assert.NoError(t, err)
|
||||
@@ -150,9 +144,22 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
||||
})
|
||||
|
||||
postureCheck := posture.Checks{
|
||||
ID: "postureCheck",
|
||||
Name: "postureCheck",
|
||||
postureCheckA := &posture.Checks{
|
||||
Name: "postureCheckA",
|
||||
AccountID: account.Id,
|
||||
Checks: posture.ChecksDefinition{
|
||||
ProcessCheck: &posture.ProcessCheck{
|
||||
Processes: []posture.Process{
|
||||
{LinuxPath: "/usr/bin/netbird", MacPath: "/usr/local/bin/netbird"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
postureCheckA, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckA)
|
||||
require.NoError(t, err)
|
||||
|
||||
postureCheckB := &posture.Checks{
|
||||
Name: "postureCheckB",
|
||||
AccountID: account.Id,
|
||||
Checks: posture.ChecksDefinition{
|
||||
NBVersionCheck: &posture.NBVersionCheck{
|
||||
@@ -169,7 +176,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck)
|
||||
postureCheckB, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
@@ -187,12 +194,12 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||
close(done)
|
||||
}()
|
||||
|
||||
postureCheck.Checks = posture.ChecksDefinition{
|
||||
postureCheckB.Checks = posture.ChecksDefinition{
|
||||
NBVersionCheck: &posture.NBVersionCheck{
|
||||
MinVersion: "0.29.0",
|
||||
},
|
||||
}
|
||||
err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck)
|
||||
_, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
@@ -202,12 +209,10 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
policy := Policy{
|
||||
ID: "policyA",
|
||||
policy := &Policy{
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
ID: xid.New().String(),
|
||||
Enabled: true,
|
||||
Sources: []string{"groupA"},
|
||||
Destinations: []string{"groupA"},
|
||||
@@ -215,7 +220,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||
Action: PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
SourcePostureChecks: []string{postureCheck.ID},
|
||||
SourcePostureChecks: []string{postureCheckB.ID},
|
||||
}
|
||||
|
||||
// Linking posture check to policy should trigger update account peers and send peer update
|
||||
@@ -226,7 +231,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false)
|
||||
policy, err = manager.SavePolicy(context.Background(), account.Id, userID, policy)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
@@ -238,7 +243,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
// Updating linked posture checks should update account peers and send peer update
|
||||
t.Run("updating linked to posture check with peers", func(t *testing.T) {
|
||||
postureCheck.Checks = posture.ChecksDefinition{
|
||||
postureCheckB.Checks = posture.ChecksDefinition{
|
||||
NBVersionCheck: &posture.NBVersionCheck{
|
||||
MinVersion: "0.29.0",
|
||||
},
|
||||
@@ -255,7 +260,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck)
|
||||
_, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
@@ -274,8 +279,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||
}()
|
||||
|
||||
policy.SourcePostureChecks = []string{}
|
||||
|
||||
err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, true)
|
||||
_, err := manager.SavePolicy(context.Background(), account.Id, userID, policy)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
@@ -293,7 +297,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err := manager.DeletePostureChecks(context.Background(), account.Id, "postureCheck", userID)
|
||||
err := manager.DeletePostureChecks(context.Background(), account.Id, postureCheckA.ID, userID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
@@ -303,17 +307,15 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
err = manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck)
|
||||
_, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Updating linked posture check to policy with no peers should not trigger account peers update and not send peer update
|
||||
t.Run("updating linked posture check to policy with no peers", func(t *testing.T) {
|
||||
policy = Policy{
|
||||
ID: "policyB",
|
||||
_, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
ID: xid.New().String(),
|
||||
Enabled: true,
|
||||
Sources: []string{"groupB"},
|
||||
Destinations: []string{"groupC"},
|
||||
@@ -321,9 +323,8 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||
Action: PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
SourcePostureChecks: []string{postureCheck.ID},
|
||||
}
|
||||
err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, false)
|
||||
SourcePostureChecks: []string{postureCheckB.ID},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
done := make(chan struct{})
|
||||
@@ -332,12 +333,12 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||
close(done)
|
||||
}()
|
||||
|
||||
postureCheck.Checks = posture.ChecksDefinition{
|
||||
postureCheckB.Checks = posture.ChecksDefinition{
|
||||
NBVersionCheck: &posture.NBVersionCheck{
|
||||
MinVersion: "0.29.0",
|
||||
},
|
||||
}
|
||||
err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck)
|
||||
_, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
@@ -354,12 +355,11 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||
t.Cleanup(func() {
|
||||
manager.peersUpdateManager.CloseChannel(context.Background(), peer2.ID)
|
||||
})
|
||||
policy = Policy{
|
||||
ID: "policyB",
|
||||
|
||||
_, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
ID: xid.New().String(),
|
||||
Enabled: true,
|
||||
Sources: []string{"groupB"},
|
||||
Destinations: []string{"groupA"},
|
||||
@@ -367,10 +367,8 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||
Action: PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
SourcePostureChecks: []string{postureCheck.ID},
|
||||
}
|
||||
|
||||
err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, true)
|
||||
SourcePostureChecks: []string{postureCheckB.ID},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
done := make(chan struct{})
|
||||
@@ -379,12 +377,12 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||
close(done)
|
||||
}()
|
||||
|
||||
postureCheck.Checks = posture.ChecksDefinition{
|
||||
postureCheckB.Checks = posture.ChecksDefinition{
|
||||
NBVersionCheck: &posture.NBVersionCheck{
|
||||
MinVersion: "0.29.0",
|
||||
},
|
||||
}
|
||||
err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck)
|
||||
_, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
@@ -397,8 +395,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||
// Updating linked client posture check to policy where source has peers but destination does not,
|
||||
// should trigger account peers update and send peer update
|
||||
t.Run("updating linked posture check to policy where source has peers but destination does not", func(t *testing.T) {
|
||||
policy = Policy{
|
||||
ID: "policyB",
|
||||
_, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
@@ -409,9 +406,8 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||
Action: PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
SourcePostureChecks: []string{postureCheck.ID},
|
||||
}
|
||||
err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, true)
|
||||
SourcePostureChecks: []string{postureCheckB.ID},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
done := make(chan struct{})
|
||||
@@ -420,7 +416,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||
close(done)
|
||||
}()
|
||||
|
||||
postureCheck.Checks = posture.ChecksDefinition{
|
||||
postureCheckB.Checks = posture.ChecksDefinition{
|
||||
ProcessCheck: &posture.ProcessCheck{
|
||||
Processes: []posture.Process{
|
||||
{
|
||||
@@ -429,7 +425,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||
},
|
||||
},
|
||||
}
|
||||
err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck)
|
||||
_, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
@@ -440,80 +436,120 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestArePostureCheckChangesAffectingPeers(t *testing.T) {
|
||||
account := &Account{
|
||||
Policies: []*Policy{
|
||||
{
|
||||
ID: "policyA",
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
Enabled: true,
|
||||
Sources: []string{"groupA"},
|
||||
Destinations: []string{"groupA"},
|
||||
},
|
||||
},
|
||||
SourcePostureChecks: []string{"checkA"},
|
||||
},
|
||||
},
|
||||
Groups: map[string]*group.Group{
|
||||
"groupA": {
|
||||
ID: "groupA",
|
||||
Peers: []string{"peer1"},
|
||||
},
|
||||
"groupB": {
|
||||
ID: "groupB",
|
||||
Peers: []string{},
|
||||
},
|
||||
},
|
||||
PostureChecks: []*posture.Checks{
|
||||
{
|
||||
ID: "checkA",
|
||||
},
|
||||
{
|
||||
ID: "checkB",
|
||||
},
|
||||
},
|
||||
func TestArePostureCheckChangesAffectPeers(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
require.NoError(t, err, "failed to create account manager")
|
||||
|
||||
account, err := initTestPostureChecksAccount(manager)
|
||||
require.NoError(t, err, "failed to init testing account")
|
||||
|
||||
groupA := &group.Group{
|
||||
ID: "groupA",
|
||||
AccountID: account.Id,
|
||||
Peers: []string{"peer1"},
|
||||
}
|
||||
|
||||
groupB := &group.Group{
|
||||
ID: "groupB",
|
||||
AccountID: account.Id,
|
||||
Peers: []string{},
|
||||
}
|
||||
err = manager.Store.SaveGroups(context.Background(), LockingStrengthUpdate, []*group.Group{groupA, groupB})
|
||||
require.NoError(t, err, "failed to save groups")
|
||||
|
||||
postureCheckA := &posture.Checks{
|
||||
Name: "checkA",
|
||||
AccountID: account.Id,
|
||||
Checks: posture.ChecksDefinition{
|
||||
NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.33.1"},
|
||||
},
|
||||
}
|
||||
postureCheckA, err = manager.SavePostureChecks(context.Background(), account.Id, adminUserID, postureCheckA)
|
||||
require.NoError(t, err, "failed to save postureCheckA")
|
||||
|
||||
postureCheckB := &posture.Checks{
|
||||
Name: "checkB",
|
||||
AccountID: account.Id,
|
||||
Checks: posture.ChecksDefinition{
|
||||
NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.33.1"},
|
||||
},
|
||||
}
|
||||
postureCheckB, err = manager.SavePostureChecks(context.Background(), account.Id, adminUserID, postureCheckB)
|
||||
require.NoError(t, err, "failed to save postureCheckB")
|
||||
|
||||
policy := &Policy{
|
||||
AccountID: account.Id,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
Enabled: true,
|
||||
Sources: []string{"groupA"},
|
||||
Destinations: []string{"groupA"},
|
||||
},
|
||||
},
|
||||
SourcePostureChecks: []string{postureCheckA.ID},
|
||||
}
|
||||
|
||||
policy, err = manager.SavePolicy(context.Background(), account.Id, adminUserID, policy)
|
||||
require.NoError(t, err, "failed to save policy")
|
||||
|
||||
t.Run("posture check exists and is linked to policy with peers", func(t *testing.T) {
|
||||
result := arePostureCheckChangesAffectingPeers(account, "checkA", true)
|
||||
result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, result)
|
||||
})
|
||||
|
||||
t.Run("posture check exists but is not linked to any policy", func(t *testing.T) {
|
||||
result := arePostureCheckChangesAffectingPeers(account, "checkB", true)
|
||||
result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckB.ID)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, result)
|
||||
})
|
||||
|
||||
t.Run("posture check does not exist", func(t *testing.T) {
|
||||
result := arePostureCheckChangesAffectingPeers(account, "unknown", false)
|
||||
result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, "unknown")
|
||||
require.NoError(t, err)
|
||||
assert.False(t, result)
|
||||
})
|
||||
|
||||
t.Run("posture check is linked to policy with no peers in source groups", func(t *testing.T) {
|
||||
account.Policies[0].Rules[0].Sources = []string{"groupB"}
|
||||
account.Policies[0].Rules[0].Destinations = []string{"groupA"}
|
||||
result := arePostureCheckChangesAffectingPeers(account, "checkA", true)
|
||||
policy.Rules[0].Sources = []string{"groupB"}
|
||||
policy.Rules[0].Destinations = []string{"groupA"}
|
||||
_, err = manager.SavePolicy(context.Background(), account.Id, adminUserID, policy)
|
||||
require.NoError(t, err, "failed to update policy")
|
||||
|
||||
result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, result)
|
||||
})
|
||||
|
||||
t.Run("posture check is linked to policy with no peers in destination groups", func(t *testing.T) {
|
||||
account.Policies[0].Rules[0].Sources = []string{"groupA"}
|
||||
account.Policies[0].Rules[0].Destinations = []string{"groupB"}
|
||||
result := arePostureCheckChangesAffectingPeers(account, "checkA", true)
|
||||
policy.Rules[0].Sources = []string{"groupA"}
|
||||
policy.Rules[0].Destinations = []string{"groupB"}
|
||||
_, err = manager.SavePolicy(context.Background(), account.Id, adminUserID, policy)
|
||||
require.NoError(t, err, "failed to update policy")
|
||||
|
||||
result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, result)
|
||||
})
|
||||
|
||||
t.Run("posture check is linked to policy with non-existent group", func(t *testing.T) {
|
||||
account.Policies[0].Rules[0].Sources = []string{"nonExistentGroup"}
|
||||
account.Policies[0].Rules[0].Destinations = []string{"nonExistentGroup"}
|
||||
result := arePostureCheckChangesAffectingPeers(account, "checkA", true)
|
||||
t.Run("posture check is linked to policy but no peers in groups", func(t *testing.T) {
|
||||
groupA.Peers = []string{}
|
||||
err = manager.Store.SaveGroup(context.Background(), LockingStrengthUpdate, groupA)
|
||||
require.NoError(t, err, "failed to save groups")
|
||||
|
||||
result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, result)
|
||||
})
|
||||
|
||||
t.Run("posture check is linked to policy but no peers in groups", func(t *testing.T) {
|
||||
account.Groups["groupA"].Peers = []string{}
|
||||
result := arePostureCheckChangesAffectingPeers(account, "checkA", true)
|
||||
t.Run("posture check is linked to policy with non-existent group", func(t *testing.T) {
|
||||
policy.Rules[0].Sources = []string{"nonExistentGroup"}
|
||||
policy.Rules[0].Destinations = []string{"nonExistentGroup"}
|
||||
_, err = manager.SavePolicy(context.Background(), account.Id, adminUserID, policy)
|
||||
require.NoError(t, err, "failed to update policy")
|
||||
|
||||
result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, result)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -237,8 +237,8 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if isRouteChangeAffectPeers(account, &newRoute) {
|
||||
am.updateAccountPeers(ctx, account)
|
||||
if am.isRouteChangeAffectPeers(account, &newRoute) {
|
||||
am.updateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
am.StoreEvent(ctx, userID, string(newRoute.ID), accountID, activity.RouteCreated, newRoute.EventMeta())
|
||||
@@ -323,8 +323,8 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
|
||||
return err
|
||||
}
|
||||
|
||||
if isRouteChangeAffectPeers(account, oldRoute) || isRouteChangeAffectPeers(account, routeToSave) {
|
||||
am.updateAccountPeers(ctx, account)
|
||||
if am.isRouteChangeAffectPeers(account, oldRoute) || am.isRouteChangeAffectPeers(account, routeToSave) {
|
||||
am.updateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
am.StoreEvent(ctx, userID, string(routeToSave.ID), accountID, activity.RouteUpdated, routeToSave.EventMeta())
|
||||
@@ -355,8 +355,8 @@ func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID stri
|
||||
|
||||
am.StoreEvent(ctx, userID, string(routy.ID), accountID, activity.RouteRemoved, routy.EventMeta())
|
||||
|
||||
if isRouteChangeAffectPeers(account, routy) {
|
||||
am.updateAccountPeers(ctx, account)
|
||||
if am.isRouteChangeAffectPeers(account, routy) {
|
||||
am.updateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -651,6 +651,6 @@ func getProtoPortInfo(rule *RouteFirewallRule) *proto.PortInfo {
|
||||
|
||||
// isRouteChangeAffectPeers checks if a given route affects peers by determining
|
||||
// if it has a routing peer, distribution, or peer groups that include peers
|
||||
func isRouteChangeAffectPeers(account *Account, route *route.Route) bool {
|
||||
return anyGroupHasPeers(account, route.Groups) || anyGroupHasPeers(account, route.PeerGroups) || route.Peer != ""
|
||||
func (am *DefaultAccountManager) isRouteChangeAffectPeers(account *Account, route *route.Route) bool {
|
||||
return am.anyGroupHasPeers(account, route.Groups) || am.anyGroupHasPeers(account, route.PeerGroups) || route.Peer != ""
|
||||
}
|
||||
|
||||
@@ -1091,7 +1091,7 @@ func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, peer4Routes.Routes, 1, "HA route should have 1 server route")
|
||||
|
||||
groups, err := am.ListGroups(context.Background(), account.Id)
|
||||
groups, err := am.Store.GetAccountGroups(context.Background(), LockingStrengthShare, account.Id)
|
||||
require.NoError(t, err)
|
||||
var groupHA1, groupHA2 *nbgroup.Group
|
||||
for _, group := range groups {
|
||||
@@ -1214,12 +1214,11 @@ func TestGetNetworkMap_RouteSync(t *testing.T) {
|
||||
|
||||
defaultRule := rules[0]
|
||||
newPolicy := defaultRule.Copy()
|
||||
newPolicy.ID = xid.New().String()
|
||||
newPolicy.Name = "peer1 only"
|
||||
newPolicy.Rules[0].Sources = []string{newGroup.ID}
|
||||
newPolicy.Rules[0].Destinations = []string{newGroup.ID}
|
||||
|
||||
err = am.SavePolicy(context.Background(), account.Id, userID, newPolicy, false)
|
||||
_, err = am.SavePolicy(context.Background(), account.Id, userID, newPolicy)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = am.DeletePolicy(context.Background(), account.Id, defaultRule.ID, userID)
|
||||
|
||||
@@ -12,9 +12,10 @@ import (
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/google/uuid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -276,7 +277,7 @@ func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID s
|
||||
// SaveSetupKey saves the provided SetupKey to the database overriding the existing one.
|
||||
// Due to the unique nature of a SetupKey certain properties must not be overwritten
|
||||
// (e.g. the key itself, creation date, ID, etc).
|
||||
// These properties are overwritten: Name, AutoGroups, Revoked. The rest is copied from the existing key.
|
||||
// These properties are overwritten: AutoGroups, Revoked (only from false to true), and the UpdatedAt. The rest is copied from the existing key.
|
||||
func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID string, keyToSave *SetupKey, userID string) (*SetupKey, error) {
|
||||
if keyToSave == nil {
|
||||
return nil, status.Errorf(status.InvalidArgument, "provided setup key to update is nil")
|
||||
@@ -312,9 +313,12 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str
|
||||
return err
|
||||
}
|
||||
|
||||
// only auto groups, revoked status, and name can be updated for now
|
||||
if oldKey.Revoked && !keyToSave.Revoked {
|
||||
return status.Errorf(status.InvalidArgument, "can't un-revoke a revoked setup key")
|
||||
}
|
||||
|
||||
// only auto groups, revoked status (from false to true) can be updated
|
||||
newKey = oldKey.Copy()
|
||||
newKey.Name = keyToSave.Name
|
||||
newKey.AutoGroups = keyToSave.AutoGroups
|
||||
newKey.Revoked = keyToSave.Revoked
|
||||
newKey.UpdatedAt = time.Now().UTC()
|
||||
@@ -375,7 +379,7 @@ func (am *DefaultAccountManager) GetSetupKey(ctx context.Context, accountID, use
|
||||
return nil, status.NewAdminPermissionError()
|
||||
}
|
||||
|
||||
setupKey, err := am.Store.GetSetupKeyByID(ctx, LockingStrengthShare, keyID, accountID)
|
||||
setupKey, err := am.Store.GetSetupKeyByID(ctx, LockingStrengthShare, accountID, keyID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -449,14 +453,14 @@ func (am *DefaultAccountManager) prepareSetupKeyEvents(ctx context.Context, tran
|
||||
modifiedGroups := slices.Concat(addedGroups, removedGroups)
|
||||
groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, modifiedGroups)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("issue getting groups for setup key events: %v", err)
|
||||
log.WithContext(ctx).Debugf("failed to get groups for setup key events: %v", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, g := range removedGroups {
|
||||
group, ok := groups[g]
|
||||
if !ok {
|
||||
log.WithContext(ctx).Debugf("skipped adding group: %s GroupRemovedFromSetupKey activity: %v", g, err)
|
||||
log.WithContext(ctx).Debugf("skipped adding group: %s GroupRemovedFromSetupKey activity: group not found", g)
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -469,7 +473,7 @@ func (am *DefaultAccountManager) prepareSetupKeyEvents(ctx context.Context, tran
|
||||
for _, g := range addedGroups {
|
||||
group, ok := groups[g]
|
||||
if !ok {
|
||||
log.WithContext(ctx).Debugf("skipped adding group: %s GroupAddedToSetupKey activity: %v", g, err)
|
||||
log.WithContext(ctx).Debugf("skipped adding group: %s GroupAddedToSetupKey activity: group not found", g)
|
||||
continue
|
||||
}
|
||||
|
||||
|
||||
@@ -56,11 +56,9 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) {
|
||||
}
|
||||
|
||||
autoGroups := []string{"group_1", "group_2"}
|
||||
newKeyName := "my-new-test-key"
|
||||
revoked := true
|
||||
newKey, err := manager.SaveSetupKey(context.Background(), account.Id, &SetupKey{
|
||||
Id: key.Id,
|
||||
Name: newKeyName,
|
||||
Revoked: revoked,
|
||||
AutoGroups: autoGroups,
|
||||
}, userID)
|
||||
@@ -68,7 +66,7 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
assertKey(t, newKey, newKeyName, revoked, "reusable", 0, key.CreatedAt, key.ExpiresAt,
|
||||
assertKey(t, newKey, keyName, revoked, "reusable", 0, key.CreatedAt, key.ExpiresAt,
|
||||
key.Id, time.Now().UTC(), autoGroups, true)
|
||||
|
||||
// check the corresponding events that should have been generated
|
||||
@@ -76,7 +74,7 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) {
|
||||
|
||||
assert.NotNil(t, ev)
|
||||
assert.Equal(t, account.Id, ev.AccountID)
|
||||
assert.Equal(t, newKeyName, ev.Meta["name"])
|
||||
assert.Equal(t, keyName, ev.Meta["name"])
|
||||
assert.Equal(t, fmt.Sprint(key.Type), fmt.Sprint(ev.Meta["type"]))
|
||||
assert.NotEmpty(t, ev.Meta["key"])
|
||||
assert.Equal(t, userID, ev.InitiatorID)
|
||||
@@ -89,7 +87,6 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) {
|
||||
autoGroups = append(autoGroups, groupAll.ID)
|
||||
_, err = manager.SaveSetupKey(context.Background(), account.Id, &SetupKey{
|
||||
Id: key.Id,
|
||||
Name: newKeyName,
|
||||
Revoked: revoked,
|
||||
AutoGroups: autoGroups,
|
||||
}, userID)
|
||||
@@ -213,22 +210,41 @@ func TestGetSetupKeys(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{
|
||||
ID: "group_1",
|
||||
Name: "group_name_1",
|
||||
Peers: []string{},
|
||||
})
|
||||
plainKey, err := manager.CreateSetupKey(context.Background(), account.Id, "key1", SetupKeyReusable, time.Hour, nil, SetupKeyUnlimitedUsage, userID, false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{
|
||||
ID: "group_2",
|
||||
Name: "group_name_2",
|
||||
Peers: []string{},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
type testCase struct {
|
||||
name string
|
||||
keyId string
|
||||
expectedFailure bool
|
||||
}
|
||||
|
||||
testCase1 := testCase{
|
||||
name: "Should get existing Setup Key",
|
||||
keyId: plainKey.Id,
|
||||
expectedFailure: false,
|
||||
}
|
||||
testCase2 := testCase{
|
||||
name: "Should fail to get non-existent Setup Key",
|
||||
keyId: "some key",
|
||||
expectedFailure: true,
|
||||
}
|
||||
|
||||
for _, tCase := range []testCase{testCase1, testCase2} {
|
||||
t.Run(tCase.name, func(t *testing.T) {
|
||||
key, err := manager.GetSetupKey(context.Background(), account.Id, userID, tCase.keyId)
|
||||
|
||||
if tCase.expectedFailure {
|
||||
if err == nil {
|
||||
t.Fatal("expected to fail")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
assert.NotEqual(t, plainKey.Key, key.Key)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -390,8 +406,7 @@ func TestSetupKeyAccountPeersUpdate(t *testing.T) {
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
policy := Policy{
|
||||
ID: "policy",
|
||||
policy := &Policy{
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
@@ -403,7 +418,7 @@ func TestSetupKeyAccountPeersUpdate(t *testing.T) {
|
||||
},
|
||||
},
|
||||
}
|
||||
err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, false)
|
||||
_, err = manager.SavePolicy(context.Background(), account.Id, userID, policy)
|
||||
require.NoError(t, err)
|
||||
|
||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
@@ -449,3 +464,31 @@ func TestSetupKeyAccountPeersUpdate(t *testing.T) {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestDefaultAccountManager_CreateSetupKey_ShouldNotAllowToUpdateRevokedKey(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
userID := "testingUser"
|
||||
account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
key, err := manager.CreateSetupKey(context.Background(), account.Id, "testName", SetupKeyReusable, time.Hour, nil, SetupKeyUnlimitedUsage, userID, false)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// revoke the key
|
||||
updateKey := key.Copy()
|
||||
updateKey.Revoked = true
|
||||
_, err = manager.SaveSetupKey(context.Background(), account.Id, updateKey, userID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// re-activate revoked key
|
||||
updateKey.Revoked = false
|
||||
_, err = manager.SaveSetupKey(context.Background(), account.Id, updateKey, userID)
|
||||
assert.Error(t, err, "should not allow to update revoked key")
|
||||
|
||||
}
|
||||
|
||||
@@ -33,12 +33,13 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
storeSqliteFileName = "store.db"
|
||||
idQueryCondition = "id = ?"
|
||||
keyQueryCondition = "key = ?"
|
||||
accountAndIDQueryCondition = "account_id = ? and id = ?"
|
||||
accountIDCondition = "account_id = ?"
|
||||
peerNotFoundFMT = "peer %s not found"
|
||||
storeSqliteFileName = "store.db"
|
||||
idQueryCondition = "id = ?"
|
||||
keyQueryCondition = "key = ?"
|
||||
accountAndIDQueryCondition = "account_id = ? and id = ?"
|
||||
accountAndIDsQueryCondition = "account_id = ? AND id IN ?"
|
||||
accountIDCondition = "account_id = ?"
|
||||
peerNotFoundFMT = "peer %s not found"
|
||||
)
|
||||
|
||||
// SqlStore represents an account storage backed by a Sql DB persisted to disk
|
||||
@@ -555,9 +556,9 @@ func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStre
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountUsers(ctx context.Context, accountID string) ([]*User, error) {
|
||||
func (s *SqlStore) GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*User, error) {
|
||||
var users []*User
|
||||
result := s.db.Find(&users, accountIDCondition, accountID)
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&users, accountIDCondition, accountID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "accountID not found: index lookup failed")
|
||||
@@ -857,7 +858,6 @@ func (s *SqlStore) SaveUserLastLogin(ctx context.Context, accountID, userID stri
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return status.NewUserNotFoundError(userID)
|
||||
}
|
||||
|
||||
return status.NewGetUserFromStoreError()
|
||||
}
|
||||
user.LastLogin = lastLogin
|
||||
@@ -1045,7 +1045,7 @@ func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountId string, peerId
|
||||
result := s.db.Where(accountAndIDQueryCondition, accountId, groupID).First(&group)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return status.Errorf(status.NotFound, "group not found for account")
|
||||
return status.NewGroupNotFoundError(groupID)
|
||||
}
|
||||
|
||||
return status.Errorf(status.Internal, "issue finding group: %s", result.Error)
|
||||
@@ -1079,15 +1079,51 @@ func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) erro
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, accountId string) error {
|
||||
result := s.db.Model(&Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1"))
|
||||
// GetPeerByID retrieves a peer by its ID and account ID.
|
||||
func (s *SqlStore) GetPeerByID(ctx context.Context, lockStrength LockingStrength, accountID, peerID string) (*nbpeer.Peer, error) {
|
||||
var peer *nbpeer.Peer
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
First(&peer, accountAndIDQueryCondition, accountID, peerID)
|
||||
if result.Error != nil {
|
||||
return status.Errorf(status.Internal, "issue incrementing network serial count: %s", result.Error)
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "peer not found")
|
||||
}
|
||||
log.WithContext(ctx).Errorf("failed to get peer from store: %s", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "failed to get peer from store")
|
||||
}
|
||||
|
||||
return peer, nil
|
||||
}
|
||||
|
||||
// GetPeersByIDs retrieves peers by their IDs and account ID.
|
||||
func (s *SqlStore) GetPeersByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, peerIDs []string) (map[string]*nbpeer.Peer, error) {
|
||||
var peers []*nbpeer.Peer
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&peers, accountAndIDsQueryCondition, accountID, peerIDs)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get peers by ID's from the store: %s", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "failed to get peers by ID's from the store")
|
||||
}
|
||||
|
||||
peersMap := make(map[string]*nbpeer.Peer)
|
||||
for _, peer := range peers {
|
||||
peersMap[peer.ID] = peer
|
||||
}
|
||||
|
||||
return peersMap, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, lockStrength LockingStrength, accountId string) error {
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
Model(&Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1"))
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to increment network serial count in store: %v", result.Error)
|
||||
return status.Errorf(status.Internal, "failed to increment network serial count in store")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) ExecuteInTransaction(ctx context.Context, operation func(store Store) error) error {
|
||||
startTime := time.Now()
|
||||
tx := s.db.Begin()
|
||||
if tx.Error != nil {
|
||||
return tx.Error
|
||||
@@ -1098,12 +1134,21 @@ func (s *SqlStore) ExecuteInTransaction(ctx context.Context, operation func(stor
|
||||
tx.Rollback()
|
||||
return err
|
||||
}
|
||||
return tx.Commit().Error
|
||||
|
||||
err = tx.Commit().Error
|
||||
|
||||
log.WithContext(ctx).Tracef("transaction took %v", time.Since(startTime))
|
||||
if s.metrics != nil {
|
||||
s.metrics.StoreMetrics().CountTransactionDuration(time.Since(startTime))
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *SqlStore) withTx(tx *gorm.DB) Store {
|
||||
return &SqlStore{
|
||||
db: tx,
|
||||
db: tx,
|
||||
storeEngine: s.storeEngine,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1117,9 +1162,10 @@ func (s *SqlStore) GetAccountDNSSettings(ctx context.Context, lockStrength Locki
|
||||
First(&accountDNSSettings, idQueryCondition, accountID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "dns settings not found")
|
||||
return nil, status.NewAccountNotFoundError(accountID)
|
||||
}
|
||||
return nil, status.Errorf(status.Internal, "failed to get dns settings from store: %v", result.Error)
|
||||
log.WithContext(ctx).Errorf("failed to get dns settings from store: %v", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "failed to get dns settings from store")
|
||||
}
|
||||
return &accountDNSSettings.DNSSettings, nil
|
||||
}
|
||||
@@ -1155,12 +1201,22 @@ func (s *SqlStore) GetAccountDomainAndCategory(ctx context.Context, lockStrength
|
||||
}
|
||||
|
||||
// GetGroupByID retrieves a group by ID and account ID.
|
||||
func (s *SqlStore) GetGroupByID(ctx context.Context, lockStrength LockingStrength, groupID, accountID string) (*nbgroup.Group, error) {
|
||||
return getRecordByID[nbgroup.Group](s.db.Preload(clause.Associations), lockStrength, groupID, accountID)
|
||||
func (s *SqlStore) GetGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) (*nbgroup.Group, error) {
|
||||
var group *nbgroup.Group
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).First(&group, accountAndIDQueryCondition, accountID, groupID)
|
||||
if err := result.Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, status.NewGroupNotFoundError(groupID)
|
||||
}
|
||||
log.WithContext(ctx).Errorf("failed to get group from store: %s", err)
|
||||
return nil, status.Errorf(status.Internal, "failed to get group from store")
|
||||
}
|
||||
|
||||
return group, nil
|
||||
}
|
||||
|
||||
// GetGroupByName retrieves a group by name and account ID.
|
||||
func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error) {
|
||||
func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStrength, accountID, groupName string) (*nbgroup.Group, error) {
|
||||
var group nbgroup.Group
|
||||
|
||||
// TODO: This fix is accepted for now, but if we need to handle this more frequently
|
||||
@@ -1172,12 +1228,13 @@ func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStren
|
||||
query = query.Order("json_array_length(peers) DESC")
|
||||
}
|
||||
|
||||
result := query.First(&group, "name = ? and account_id = ?", groupName, accountID)
|
||||
result := query.First(&group, "account_id = ? AND name = ?", accountID, groupName)
|
||||
if err := result.Error; err != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "group not found")
|
||||
return nil, status.NewGroupNotFoundError(groupName)
|
||||
}
|
||||
return nil, status.Errorf(status.Internal, "failed to get group from store: %s", result.Error)
|
||||
log.WithContext(ctx).Errorf("failed to get group by name from store: %v", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "failed to get group by name from store")
|
||||
}
|
||||
return &group, nil
|
||||
}
|
||||
@@ -1185,10 +1242,10 @@ func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStren
|
||||
// GetGroupsByIDs retrieves groups by their IDs and account ID.
|
||||
func (s *SqlStore) GetGroupsByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, groupIDs []string) (map[string]*nbgroup.Group, error) {
|
||||
var groups []*nbgroup.Group
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&groups, "account_id = ? AND id in ?", accountID, groupIDs)
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&groups, accountAndIDsQueryCondition, accountID, groupIDs)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get groups by ID's from the store: %s", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "failed to get groups by ID's from the store")
|
||||
log.WithContext(ctx).Errorf("failed to get groups by ID's from store: %s", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "failed to get groups by ID's from store")
|
||||
}
|
||||
|
||||
groupsMap := make(map[string]*nbgroup.Group)
|
||||
@@ -1203,29 +1260,175 @@ func (s *SqlStore) GetGroupsByIDs(ctx context.Context, lockStrength LockingStren
|
||||
func (s *SqlStore) SaveGroup(ctx context.Context, lockStrength LockingStrength, group *nbgroup.Group) error {
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(group)
|
||||
if result.Error != nil {
|
||||
return status.Errorf(status.Internal, "failed to save group to store: %v", result.Error)
|
||||
log.WithContext(ctx).Errorf("failed to save group to store: %v", result.Error)
|
||||
return status.Errorf(status.Internal, "failed to save group to store")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteGroup deletes a group from the database.
|
||||
func (s *SqlStore) DeleteGroup(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) error {
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
Delete(&nbgroup.Group{}, accountAndIDQueryCondition, accountID, groupID)
|
||||
if err := result.Error; err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to delete group from store: %s", result.Error)
|
||||
return status.Errorf(status.Internal, "failed to delete group from store")
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
return status.NewGroupNotFoundError(groupID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteGroups deletes groups from the database.
|
||||
func (s *SqlStore) DeleteGroups(ctx context.Context, strength LockingStrength, accountID string, groupIDs []string) error {
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(strength)}).
|
||||
Delete(&nbgroup.Group{}, accountAndIDsQueryCondition, accountID, groupIDs)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to delete groups from store: %v", result.Error)
|
||||
return status.Errorf(status.Internal, "failed to delete groups from store")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAccountPolicies retrieves policies for an account.
|
||||
func (s *SqlStore) GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) {
|
||||
return getRecords[*Policy](s.db.Preload(clause.Associations), lockStrength, accountID)
|
||||
var policies []*Policy
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
Preload(clause.Associations).Find(&policies, accountIDCondition, accountID)
|
||||
if err := result.Error; err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get policies from the store: %s", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "failed to get policies from store")
|
||||
}
|
||||
|
||||
return policies, nil
|
||||
}
|
||||
|
||||
// GetPolicyByID retrieves a policy by its ID and account ID.
|
||||
func (s *SqlStore) GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error) {
|
||||
return getRecordByID[Policy](s.db.Preload(clause.Associations), lockStrength, policyID, accountID)
|
||||
func (s *SqlStore) GetPolicyByID(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) (*Policy, error) {
|
||||
var policy *Policy
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Preload(clause.Associations).
|
||||
First(&policy, accountAndIDQueryCondition, accountID, policyID)
|
||||
if err := result.Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, status.NewPolicyNotFoundError(policyID)
|
||||
}
|
||||
log.WithContext(ctx).Errorf("failed to get policy from store: %s", err)
|
||||
return nil, status.Errorf(status.Internal, "failed to get policy from store")
|
||||
}
|
||||
|
||||
return policy, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) CreatePolicy(ctx context.Context, lockStrength LockingStrength, policy *Policy) error {
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Create(policy)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to create policy in store: %s", result.Error)
|
||||
return status.Errorf(status.Internal, "failed to create policy in store")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SavePolicy saves a policy to the database.
|
||||
func (s *SqlStore) SavePolicy(ctx context.Context, lockStrength LockingStrength, policy *Policy) error {
|
||||
result := s.db.Session(&gorm.Session{FullSaveAssociations: true}).
|
||||
Clauses(clause.Locking{Strength: string(lockStrength)}).Save(policy)
|
||||
if err := result.Error; err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to save policy to the store: %s", err)
|
||||
return status.Errorf(status.Internal, "failed to save policy to store")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) DeletePolicy(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) error {
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
Delete(&Policy{}, accountAndIDQueryCondition, accountID, policyID)
|
||||
if err := result.Error; err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to delete policy from store: %s", err)
|
||||
return status.Errorf(status.Internal, "failed to delete policy from store")
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
return status.NewPolicyNotFoundError(policyID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAccountPostureChecks retrieves posture checks for an account.
|
||||
func (s *SqlStore) GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error) {
|
||||
return getRecords[*posture.Checks](s.db, lockStrength, accountID)
|
||||
var postureChecks []*posture.Checks
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&postureChecks, accountIDCondition, accountID)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get posture checks from store: %s", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "failed to get posture checks from store")
|
||||
}
|
||||
|
||||
return postureChecks, nil
|
||||
}
|
||||
|
||||
// GetPostureChecksByID retrieves posture checks by their ID and account ID.
|
||||
func (s *SqlStore) GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, postureCheckID string, accountID string) (*posture.Checks, error) {
|
||||
return getRecordByID[posture.Checks](s.db, lockStrength, postureCheckID, accountID)
|
||||
func (s *SqlStore) GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, accountID, postureChecksID string) (*posture.Checks, error) {
|
||||
var postureCheck *posture.Checks
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
First(&postureCheck, accountAndIDQueryCondition, accountID, postureChecksID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.NewPostureChecksNotFoundError(postureChecksID)
|
||||
}
|
||||
log.WithContext(ctx).Errorf("failed to get posture check from store: %s", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "failed to get posture check from store")
|
||||
}
|
||||
|
||||
return postureCheck, nil
|
||||
}
|
||||
|
||||
// GetPostureChecksByIDs retrieves posture checks by their IDs and account ID.
|
||||
func (s *SqlStore) GetPostureChecksByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, postureChecksIDs []string) (map[string]*posture.Checks, error) {
|
||||
var postureChecks []*posture.Checks
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&postureChecks, accountAndIDsQueryCondition, accountID, postureChecksIDs)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get posture checks by ID's from store: %s", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "failed to get posture checks by ID's from store")
|
||||
}
|
||||
|
||||
postureChecksMap := make(map[string]*posture.Checks)
|
||||
for _, postureCheck := range postureChecks {
|
||||
postureChecksMap[postureCheck.ID] = postureCheck
|
||||
}
|
||||
|
||||
return postureChecksMap, nil
|
||||
}
|
||||
|
||||
// SavePostureChecks saves a posture checks to the database.
|
||||
func (s *SqlStore) SavePostureChecks(ctx context.Context, lockStrength LockingStrength, postureCheck *posture.Checks) error {
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(postureCheck)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to save posture checks to store: %s", result.Error)
|
||||
return status.Errorf(status.Internal, "failed to save posture checks to store")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeletePostureChecks deletes a posture checks from the database.
|
||||
func (s *SqlStore) DeletePostureChecks(ctx context.Context, lockStrength LockingStrength, accountID, postureChecksID string) error {
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
Delete(&posture.Checks{}, accountAndIDQueryCondition, accountID, postureChecksID)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to delete posture checks from store: %s", result.Error)
|
||||
return status.Errorf(status.Internal, "failed to delete posture checks from store")
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
return status.NewPostureChecksNotFoundError(postureChecksID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAccountRoutes retrieves network routes for an account.
|
||||
@@ -1295,12 +1498,55 @@ func (s *SqlStore) DeleteSetupKey(ctx context.Context, lockStrength LockingStren
|
||||
|
||||
// GetAccountNameServerGroups retrieves name server groups for an account.
|
||||
func (s *SqlStore) GetAccountNameServerGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbdns.NameServerGroup, error) {
|
||||
return getRecords[*nbdns.NameServerGroup](s.db, lockStrength, accountID)
|
||||
var nsGroups []*nbdns.NameServerGroup
|
||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&nsGroups, accountIDCondition, accountID)
|
||||
if err := result.Error; err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get name server groups from the store: %s", err)
|
||||
return nil, status.Errorf(status.Internal, "failed to get name server groups from store")
|
||||
}
|
||||
|
||||
return nsGroups, nil
|
||||
}
|
||||
|
||||
// GetNameServerGroupByID retrieves a name server group by its ID and account ID.
|
||||
func (s *SqlStore) GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, nsGroupID string, accountID string) (*nbdns.NameServerGroup, error) {
|
||||
return getRecordByID[nbdns.NameServerGroup](s.db, lockStrength, nsGroupID, accountID)
|
||||
func (s *SqlStore) GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, nsGroupID string) (*nbdns.NameServerGroup, error) {
|
||||
var nsGroup *nbdns.NameServerGroup
|
||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
First(&nsGroup, accountAndIDQueryCondition, accountID, nsGroupID)
|
||||
if err := result.Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, status.NewNameServerGroupNotFoundError(nsGroupID)
|
||||
}
|
||||
log.WithContext(ctx).Errorf("failed to get name server group from the store: %s", err)
|
||||
return nil, status.Errorf(status.Internal, "failed to get name server group from store")
|
||||
}
|
||||
|
||||
return nsGroup, nil
|
||||
}
|
||||
|
||||
// SaveNameServerGroup saves a name server group to the database.
|
||||
func (s *SqlStore) SaveNameServerGroup(ctx context.Context, lockStrength LockingStrength, nameServerGroup *nbdns.NameServerGroup) error {
|
||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(nameServerGroup)
|
||||
if err := result.Error; err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to save name server group to the store: %s", err)
|
||||
return status.Errorf(status.Internal, "failed to save name server group to store")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteNameServerGroup deletes a name server group from the database.
|
||||
func (s *SqlStore) DeleteNameServerGroup(ctx context.Context, lockStrength LockingStrength, accountID, nsGroupID string) error {
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Delete(&nbdns.NameServerGroup{}, accountAndIDQueryCondition, accountID, nsGroupID)
|
||||
if err := result.Error; err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to delete name server group from the store: %s", err)
|
||||
return status.Errorf(status.Internal, "failed to delete name server group from store")
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
return status.NewNameServerGroupNotFoundError(nsGroupID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// getRecords retrieves records from the database based on the account ID.
|
||||
@@ -1335,3 +1581,19 @@ func getRecordByID[T any](db *gorm.DB, lockStrength LockingStrength, recordID, a
|
||||
}
|
||||
return &record, nil
|
||||
}
|
||||
|
||||
// SaveDNSSettings saves the DNS settings to the store.
|
||||
func (s *SqlStore) SaveDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *DNSSettings) error {
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).
|
||||
Where(idQueryCondition, accountID).Updates(&AccountDNSSettings{DNSSettings: *settings})
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to save dns settings to store: %v", result.Error)
|
||||
return status.Errorf(status.Internal, "failed to save dns settings to store")
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
return status.NewAccountNotFoundError(accountID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -14,11 +14,11 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
route2 "github.com/netbirdio/netbird/route"
|
||||
|
||||
@@ -1181,7 +1181,7 @@ func TestSqlite_CreateAndGetObjectInTransaction(t *testing.T) {
|
||||
t.Fatal("failed to save group")
|
||||
return err
|
||||
}
|
||||
group, err = transaction.GetGroupByID(context.Background(), LockingStrengthUpdate, group.ID, group.AccountID)
|
||||
group, err = transaction.GetGroupByID(context.Background(), LockingStrengthUpdate, group.AccountID, group.ID)
|
||||
if err != nil {
|
||||
t.Fatal("failed to get group")
|
||||
return err
|
||||
@@ -1201,7 +1201,7 @@ func TestSqlite_GetAccoundUsers(t *testing.T) {
|
||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
account, err := store.GetAccount(context.Background(), accountID)
|
||||
require.NoError(t, err)
|
||||
users, err := store.GetAccountUsers(context.Background(), accountID)
|
||||
users, err := store.GetAccountUsers(context.Background(), LockingStrengthShare, accountID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, users, len(account.Users))
|
||||
}
|
||||
@@ -1260,9 +1260,9 @@ func TestSqlite_GetGroupByName(t *testing.T) {
|
||||
}
|
||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
|
||||
group, err := store.GetGroupByName(context.Background(), LockingStrengthShare, "All", accountID)
|
||||
group, err := store.GetGroupByName(context.Background(), LockingStrengthShare, accountID, "All")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "All", group.Name)
|
||||
require.True(t, group.IsGroupAll())
|
||||
}
|
||||
|
||||
func Test_DeleteSetupKeySuccessfully(t *testing.T) {
|
||||
@@ -1293,3 +1293,761 @@ func Test_DeleteSetupKeyFailsForNonExistingKey(t *testing.T) {
|
||||
err = store.DeleteSetupKey(context.Background(), LockingStrengthUpdate, accountID, nonExistingKeyID)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestSqlStore_GetGroupsByIDs(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
groupIDs []string
|
||||
expectedCount int
|
||||
}{
|
||||
{
|
||||
name: "retrieve existing groups by existing IDs",
|
||||
groupIDs: []string{"cfefqs706sqkneg59g4g", "cfefqs706sqkneg59g3g"},
|
||||
expectedCount: 2,
|
||||
},
|
||||
{
|
||||
name: "empty group IDs list",
|
||||
groupIDs: []string{},
|
||||
expectedCount: 0,
|
||||
},
|
||||
{
|
||||
name: "non-existing group IDs",
|
||||
groupIDs: []string{"nonexistent1", "nonexistent2"},
|
||||
expectedCount: 0,
|
||||
},
|
||||
{
|
||||
name: "mixed existing and non-existing group IDs",
|
||||
groupIDs: []string{"cfefqs706sqkneg59g4g", "nonexistent"},
|
||||
expectedCount: 1,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
groups, err := store.GetGroupsByIDs(context.Background(), LockingStrengthShare, accountID, tt.groupIDs)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, groups, tt.expectedCount)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSqlStore_SaveGroup(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
|
||||
group := &nbgroup.Group{
|
||||
ID: "group-id",
|
||||
AccountID: accountID,
|
||||
Issued: "api",
|
||||
Peers: []string{"peer1", "peer2"},
|
||||
}
|
||||
err = store.SaveGroup(context.Background(), LockingStrengthUpdate, group)
|
||||
require.NoError(t, err)
|
||||
|
||||
savedGroup, err := store.GetGroupByID(context.Background(), LockingStrengthShare, accountID, "group-id")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, savedGroup, group)
|
||||
}
|
||||
|
||||
func TestSqlStore_SaveGroups(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
|
||||
groups := []*nbgroup.Group{
|
||||
{
|
||||
ID: "group-1",
|
||||
AccountID: accountID,
|
||||
Issued: "api",
|
||||
Peers: []string{"peer1", "peer2"},
|
||||
},
|
||||
{
|
||||
ID: "group-2",
|
||||
AccountID: accountID,
|
||||
Issued: "integration",
|
||||
Peers: []string{"peer3", "peer4"},
|
||||
},
|
||||
}
|
||||
err = store.SaveGroups(context.Background(), LockingStrengthUpdate, groups)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestSqlStore_DeleteGroup(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
groupID string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "delete existing group",
|
||||
groupID: "cfefqs706sqkneg59g4g",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "delete non-existing group",
|
||||
groupID: "non-existing-group-id",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "delete with empty group ID",
|
||||
groupID: "",
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := store.DeleteGroup(context.Background(), LockingStrengthUpdate, accountID, tt.groupID)
|
||||
if tt.expectError {
|
||||
require.Error(t, err)
|
||||
sErr, ok := status.FromError(err)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, sErr.Type(), status.NotFound)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
|
||||
group, err := store.GetGroupByID(context.Background(), LockingStrengthShare, accountID, tt.groupID)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, group)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSqlStore_DeleteGroups(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
groupIDs []string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "delete multiple existing groups",
|
||||
groupIDs: []string{"cfefqs706sqkneg59g4g", "cfefqs706sqkneg59g3g"},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "delete non-existing groups",
|
||||
groupIDs: []string{"non-existing-id-1", "non-existing-id-2"},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "delete with empty group IDs list",
|
||||
groupIDs: []string{},
|
||||
expectError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := store.DeleteGroups(context.Background(), LockingStrengthUpdate, accountID, tt.groupIDs)
|
||||
if tt.expectError {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
|
||||
for _, groupID := range tt.groupIDs {
|
||||
group, err := store.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, group)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSqlStore_GetPeerByID(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store_policy_migrate.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
tests := []struct {
|
||||
name string
|
||||
peerID string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "retrieve existing peer",
|
||||
peerID: "cfefqs706sqkneg59g4g",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "retrieve non-existing peer",
|
||||
peerID: "non-existing",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "retrieve with empty peer ID",
|
||||
peerID: "",
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
peer, err := store.GetPeerByID(context.Background(), LockingStrengthShare, accountID, tt.peerID)
|
||||
if tt.expectError {
|
||||
require.Error(t, err)
|
||||
sErr, ok := status.FromError(err)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, sErr.Type(), status.NotFound)
|
||||
require.Nil(t, peer)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, peer)
|
||||
require.Equal(t, tt.peerID, peer.ID)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSqlStore_GetPeersByIDs(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store_policy_migrate.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
tests := []struct {
|
||||
name string
|
||||
peerIDs []string
|
||||
expectedCount int
|
||||
}{
|
||||
{
|
||||
name: "retrieve existing peers by existing IDs",
|
||||
peerIDs: []string{"cfefqs706sqkneg59g4g", "cfeg6sf06sqkneg59g50"},
|
||||
expectedCount: 2,
|
||||
},
|
||||
{
|
||||
name: "empty peer IDs list",
|
||||
peerIDs: []string{},
|
||||
expectedCount: 0,
|
||||
},
|
||||
{
|
||||
name: "non-existing peer IDs",
|
||||
peerIDs: []string{"nonexistent1", "nonexistent2"},
|
||||
expectedCount: 0,
|
||||
},
|
||||
{
|
||||
name: "mixed existing and non-existing peer IDs",
|
||||
peerIDs: []string{"cfeg6sf06sqkneg59g50", "nonexistent"},
|
||||
expectedCount: 1,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
peers, err := store.GetPeersByIDs(context.Background(), LockingStrengthShare, accountID, tt.peerIDs)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, peers, tt.expectedCount)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSqlStore_GetPostureChecksByID(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
tests := []struct {
|
||||
name string
|
||||
postureChecksID string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "retrieve existing posture checks",
|
||||
postureChecksID: "csplshq7qv948l48f7t0",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "retrieve non-existing posture checks",
|
||||
postureChecksID: "non-existing",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "retrieve with empty posture checks ID",
|
||||
postureChecksID: "",
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
postureChecks, err := store.GetPostureChecksByID(context.Background(), LockingStrengthShare, accountID, tt.postureChecksID)
|
||||
if tt.expectError {
|
||||
require.Error(t, err)
|
||||
sErr, ok := status.FromError(err)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, sErr.Type(), status.NotFound)
|
||||
require.Nil(t, postureChecks)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, postureChecks)
|
||||
require.Equal(t, tt.postureChecksID, postureChecks.ID)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSqlStore_GetPostureChecksByIDs(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
postureCheckIDs []string
|
||||
expectedCount int
|
||||
}{
|
||||
{
|
||||
name: "retrieve existing posture checks by existing IDs",
|
||||
postureCheckIDs: []string{"csplshq7qv948l48f7t0", "cspnllq7qv95uq1r4k90"},
|
||||
expectedCount: 2,
|
||||
},
|
||||
{
|
||||
name: "empty posture check IDs list",
|
||||
postureCheckIDs: []string{},
|
||||
expectedCount: 0,
|
||||
},
|
||||
{
|
||||
name: "non-existing posture check IDs",
|
||||
postureCheckIDs: []string{"nonexistent1", "nonexistent2"},
|
||||
expectedCount: 0,
|
||||
},
|
||||
{
|
||||
name: "mixed existing and non-existing posture check IDs",
|
||||
postureCheckIDs: []string{"cspnllq7qv95uq1r4k90", "nonexistent"},
|
||||
expectedCount: 1,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
groups, err := store.GetPostureChecksByIDs(context.Background(), LockingStrengthShare, accountID, tt.postureCheckIDs)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, groups, tt.expectedCount)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSqlStore_SavePostureChecks(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
|
||||
postureChecks := &posture.Checks{
|
||||
ID: "posture-checks-id",
|
||||
AccountID: accountID,
|
||||
Checks: posture.ChecksDefinition{
|
||||
NBVersionCheck: &posture.NBVersionCheck{
|
||||
MinVersion: "0.31.0",
|
||||
},
|
||||
OSVersionCheck: &posture.OSVersionCheck{
|
||||
Ios: &posture.MinVersionCheck{
|
||||
MinVersion: "13.0.1",
|
||||
},
|
||||
Linux: &posture.MinKernelVersionCheck{
|
||||
MinKernelVersion: "5.3.3-dev",
|
||||
},
|
||||
},
|
||||
GeoLocationCheck: &posture.GeoLocationCheck{
|
||||
Locations: []posture.Location{
|
||||
{
|
||||
CountryCode: "DE",
|
||||
CityName: "Berlin",
|
||||
},
|
||||
},
|
||||
Action: posture.CheckActionAllow,
|
||||
},
|
||||
},
|
||||
}
|
||||
err = store.SavePostureChecks(context.Background(), LockingStrengthUpdate, postureChecks)
|
||||
require.NoError(t, err)
|
||||
|
||||
savePostureChecks, err := store.GetPostureChecksByID(context.Background(), LockingStrengthShare, accountID, "posture-checks-id")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, savePostureChecks, postureChecks)
|
||||
}
|
||||
|
||||
func TestSqlStore_DeletePostureChecks(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
postureChecksID string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "delete existing posture checks",
|
||||
postureChecksID: "csplshq7qv948l48f7t0",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "delete non-existing posture checks",
|
||||
postureChecksID: "non-existing-posture-checks-id",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "delete with empty posture checks ID",
|
||||
postureChecksID: "",
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err = store.DeletePostureChecks(context.Background(), LockingStrengthUpdate, accountID, tt.postureChecksID)
|
||||
if tt.expectError {
|
||||
require.Error(t, err)
|
||||
sErr, ok := status.FromError(err)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, sErr.Type(), status.NotFound)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
group, err := store.GetPostureChecksByID(context.Background(), LockingStrengthShare, accountID, tt.postureChecksID)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, group)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSqlStore_GetPolicyByID(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
tests := []struct {
|
||||
name string
|
||||
policyID string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "retrieve existing policy",
|
||||
policyID: "cs1tnh0hhcjnqoiuebf0",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "retrieve non-existing policy checks",
|
||||
policyID: "non-existing",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "retrieve with empty policy ID",
|
||||
policyID: "",
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
policy, err := store.GetPolicyByID(context.Background(), LockingStrengthShare, accountID, tt.policyID)
|
||||
if tt.expectError {
|
||||
require.Error(t, err)
|
||||
sErr, ok := status.FromError(err)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, sErr.Type(), status.NotFound)
|
||||
require.Nil(t, policy)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, policy)
|
||||
require.Equal(t, tt.policyID, policy.ID)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSqlStore_CreatePolicy(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
|
||||
policy := &Policy{
|
||||
ID: "policy-id",
|
||||
AccountID: accountID,
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
Enabled: true,
|
||||
Sources: []string{"groupA"},
|
||||
Destinations: []string{"groupC"},
|
||||
Bidirectional: true,
|
||||
Action: PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
}
|
||||
err = store.CreatePolicy(context.Background(), LockingStrengthUpdate, policy)
|
||||
require.NoError(t, err)
|
||||
|
||||
savePolicy, err := store.GetPolicyByID(context.Background(), LockingStrengthShare, accountID, policy.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, savePolicy, policy)
|
||||
|
||||
}
|
||||
|
||||
func TestSqlStore_SavePolicy(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
policyID := "cs1tnh0hhcjnqoiuebf0"
|
||||
|
||||
policy, err := store.GetPolicyByID(context.Background(), LockingStrengthShare, accountID, policyID)
|
||||
require.NoError(t, err)
|
||||
|
||||
policy.Enabled = false
|
||||
policy.Description = "policy"
|
||||
policy.Rules[0].Sources = []string{"group"}
|
||||
policy.Rules[0].Ports = []string{"80", "443"}
|
||||
err = store.SavePolicy(context.Background(), LockingStrengthUpdate, policy)
|
||||
require.NoError(t, err)
|
||||
|
||||
savePolicy, err := store.GetPolicyByID(context.Background(), LockingStrengthShare, accountID, policy.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, savePolicy, policy)
|
||||
}
|
||||
|
||||
func TestSqlStore_DeletePolicy(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
policyID := "cs1tnh0hhcjnqoiuebf0"
|
||||
|
||||
err = store.DeletePolicy(context.Background(), LockingStrengthShare, accountID, policyID)
|
||||
require.NoError(t, err)
|
||||
|
||||
policy, err := store.GetPolicyByID(context.Background(), LockingStrengthShare, accountID, policyID)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, policy)
|
||||
}
|
||||
|
||||
func TestSqlStore_GetDNSSettings(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
accountID string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "retrieve existing account dns settings",
|
||||
accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "retrieve non-existing account dns settings",
|
||||
accountID: "non-existing",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "retrieve dns settings with empty account ID",
|
||||
accountID: "",
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
dnsSettings, err := store.GetAccountDNSSettings(context.Background(), LockingStrengthShare, tt.accountID)
|
||||
if tt.expectError {
|
||||
require.Error(t, err)
|
||||
sErr, ok := status.FromError(err)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, sErr.Type(), status.NotFound)
|
||||
require.Nil(t, dnsSettings)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, dnsSettings)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSqlStore_SaveDNSSettings(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
|
||||
dnsSettings, err := store.GetAccountDNSSettings(context.Background(), LockingStrengthShare, accountID)
|
||||
require.NoError(t, err)
|
||||
|
||||
dnsSettings.DisabledManagementGroups = []string{"groupA", "groupB"}
|
||||
err = store.SaveDNSSettings(context.Background(), LockingStrengthUpdate, accountID, dnsSettings)
|
||||
require.NoError(t, err)
|
||||
|
||||
saveDNSSettings, err := store.GetAccountDNSSettings(context.Background(), LockingStrengthShare, accountID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, saveDNSSettings, dnsSettings)
|
||||
}
|
||||
|
||||
func TestSqlStore_GetAccountNameServerGroups(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
accountID string
|
||||
expectedCount int
|
||||
}{
|
||||
{
|
||||
name: "retrieve name server groups by existing account ID",
|
||||
accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b",
|
||||
expectedCount: 1,
|
||||
},
|
||||
{
|
||||
name: "non-existing account ID",
|
||||
accountID: "nonexistent",
|
||||
expectedCount: 0,
|
||||
},
|
||||
{
|
||||
name: "empty account ID",
|
||||
accountID: "",
|
||||
expectedCount: 0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
peers, err := store.GetAccountNameServerGroups(context.Background(), LockingStrengthShare, tt.accountID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, peers, tt.expectedCount)
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestSqlStore_GetNameServerByID(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
tests := []struct {
|
||||
name string
|
||||
nsGroupID string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "retrieve existing nameserver group",
|
||||
nsGroupID: "csqdelq7qv97ncu7d9t0",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "retrieve non-existing nameserver group",
|
||||
nsGroupID: "non-existing",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "retrieve with empty nameserver group ID",
|
||||
nsGroupID: "",
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
nsGroup, err := store.GetNameServerGroupByID(context.Background(), LockingStrengthShare, accountID, tt.nsGroupID)
|
||||
if tt.expectError {
|
||||
require.Error(t, err)
|
||||
sErr, ok := status.FromError(err)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, sErr.Type(), status.NotFound)
|
||||
require.Nil(t, nsGroup)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, nsGroup)
|
||||
require.Equal(t, tt.nsGroupID, nsGroup.ID)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSqlStore_SaveNameServerGroup(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
|
||||
nsGroup := &nbdns.NameServerGroup{
|
||||
ID: "ns-group-id",
|
||||
AccountID: accountID,
|
||||
Name: "NS Group",
|
||||
NameServers: []nbdns.NameServer{
|
||||
{
|
||||
IP: netip.MustParseAddr("8.8.8.8"),
|
||||
NSType: 1,
|
||||
Port: 53,
|
||||
},
|
||||
},
|
||||
Groups: []string{"groupA"},
|
||||
Primary: true,
|
||||
Enabled: true,
|
||||
SearchDomainsEnabled: false,
|
||||
}
|
||||
|
||||
err = store.SaveNameServerGroup(context.Background(), LockingStrengthUpdate, nsGroup)
|
||||
require.NoError(t, err)
|
||||
|
||||
saveNSGroup, err := store.GetNameServerGroupByID(context.Background(), LockingStrengthShare, accountID, nsGroup.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, saveNSGroup, nsGroup)
|
||||
}
|
||||
|
||||
func TestSqlStore_DeleteNameServerGroup(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
nsGroupID := "csqdelq7qv97ncu7d9t0"
|
||||
|
||||
err = store.DeleteNameServerGroup(context.Background(), LockingStrengthShare, accountID, nsGroupID)
|
||||
require.NoError(t, err)
|
||||
|
||||
nsGroup, err := store.GetNameServerGroupByID(context.Background(), LockingStrengthShare, accountID, nsGroupID)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, nsGroup)
|
||||
}
|
||||
|
||||
@@ -3,7 +3,6 @@ package status
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -126,11 +125,6 @@ func NewAdminPermissionError() error {
|
||||
return Errorf(PermissionDenied, "admin role required to perform this action")
|
||||
}
|
||||
|
||||
// NewStoreContextCanceledError creates a new Error with Internal type for a canceled store context
|
||||
func NewStoreContextCanceledError(duration time.Duration) error {
|
||||
return Errorf(Internal, "store access: context canceled after %v", duration)
|
||||
}
|
||||
|
||||
// NewInvalidKeyIDError creates a new Error with InvalidArgument type for an issue getting a setup key
|
||||
func NewInvalidKeyIDError() error {
|
||||
return Errorf(InvalidArgument, "invalid key ID")
|
||||
@@ -140,3 +134,23 @@ func NewInvalidKeyIDError() error {
|
||||
func NewGetAccountError(err error) error {
|
||||
return Errorf(Internal, "error getting account: %s", err)
|
||||
}
|
||||
|
||||
// NewGroupNotFoundError creates a new Error with NotFound type for a missing group
|
||||
func NewGroupNotFoundError(groupID string) error {
|
||||
return Errorf(NotFound, "group: %s not found", groupID)
|
||||
}
|
||||
|
||||
// NewPostureChecksNotFoundError creates a new Error with NotFound type for a missing posture checks
|
||||
func NewPostureChecksNotFoundError(postureChecksID string) error {
|
||||
return Errorf(NotFound, "posture checks: %s not found", postureChecksID)
|
||||
}
|
||||
|
||||
// NewPolicyNotFoundError creates a new Error with NotFound type for a missing policy
|
||||
func NewPolicyNotFoundError(policyID string) error {
|
||||
return Errorf(NotFound, "policy: %s not found", policyID)
|
||||
}
|
||||
|
||||
// NewNameServerGroupNotFoundError creates a new Error with NotFound type for a missing name server group
|
||||
func NewNameServerGroupNotFoundError(nsGroupID string) error {
|
||||
return Errorf(NotFound, "nameserver group: %s not found", nsGroupID)
|
||||
}
|
||||
|
||||
@@ -59,10 +59,11 @@ type Store interface {
|
||||
SaveAccount(ctx context.Context, account *Account) error
|
||||
DeleteAccount(ctx context.Context, account *Account) error
|
||||
UpdateAccountDomainAttributes(ctx context.Context, accountID string, domain string, category string, isPrimaryDomain bool) error
|
||||
SaveDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *DNSSettings) error
|
||||
|
||||
GetUserByTokenID(ctx context.Context, tokenID string) (*User, error)
|
||||
GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error)
|
||||
GetAccountUsers(ctx context.Context, accountID string) ([]*User, error)
|
||||
GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*User, error)
|
||||
SaveUsers(accountID string, users map[string]*User) error
|
||||
SaveUser(ctx context.Context, lockStrength LockingStrength, user *User) error
|
||||
SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error
|
||||
@@ -76,13 +77,21 @@ type Store interface {
|
||||
GetGroupsByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, groupIDs []string) (map[string]*nbgroup.Group, error)
|
||||
SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*nbgroup.Group) error
|
||||
SaveGroup(ctx context.Context, lockStrength LockingStrength, group *nbgroup.Group) error
|
||||
DeleteGroup(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) error
|
||||
DeleteGroups(ctx context.Context, strength LockingStrength, accountID string, groupIDs []string) error
|
||||
|
||||
GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error)
|
||||
GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error)
|
||||
GetPolicyByID(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) (*Policy, error)
|
||||
CreatePolicy(ctx context.Context, lockStrength LockingStrength, policy *Policy) error
|
||||
SavePolicy(ctx context.Context, lockStrength LockingStrength, policy *Policy) error
|
||||
DeletePolicy(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) error
|
||||
|
||||
GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
|
||||
GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error)
|
||||
GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, postureCheckID string, accountID string) (*posture.Checks, error)
|
||||
GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, accountID, postureCheckID string) (*posture.Checks, error)
|
||||
GetPostureChecksByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, postureChecksIDs []string) (map[string]*posture.Checks, error)
|
||||
SavePostureChecks(ctx context.Context, lockStrength LockingStrength, postureCheck *posture.Checks) error
|
||||
DeletePostureChecks(ctx context.Context, lockStrength LockingStrength, accountID, postureChecksID string) error
|
||||
|
||||
GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string) ([]string, error)
|
||||
AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error
|
||||
@@ -90,6 +99,8 @@ type Store interface {
|
||||
AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error
|
||||
GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error)
|
||||
GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error)
|
||||
GetPeerByID(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) (*nbpeer.Peer, error)
|
||||
GetPeersByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, peerIDs []string) (map[string]*nbpeer.Peer, error)
|
||||
SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error
|
||||
SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error
|
||||
SavePeerLocation(accountID string, peer *nbpeer.Peer) error
|
||||
@@ -106,9 +117,11 @@ type Store interface {
|
||||
|
||||
GetAccountNameServerGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*dns.NameServerGroup, error)
|
||||
GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, nameServerGroupID string, accountID string) (*dns.NameServerGroup, error)
|
||||
SaveNameServerGroup(ctx context.Context, lockStrength LockingStrength, nameServerGroup *dns.NameServerGroup) error
|
||||
DeleteNameServerGroup(ctx context.Context, lockStrength LockingStrength, accountID, nameServerGroupID string) error
|
||||
|
||||
GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountId string) ([]net.IP, error)
|
||||
IncrementNetworkSerial(ctx context.Context, accountId string) error
|
||||
IncrementNetworkSerial(ctx context.Context, lockStrength LockingStrength, accountId string) error
|
||||
GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountId string) (*Network, error)
|
||||
|
||||
GetInstallationID() string
|
||||
|
||||
@@ -13,6 +13,7 @@ type AccountManagerMetrics struct {
|
||||
updateAccountPeersDurationMs metric.Float64Histogram
|
||||
getPeerNetworkMapDurationMs metric.Float64Histogram
|
||||
networkMapObjectCount metric.Int64Histogram
|
||||
peerMetaUpdateCount metric.Int64Counter
|
||||
}
|
||||
|
||||
// NewAccountManagerMetrics creates an instance of AccountManagerMetrics
|
||||
@@ -44,11 +45,17 @@ func NewAccountManagerMetrics(ctx context.Context, meter metric.Meter) (*Account
|
||||
return nil, err
|
||||
}
|
||||
|
||||
peerMetaUpdateCount, err := meter.Int64Counter("management.account.peer.meta.update.counter", metric.WithUnit("1"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &AccountManagerMetrics{
|
||||
ctx: ctx,
|
||||
getPeerNetworkMapDurationMs: getPeerNetworkMapDurationMs,
|
||||
updateAccountPeersDurationMs: updateAccountPeersDurationMs,
|
||||
networkMapObjectCount: networkMapObjectCount,
|
||||
peerMetaUpdateCount: peerMetaUpdateCount,
|
||||
}, nil
|
||||
|
||||
}
|
||||
@@ -67,3 +74,8 @@ func (metrics *AccountManagerMetrics) CountGetPeerNetworkMapDuration(duration ti
|
||||
func (metrics *AccountManagerMetrics) CountNetworkMapObjects(count int64) {
|
||||
metrics.networkMapObjectCount.Record(metrics.ctx, count)
|
||||
}
|
||||
|
||||
// CountPeerMetUpdate counts the number of peer meta updates
|
||||
func (metrics *AccountManagerMetrics) CountPeerMetUpdate() {
|
||||
metrics.peerMetaUpdateCount.Add(metrics.ctx, 1)
|
||||
}
|
||||
|
||||
@@ -13,6 +13,7 @@ type StoreMetrics struct {
|
||||
globalLockAcquisitionDurationMs metric.Int64Histogram
|
||||
persistenceDurationMicro metric.Int64Histogram
|
||||
persistenceDurationMs metric.Int64Histogram
|
||||
transactionDurationMs metric.Int64Histogram
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
@@ -40,11 +41,17 @@ func NewStoreMetrics(ctx context.Context, meter metric.Meter) (*StoreMetrics, er
|
||||
return nil, err
|
||||
}
|
||||
|
||||
transactionDurationMs, err := meter.Int64Histogram("management.store.transaction.duration.ms")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &StoreMetrics{
|
||||
globalLockAcquisitionDurationMicro: globalLockAcquisitionDurationMicro,
|
||||
globalLockAcquisitionDurationMs: globalLockAcquisitionDurationMs,
|
||||
persistenceDurationMicro: persistenceDurationMicro,
|
||||
persistenceDurationMs: persistenceDurationMs,
|
||||
transactionDurationMs: transactionDurationMs,
|
||||
ctx: ctx,
|
||||
}, nil
|
||||
}
|
||||
@@ -60,3 +67,8 @@ func (metrics *StoreMetrics) CountPersistenceDuration(duration time.Duration) {
|
||||
metrics.persistenceDurationMicro.Record(metrics.ctx, duration.Microseconds())
|
||||
metrics.persistenceDurationMs.Record(metrics.ctx, duration.Milliseconds())
|
||||
}
|
||||
|
||||
// CountTransactionDuration counts the duration of a store persistence operation
|
||||
func (metrics *StoreMetrics) CountTransactionDuration(duration time.Duration) {
|
||||
metrics.transactionDurationMs.Record(metrics.ctx, duration.Milliseconds())
|
||||
}
|
||||
|
||||
@@ -34,4 +34,7 @@ INSERT INTO personal_access_tokens VALUES('9dj38s35-63fb-11ec-90d6-0242ac120003'
|
||||
INSERT INTO "groups" VALUES('cfefqs706sqkneg59g4g','bf1c8084-ba50-4ce7-9439-34653001fc3b','All','api','[]',0,'');
|
||||
INSERT INTO "groups" VALUES('cfefqs706sqkneg59g3g','bf1c8084-ba50-4ce7-9439-34653001fc3b','AwesomeGroup1','api','[]',0,'');
|
||||
INSERT INTO "groups" VALUES('cfefqs706sqkneg59g2g','bf1c8084-ba50-4ce7-9439-34653001fc3b','AwesomeGroup2','api','[]',0,'');
|
||||
INSERT INTO posture_checks VALUES('csplshq7qv948l48f7t0','NetBird Version > 0.32.0','','bf1c8084-ba50-4ce7-9439-34653001fc3b','{"NBVersionCheck":{"MinVersion":"0.31.0"}}');
|
||||
INSERT INTO posture_checks VALUES('cspnllq7qv95uq1r4k90','Allow Berlin and Deny local network 172.16.1.0/24','','bf1c8084-ba50-4ce7-9439-34653001fc3b','{"GeoLocationCheck":{"Locations":[{"CountryCode":"DE","CityName":"Berlin"}],"Action":"allow"},"PeerNetworkRangeCheck":{"Action":"deny","Ranges":["172.16.1.0/24"]}}');
|
||||
INSERT INTO name_server_groups VALUES('csqdelq7qv97ncu7d9t0','bf1c8084-ba50-4ce7-9439-34653001fc3b','Google DNS','Google DNS Servers','[{"IP":"8.8.8.8","NSType":1,"Port":53},{"IP":"8.8.4.4","NSType":1,"Port":53}]','["cfefqs706sqkneg59g2g"]',1,'[]',1,0);
|
||||
INSERT INTO installations VALUES(1,'');
|
||||
|
||||
@@ -494,7 +494,7 @@ func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, account
|
||||
|
||||
am.StoreEvent(ctx, initiatorUserID, targetUserID, account.Id, activity.UserDeleted, meta)
|
||||
if updateAccountPeers {
|
||||
am.updateAccountPeers(ctx, account)
|
||||
am.updateAccountPeers(ctx, account.Id)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -805,15 +805,20 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
|
||||
expiredPeers = append(expiredPeers, blockedPeers...)
|
||||
}
|
||||
|
||||
peerGroupsAdded := make(map[string][]string)
|
||||
peerGroupsRemoved := make(map[string][]string)
|
||||
if update.AutoGroups != nil && account.Settings.GroupsPropagationEnabled {
|
||||
removedGroups := difference(oldUser.AutoGroups, update.AutoGroups)
|
||||
// need force update all auto groups in any case they will not be duplicated
|
||||
account.UserGroupsAddToPeers(oldUser.Id, update.AutoGroups...)
|
||||
account.UserGroupsRemoveFromPeers(oldUser.Id, removedGroups...)
|
||||
peerGroupsAdded = account.UserGroupsAddToPeers(oldUser.Id, update.AutoGroups...)
|
||||
peerGroupsRemoved = account.UserGroupsRemoveFromPeers(oldUser.Id, removedGroups...)
|
||||
}
|
||||
|
||||
events := am.prepareUserUpdateEvents(ctx, initiatorUser.Id, oldUser, newUser, account, transferredOwnerRole)
|
||||
eventsToStore = append(eventsToStore, events...)
|
||||
userUpdateEvents := am.prepareUserUpdateEvents(ctx, initiatorUser.Id, oldUser, newUser, account, transferredOwnerRole)
|
||||
eventsToStore = append(eventsToStore, userUpdateEvents...)
|
||||
|
||||
userGroupsEvents := am.prepareUserGroupsEvents(ctx, initiatorUser.Id, oldUser, newUser, account, peerGroupsAdded, peerGroupsRemoved)
|
||||
eventsToStore = append(eventsToStore, userGroupsEvents...)
|
||||
|
||||
updatedUserInfo, err := getUserInfo(ctx, am, newUser, account)
|
||||
if err != nil {
|
||||
@@ -835,7 +840,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
|
||||
}
|
||||
|
||||
if account.Settings.GroupsPropagationEnabled && areUsersLinkedToPeers(account, userIDs) {
|
||||
am.updateAccountPeers(ctx, account)
|
||||
am.updateAccountPeers(ctx, account.Id)
|
||||
}
|
||||
|
||||
for _, storeEvent := range eventsToStore {
|
||||
@@ -872,32 +877,78 @@ func (am *DefaultAccountManager) prepareUserUpdateEvents(ctx context.Context, in
|
||||
})
|
||||
}
|
||||
|
||||
return eventsToStore
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) prepareUserGroupsEvents(ctx context.Context, initiatorUserID string, oldUser, newUser *User, account *Account, peerGroupsAdded, peerGroupsRemoved map[string][]string) []func() {
|
||||
var eventsToStore []func()
|
||||
if newUser.AutoGroups != nil {
|
||||
removedGroups := difference(oldUser.AutoGroups, newUser.AutoGroups)
|
||||
addedGroups := difference(newUser.AutoGroups, oldUser.AutoGroups)
|
||||
for _, g := range removedGroups {
|
||||
group := account.GetGroup(g)
|
||||
if group != nil {
|
||||
eventsToStore = append(eventsToStore, func() {
|
||||
am.StoreEvent(ctx, initiatorUserID, oldUser.Id, account.Id, activity.GroupRemovedFromUser,
|
||||
map[string]any{"group": group.Name, "group_id": group.ID, "is_service_user": newUser.IsServiceUser, "user_name": newUser.ServiceUserName})
|
||||
})
|
||||
|
||||
} else {
|
||||
log.WithContext(ctx).Errorf("group %s not found while saving user activity event of account %s", g, account.Id)
|
||||
}
|
||||
}
|
||||
for _, g := range addedGroups {
|
||||
group := account.GetGroup(g)
|
||||
if group != nil {
|
||||
eventsToStore = append(eventsToStore, func() {
|
||||
am.StoreEvent(ctx, initiatorUserID, oldUser.Id, account.Id, activity.GroupAddedToUser,
|
||||
map[string]any{"group": group.Name, "group_id": group.ID, "is_service_user": newUser.IsServiceUser, "user_name": newUser.ServiceUserName})
|
||||
})
|
||||
}
|
||||
removedEvents := am.handleGroupRemovedFromUser(ctx, initiatorUserID, oldUser, newUser, account, removedGroups, peerGroupsRemoved)
|
||||
eventsToStore = append(eventsToStore, removedEvents...)
|
||||
|
||||
addedEvents := am.handleGroupAddedToUser(ctx, initiatorUserID, oldUser, newUser, account, addedGroups, peerGroupsAdded)
|
||||
eventsToStore = append(eventsToStore, addedEvents...)
|
||||
}
|
||||
return eventsToStore
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) handleGroupAddedToUser(ctx context.Context, initiatorUserID string, oldUser, newUser *User, account *Account, addedGroups []string, peerGroupsAdded map[string][]string) []func() {
|
||||
var eventsToStore []func()
|
||||
for _, g := range addedGroups {
|
||||
group := account.GetGroup(g)
|
||||
if group != nil {
|
||||
eventsToStore = append(eventsToStore, func() {
|
||||
am.StoreEvent(ctx, initiatorUserID, oldUser.Id, account.Id, activity.GroupAddedToUser,
|
||||
map[string]any{"group": group.Name, "group_id": group.ID, "is_service_user": newUser.IsServiceUser, "user_name": newUser.ServiceUserName})
|
||||
})
|
||||
}
|
||||
}
|
||||
for groupID, peerIDs := range peerGroupsAdded {
|
||||
group := account.GetGroup(groupID)
|
||||
for _, peerID := range peerIDs {
|
||||
peer := account.GetPeer(peerID)
|
||||
eventsToStore = append(eventsToStore, func() {
|
||||
meta := map[string]any{
|
||||
"group": group.Name, "group_id": group.ID,
|
||||
"peer_ip": peer.IP.String(), "peer_fqdn": peer.FQDN(am.GetDNSDomain()),
|
||||
}
|
||||
am.StoreEvent(ctx, activity.SystemInitiator, peer.ID, account.Id, activity.GroupAddedToPeer, meta)
|
||||
})
|
||||
}
|
||||
}
|
||||
return eventsToStore
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) handleGroupRemovedFromUser(ctx context.Context, initiatorUserID string, oldUser, newUser *User, account *Account, removedGroups []string, peerGroupsRemoved map[string][]string) []func() {
|
||||
var eventsToStore []func()
|
||||
for _, g := range removedGroups {
|
||||
group := account.GetGroup(g)
|
||||
if group != nil {
|
||||
eventsToStore = append(eventsToStore, func() {
|
||||
am.StoreEvent(ctx, initiatorUserID, oldUser.Id, account.Id, activity.GroupRemovedFromUser,
|
||||
map[string]any{"group": group.Name, "group_id": group.ID, "is_service_user": newUser.IsServiceUser, "user_name": newUser.ServiceUserName})
|
||||
})
|
||||
|
||||
} else {
|
||||
log.WithContext(ctx).Errorf("group %s not found while saving user activity event of account %s", g, account.Id)
|
||||
}
|
||||
}
|
||||
for groupID, peerIDs := range peerGroupsRemoved {
|
||||
group := account.GetGroup(groupID)
|
||||
for _, peerID := range peerIDs {
|
||||
peer := account.GetPeer(peerID)
|
||||
eventsToStore = append(eventsToStore, func() {
|
||||
meta := map[string]any{
|
||||
"group": group.Name, "group_id": group.ID,
|
||||
"peer_ip": peer.IP.String(), "peer_fqdn": peer.FQDN(am.GetDNSDomain()),
|
||||
}
|
||||
am.StoreEvent(ctx, activity.SystemInitiator, peer.ID, account.Id, activity.GroupRemovedFromPeer, meta)
|
||||
})
|
||||
}
|
||||
}
|
||||
return eventsToStore
|
||||
}
|
||||
|
||||
@@ -1132,7 +1183,7 @@ func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accou
|
||||
if len(peerIDs) != 0 {
|
||||
// this will trigger peer disconnect from the management service
|
||||
am.peersUpdateManager.CloseChannels(ctx, peerIDs)
|
||||
am.updateAccountPeers(ctx, account)
|
||||
am.updateAccountPeers(ctx, account.Id)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -1240,7 +1291,7 @@ func (am *DefaultAccountManager) DeleteRegularUsers(ctx context.Context, account
|
||||
}
|
||||
|
||||
if updateAccountPeers {
|
||||
am.updateAccountPeers(ctx, account)
|
||||
am.updateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
for targetUserID, meta := range deletedUsersMeta {
|
||||
|
||||
@@ -1279,8 +1279,7 @@ func TestUserAccountPeersUpdate(t *testing.T) {
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
policy := Policy{
|
||||
ID: "policy",
|
||||
policy := &Policy{
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
@@ -1292,7 +1291,7 @@ func TestUserAccountPeersUpdate(t *testing.T) {
|
||||
},
|
||||
},
|
||||
}
|
||||
err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, false)
|
||||
_, err = manager.SavePolicy(context.Background(), account.Id, userID, policy)
|
||||
require.NoError(t, err)
|
||||
|
||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
|
||||
@@ -140,7 +140,7 @@ type Client struct {
|
||||
instanceURL *RelayAddr
|
||||
muInstanceURL sync.Mutex
|
||||
|
||||
onDisconnectListener func()
|
||||
onDisconnectListener func(string)
|
||||
onConnectedListener func()
|
||||
listenerMutex sync.Mutex
|
||||
}
|
||||
@@ -233,7 +233,7 @@ func (c *Client) ServerInstanceURL() (string, error) {
|
||||
}
|
||||
|
||||
// SetOnDisconnectListener sets a function that will be called when the connection to the relay server is closed.
|
||||
func (c *Client) SetOnDisconnectListener(fn func()) {
|
||||
func (c *Client) SetOnDisconnectListener(fn func(string)) {
|
||||
c.listenerMutex.Lock()
|
||||
defer c.listenerMutex.Unlock()
|
||||
c.onDisconnectListener = fn
|
||||
@@ -554,7 +554,7 @@ func (c *Client) notifyDisconnected() {
|
||||
if c.onDisconnectListener == nil {
|
||||
return
|
||||
}
|
||||
go c.onDisconnectListener()
|
||||
go c.onDisconnectListener(c.connectionURL)
|
||||
}
|
||||
|
||||
func (c *Client) notifyConnected() {
|
||||
|
||||
@@ -551,7 +551,7 @@ func TestCloseByServer(t *testing.T) {
|
||||
}
|
||||
|
||||
disconnected := make(chan struct{})
|
||||
relayClient.SetOnDisconnectListener(func() {
|
||||
relayClient.SetOnDisconnectListener(func(_ string) {
|
||||
log.Infof("client disconnected")
|
||||
close(disconnected)
|
||||
})
|
||||
|
||||
@@ -4,65 +4,120 @@ import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
var (
|
||||
reconnectingTimeout = 5 * time.Second
|
||||
reconnectingTimeout = 60 * time.Second
|
||||
)
|
||||
|
||||
// Guard manage the reconnection tries to the Relay server in case of disconnection event.
|
||||
type Guard struct {
|
||||
ctx context.Context
|
||||
relayClient *Client
|
||||
// OnNewRelayClient is a channel that is used to notify the relay client about a new relay client instance.
|
||||
OnNewRelayClient chan *Client
|
||||
serverPicker *ServerPicker
|
||||
}
|
||||
|
||||
// NewGuard creates a new guard for the relay client.
|
||||
func NewGuard(context context.Context, relayClient *Client) *Guard {
|
||||
func NewGuard(sp *ServerPicker) *Guard {
|
||||
g := &Guard{
|
||||
ctx: context,
|
||||
relayClient: relayClient,
|
||||
OnNewRelayClient: make(chan *Client, 1),
|
||||
serverPicker: sp,
|
||||
}
|
||||
return g
|
||||
}
|
||||
|
||||
// OnDisconnected is called when the relay client is disconnected from the relay server. It will trigger the reconnection
|
||||
// StartReconnectTrys is called when the relay client is disconnected from the relay server.
|
||||
// It attempts to reconnect to the relay server. The function first tries a quick reconnect
|
||||
// to the same server that was used before, if the server URL is still valid. If the quick
|
||||
// reconnect fails, it starts a ticker to periodically attempt server picking until it
|
||||
// succeeds or the context is done.
|
||||
//
|
||||
// Parameters:
|
||||
// - ctx: The context to control the lifecycle of the reconnection attempts.
|
||||
// - relayClient: The relay client instance that was disconnected.
|
||||
// todo prevent multiple reconnection instances. In the current usage it should not happen, but it is better to prevent
|
||||
func (g *Guard) OnDisconnected() {
|
||||
if g.quickReconnect() {
|
||||
func (g *Guard) StartReconnectTrys(ctx context.Context, relayClient *Client) {
|
||||
if relayClient == nil {
|
||||
goto RETRY
|
||||
}
|
||||
if g.isServerURLStillValid(relayClient) && g.quickReconnect(ctx, relayClient) {
|
||||
return
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(reconnectingTimeout)
|
||||
RETRY:
|
||||
ticker := exponentTicker(ctx)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
err := g.relayClient.Connect()
|
||||
if err != nil {
|
||||
log.Errorf("failed to reconnect to relay server: %s", err)
|
||||
if err := g.retry(ctx); err != nil {
|
||||
log.Errorf("failed to pick new Relay server: %s", err)
|
||||
continue
|
||||
}
|
||||
return
|
||||
case <-g.ctx.Done():
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (g *Guard) quickReconnect() bool {
|
||||
ctx, cancel := context.WithTimeout(g.ctx, 1500*time.Millisecond)
|
||||
func (g *Guard) retry(ctx context.Context) error {
|
||||
log.Infof("try to pick up a new Relay server")
|
||||
relayClient, err := g.serverPicker.PickServer(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// prevent to work with a deprecated Relay client instance
|
||||
g.drainRelayClientChan()
|
||||
|
||||
g.OnNewRelayClient <- relayClient
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *Guard) quickReconnect(parentCtx context.Context, rc *Client) bool {
|
||||
ctx, cancel := context.WithTimeout(parentCtx, 1500*time.Millisecond)
|
||||
defer cancel()
|
||||
<-ctx.Done()
|
||||
|
||||
if g.ctx.Err() != nil {
|
||||
if parentCtx.Err() != nil {
|
||||
return false
|
||||
}
|
||||
log.Infof("try to reconnect to Relay server: %s", rc.connectionURL)
|
||||
|
||||
if err := g.relayClient.Connect(); err != nil {
|
||||
if err := rc.Connect(); err != nil {
|
||||
log.Errorf("failed to reconnect to relay server: %s", err)
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (g *Guard) drainRelayClientChan() {
|
||||
select {
|
||||
case <-g.OnNewRelayClient:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func (g *Guard) isServerURLStillValid(rc *Client) bool {
|
||||
for _, url := range g.serverPicker.ServerURLs.Load().([]string) {
|
||||
if url == rc.connectionURL {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func exponentTicker(ctx context.Context) *backoff.Ticker {
|
||||
bo := backoff.WithContext(&backoff.ExponentialBackOff{
|
||||
InitialInterval: 2 * time.Second,
|
||||
Multiplier: 2,
|
||||
MaxInterval: reconnectingTimeout,
|
||||
Clock: backoff.SystemClock,
|
||||
}, ctx)
|
||||
|
||||
return backoff.NewTicker(bo)
|
||||
}
|
||||
|
||||
@@ -57,12 +57,15 @@ type ManagerService interface {
|
||||
// relay servers will be closed if there is no active connection. Periodically the manager will check if there is any
|
||||
// unused relay connection and close it.
|
||||
type Manager struct {
|
||||
ctx context.Context
|
||||
serverURLs []string
|
||||
peerID string
|
||||
tokenStore *relayAuth.TokenStore
|
||||
ctx context.Context
|
||||
peerID string
|
||||
running bool
|
||||
tokenStore *relayAuth.TokenStore
|
||||
serverPicker *ServerPicker
|
||||
|
||||
relayClient *Client
|
||||
relayClient *Client
|
||||
// the guard logic can overwrite the relayClient variable, this mutex protect the usage of the variable
|
||||
relayClientMu sync.Mutex
|
||||
reconnectGuard *Guard
|
||||
|
||||
relayClients map[string]*RelayTrack
|
||||
@@ -76,48 +79,54 @@ type Manager struct {
|
||||
// NewManager creates a new manager instance.
|
||||
// The serverURL address can be empty. In this case, the manager will not serve.
|
||||
func NewManager(ctx context.Context, serverURLs []string, peerID string) *Manager {
|
||||
return &Manager{
|
||||
ctx: ctx,
|
||||
serverURLs: serverURLs,
|
||||
peerID: peerID,
|
||||
tokenStore: &relayAuth.TokenStore{},
|
||||
tokenStore := &relayAuth.TokenStore{}
|
||||
|
||||
m := &Manager{
|
||||
ctx: ctx,
|
||||
peerID: peerID,
|
||||
tokenStore: tokenStore,
|
||||
serverPicker: &ServerPicker{
|
||||
TokenStore: tokenStore,
|
||||
PeerID: peerID,
|
||||
},
|
||||
relayClients: make(map[string]*RelayTrack),
|
||||
onDisconnectedListeners: make(map[string]*list.List),
|
||||
}
|
||||
m.serverPicker.ServerURLs.Store(serverURLs)
|
||||
m.reconnectGuard = NewGuard(m.serverPicker)
|
||||
return m
|
||||
}
|
||||
|
||||
// Serve starts the manager. It will establish a connection to the relay server and start the relay cleanup loop for
|
||||
// the unused relay connections. The manager will automatically reconnect to the relay server in case of disconnection.
|
||||
// Serve starts the manager, attempting to establish a connection with the relay server.
|
||||
// If the connection fails, it will keep trying to reconnect in the background.
|
||||
// Additionally, it starts a cleanup loop to remove unused relay connections.
|
||||
// The manager will automatically reconnect to the relay server in case of disconnection.
|
||||
func (m *Manager) Serve() error {
|
||||
if m.relayClient != nil {
|
||||
if m.running {
|
||||
return fmt.Errorf("manager already serving")
|
||||
}
|
||||
log.Debugf("starting relay client manager with %v relay servers", m.serverURLs)
|
||||
m.running = true
|
||||
log.Debugf("starting relay client manager with %v relay servers", m.serverPicker.ServerURLs.Load())
|
||||
|
||||
sp := ServerPicker{
|
||||
TokenStore: m.tokenStore,
|
||||
PeerID: m.peerID,
|
||||
}
|
||||
|
||||
client, err := sp.PickServer(m.ctx, m.serverURLs)
|
||||
client, err := m.serverPicker.PickServer(m.ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
go m.reconnectGuard.StartReconnectTrys(m.ctx, nil)
|
||||
} else {
|
||||
m.storeClient(client)
|
||||
}
|
||||
m.relayClient = client
|
||||
|
||||
m.reconnectGuard = NewGuard(m.ctx, m.relayClient)
|
||||
m.relayClient.SetOnConnectedListener(m.onServerConnected)
|
||||
m.relayClient.SetOnDisconnectListener(func() {
|
||||
m.onServerDisconnected(client.connectionURL)
|
||||
})
|
||||
m.startCleanupLoop()
|
||||
return nil
|
||||
go m.listenGuardEvent(m.ctx)
|
||||
go m.startCleanupLoop()
|
||||
return err
|
||||
}
|
||||
|
||||
// OpenConn opens a connection to the given peer key. If the peer is on the same relay server, the connection will be
|
||||
// established via the relay server. If the peer is on a different relay server, the manager will establish a new
|
||||
// connection to the relay server. It returns back with a net.Conn what represent the remote peer connection.
|
||||
func (m *Manager) OpenConn(serverAddress, peerKey string) (net.Conn, error) {
|
||||
m.relayClientMu.Lock()
|
||||
defer m.relayClientMu.Unlock()
|
||||
|
||||
if m.relayClient == nil {
|
||||
return nil, ErrRelayClientNotConnected
|
||||
}
|
||||
@@ -146,6 +155,9 @@ func (m *Manager) OpenConn(serverAddress, peerKey string) (net.Conn, error) {
|
||||
|
||||
// Ready returns true if the home Relay client is connected to the relay server.
|
||||
func (m *Manager) Ready() bool {
|
||||
m.relayClientMu.Lock()
|
||||
defer m.relayClientMu.Unlock()
|
||||
|
||||
if m.relayClient == nil {
|
||||
return false
|
||||
}
|
||||
@@ -159,6 +171,13 @@ func (m *Manager) SetOnReconnectedListener(f func()) {
|
||||
// AddCloseListener adds a listener to the given server instance address. The listener will be called if the connection
|
||||
// closed.
|
||||
func (m *Manager) AddCloseListener(serverAddress string, onClosedListener OnServerCloseListener) error {
|
||||
m.relayClientMu.Lock()
|
||||
defer m.relayClientMu.Unlock()
|
||||
|
||||
if m.relayClient == nil {
|
||||
return ErrRelayClientNotConnected
|
||||
}
|
||||
|
||||
foreign, err := m.isForeignServer(serverAddress)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -177,6 +196,9 @@ func (m *Manager) AddCloseListener(serverAddress string, onClosedListener OnServ
|
||||
// RelayInstanceAddress returns the address of the permanent relay server. It could change if the network connection is
|
||||
// lost. This address will be sent to the target peer to choose the common relay server for the communication.
|
||||
func (m *Manager) RelayInstanceAddress() (string, error) {
|
||||
m.relayClientMu.Lock()
|
||||
defer m.relayClientMu.Unlock()
|
||||
|
||||
if m.relayClient == nil {
|
||||
return "", ErrRelayClientNotConnected
|
||||
}
|
||||
@@ -185,13 +207,18 @@ func (m *Manager) RelayInstanceAddress() (string, error) {
|
||||
|
||||
// ServerURLs returns the addresses of the relay servers.
|
||||
func (m *Manager) ServerURLs() []string {
|
||||
return m.serverURLs
|
||||
return m.serverPicker.ServerURLs.Load().([]string)
|
||||
}
|
||||
|
||||
// HasRelayAddress returns true if the manager is serving. With this method can check if the peer can communicate with
|
||||
// Relay service.
|
||||
func (m *Manager) HasRelayAddress() bool {
|
||||
return len(m.serverURLs) > 0
|
||||
return len(m.serverPicker.ServerURLs.Load().([]string)) > 0
|
||||
}
|
||||
|
||||
func (m *Manager) UpdateServerURLs(serverURLs []string) {
|
||||
log.Infof("update relay server URLs: %v", serverURLs)
|
||||
m.serverPicker.ServerURLs.Store(serverURLs)
|
||||
}
|
||||
|
||||
// UpdateToken updates the token in the token store.
|
||||
@@ -245,9 +272,7 @@ func (m *Manager) openConnVia(serverAddress, peerKey string) (net.Conn, error) {
|
||||
return nil, err
|
||||
}
|
||||
// if connection closed then delete the relay client from the list
|
||||
relayClient.SetOnDisconnectListener(func() {
|
||||
m.onServerDisconnected(serverAddress)
|
||||
})
|
||||
relayClient.SetOnDisconnectListener(m.onServerDisconnected)
|
||||
rt.relayClient = relayClient
|
||||
rt.Unlock()
|
||||
|
||||
@@ -265,14 +290,37 @@ func (m *Manager) onServerConnected() {
|
||||
go m.onReconnectedListenerFn()
|
||||
}
|
||||
|
||||
// onServerDisconnected start to reconnection for home server only
|
||||
func (m *Manager) onServerDisconnected(serverAddress string) {
|
||||
m.relayClientMu.Lock()
|
||||
if serverAddress == m.relayClient.connectionURL {
|
||||
go m.reconnectGuard.OnDisconnected()
|
||||
go m.reconnectGuard.StartReconnectTrys(m.ctx, m.relayClient)
|
||||
}
|
||||
m.relayClientMu.Unlock()
|
||||
|
||||
m.notifyOnDisconnectListeners(serverAddress)
|
||||
}
|
||||
|
||||
func (m *Manager) listenGuardEvent(ctx context.Context) {
|
||||
for {
|
||||
select {
|
||||
case rc := <-m.reconnectGuard.OnNewRelayClient:
|
||||
m.storeClient(rc)
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) storeClient(client *Client) {
|
||||
m.relayClientMu.Lock()
|
||||
defer m.relayClientMu.Unlock()
|
||||
|
||||
m.relayClient = client
|
||||
m.relayClient.SetOnConnectedListener(m.onServerConnected)
|
||||
m.relayClient.SetOnDisconnectListener(m.onServerDisconnected)
|
||||
}
|
||||
|
||||
func (m *Manager) isForeignServer(address string) (bool, error) {
|
||||
rAddr, err := m.relayClient.ServerInstanceURL()
|
||||
if err != nil {
|
||||
@@ -282,22 +330,16 @@ func (m *Manager) isForeignServer(address string) (bool, error) {
|
||||
}
|
||||
|
||||
func (m *Manager) startCleanupLoop() {
|
||||
if m.ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(relayCleanupInterval)
|
||||
go func() {
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-m.ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
m.cleanUpUnusedRelays()
|
||||
}
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-m.ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
m.cleanUpUnusedRelays()
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) cleanUpUnusedRelays() {
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
@@ -12,10 +13,13 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
connectionTimeout = 30 * time.Second
|
||||
maxConcurrentServers = 7
|
||||
)
|
||||
|
||||
var (
|
||||
connectionTimeout = 30 * time.Second
|
||||
)
|
||||
|
||||
type connResult struct {
|
||||
RelayClient *Client
|
||||
Url string
|
||||
@@ -24,20 +28,22 @@ type connResult struct {
|
||||
|
||||
type ServerPicker struct {
|
||||
TokenStore *auth.TokenStore
|
||||
ServerURLs atomic.Value
|
||||
PeerID string
|
||||
}
|
||||
|
||||
func (sp *ServerPicker) PickServer(parentCtx context.Context, urls []string) (*Client, error) {
|
||||
func (sp *ServerPicker) PickServer(parentCtx context.Context) (*Client, error) {
|
||||
ctx, cancel := context.WithTimeout(parentCtx, connectionTimeout)
|
||||
defer cancel()
|
||||
|
||||
totalServers := len(urls)
|
||||
totalServers := len(sp.ServerURLs.Load().([]string))
|
||||
|
||||
connResultChan := make(chan connResult, totalServers)
|
||||
successChan := make(chan connResult, 1)
|
||||
concurrentLimiter := make(chan struct{}, maxConcurrentServers)
|
||||
|
||||
for _, url := range urls {
|
||||
log.Debugf("pick server from list: %v", sp.ServerURLs.Load().([]string))
|
||||
for _, url := range sp.ServerURLs.Load().([]string) {
|
||||
// todo check if we have a successful connection so we do not need to connect to other servers
|
||||
concurrentLimiter <- struct{}{}
|
||||
go func(url string) {
|
||||
@@ -78,7 +84,7 @@ func (sp *ServerPicker) processConnResults(resultChan chan connResult, successCh
|
||||
for numOfResults := 0; numOfResults < cap(resultChan); numOfResults++ {
|
||||
cr := <-resultChan
|
||||
if cr.Err != nil {
|
||||
log.Debugf("failed to connect to Relay server: %s: %v", cr.Url, cr.Err)
|
||||
log.Tracef("failed to connect to Relay server: %s: %v", cr.Url, cr.Err)
|
||||
continue
|
||||
}
|
||||
log.Infof("connected to Relay server: %s", cr.Url)
|
||||
|
||||
@@ -4,19 +4,23 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestServerPicker_UnavailableServers(t *testing.T) {
|
||||
connectionTimeout = 5 * time.Second
|
||||
|
||||
sp := ServerPicker{
|
||||
TokenStore: nil,
|
||||
PeerID: "test",
|
||||
}
|
||||
sp.ServerURLs.Store([]string{"rel://dummy1", "rel://dummy2"})
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), connectionTimeout+1)
|
||||
defer cancel()
|
||||
|
||||
go func() {
|
||||
_, err := sp.PickServer(ctx, []string{"rel://dummy1", "rel://dummy2"})
|
||||
_, err := sp.PickServer(ctx)
|
||||
if err == nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
@@ -16,6 +16,8 @@ import (
|
||||
|
||||
const (
|
||||
bufferSize = 8820
|
||||
|
||||
errCloseConn = "failed to close connection to peer: %s"
|
||||
)
|
||||
|
||||
// Peer represents a peer connection
|
||||
@@ -46,6 +48,12 @@ func NewPeer(metrics *metrics.Metrics, id []byte, conn net.Conn, store *Store) *
|
||||
// It manages the protocol (healthcheck, transport, close). Read the message and determine the message type and handle
|
||||
// the message accordingly.
|
||||
func (p *Peer) Work() {
|
||||
defer func() {
|
||||
if err := p.conn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
|
||||
p.log.Errorf(errCloseConn, err)
|
||||
}
|
||||
}()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
@@ -97,7 +105,7 @@ func (p *Peer) handleMsgType(ctx context.Context, msgType messages.MsgType, hc *
|
||||
case messages.MsgTypeClose:
|
||||
p.log.Infof("peer exited gracefully")
|
||||
if err := p.conn.Close(); err != nil {
|
||||
log.Errorf("failed to close connection to peer: %s", err)
|
||||
log.Errorf(errCloseConn, err)
|
||||
}
|
||||
default:
|
||||
p.log.Warnf("received unexpected message type: %s", msgType)
|
||||
@@ -121,9 +129,8 @@ func (p *Peer) CloseGracefully(ctx context.Context) {
|
||||
p.log.Errorf("failed to send close message to peer: %s", p.String())
|
||||
}
|
||||
|
||||
err = p.conn.Close()
|
||||
if err != nil {
|
||||
p.log.Errorf("failed to close connection to peer: %s", err)
|
||||
if err := p.conn.Close(); err != nil {
|
||||
p.log.Errorf(errCloseConn, err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -132,7 +139,7 @@ func (p *Peer) Close() {
|
||||
defer p.connMu.Unlock()
|
||||
|
||||
if err := p.conn.Close(); err != nil {
|
||||
p.log.Errorf("failed to close connection to peer: %s", err)
|
||||
p.log.Errorf(errCloseConn, err)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
68
util/file.go
68
util/file.go
@@ -4,6 +4,7 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
@@ -14,8 +15,21 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func WriteBytesWithRestrictedPermission(ctx context.Context, file string, bs []byte) error {
|
||||
configDir, configFileName, err := prepareConfigFileDir(file)
|
||||
if err != nil {
|
||||
return fmt.Errorf("prepare config file dir: %w", err)
|
||||
}
|
||||
|
||||
if err = EnforcePermission(file); err != nil {
|
||||
return fmt.Errorf("enforce permission: %w", err)
|
||||
}
|
||||
|
||||
return writeBytes(ctx, file, err, configDir, configFileName, bs)
|
||||
}
|
||||
|
||||
// WriteJsonWithRestrictedPermission writes JSON config object to a file. Enforces permission on the parent directory
|
||||
func WriteJsonWithRestrictedPermission(file string, obj interface{}) error {
|
||||
func WriteJsonWithRestrictedPermission(ctx context.Context, file string, obj interface{}) error {
|
||||
configDir, configFileName, err := prepareConfigFileDir(file)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -26,18 +40,18 @@ func WriteJsonWithRestrictedPermission(file string, obj interface{}) error {
|
||||
return err
|
||||
}
|
||||
|
||||
return writeJson(file, obj, configDir, configFileName)
|
||||
return writeJson(ctx, file, obj, configDir, configFileName)
|
||||
}
|
||||
|
||||
// WriteJson writes JSON config object to a file creating parent directories if required
|
||||
// The output JSON is pretty-formatted
|
||||
func WriteJson(file string, obj interface{}) error {
|
||||
func WriteJson(ctx context.Context, file string, obj interface{}) error {
|
||||
configDir, configFileName, err := prepareConfigFileDir(file)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return writeJson(file, obj, configDir, configFileName)
|
||||
return writeJson(ctx, file, obj, configDir, configFileName)
|
||||
}
|
||||
|
||||
// DirectWriteJson writes JSON config object to a file creating parent directories if required without creating a temporary file
|
||||
@@ -79,24 +93,47 @@ func DirectWriteJson(ctx context.Context, file string, obj interface{}) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func writeJson(file string, obj interface{}, configDir string, configFileName string) error {
|
||||
func writeJson(ctx context.Context, file string, obj interface{}, configDir string, configFileName string) error {
|
||||
// Check context before expensive operations
|
||||
if ctx.Err() != nil {
|
||||
return fmt.Errorf("write json start: %w", ctx.Err())
|
||||
}
|
||||
|
||||
// make it pretty
|
||||
bs, err := json.MarshalIndent(obj, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("marshal: %w", err)
|
||||
}
|
||||
|
||||
return writeBytes(ctx, file, err, configDir, configFileName, bs)
|
||||
}
|
||||
|
||||
func writeBytes(ctx context.Context, file string, err error, configDir string, configFileName string, bs []byte) error {
|
||||
if ctx.Err() != nil {
|
||||
return fmt.Errorf("write bytes start: %w", ctx.Err())
|
||||
}
|
||||
|
||||
tempFile, err := os.CreateTemp(configDir, ".*"+configFileName)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("create temp: %w", err)
|
||||
}
|
||||
|
||||
tempFileName := tempFile.Name()
|
||||
// closing file ops as windows doesn't allow to move it
|
||||
err = tempFile.Close()
|
||||
|
||||
if deadline, ok := ctx.Deadline(); ok {
|
||||
if err := tempFile.SetDeadline(deadline); err != nil && !errors.Is(err, os.ErrNoDeadline) {
|
||||
log.Warnf("failed to set deadline: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
_, err = tempFile.Write(bs)
|
||||
if err != nil {
|
||||
return err
|
||||
_ = tempFile.Close()
|
||||
return fmt.Errorf("write: %w", err)
|
||||
}
|
||||
|
||||
if err = tempFile.Close(); err != nil {
|
||||
return fmt.Errorf("close %s: %w", tempFileName, err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
@@ -106,14 +143,13 @@ func writeJson(file string, obj interface{}, configDir string, configFileName st
|
||||
}
|
||||
}()
|
||||
|
||||
err = os.WriteFile(tempFileName, bs, 0600)
|
||||
if err != nil {
|
||||
return err
|
||||
// Check context again
|
||||
if ctx.Err() != nil {
|
||||
return fmt.Errorf("after temp file: %w", ctx.Err())
|
||||
}
|
||||
|
||||
err = os.Rename(tempFileName, file)
|
||||
if err != nil {
|
||||
return err
|
||||
if err = os.Rename(tempFileName, file); err != nil {
|
||||
return fmt.Errorf("move %s to %s: %w", tempFileName, file, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/md5"
|
||||
"encoding/hex"
|
||||
"io"
|
||||
@@ -39,7 +40,7 @@ func TestConfigJSON(t *testing.T) {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
err := WriteJson(tmpDir+"/testconfig.json", tt.config)
|
||||
err := WriteJson(context.Background(), tmpDir+"/testconfig.json", tt.config)
|
||||
require.NoError(t, err)
|
||||
|
||||
read, err := ReadJson(tmpDir+"/testconfig.json", &TestConfig{})
|
||||
@@ -73,7 +74,7 @@ func TestCopyFileContents(t *testing.T) {
|
||||
src := tmpDir + "/copytest_src"
|
||||
dst := tmpDir + "/copytest_dst"
|
||||
|
||||
err := WriteJson(src, tt.srcContent)
|
||||
err := WriteJson(context.Background(), src, tt.srcContent)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = CopyFileContents(src, dst)
|
||||
@@ -127,7 +128,7 @@ func TestHandleConfigFileWithoutFullPath(t *testing.T) {
|
||||
_ = os.Remove(cfgFile)
|
||||
}()
|
||||
|
||||
err := WriteJson(cfgFile, tt.config)
|
||||
err := WriteJson(context.Background(), cfgFile, tt.config)
|
||||
require.NoError(t, err)
|
||||
|
||||
read, err := ReadJson(cfgFile, &TestConfig{})
|
||||
|
||||
@@ -3,6 +3,9 @@ package grpc
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
"net"
|
||||
"os/user"
|
||||
"runtime"
|
||||
@@ -23,20 +26,22 @@ func WithCustomDialer() grpc.DialOption {
|
||||
if runtime.GOOS == "linux" {
|
||||
currentUser, err := user.Current()
|
||||
if err != nil {
|
||||
log.Fatalf("failed to get current user: %v", err)
|
||||
return nil, status.Errorf(codes.FailedPrecondition, "failed to get current user: %v", err)
|
||||
}
|
||||
|
||||
// the custom dialer requires root permissions which are not required for use cases run as non-root
|
||||
if currentUser.Uid != "0" {
|
||||
log.Debug("Not running as root, using standard dialer")
|
||||
dialer := &net.Dialer{}
|
||||
return dialer.DialContext(ctx, "tcp", addr)
|
||||
}
|
||||
}
|
||||
|
||||
log.Debug("Using nbnet.NewDialer()")
|
||||
conn, err := nbnet.NewDialer().DialContext(ctx, "tcp", addr)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to dial: %s", err)
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("nbnet.NewDialer().DialContext: %w", err)
|
||||
}
|
||||
return conn, nil
|
||||
})
|
||||
|
||||
31
util/net/conn.go
Normal file
31
util/net/conn.go
Normal file
@@ -0,0 +1,31 @@
|
||||
//go:build !ios
|
||||
|
||||
package net
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// Conn wraps a net.Conn to override the Close method
|
||||
type Conn struct {
|
||||
net.Conn
|
||||
ID ConnectionID
|
||||
}
|
||||
|
||||
// Close overrides the net.Conn Close method to execute all registered hooks after closing the connection
|
||||
func (c *Conn) Close() error {
|
||||
err := c.Conn.Close()
|
||||
|
||||
dialerCloseHooksMutex.RLock()
|
||||
defer dialerCloseHooksMutex.RUnlock()
|
||||
|
||||
for _, hook := range dialerCloseHooks {
|
||||
if err := hook(c.ID, &c.Conn); err != nil {
|
||||
log.Errorf("Error executing dialer close hook: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
58
util/net/dial.go
Normal file
58
util/net/dial.go
Normal file
@@ -0,0 +1,58 @@
|
||||
//go:build !ios
|
||||
|
||||
package net
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func DialUDP(network string, laddr, raddr *net.UDPAddr) (*net.UDPConn, error) {
|
||||
if CustomRoutingDisabled() {
|
||||
return net.DialUDP(network, laddr, raddr)
|
||||
}
|
||||
|
||||
dialer := NewDialer()
|
||||
dialer.LocalAddr = laddr
|
||||
|
||||
conn, err := dialer.Dial(network, raddr.String())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("dialing UDP %s: %w", raddr.String(), err)
|
||||
}
|
||||
|
||||
udpConn, ok := conn.(*Conn).Conn.(*net.UDPConn)
|
||||
if !ok {
|
||||
if err := conn.Close(); err != nil {
|
||||
log.Errorf("Failed to close connection: %v", err)
|
||||
}
|
||||
return nil, fmt.Errorf("expected UDP connection, got different type: %T", conn)
|
||||
}
|
||||
|
||||
return udpConn, nil
|
||||
}
|
||||
|
||||
func DialTCP(network string, laddr, raddr *net.TCPAddr) (*net.TCPConn, error) {
|
||||
if CustomRoutingDisabled() {
|
||||
return net.DialTCP(network, laddr, raddr)
|
||||
}
|
||||
|
||||
dialer := NewDialer()
|
||||
dialer.LocalAddr = laddr
|
||||
|
||||
conn, err := dialer.Dial(network, raddr.String())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("dialing TCP %s: %w", raddr.String(), err)
|
||||
}
|
||||
|
||||
tcpConn, ok := conn.(*Conn).Conn.(*net.TCPConn)
|
||||
if !ok {
|
||||
if err := conn.Close(); err != nil {
|
||||
log.Errorf("Failed to close connection: %v", err)
|
||||
}
|
||||
return nil, fmt.Errorf("expected TCP connection, got different type: %T", conn)
|
||||
}
|
||||
|
||||
return tcpConn, nil
|
||||
}
|
||||
@@ -1,25 +0,0 @@
|
||||
package net
|
||||
|
||||
import (
|
||||
"syscall"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func (d *Dialer) init() {
|
||||
d.Dialer.Control = func(_, _ string, c syscall.RawConn) error {
|
||||
err := c.Control(func(fd uintptr) {
|
||||
androidProtectSocketLock.Lock()
|
||||
f := androidProtectSocket
|
||||
androidProtectSocketLock.Unlock()
|
||||
if f == nil {
|
||||
return
|
||||
}
|
||||
ok := f(int32(fd))
|
||||
if !ok {
|
||||
log.Errorf("failed to protect socket: %d", fd)
|
||||
}
|
||||
})
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -69,7 +69,7 @@ func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.
|
||||
|
||||
conn, err := d.Dialer.DialContext(ctx, network, address)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("dial: %w", err)
|
||||
return nil, fmt.Errorf("d.Dialer.DialContext: %w", err)
|
||||
}
|
||||
|
||||
// Wrap the connection in Conn to handle Close with hooks
|
||||
@@ -81,28 +81,6 @@ func (d *Dialer) Dial(network, address string) (net.Conn, error) {
|
||||
return d.DialContext(context.Background(), network, address)
|
||||
}
|
||||
|
||||
// Conn wraps a net.Conn to override the Close method
|
||||
type Conn struct {
|
||||
net.Conn
|
||||
ID ConnectionID
|
||||
}
|
||||
|
||||
// Close overrides the net.Conn Close method to execute all registered hooks after closing the connection
|
||||
func (c *Conn) Close() error {
|
||||
err := c.Conn.Close()
|
||||
|
||||
dialerCloseHooksMutex.RLock()
|
||||
defer dialerCloseHooksMutex.RUnlock()
|
||||
|
||||
for _, hook := range dialerCloseHooks {
|
||||
if err := hook(c.ID, &c.Conn); err != nil {
|
||||
log.Errorf("Error executing dialer close hook: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func callDialerHooks(ctx context.Context, connID ConnectionID, address string, resolver *net.Resolver) error {
|
||||
host, _, err := net.SplitHostPort(address)
|
||||
if err != nil {
|
||||
@@ -127,51 +105,3 @@ func callDialerHooks(ctx context.Context, connID ConnectionID, address string, r
|
||||
|
||||
return result.ErrorOrNil()
|
||||
}
|
||||
|
||||
func DialUDP(network string, laddr, raddr *net.UDPAddr) (*net.UDPConn, error) {
|
||||
if CustomRoutingDisabled() {
|
||||
return net.DialUDP(network, laddr, raddr)
|
||||
}
|
||||
|
||||
dialer := NewDialer()
|
||||
dialer.LocalAddr = laddr
|
||||
|
||||
conn, err := dialer.Dial(network, raddr.String())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("dialing UDP %s: %w", raddr.String(), err)
|
||||
}
|
||||
|
||||
udpConn, ok := conn.(*Conn).Conn.(*net.UDPConn)
|
||||
if !ok {
|
||||
if err := conn.Close(); err != nil {
|
||||
log.Errorf("Failed to close connection: %v", err)
|
||||
}
|
||||
return nil, fmt.Errorf("expected UDP connection, got different type: %T", conn)
|
||||
}
|
||||
|
||||
return udpConn, nil
|
||||
}
|
||||
|
||||
func DialTCP(network string, laddr, raddr *net.TCPAddr) (*net.TCPConn, error) {
|
||||
if CustomRoutingDisabled() {
|
||||
return net.DialTCP(network, laddr, raddr)
|
||||
}
|
||||
|
||||
dialer := NewDialer()
|
||||
dialer.LocalAddr = laddr
|
||||
|
||||
conn, err := dialer.Dial(network, raddr.String())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("dialing TCP %s: %w", raddr.String(), err)
|
||||
}
|
||||
|
||||
tcpConn, ok := conn.(*Conn).Conn.(*net.TCPConn)
|
||||
if !ok {
|
||||
if err := conn.Close(); err != nil {
|
||||
log.Errorf("Failed to close connection: %v", err)
|
||||
}
|
||||
return nil, fmt.Errorf("expected TCP connection, got different type: %T", conn)
|
||||
}
|
||||
|
||||
return tcpConn, nil
|
||||
}
|
||||
5
util/net/dialer_init_android.go
Normal file
5
util/net/dialer_init_android.go
Normal file
@@ -0,0 +1,5 @@
|
||||
package net
|
||||
|
||||
func (d *Dialer) init() {
|
||||
d.Dialer.Control = ControlProtectSocket
|
||||
}
|
||||
@@ -7,6 +7,6 @@ import "syscall"
|
||||
// init configures the net.Dialer Control function to set the fwmark on the socket
|
||||
func (d *Dialer) init() {
|
||||
d.Dialer.Control = func(_, _ string, c syscall.RawConn) error {
|
||||
return SetRawSocketMark(c)
|
||||
return setRawSocketMark(c)
|
||||
}
|
||||
}
|
||||
@@ -3,4 +3,5 @@
|
||||
package net
|
||||
|
||||
func (d *Dialer) init() {
|
||||
// implemented on Linux and Android only
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user