Compare commits

..

1 Commits

Author SHA1 Message Date
Zoltan Papp
7ea0882975 [ios] add SSHClient gomobile binding for in-app terminal
Exposes SSHClient + SSHTerminalListener to the iOS app, mirroring the
Android binding. Connect() auto-detects the server type via banner
inspection and selects the auth path: NetBird-SSH with JWT triggers the
device-code OAuth flow via the existing URLOpener; NetBird-SSH without
JWT uses the NetBird private key; regular SSH falls back to the NetBird
key then optional password. The client dials through the running tunnel
with a plain net.Dialer and streams PTY output back to Swift via the
gomobile-bound listener for rendering in the terminal view.

Adds a config field and sshState() accessor to the iOS Client so the
SSH client can reach the active config and engine.
2026-05-27 18:29:12 +02:00
40 changed files with 1198 additions and 4770 deletions

View File

@@ -23,7 +23,6 @@ import (
"github.com/netbirdio/netbird/client/internal/peer/id"
"github.com/netbirdio/netbird/client/internal/peer/worker"
"github.com/netbirdio/netbird/client/internal/portforward"
"github.com/netbirdio/netbird/client/internal/rosenpass"
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/route"
relayClient "github.com/netbirdio/netbird/shared/relay/client"
@@ -900,7 +899,7 @@ func (conn *Conn) presharedKey(remoteRosenpassKey []byte) *wgtypes.Key {
}
// Fallback to deterministic key if no NetBird PSK is configured
determKey, err := rosenpass.DeterministicSeedKey(conn.config.LocalKey, conn.config.Key)
determKey, err := conn.rosenpassDetermKey()
if err != nil {
conn.Log.Errorf("failed to generate Rosenpass initial key: %v", err)
return nil
@@ -909,6 +908,26 @@ func (conn *Conn) presharedKey(remoteRosenpassKey []byte) *wgtypes.Key {
return determKey
}
// todo: move this logic into Rosenpass package
func (conn *Conn) rosenpassDetermKey() (*wgtypes.Key, error) {
lk := []byte(conn.config.LocalKey)
rk := []byte(conn.config.Key) // remote key
var keyInput []byte
if string(lk) > string(rk) {
//nolint:gocritic
keyInput = append(lk[:16], rk[:16]...)
} else {
//nolint:gocritic
keyInput = append(rk[:16], lk[:16]...)
}
key, err := wgtypes.NewKey(keyInput)
if err != nil {
return nil, err
}
return &key, nil
}
func isController(config ConnConfig) bool {
return config.LocalKey > config.Key
}

View File

@@ -28,15 +28,6 @@ func hashRosenpassKey(key []byte) string {
return hex.EncodeToString(hasher.Sum(nil))
}
// rpServer is the subset of rp.Server used by Manager. Defined as an interface
// so tests can substitute a mock without spinning up a real UDP server.
type rpServer interface {
AddPeer(rp.PeerConfig) (rp.PeerID, error)
RemovePeer(rp.PeerID) error
Run() error
Close() error
}
type Manager struct {
ifaceName string
spk []byte
@@ -45,7 +36,7 @@ type Manager struct {
preSharedKey *[32]byte
rpPeerIDs map[string]*rp.PeerID
rpWgHandler *NetbirdHandler
server rpServer
server *rp.Server
lock sync.Mutex
port int
wgIface PresharedKeySetter
@@ -60,22 +51,7 @@ func NewManager(preSharedKey *wgtypes.Key, wgIfaceName string) (*Manager, error)
rpKeyHash := hashRosenpassKey(public)
log.Tracef("generated new rosenpass key pair with public key %s", rpKeyHash)
return &Manager{
ifaceName: wgIfaceName,
rpKeyHash: rpKeyHash,
spk: public,
ssk: secret,
preSharedKey: (*[32]byte)(preSharedKey),
rpPeerIDs: make(map[string]*rp.PeerID),
// rpWgHandler is created here (instead of only in generateConfig) so it
// is never nil between NewManager and Run(). Otherwise an early
// OnConnected call (race observed on Android, issue #4341) panics on
// nil receiver in addPeer -> m.rpWgHandler.AddPeer. generateConfig will
// replace it with a fresh handler on each Run() to clear stale peer
// state from previous engine sessions.
rpWgHandler: NewNetbirdHandler(),
lock: sync.Mutex{},
}, nil
return &Manager{ifaceName: wgIfaceName, rpKeyHash: rpKeyHash, spk: public, ssk: secret, preSharedKey: (*[32]byte)(preSharedKey), rpPeerIDs: make(map[string]*rp.PeerID), lock: sync.Mutex{}}, nil
}
func (m *Manager) GetPubKey() []byte {
@@ -89,16 +65,6 @@ func (m *Manager) GetAddress() *net.UDPAddr {
// addPeer adds a new peer to the Rosenpass server
func (m *Manager) addPeer(rosenpassPubKey []byte, rosenpassAddr string, wireGuardIP string, wireGuardPubKey string) error {
// Defense in depth against issue #4341 (Android crash): if Run() has not
// completed yet, m.server / m.rpWgHandler may be nil. Return an explicit
// error instead of panicking on nil-receiver dereference.
if m.server == nil {
return fmt.Errorf("rosenpass server not initialized")
}
if m.rpWgHandler == nil {
return fmt.Errorf("rosenpass wg handler not initialized")
}
var err error
pcfg := rp.PeerConfig{PublicKey: rosenpassPubKey}
if m.preSharedKey != nil {
@@ -113,16 +79,6 @@ func (m *Manager) addPeer(rosenpassPubKey []byte, rosenpassAddr string, wireGuar
if pcfg.Endpoint, err = net.ResolveUDPAddr("udp", peerAddr); err != nil {
return fmt.Errorf("failed to resolve peer endpoint address: %w", err)
}
// Our local Rosenpass UDP server binds on the IPv6 wildcard ([::]) — see
// GetAddress(). The remote peer's endpoint (pcfg.Endpoint) is the destination
// our server will sendto when initiating handshakes. ResolveUDPAddr returns a
// 4-byte IPv4 for IPv4 hosts, which the kernel rejects (EDESTADDRREQ) when
// sent from an AF_INET6 socket. Normalize the remote endpoint to IPv4-mapped
// IPv6 so its address family matches our listening socket.
// TODO: maybe bind the Rosenpass UDP server to the peer wg IP addr
if v4 := pcfg.Endpoint.IP.To4(); v4 != nil {
pcfg.Endpoint.IP = v4.To16()
}
}
peerID, err := m.server.AddPeer(pcfg)
if err != nil {
@@ -226,31 +182,24 @@ func (m *Manager) Run() error {
return err
}
server, err := rp.NewUDPServer(conf)
m.server, err = rp.NewUDPServer(conf)
if err != nil {
return err
}
m.lock.Lock()
m.server = server
m.lock.Unlock()
log.Infof("starting rosenpass server on port %d", m.port)
return server.Run()
return m.server.Run()
}
// Close closes the Rosenpass server
func (m *Manager) Close() error {
m.lock.Lock()
server := m.server
m.server = nil
m.lock.Unlock()
if server == nil {
return nil
}
if err := server.Close(); err != nil {
log.Errorf("failed closing local rosenpass server: %v", err)
if m.server != nil {
err := m.server.Close()
if err != nil {
log.Errorf("failed closing local rosenpass server")
}
m.server = nil
}
return nil
}

View File

@@ -1,412 +1,14 @@
package rosenpass
import (
"errors"
"os"
"sync"
"testing"
rp "cunicu.li/go-rosenpass"
"github.com/stretchr/testify/require"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
// --- test doubles -----------------------------------------------------------
type addPeerCall struct {
cfg rp.PeerConfig
}
type removePeerCall struct {
id rp.PeerID
}
type mockServer struct {
mu sync.Mutex
addCalls []addPeerCall
removed []removePeerCall
nextID rp.PeerID
addErr error
removeErr error
closed bool
ran bool
}
func (m *mockServer) AddPeer(cfg rp.PeerConfig) (rp.PeerID, error) {
m.mu.Lock()
defer m.mu.Unlock()
m.addCalls = append(m.addCalls, addPeerCall{cfg: cfg})
if m.addErr != nil {
return rp.PeerID{}, m.addErr
}
// Increment a byte in nextID so distinct peers get distinct IDs.
m.nextID[0]++
return m.nextID, nil
}
func (m *mockServer) RemovePeer(id rp.PeerID) error {
m.mu.Lock()
defer m.mu.Unlock()
m.removed = append(m.removed, removePeerCall{id: id})
return m.removeErr
}
func (m *mockServer) Run() error { m.ran = true; return nil }
func (m *mockServer) Close() error { m.closed = true; return nil }
type setPSKCall struct {
peerKey string
psk wgtypes.Key
updateOnly bool
}
type mockIface struct {
mu sync.Mutex
calls []setPSKCall
err error
}
func (m *mockIface) SetPresharedKey(peerKey string, psk wgtypes.Key, updateOnly bool) error {
m.mu.Lock()
defer m.mu.Unlock()
m.calls = append(m.calls, setPSKCall{peerKey: peerKey, psk: psk, updateOnly: updateOnly})
return m.err
}
// newTestManager builds a Manager with deterministic spk so tie-break
// against a peer pubkey is controllable from tests. The provided spk byte
// becomes the first byte; remaining bytes are zero.
func newTestManager(spkFirstByte byte, mock *mockServer) *Manager {
spk := make([]byte, 32)
spk[0] = spkFirstByte
return &Manager{
ifaceName: "wt0",
spk: spk,
ssk: make([]byte, 32),
rpKeyHash: "test-hash",
rpPeerIDs: make(map[string]*rp.PeerID),
rpWgHandler: NewNetbirdHandler(),
server: mock,
}
}
// validWGKey returns a deterministic 32-byte wireguard public key (base64).
func validWGKey(t *testing.T, lastByte byte) string {
t.Helper()
var k wgtypes.Key
k[31] = lastByte
return k.String()
}
// --- pure helpers ----------------------------------------------------------
func TestHashRosenpassKey_Deterministic(t *testing.T) {
key := []byte("hello-rosenpass")
require.Equal(t, hashRosenpassKey(key), hashRosenpassKey(key))
require.Len(t, hashRosenpassKey(key), 64) // sha256 hex
}
func TestHashRosenpassKey_DifferentInputsDifferOutputs(t *testing.T) {
require.NotEqual(t, hashRosenpassKey([]byte("a")), hashRosenpassKey([]byte("b")))
}
func TestGetLogLevel_DefaultWhenUnset(t *testing.T) {
// Snapshot + unset to exercise the LookupEnv ok=false branch. t.Setenv
// can only set, not delete, so do it manually with restore via t.Cleanup.
prev, hadPrev := os.LookupEnv(defaultLogLevelVar)
require.NoError(t, os.Unsetenv(defaultLogLevelVar))
t.Cleanup(func() {
if hadPrev {
_ = os.Setenv(defaultLogLevelVar, prev)
} else {
_ = os.Unsetenv(defaultLogLevelVar)
}
})
require.Equal(t, defaultLog.String(), getLogLevel().String())
}
func TestGetLogLevel_Cases(t *testing.T) {
cases := map[string]string{
"debug": "DEBUG",
"info": "INFO",
"warn": "WARN",
"error": "ERROR",
"unknown": "INFO", // default fallback
}
for input, wantStr := range cases {
input, wantStr := input, wantStr
t.Run(input, func(t *testing.T) {
t.Setenv(defaultLogLevelVar, input)
require.Equal(t, wantStr, getLogLevel().String())
})
}
}
func TestFindRandomAvailableUDPPort(t *testing.T) {
port, err := findRandomAvailableUDPPort()
require.NoError(t, err)
require.Greater(t, port, 0)
require.LessOrEqual(t, port, 65535)
}
// --- addPeer ---------------------------------------------------------------
func TestAddPeer_HigherLocalPubkey_SetsEndpoint(t *testing.T) {
srv := &mockServer{}
m := newTestManager(0xFF, srv) // local spk lexicographically larger
remotePubKey := make([]byte, 32) // remote spk = all zeros (smaller)
err := m.addPeer(remotePubKey, "rosenpass-host:7000", "100.1.1.1", validWGKey(t, 1))
require.NoError(t, err)
require.Len(t, srv.addCalls, 1)
ep := srv.addCalls[0].cfg.Endpoint
require.NotNil(t, ep, "initiator side must set Endpoint")
require.Equal(t, 7000, ep.Port)
require.Equal(t, "100.1.1.1", ep.IP.String())
}
func TestAddPeer_HigherLocalPubkey_EndpointIPIsIPv4Mapped(t *testing.T) {
// Regression guard for the EDESTADDRREQ fix: Endpoint.IP must be 16-byte
// (IPv4-mapped IPv6) so it matches the AF_INET6 listening socket family.
srv := &mockServer{}
m := newTestManager(0xFF, srv)
err := m.addPeer(make([]byte, 32), "rp:5000", "100.1.1.1", validWGKey(t, 1))
require.NoError(t, err)
ep := srv.addCalls[0].cfg.Endpoint
require.NotNil(t, ep)
require.Len(t, ep.IP, 16, "IPv4 endpoint must be normalized to 16-byte v4-mapped form")
require.True(t, ep.IP.To4() != nil, "Endpoint must still be detected as IPv4")
}
func TestAddPeer_LowerLocalPubkey_LeavesEndpointNil(t *testing.T) {
srv := &mockServer{}
m := newTestManager(0x00, srv) // local spk smaller
remotePubKey := make([]byte, 32)
remotePubKey[0] = 0xFF
err := m.addPeer(remotePubKey, "rp:5000", "100.1.1.1", validWGKey(t, 2))
require.NoError(t, err)
require.Nil(t, srv.addCalls[0].cfg.Endpoint, "responder side must NOT set Endpoint")
}
func TestAddPeer_PresharedKeyPropagated(t *testing.T) {
srv := &mockServer{}
psk := &wgtypes.Key{0x42}
m := newTestManager(0xFF, srv)
m.preSharedKey = (*[32]byte)(psk)
err := m.addPeer(make([]byte, 32), "rp:5000", "100.1.1.1", validWGKey(t, 3))
require.NoError(t, err)
require.Equal(t, [32]byte(*psk), [32]byte(srv.addCalls[0].cfg.PresharedKey))
}
func TestAddPeer_InvalidRosenpassAddr_ReturnsError(t *testing.T) {
srv := &mockServer{}
m := newTestManager(0xFF, srv) // initiator path → parses rosenpassAddr
err := m.addPeer(make([]byte, 32), "not-a-host-port", "100.1.1.1", validWGKey(t, 1))
require.Error(t, err)
require.Empty(t, srv.addCalls, "server.AddPeer must not run when address parse fails")
}
func TestAddPeer_InvalidWireGuardPubKey_ReturnsError(t *testing.T) {
srv := &mockServer{}
m := newTestManager(0xFF, srv)
err := m.addPeer(make([]byte, 32), "rp:5000", "100.1.1.1", "not-a-valid-key")
require.Error(t, err)
}
func TestAddPeer_ServerError_Propagates(t *testing.T) {
srv := &mockServer{addErr: errors.New("boom")}
m := newTestManager(0xFF, srv)
err := m.addPeer(make([]byte, 32), "rp:5000", "100.1.1.1", validWGKey(t, 1))
require.Error(t, err)
}
// Regression guard for issue #4341 (Android crash). If Run() has not completed
// before OnConnected fires, m.rpWgHandler or m.server may be nil. Without the
// nil guards, m.rpWgHandler.AddPeer panics on nil receiver.
func TestAddPeer_NilHandler_ReturnsErrorNoCrash(t *testing.T) {
srv := &mockServer{}
m := newTestManager(0xFF, srv)
m.rpWgHandler = nil // simulate Run() not yet completed
err := m.addPeer(make([]byte, 32), "rp:5000", "100.1.1.1", validWGKey(t, 1))
require.Error(t, err)
require.Contains(t, err.Error(), "wg handler not initialized")
}
func TestAddPeer_NilServer_ReturnsErrorNoCrash(t *testing.T) {
m := newTestManager(0xFF, nil)
m.server = nil // simulate Run() not yet completed
err := m.addPeer(make([]byte, 32), "rp:5000", "100.1.1.1", validWGKey(t, 1))
require.Error(t, err)
require.Contains(t, err.Error(), "server not initialized")
}
// NewManager must pre-initialize rpWgHandler so the nil-receiver crash from
// issue #4341 cannot occur in the window between NewManager and Run().
func TestNewManager_PreInitializesHandler(t *testing.T) {
psk := wgtypes.Key{}
m, err := NewManager(&psk, "wt0")
require.NoError(t, err)
require.NotNil(t, m.rpWgHandler, "rpWgHandler must be initialized in NewManager")
}
func TestAddPeer_RecordsPeerID(t *testing.T) {
srv := &mockServer{}
m := newTestManager(0xFF, srv)
wgKey := validWGKey(t, 5)
err := m.addPeer(make([]byte, 32), "rp:5000", "100.1.1.1", wgKey)
require.NoError(t, err)
require.Contains(t, m.rpPeerIDs, wgKey)
}
// --- OnConnected / OnDisconnected ------------------------------------------
func TestOnConnected_NilRemotePubKey_NoAddPeer(t *testing.T) {
srv := &mockServer{}
m := newTestManager(0xFF, srv)
m.OnConnected(validWGKey(t, 1), nil, "100.1.1.1", "rp:5000")
require.Empty(t, srv.addCalls, "nil remote rosenpass pubkey must skip AddPeer")
require.Empty(t, m.rpPeerIDs)
}
func TestOnConnected_ValidPubKey_CallsAddPeer(t *testing.T) {
srv := &mockServer{}
m := newTestManager(0xFF, srv)
wgKey := validWGKey(t, 1)
m.OnConnected(wgKey, make([]byte, 32), "100.1.1.1", "rp:5000")
require.Len(t, srv.addCalls, 1)
require.Contains(t, m.rpPeerIDs, wgKey)
}
func TestOnDisconnected_UnknownPeer_NoOp(t *testing.T) {
srv := &mockServer{}
m := newTestManager(0xFF, srv)
m.OnDisconnected(validWGKey(t, 99))
require.Empty(t, srv.removed, "unknown peer key must not call RemovePeer")
}
func TestOnDisconnected_KnownPeer_CallsRemoveAndForgets(t *testing.T) {
srv := &mockServer{}
m := newTestManager(0xFF, srv)
wgKey := validWGKey(t, 1)
require.NoError(t, m.addPeer(make([]byte, 32), "rp:5000", "100.1.1.1", wgKey))
require.Contains(t, m.rpPeerIDs, wgKey)
m.OnDisconnected(wgKey)
require.Len(t, srv.removed, 1)
require.NotContains(t, m.rpPeerIDs, wgKey, "peer must be forgotten after disconnect")
}
// --- IsPresharedKeyInitialized ---------------------------------------------
func TestIsPresharedKeyInitialized_UnknownPeer_ReturnsFalse(t *testing.T) {
srv := &mockServer{}
m := newTestManager(0xFF, srv)
require.False(t, m.IsPresharedKeyInitialized(validWGKey(t, 1)))
}
func TestIsPresharedKeyInitialized_AddedButNotHandshaken_ReturnsFalse(t *testing.T) {
srv := &mockServer{}
m := newTestManager(0xFF, srv)
wgKey := validWGKey(t, 2)
require.NoError(t, m.addPeer(make([]byte, 32), "rp:5000", "100.1.1.1", wgKey))
require.False(t, m.IsPresharedKeyInitialized(wgKey))
}
// --- NetbirdHandler.outputKey ----------------------------------------------
func TestHandler_OutputKey_FirstCallUsesUpdateOnlyFalse(t *testing.T) {
h := NewNetbirdHandler()
iface := &mockIface{}
h.SetInterface(iface)
pid := rp.PeerID{0x01}
wgKey := wgtypes.Key{0xAA}
h.AddPeer(pid, "wt0", rp.Key(wgKey))
psk := rp.Key{0xBB}
h.HandshakeCompleted(pid, psk)
require.Len(t, iface.calls, 1)
require.False(t, iface.calls[0].updateOnly, "first PSK rotation must use updateOnly=false")
require.Equal(t, wgKey.String(), iface.calls[0].peerKey)
}
func TestHandler_OutputKey_SubsequentCallsUseUpdateOnlyTrue(t *testing.T) {
h := NewNetbirdHandler()
iface := &mockIface{}
h.SetInterface(iface)
pid := rp.PeerID{0x02}
h.AddPeer(pid, "wt0", rp.Key(wgtypes.Key{0xCC}))
h.HandshakeCompleted(pid, rp.Key{0x01}) // first
h.HandshakeCompleted(pid, rp.Key{0x02}) // second
require.Len(t, iface.calls, 2)
require.False(t, iface.calls[0].updateOnly)
require.True(t, iface.calls[1].updateOnly, "subsequent rotations must use updateOnly=true")
}
func TestHandler_OutputKey_NilInterface_NoCrashNoCall(t *testing.T) {
h := NewNetbirdHandler()
// no SetInterface — iface remains nil
pid := rp.PeerID{0x03}
h.AddPeer(pid, "wt0", rp.Key(wgtypes.Key{}))
// Must not panic.
h.HandshakeCompleted(pid, rp.Key{})
}
func TestHandler_OutputKey_UnknownPeer_NoCall(t *testing.T) {
h := NewNetbirdHandler()
iface := &mockIface{}
h.SetInterface(iface)
h.HandshakeCompleted(rp.PeerID{0xFF}, rp.Key{})
require.Empty(t, iface.calls, "unknown peer id must not trigger SetPresharedKey")
}
func TestHandler_RemovePeer_ClearsInitializedState(t *testing.T) {
h := NewNetbirdHandler()
iface := &mockIface{}
h.SetInterface(iface)
pid := rp.PeerID{0x04}
h.AddPeer(pid, "wt0", rp.Key(wgtypes.Key{0xDD}))
h.HandshakeCompleted(pid, rp.Key{0x01})
require.True(t, h.IsPeerInitialized(pid))
h.RemovePeer(pid)
require.False(t, h.IsPeerInitialized(pid), "RemovePeer must clear initialized flag")
}
func TestHandler_SetInterfaceAfterAddPeer_StillReceivesKey(t *testing.T) {
h := NewNetbirdHandler()
pid := rp.PeerID{0x05}
wgKey := wgtypes.Key{0xEE}
h.AddPeer(pid, "wt0", rp.Key(wgKey))
iface := &mockIface{}
h.SetInterface(iface) // set after AddPeer
h.HandshakeCompleted(pid, rp.Key{0x42})
require.Len(t, iface.calls, 1)
require.Equal(t, wgKey.String(), iface.calls[0].peerKey)
}

View File

@@ -1,42 +0,0 @@
package rosenpass
import (
"fmt"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
// DeterministicSeedKey derives a 32-byte WireGuard preshared key from a pair
// of peer public keys. Both peers, given the same key pair, produce the same
// output regardless of which side runs the function: the inputs are ordered
// lexicographically before concatenation.
//
// NetBird uses this value as the initial Rosenpass-side preshared key when no
// explicit account-level PSK is configured, so both peers converge on the same
// PSK before the first post-quantum handshake completes.
//
// The resulting key MUST NOT be treated as quantum-safe: it is deterministic
// from public keys and exists only to seed WireGuard until Rosenpass rotates
// in a real post-quantum PSK.
func DeterministicSeedKey(localKey, remoteKey string) (*wgtypes.Key, error) {
lk := []byte(localKey)
rk := []byte(remoteKey)
if len(lk) < 16 || len(rk) < 16 {
return nil, fmt.Errorf("rosenpass: peer keys must be at least 16 bytes (got local=%d, remote=%d)", len(lk), len(rk))
}
var keyInput []byte
if localKey > remoteKey {
keyInput = append(keyInput, lk[:16]...)
keyInput = append(keyInput, rk[:16]...)
} else {
keyInput = append(keyInput, rk[:16]...)
keyInput = append(keyInput, lk[:16]...)
}
key, err := wgtypes.NewKey(keyInput)
if err != nil {
return nil, fmt.Errorf("rosenpass: deterministic seed key: %w", err)
}
return &key, nil
}

View File

@@ -1,43 +0,0 @@
package rosenpass
import (
"strings"
"testing"
"github.com/stretchr/testify/require"
)
func TestDeterministicSeedKey_SameForBothSides(t *testing.T) {
// Peer A and peer B must derive the same PSK regardless of which side
// computes it: the function orders inputs internally.
a := strings.Repeat("a", 32)
b := strings.Repeat("b", 32)
keyAB, err := DeterministicSeedKey(a, b)
require.NoError(t, err)
keyBA, err := DeterministicSeedKey(b, a)
require.NoError(t, err)
require.Equal(t, keyAB.String(), keyBA.String(), "swapping arguments must yield identical key")
}
func TestDeterministicSeedKey_ChangesWithKeys(t *testing.T) {
a := strings.Repeat("a", 32)
b := strings.Repeat("b", 32)
c := strings.Repeat("c", 32)
keyAB, err := DeterministicSeedKey(a, b)
require.NoError(t, err)
keyAC, err := DeterministicSeedKey(a, c)
require.NoError(t, err)
require.NotEqual(t, keyAB.String(), keyAC.String(), "different peer pair must yield different key")
}
func TestDeterministicSeedKey_TooShortKey_ReturnsError(t *testing.T) {
short := "short" // < 16 bytes
long := strings.Repeat("x", 32)
_, err := DeterministicSeedKey(short, long)
require.Error(t, err)
_, err = DeterministicSeedKey(long, short)
require.Error(t, err)
}

View File

@@ -96,19 +96,17 @@ func (m *Manager) Stop(ctx context.Context) error {
}
m.mu.Lock()
cancel := m.cancel
done := m.done
m.mu.Unlock()
defer m.mu.Unlock()
if cancel == nil {
if m.cancel == nil {
return nil
}
cancel()
m.cancel()
select {
case <-ctx.Done():
return ctx.Err()
case <-done:
case <-m.done:
}
return nil

View File

@@ -76,6 +76,9 @@ type Client struct {
dnsManager dns.IosDnsManager
loginComplete bool
connectClient *internal.ConnectClient
// config holds the active configuration once Run has loaded it. Consumed by
// the in-app SSH client for the NetBird SSH key and the OAuth flow.
config *profilemanager.Config
// preloadedConfig holds config loaded from JSON (used on tvOS where file writes are blocked)
preloadedConfig *profilemanager.Config
}
@@ -160,6 +163,7 @@ func (c *Client) Run(fd int32, interfaceName string, envList *EnvList) error {
ctx = internal.CtxInitState(ctx)
c.onHostDnsFn = func([]string) {}
cfg.WgIface = interfaceName
c.config = cfg
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
return c.connectClient.RunOniOS(fd, c.networkChangeListener, c.dnsManager, c.stateFile)
@@ -527,6 +531,13 @@ func (c *Client) DeselectRoute(id string) error {
return nil
}
// sshState returns the active config and the running connect client for the
// in-app SSH client. Both are nil until Run has loaded the config and started
// the tunnel.
func (c *Client) sshState() (*profilemanager.Config, *internal.ConnectClient) {
return c.config, c.connectClient
}
func formatDuration(d time.Duration) string {
ds := d.String()
dotIndex := strings.Index(ds, ".")

View File

@@ -0,0 +1,431 @@
//go:build ios
package NetBirdSDK
import (
"context"
"errors"
"fmt"
"io"
"net"
"sync"
"time"
log "github.com/sirupsen/logrus"
gossh "golang.org/x/crypto/ssh"
"github.com/netbirdio/netbird/client/internal"
nbssh "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/ssh/detection"
)
const (
sshDialTimeout = 30 * time.Second
sshDetectionTimeout = 5 * time.Second
)
// SSHTerminalListener receives SSH session events. It is implemented in Swift.
//
// All callbacks are invoked from goroutines and may run concurrently with each
// other; the implementation must be safe to call from any thread.
type SSHTerminalListener interface {
OnConnected()
OnData(data []byte)
OnClose(reason string)
OnError(message string)
}
// SSHClient is a NetBird-aware SSH client exposed to Swift via gomobile.
//
// It dials through the running NetBird tunnel and runs a standard SSH session
// on top with PTY enabled. Host-key verification uses the NetBird-provided
// peer SSH host keys, identical to the desktop client.
type SSHClient struct {
nb *Client
mu sync.Mutex
listener SSHTerminalListener
urlOpener URLOpener
sshClient *gossh.Client
session *gossh.Session
stdin io.WriteCloser
closed bool
}
// NewSSHClient creates a new SSH client bound to the running NetBird Client.
func NewSSHClient(c *Client) *SSHClient {
return &SSHClient{nb: c}
}
// SetListener registers the Swift listener. Must be called before Connect to
// receive any events.
func (s *SSHClient) SetListener(l SSHTerminalListener) {
s.mu.Lock()
s.listener = l
s.mu.Unlock()
}
// SetURLOpener registers the Swift URL opener used to display the device-code
// authorization page in an in-app browser when the target peer requires JWT
// authentication. Must be set before Connect to be effective.
func (s *SSHClient) SetURLOpener(opener URLOpener) {
s.mu.Lock()
s.urlOpener = opener
s.mu.Unlock()
}
// Connect dials the SSH server through the NetBird tunnel and performs the
// SSH handshake. It auto-detects the server type via SSH banner inspection
// and selects the appropriate authentication path:
//
// - NetBird-SSH server requiring JWT: launches the OAuth 2.0 device-code
// flow, opens the verification URL through the registered URLOpener, and
// uses the resulting token as the SSH password. Host-key verification
// uses the NetBird peer registry.
// - NetBird-SSH server without JWT: authenticates with the NetBird SSH
// private key. Host-key verification uses the NetBird peer registry.
// - Regular SSH server (e.g. OpenSSH): authenticates with the NetBird key
// first (so a user-installed NetBird public key works), then falls back
// to the supplied password if non-empty. Host-key verification is
// disabled (TOFU pending).
//
// The password parameter is only consulted for regular SSH servers.
func (s *SSHClient) Connect(host string, port int, user, password string) error {
cfg, cc := s.nb.sshState()
if cc == nil {
return errors.New("netbird client not running")
}
if cfg == nil {
return errors.New("netbird config not loaded")
}
engine := cc.Engine()
if engine == nil {
return errors.New("netbird engine not available")
}
serverType := detectServerType(host, port)
log.Infof("SSH server type for %s:%d: %s", host, port, serverType)
authMethods, hostKeyCallback, err := s.buildAuth(cfg, engine, serverType, password)
if err != nil {
return err
}
clientConfig := &gossh.ClientConfig{
User: user,
Auth: authMethods,
HostKeyCallback: hostKeyCallback,
Timeout: sshDialTimeout,
}
return s.dialAndHandshake(host, port, clientConfig)
}
// StartSession requests a PTY and starts an interactive shell. Output from
// the session is forwarded to the listener via OnData.
func (s *SSHClient) StartSession(cols, rows int) error {
log.Debugf("SSH: starting session %dx%d", cols, rows)
s.mu.Lock()
sshClient := s.sshClient
s.mu.Unlock()
if sshClient == nil {
return errors.New("ssh client not connected")
}
session, err := sshClient.NewSession()
if err != nil {
return fmt.Errorf("new session: %w", err)
}
modes := gossh.TerminalModes{
gossh.ECHO: 1,
gossh.TTY_OP_ISPEED: 14400,
gossh.TTY_OP_OSPEED: 14400,
gossh.VINTR: 3,
gossh.VQUIT: 28,
gossh.VERASE: 127,
}
if err := session.RequestPty("xterm-256color", rows, cols, modes); err != nil {
closeQuiet(session, "session after pty error")
return fmt.Errorf("request pty: %w", err)
}
stdin, err := session.StdinPipe()
if err != nil {
closeQuiet(session, "session after stdin error")
return fmt.Errorf("stdin pipe: %w", err)
}
stdout, err := session.StdoutPipe()
if err != nil {
closeQuiet(session, "session after stdout error")
return fmt.Errorf("stdout pipe: %w", err)
}
stderr, err := session.StderrPipe()
if err != nil {
closeQuiet(session, "session after stderr error")
return fmt.Errorf("stderr pipe: %w", err)
}
if err := session.Shell(); err != nil {
closeQuiet(session, "session after shell error")
return fmt.Errorf("start shell: %w", err)
}
s.mu.Lock()
s.session = session
s.stdin = stdin
s.mu.Unlock()
go s.readLoop(stdout, "stdout")
go s.readLoop(stderr, "stderr")
log.Debug("SSH: session started, shell running")
return nil
}
// Write sends data to the SSH session stdin.
func (s *SSHClient) Write(data []byte) error {
s.mu.Lock()
stdin := s.stdin
s.mu.Unlock()
if stdin == nil {
return errors.New("ssh session not started")
}
if _, err := stdin.Write(data); err != nil {
return fmt.Errorf("write stdin: %w", err)
}
return nil
}
// Resize updates the PTY window size.
func (s *SSHClient) Resize(cols, rows int) error {
s.mu.Lock()
session := s.session
s.mu.Unlock()
if session == nil {
return errors.New("ssh session not started")
}
return session.WindowChange(rows, cols)
}
// Close terminates the SSH session and underlying connection. Safe to call
// multiple times.
func (s *SSHClient) Close() error {
s.mu.Lock()
sshClient := s.sshClient
session := s.session
stdin := s.stdin
s.sshClient = nil
s.session = nil
s.stdin = nil
s.mu.Unlock()
if stdin != nil {
if err := stdin.Close(); err != nil {
log.Debugf("ssh: stdin close: %v", err)
}
}
if session != nil {
if err := session.Close(); err != nil && !errors.Is(err, io.EOF) {
log.Debugf("ssh: session close: %v", err)
}
}
var firstErr error
if sshClient != nil {
if err := sshClient.Close(); err != nil {
firstErr = err
}
}
s.notifyClose("closed by client")
return firstErr
}
func (s *SSHClient) buildAuth(cfg *profilemanager.Config, engine *internal.Engine,
serverType detection.ServerType, password string) ([]gossh.AuthMethod, gossh.HostKeyCallback, error) {
switch serverType {
case detection.ServerTypeNetBirdJWT:
token, err := s.requestJWTToken(cfg)
if err != nil {
return nil, nil, fmt.Errorf("jwt: %w", err)
}
auths := []gossh.AuthMethod{gossh.Password(token)}
return auths, nbssh.CreateHostKeyCallback(&engineHostKeyVerifier{engine: engine}), nil
case detection.ServerTypeNetBirdNoJWT:
if cfg.SSHKey == "" {
return nil, nil, errors.New("no NetBird SSH key available")
}
signer, err := gossh.ParsePrivateKey([]byte(cfg.SSHKey))
if err != nil {
return nil, nil, fmt.Errorf("parse netbird ssh key: %w", err)
}
auths := []gossh.AuthMethod{gossh.PublicKeys(signer)}
return auths, nbssh.CreateHostKeyCallback(&engineHostKeyVerifier{engine: engine}), nil
default: // regular SSH
var auths []gossh.AuthMethod
if cfg.SSHKey != "" {
if signer, err := gossh.ParsePrivateKey([]byte(cfg.SSHKey)); err == nil {
auths = append(auths, gossh.PublicKeys(signer))
} else {
log.Debugf("ssh: parse netbird key for regular auth: %v", err)
}
}
if password != "" {
pw := password
auths = append(auths, gossh.Password(pw))
auths = append(auths, gossh.KeyboardInteractive(func(_, _ string, questions []string, _ []bool) ([]string, error) {
answers := make([]string, len(questions))
for i := range questions {
answers[i] = pw
}
return answers, nil
}))
}
if len(auths) == 0 {
return nil, nil, errors.New("no auth method available: provide a password or configure NetBird SSH key")
}
return auths, gossh.InsecureIgnoreHostKey(), nil // nolint:gosec // TOFU not yet implemented
}
}
func (s *SSHClient) requestJWTToken(cfg *profilemanager.Config) (string, error) {
s.mu.Lock()
urlOpener := s.urlOpener
s.mu.Unlock()
if urlOpener == nil {
return "", errors.New("URL opener not configured for JWT auth")
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
defer cancel()
flow, err := auth.NewOAuthFlow(ctx, cfg, false, true, profilemanager.GetLoginHint())
if err != nil {
return "", fmt.Errorf("create oauth flow: %w", err)
}
flowInfo, err := flow.RequestAuthInfo(ctx)
if err != nil {
return "", fmt.Errorf("request auth info: %w", err)
}
go urlOpener.Open(flowInfo.VerificationURIComplete, flowInfo.UserCode)
tokenInfo, err := flow.WaitToken(ctx, flowInfo)
if err != nil {
return "", fmt.Errorf("wait for token: %w", err)
}
token := tokenInfo.GetTokenToUse()
if token == "" {
return "", errors.New("empty token returned by IdP")
}
return token, nil
}
func (s *SSHClient) dialAndHandshake(host string, port int, clientConfig *gossh.ClientConfig) error {
addr := net.JoinHostPort(host, strconv.Itoa(port))
log.Infof("SSH: connecting to %s as %s", addr, clientConfig.User)
ctx, cancel := context.WithTimeout(context.Background(), sshDialTimeout)
defer cancel()
var dialer net.Dialer
conn, err := dialer.DialContext(ctx, "tcp", addr)
if err != nil {
return fmt.Errorf("dial %s: %w", addr, err)
}
sshConn, chans, reqs, err := gossh.NewClientConn(conn, addr, clientConfig)
if err != nil {
if cerr := conn.Close(); cerr != nil {
log.Debugf("ssh: close after handshake error: %v", cerr)
}
return fmt.Errorf("ssh handshake: %w", err)
}
s.mu.Lock()
s.sshClient = gossh.NewClient(sshConn, chans, reqs)
listener := s.listener
s.mu.Unlock()
log.Infof("SSH: connected to %s", addr)
if listener != nil {
listener.OnConnected()
}
return nil
}
func (s *SSHClient) readLoop(r io.Reader, name string) {
buf := make([]byte, 4096)
for {
n, err := r.Read(buf)
if n > 0 {
s.mu.Lock()
listener := s.listener
s.mu.Unlock()
if listener != nil {
chunk := make([]byte, n)
copy(chunk, buf[:n])
listener.OnData(chunk)
}
}
if err != nil {
if !errors.Is(err, io.EOF) {
log.Debugf("ssh %s read: %v", name, err)
}
s.notifyClose(err.Error())
return
}
}
}
func (s *SSHClient) notifyClose(reason string) {
s.mu.Lock()
if s.closed {
s.mu.Unlock()
return
}
s.closed = true
listener := s.listener
s.mu.Unlock()
if listener != nil {
listener.OnClose(reason)
}
}
// engineHostKeyVerifier adapts *internal.Engine to nbssh.HostKeyVerifier.
type engineHostKeyVerifier struct {
engine *internal.Engine
}
func (v *engineHostKeyVerifier) VerifySSHHostKey(peerAddress string, presented []byte) error {
storedKey, found := v.engine.GetPeerSSHKey(peerAddress)
if !found {
return nbssh.ErrPeerNotFound
}
return nbssh.VerifyHostKey(storedKey, presented, peerAddress)
}
func detectServerType(host string, port int) detection.ServerType {
ctx, cancel := context.WithTimeout(context.Background(), sshDetectionTimeout)
defer cancel()
dialer := &net.Dialer{}
serverType, err := detection.DetectSSHServerType(ctx, dialer, host, port)
if err != nil {
log.Debugf("ssh: server detection for %s:%d failed: %v (assuming regular SSH)", host, port, err)
return detection.ServerTypeRegular
}
return serverType
}
func closeQuiet(c io.Closer, label string) {
if c == nil {
return
}
if err := c.Close(); err != nil && !errors.Is(err, io.EOF) {
log.Debugf("ssh: close %s: %v", label, err)
}
}

10
go.mod
View File

@@ -3,7 +3,7 @@ module github.com/netbirdio/netbird
go 1.25.5
require (
cunicu.li/go-rosenpass v0.5.42
cunicu.li/go-rosenpass v0.4.0
github.com/cenkalti/backoff/v4 v4.3.0
github.com/cloudflare/circl v1.3.3 // indirect
github.com/golang/protobuf v1.5.4
@@ -19,8 +19,8 @@ require (
github.com/vishvananda/netlink v1.3.1
golang.org/x/crypto v0.50.0
golang.org/x/sys v0.43.0
golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10
golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6
golang.zx2c4.com/wireguard/windows v0.5.3
google.golang.org/grpc v1.80.0
google.golang.org/protobuf v1.36.11
@@ -38,7 +38,7 @@ require (
github.com/aws/aws-sdk-go-v2/service/s3 v1.87.3
github.com/c-robinson/iplib v1.0.3
github.com/caddyserver/certmagic v0.21.3
github.com/cilium/ebpf v0.19.0
github.com/cilium/ebpf v0.15.0
github.com/coder/websocket v1.8.14
github.com/coreos/go-iptables v0.7.0
github.com/coreos/go-oidc/v3 v3.18.0
@@ -60,7 +60,7 @@ require (
github.com/google/go-cmp v0.7.0
github.com/google/gopacket v1.1.19
github.com/google/nftables v0.3.0
github.com/gopacket/gopacket v1.4.0
github.com/gopacket/gopacket v1.1.1
github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.0.2-0.20240212192251-757544f21357
github.com/hashicorp/go-multierror v1.1.1
github.com/hashicorp/go-secure-stdlib/base62 v0.1.2

22
go.sum
View File

@@ -7,8 +7,8 @@ cloud.google.com/go/compute/metadata v0.9.0 h1:pDUj4QMoPejqq20dK0Pg2N4yG9zIkYGdB
cloud.google.com/go/compute/metadata v0.9.0/go.mod h1:E0bWwX5wTnLPedCKqk3pJmVgCBSM6qQI1yTBdEb3C10=
codeberg.org/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6 h1:b8xUw3004wk+3ipBhu0VU4RtUJsegMIiqjxSK4++lzA=
codeberg.org/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6/go.mod h1:+CauBF6R70Jqcyl8N2hC8pAXYbWkGIezuSbuGLtRhnw=
cunicu.li/go-rosenpass v0.5.42 h1:fRDsGwCxd7DhDgZI1Pxeo8GtNyq8BESZJ7w2/BGGJtU=
cunicu.li/go-rosenpass v0.5.42/go.mod h1:YRBeyKOe/gWpSX2kpDUec5p9t0XOLsshTguId5gTGVg=
cunicu.li/go-rosenpass v0.4.0 h1:LtPtBgFWY/9emfgC4glKLEqS0MJTylzV6+ChRhiZERw=
cunicu.li/go-rosenpass v0.4.0/go.mod h1:MPbjH9nxV4l3vEagKVdFNwHOketqgS5/To1VYJplf/M=
dario.cat/mergo v1.0.1 h1:Ra4+bf83h2ztPIQYNP99R6m+Y7KfnARDfID+a+vLl4s=
dario.cat/mergo v1.0.1/go.mod h1:uNxQE+84aUszobStD9th8a29P2fMDhsBdgRYvZOxGmk=
filippo.io/edwards25519 v1.1.1 h1:YpjwWWlNmGIDyXOn8zLzqiD+9TyIlPhGFG96P39uBpw=
@@ -111,8 +111,8 @@ github.com/cenkalti/backoff/v5 v5.0.3 h1:ZN+IMa753KfX5hd8vVaMixjnqRZ3y8CuJKRKj1x
github.com/cenkalti/backoff/v5 v5.0.3/go.mod h1:rkhZdG3JZukswDf7f0cwqPNk4K0sa+F97BxZthm/crw=
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/cilium/ebpf v0.19.0 h1:Ro/rE64RmFBeA9FGjcTc+KmCeY6jXmryu6FfnzPRIao=
github.com/cilium/ebpf v0.19.0/go.mod h1:fLCgMo3l8tZmAdM3B2XqdFzXBpwkcSTroaVqN08OWVY=
github.com/cilium/ebpf v0.15.0 h1:7NxJhNiBT3NG8pZJ3c+yfrVdHY8ScgKD27sScgjLMMk=
github.com/cilium/ebpf v0.15.0/go.mod h1:DHp1WyrLeiBh19Cf/tfiSMhqheEiK8fXFZ4No0P1Hso=
github.com/coder/websocket v1.8.14 h1:9L0p0iKiNOibykf283eHkKUHHrpG7f65OE3BhhO7v9g=
github.com/coder/websocket v1.8.14/go.mod h1:NX3SzP+inril6yawo5CQXx8+fk145lPDC6pumgx0mVg=
github.com/containerd/log v0.1.0 h1:TCJt7ioM2cr/tfR8GPbGf9/VRAX8D2B4PjzCpfX540I=
@@ -225,8 +225,8 @@ github.com/go-openapi/validate v0.24.0 h1:LdfDKwNbpB6Vn40xhTdNZAnfLECL81w+VX3Bum
github.com/go-openapi/validate v0.24.0/go.mod h1:iyeX1sEufmv3nPbBdX3ieNviWnOZaJ1+zquzJEf2BAQ=
github.com/go-playground/locales v0.12.1/go.mod h1:IUMDtCfWo/w/mtMfIE/IG2K+Ey3ygWanZIBtBW0W2TM=
github.com/go-playground/universal-translator v0.16.0/go.mod h1:1AnU7NaIRDWWzGEKwgtJRd2xk99HeFyHw3yid4rvQIY=
github.com/go-quicktest/qt v1.101.1-0.20240301121107-c6c8733fa1e6 h1:teYtXy9B7y5lHTp8V9KPxpYRAVA7dozigQcMiBust1s=
github.com/go-quicktest/qt v1.101.1-0.20240301121107-c6c8733fa1e6/go.mod h1:p4lGIVX+8Wa6ZPNDvqcxq36XpUDLh42FLetFU7odllI=
github.com/go-quicktest/qt v1.101.0 h1:O1K29Txy5P2OK0dGo59b7b0LR6wKfIhttaAhHUyn7eI=
github.com/go-quicktest/qt v1.101.0/go.mod h1:14Bz/f7NwaXPtdYEgzsx46kqSxVwTbzVZsDC26tQJow=
github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI=
github.com/go-sql-driver/mysql v1.9.3 h1:U/N249h2WzJ3Ukj8SowVFjdtZKfu9vlLZxjPXV1aweo=
github.com/go-sql-driver/mysql v1.9.3/go.mod h1:qn46aNg1333BRMNU69Lq93t8du/dwxI64Gl8i5p1WMU=
@@ -307,8 +307,8 @@ github.com/googleapis/enterprise-certificate-proxy v0.3.14 h1:yh8ncqsbUY4shRD5dA
github.com/googleapis/enterprise-certificate-proxy v0.3.14/go.mod h1:vqVt9yG9480NtzREnTlmGSBmFrA+bzb0yl0TxoBQXOg=
github.com/googleapis/gax-go/v2 v2.21.0 h1:h45NjjzEO3faG9Lg/cFrBh2PgegVVgzqKzuZl/wMbiI=
github.com/googleapis/gax-go/v2 v2.21.0/go.mod h1:But/NJU6TnZsrLai/xBAQLLz+Hc7fHZJt/hsCz3Fih4=
github.com/gopacket/gopacket v1.4.0 h1:cr1OlFpzksCkZHNO0eLjaSSOrMQnpPXg0j6qHIY3y2U=
github.com/gopacket/gopacket v1.4.0/go.mod h1:EpvsxINeehp5qj4YMKMLf2/dekdhKn2IIAO/ZOifS7o=
github.com/gopacket/gopacket v1.1.1 h1:zbx9F9d6A7sWNkFKrvMBZTfGgxFoY4NgUudFVVHMfcw=
github.com/gopacket/gopacket v1.1.1/go.mod h1:HavMeONEl7W9036of9LbSWoonqhH7HA1+ZRO+rMIvFs=
github.com/gorilla/handlers v1.5.2 h1:cLTUSsNkgcwhgRqvCNmdbRWG0A3N4F+M2nWKdScwyEE=
github.com/gorilla/handlers v1.5.2/go.mod h1:dX+xVpaxdSw+q0Qek8SSsl3dfMk3jNddUkMzo0GtH0w=
github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY=
@@ -390,8 +390,6 @@ github.com/jonboulle/clockwork v0.5.0 h1:Hyh9A8u51kptdkR+cqRpT1EebBwTn1oK9YfGYbd
github.com/jonboulle/clockwork v0.5.0/go.mod h1:3mZlmanh0g2NDKO5TWZVJAfofYk64M7XN3SzBPjZF60=
github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY=
github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y=
github.com/jsimonetti/rtnetlink/v2 v2.0.1 h1:xda7qaHDSVOsADNouv7ukSuicKZO7GgVUCXxpaIEIlM=
github.com/jsimonetti/rtnetlink/v2 v2.0.1/go.mod h1:7MoNYNbb3UaDHtF8udiJo/RH6VsTKP1pqKLUTVCvToE=
github.com/json-iterator/go v1.1.7/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
github.com/jsummers/gobmp v0.0.0-20230614200233-a9de23ed2e25 h1:YLvr1eE6cdCqjOe972w/cYF+FjW34v27+9Vo5106B4M=
github.com/jsummers/gobmp v0.0.0-20230614200233-a9de23ed2e25/go.mod h1:kLgvv7o6UM+0QSf0QjAse3wReFDsb9qbZJdfexWlrQw=
@@ -902,8 +900,8 @@ golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8T
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg=
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10 h1:3GDAcqdIg1ozBNLgPy4SLT84nfcBjr6rhGtXYtrkWLU=
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10/go.mod h1:T97yPqesLiNrOYxkwmhMI0ZIlJDm+p0PMR8eRVeR5tQ=
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 h1:CawjfCvYQH2OU3/TnxLx97WDSUDRABfT18pCOYwc2GE=
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6/go.mod h1:3rxYc4HtVcSG9gVaTs2GEBdehh+sYPOwKtyUWEOTb80=
golang.zx2c4.com/wireguard/windows v0.5.3 h1:On6j2Rpn3OEMXqBq00QEDC7bWSZrPIHKIus8eIuExIE=
golang.zx2c4.com/wireguard/windows v0.5.3/go.mod h1:9TEe8TJmtwyQebdFwAkEWOPr3prrtqm+REGFifP60hI=
gonum.org/v1/gonum v0.17.0 h1:VbpOemQlsSMrYmn7T2OUvQ4dqxQXU+ouZFQsZOx50z4=

View File

@@ -44,7 +44,7 @@ type Controller struct {
EphemeralPeersManager ephemeral.Manager
accountUpdateLocks sync.Map
affectedPeerUpdateLocks sync.Map
sendAccountUpdateLocks sync.Map
updateAccountPeersBufferInterval atomic.Int64
// dnsDomain is used for peer resolution. This is appended to the peer's name
dnsDomain string
@@ -63,13 +63,6 @@ type bufferUpdate struct {
update atomic.Bool
}
type bufferAffectedUpdate struct {
sendMu sync.Mutex
dataMu sync.Mutex
next *time.Timer
peerIDs map[string]struct{}
}
var _ network_map.Controller = (*Controller)(nil)
func NewController(ctx context.Context, store store.Store, metrics telemetry.AppMetrics, peersUpdateManager network_map.PeersUpdateManager, requestBuffer account.RequestBuffer, integratedPeerValidator integrated_validator.IntegratedValidator, settingsManager settings.Manager, dnsDomain string, proxyController port_forwarding.Controller, ephemeralPeersManager ephemeral.Manager, config *config.Config) *Controller {
@@ -207,7 +200,7 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
c.metrics.CountCalcPeerNetworkMapDuration(time.Since(start))
proxyNetworkMap, ok := proxyNetworkMaps[p.ID]
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
if ok {
remotePeerNetworkMap.Merge(proxyNetworkMap)
}
@@ -232,6 +225,44 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
return nil
}
func (c *Controller) bufferSendUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) error {
log.WithContext(ctx).Tracef("buffer sending update peers for account %s from %s", accountID, util.GetCallerName())
if c.accountManagerMetrics != nil {
c.accountManagerMetrics.CountUpdateAccountPeersTriggered(string(reason.Resource), string(reason.Operation))
}
bufUpd, _ := c.sendAccountUpdateLocks.LoadOrStore(accountID, &bufferUpdate{})
b := bufUpd.(*bufferUpdate)
if !b.mu.TryLock() {
b.update.Store(true)
return nil
}
if b.next != nil {
b.next.Stop()
}
go func() {
defer b.mu.Unlock()
_ = c.sendUpdateAccountPeers(ctx, accountID, reason)
if !b.update.Load() {
return
}
b.update.Store(false)
if b.next == nil {
b.next = time.AfterFunc(time.Duration(c.updateAccountPeersBufferInterval.Load()), func() {
_ = c.sendUpdateAccountPeers(ctx, accountID, reason)
})
return
}
b.next.Reset(time.Duration(c.updateAccountPeersBufferInterval.Load()))
}()
return nil
}
// UpdatePeers updates all peers that belong to an account.
// Should be called when changes have to be synced to peers.
func (c *Controller) UpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) error {
@@ -241,143 +272,6 @@ func (c *Controller) UpdateAccountPeers(ctx context.Context, accountID string, r
return c.sendUpdateAccountPeers(ctx, accountID, reason)
}
// UpdateAffectedPeers updates only the specified peers that belong to an account.
func (c *Controller) UpdateAffectedPeers(ctx context.Context, accountID string, peerIDs []string) error {
if len(peerIDs) == 0 {
return nil
}
return c.sendUpdateForAffectedPeers(ctx, accountID, peerIDs)
}
func (c *Controller) sendUpdateForAffectedPeers(ctx context.Context, accountID string, peerIDs []string) error {
log.WithContext(ctx).Tracef("sendUpdateForAffectedPeers: account %s, %d affected peers: %v (caller: %s)", accountID, len(peerIDs), peerIDs, util.GetCallerName())
if !c.hasConnectedPeers(peerIDs) {
log.WithContext(ctx).Tracef("sendUpdateForAffectedPeers: no connected peers among %v, skipping", peerIDs)
return nil
}
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return fmt.Errorf("failed to get account: %v", err)
}
globalStart := time.Now()
peersToUpdate := c.filterConnectedAffectedPeers(account, peerIDs)
if len(peersToUpdate) == 0 {
log.WithContext(ctx).Tracef("sendUpdateForAffectedPeers: no peers to update (affected peers not found in account or no channels)")
return nil
}
log.WithContext(ctx).Tracef("sendUpdateForAffectedPeers: sending network map to %d connected peers", len(peersToUpdate))
approvedPeersMap, err := c.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
if err != nil {
return fmt.Errorf("failed to get validate peers: %v", err)
}
var wg sync.WaitGroup
semaphore := make(chan struct{}, 10)
account.InjectProxyPolicies(ctx)
dnsCache := &cache.DNSConfigCache{}
dnsDomain := c.GetDNSDomain(account.Settings)
peersCustomZone := account.GetPeersCustomZone(ctx, dnsDomain)
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
groupIDToUserIDs := account.GetActiveGroupUsers()
proxyNetworkMaps, err := c.proxyController.GetProxyNetworkMapsAll(ctx, accountID, account.Peers)
if err != nil {
log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err)
return fmt.Errorf("failed to get proxy network maps: %v", err)
}
extraSetting, err := c.settingsManager.GetExtraSettings(ctx, accountID)
if err != nil {
return fmt.Errorf("failed to get flow enabled status: %v", err)
}
dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), network_map.DnsForwarderPortMinVersion)
accountZones, err := c.repo.GetAccountZones(ctx, account.Id)
if err != nil {
log.WithContext(ctx).Errorf("failed to get account zones: %v", err)
return fmt.Errorf("failed to get account zones: %v", err)
}
for _, peer := range peersToUpdate {
wg.Add(1)
semaphore <- struct{}{}
go func(p *nbpeer.Peer) {
defer wg.Done()
defer func() { <-semaphore }()
start := time.Now()
postureChecks, err := c.getPeerPostureChecks(account, p.ID)
if err != nil {
log.WithContext(ctx).Debugf("failed to get posture checks for peer %s: %v", p.ID, err)
return
}
c.metrics.CountCalcPostureChecksDuration(time.Since(start))
start = time.Now()
remotePeerNetworkMap := account.GetPeerNetworkMapFromComponents(ctx, p.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
c.metrics.CountCalcPeerNetworkMapDuration(time.Since(start))
proxyNetworkMap, ok := proxyNetworkMaps[p.ID]
if ok {
remotePeerNetworkMap.Merge(proxyNetworkMap)
}
peerGroups := account.GetPeerGroups(p.ID)
start = time.Now()
update := grpc.ToSyncResponse(ctx, nil, c.config.HttpConfig, c.config.DeviceAuthorizationFlow, p, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSetting, maps.Keys(peerGroups), dnsFwdPort)
c.metrics.CountToSyncResponseDuration(time.Since(start))
c.peersUpdateManager.SendUpdate(ctx, p.ID, &network_map.UpdateMessage{
Update: update,
MessageType: network_map.MessageTypeNetworkMap,
})
}(peer)
}
wg.Wait()
if c.accountManagerMetrics != nil {
c.accountManagerMetrics.CountUpdateAccountPeersDuration(time.Since(globalStart))
}
return nil
}
func (c *Controller) hasConnectedPeers(peerIDs []string) bool {
for _, id := range peerIDs {
if c.peersUpdateManager.HasChannel(id) {
return true
}
}
return false
}
func (c *Controller) filterConnectedAffectedPeers(account *types.Account, peerIDs []string) []*nbpeer.Peer {
affected := make(map[string]struct{}, len(peerIDs))
for _, id := range peerIDs {
affected[id] = struct{}{}
}
var result []*nbpeer.Peer
for _, peer := range account.Peers {
if _, ok := affected[peer.ID]; ok && c.peersUpdateManager.HasChannel(peer.ID) {
result = append(result, peer)
}
}
return result
}
func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, peerId string) error {
if !c.peersUpdateManager.HasChannel(peerId) {
return fmt.Errorf("peer %s doesn't have a channel, skipping network map update", peerId)
@@ -486,100 +380,6 @@ func (c *Controller) BufferUpdateAccountPeers(ctx context.Context, accountID str
return nil
}
// BufferUpdateAffectedPeers accumulates peer IDs and flushes them after the buffer interval.
func (c *Controller) BufferUpdateAffectedPeers(ctx context.Context, accountID string, peerIDs []string, reason types.UpdateReason) error {
if len(peerIDs) == 0 {
return nil
}
if c.accountManagerMetrics != nil {
c.accountManagerMetrics.CountUpdateAccountPeersTriggered(string(reason.Resource), string(reason.Operation))
}
log.WithContext(ctx).Tracef("buffer updating %d affected peers for account %s from %s", len(peerIDs), accountID, util.GetCallerName())
bufUpd, _ := c.affectedPeerUpdateLocks.LoadOrStore(accountID, &bufferAffectedUpdate{
peerIDs: make(map[string]struct{}),
})
b := bufUpd.(*bufferAffectedUpdate)
b.addPeerIDs(peerIDs)
if !b.sendMu.TryLock() {
// Another goroutine is already sending; it will pick up our IDs on its next drain.
return nil
}
b.stopTimer()
collected := b.drainPeerIDs()
go func() {
defer b.sendMu.Unlock()
_ = c.sendUpdateForAffectedPeers(ctx, accountID, collected)
// Check if more peer IDs accumulated while we were sending.
if !b.hasPending() {
return
}
// Schedule a debounced flush for the newly accumulated IDs.
b.setTimer(time.Duration(c.updateAccountPeersBufferInterval.Load()), func() {
ids := b.drainPeerIDs()
if len(ids) > 0 {
_ = c.sendUpdateForAffectedPeers(ctx, accountID, ids)
}
})
}()
return nil
}
func (b *bufferAffectedUpdate) addPeerIDs(ids []string) {
b.dataMu.Lock()
for _, id := range ids {
b.peerIDs[id] = struct{}{}
}
b.dataMu.Unlock()
}
func (b *bufferAffectedUpdate) drainPeerIDs() []string {
b.dataMu.Lock()
defer b.dataMu.Unlock()
if len(b.peerIDs) == 0 {
return nil
}
ids := make([]string, 0, len(b.peerIDs))
for id := range b.peerIDs {
ids = append(ids, id)
}
b.peerIDs = make(map[string]struct{})
return ids
}
func (b *bufferAffectedUpdate) hasPending() bool {
b.dataMu.Lock()
defer b.dataMu.Unlock()
return len(b.peerIDs) > 0
}
func (b *bufferAffectedUpdate) stopTimer() {
b.dataMu.Lock()
defer b.dataMu.Unlock()
if b.next != nil {
b.next.Stop()
}
}
func (b *bufferAffectedUpdate) setTimer(d time.Duration, f func()) {
b.dataMu.Lock()
defer b.dataMu.Unlock()
if b.next == nil {
b.next = time.AfterFunc(d, f)
return
}
b.next.Reset(d)
}
func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
if isRequiresApproval {
network, err := c.repo.GetAccountNetwork(ctx, accountID)
@@ -777,24 +577,21 @@ func isPeerInPolicySourceGroups(account *types.Account, peerID string, policy *t
return false, nil
}
func (c *Controller) OnPeersUpdated(ctx context.Context, accountID string, peerIDs []string, affectedPeerIDs []string) error {
if len(affectedPeerIDs) == 0 {
log.WithContext(ctx).Tracef("no affected peers for peer update in account %s, skipping", accountID)
return nil
func (c *Controller) OnPeersUpdated(ctx context.Context, accountID string, peerIDs []string) error {
err := c.bufferSendUpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourcePeer, Operation: types.UpdateOperationUpdate})
if err != nil {
log.WithContext(ctx).Errorf("failed to buffer update account peers for peer update in account %s: %v", accountID, err)
}
return c.BufferUpdateAffectedPeers(ctx, accountID, affectedPeerIDs, types.UpdateReason{Resource: types.UpdateResourcePeer, Operation: types.UpdateOperationUpdate})
return nil
}
func (c *Controller) OnPeersAdded(ctx context.Context, accountID string, peerIDs []string, affectedPeerIDs []string) error {
func (c *Controller) OnPeersAdded(ctx context.Context, accountID string, peerIDs []string) error {
log.WithContext(ctx).Debugf("OnPeersAdded call to add peers: %v", peerIDs)
if len(affectedPeerIDs) == 0 {
log.WithContext(ctx).Tracef("no affected peers for peer add in account %s, skipping", accountID)
return nil
}
return c.BufferUpdateAffectedPeers(ctx, accountID, affectedPeerIDs, types.UpdateReason{Resource: types.UpdateResourcePeer, Operation: types.UpdateOperationCreate})
return c.bufferSendUpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourcePeer, Operation: types.UpdateOperationCreate})
}
func (c *Controller) OnPeersDeleted(ctx context.Context, accountID string, peerIDs []string, affectedPeerIDs []string) error {
func (c *Controller) OnPeersDeleted(ctx context.Context, accountID string, peerIDs []string) error {
network, err := c.repo.GetAccountNetwork(ctx, accountID)
if err != nil {
return err
@@ -827,11 +624,7 @@ func (c *Controller) OnPeersDeleted(ctx context.Context, accountID string, peerI
c.peersUpdateManager.CloseChannel(ctx, peerID)
}
if len(affectedPeerIDs) == 0 {
log.WithContext(ctx).Tracef("no affected peers for peer delete in account %s, skipping network map update", accountID)
return nil
}
return c.BufferUpdateAffectedPeers(ctx, accountID, affectedPeerIDs, types.UpdateReason{Resource: types.UpdateResourcePeer, Operation: types.UpdateOperationDelete})
return c.bufferSendUpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourcePeer, Operation: types.UpdateOperationDelete})
}
// GetNetworkMap returns Network map for a given peer (omits original peer from the Peers result)

View File

@@ -19,8 +19,6 @@ const (
type Controller interface {
UpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) error
UpdateAffectedPeers(ctx context.Context, accountID string, peerIDs []string) error
BufferUpdateAffectedPeers(ctx context.Context, accountID string, peerIDs []string, reason types.UpdateReason) error
UpdateAccountPeer(ctx context.Context, accountId string, peerId string) error
BufferUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) error
GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, p *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
@@ -29,9 +27,9 @@ type Controller interface {
GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error)
CountStreams() int
OnPeersUpdated(ctx context.Context, accountId string, peerIDs []string, affectedPeerIDs []string) error
OnPeersAdded(ctx context.Context, accountID string, peerIDs []string, affectedPeerIDs []string) error
OnPeersDeleted(ctx context.Context, accountID string, peerIDs []string, affectedPeerIDs []string) error
OnPeersUpdated(ctx context.Context, accountId string, peerIDs []string) error
OnPeersAdded(ctx context.Context, accountID string, peerIDs []string) error
OnPeersDeleted(ctx context.Context, accountID string, peerIDs []string) error
DisconnectPeers(ctx context.Context, accountId string, peerIDs []string)
OnPeerConnected(ctx context.Context, accountID string, peerID string) (chan *UpdateMessage, error)
OnPeerDisconnected(ctx context.Context, accountID string, peerID string)

View File

@@ -57,20 +57,6 @@ func (mr *MockControllerMockRecorder) BufferUpdateAccountPeers(ctx, accountID, r
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BufferUpdateAccountPeers", reflect.TypeOf((*MockController)(nil).BufferUpdateAccountPeers), ctx, accountID, reason)
}
// BufferUpdateAffectedPeers mocks base method.
func (m *MockController) BufferUpdateAffectedPeers(ctx context.Context, accountID string, peerIDs []string, reason types.UpdateReason) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "BufferUpdateAffectedPeers", ctx, accountID, peerIDs, reason)
ret0, _ := ret[0].(error)
return ret0
}
// BufferUpdateAffectedPeers indicates an expected call of BufferUpdateAffectedPeers.
func (mr *MockControllerMockRecorder) BufferUpdateAffectedPeers(ctx, accountID, peerIDs, reason any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BufferUpdateAffectedPeers", reflect.TypeOf((*MockController)(nil).BufferUpdateAffectedPeers), ctx, accountID, peerIDs, reason)
}
// CountStreams mocks base method.
func (m *MockController) CountStreams() int {
m.ctrl.T.Helper()
@@ -172,45 +158,45 @@ func (mr *MockControllerMockRecorder) OnPeerDisconnected(ctx, accountID, peerID
}
// OnPeersAdded mocks base method.
func (m *MockController) OnPeersAdded(ctx context.Context, accountID string, peerIDs []string, affectedPeerIDs []string) error {
func (m *MockController) OnPeersAdded(ctx context.Context, accountID string, peerIDs []string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "OnPeersAdded", ctx, accountID, peerIDs, affectedPeerIDs)
ret := m.ctrl.Call(m, "OnPeersAdded", ctx, accountID, peerIDs)
ret0, _ := ret[0].(error)
return ret0
}
// OnPeersAdded indicates an expected call of OnPeersAdded.
func (mr *MockControllerMockRecorder) OnPeersAdded(ctx, accountID, peerIDs, affectedPeerIDs any) *gomock.Call {
func (mr *MockControllerMockRecorder) OnPeersAdded(ctx, accountID, peerIDs any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeersAdded", reflect.TypeOf((*MockController)(nil).OnPeersAdded), ctx, accountID, peerIDs, affectedPeerIDs)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeersAdded", reflect.TypeOf((*MockController)(nil).OnPeersAdded), ctx, accountID, peerIDs)
}
// OnPeersDeleted mocks base method.
func (m *MockController) OnPeersDeleted(ctx context.Context, accountID string, peerIDs []string, affectedPeerIDs []string) error {
func (m *MockController) OnPeersDeleted(ctx context.Context, accountID string, peerIDs []string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "OnPeersDeleted", ctx, accountID, peerIDs, affectedPeerIDs)
ret := m.ctrl.Call(m, "OnPeersDeleted", ctx, accountID, peerIDs)
ret0, _ := ret[0].(error)
return ret0
}
// OnPeersDeleted indicates an expected call of OnPeersDeleted.
func (mr *MockControllerMockRecorder) OnPeersDeleted(ctx, accountID, peerIDs, affectedPeerIDs any) *gomock.Call {
func (mr *MockControllerMockRecorder) OnPeersDeleted(ctx, accountID, peerIDs any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeersDeleted", reflect.TypeOf((*MockController)(nil).OnPeersDeleted), ctx, accountID, peerIDs, affectedPeerIDs)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeersDeleted", reflect.TypeOf((*MockController)(nil).OnPeersDeleted), ctx, accountID, peerIDs)
}
// OnPeersUpdated mocks base method.
func (m *MockController) OnPeersUpdated(ctx context.Context, accountId string, peerIDs []string, affectedPeerIDs []string) error {
func (m *MockController) OnPeersUpdated(ctx context.Context, accountId string, peerIDs []string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "OnPeersUpdated", ctx, accountId, peerIDs, affectedPeerIDs)
ret := m.ctrl.Call(m, "OnPeersUpdated", ctx, accountId, peerIDs)
ret0, _ := ret[0].(error)
return ret0
}
// OnPeersUpdated indicates an expected call of OnPeersUpdated.
func (mr *MockControllerMockRecorder) OnPeersUpdated(ctx, accountId, peerIDs, affectedPeerIDs any) *gomock.Call {
func (mr *MockControllerMockRecorder) OnPeersUpdated(ctx, accountId, peerIDs any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeersUpdated", reflect.TypeOf((*MockController)(nil).OnPeersUpdated), ctx, accountId, peerIDs, affectedPeerIDs)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeersUpdated", reflect.TypeOf((*MockController)(nil).OnPeersUpdated), ctx, accountId, peerIDs)
}
// StartWarmup mocks base method.
@@ -264,17 +250,3 @@ func (mr *MockControllerMockRecorder) UpdateAccountPeers(ctx, accountID, reason
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAccountPeers", reflect.TypeOf((*MockController)(nil).UpdateAccountPeers), ctx, accountID, reason)
}
// UpdateAffectedPeers mocks base method.
func (m *MockController) UpdateAffectedPeers(ctx context.Context, accountID string, peerIDs []string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpdateAffectedPeers", ctx, accountID, peerIDs)
ret0, _ := ret[0].(error)
return ret0
}
// UpdateAffectedPeers indicates an expected call of UpdateAffectedPeers.
func (mr *MockControllerMockRecorder) UpdateAffectedPeers(ctx, accountID, peerIDs any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAffectedPeers", reflect.TypeOf((*MockController)(nil).UpdateAffectedPeers), ctx, accountID, peerIDs)
}

View File

@@ -2562,9 +2562,7 @@ func (am *DefaultAccountManager) UpdatePeerIP(ctx context.Context, accountID, us
if err != nil {
return err
}
changedPeerIDs := []string{peerID}
affectedPeerIDs := am.resolveAffectedPeersForPeerChanges(ctx, am.Store, accountID, changedPeerIDs)
err = am.networkMapController.OnPeersUpdated(ctx, peer.AccountID, changedPeerIDs, affectedPeerIDs)
err = am.networkMapController.OnPeersUpdated(ctx, peer.AccountID, []string{peerID})
if err != nil {
return fmt.Errorf("notify network map controller of peer update: %w", err)
}
@@ -2655,9 +2653,7 @@ func (am *DefaultAccountManager) UpdatePeerIPv6(ctx context.Context, accountID,
}
if updateNetworkMap {
changedPeerIDs := []string{peerID}
affectedPeerIDs := am.resolveAffectedPeersForPeerChanges(ctx, am.Store, accountID, changedPeerIDs)
if err := am.networkMapController.OnPeersUpdated(ctx, accountID, changedPeerIDs, affectedPeerIDs); err != nil {
if err := am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peerID}); err != nil {
return fmt.Errorf("notify network map controller: %w", err)
}
}

View File

@@ -127,8 +127,6 @@ type Manager interface {
GetAccountSettings(ctx context.Context, accountID string, userID string) (*types.Settings, error)
DeleteSetupKey(ctx context.Context, accountID, userID, keyID string) error
UpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason)
UpdateAffectedPeers(ctx context.Context, accountID string, peerIDs []string)
BufferUpdateAffectedPeers(ctx context.Context, accountID string, peerIDs []string, reason types.UpdateReason)
BufferUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason)
BuildUserInfosForAccount(ctx context.Context, accountID, initiatorUserID string, accountUsers []*types.User) (map[string]*types.UserInfo, error)
SyncUserJWTGroups(ctx context.Context, userAuth auth.UserAuth) error

View File

@@ -122,18 +122,6 @@ func (mr *MockManagerMockRecorder) BufferUpdateAccountPeers(ctx, accountID, reas
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BufferUpdateAccountPeers", reflect.TypeOf((*MockManager)(nil).BufferUpdateAccountPeers), ctx, accountID, reason)
}
// BufferUpdateAffectedPeers mocks base method.
func (m *MockManager) BufferUpdateAffectedPeers(ctx context.Context, accountID string, peerIDs []string, reason types.UpdateReason) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "BufferUpdateAffectedPeers", ctx, accountID, peerIDs, reason)
}
// BufferUpdateAffectedPeers indicates an expected call of BufferUpdateAffectedPeers.
func (mr *MockManagerMockRecorder) BufferUpdateAffectedPeers(ctx, accountID, peerIDs, reason interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BufferUpdateAffectedPeers", reflect.TypeOf((*MockManager)(nil).BufferUpdateAffectedPeers), ctx, accountID, peerIDs, reason)
}
// BuildUserInfosForAccount mocks base method.
func (m *MockManager) BuildUserInfosForAccount(ctx context.Context, accountID, initiatorUserID string, accountUsers []*types.User) (map[string]*types.UserInfo, error) {
m.ctrl.T.Helper()
@@ -1634,18 +1622,6 @@ func (mr *MockManagerMockRecorder) UpdateAccountPeers(ctx, accountID, reason int
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAccountPeers", reflect.TypeOf((*MockManager)(nil).UpdateAccountPeers), ctx, accountID, reason)
}
// UpdateAffectedPeers mocks base method.
func (m *MockManager) UpdateAffectedPeers(ctx context.Context, accountID string, peerIDs []string) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "UpdateAffectedPeers", ctx, accountID, peerIDs)
}
// UpdateAffectedPeers indicates an expected call of UpdateAffectedPeers.
func (mr *MockManagerMockRecorder) UpdateAffectedPeers(ctx, accountID, peerIDs interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAffectedPeers", reflect.TypeOf((*MockManager)(nil).UpdateAffectedPeers), ctx, accountID, peerIDs)
}
// UpdateAccountSettings mocks base method.
func (m *MockManager) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Settings, error) {
m.ctrl.T.Helper()

View File

@@ -3282,19 +3282,6 @@ func setupNetworkMapTest(t *testing.T) (*DefaultAccountManager, *update_channel.
// when the channel delivers.
const peerUpdateTimeout = 5 * time.Second
func drainPeerUpdates(ch <-chan *network_map.UpdateMessage) {
for {
select {
case _, ok := <-ch:
if !ok {
return
}
case <-time.After(200 * time.Millisecond):
return
}
}
}
func peerShouldNotReceiveUpdate(t *testing.T, updateMessage <-chan *network_map.UpdateMessage) {
t.Helper()
select {

View File

@@ -1,333 +0,0 @@
package server
import (
"context"
log "github.com/sirupsen/logrus"
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/route"
)
// collectPeerChangeAffectedGroups walks policies, routes, nameservers, DNS settings,
// and network routers to collect all group IDs and direct peer IDs affected by the
// changed groups and/or changed peers. Each collection is fetched from the store exactly once.
func collectPeerChangeAffectedGroups(ctx context.Context, transaction store.Store, accountID string, changedGroupIDs, changedPeerIDs []string) (allGroupIDs []string, directPeerIDs []string) {
if len(changedGroupIDs) == 0 && len(changedPeerIDs) == 0 {
return nil, nil
}
changedGroupSet := toSet(changedGroupIDs)
changedPeerSet := toSet(changedPeerIDs)
groupSet := make(map[string]struct{})
peerSet := make(map[string]struct{})
collectAffectedFromPolicies(ctx, transaction, accountID, changedGroupSet, changedPeerSet, groupSet, peerSet)
collectAffectedFromRoutes(ctx, transaction, accountID, changedGroupSet, changedPeerSet, groupSet, peerSet)
collectAffectedFromNameServers(ctx, transaction, accountID, changedGroupSet, groupSet)
collectAffectedFromDNSSettings(ctx, transaction, accountID, changedGroupSet, groupSet)
collectAffectedFromNetworkRouters(ctx, transaction, accountID, changedGroupSet, changedPeerSet, groupSet, peerSet)
collectAffectedFromProxyServices(ctx, transaction, accountID, changedGroupSet, changedPeerSet, peerSet)
allGroupIDs = setToSlice(groupSet)
directPeerIDs = setToSlice(peerSet)
log.WithContext(ctx).Tracef("affected groups resolution: changedGroups=%v changedPeers=%v -> affectedGroups=%v, directPeers=%v",
changedGroupIDs, changedPeerIDs, allGroupIDs, directPeerIDs)
return allGroupIDs, directPeerIDs
}
// collectGroupChangeAffectedGroups is a convenience wrapper used by callers that only have changed groups.
func collectGroupChangeAffectedGroups(ctx context.Context, transaction store.Store, accountID string, changedGroupIDs []string) ([]string, []string) {
return collectPeerChangeAffectedGroups(ctx, transaction, accountID, changedGroupIDs, nil)
}
func collectAffectedFromPolicies(ctx context.Context, transaction store.Store, accountID string, changedGroupSet, changedPeerSet map[string]struct{}, groupSet, peerSet map[string]struct{}) {
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get policies for affected group resolution: %v", err)
return
}
for _, policy := range policies {
matchedByGroup := policyReferencesGroups(policy, changedGroupSet)
matchedByPeer := len(changedPeerSet) > 0 && policyReferencesDirectPeers(policy, changedPeerSet)
if !matchedByGroup && !matchedByPeer {
continue
}
addAllToSet(groupSet, policy.RuleGroups())
collectPolicyDirectPeers(policy, peerSet)
}
}
func collectAffectedFromRoutes(ctx context.Context, transaction store.Store, accountID string, changedGroupSet, changedPeerSet map[string]struct{}, groupSet, peerSet map[string]struct{}) {
routes, err := transaction.GetAccountRoutes(ctx, store.LockingStrengthNone, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get routes for affected group resolution: %v", err)
return
}
for _, r := range routes {
matchedByGroup := routeReferencesGroups(r, changedGroupSet)
matchedByPeer := r.Peer != "" && len(changedPeerSet) > 0 && isInSet(r.Peer, changedPeerSet)
if !matchedByGroup && !matchedByPeer {
continue
}
addAllToSet(groupSet, r.Groups, r.PeerGroups, r.AccessControlGroups)
if r.Peer != "" {
peerSet[r.Peer] = struct{}{}
}
}
}
func collectAffectedFromNameServers(ctx context.Context, transaction store.Store, accountID string, changedGroupSet map[string]struct{}, groupSet map[string]struct{}) {
if len(changedGroupSet) == 0 {
return
}
nsGroups, err := transaction.GetAccountNameServerGroups(ctx, store.LockingStrengthNone, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get nameserver groups for affected group resolution: %v", err)
return
}
for _, ns := range nsGroups {
if anyInSet(ns.Groups, changedGroupSet) {
addAllToSet(groupSet, ns.Groups)
}
}
}
func collectAffectedFromDNSSettings(ctx context.Context, transaction store.Store, accountID string, changedGroupSet map[string]struct{}, groupSet map[string]struct{}) {
if len(changedGroupSet) == 0 {
return
}
dnsSettings, err := transaction.GetAccountDNSSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get DNS settings for affected group resolution: %v", err)
return
}
for _, gID := range dnsSettings.DisabledManagementGroups {
if _, ok := changedGroupSet[gID]; ok {
groupSet[gID] = struct{}{}
}
}
}
func collectAffectedFromNetworkRouters(ctx context.Context, transaction store.Store, accountID string, changedGroupSet, changedPeerSet map[string]struct{}, groupSet, peerSet map[string]struct{}) {
routers, err := transaction.GetNetworkRoutersByAccountID(ctx, store.LockingStrengthNone, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get network routers for affected group resolution: %v", err)
return
}
for _, router := range routers {
matchedByGroup := routerReferencesGroups(router, changedGroupSet)
matchedByPeer := router.Peer != "" && len(changedPeerSet) > 0 && isInSet(router.Peer, changedPeerSet)
if !matchedByGroup && !matchedByPeer {
continue
}
addAllToSet(groupSet, router.PeerGroups)
if router.Peer != "" {
peerSet[router.Peer] = struct{}{}
}
}
}
// collectAffectedFromProxyServices handles policies that are synthesized at
// network-map computation time by Account.InjectProxyPolicies. Those policies
// connect proxy peers (peer.ProxyMeta.Embedded == true) to service targets and
// never reach the database, so the other collectors cannot see them.
func collectAffectedFromProxyServices(ctx context.Context, transaction store.Store, accountID string, changedGroupSet, changedPeerSet map[string]struct{}, peerSet map[string]struct{}) {
if len(changedGroupSet) == 0 && len(changedPeerSet) == 0 {
return
}
services, proxyByCluster, ok := loadProxyServiceContext(ctx, transaction, accountID)
if !ok {
return
}
expandedPeerSet := expandChangedPeersWithGroups(ctx, transaction, accountID, changedGroupSet, changedPeerSet)
for _, svc := range services {
if svc == nil {
continue
}
proxyPeers := proxyByCluster[svc.ProxyCluster]
if len(proxyPeers) == 0 {
continue
}
if !serviceMatchesChangedPeers(svc, proxyPeers, expandedPeerSet) {
continue
}
log.WithContext(ctx).Tracef("collectAffectedFromProxyServices: service %s (cluster=%s) matched; folding %d proxy peers and target peers",
svc.ID, svc.ProxyCluster, len(proxyPeers))
addServicePeersToSet(svc, proxyPeers, peerSet)
}
}
func loadProxyServiceContext(ctx context.Context, transaction store.Store, accountID string) ([]*rpservice.Service, map[string][]string, bool) {
services, err := transaction.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get services for affected group resolution: %v", err)
return nil, nil, false
}
if len(services) == 0 {
return nil, nil, false
}
proxyByCluster, err := transaction.GetEmbeddedProxyPeerIDsByCluster(ctx, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get embedded proxy peers for affected group resolution: %v", err)
return nil, nil, false
}
if len(proxyByCluster) == 0 {
return nil, nil, false
}
return services, proxyByCluster, true
}
func expandChangedPeersWithGroups(ctx context.Context, transaction store.Store, accountID string, changedGroupSet, changedPeerSet map[string]struct{}) map[string]struct{} {
if len(changedGroupSet) == 0 {
return changedPeerSet
}
ids, err := transaction.GetPeerIDsByGroups(ctx, accountID, setToSlice(changedGroupSet))
if err != nil {
log.WithContext(ctx).Errorf("failed to expand changed groups to peers for service resolution: %v", err)
return changedPeerSet
}
if len(ids) == 0 {
return changedPeerSet
}
merged := make(map[string]struct{}, len(changedPeerSet)+len(ids))
for id := range changedPeerSet {
merged[id] = struct{}{}
}
for _, id := range ids {
merged[id] = struct{}{}
}
return merged
}
func serviceMatchesChangedPeers(svc *rpservice.Service, proxyPeers []string, changedPeers map[string]struct{}) bool {
for _, pid := range proxyPeers {
if _, ok := changedPeers[pid]; ok {
return true
}
}
for _, target := range svc.Targets {
if target.TargetType != rpservice.TargetTypePeer || target.TargetId == "" {
continue
}
if _, ok := changedPeers[target.TargetId]; ok {
return true
}
}
return false
}
func addServicePeersToSet(svc *rpservice.Service, proxyPeers []string, peerSet map[string]struct{}) {
for _, pid := range proxyPeers {
peerSet[pid] = struct{}{}
}
for _, target := range svc.Targets {
if target.TargetType == rpservice.TargetTypePeer && target.TargetId != "" {
peerSet[target.TargetId] = struct{}{}
}
}
}
func collectPolicyDirectPeers(policy *types.Policy, peerSet map[string]struct{}) {
for _, rule := range policy.Rules {
if rule.SourceResource.Type == types.ResourceTypePeer && rule.SourceResource.ID != "" {
peerSet[rule.SourceResource.ID] = struct{}{}
}
if rule.DestinationResource.Type == types.ResourceTypePeer && rule.DestinationResource.ID != "" {
peerSet[rule.DestinationResource.ID] = struct{}{}
}
}
}
func policyReferencesGroups(policy *types.Policy, groupSet map[string]struct{}) bool {
for _, rule := range policy.Rules {
if anyInSet(rule.Sources, groupSet) || anyInSet(rule.Destinations, groupSet) {
return true
}
}
return false
}
func policyReferencesDirectPeers(policy *types.Policy, changedSet map[string]struct{}) bool {
for _, rule := range policy.Rules {
if isDirectPeerInSet(rule.SourceResource, changedSet) || isDirectPeerInSet(rule.DestinationResource, changedSet) {
return true
}
}
return false
}
func isDirectPeerInSet(res types.Resource, set map[string]struct{}) bool {
if res.Type != types.ResourceTypePeer || res.ID == "" {
return false
}
_, ok := set[res.ID]
return ok
}
func routeReferencesGroups(r *route.Route, groupSet map[string]struct{}) bool {
return anyInSet(r.Groups, groupSet) || anyInSet(r.PeerGroups, groupSet) || anyInSet(r.AccessControlGroups, groupSet)
}
func routerReferencesGroups(router *routerTypes.NetworkRouter, groupSet map[string]struct{}) bool {
return anyInSet(router.PeerGroups, groupSet)
}
func anyInSet(ids []string, set map[string]struct{}) bool {
for _, id := range ids {
if _, ok := set[id]; ok {
return true
}
}
return false
}
func isInSet(id string, set map[string]struct{}) bool {
_, ok := set[id]
return ok
}
func addAllToSet(set map[string]struct{}, slices ...[]string) {
for _, s := range slices {
for _, id := range s {
set[id] = struct{}{}
}
}
}
func toSet(ids []string) map[string]struct{} {
set := make(map[string]struct{}, len(ids))
for _, id := range ids {
set[id] = struct{}{}
}
return set
}
func setToSlice(set map[string]struct{}) []string {
s := make([]string, 0, len(set))
for id := range set {
s = append(s, id)
}
return s
}

File diff suppressed because it is too large Load Diff

View File

@@ -47,8 +47,8 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID
return status.NewPermissionDeniedError()
}
var updateAccountPeers bool
var eventsToStore []func()
var affectedPeerIDs []string
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if err = validateDNSSettings(ctx, transaction, accountID, dnsSettingsToSave); err != nil {
@@ -63,6 +63,11 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID
addedGroups := util.Difference(dnsSettingsToSave.DisabledManagementGroups, oldSettings.DisabledManagementGroups)
removedGroups := util.Difference(oldSettings.DisabledManagementGroups, dnsSettingsToSave.DisabledManagementGroups)
updateAccountPeers, err = areDNSSettingChangesAffectPeers(ctx, transaction, accountID, addedGroups, removedGroups)
if err != nil {
return err
}
events := am.prepareDNSSettingsEvents(ctx, transaction, accountID, userID, addedGroups, removedGroups)
eventsToStore = append(eventsToStore, events...)
@@ -70,9 +75,6 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID
return err
}
allGroups := slices.Concat(addedGroups, removedGroups)
affectedPeerIDs = am.resolvePeerIDs(ctx, transaction, accountID, allGroups, nil)
return transaction.IncrementNetworkSerial(ctx, accountID)
})
if err != nil {
@@ -83,11 +85,8 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID
storeEvent()
}
if len(affectedPeerIDs) > 0 {
log.WithContext(ctx).Debugf("SaveDNSSettings: updating %d affected peers: %v", len(affectedPeerIDs), affectedPeerIDs)
am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
} else {
log.WithContext(ctx).Tracef("SaveDNSSettings: no affected peers")
if updateAccountPeers {
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceDNSSettings, Operation: types.UpdateOperationUpdate})
}
return nil
@@ -134,6 +133,20 @@ func (am *DefaultAccountManager) prepareDNSSettingsEvents(ctx context.Context, t
return eventsToStore
}
// areDNSSettingChangesAffectPeers checks if the DNS settings changes affect any peers.
func areDNSSettingChangesAffectPeers(ctx context.Context, transaction store.Store, accountID string, addedGroups, removedGroups []string) (bool, error) {
hasPeers, err := anyGroupHasPeersOrResources(ctx, transaction, accountID, addedGroups)
if err != nil {
return false, err
}
if hasPeers {
return true, nil
}
return anyGroupHasPeersOrResources(ctx, transaction, accountID, removedGroups)
}
// validateDNSSettings validates the DNS settings.
func validateDNSSettings(ctx context.Context, transaction store.Store, accountID string, settings *types.DNSSettings) error {
if len(settings.DisabledManagementGroups) == 0 {

View File

@@ -9,12 +9,15 @@ import (
"github.com/rs/xid"
log "github.com/sirupsen/logrus"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/activity"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/util"
"github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/status"
)
@@ -76,7 +79,7 @@ func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, use
}
var eventsToStore []func()
var affectedPeerIDs []string
var updateAccountPeers bool
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil {
@@ -88,6 +91,11 @@ func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, use
events := am.prepareGroupEvents(ctx, transaction, accountID, userID, newGroup)
eventsToStore = append(eventsToStore, events...)
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{newGroup.ID})
if err != nil {
return err
}
if err := transaction.CreateGroup(ctx, newGroup); err != nil {
return status.Errorf(status.Internal, "failed to create group: %v", err)
}
@@ -98,9 +106,6 @@ func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, use
}
}
groupIDs, directPeerIDs := collectGroupChangeAffectedGroups(ctx, transaction, accountID, []string{newGroup.ID})
affectedPeerIDs = am.resolvePeerIDs(ctx, transaction, accountID, groupIDs, directPeerIDs)
return transaction.IncrementNetworkSerial(ctx, accountID)
})
if err != nil {
@@ -111,11 +116,8 @@ func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, use
storeEvent()
}
if len(affectedPeerIDs) > 0 {
log.WithContext(ctx).Debugf("CreateGroup %s: updating %d affected peers: %v", newGroup.ID, len(affectedPeerIDs), affectedPeerIDs)
am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
} else {
log.WithContext(ctx).Tracef("CreateGroup %s: no affected peers", newGroup.ID)
if updateAccountPeers {
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceGroup, Operation: types.UpdateOperationCreate})
}
return nil
@@ -132,7 +134,7 @@ func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, use
}
var eventsToStore []func()
var affectedPeerIDs []string
var updateAccountPeers bool
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil {
@@ -163,6 +165,11 @@ func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, use
}
}
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{newGroup.ID})
if err != nil {
return err
}
if err = transaction.UpdateGroup(ctx, newGroup); err != nil {
return err
}
@@ -171,9 +178,6 @@ func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, use
return err
}
groupIDs, directPeerIDs := collectGroupChangeAffectedGroups(ctx, transaction, accountID, []string{newGroup.ID})
affectedPeerIDs = am.resolvePeerIDs(ctx, transaction, accountID, groupIDs, append(directPeerIDs, peersToRemove...))
return transaction.IncrementNetworkSerial(ctx, accountID)
})
if err != nil {
@@ -184,11 +188,8 @@ func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, use
storeEvent()
}
if len(affectedPeerIDs) > 0 {
log.WithContext(ctx).Debugf("UpdateGroup %s: updating %d affected peers: %v", newGroup.ID, len(affectedPeerIDs), affectedPeerIDs)
am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
} else {
log.WithContext(ctx).Tracef("UpdateGroup %s: no affected peers", newGroup.ID)
if updateAccountPeers {
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceGroup, Operation: types.UpdateOperationUpdate})
}
return nil
@@ -208,6 +209,7 @@ func (am *DefaultAccountManager) CreateGroups(ctx context.Context, accountID, us
}
var eventsToStore []func()
var updateAccountPeers bool
var globalErr error
groupIDs := make([]string, 0, len(groups))
@@ -245,17 +247,17 @@ func (am *DefaultAccountManager) CreateGroups(ctx context.Context, accountID, us
}
}
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, am.Store, accountID, groupIDs)
if err != nil {
return err
}
for _, storeEvent := range eventsToStore {
storeEvent()
}
allGroupIDs, directPeerIDs := collectGroupChangeAffectedGroups(ctx, am.Store, accountID, groupIDs)
affectedPeerIDs := am.resolvePeerIDs(ctx, am.Store, accountID, allGroupIDs, directPeerIDs)
if len(affectedPeerIDs) > 0 {
log.WithContext(ctx).Debugf("CreateGroups %v: updating %d affected peers: %v", groupIDs, len(affectedPeerIDs), affectedPeerIDs)
am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
} else {
log.WithContext(ctx).Tracef("CreateGroups %v: no affected peers", groupIDs)
if updateAccountPeers {
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceGroup, Operation: types.UpdateOperationCreate})
}
return globalErr
@@ -275,6 +277,7 @@ func (am *DefaultAccountManager) UpdateGroups(ctx context.Context, accountID, us
}
var eventsToStore []func()
var updateAccountPeers bool
var globalErr error
groupIDs := make([]string, 0, len(groups))
@@ -292,17 +295,17 @@ func (am *DefaultAccountManager) UpdateGroups(ctx context.Context, accountID, us
groupIDs = append(groupIDs, newGroup.ID)
}
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, am.Store, accountID, groupIDs)
if err != nil {
return err
}
for _, storeEvent := range eventsToStore {
storeEvent()
}
allGroupIDs, directPeerIDs := collectGroupChangeAffectedGroups(ctx, am.Store, accountID, groupIDs)
affectedPeerIDs := am.resolvePeerIDs(ctx, am.Store, accountID, allGroupIDs, directPeerIDs)
if len(affectedPeerIDs) > 0 {
log.WithContext(ctx).Debugf("UpdateGroups %v: updating %d affected peers: %v", groupIDs, len(affectedPeerIDs), affectedPeerIDs)
am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
} else {
log.WithContext(ctx).Tracef("UpdateGroups %v: no affected peers", groupIDs)
if updateAccountPeers {
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceGroup, Operation: types.UpdateOperationUpdate})
}
return globalErr
@@ -485,10 +488,15 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us
// GroupAddPeer appends peer to the group
func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, groupID, peerID string) error {
var affectedPeerIDs []string
var updateAccountPeers bool
var err error
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{groupID})
if err != nil {
return err
}
if err = transaction.AddPeerToGroup(ctx, accountID, peerID, groupID); err != nil {
return err
}
@@ -497,20 +505,14 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr
return err
}
allGroupIDs, directPeerIDs := collectGroupChangeAffectedGroups(ctx, transaction, accountID, []string{groupID})
affectedPeerIDs = am.resolvePeerIDs(ctx, transaction, accountID, allGroupIDs, directPeerIDs)
return transaction.IncrementNetworkSerial(ctx, accountID)
})
if err != nil {
return err
}
if len(affectedPeerIDs) > 0 {
log.WithContext(ctx).Debugf("GroupAddPeer group=%s peer=%s: updating %d affected peers: %v", groupID, peerID, len(affectedPeerIDs), affectedPeerIDs)
am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
} else {
log.WithContext(ctx).Tracef("GroupAddPeer group=%s peer=%s: no affected peers", groupID, peerID)
if updateAccountPeers {
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceGroup, Operation: types.UpdateOperationUpdate})
}
return nil
@@ -519,7 +521,7 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr
// GroupAddResource appends resource to the group
func (am *DefaultAccountManager) GroupAddResource(ctx context.Context, accountID, groupID string, resource types.Resource) error {
var group *types.Group
var affectedPeerIDs []string
var updateAccountPeers bool
var err error
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
@@ -532,12 +534,14 @@ func (am *DefaultAccountManager) GroupAddResource(ctx context.Context, accountID
return nil
}
if err = transaction.UpdateGroup(ctx, group); err != nil {
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{groupID})
if err != nil {
return err
}
allGroupIDs, directPeerIDs := collectGroupChangeAffectedGroups(ctx, transaction, accountID, []string{groupID})
affectedPeerIDs = am.resolvePeerIDs(ctx, transaction, accountID, allGroupIDs, directPeerIDs)
if err = transaction.UpdateGroup(ctx, group); err != nil {
return err
}
return transaction.IncrementNetworkSerial(ctx, accountID)
})
@@ -545,11 +549,8 @@ func (am *DefaultAccountManager) GroupAddResource(ctx context.Context, accountID
return err
}
if len(affectedPeerIDs) > 0 {
log.WithContext(ctx).Debugf("GroupAddResource group=%s resource=%s: updating %d affected peers: %v", groupID, resource.ID, len(affectedPeerIDs), affectedPeerIDs)
am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
} else {
log.WithContext(ctx).Tracef("GroupAddResource group=%s resource=%s: no affected peers", groupID, resource.ID)
if updateAccountPeers {
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceGroup, Operation: types.UpdateOperationUpdate})
}
return nil
@@ -557,13 +558,14 @@ func (am *DefaultAccountManager) GroupAddResource(ctx context.Context, accountID
// GroupDeletePeer removes peer from the group
func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID, groupID, peerID string) error {
var affectedPeerIDs []string
var updateAccountPeers bool
var err error
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
// Resolve before removing, so the peer being removed is still included
allGroupIDs, directPeerIDs := collectGroupChangeAffectedGroups(ctx, transaction, accountID, []string{groupID})
affectedPeerIDs = am.resolvePeerIDs(ctx, transaction, accountID, allGroupIDs, directPeerIDs)
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{groupID})
if err != nil {
return err
}
if err = transaction.RemovePeerFromGroup(ctx, peerID, groupID); err != nil {
return err
@@ -579,11 +581,8 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID,
return err
}
if len(affectedPeerIDs) > 0 {
log.WithContext(ctx).Debugf("GroupDeletePeer group=%s peer=%s: updating %d affected peers: %v", groupID, peerID, len(affectedPeerIDs), affectedPeerIDs)
am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
} else {
log.WithContext(ctx).Tracef("GroupDeletePeer group=%s peer=%s: no affected peers", groupID, peerID)
if updateAccountPeers {
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceGroup, Operation: types.UpdateOperationUpdate})
}
return nil
@@ -592,7 +591,7 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID,
// GroupDeleteResource removes resource from the group
func (am *DefaultAccountManager) GroupDeleteResource(ctx context.Context, accountID, groupID string, resource types.Resource) error {
var group *types.Group
var affectedPeerIDs []string
var updateAccountPeers bool
var err error
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
@@ -605,12 +604,14 @@ func (am *DefaultAccountManager) GroupDeleteResource(ctx context.Context, accoun
return nil
}
if err = transaction.UpdateGroup(ctx, group); err != nil {
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{groupID})
if err != nil {
return err
}
allGroupIDs, directPeerIDs := collectGroupChangeAffectedGroups(ctx, transaction, accountID, []string{groupID})
affectedPeerIDs = am.resolvePeerIDs(ctx, transaction, accountID, allGroupIDs, directPeerIDs)
if err = transaction.UpdateGroup(ctx, group); err != nil {
return err
}
return transaction.IncrementNetworkSerial(ctx, accountID)
})
@@ -618,11 +619,8 @@ func (am *DefaultAccountManager) GroupDeleteResource(ctx context.Context, accoun
return err
}
if len(affectedPeerIDs) > 0 {
log.WithContext(ctx).Debugf("GroupDeleteResource group=%s resource=%s: updating %d affected peers: %v", groupID, resource.ID, len(affectedPeerIDs), affectedPeerIDs)
am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
} else {
log.WithContext(ctx).Tracef("GroupDeleteResource group=%s resource=%s: no affected peers", groupID, resource.ID)
if updateAccountPeers {
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceGroup, Operation: types.UpdateOperationUpdate})
}
return nil
@@ -653,3 +651,230 @@ func validateNewGroup(ctx context.Context, transaction store.Store, accountID st
return nil
}
func validateDeleteGroup(ctx context.Context, transaction store.Store, group *types.Group, userID string, flowGroups []string) error {
// disable a deleting integration group if the initiator is not an admin service user
if group.Issued == types.GroupIssuedIntegration {
executingUser, err := transaction.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
if err != nil {
return status.Errorf(status.Internal, "failed to get user")
}
if executingUser.Role != types.UserRoleAdmin || !executingUser.IsServiceUser {
return status.Errorf(status.PermissionDenied, "only service users with admin power can delete integration group")
}
}
if group.IsGroupAll() {
return status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed")
}
if len(group.Resources) > 0 {
return &GroupLinkError{"network resource", group.Resources[0].ID}
}
if slices.Contains(flowGroups, group.ID) {
return &GroupLinkError{"settings", "traffic event logging"}
}
if isLinked, linkedRoute := isGroupLinkedToRoute(ctx, transaction, group.AccountID, group.ID); isLinked {
return &GroupLinkError{"route", string(linkedRoute.NetID)}
}
if isLinked, linkedDns := isGroupLinkedToDns(ctx, transaction, group.AccountID, group.ID); isLinked {
return &GroupLinkError{"name server groups", linkedDns.Name}
}
if isLinked, linkedPolicy := isGroupLinkedToPolicy(ctx, transaction, group.AccountID, group.ID); isLinked {
return &GroupLinkError{"policy", linkedPolicy.Name}
}
if isLinked, linkedSetupKey := isGroupLinkedToSetupKey(ctx, transaction, group.AccountID, group.ID); isLinked {
return &GroupLinkError{"setup key", linkedSetupKey.Name}
}
if isLinked, linkedUser := isGroupLinkedToUser(ctx, transaction, group.AccountID, group.ID); isLinked {
return &GroupLinkError{"user", linkedUser.Id}
}
if isLinked, linkedRouter := isGroupLinkedToNetworkRouter(ctx, transaction, group.AccountID, group.ID); isLinked {
return &GroupLinkError{"network router", linkedRouter.ID}
}
return checkGroupLinkedToSettings(ctx, transaction, group)
}
// checkGroupLinkedToSettings verifies if a group is linked to any settings in the account.
func checkGroupLinkedToSettings(ctx context.Context, transaction store.Store, group *types.Group) error {
dnsSettings, err := transaction.GetAccountDNSSettings(ctx, store.LockingStrengthNone, group.AccountID)
if err != nil {
return status.Errorf(status.Internal, "failed to get DNS settings")
}
if slices.Contains(dnsSettings.DisabledManagementGroups, group.ID) {
return &GroupLinkError{"disabled DNS management groups", group.Name}
}
settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthNone, group.AccountID)
if err != nil {
return status.Errorf(status.Internal, "failed to get account settings")
}
if settings.Extra != nil && slices.Contains(settings.Extra.IntegratedValidatorGroups, group.ID) {
return &GroupLinkError{"integrated validator", group.Name}
}
return nil
}
// isGroupLinkedToRoute checks if a group is linked to any route in the account.
func isGroupLinkedToRoute(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *route.Route) {
routes, err := transaction.GetAccountRoutes(ctx, store.LockingStrengthNone, accountID)
if err != nil {
log.WithContext(ctx).Errorf("error retrieving routes while checking group linkage: %v", err)
return false, nil
}
for _, r := range routes {
isLinked := slices.Contains(r.Groups, groupID) ||
slices.Contains(r.PeerGroups, groupID) ||
slices.Contains(r.AccessControlGroups, groupID)
if isLinked {
return true, r
}
}
return false, nil
}
// isGroupLinkedToPolicy checks if a group is linked to any policy in the account.
func isGroupLinkedToPolicy(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *types.Policy) {
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
if err != nil {
log.WithContext(ctx).Errorf("error retrieving policies while checking group linkage: %v", err)
return false, nil
}
for _, policy := range policies {
for _, rule := range policy.Rules {
if slices.Contains(rule.Sources, groupID) || slices.Contains(rule.Destinations, groupID) {
return true, policy
}
}
}
return false, nil
}
// isGroupLinkedToDns checks if a group is linked to any nameserver group in the account.
func isGroupLinkedToDns(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *nbdns.NameServerGroup) {
nameServerGroups, err := transaction.GetAccountNameServerGroups(ctx, store.LockingStrengthNone, accountID)
if err != nil {
log.WithContext(ctx).Errorf("error retrieving name server groups while checking group linkage: %v", err)
return false, nil
}
for _, dns := range nameServerGroups {
for _, g := range dns.Groups {
if g == groupID {
return true, dns
}
}
}
return false, nil
}
// isGroupLinkedToSetupKey checks if a group is linked to any setup key in the account.
func isGroupLinkedToSetupKey(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *types.SetupKey) {
setupKeys, err := transaction.GetAccountSetupKeys(ctx, store.LockingStrengthNone, accountID)
if err != nil {
log.WithContext(ctx).Errorf("error retrieving setup keys while checking group linkage: %v", err)
return false, nil
}
for _, setupKey := range setupKeys {
if slices.Contains(setupKey.AutoGroups, groupID) {
return true, setupKey
}
}
return false, nil
}
// isGroupLinkedToUser checks if a group is linked to any user in the account.
func isGroupLinkedToUser(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *types.User) {
users, err := transaction.GetAccountUsers(ctx, store.LockingStrengthNone, accountID)
if err != nil {
log.WithContext(ctx).Errorf("error retrieving users while checking group linkage: %v", err)
return false, nil
}
for _, user := range users {
if slices.Contains(user.AutoGroups, groupID) {
return true, user
}
}
return false, nil
}
// isGroupLinkedToNetworkRouter checks if a group is linked to any network router in the account.
func isGroupLinkedToNetworkRouter(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *routerTypes.NetworkRouter) {
routers, err := transaction.GetNetworkRoutersByAccountID(ctx, store.LockingStrengthNone, accountID)
if err != nil {
log.WithContext(ctx).Errorf("error retrieving network routers while checking group linkage: %v", err)
return false, nil
}
for _, router := range routers {
if slices.Contains(router.PeerGroups, groupID) {
return true, router
}
}
return false, nil
}
// areGroupChangesAffectPeers checks if any changes to the specified groups will affect peers.
func areGroupChangesAffectPeers(ctx context.Context, transaction store.Store, accountID string, groupIDs []string) (bool, error) {
if len(groupIDs) == 0 {
return false, nil
}
dnsSettings, err := transaction.GetAccountDNSSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return false, err
}
for _, groupID := range groupIDs {
if slices.Contains(dnsSettings.DisabledManagementGroups, groupID) {
return true, nil
}
if linked, _ := isGroupLinkedToDns(ctx, transaction, accountID, groupID); linked {
return true, nil
}
if linked, _ := isGroupLinkedToPolicy(ctx, transaction, accountID, groupID); linked {
return true, nil
}
if linked, _ := isGroupLinkedToRoute(ctx, transaction, accountID, groupID); linked {
return true, nil
}
if linked, _ := isGroupLinkedToNetworkRouter(ctx, transaction, accountID, groupID); linked {
return true, nil
}
}
return false, nil
}
// anyGroupHasPeersOrResources checks if any of the given groups in the account have peers or resources.
func anyGroupHasPeersOrResources(ctx context.Context, transaction store.Store, accountID string, groupIDs []string) (bool, error) {
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, groupIDs)
if err != nil {
return false, err
}
for _, group := range groups {
if group.HasPeers() || group.HasResources() {
return true, nil
}
}
return false, nil
}

View File

@@ -1,287 +0,0 @@
package server
import (
"context"
"slices"
log "github.com/sirupsen/logrus"
nbdns "github.com/netbirdio/netbird/dns"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/status"
)
func validateDeleteGroup(ctx context.Context, transaction store.Store, group *types.Group, userID string, flowGroups []string) error {
// disable a deleting integration group if the initiator is not an admin service user
if group.Issued == types.GroupIssuedIntegration {
executingUser, err := transaction.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
if err != nil {
return status.Errorf(status.Internal, "failed to get user")
}
if executingUser.Role != types.UserRoleAdmin || !executingUser.IsServiceUser {
return status.Errorf(status.PermissionDenied, "only service users with admin power can delete integration group")
}
}
if group.IsGroupAll() {
return status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed")
}
if len(group.Resources) > 0 {
return &GroupLinkError{"network resource", group.Resources[0].ID}
}
if slices.Contains(flowGroups, group.ID) {
return &GroupLinkError{"settings", "traffic event logging"}
}
if isLinked, linkedRoute := isGroupLinkedToRoute(ctx, transaction, group.AccountID, group.ID); isLinked {
return &GroupLinkError{"route", string(linkedRoute.NetID)}
}
if isLinked, linkedDns := isGroupLinkedToDns(ctx, transaction, group.AccountID, group.ID); isLinked {
return &GroupLinkError{"name server groups", linkedDns.Name}
}
if isLinked, linkedPolicy := isGroupLinkedToPolicy(ctx, transaction, group.AccountID, group.ID); isLinked {
return &GroupLinkError{"policy", linkedPolicy.Name}
}
if isLinked, linkedSetupKey := isGroupLinkedToSetupKey(ctx, transaction, group.AccountID, group.ID); isLinked {
return &GroupLinkError{"setup key", linkedSetupKey.Name}
}
if isLinked, linkedUser := isGroupLinkedToUser(ctx, transaction, group.AccountID, group.ID); isLinked {
return &GroupLinkError{"user", linkedUser.Id}
}
if isLinked, linkedRouter := isGroupLinkedToNetworkRouter(ctx, transaction, group.AccountID, group.ID); isLinked {
return &GroupLinkError{"network router", linkedRouter.ID}
}
return checkGroupLinkedToSettings(ctx, transaction, group)
}
// checkGroupLinkedToSettings verifies if a group is linked to any settings in the account.
func checkGroupLinkedToSettings(ctx context.Context, transaction store.Store, group *types.Group) error {
dnsSettings, err := transaction.GetAccountDNSSettings(ctx, store.LockingStrengthNone, group.AccountID)
if err != nil {
return status.Errorf(status.Internal, "failed to get DNS settings")
}
if slices.Contains(dnsSettings.DisabledManagementGroups, group.ID) {
return &GroupLinkError{"disabled DNS management groups", group.Name}
}
settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthNone, group.AccountID)
if err != nil {
return status.Errorf(status.Internal, "failed to get account settings")
}
if settings.Extra != nil && slices.Contains(settings.Extra.IntegratedValidatorGroups, group.ID) {
return &GroupLinkError{"integrated validator", group.Name}
}
return nil
}
// isGroupLinkedToRoute checks if a group is linked to any route in the account.
func isGroupLinkedToRoute(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *route.Route) {
routes, err := transaction.GetAccountRoutes(ctx, store.LockingStrengthNone, accountID)
if err != nil {
log.WithContext(ctx).Errorf("error retrieving routes while checking group linkage: %v", err)
return false, nil
}
for _, r := range routes {
isLinked := slices.Contains(r.Groups, groupID) ||
slices.Contains(r.PeerGroups, groupID) ||
slices.Contains(r.AccessControlGroups, groupID)
if isLinked {
return true, r
}
}
return false, nil
}
// isGroupLinkedToPolicy checks if a group is linked to any policy in the account.
func isGroupLinkedToPolicy(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *types.Policy) {
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
if err != nil {
log.WithContext(ctx).Errorf("error retrieving policies while checking group linkage: %v", err)
return false, nil
}
for _, policy := range policies {
for _, rule := range policy.Rules {
if slices.Contains(rule.Sources, groupID) || slices.Contains(rule.Destinations, groupID) {
return true, policy
}
}
}
return false, nil
}
// isGroupLinkedToDns checks if a group is linked to any nameserver group in the account.
func isGroupLinkedToDns(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *nbdns.NameServerGroup) {
nameServerGroups, err := transaction.GetAccountNameServerGroups(ctx, store.LockingStrengthNone, accountID)
if err != nil {
log.WithContext(ctx).Errorf("error retrieving name server groups while checking group linkage: %v", err)
return false, nil
}
for _, dns := range nameServerGroups {
for _, g := range dns.Groups {
if g == groupID {
return true, dns
}
}
}
return false, nil
}
// isGroupLinkedToSetupKey checks if a group is linked to any setup key in the account.
func isGroupLinkedToSetupKey(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *types.SetupKey) {
setupKeys, err := transaction.GetAccountSetupKeys(ctx, store.LockingStrengthNone, accountID)
if err != nil {
log.WithContext(ctx).Errorf("error retrieving setup keys while checking group linkage: %v", err)
return false, nil
}
for _, setupKey := range setupKeys {
if slices.Contains(setupKey.AutoGroups, groupID) {
return true, setupKey
}
}
return false, nil
}
// isGroupLinkedToUser checks if a group is linked to any user in the account.
func isGroupLinkedToUser(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *types.User) {
users, err := transaction.GetAccountUsers(ctx, store.LockingStrengthNone, accountID)
if err != nil {
log.WithContext(ctx).Errorf("error retrieving users while checking group linkage: %v", err)
return false, nil
}
for _, user := range users {
if slices.Contains(user.AutoGroups, groupID) {
return true, user
}
}
return false, nil
}
// isGroupLinkedToNetworkRouter checks if a group is linked to any network router in the account.
func isGroupLinkedToNetworkRouter(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *routerTypes.NetworkRouter) {
routers, err := transaction.GetNetworkRoutersByAccountID(ctx, store.LockingStrengthNone, accountID)
if err != nil {
log.WithContext(ctx).Errorf("error retrieving network routers while checking group linkage: %v", err)
return false, nil
}
for _, router := range routers {
if slices.Contains(router.PeerGroups, groupID) {
return true, router
}
}
return false, nil
}
// areGroupChangesAffectPeers checks if any changes to the specified groups will affect peers.
// It fetches each collection once and checks all groupIDs against them in memory.
func areGroupChangesAffectPeers(ctx context.Context, transaction store.Store, accountID string, groupIDs []string) (bool, error) {
if len(groupIDs) == 0 {
return false, nil
}
groupSet := make(map[string]struct{}, len(groupIDs))
for _, id := range groupIDs {
groupSet[id] = struct{}{}
}
if affected, err := dnsSettingsReferenceGroups(ctx, transaction, accountID, groupSet); affected || err != nil {
return affected, err
}
if affected, err := nameServersReferenceGroups(ctx, transaction, accountID, groupSet); affected || err != nil {
return affected, err
}
if affected, err := policiesReferenceGroups(ctx, transaction, accountID, groupSet); affected || err != nil {
return affected, err
}
if affected, err := routesReferenceGroups(ctx, transaction, accountID, groupSet); affected || err != nil {
return affected, err
}
if affected, err := networkRoutersReferenceGroups(ctx, transaction, accountID, groupSet); affected || err != nil {
return affected, err
}
return false, nil
}
func dnsSettingsReferenceGroups(ctx context.Context, transaction store.Store, accountID string, groupSet map[string]struct{}) (bool, error) {
dnsSettings, err := transaction.GetAccountDNSSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return false, err
}
return anyInSet(dnsSettings.DisabledManagementGroups, groupSet), nil
}
func nameServersReferenceGroups(ctx context.Context, transaction store.Store, accountID string, groupSet map[string]struct{}) (bool, error) {
nameServerGroups, err := transaction.GetAccountNameServerGroups(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return false, err
}
for _, ns := range nameServerGroups {
if anyInSet(ns.Groups, groupSet) {
return true, nil
}
}
return false, nil
}
func policiesReferenceGroups(ctx context.Context, transaction store.Store, accountID string, groupSet map[string]struct{}) (bool, error) {
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return false, err
}
for _, policy := range policies {
for _, rule := range policy.Rules {
if anyInSet(rule.Sources, groupSet) || anyInSet(rule.Destinations, groupSet) {
return true, nil
}
}
}
return false, nil
}
func routesReferenceGroups(ctx context.Context, transaction store.Store, accountID string, groupSet map[string]struct{}) (bool, error) {
routes, err := transaction.GetAccountRoutes(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return false, err
}
for _, r := range routes {
if anyInSet(r.Groups, groupSet) || anyInSet(r.PeerGroups, groupSet) || anyInSet(r.AccessControlGroups, groupSet) {
return true, nil
}
}
return false, nil
}
func networkRoutersReferenceGroups(ctx context.Context, transaction store.Store, accountID string, groupSet map[string]struct{}) (bool, error) {
routers, err := transaction.GetNetworkRoutersByAccountID(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return false, err
}
for _, router := range routers {
if anyInSet(router.PeerGroups, groupSet) {
return true, nil
}
}
return false, nil
}

View File

@@ -131,8 +131,6 @@ type MockAccountManager struct {
AllowSyncFunc func(string, uint64) bool
UpdateAccountPeersFunc func(ctx context.Context, accountID string, reason types.UpdateReason)
UpdateAffectedPeersFunc func(ctx context.Context, accountID string, peerIDs []string)
BufferUpdateAffectedPeersFunc func(ctx context.Context, accountID string, peerIDs []string, reason types.UpdateReason)
BufferUpdateAccountPeersFunc func(ctx context.Context, accountID string, reason types.UpdateReason)
RecalculateNetworkMapCacheFunc func(ctx context.Context, accountId string) error
@@ -210,18 +208,6 @@ func (am *MockAccountManager) UpdateAccountPeers(ctx context.Context, accountID
}
}
func (am *MockAccountManager) UpdateAffectedPeers(ctx context.Context, accountID string, peerIDs []string) {
if am.UpdateAffectedPeersFunc != nil {
am.UpdateAffectedPeersFunc(ctx, accountID, peerIDs)
}
}
func (am *MockAccountManager) BufferUpdateAffectedPeers(ctx context.Context, accountID string, peerIDs []string, reason types.UpdateReason) {
if am.BufferUpdateAffectedPeersFunc != nil {
am.BufferUpdateAffectedPeersFunc(ctx, accountID, peerIDs, reason)
}
}
func (am *MockAccountManager) BufferUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) {
if am.BufferUpdateAccountPeersFunc != nil {
am.BufferUpdateAccountPeersFunc(ctx, accountID, reason)

View File

@@ -4,12 +4,10 @@ import (
"context"
"errors"
"fmt"
"slices"
"strings"
"unicode/utf8"
"github.com/rs/xid"
log "github.com/sirupsen/logrus"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/activity"
@@ -59,18 +57,21 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco
SearchDomainsEnabled: searchDomainEnabled,
}
var affectedPeerIDs []string
var updateAccountPeers bool
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if err = validateNameServerGroup(ctx, transaction, accountID, newNSGroup); err != nil {
return err
}
if err = transaction.SaveNameServerGroup(ctx, newNSGroup); err != nil {
updateAccountPeers, err = anyGroupHasPeersOrResources(ctx, transaction, accountID, newNSGroup.Groups)
if err != nil {
return err
}
affectedPeerIDs = am.resolvePeerIDs(ctx, transaction, accountID, newNSGroup.Groups, nil)
if err = transaction.SaveNameServerGroup(ctx, newNSGroup); err != nil {
return err
}
return transaction.IncrementNetworkSerial(ctx, accountID)
})
@@ -80,11 +81,8 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco
am.StoreEvent(ctx, userID, newNSGroup.ID, accountID, activity.NameserverGroupCreated, newNSGroup.EventMeta())
if len(affectedPeerIDs) > 0 {
log.WithContext(ctx).Debugf("CreateNameServerGroup %s: updating %d affected peers: %v", newNSGroup.ID, len(affectedPeerIDs), affectedPeerIDs)
am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
} else {
log.WithContext(ctx).Tracef("CreateNameServerGroup %s: no affected peers", newNSGroup.ID)
if updateAccountPeers {
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceNameServerGroup, Operation: types.UpdateOperationCreate})
}
return newNSGroup.Copy(), nil
@@ -104,7 +102,7 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun
return status.NewPermissionDeniedError()
}
var affectedPeerIDs []string
var updateAccountPeers bool
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
oldNSGroup, err := transaction.GetNameServerGroupByID(ctx, store.LockingStrengthNone, accountID, nsGroupToSave.ID)
@@ -117,12 +115,14 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun
return err
}
if err = transaction.SaveNameServerGroup(ctx, nsGroupToSave); err != nil {
updateAccountPeers, err = areNameServerGroupChangesAffectPeers(ctx, transaction, nsGroupToSave, oldNSGroup)
if err != nil {
return err
}
allGroups := slices.Concat(nsGroupToSave.Groups, oldNSGroup.Groups)
affectedPeerIDs = am.resolvePeerIDs(ctx, transaction, accountID, allGroups, nil)
if err = transaction.SaveNameServerGroup(ctx, nsGroupToSave); err != nil {
return err
}
return transaction.IncrementNetworkSerial(ctx, accountID)
})
@@ -132,11 +132,8 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun
am.StoreEvent(ctx, userID, nsGroupToSave.ID, accountID, activity.NameserverGroupUpdated, nsGroupToSave.EventMeta())
if len(affectedPeerIDs) > 0 {
log.WithContext(ctx).Debugf("SaveNameServerGroup %s: updating %d affected peers: %v", nsGroupToSave.ID, len(affectedPeerIDs), affectedPeerIDs)
am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
} else {
log.WithContext(ctx).Tracef("SaveNameServerGroup %s: no affected peers", nsGroupToSave.ID)
if updateAccountPeers {
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceNameServerGroup, Operation: types.UpdateOperationUpdate})
}
return nil
@@ -153,7 +150,7 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco
}
var nsGroup *nbdns.NameServerGroup
var affectedPeerIDs []string
var updateAccountPeers bool
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
nsGroup, err = transaction.GetNameServerGroupByID(ctx, store.LockingStrengthUpdate, accountID, nsGroupID)
@@ -161,7 +158,10 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco
return err
}
affectedPeerIDs = am.resolvePeerIDs(ctx, transaction, accountID, nsGroup.Groups, nil)
updateAccountPeers, err = anyGroupHasPeersOrResources(ctx, transaction, accountID, nsGroup.Groups)
if err != nil {
return err
}
if err = transaction.DeleteNameServerGroup(ctx, accountID, nsGroupID); err != nil {
return err
@@ -175,11 +175,8 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco
am.StoreEvent(ctx, userID, nsGroup.ID, accountID, activity.NameserverGroupDeleted, nsGroup.EventMeta())
if len(affectedPeerIDs) > 0 {
log.WithContext(ctx).Debugf("DeleteNameServerGroup %s: updating %d affected peers: %v", nsGroupID, len(affectedPeerIDs), affectedPeerIDs)
am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
} else {
log.WithContext(ctx).Tracef("DeleteNameServerGroup %s: no affected peers", nsGroupID)
if updateAccountPeers {
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceNameServerGroup, Operation: types.UpdateOperationDelete})
}
return nil
@@ -227,6 +224,24 @@ func validateNameServerGroup(ctx context.Context, transaction store.Store, accou
return validateGroups(nameserverGroup.Groups, groups)
}
// areNameServerGroupChangesAffectPeers checks if the changes in the nameserver group affect the peers.
func areNameServerGroupChangesAffectPeers(ctx context.Context, transaction store.Store, newNSGroup, oldNSGroup *nbdns.NameServerGroup) (bool, error) {
if !newNSGroup.Enabled && !oldNSGroup.Enabled {
return false, nil
}
hasPeers, err := anyGroupHasPeersOrResources(ctx, transaction, newNSGroup.AccountID, newNSGroup.Groups)
if err != nil {
return false, err
}
if hasPeers {
return true, nil
}
return anyGroupHasPeersOrResources(ctx, transaction, oldNSGroup.AccountID, oldNSGroup.Groups)
}
func validateDomainInput(primary bool, domains []string, searchDomainsEnabled bool) error {
if !primary && len(domains) == 0 {
return status.Errorf(status.InvalidArgument, "nameserver group primary status is false and domains are empty,"+

View File

@@ -5,7 +5,6 @@ import (
"fmt"
"github.com/rs/xid"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity"
@@ -16,7 +15,7 @@ import (
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store"
nbTypes "github.com/netbirdio/netbird/management/server/types"
serverTypes "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/status"
)
@@ -113,14 +112,6 @@ func (m *managerImpl) UpdateNetwork(ctx context.Context, userID string, network
return network, m.store.SaveNetwork(ctx, network)
}
// networkAffectedPeersData holds data loaded inside the transaction for affected peer resolution.
type networkAffectedPeersData struct {
resourceGroupIDs []string
routerPeerGroups []string
directPeerIDs []string
policies []*nbTypes.Policy
}
func (m *managerImpl) DeleteNetwork(ctx context.Context, accountID, userID, networkID string) error {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Networks, operations.Delete)
if err != nil {
@@ -136,22 +127,13 @@ func (m *managerImpl) DeleteNetwork(ctx context.Context, accountID, userID, netw
}
var eventsToStore []func()
var affectedData *networkAffectedPeersData
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
resources, err := transaction.GetNetworkResourcesByNetID(ctx, store.LockingStrengthUpdate, accountID, networkID)
if err != nil {
return fmt.Errorf("failed to get resources in network: %w", err)
}
var resourceGroupIDs []string
for _, resource := range resources {
groups, err := transaction.GetResourceGroups(ctx, store.LockingStrengthNone, accountID, resource.ID)
if err == nil {
for _, g := range groups {
resourceGroupIDs = append(resourceGroupIDs, g.ID)
}
}
event, err := m.resourcesManager.DeleteResourceInTransaction(ctx, transaction, accountID, userID, networkID, resource.ID)
if err != nil {
return fmt.Errorf("failed to delete resource: %w", err)
@@ -159,19 +141,12 @@ func (m *managerImpl) DeleteNetwork(ctx context.Context, accountID, userID, netw
eventsToStore = append(eventsToStore, event...)
}
netRouters, err := transaction.GetNetworkRoutersByNetID(ctx, store.LockingStrengthUpdate, accountID, networkID)
routers, err := transaction.GetNetworkRoutersByNetID(ctx, store.LockingStrengthUpdate, accountID, networkID)
if err != nil {
return fmt.Errorf("failed to get routers in network: %w", err)
}
var routerPeerGroups []string
var directPeerIDs []string
for _, router := range netRouters {
routerPeerGroups = append(routerPeerGroups, router.PeerGroups...)
if router.Peer != "" {
directPeerIDs = append(directPeerIDs, router.Peer)
}
for _, router := range routers {
event, err := m.routersManager.DeleteRouterInTransaction(ctx, transaction, accountID, userID, networkID, router.ID)
if err != nil {
return fmt.Errorf("failed to delete router: %w", err)
@@ -179,24 +154,6 @@ func (m *managerImpl) DeleteNetwork(ctx context.Context, accountID, userID, netw
eventsToStore = append(eventsToStore, event)
}
// load policies before deleting so group memberships are still present
var policies []*nbTypes.Policy
if len(resourceGroupIDs) > 0 {
policies, err = transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get policies for affected peers: %v", err)
}
}
if len(resourceGroupIDs) > 0 || len(routerPeerGroups) > 0 || len(directPeerIDs) > 0 {
affectedData = &networkAffectedPeersData{
resourceGroupIDs: resourceGroupIDs,
routerPeerGroups: routerPeerGroups,
directPeerIDs: directPeerIDs,
policies: policies,
}
}
err = transaction.DeleteNetwork(ctx, accountID, networkID)
if err != nil {
return fmt.Errorf("failed to delete network: %w", err)
@@ -221,111 +178,11 @@ func (m *managerImpl) DeleteNetwork(ctx context.Context, accountID, userID, netw
event()
}
if affectedData != nil {
affectedPeerIDs := resolveNetworkAffectedPeers(ctx, m.store, accountID, affectedData)
if len(affectedPeerIDs) > 0 {
log.WithContext(ctx).Debugf("DeleteNetwork %s: updating %d affected peers: %v", networkID, len(affectedPeerIDs), affectedPeerIDs)
go m.accountManager.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
} else {
log.WithContext(ctx).Tracef("DeleteNetwork %s: no affected peers", networkID)
}
}
go m.accountManager.UpdateAccountPeers(ctx, accountID, serverTypes.UpdateReason{Resource: serverTypes.UpdateResourceNetwork, Operation: serverTypes.UpdateOperationDelete})
return nil
}
// resolveNetworkAffectedPeers computes affected peer IDs from preloaded data outside the transaction.
func resolveNetworkAffectedPeers(ctx context.Context, s store.Store, accountID string, data *networkAffectedPeersData) []string {
log.WithContext(ctx).Tracef("resolveNetworkAffectedPeers: routerPeerGroups=%v, resourceGroupIDs=%v, directPeerIDs=%v, policies=%d",
data.routerPeerGroups, data.resourceGroupIDs, data.directPeerIDs, len(data.policies))
groupSet := make(map[string]struct{})
for _, gID := range data.routerPeerGroups {
groupSet[gID] = struct{}{}
}
if len(data.resourceGroupIDs) > 0 {
for _, gID := range data.resourceGroupIDs {
groupSet[gID] = struct{}{}
}
collectPolicySourceGroups(data.policies, data.resourceGroupIDs, groupSet)
}
if len(groupSet) == 0 && len(data.directPeerIDs) == 0 {
return nil
}
peerIDs := resolveGroupsAndDirectPeers(ctx, s, accountID, groupSet, data.directPeerIDs)
log.WithContext(ctx).Tracef("resolveNetworkAffectedPeers: result %d peers: %v", len(peerIDs), peerIDs)
return peerIDs
}
// collectPolicySourceGroups finds policies whose rules reference any of the destination group IDs
// and adds their source groups to the groupSet.
func collectPolicySourceGroups(policies []*nbTypes.Policy, destGroupIDs []string, groupSet map[string]struct{}) {
destSet := make(map[string]struct{}, len(destGroupIDs))
for _, gID := range destGroupIDs {
destSet[gID] = struct{}{}
}
for _, policy := range policies {
if policy == nil || !policy.Enabled {
continue
}
for _, rule := range policy.Rules {
if rule == nil || !rule.Enabled {
continue
}
if ruleMatchesDestinations(rule, destSet) {
for _, gID := range rule.Sources {
groupSet[gID] = struct{}{}
}
}
}
}
}
// ruleMatchesDestinations checks if a policy rule references any of the destination groups.
func ruleMatchesDestinations(rule *nbTypes.PolicyRule, destSet map[string]struct{}) bool {
for _, gID := range rule.Destinations {
if _, ok := destSet[gID]; ok {
return true
}
}
return false
}
// resolveGroupsAndDirectPeers resolves group IDs and direct peer IDs into a deduplicated peer ID list.
func resolveGroupsAndDirectPeers(ctx context.Context, s store.Store, accountID string, groupSet map[string]struct{}, directPeerIDs []string) []string {
groupIDs := make([]string, 0, len(groupSet))
for gID := range groupSet {
groupIDs = append(groupIDs, gID)
}
peerIDs, err := s.GetPeerIDsByGroups(ctx, accountID, groupIDs)
if err != nil {
log.WithContext(ctx).Errorf("failed to resolve peer IDs: %v", err)
return nil
}
if len(directPeerIDs) == 0 {
return peerIDs
}
seen := make(map[string]struct{}, len(peerIDs))
for _, id := range peerIDs {
seen[id] = struct{}{}
}
for _, id := range directPeerIDs {
if _, exists := seen[id]; !exists {
peerIDs = append(peerIDs, id)
seen[id] = struct{}{}
}
}
return peerIDs
}
func NewManagerMock() Manager {
return &mockManager{}
}

View File

@@ -114,11 +114,45 @@ func (m *managerImpl) CreateResource(ctx context.Context, userID string, resourc
}
var eventsToStore []func()
var affectedData *resourceAffectedPeersData
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
var txErr error
eventsToStore, affectedData, txErr = m.createResourceInTransaction(ctx, transaction, userID, resource)
return txErr
_, err = transaction.GetNetworkResourceByName(ctx, store.LockingStrengthNone, resource.AccountID, resource.Name)
if err == nil {
return status.Errorf(status.InvalidArgument, "resource with name %s already exists", resource.Name)
}
network, err := transaction.GetNetworkByID(ctx, store.LockingStrengthUpdate, resource.AccountID, resource.NetworkID)
if err != nil {
return fmt.Errorf("failed to get network: %w", err)
}
err = transaction.SaveNetworkResource(ctx, resource)
if err != nil {
return fmt.Errorf("failed to save network resource: %w", err)
}
event := func() {
m.accountManager.StoreEvent(ctx, userID, resource.ID, resource.AccountID, activity.NetworkResourceCreated, resource.EventMeta(network))
}
eventsToStore = append(eventsToStore, event)
res := nbtypes.Resource{
ID: resource.ID,
Type: nbtypes.ResourceType(resource.Type.String()),
}
for _, groupID := range resource.GroupIDs {
event, err := m.groupsManager.AddResourceToGroupInTransaction(ctx, transaction, resource.AccountID, userID, groupID, &res)
if err != nil {
return fmt.Errorf("failed to add resource to group: %w", err)
}
eventsToStore = append(eventsToStore, event)
}
err = transaction.IncrementNetworkSerial(ctx, resource.AccountID)
if err != nil {
return fmt.Errorf("failed to increment network serial: %w", err)
}
return nil
})
if err != nil {
return nil, fmt.Errorf("failed to create network resource: %w", err)
@@ -128,60 +162,11 @@ func (m *managerImpl) CreateResource(ctx context.Context, userID string, resourc
event()
}
if affectedPeerIDs := m.resolveResourceAffectedPeers(ctx, resource.AccountID, affectedData); len(affectedPeerIDs) > 0 {
log.WithContext(ctx).Debugf("CreateResource %s: updating %d affected peers: %v", resource.ID, len(affectedPeerIDs), affectedPeerIDs)
go m.accountManager.UpdateAffectedPeers(ctx, resource.AccountID, affectedPeerIDs)
} else {
log.WithContext(ctx).Tracef("CreateResource %s: no affected peers", resource.ID)
}
go m.accountManager.UpdateAccountPeers(ctx, resource.AccountID, nbtypes.UpdateReason{Resource: nbtypes.UpdateResourceNetworkResource, Operation: nbtypes.UpdateOperationCreate})
return resource, nil
}
func (m *managerImpl) createResourceInTransaction(ctx context.Context, transaction store.Store, userID string, resource *types.NetworkResource) ([]func(), *resourceAffectedPeersData, error) {
_, err := transaction.GetNetworkResourceByName(ctx, store.LockingStrengthNone, resource.AccountID, resource.Name)
if err == nil {
return nil, nil, status.Errorf(status.InvalidArgument, "resource with name %s already exists", resource.Name)
}
network, err := transaction.GetNetworkByID(ctx, store.LockingStrengthUpdate, resource.AccountID, resource.NetworkID)
if err != nil {
return nil, nil, fmt.Errorf("failed to get network: %w", err)
}
if err = transaction.SaveNetworkResource(ctx, resource); err != nil {
return nil, nil, fmt.Errorf("failed to save network resource: %w", err)
}
var eventsToStore []func()
eventsToStore = append(eventsToStore, func() {
m.accountManager.StoreEvent(ctx, userID, resource.ID, resource.AccountID, activity.NetworkResourceCreated, resource.EventMeta(network))
})
res := nbtypes.Resource{
ID: resource.ID,
Type: nbtypes.ResourceType(resource.Type.String()),
}
for _, groupID := range resource.GroupIDs {
event, err := m.groupsManager.AddResourceToGroupInTransaction(ctx, transaction, resource.AccountID, userID, groupID, &res)
if err != nil {
return nil, nil, fmt.Errorf("failed to add resource to group: %w", err)
}
eventsToStore = append(eventsToStore, event)
}
if err = transaction.IncrementNetworkSerial(ctx, resource.AccountID); err != nil {
return nil, nil, fmt.Errorf("failed to increment network serial: %w", err)
}
affectedData, err := loadResourceAffectedPeersData(ctx, transaction, resource.AccountID, resource.NetworkID, resource.GroupIDs)
if err != nil {
log.WithContext(ctx).Errorf("failed to load affected peers data: %v", err)
}
return eventsToStore, affectedData, nil
}
func (m *managerImpl) GetResource(ctx context.Context, accountID, userID, networkID, resourceID string) (*types.NetworkResource, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Networks, operations.Read)
if err != nil {
@@ -222,7 +207,6 @@ func (m *managerImpl) UpdateResource(ctx context.Context, userID string, resourc
resource.Prefix = prefix
var eventsToStore []func()
var affectedData *resourceAffectedPeersData
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
network, err := transaction.GetNetworkByID(ctx, store.LockingStrengthUpdate, resource.AccountID, resource.NetworkID)
if err != nil {
@@ -248,15 +232,6 @@ func (m *managerImpl) UpdateResource(ctx context.Context, userID string, resourc
return fmt.Errorf("failed to get network resource: %w", err)
}
oldGroups, err := m.groupsManager.GetResourceGroupsInTransaction(ctx, transaction, store.LockingStrengthNone, oldResource.AccountID, oldResource.ID)
if err != nil {
return fmt.Errorf("failed to get old resource groups: %w", err)
}
var oldGroupIDs []string
for _, g := range oldGroups {
oldGroupIDs = append(oldGroupIDs, g.ID)
}
err = transaction.SaveNetworkResource(ctx, resource)
if err != nil {
return fmt.Errorf("failed to save network resource: %w", err)
@@ -272,11 +247,6 @@ func (m *managerImpl) UpdateResource(ctx context.Context, userID string, resourc
m.accountManager.StoreEvent(ctx, userID, resource.ID, resource.AccountID, activity.NetworkResourceUpdated, resource.EventMeta(network))
})
affectedData, err = loadResourceAffectedPeersData(ctx, transaction, resource.AccountID, resource.NetworkID, append(resource.GroupIDs, oldGroupIDs...))
if err != nil {
log.WithContext(ctx).Errorf("failed to load affected peers data: %v", err)
}
err = transaction.IncrementNetworkSerial(ctx, resource.AccountID)
if err != nil {
return fmt.Errorf("failed to increment network serial: %w", err)
@@ -300,12 +270,7 @@ func (m *managerImpl) UpdateResource(ctx context.Context, userID string, resourc
}
}()
if affectedPeerIDs := m.resolveResourceAffectedPeers(ctx, resource.AccountID, affectedData); len(affectedPeerIDs) > 0 {
log.WithContext(ctx).Debugf("UpdateResource %s: updating %d affected peers: %v", resource.ID, len(affectedPeerIDs), affectedPeerIDs)
go m.accountManager.UpdateAffectedPeers(ctx, resource.AccountID, affectedPeerIDs)
} else {
log.WithContext(ctx).Tracef("UpdateResource %s: no affected peers", resource.ID)
}
go m.accountManager.UpdateAccountPeers(ctx, resource.AccountID, nbtypes.UpdateReason{Resource: nbtypes.UpdateResourceNetworkResource, Operation: nbtypes.UpdateOperationUpdate})
return resource, nil
}
@@ -366,22 +331,7 @@ func (m *managerImpl) DeleteResource(ctx context.Context, accountID, userID, net
}
var events []func()
var affectedData *resourceAffectedPeersData
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
groups, err := m.groupsManager.GetResourceGroupsInTransaction(ctx, transaction, store.LockingStrengthNone, accountID, resourceID)
if err != nil {
return fmt.Errorf("failed to get resource groups: %w", err)
}
var resourceGroupIDs []string
for _, g := range groups {
resourceGroupIDs = append(resourceGroupIDs, g.ID)
}
affectedData, err = loadResourceAffectedPeersData(ctx, transaction, accountID, networkID, resourceGroupIDs)
if err != nil {
log.WithContext(ctx).Errorf("failed to load affected peers data: %v", err)
}
events, err = m.DeleteResourceInTransaction(ctx, transaction, accountID, userID, networkID, resourceID)
if err != nil {
return fmt.Errorf("failed to delete resource: %w", err)
@@ -402,12 +352,7 @@ func (m *managerImpl) DeleteResource(ctx context.Context, accountID, userID, net
event()
}
if affectedPeerIDs := m.resolveResourceAffectedPeers(ctx, accountID, affectedData); len(affectedPeerIDs) > 0 {
log.WithContext(ctx).Debugf("DeleteResource %s: updating %d affected peers: %v", resourceID, len(affectedPeerIDs), affectedPeerIDs)
go m.accountManager.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
} else {
log.WithContext(ctx).Tracef("DeleteResource %s: no affected peers", resourceID)
}
go m.accountManager.UpdateAccountPeers(ctx, accountID, nbtypes.UpdateReason{Resource: nbtypes.UpdateResourceNetworkResource, Operation: nbtypes.UpdateOperationDelete})
return nil
}
@@ -454,151 +399,6 @@ func (m *managerImpl) DeleteResourceInTransaction(ctx context.Context, transacti
return eventsToStore, nil
}
// resourceAffectedPeersData holds data loaded inside a transaction for affected peer resolution.
type resourceAffectedPeersData struct {
resourceGroupIDs []string
policies []*nbtypes.Policy
routerPeerGroups []string
routerDirectPeers []string
}
// loadResourceAffectedPeersData loads the data needed to determine affected peers within a transaction.
func loadResourceAffectedPeersData(ctx context.Context, transaction store.Store, accountID, networkID string, resourceGroupIDs []string) (*resourceAffectedPeersData, error) {
if len(resourceGroupIDs) == 0 {
return &resourceAffectedPeersData{}, nil
}
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, fmt.Errorf("failed to get policies: %w", err)
}
routers, err := transaction.GetNetworkRoutersByNetID(ctx, store.LockingStrengthNone, accountID, networkID)
if err != nil {
return nil, fmt.Errorf("failed to get routers: %w", err)
}
var routerPeerGroups []string
var routerDirectPeers []string
for _, router := range routers {
if !router.Enabled {
continue
}
routerPeerGroups = append(routerPeerGroups, router.PeerGroups...)
if router.Peer != "" {
routerDirectPeers = append(routerDirectPeers, router.Peer)
}
}
return &resourceAffectedPeersData{
resourceGroupIDs: resourceGroupIDs,
policies: policies,
routerPeerGroups: routerPeerGroups,
routerDirectPeers: routerDirectPeers,
}, nil
}
// resolveResourceAffectedPeers computes affected peer IDs from preloaded data outside the transaction.
func (m *managerImpl) resolveResourceAffectedPeers(ctx context.Context, accountID string, data *resourceAffectedPeersData) []string {
if data == nil {
return nil
}
log.WithContext(ctx).Tracef("resolveResourceAffectedPeers: resourceGroupIDs=%v, routerPeerGroups=%v, routerDirectPeers=%v, policies=%d",
data.resourceGroupIDs, data.routerPeerGroups, data.routerDirectPeers, len(data.policies))
groupSet := make(map[string]struct{})
directPeerIDs := collectResourcePolicySourceGroups(data.policies, data.resourceGroupIDs, groupSet)
for _, gID := range data.routerPeerGroups {
groupSet[gID] = struct{}{}
}
directPeerIDs = append(directPeerIDs, data.routerDirectPeers...)
if len(groupSet) == 0 && len(directPeerIDs) == 0 {
return nil
}
peerIDs := resolveGroupsAndDirectPeers(ctx, m.store, accountID, groupSet, directPeerIDs)
log.WithContext(ctx).Tracef("resolveResourceAffectedPeers: result %d peers: %v", len(peerIDs), peerIDs)
return peerIDs
}
// collectResourcePolicySourceGroups finds policies whose rules reference the resource destination groups,
// adds their source groups to groupSet, and returns any direct peer IDs from source resources.
func collectResourcePolicySourceGroups(policies []*nbtypes.Policy, destGroupIDs []string, groupSet map[string]struct{}) []string {
destSet := make(map[string]struct{}, len(destGroupIDs))
for _, gID := range destGroupIDs {
destSet[gID] = struct{}{}
}
var directPeerIDs []string
for _, policy := range policies {
if policy == nil || !policy.Enabled {
continue
}
directPeerIDs = collectSourcesFromPolicyRules(policy.Rules, destSet, groupSet, directPeerIDs)
}
return directPeerIDs
}
func collectSourcesFromPolicyRules(rules []*nbtypes.PolicyRule, destSet map[string]struct{}, groupSet map[string]struct{}, directPeerIDs []string) []string {
for _, rule := range rules {
if rule == nil || !rule.Enabled {
continue
}
if !ruleMatchesDestinations(rule, destSet) {
continue
}
for _, gID := range rule.Sources {
groupSet[gID] = struct{}{}
}
if rule.SourceResource.Type == nbtypes.ResourceTypePeer && rule.SourceResource.ID != "" {
directPeerIDs = append(directPeerIDs, rule.SourceResource.ID)
}
}
return directPeerIDs
}
func ruleMatchesDestinations(rule *nbtypes.PolicyRule, destSet map[string]struct{}) bool {
for _, gID := range rule.Destinations {
if _, ok := destSet[gID]; ok {
return true
}
}
return false
}
func resolveGroupsAndDirectPeers(ctx context.Context, s store.Store, accountID string, groupSet map[string]struct{}, directPeerIDs []string) []string {
groupIDs := make([]string, 0, len(groupSet))
for gID := range groupSet {
groupIDs = append(groupIDs, gID)
}
peerIDs, err := s.GetPeerIDsByGroups(ctx, accountID, groupIDs)
if err != nil {
log.WithContext(ctx).Errorf("failed to resolve peer IDs: %v", err)
return nil
}
if len(directPeerIDs) == 0 {
return peerIDs
}
seen := make(map[string]struct{}, len(peerIDs))
for _, id := range peerIDs {
seen[id] = struct{}{}
}
for _, id := range directPeerIDs {
if _, exists := seen[id]; !exists {
peerIDs = append(peerIDs, id)
seen[id] = struct{}{}
}
}
return peerIDs
}
func NewManagerMock() Manager {
return &mockManager{}
}

View File

@@ -6,7 +6,6 @@ import (
"fmt"
"github.com/rs/xid"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity"
@@ -16,7 +15,7 @@ import (
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store"
nbtypes "github.com/netbirdio/netbird/management/server/types"
serverTypes "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/status"
)
@@ -91,7 +90,6 @@ func (m *managerImpl) CreateRouter(ctx context.Context, userID string, router *t
}
var network *networkTypes.Network
var affectedData *routerAffectedPeersData
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
network, err = transaction.GetNetworkByID(ctx, store.LockingStrengthNone, router.AccountID, router.NetworkID)
if err != nil {
@@ -114,11 +112,6 @@ func (m *managerImpl) CreateRouter(ctx context.Context, userID string, router *t
return fmt.Errorf("failed to increment network serial: %w", err)
}
affectedData, err = loadRouterAffectedPeersData(ctx, transaction, router.AccountID, router.NetworkID, router.PeerGroups, router.Peer)
if err != nil {
log.WithContext(ctx).Errorf("failed to load affected peers data: %v", err)
}
return nil
})
if err != nil {
@@ -127,12 +120,7 @@ func (m *managerImpl) CreateRouter(ctx context.Context, userID string, router *t
m.accountManager.StoreEvent(ctx, userID, router.ID, router.AccountID, activity.NetworkRouterCreated, router.EventMeta(network))
if affectedPeerIDs := m.resolveRouterAffectedPeers(ctx, router.AccountID, affectedData); len(affectedPeerIDs) > 0 {
log.WithContext(ctx).Debugf("CreateRouter %s: updating %d affected peers: %v", router.ID, len(affectedPeerIDs), affectedPeerIDs)
go m.accountManager.UpdateAffectedPeers(ctx, router.AccountID, affectedPeerIDs)
} else {
log.WithContext(ctx).Tracef("CreateRouter %s: no affected peers", router.ID)
}
go m.accountManager.UpdateAccountPeers(ctx, router.AccountID, serverTypes.UpdateReason{Resource: serverTypes.UpdateResourceNetworkRouter, Operation: serverTypes.UpdateOperationCreate})
return router, nil
}
@@ -168,11 +156,36 @@ func (m *managerImpl) UpdateRouter(ctx context.Context, userID string, router *t
}
var network *networkTypes.Network
var affectedData *routerAffectedPeersData
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
var txErr error
network, affectedData, txErr = m.updateRouterInTransaction(ctx, transaction, router)
return txErr
network, err = transaction.GetNetworkByID(ctx, store.LockingStrengthNone, router.AccountID, router.NetworkID)
if err != nil {
return fmt.Errorf("failed to get network: %w", err)
}
existing, err := transaction.GetNetworkRouterByID(ctx, store.LockingStrengthUpdate, router.AccountID, router.ID)
if err != nil {
return fmt.Errorf("failed to get network router: %w", err)
}
if existing.AccountID != router.AccountID {
return status.NewNetworkRouterNotFoundError(router.ID)
}
if existing.NetworkID != router.NetworkID {
return status.NewRouterNotPartOfNetworkError(router.ID, router.NetworkID)
}
err = transaction.UpdateNetworkRouter(ctx, router)
if err != nil {
return fmt.Errorf("failed to update network router: %w", err)
}
err = transaction.IncrementNetworkSerial(ctx, router.AccountID)
if err != nil {
return fmt.Errorf("failed to increment network serial: %w", err)
}
return nil
})
if err != nil {
return nil, err
@@ -180,61 +193,11 @@ func (m *managerImpl) UpdateRouter(ctx context.Context, userID string, router *t
m.accountManager.StoreEvent(ctx, userID, router.ID, router.AccountID, activity.NetworkRouterUpdated, router.EventMeta(network))
if affectedPeerIDs := m.resolveRouterAffectedPeers(ctx, router.AccountID, affectedData); len(affectedPeerIDs) > 0 {
log.WithContext(ctx).Debugf("UpdateRouter %s: updating %d affected peers: %v", router.ID, len(affectedPeerIDs), affectedPeerIDs)
go m.accountManager.UpdateAffectedPeers(ctx, router.AccountID, affectedPeerIDs)
} else {
log.WithContext(ctx).Tracef("UpdateRouter %s: no affected peers", router.ID)
}
go m.accountManager.UpdateAccountPeers(ctx, router.AccountID, serverTypes.UpdateReason{Resource: serverTypes.UpdateResourceNetworkRouter, Operation: serverTypes.UpdateOperationUpdate})
return router, nil
}
func (m *managerImpl) updateRouterInTransaction(ctx context.Context, transaction store.Store, router *types.NetworkRouter) (*networkTypes.Network, *routerAffectedPeersData, error) {
network, err := transaction.GetNetworkByID(ctx, store.LockingStrengthNone, router.AccountID, router.NetworkID)
if err != nil {
return nil, nil, fmt.Errorf("failed to get network: %w", err)
}
existing, err := transaction.GetNetworkRouterByID(ctx, store.LockingStrengthUpdate, router.AccountID, router.ID)
if err != nil {
return nil, nil, fmt.Errorf("failed to get network router: %w", err)
}
if existing.AccountID != router.AccountID {
return nil, nil, status.NewNetworkRouterNotFoundError(router.ID)
}
if existing.NetworkID != router.NetworkID {
return nil, nil, status.NewRouterNotPartOfNetworkError(router.ID, router.NetworkID)
}
allPeerGroups := append([]string{}, router.PeerGroups...)
allPeerGroups = append(allPeerGroups, existing.PeerGroups...)
var directPeers []string
if router.Peer != "" {
directPeers = append(directPeers, router.Peer)
}
if existing.Peer != "" {
directPeers = append(directPeers, existing.Peer)
}
if err = transaction.UpdateNetworkRouter(ctx, router); err != nil {
return nil, nil, fmt.Errorf("failed to update network router: %w", err)
}
if err = transaction.IncrementNetworkSerial(ctx, router.AccountID); err != nil {
return nil, nil, fmt.Errorf("failed to increment network serial: %w", err)
}
affectedData, err := loadRouterAffectedPeersData(ctx, transaction, router.AccountID, router.NetworkID, allPeerGroups, directPeers...)
if err != nil {
log.WithContext(ctx).Errorf("failed to load affected peers data: %v", err)
}
return network, affectedData, nil
}
func (m *managerImpl) DeleteRouter(ctx context.Context, accountID, userID, networkID, routerID string) error {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Networks, operations.Delete)
if err != nil {
@@ -245,19 +208,7 @@ func (m *managerImpl) DeleteRouter(ctx context.Context, accountID, userID, netwo
}
var event func()
var affectedData *routerAffectedPeersData
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
router, err := transaction.GetNetworkRouterByID(ctx, store.LockingStrengthNone, accountID, routerID)
if err != nil {
return fmt.Errorf("failed to get router: %w", err)
}
// load before delete so group memberships are still present
affectedData, err = loadRouterAffectedPeersData(ctx, transaction, accountID, networkID, router.PeerGroups, router.Peer)
if err != nil {
log.WithContext(ctx).Errorf("failed to load affected peers data: %v", err)
}
event, err = m.DeleteRouterInTransaction(ctx, transaction, accountID, userID, networkID, routerID)
if err != nil {
return fmt.Errorf("failed to delete network router: %w", err)
@@ -276,12 +227,7 @@ func (m *managerImpl) DeleteRouter(ctx context.Context, accountID, userID, netwo
event()
if affectedPeerIDs := m.resolveRouterAffectedPeers(ctx, accountID, affectedData); len(affectedPeerIDs) > 0 {
log.WithContext(ctx).Debugf("DeleteRouter %s: updating %d affected peers: %v", routerID, len(affectedPeerIDs), affectedPeerIDs)
go m.accountManager.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
} else {
log.WithContext(ctx).Tracef("DeleteRouter %s: no affected peers", routerID)
}
go m.accountManager.UpdateAccountPeers(ctx, accountID, serverTypes.UpdateReason{Resource: serverTypes.UpdateResourceNetworkRouter, Operation: serverTypes.UpdateOperationDelete})
return nil
}
@@ -313,153 +259,6 @@ func (m *managerImpl) DeleteRouterInTransaction(ctx context.Context, transaction
return event, nil
}
// routerAffectedPeersData holds data loaded inside a transaction for affected peer resolution.
type routerAffectedPeersData struct {
routerPeerGroups []string
directPeerIDs []string
resourceGroupIDs []string
policies []*nbtypes.Policy
}
// loadRouterAffectedPeersData loads the data needed to determine affected peers within a transaction.
func loadRouterAffectedPeersData(ctx context.Context, transaction store.Store, accountID, networkID string, routerPeerGroups []string, directPeers ...string) (*routerAffectedPeersData, error) {
var directPeerIDs []string
for _, p := range directPeers {
if p != "" {
directPeerIDs = append(directPeerIDs, p)
}
}
if len(routerPeerGroups) == 0 && len(directPeerIDs) == 0 {
return &routerAffectedPeersData{}, nil
}
resources, err := transaction.GetNetworkResourcesByNetID(ctx, store.LockingStrengthNone, accountID, networkID)
if err != nil {
return nil, fmt.Errorf("failed to get network resources: %w", err)
}
var resourceGroupIDs []string
for _, resource := range resources {
if !resource.Enabled {
continue
}
groups, err := transaction.GetResourceGroups(ctx, store.LockingStrengthNone, accountID, resource.ID)
if err != nil {
return nil, fmt.Errorf("failed to get groups for resource %s: %w", resource.ID, err)
}
for _, g := range groups {
resourceGroupIDs = append(resourceGroupIDs, g.ID)
}
}
var policies []*nbtypes.Policy
if len(resourceGroupIDs) > 0 {
policies, err = transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, fmt.Errorf("failed to get policies: %w", err)
}
}
return &routerAffectedPeersData{
routerPeerGroups: routerPeerGroups,
directPeerIDs: directPeerIDs,
resourceGroupIDs: resourceGroupIDs,
policies: policies,
}, nil
}
// resolveRouterAffectedPeers computes affected peer IDs from preloaded data outside the transaction.
func (m *managerImpl) resolveRouterAffectedPeers(ctx context.Context, accountID string, data *routerAffectedPeersData) []string {
if data == nil {
return nil
}
log.WithContext(ctx).Tracef("resolveRouterAffectedPeers: routerPeerGroups=%v, directPeerIDs=%v, resourceGroupIDs=%v, policies=%d",
data.routerPeerGroups, data.directPeerIDs, data.resourceGroupIDs, len(data.policies))
groupSet := make(map[string]struct{})
for _, gID := range data.routerPeerGroups {
groupSet[gID] = struct{}{}
}
if len(data.resourceGroupIDs) > 0 {
collectPolicySourceGroups(data.policies, data.resourceGroupIDs, groupSet)
}
if len(groupSet) == 0 && len(data.directPeerIDs) == 0 {
return nil
}
peerIDs := resolveGroupsAndDirectPeers(ctx, m.store, accountID, groupSet, data.directPeerIDs)
log.WithContext(ctx).Tracef("resolveRouterAffectedPeers: result %d peers: %v", len(peerIDs), peerIDs)
return peerIDs
}
// collectPolicySourceGroups finds policies whose rules reference any of the destination group IDs
// and adds their source groups to the groupSet.
func collectPolicySourceGroups(policies []*nbtypes.Policy, destGroupIDs []string, groupSet map[string]struct{}) {
destSet := make(map[string]struct{}, len(destGroupIDs))
for _, gID := range destGroupIDs {
destSet[gID] = struct{}{}
}
for _, policy := range policies {
if policy == nil || !policy.Enabled {
continue
}
for _, rule := range policy.Rules {
if rule == nil || !rule.Enabled {
continue
}
if ruleMatchesDestinations(rule, destSet) {
for _, gID := range rule.Sources {
groupSet[gID] = struct{}{}
}
}
}
}
}
func ruleMatchesDestinations(rule *nbtypes.PolicyRule, destSet map[string]struct{}) bool {
for _, gID := range rule.Destinations {
if _, ok := destSet[gID]; ok {
return true
}
}
return false
}
func resolveGroupsAndDirectPeers(ctx context.Context, s store.Store, accountID string, groupSet map[string]struct{}, directPeerIDs []string) []string {
groupIDs := make([]string, 0, len(groupSet))
for gID := range groupSet {
groupIDs = append(groupIDs, gID)
}
peerIDs, err := s.GetPeerIDsByGroups(ctx, accountID, groupIDs)
if err != nil {
log.WithContext(ctx).Errorf("failed to resolve peer IDs: %v", err)
return nil
}
if len(directPeerIDs) == 0 {
return peerIDs
}
seen := make(map[string]struct{}, len(peerIDs))
for _, id := range peerIDs {
seen[id] = struct{}{}
}
for _, id := range directPeerIDs {
if _, exists := seen[id]; !exists {
peerIDs = append(peerIDs, id)
seen[id] = struct{}{}
}
}
return peerIDs
}
func NewManagerMock() Manager {
return &mockManager{}
}

View File

@@ -120,23 +120,19 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK
}
if expired {
changedPeerIDs := []string{peer.ID}
affectedPeerIDs := am.resolveAffectedPeersForPeerChanges(ctx, am.Store, accountID, changedPeerIDs)
err = am.networkMapController.OnPeersUpdated(ctx, accountID, changedPeerIDs, affectedPeerIDs)
if err != nil {
if err = am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peer.ID}); err != nil {
return fmt.Errorf("notify network map controller of peer update: %w", err)
}
}
// An embedded proxy peer flipping to connected is the trigger for
// SynthesizePrivateServiceZones to emit DNS A records pointing at its
// tunnel IP. Fan the change out to every peer that has a synthesized
// policy or DNS-record edge from this proxy peer; otherwise their
// synth state stays stale until some other change pokes the controller.
// tunnel IP. Without an account-wide netmap recompute, user peers keep
// the stale synth (or no synth at all on first connect) until some
// other change pokes the controller. Fire OnPeersUpdated so the
// buffered recompute fans the new state out to every peer.
if peer.ProxyMeta.Embedded {
changedPeerIDs := []string{peer.ID}
affectedPeerIDs := am.resolveAffectedPeersForPeerChanges(ctx, am.Store, accountID, changedPeerIDs)
if err := am.networkMapController.OnPeersUpdated(ctx, accountID, changedPeerIDs, affectedPeerIDs); err != nil {
if err := am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peer.ID}); err != nil {
log.WithContext(ctx).Warnf("notify network map controller of embedded proxy %s connect: %v", peer.ID, err)
}
}
@@ -178,12 +174,11 @@ func (am *DefaultAccountManager) MarkPeerDisconnected(ctx context.Context, peerP
am.metrics.AccountManagerMetrics().CountPeerStatusUpdate(telemetry.PeerStatusDisconnect, telemetry.PeerStatusApplied)
// Symmetric with MarkPeerConnected: when an embedded proxy peer goes
// offline, refresh the peers that had synthesized records pointing at
// it so they pull the stale entries instead of waiting out TTL.
// offline, drive an account-wide netmap recompute so the synthesized
// DNS records that pointed at it are pulled. Without this the records
// linger client-side at TTL until something else triggers a refresh.
if peer.ProxyMeta.Embedded {
changedPeerIDs := []string{peer.ID}
affectedPeerIDs := am.resolveAffectedPeersForPeerChanges(ctx, am.Store, accountID, changedPeerIDs)
if err := am.networkMapController.OnPeersUpdated(ctx, accountID, changedPeerIDs, affectedPeerIDs); err != nil {
if err := am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peer.ID}); err != nil {
log.WithContext(ctx).Warnf("notify network map controller of embedded proxy %s disconnect: %v", peer.ID, err)
}
}
@@ -350,10 +345,7 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
}
}
changedPeerIDs := []string{peer.ID}
affectedPeerIDs := am.resolveAffectedPeersForPeerChanges(ctx, am.Store, accountID, changedPeerIDs)
affectedPeerIDs = append(affectedPeerIDs, peer.ID)
err = am.networkMapController.OnPeersUpdated(ctx, accountID, changedPeerIDs, affectedPeerIDs)
err = am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peer.ID})
if err != nil {
return nil, fmt.Errorf("notify network map controller of peer update: %w", err)
}
@@ -511,7 +503,6 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
var peer *nbpeer.Peer
var settings *types.Settings
var eventsToStore []func()
var affectedPeerIDs []string
serviceID, err := am.serviceManager.GetServiceIDByTargetID(ctx, accountID, peerID)
if err != nil {
@@ -536,8 +527,6 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
return err
}
affectedPeerIDs = am.resolveAffectedPeersForPeerChanges(ctx, transaction, accountID, []string{peerID})
eventsToStore, err = deletePeers(ctx, am, transaction, accountID, userID, []*nbpeer.Peer{peer}, settings)
if err != nil {
return fmt.Errorf("failed to delete peer: %w", err)
@@ -561,7 +550,7 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
log.WithContext(ctx).Errorf("failed to delete peer %s from integrated validator: %v", peerID, err)
}
if err = am.networkMapController.OnPeersDeleted(ctx, accountID, []string{peerID}, affectedPeerIDs); err != nil {
if err = am.networkMapController.OnPeersDeleted(ctx, accountID, []string{peerID}); err != nil {
log.WithContext(ctx).Errorf("failed to delete peer %s from network map: %v", peerID, err)
}
@@ -934,9 +923,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe
am.StoreEvent(ctx, opEvent.InitiatorID, opEvent.TargetID, opEvent.AccountID, opEvent.Activity, opEvent.Meta)
}
changedPeerIDs := []string{newPeer.ID}
affectedPeerIDs := am.resolveAffectedPeersForPeerChanges(ctx, am.Store, accountID, changedPeerIDs)
if err := am.networkMapController.OnPeersAdded(ctx, accountID, changedPeerIDs, affectedPeerIDs); err != nil {
if err := am.networkMapController.OnPeersAdded(ctx, accountID, []string{newPeer.ID}); err != nil {
log.WithContext(ctx).Errorf("failed to update network map cache for peer %s: %v", newPeer.ID, err)
}
@@ -1024,9 +1011,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy
}
if isStatusChanged || sync.UpdateAccountPeers || ipv6CapabilityChanged || (updated && (len(postureChecks) > 0 || versionChanged)) {
changedPeerIDs := []string{peer.ID}
affectedPeerIDs := am.resolveAffectedPeersForPeerChanges(ctx, am.Store, accountID, changedPeerIDs)
err = am.networkMapController.OnPeersUpdated(ctx, accountID, changedPeerIDs, affectedPeerIDs)
err = am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peer.ID})
if err != nil {
return nil, nil, nil, 0, fmt.Errorf("notify network map controller of peer update: %w", err)
}
@@ -1156,9 +1141,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
}
if updateRemotePeers || isStatusChanged || ipv6CapabilityChanged || (isPeerUpdated && len(postureChecks) > 0) {
changedPeerIDs := []string{peer.ID}
affectedPeerIDs := am.resolveAffectedPeersForPeerChanges(ctx, am.Store, accountID, changedPeerIDs)
err = am.networkMapController.OnPeersUpdated(ctx, accountID, changedPeerIDs, affectedPeerIDs)
err = am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peer.ID})
if err != nil {
return nil, nil, nil, fmt.Errorf("notify network map controller of peer update: %w", err)
}
@@ -1350,64 +1333,6 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account
_ = am.networkMapController.UpdateAccountPeers(ctx, accountID, reason)
}
// UpdateAffectedPeers updates only the specified peers that belong to an account.
func (am *DefaultAccountManager) UpdateAffectedPeers(ctx context.Context, accountID string, peerIDs []string) {
ctx = context.WithoutCancel(ctx)
log.WithContext(ctx).Tracef("UpdateAffectedPeers: %d peers for account %s", len(peerIDs), accountID)
_ = am.networkMapController.UpdateAffectedPeers(ctx, accountID, peerIDs)
}
// resolvePeerIDs resolves group IDs and direct peer IDs into a deduplicated peer ID list.
func (am *DefaultAccountManager) resolvePeerIDs(ctx context.Context, s store.Store, accountID string, groupIDs []string, directPeerIDs []string) []string {
peerIDs, err := s.GetPeerIDsByGroups(ctx, accountID, groupIDs)
if err != nil {
log.WithContext(ctx).Errorf("failed to resolve peer IDs by groups: %v", err)
return nil
}
if len(directPeerIDs) == 0 {
log.WithContext(ctx).Tracef("resolvePeerIDs: groups=%v -> %d peers: %v", groupIDs, len(peerIDs), peerIDs)
return peerIDs
}
seen := make(map[string]struct{}, len(peerIDs))
for _, id := range peerIDs {
seen[id] = struct{}{}
}
for _, id := range directPeerIDs {
if _, exists := seen[id]; !exists {
peerIDs = append(peerIDs, id)
seen[id] = struct{}{}
}
}
log.WithContext(ctx).Tracef("resolvePeerIDs: groups=%v + directPeers=%v -> %d peers: %v", groupIDs, directPeerIDs, len(peerIDs), peerIDs)
return peerIDs
}
// BufferUpdateAffectedPeers accumulates peer IDs and flushes them after the buffer interval.
func (am *DefaultAccountManager) BufferUpdateAffectedPeers(ctx context.Context, accountID string, peerIDs []string, reason types.UpdateReason) {
_ = am.networkMapController.BufferUpdateAffectedPeers(ctx, accountID, peerIDs, reason)
}
// resolveAffectedPeersForPeerChanges resolves changed peer IDs into the full set of affected peer IDs.
func (am *DefaultAccountManager) resolveAffectedPeersForPeerChanges(ctx context.Context, s store.Store, accountID string, changedPeerIDs []string) []string {
groupIDs, err := s.GetGroupIDsByPeerIDs(ctx, accountID, changedPeerIDs)
if err != nil {
log.WithContext(ctx).Errorf("failed to get group IDs for changed peers: %v", err)
return nil
}
log.WithContext(ctx).Tracef("resolveAffectedPeersForPeerChanges: changedPeers=%v -> groups=%v", changedPeerIDs, groupIDs)
// Single pass: find entities referencing the changed groups OR the changed peers directly
allGroupIDs, directPeerIDs := collectPeerChangeAffectedGroups(ctx, s, accountID, groupIDs, changedPeerIDs)
result := am.resolvePeerIDs(ctx, s, accountID, allGroupIDs, directPeerIDs)
log.WithContext(ctx).Tracef("resolveAffectedPeersForPeerChanges: changedPeers=%v -> %d affected peers", changedPeerIDs, len(result))
return result
}
func (am *DefaultAccountManager) BufferUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) {
_ = am.networkMapController.BufferUpdateAccountPeers(ctx, accountID, reason)
}

View File

@@ -1855,7 +1855,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
t.Run("adding peer to unlinked group", func(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldNotReceiveUpdate(t, updMsg)
peerShouldReceiveUpdate(t, updMsg) //
close(done)
}()
@@ -1880,7 +1880,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
t.Run("deleting peer with unlinked group", func(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldNotReceiveUpdate(t, updMsg)
peerShouldReceiveUpdate(t, updMsg)
close(done)
}()
@@ -2018,10 +2018,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
}
})
// drain any buffered updates from previous subtests
drainPeerUpdates(updMsg)
// Adding peer to group linked with route should update peers in that group, not unrelated peers
// Adding peer to group linked with route should update account peers and send peer update
t.Run("adding peer to group linked with route", func(t *testing.T) {
route := nbroute.Route{
ID: "testingRoute1",
@@ -2045,7 +2042,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldNotReceiveUpdate(t, updMsg)
peerShouldReceiveUpdate(t, updMsg)
close(done)
}()
@@ -2062,16 +2059,16 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})
// Deleting peer with linked group to route should update peers in that group, not unrelated peers
// Deleting peer with linked group to route should update account peers and send peer update
t.Run("deleting peer with linked group to route", func(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldNotReceiveUpdate(t, updMsg)
peerShouldReceiveUpdate(t, updMsg)
close(done)
}()
@@ -2080,12 +2077,12 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})
// Adding peer to group linked with name server group should update peers in that group, not unrelated peers
// Adding peer to group linked with name server group should update account peers and send peer update
t.Run("adding peer to group linked with name server group", func(t *testing.T) {
_, err = manager.CreateNameServerGroup(
context.Background(), account.Id, "nsGroup", "nsGroup", []nbdns.NameServer{{
@@ -2100,7 +2097,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldNotReceiveUpdate(t, updMsg)
peerShouldReceiveUpdate(t, updMsg)
close(done)
}()
@@ -2117,16 +2114,16 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})
// Deleting peer with linked group to name server group should update peers in that group, not unrelated peers
// Deleting peer with linked group to name server group should update account peers and send peer update
t.Run("deleting peer with linked group to route", func(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldNotReceiveUpdate(t, updMsg)
peerShouldReceiveUpdate(t, updMsg)
close(done)
}()
@@ -2135,8 +2132,8 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})
}

View File

@@ -5,7 +5,7 @@ import (
_ "embed"
"github.com/rs/xid"
log "github.com/sirupsen/logrus"
"github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
@@ -45,38 +45,44 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
}
var isUpdate = policy.ID != ""
var existingPolicy *types.Policy
var updateAccountPeers bool
var action = activity.PolicyAdded
var unchanged bool
var affectedPeerIDs []string
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
existingPolicy, err = validatePolicy(ctx, transaction, accountID, policy)
existingPolicy, err := validatePolicy(ctx, transaction, accountID, policy)
if err != nil {
return err
}
if isUpdate {
if policy.Equal(existingPolicy) {
log.WithContext(ctx).Tracef("policy update skipped because equal to stored one - policy id %s", policy.ID)
logrus.WithContext(ctx).Tracef("policy update skipped because equal to stored one - policy id %s", policy.ID)
unchanged = true
return nil
}
action = activity.PolicyUpdated
updateAccountPeers, err = arePolicyChangesAffectPeersWithExisting(ctx, transaction, policy, existingPolicy)
if err != nil {
return err
}
if err = transaction.SavePolicy(ctx, policy); err != nil {
return err
}
} else {
updateAccountPeers, err = arePolicyChangesAffectPeers(ctx, transaction, policy)
if err != nil {
return err
}
if err = transaction.CreatePolicy(ctx, policy); err != nil {
return err
}
}
groupIDs, directPeerIDs := collectPolicyAffectedGroupsAndPeers(ctx, policy, existingPolicy)
affectedPeerIDs = am.resolvePeerIDs(ctx, transaction, accountID, groupIDs, directPeerIDs)
return transaction.IncrementNetworkSerial(ctx, accountID)
})
if err != nil {
@@ -89,11 +95,12 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta())
if len(affectedPeerIDs) > 0 {
log.WithContext(ctx).Tracef("SavePolicy %s: updating %d affected peers: %v", policy.ID, len(affectedPeerIDs), affectedPeerIDs)
am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
} else {
log.WithContext(ctx).Tracef("SavePolicy %s: no affected peers", policy.ID)
if updateAccountPeers {
policyOp := types.UpdateOperationCreate
if isUpdate {
policyOp = types.UpdateOperationUpdate
}
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourcePolicy, Operation: policyOp})
}
return policy, nil
@@ -110,7 +117,7 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po
}
var policy *types.Policy
var affectedPeerIDs []string
var updateAccountPeers bool
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
policy, err = transaction.GetPolicyByID(ctx, store.LockingStrengthUpdate, accountID, policyID)
@@ -118,8 +125,10 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po
return err
}
groupIDs, directPeerIDs := collectPolicyAffectedGroupsAndPeers(ctx, policy)
affectedPeerIDs = am.resolvePeerIDs(ctx, transaction, accountID, groupIDs, directPeerIDs)
updateAccountPeers, err = arePolicyChangesAffectPeers(ctx, transaction, policy)
if err != nil {
return err
}
if err = transaction.DeletePolicy(ctx, accountID, policyID); err != nil {
return err
@@ -133,11 +142,8 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po
am.StoreEvent(ctx, userID, policyID, accountID, activity.PolicyRemoved, policy.EventMeta())
if len(affectedPeerIDs) > 0 {
log.WithContext(ctx).Debugf("DeletePolicy %s: updating %d affected peers: %v", policyID, len(affectedPeerIDs), affectedPeerIDs)
am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
} else {
log.WithContext(ctx).Tracef("DeletePolicy %s: no affected peers", policyID)
if updateAccountPeers {
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourcePolicy, Operation: types.UpdateOperationDelete})
}
return nil
@@ -156,28 +162,44 @@ func (am *DefaultAccountManager) ListPolicies(ctx context.Context, accountID, us
return am.Store.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
}
// collectPolicyAffectedGroupsAndPeers returns group IDs and direct peer IDs from the given policies.
func collectPolicyAffectedGroupsAndPeers(ctx context.Context, policies ...*types.Policy) (groupIDs []string, directPeerIDs []string) {
for _, policy := range policies {
if policy == nil {
continue
}
ruleGroups := policy.RuleGroups()
log.WithContext(ctx).Tracef("collectPolicyAffectedGroupsAndPeers: policy %s (%s) ruleGroups=%v", policy.ID, policy.Name, ruleGroups)
groupIDs = append(groupIDs, ruleGroups...)
for _, rule := range policy.Rules {
if rule.SourceResource.Type == types.ResourceTypePeer && rule.SourceResource.ID != "" {
log.WithContext(ctx).Tracef("collectPolicyAffectedGroupsAndPeers: policy %s rule %s direct source peer %s", policy.ID, rule.ID, rule.SourceResource.ID)
directPeerIDs = append(directPeerIDs, rule.SourceResource.ID)
}
if rule.DestinationResource.Type == types.ResourceTypePeer && rule.DestinationResource.ID != "" {
log.WithContext(ctx).Tracef("collectPolicyAffectedGroupsAndPeers: policy %s rule %s direct destination peer %s", policy.ID, rule.ID, rule.DestinationResource.ID)
directPeerIDs = append(directPeerIDs, rule.DestinationResource.ID)
}
// arePolicyChangesAffectPeers checks if a policy (being created or deleted) will affect any associated peers.
func arePolicyChangesAffectPeers(ctx context.Context, transaction store.Store, policy *types.Policy) (bool, error) {
for _, rule := range policy.Rules {
if rule.SourceResource.Type != "" || rule.DestinationResource.Type != "" {
return true, nil
}
}
log.WithContext(ctx).Tracef("collectPolicyAffectedGroupsAndPeers: result groupIDs=%v, directPeerIDs=%v", groupIDs, directPeerIDs)
return
return anyGroupHasPeersOrResources(ctx, transaction, policy.AccountID, policy.RuleGroups())
}
func arePolicyChangesAffectPeersWithExisting(ctx context.Context, transaction store.Store, policy *types.Policy, existingPolicy *types.Policy) (bool, error) {
if !policy.Enabled && !existingPolicy.Enabled {
return false, nil
}
for _, rule := range existingPolicy.Rules {
if rule.SourceResource.Type != "" || rule.DestinationResource.Type != "" {
return true, nil
}
}
hasPeers, err := anyGroupHasPeersOrResources(ctx, transaction, policy.AccountID, existingPolicy.RuleGroups())
if err != nil {
return false, err
}
if hasPeers {
return true, nil
}
for _, rule := range policy.Rules {
if rule.SourceResource.Type != "" || rule.DestinationResource.Type != "" {
return true, nil
}
}
return anyGroupHasPeersOrResources(ctx, transaction, policy.AccountID, policy.RuleGroups())
}
// validatePolicy validates the policy and its rules. For updates it returns

View File

@@ -1319,14 +1319,12 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
}
})
// Updating disabled policy with destination and source groups containing peers should still update account's peers
// because affected peer resolution does not filter by policy enabled state
// Updating disabled policy with destination and source groups containing peers should not update account's peers
// or send peer update
t.Run("updating disabled policy with source and destination groups with peers", func(t *testing.T) {
drainPeerUpdates(updMsg)
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, updMsg)
peerShouldNotReceiveUpdate(t, updMsg)
close(done)
}()
@@ -1337,8 +1335,8 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
select {
case <-done:
case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate")
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
}
})

View File

@@ -5,13 +5,13 @@ import (
"slices"
"github.com/rs/xid"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/status"
)
@@ -41,9 +41,9 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI
return nil, status.NewPermissionDeniedError()
}
var updateAccountPeers bool
var isUpdate = postureChecks.ID != ""
var action = activity.PostureCheckCreated
var affectedPeerIDs []string
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if err = validatePostureChecks(ctx, transaction, accountID, postureChecks); err != nil {
@@ -51,10 +51,12 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI
}
if isUpdate {
action = activity.PostureCheckUpdated
updateAccountPeers, err = arePostureCheckChangesAffectPeers(ctx, transaction, accountID, postureChecks.ID)
if err != nil {
return err
}
groupIDs, directPeerIDs := collectPostureCheckAffectedGroupsAndPeers(ctx, transaction, accountID, postureChecks.ID)
affectedPeerIDs = am.resolvePeerIDs(ctx, transaction, accountID, groupIDs, directPeerIDs)
action = activity.PostureCheckUpdated
}
postureChecks.AccountID = accountID
@@ -74,11 +76,12 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI
am.StoreEvent(ctx, userID, postureChecks.ID, accountID, action, postureChecks.EventMeta())
if len(affectedPeerIDs) > 0 {
log.WithContext(ctx).Debugf("SavePostureChecks %s: updating %d affected peers: %v", postureChecks.ID, len(affectedPeerIDs), affectedPeerIDs)
am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
} else {
log.WithContext(ctx).Tracef("SavePostureChecks %s: no affected peers", postureChecks.ID)
if updateAccountPeers {
postureOp := types.UpdateOperationCreate
if isUpdate {
postureOp = types.UpdateOperationUpdate
}
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourcePostureCheck, Operation: postureOp})
}
return postureChecks, nil
@@ -134,25 +137,27 @@ func (am *DefaultAccountManager) ListPostureChecks(ctx context.Context, accountI
return am.Store.GetAccountPostureChecks(ctx, store.LockingStrengthNone, accountID)
}
// collectPostureCheckAffectedGroupsAndPeers returns group IDs and peer IDs from policies referencing the posture check.
func collectPostureCheckAffectedGroupsAndPeers(ctx context.Context, transaction store.Store, accountID, postureCheckID string) (groupIDs []string, directPeerIDs []string) {
// arePostureCheckChangesAffectPeers checks if the changes in posture checks are affecting peers.
func arePostureCheckChangesAffectPeers(ctx context.Context, transaction store.Store, accountID, postureCheckID string) (bool, error) {
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get policies for posture check affected peers resolution: %v", err)
return nil, nil
return false, err
}
for _, policy := range policies {
if slices.Contains(policy.SourcePostureChecks, postureCheckID) {
log.WithContext(ctx).Tracef("collectPostureCheckAffectedGroupsAndPeers: posture check %s referenced by policy %s (%s)", postureCheckID, policy.ID, policy.Name)
gIDs, pIDs := collectPolicyAffectedGroupsAndPeers(ctx, policy)
groupIDs = append(groupIDs, gIDs...)
directPeerIDs = append(directPeerIDs, pIDs...)
hasPeers, err := anyGroupHasPeersOrResources(ctx, transaction, accountID, policy.RuleGroups())
if err != nil {
return false, err
}
if hasPeers {
return true, nil
}
}
}
log.WithContext(ctx).Tracef("collectPostureCheckAffectedGroupsAndPeers: postureCheck=%s -> groupIDs=%v, directPeerIDs=%v", postureCheckID, groupIDs, directPeerIDs)
return groupIDs, directPeerIDs
return false, nil
}
// validatePostureChecks validates the posture checks.

View File

@@ -503,20 +503,21 @@ func TestArePostureCheckChangesAffectPeers(t *testing.T) {
require.NoError(t, err, "failed to save policy")
t.Run("posture check exists and is linked to policy with peers", func(t *testing.T) {
groupIDs, _ := collectPostureCheckAffectedGroupsAndPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID)
assert.NotEmpty(t, groupIDs)
result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID)
require.NoError(t, err)
assert.True(t, result)
})
t.Run("posture check exists but is not linked to any policy", func(t *testing.T) {
groupIDs, directPeerIDs := collectPostureCheckAffectedGroupsAndPeers(context.Background(), manager.Store, account.Id, postureCheckB.ID)
assert.Empty(t, groupIDs)
assert.Empty(t, directPeerIDs)
result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckB.ID)
require.NoError(t, err)
assert.False(t, result)
})
t.Run("posture check does not exist", func(t *testing.T) {
groupIDs, directPeerIDs := collectPostureCheckAffectedGroupsAndPeers(context.Background(), manager.Store, account.Id, "unknown")
assert.Empty(t, groupIDs)
assert.Empty(t, directPeerIDs)
result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, "unknown")
require.NoError(t, err)
assert.False(t, result)
})
t.Run("posture check is linked to policy with no peers in source groups", func(t *testing.T) {
@@ -525,8 +526,9 @@ func TestArePostureCheckChangesAffectPeers(t *testing.T) {
_, err = manager.SavePolicy(context.Background(), account.Id, adminUserID, policy, true)
require.NoError(t, err, "failed to update policy")
groupIDs, _ := collectPostureCheckAffectedGroupsAndPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID)
assert.NotEmpty(t, groupIDs)
result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID)
require.NoError(t, err)
assert.True(t, result)
})
t.Run("posture check is linked to policy with no peers in destination groups", func(t *testing.T) {
@@ -535,8 +537,9 @@ func TestArePostureCheckChangesAffectPeers(t *testing.T) {
_, err = manager.SavePolicy(context.Background(), account.Id, adminUserID, policy, true)
require.NoError(t, err, "failed to update policy")
groupIDs, _ := collectPostureCheckAffectedGroupsAndPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID)
assert.NotEmpty(t, groupIDs)
result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID)
require.NoError(t, err)
assert.True(t, result)
})
t.Run("posture check is linked to policy but no peers in groups", func(t *testing.T) {
@@ -544,9 +547,9 @@ func TestArePostureCheckChangesAffectPeers(t *testing.T) {
err = manager.UpdateGroup(context.Background(), account.Id, adminUserID, groupA)
require.NoError(t, err, "failed to save groups")
// The collector returns groups even if they have no peers — the groups are still referenced
groupIDs, _ := collectPostureCheckAffectedGroupsAndPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID)
assert.NotEmpty(t, groupIDs)
result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID)
require.NoError(t, err)
assert.False(t, result)
})
t.Run("posture check is linked to policy with non-existent group", func(t *testing.T) {
@@ -555,10 +558,8 @@ func TestArePostureCheckChangesAffectPeers(t *testing.T) {
_, err = manager.SavePolicy(context.Background(), account.Id, adminUserID, policy, true)
require.NoError(t, err, "failed to update policy")
// Non-existent groups are filtered out during SavePolicy validation,
// so the saved policy has empty Sources/Destinations
groupIDs, directPeerIDs := collectPostureCheckAffectedGroupsAndPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID)
assert.Empty(t, groupIDs)
assert.Empty(t, directPeerIDs)
result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID)
require.NoError(t, err)
assert.False(t, result)
})
}

View File

@@ -8,7 +8,6 @@ import (
"unicode/utf8"
"github.com/rs/xid"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/permissions/modules"
@@ -148,7 +147,7 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
}
var newRoute *route.Route
var affectedPeerIDs []string
var updateAccountPeers bool
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
newRoute = &route.Route{
@@ -174,12 +173,14 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
return err
}
if err = transaction.SaveRoute(ctx, newRoute); err != nil {
updateAccountPeers, err = areRouteChangesAffectPeers(ctx, transaction, newRoute)
if err != nil {
return err
}
groupIDs, directPeerIDs := collectRouteAffectedGroupsAndPeers(ctx, newRoute)
affectedPeerIDs = am.resolvePeerIDs(ctx, transaction, accountID, groupIDs, directPeerIDs)
if err = transaction.SaveRoute(ctx, newRoute); err != nil {
return err
}
return transaction.IncrementNetworkSerial(ctx, accountID)
})
@@ -189,11 +190,8 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
am.StoreEvent(ctx, userID, string(newRoute.ID), accountID, activity.RouteCreated, newRoute.EventMeta())
if len(affectedPeerIDs) > 0 {
log.WithContext(ctx).Debugf("CreateRoute %s: updating %d affected peers: %v", newRoute.ID, len(affectedPeerIDs), affectedPeerIDs)
am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
} else {
log.WithContext(ctx).Tracef("CreateRoute %s: no affected peers", newRoute.ID)
if updateAccountPeers {
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceRoute, Operation: types.UpdateOperationCreate})
}
return newRoute, nil
@@ -210,7 +208,8 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
}
var oldRoute *route.Route
var affectedPeerIDs []string
var oldRouteAffectsPeers bool
var newRouteAffectsPeers bool
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if err = validateRoute(ctx, transaction, accountID, routeToSave); err != nil {
@@ -222,15 +221,21 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
return err
}
oldRouteAffectsPeers, err = areRouteChangesAffectPeers(ctx, transaction, oldRoute)
if err != nil {
return err
}
newRouteAffectsPeers, err = areRouteChangesAffectPeers(ctx, transaction, routeToSave)
if err != nil {
return err
}
routeToSave.AccountID = accountID
if err = transaction.SaveRoute(ctx, routeToSave); err != nil {
return err
}
groupIDs, directPeerIDs := collectRouteAffectedGroupsAndPeers(ctx, routeToSave, oldRoute)
affectedPeerIDs = am.resolvePeerIDs(ctx, transaction, accountID, groupIDs, directPeerIDs)
return transaction.IncrementNetworkSerial(ctx, accountID)
})
if err != nil {
@@ -239,11 +244,8 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
am.StoreEvent(ctx, userID, string(routeToSave.ID), accountID, activity.RouteUpdated, routeToSave.EventMeta())
if len(affectedPeerIDs) > 0 {
log.WithContext(ctx).Debugf("SaveRoute %s: updating %d affected peers: %v", routeToSave.ID, len(affectedPeerIDs), affectedPeerIDs)
am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
} else {
log.WithContext(ctx).Tracef("SaveRoute %s: no affected peers", routeToSave.ID)
if oldRouteAffectsPeers || newRouteAffectsPeers {
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceRoute, Operation: types.UpdateOperationUpdate})
}
return nil
@@ -259,17 +261,19 @@ func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID stri
return status.NewPermissionDeniedError()
}
var rt *route.Route
var affectedPeerIDs []string
var route *route.Route
var updateAccountPeers bool
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
rt, err = transaction.GetRouteByID(ctx, store.LockingStrengthUpdate, accountID, string(routeID))
route, err = transaction.GetRouteByID(ctx, store.LockingStrengthUpdate, accountID, string(routeID))
if err != nil {
return err
}
groupIDs, directPeerIDs := collectRouteAffectedGroupsAndPeers(ctx, rt)
affectedPeerIDs = am.resolvePeerIDs(ctx, transaction, accountID, groupIDs, directPeerIDs)
updateAccountPeers, err = areRouteChangesAffectPeers(ctx, transaction, route)
if err != nil {
return err
}
if err = transaction.DeleteRoute(ctx, accountID, string(routeID)); err != nil {
return err
@@ -281,13 +285,10 @@ func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID stri
return fmt.Errorf("failed to delete route %s: %w", routeID, err)
}
am.StoreEvent(ctx, userID, string(rt.ID), accountID, activity.RouteRemoved, rt.EventMeta())
am.StoreEvent(ctx, userID, string(route.ID), accountID, activity.RouteRemoved, route.EventMeta())
if len(affectedPeerIDs) > 0 {
log.WithContext(ctx).Debugf("DeleteRoute %s: updating %d affected peers: %v", routeID, len(affectedPeerIDs), affectedPeerIDs)
am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
} else {
log.WithContext(ctx).Tracef("DeleteRoute %s: no affected peers", routeID)
if updateAccountPeers {
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceRoute, Operation: types.UpdateOperationDelete})
}
return nil
@@ -376,23 +377,23 @@ func getPlaceholderIP() netip.Prefix {
return netip.PrefixFrom(netip.AddrFrom4([4]byte{192, 0, 2, 0}), 32)
}
// collectRouteAffectedGroupsAndPeers returns group IDs and direct peer IDs from the given routes.
func collectRouteAffectedGroupsAndPeers(ctx context.Context, routes ...*route.Route) (groupIDs []string, directPeerIDs []string) {
for _, r := range routes {
if r == nil {
continue
}
log.WithContext(ctx).Tracef("collectRouteAffectedGroupsAndPeers: route %s groups=%v peerGroups=%v accessControlGroups=%v peer=%q",
r.ID, r.Groups, r.PeerGroups, r.AccessControlGroups, r.Peer)
groupIDs = append(groupIDs, r.Groups...)
groupIDs = append(groupIDs, r.PeerGroups...)
groupIDs = append(groupIDs, r.AccessControlGroups...)
if r.Peer != "" {
directPeerIDs = append(directPeerIDs, r.Peer)
}
// areRouteChangesAffectPeers checks if a given route affects peers by determining
// if it has a routing peer, distribution, or peer groups that include peers.
func areRouteChangesAffectPeers(ctx context.Context, transaction store.Store, route *route.Route) (bool, error) {
if route.Peer != "" {
return true, nil
}
log.WithContext(ctx).Tracef("collectRouteAffectedGroupsAndPeers: result groupIDs=%v, directPeerIDs=%v", groupIDs, directPeerIDs)
return
hasPeers, err := anyGroupHasPeersOrResources(ctx, transaction, route.AccountID, route.Groups)
if err != nil {
return false, err
}
if hasPeers {
return true, nil
}
return anyGroupHasPeersOrResources(ctx, transaction, route.AccountID, route.PeerGroups)
}
// GetRoutesByPrefixOrDomains return list of routes by account and route prefix

View File

@@ -1962,10 +1962,8 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
})
// Creating a route with no routing peer and having peers in groups that don't include peer1 should not send peer1 an update
// Creating a route with no routing peer and having peers in groups should update account peers and send peer update
t.Run("creating a route with peers in PeerGroups and Groups", func(t *testing.T) {
drainPeerUpdates(updMsg)
route := route.Route{
ID: "testingRoute2",
Network: netip.MustParsePrefix("192.0.2.0/32"),
@@ -1981,7 +1979,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldNotReceiveUpdate(t, updMsg)
peerShouldReceiveUpdate(t, updMsg)
close(done)
}()
@@ -1994,8 +1992,8 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})

View File

@@ -265,8 +265,7 @@ func (s *SqlStore) AcquireGlobalLock(ctx context.Context) (unlock func()) {
return unlock
}
// Deprecated: Full
// account operations are no longer supported
// Deprecated: Full account operations are no longer supported
func (s *SqlStore) SaveAccount(ctx context.Context, account *types.Account) error {
start := time.Now()
defer func() {
@@ -4895,64 +4894,6 @@ func (s *SqlStore) GetPeersByGroupIDs(ctx context.Context, accountID string, gro
return peers, nil
}
func (s *SqlStore) GetPeerIDsByGroups(ctx context.Context, accountID string, groupIDs []string) ([]string, error) {
if len(groupIDs) == 0 {
return nil, nil
}
var peerIDs []string
result := s.db.Model(&types.GroupPeer{}).
Select("DISTINCT peer_id").
Where("account_id = ? AND group_id IN ?", accountID, groupIDs).
Pluck("peer_id", &peerIDs)
if result.Error != nil {
return nil, status.Errorf(status.Internal, "failed to get peer IDs by groups: %s", result.Error)
}
return peerIDs, nil
}
func (s *SqlStore) GetGroupIDsByPeerIDs(ctx context.Context, accountID string, peerIDs []string) ([]string, error) {
if len(peerIDs) == 0 {
return nil, nil
}
var groupIDs []string
result := s.db.Model(&types.GroupPeer{}).
Select("DISTINCT group_id").
Where("account_id = ? AND peer_id IN ?", accountID, peerIDs).
Pluck("group_id", &groupIDs)
if result.Error != nil {
return nil, status.Errorf(status.Internal, "failed to get group IDs by peers: %s", result.Error)
}
return groupIDs, nil
}
// GetEmbeddedProxyPeerIDsByCluster returns peer IDs of all embedded proxy peers
// in the account, grouped by their ProxyCluster. The map is nil when no embedded
// proxy peers exist.
func (s *SqlStore) GetEmbeddedProxyPeerIDsByCluster(ctx context.Context, accountID string) (map[string][]string, error) {
type row struct {
ID string
Cluster string
}
var rows []row
result := s.db.Model(&nbpeer.Peer{}).
Select("id, proxy_meta_cluster AS cluster").
Where("account_id = ? AND proxy_meta_embedded = ?", accountID, true).
Scan(&rows)
if result.Error != nil {
return nil, status.Errorf(status.Internal, "failed to get embedded proxy peers: %s", result.Error)
}
out := make(map[string][]string, len(rows))
for _, r := range rows {
out[r.Cluster] = append(out[r.Cluster], r.ID)
}
return out, nil
}
func (s *SqlStore) GetUserIDByPeerKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (string, error) {
tx := s.db
if lockStrength != LockingStrengthNone {

View File

@@ -162,9 +162,6 @@ type Store interface {
GetPeerByID(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) (*nbpeer.Peer, error)
GetPeersByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, peerIDs []string) (map[string]*nbpeer.Peer, error)
GetPeersByGroupIDs(ctx context.Context, accountID string, groupIDs []string) ([]*nbpeer.Peer, error)
GetPeerIDsByGroups(ctx context.Context, accountID string, groupIDs []string) ([]string, error)
GetGroupIDsByPeerIDs(ctx context.Context, accountID string, peerIDs []string) ([]string, error)
GetEmbeddedProxyPeerIDsByCluster(ctx context.Context, accountID string) (map[string][]string, error)
GetAccountPeersWithExpiration(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error)
GetAccountPeersWithInactivity(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error)
GetAllEphemeralPeers(ctx context.Context, lockStrength LockingStrength) ([]*nbpeer.Peer, error)

View File

@@ -1925,51 +1925,6 @@ func (mr *MockStoreMockRecorder) GetPeersByGroupIDs(ctx, accountID, groupIDs int
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeersByGroupIDs", reflect.TypeOf((*MockStore)(nil).GetPeersByGroupIDs), ctx, accountID, groupIDs)
}
// GetPeerIDsByGroups mocks base method.
func (m *MockStore) GetPeerIDsByGroups(ctx context.Context, accountID string, groupIDs []string) ([]string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetPeerIDsByGroups", ctx, accountID, groupIDs)
ret0, _ := ret[0].([]string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetPeerIDsByGroups indicates an expected call of GetPeerIDsByGroups.
func (mr *MockStoreMockRecorder) GetPeerIDsByGroups(ctx, accountID, groupIDs interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeerIDsByGroups", reflect.TypeOf((*MockStore)(nil).GetPeerIDsByGroups), ctx, accountID, groupIDs)
}
// GetGroupIDsByPeerIDs mocks base method.
func (m *MockStore) GetGroupIDsByPeerIDs(ctx context.Context, accountID string, peerIDs []string) ([]string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetGroupIDsByPeerIDs", ctx, accountID, peerIDs)
ret0, _ := ret[0].([]string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetGroupIDsByPeerIDs indicates an expected call of GetGroupIDsByPeerIDs.
func (mr *MockStoreMockRecorder) GetGroupIDsByPeerIDs(ctx, accountID, peerIDs interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGroupIDsByPeerIDs", reflect.TypeOf((*MockStore)(nil).GetGroupIDsByPeerIDs), ctx, accountID, peerIDs)
}
// GetEmbeddedProxyPeerIDsByCluster mocks base method.
func (m *MockStore) GetEmbeddedProxyPeerIDsByCluster(ctx context.Context, accountID string) (map[string][]string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetEmbeddedProxyPeerIDsByCluster", ctx, accountID)
ret0, _ := ret[0].(map[string][]string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetEmbeddedProxyPeerIDsByCluster indicates an expected call of GetEmbeddedProxyPeerIDsByCluster.
func (mr *MockStoreMockRecorder) GetEmbeddedProxyPeerIDsByCluster(ctx, accountID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetEmbeddedProxyPeerIDsByCluster", reflect.TypeOf((*MockStore)(nil).GetEmbeddedProxyPeerIDsByCluster), ctx, accountID)
}
// GetPeersByIDs mocks base method.
func (m *MockStore) GetPeersByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, peerIDs []string) (map[string]*peer.Peer, error) {
m.ctrl.T.Helper()

View File

@@ -1165,8 +1165,7 @@ func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accou
}
}
affectedPeerIDs := am.resolveAffectedPeersForPeerChanges(ctx, am.Store, accountID, peerIDs)
err = am.networkMapController.OnPeersUpdated(ctx, accountID, peerIDs, affectedPeerIDs)
err = am.networkMapController.OnPeersUpdated(ctx, accountID, peerIDs)
if err != nil {
return fmt.Errorf("notify network map controller of peer update: %w", err)
}
@@ -1282,7 +1281,6 @@ func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, accountI
var userPeers []*nbpeer.Peer
var targetUser *types.User
var settings *types.Settings
var affectedPeerIDs []string
var err error
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
@@ -1303,14 +1301,6 @@ func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, accountI
if len(userPeers) > 0 {
updateAccountPeers = true
var peerIDs []string
for _, peer := range userPeers {
peerIDs = append(peerIDs, peer.ID)
}
// Resolve before delete so group memberships are still present.
affectedPeerIDs = am.resolveAffectedPeersForPeerChanges(ctx, transaction, accountID, peerIDs)
addPeerRemovedEvents, err = deletePeers(ctx, am, transaction, accountID, targetUserInfo.ID, userPeers, settings)
if err != nil {
return fmt.Errorf("failed to delete user peers: %w", err)
@@ -1334,7 +1324,7 @@ func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, accountI
log.WithContext(ctx).Errorf("failed to delete peer %s from integrated validator: %v", peer.ID, err)
}
}
if err := am.networkMapController.OnPeersDeleted(ctx, accountID, peerIDs, affectedPeerIDs); err != nil {
if err := am.networkMapController.OnPeersDeleted(ctx, accountID, peerIDs); err != nil {
log.WithContext(ctx).Errorf("failed to delete peers %s from network map: %v", peerIDs, err)
}

View File

@@ -846,7 +846,7 @@ func TestUser_DeleteUser_regularUser(t *testing.T) {
ctrl := gomock.NewController(t)
networkMapControllerMock := network_map.NewMockController(ctrl)
networkMapControllerMock.EXPECT().
OnPeersDeleted(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).
OnPeersDeleted(gomock.Any(), gomock.Any(), gomock.Any()).
Return(nil)
permissionsManager := permissions.NewManager(store)
@@ -962,7 +962,7 @@ func TestUser_DeleteUser_RegularUsers(t *testing.T) {
ctrl := gomock.NewController(t)
networkMapControllerMock := network_map.NewMockController(ctrl)
networkMapControllerMock.EXPECT().
OnPeersDeleted(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).
OnPeersDeleted(gomock.Any(), gomock.Any(), gomock.Any()).
Return(nil).
AnyTimes()
@@ -1531,14 +1531,11 @@ func TestUserAccountPeersUpdate(t *testing.T) {
}
})
// drain any buffered updates from previous subtests
drainPeerUpdates(updMsg)
// deleting user with no linked peers should not update account peers and not send peer update
t.Run("deleting user with no linked peers", func(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldNotReceiveUpdate(t, updMsg)
peerShouldReceiveUpdate(t, updMsg)
close(done)
}()
@@ -2025,7 +2022,7 @@ func TestUser_Operations_WithEmbeddedIDP(t *testing.T) {
ctrl := gomock.NewController(t)
networkMapControllerMock := network_map.NewMockController(ctrl)
networkMapControllerMock.EXPECT().
OnPeersDeleted(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).
OnPeersDeleted(gomock.Any(), gomock.Any(), gomock.Any()).
Return(nil).
AnyTimes()