Compare commits

...

24 Commits

Author SHA1 Message Date
crn4
eab0826b4e merge main 2026-05-29 14:08:09 +02:00
crn4
7048b87931 fix wasm bin filesize - type aliases 2026-05-29 13:50:06 +02:00
Zoltan Papp
174dc24867 [management] Add SSO session extend flow (management) (#6197)
* add SSO session extend flow (management)

Adds the management-server half of the SSO session-extension feature:

- New ExtendAuthSession gRPC RPC that refreshes a peer's session expiry
  using a fresh JWT, validated through the same pipeline as Login but
  without tearing down the tunnel or redoing the NetworkMap sync.
- Per-peer SessionExpiresAt timestamp on every LoginResponse and
  SyncResponse so connected clients learn the deadline on the existing
  long-lived stream, and admin-side changes (toggling expiration,
  changing the expiration window) reach every peer within seconds.
- SessionExpiresAt(...) helper on Peer that derives the absolute UTC
  deadline from LastLogin + the account-level PeerLoginExpiration
  setting, returning zero when the peer is not SSO-tracked or expiration
  is disabled.

The matching client-side consumer of these fields lands separately.

* encode SessionExpiresAt as 3-state on the wire

Previously the `sessionExpiresAt` field on LoginResponse, SyncResponse
and ExtendAuthSessionResponse was 2-state: a valid timestamp meant
"new deadline", and nil meant "clear". That conflated two distinct
meanings — "no info in this snapshot" vs "expiry is explicitly off /
peer is not SSO-tracked" — so a Sync push that legitimately couldn't
compute the deadline (settings lookup failed) would silently clear the
client's anchor and lose the warning window.

Three states now, encoded on the same field number (no .proto schema
churn — only comments and the server-side encoder change):

  - nil pointer (field absent) → "no info"; client preserves anchor
  - &Timestamp{} (seconds=0, nanos=0) → explicit "disabled / not SSO"
    sentinel; client clears
  - valid timestamp → new absolute UTC deadline

A new encodeSessionExpiresAt helper centralises the zero/non-zero
encoding and is shared by the Sync, Login and ExtendAuthSession
builders. The Sync builder still emits nil when settings are missing.
Login and ExtendAuthSession always carry an authoritative value.

The matching client-side decoder lands on feature/session-extend.

* add UserExtendedPeerSession activity event

ExtendAuthSession previously reused UserLoggedInPeer for its audit
record, which conflated two distinct user actions: a full interactive
SSO login (tunnel re-established, network map resync) versus an
in-place deadline refresh (tunnel untouched). Auditors reading the log
couldn't tell which one happened, and downstream dashboards/alerts on
"login" volume were polluted by routine extends.

Adds a dedicated UserExtendedPeerSession Activity (code 125,
"user.peer.session.extend") and switches ExtendPeerSession over to it.
The peer-extend audit trail is now distinguishable from interactive
logins.

* make ExtendAuthSession JWT-retry backoff cancellable

Skip the retry log and 200ms wait on the final attempt, and replace the
uncancellable time.Sleep with a select on time.After/ctx.Done so an
upstream cancellation aborts the wait instead of running it to
completion.
2026-05-28 19:14:14 +02:00
crn4
596952265d wgkey as peer id 2026-05-28 12:52:15 +02:00
Riccardo Manfrin
7ea5e37dd4 [client] Improve rosenpass support (#6136)
* Updates rosenpass version

go-rosenpass v0.4.0 → v0.5.42 bump — detailed findings

Change summary
cunicu.li/go-rosenpass  v0.4.0  → v0.5.42   (target)
cilium/ebpf             v0.15.0 → v0.19.0   (transitive)
gopacket/gopacket       v1.1.1  → v1.4.0    (transitive)
wireguard               2023-07 → 2023-12   (transitive)
wireguard/wgctrl        2023-04 → 2024-12   (transitive)

Wire interop

v0.4.0 (in v0.70.5) <-> v0.5.42 OK
v0.5.42 <-> v0.5.42 OK

Quantum resistance: true both ends

---
**Replay error eliminated.**

Before (on v0.4.0):

`ERROR Failed to handle message: failed to load biscuit (ICR1): detected replay`

Recurring every ~50ms for minutes at a time. Gone entirely after both ends upgraded to v0.5.42. Upstream fix in biscuit/replay handling between v0.4.x and v0.5.x series.

* Fixup [::]:port socket trying to send to v4

* Adds more tests on netbird<->rosenpass interactions

* Anticipates rp handler creation before generateConfig

* [client] Moves deterministic key gen into rosenpass

* go mod tidy

* Adds reminder to reason about rosenpass surface area

* Apply code rabbit suggestions
2026-05-28 09:01:18 +02:00
Riccardo Manfrin
9d7ef9b255 [client] Fix statemanager possible deadlock (#6228)
1. Stop() takes m.mu.Lock() and defers m.mu.Unlock()
2. <-m.done under lock
3. periodicStateSave defers close(m.done)
4. periodicStateSave calls PersistState() (line 256) which does m.mu.Lock()

Double Stop() remains idempotent: second cancel() on dead ctx
 (no-op) and reads done already closed (immediate return).
2026-05-28 08:54:15 +02:00
Pascal Fischer
944a258459 [management] extend nmap monitoring (#6271) 2026-05-27 16:56:02 +02:00
crn4
21cfec93d4 comments cleanup 2026-05-27 16:51:55 +02:00
Pascal Fischer
1f9a829f2c [management] update log levels (#6266) 2026-05-27 11:43:49 +02:00
crn4
98818e3095 merge main\ 2026-05-27 10:50:24 +02:00
crn4
5d5c2d9f95 filtering fix 2026-05-26 11:33:21 +02:00
crn4
13e41e432c idp dex fix 2026-05-21 15:21:28 +02:00
crn4
efa6a3f502 missed file 2026-05-20 12:41:05 +02:00
crn4
5fbcdeceac more comments 2026-05-19 21:41:08 +02:00
crn4
3a1bbeba90 review comments 2026-05-19 20:27:50 +02:00
crn4
728057ef15 missed files for client side and shared files 2026-05-19 14:46:23 +02:00
crn4
582cd70086 client side and components on shared folder 2026-05-19 14:46:09 +02:00
crn4
9bbbafaf69 int id for networks and posture checks migration 2026-05-19 14:45:40 +02:00
crn4
672b057aa0 fix Group.Copy losing AccountSeqID 2026-05-19 14:43:59 +02:00
crn4
b9a0186200 fix routes filtering in account componnents 2026-05-19 14:43:49 +02:00
crn4
9083bdb977 capabilities conditioning 2026-05-19 14:43:38 +02:00
crn4
b194af48b8 wire size benches fix 2026-05-19 14:43:28 +02:00
crn4
4543780ef0 grpc components encoding with optimisations 2026-05-19 14:43:17 +02:00
crn4
2de0283971 init int inds migration 2026-05-19 14:42:55 +02:00
90 changed files with 11359 additions and 1553 deletions

View File

@@ -35,7 +35,7 @@ jobs:
display_name: Linux display_name: Linux
name: ${{ matrix.display_name }} name: ${{ matrix.display_name }}
runs-on: ${{ matrix.os }} runs-on: ${{ matrix.os }}
timeout-minutes: 15 timeout-minutes: 25
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v4 uses: actions/checkout@v4
@@ -58,4 +58,4 @@ jobs:
skip-cache: true skip-cache: true
skip-save-cache: true skip-save-cache: true
cache-invalidation-interval: 0 cache-invalidation-interval: 0
args: --timeout=12m args: --timeout=20m

View File

@@ -61,9 +61,11 @@ import (
cProto "github.com/netbirdio/netbird/client/proto" cProto "github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/client/system"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
types "github.com/netbirdio/netbird/shared/management/types"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
mgm "github.com/netbirdio/netbird/shared/management/client" mgm "github.com/netbirdio/netbird/shared/management/client"
"github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/shared/management/domain"
nbnetworkmap "github.com/netbirdio/netbird/shared/management/networkmap"
mgmProto "github.com/netbirdio/netbird/shared/management/proto" mgmProto "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/shared/netiputil" "github.com/netbirdio/netbird/shared/netiputil"
auth "github.com/netbirdio/netbird/shared/relay/auth/hmac" auth "github.com/netbirdio/netbird/shared/relay/auth/hmac"
@@ -202,6 +204,13 @@ type Engine struct {
// networkSerial is the latest CurrentSerial (state ID) of the network sent by the Management service // networkSerial is the latest CurrentSerial (state ID) of the network sent by the Management service
networkSerial uint64 networkSerial uint64
// latestComponents is the most-recent NetworkMapComponents decoded from
// a NetworkMapEnvelope (capability=3 peers only). Held alongside the
// NetworkMap that Calculate() produced from it so future incremental
// updates have a base to apply changes against. nil for legacy-format
// peers. Guarded by syncMsgMux.
latestComponents *types.NetworkMapComponents
networkMonitor *networkmonitor.NetworkMonitor networkMonitor *networkmonitor.NetworkMonitor
sshServer sshServer sshServer sshServer
@@ -865,8 +874,12 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
return e.ctx.Err() return e.ctx.Err()
} }
if update.NetworkMap != nil && update.NetworkMap.PeerConfig != nil { // Envelope sync responses carry PeerConfig at the top level; legacy
e.handleAutoUpdateVersion(update.NetworkMap.PeerConfig.AutoUpdate) // NetworkMap syncs carry it under NetworkMap.PeerConfig.
if pc := update.GetPeerConfig(); pc != nil {
e.handleAutoUpdateVersion(pc.GetAutoUpdate())
} else if nm := update.GetNetworkMap(); nm != nil && nm.GetPeerConfig() != nil {
e.handleAutoUpdateVersion(nm.GetPeerConfig().GetAutoUpdate())
} }
if update.GetNetbirdConfig() != nil { if update.GetNetbirdConfig() != nil {
@@ -907,11 +920,45 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
return err return err
} }
nm := update.GetNetworkMap() var (
nm *mgmProto.NetworkMap
components *types.NetworkMapComponents
)
if envelope := update.GetNetworkMapEnvelope(); envelope != nil {
// Components-format peer: decode the envelope back to typed
// components, run Calculate() locally, and convert to the wire
// NetworkMap shape the rest of the engine consumes. Components are
// retained so future incremental updates can apply deltas instead
// of doing a full reconstruction.
localKey := e.config.WgPrivateKey.PublicKey().String()
dnsName := ""
if pc := update.GetPeerConfig(); pc != nil {
// PeerConfig.Fqdn = "<dns_label>.<dns_domain>" — extract the
// shared domain by stripping the peer's own label prefix. Falls
// back to empty if the FQDN doesn't have the expected shape.
dnsName = extractDNSDomainFromFQDN(pc.GetFqdn())
}
result, err := nbnetworkmap.EnvelopeToNetworkMap(e.ctx, envelope, localKey, dnsName)
if err != nil {
return fmt.Errorf("decode network map envelope: %w", err)
}
nm = result.NetworkMap
components = result.Components
} else {
nm = update.GetNetworkMap()
}
if nm == nil { if nm == nil {
return nil return nil
} }
// Only retain the components view when the server sent the envelope
// path. A legacy proto.NetworkMap means components == nil; writing it
// here would clobber a previously-cached snapshot, breaking the
// incremental-delta base on a future envelope sync.
if components != nil {
e.latestComponents = components
}
// Persist sync response under the dedicated lock (syncRespMux), not under syncMsgMux. // Persist sync response under the dedicated lock (syncRespMux), not under syncMsgMux.
// Read the storage-enabled flag under the syncRespMux too. // Read the storage-enabled flag under the syncRespMux too.
e.syncRespMux.RLock() e.syncRespMux.RLock()
@@ -937,6 +984,19 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
return nil return nil
} }
// extractDNSDomainFromFQDN returns the trailing dotted domain part of the
// receiving peer's FQDN — the same value the management server fills as
// dnsName when it builds the legacy NetworkMap. "peer42.netbird.cloud" →
// "netbird.cloud". An empty string is returned for unrecognized formats.
func extractDNSDomainFromFQDN(fqdn string) string {
for i := 0; i < len(fqdn); i++ {
if fqdn[i] == '.' && i+1 < len(fqdn) {
return fqdn[i+1:]
}
}
return ""
}
func (e *Engine) handleRelayUpdate(update *mgmProto.RelayConfig) error { func (e *Engine) handleRelayUpdate(update *mgmProto.RelayConfig) error {
if update != nil { if update != nil {
// when we receive token we expect valid address list too // when we receive token we expect valid address list too

View File

@@ -23,6 +23,7 @@ import (
"github.com/netbirdio/netbird/client/internal/peer/id" "github.com/netbirdio/netbird/client/internal/peer/id"
"github.com/netbirdio/netbird/client/internal/peer/worker" "github.com/netbirdio/netbird/client/internal/peer/worker"
"github.com/netbirdio/netbird/client/internal/portforward" "github.com/netbirdio/netbird/client/internal/portforward"
"github.com/netbirdio/netbird/client/internal/rosenpass"
"github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
relayClient "github.com/netbirdio/netbird/shared/relay/client" relayClient "github.com/netbirdio/netbird/shared/relay/client"
@@ -899,7 +900,7 @@ func (conn *Conn) presharedKey(remoteRosenpassKey []byte) *wgtypes.Key {
} }
// Fallback to deterministic key if no NetBird PSK is configured // Fallback to deterministic key if no NetBird PSK is configured
determKey, err := conn.rosenpassDetermKey() determKey, err := rosenpass.DeterministicSeedKey(conn.config.LocalKey, conn.config.Key)
if err != nil { if err != nil {
conn.Log.Errorf("failed to generate Rosenpass initial key: %v", err) conn.Log.Errorf("failed to generate Rosenpass initial key: %v", err)
return nil return nil
@@ -908,26 +909,6 @@ func (conn *Conn) presharedKey(remoteRosenpassKey []byte) *wgtypes.Key {
return determKey return determKey
} }
// todo: move this logic into Rosenpass package
func (conn *Conn) rosenpassDetermKey() (*wgtypes.Key, error) {
lk := []byte(conn.config.LocalKey)
rk := []byte(conn.config.Key) // remote key
var keyInput []byte
if string(lk) > string(rk) {
//nolint:gocritic
keyInput = append(lk[:16], rk[:16]...)
} else {
//nolint:gocritic
keyInput = append(rk[:16], lk[:16]...)
}
key, err := wgtypes.NewKey(keyInput)
if err != nil {
return nil, err
}
return &key, nil
}
func isController(config ConnConfig) bool { func isController(config ConnConfig) bool {
return config.LocalKey > config.Key return config.LocalKey > config.Key
} }

View File

@@ -28,6 +28,15 @@ func hashRosenpassKey(key []byte) string {
return hex.EncodeToString(hasher.Sum(nil)) return hex.EncodeToString(hasher.Sum(nil))
} }
// rpServer is the subset of rp.Server used by Manager. Defined as an interface
// so tests can substitute a mock without spinning up a real UDP server.
type rpServer interface {
AddPeer(rp.PeerConfig) (rp.PeerID, error)
RemovePeer(rp.PeerID) error
Run() error
Close() error
}
type Manager struct { type Manager struct {
ifaceName string ifaceName string
spk []byte spk []byte
@@ -36,7 +45,7 @@ type Manager struct {
preSharedKey *[32]byte preSharedKey *[32]byte
rpPeerIDs map[string]*rp.PeerID rpPeerIDs map[string]*rp.PeerID
rpWgHandler *NetbirdHandler rpWgHandler *NetbirdHandler
server *rp.Server server rpServer
lock sync.Mutex lock sync.Mutex
port int port int
wgIface PresharedKeySetter wgIface PresharedKeySetter
@@ -51,7 +60,22 @@ func NewManager(preSharedKey *wgtypes.Key, wgIfaceName string) (*Manager, error)
rpKeyHash := hashRosenpassKey(public) rpKeyHash := hashRosenpassKey(public)
log.Tracef("generated new rosenpass key pair with public key %s", rpKeyHash) log.Tracef("generated new rosenpass key pair with public key %s", rpKeyHash)
return &Manager{ifaceName: wgIfaceName, rpKeyHash: rpKeyHash, spk: public, ssk: secret, preSharedKey: (*[32]byte)(preSharedKey), rpPeerIDs: make(map[string]*rp.PeerID), lock: sync.Mutex{}}, nil return &Manager{
ifaceName: wgIfaceName,
rpKeyHash: rpKeyHash,
spk: public,
ssk: secret,
preSharedKey: (*[32]byte)(preSharedKey),
rpPeerIDs: make(map[string]*rp.PeerID),
// rpWgHandler is created here (instead of only in generateConfig) so it
// is never nil between NewManager and Run(). Otherwise an early
// OnConnected call (race observed on Android, issue #4341) panics on
// nil receiver in addPeer -> m.rpWgHandler.AddPeer. generateConfig will
// replace it with a fresh handler on each Run() to clear stale peer
// state from previous engine sessions.
rpWgHandler: NewNetbirdHandler(),
lock: sync.Mutex{},
}, nil
} }
func (m *Manager) GetPubKey() []byte { func (m *Manager) GetPubKey() []byte {
@@ -65,6 +89,16 @@ func (m *Manager) GetAddress() *net.UDPAddr {
// addPeer adds a new peer to the Rosenpass server // addPeer adds a new peer to the Rosenpass server
func (m *Manager) addPeer(rosenpassPubKey []byte, rosenpassAddr string, wireGuardIP string, wireGuardPubKey string) error { func (m *Manager) addPeer(rosenpassPubKey []byte, rosenpassAddr string, wireGuardIP string, wireGuardPubKey string) error {
// Defense in depth against issue #4341 (Android crash): if Run() has not
// completed yet, m.server / m.rpWgHandler may be nil. Return an explicit
// error instead of panicking on nil-receiver dereference.
if m.server == nil {
return fmt.Errorf("rosenpass server not initialized")
}
if m.rpWgHandler == nil {
return fmt.Errorf("rosenpass wg handler not initialized")
}
var err error var err error
pcfg := rp.PeerConfig{PublicKey: rosenpassPubKey} pcfg := rp.PeerConfig{PublicKey: rosenpassPubKey}
if m.preSharedKey != nil { if m.preSharedKey != nil {
@@ -79,6 +113,16 @@ func (m *Manager) addPeer(rosenpassPubKey []byte, rosenpassAddr string, wireGuar
if pcfg.Endpoint, err = net.ResolveUDPAddr("udp", peerAddr); err != nil { if pcfg.Endpoint, err = net.ResolveUDPAddr("udp", peerAddr); err != nil {
return fmt.Errorf("failed to resolve peer endpoint address: %w", err) return fmt.Errorf("failed to resolve peer endpoint address: %w", err)
} }
// Our local Rosenpass UDP server binds on the IPv6 wildcard ([::]) — see
// GetAddress(). The remote peer's endpoint (pcfg.Endpoint) is the destination
// our server will sendto when initiating handshakes. ResolveUDPAddr returns a
// 4-byte IPv4 for IPv4 hosts, which the kernel rejects (EDESTADDRREQ) when
// sent from an AF_INET6 socket. Normalize the remote endpoint to IPv4-mapped
// IPv6 so its address family matches our listening socket.
// TODO: maybe bind the Rosenpass UDP server to the peer wg IP addr
if v4 := pcfg.Endpoint.IP.To4(); v4 != nil {
pcfg.Endpoint.IP = v4.To16()
}
} }
peerID, err := m.server.AddPeer(pcfg) peerID, err := m.server.AddPeer(pcfg)
if err != nil { if err != nil {
@@ -182,24 +226,31 @@ func (m *Manager) Run() error {
return err return err
} }
m.server, err = rp.NewUDPServer(conf) server, err := rp.NewUDPServer(conf)
if err != nil { if err != nil {
return err return err
} }
m.lock.Lock()
m.server = server
m.lock.Unlock()
log.Infof("starting rosenpass server on port %d", m.port) log.Infof("starting rosenpass server on port %d", m.port)
return m.server.Run() return server.Run()
} }
// Close closes the Rosenpass server // Close closes the Rosenpass server
func (m *Manager) Close() error { func (m *Manager) Close() error {
if m.server != nil { m.lock.Lock()
err := m.server.Close() server := m.server
if err != nil { m.server = nil
log.Errorf("failed closing local rosenpass server") m.lock.Unlock()
} if server == nil {
m.server = nil return nil
}
if err := server.Close(); err != nil {
log.Errorf("failed closing local rosenpass server: %v", err)
} }
return nil return nil
} }

View File

@@ -1,14 +1,412 @@
package rosenpass package rosenpass
import ( import (
"errors"
"os"
"sync"
"testing" "testing"
rp "cunicu.li/go-rosenpass"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
) )
// --- test doubles -----------------------------------------------------------
type addPeerCall struct {
cfg rp.PeerConfig
}
type removePeerCall struct {
id rp.PeerID
}
type mockServer struct {
mu sync.Mutex
addCalls []addPeerCall
removed []removePeerCall
nextID rp.PeerID
addErr error
removeErr error
closed bool
ran bool
}
func (m *mockServer) AddPeer(cfg rp.PeerConfig) (rp.PeerID, error) {
m.mu.Lock()
defer m.mu.Unlock()
m.addCalls = append(m.addCalls, addPeerCall{cfg: cfg})
if m.addErr != nil {
return rp.PeerID{}, m.addErr
}
// Increment a byte in nextID so distinct peers get distinct IDs.
m.nextID[0]++
return m.nextID, nil
}
func (m *mockServer) RemovePeer(id rp.PeerID) error {
m.mu.Lock()
defer m.mu.Unlock()
m.removed = append(m.removed, removePeerCall{id: id})
return m.removeErr
}
func (m *mockServer) Run() error { m.ran = true; return nil }
func (m *mockServer) Close() error { m.closed = true; return nil }
type setPSKCall struct {
peerKey string
psk wgtypes.Key
updateOnly bool
}
type mockIface struct {
mu sync.Mutex
calls []setPSKCall
err error
}
func (m *mockIface) SetPresharedKey(peerKey string, psk wgtypes.Key, updateOnly bool) error {
m.mu.Lock()
defer m.mu.Unlock()
m.calls = append(m.calls, setPSKCall{peerKey: peerKey, psk: psk, updateOnly: updateOnly})
return m.err
}
// newTestManager builds a Manager with deterministic spk so tie-break
// against a peer pubkey is controllable from tests. The provided spk byte
// becomes the first byte; remaining bytes are zero.
func newTestManager(spkFirstByte byte, mock *mockServer) *Manager {
spk := make([]byte, 32)
spk[0] = spkFirstByte
return &Manager{
ifaceName: "wt0",
spk: spk,
ssk: make([]byte, 32),
rpKeyHash: "test-hash",
rpPeerIDs: make(map[string]*rp.PeerID),
rpWgHandler: NewNetbirdHandler(),
server: mock,
}
}
// validWGKey returns a deterministic 32-byte wireguard public key (base64).
func validWGKey(t *testing.T, lastByte byte) string {
t.Helper()
var k wgtypes.Key
k[31] = lastByte
return k.String()
}
// --- pure helpers ----------------------------------------------------------
func TestHashRosenpassKey_Deterministic(t *testing.T) {
key := []byte("hello-rosenpass")
require.Equal(t, hashRosenpassKey(key), hashRosenpassKey(key))
require.Len(t, hashRosenpassKey(key), 64) // sha256 hex
}
func TestHashRosenpassKey_DifferentInputsDifferOutputs(t *testing.T) {
require.NotEqual(t, hashRosenpassKey([]byte("a")), hashRosenpassKey([]byte("b")))
}
func TestGetLogLevel_DefaultWhenUnset(t *testing.T) {
// Snapshot + unset to exercise the LookupEnv ok=false branch. t.Setenv
// can only set, not delete, so do it manually with restore via t.Cleanup.
prev, hadPrev := os.LookupEnv(defaultLogLevelVar)
require.NoError(t, os.Unsetenv(defaultLogLevelVar))
t.Cleanup(func() {
if hadPrev {
_ = os.Setenv(defaultLogLevelVar, prev)
} else {
_ = os.Unsetenv(defaultLogLevelVar)
}
})
require.Equal(t, defaultLog.String(), getLogLevel().String())
}
func TestGetLogLevel_Cases(t *testing.T) {
cases := map[string]string{
"debug": "DEBUG",
"info": "INFO",
"warn": "WARN",
"error": "ERROR",
"unknown": "INFO", // default fallback
}
for input, wantStr := range cases {
input, wantStr := input, wantStr
t.Run(input, func(t *testing.T) {
t.Setenv(defaultLogLevelVar, input)
require.Equal(t, wantStr, getLogLevel().String())
})
}
}
func TestFindRandomAvailableUDPPort(t *testing.T) { func TestFindRandomAvailableUDPPort(t *testing.T) {
port, err := findRandomAvailableUDPPort() port, err := findRandomAvailableUDPPort()
require.NoError(t, err) require.NoError(t, err)
require.Greater(t, port, 0) require.Greater(t, port, 0)
require.LessOrEqual(t, port, 65535) require.LessOrEqual(t, port, 65535)
} }
// --- addPeer ---------------------------------------------------------------
func TestAddPeer_HigherLocalPubkey_SetsEndpoint(t *testing.T) {
srv := &mockServer{}
m := newTestManager(0xFF, srv) // local spk lexicographically larger
remotePubKey := make([]byte, 32) // remote spk = all zeros (smaller)
err := m.addPeer(remotePubKey, "rosenpass-host:7000", "100.1.1.1", validWGKey(t, 1))
require.NoError(t, err)
require.Len(t, srv.addCalls, 1)
ep := srv.addCalls[0].cfg.Endpoint
require.NotNil(t, ep, "initiator side must set Endpoint")
require.Equal(t, 7000, ep.Port)
require.Equal(t, "100.1.1.1", ep.IP.String())
}
func TestAddPeer_HigherLocalPubkey_EndpointIPIsIPv4Mapped(t *testing.T) {
// Regression guard for the EDESTADDRREQ fix: Endpoint.IP must be 16-byte
// (IPv4-mapped IPv6) so it matches the AF_INET6 listening socket family.
srv := &mockServer{}
m := newTestManager(0xFF, srv)
err := m.addPeer(make([]byte, 32), "rp:5000", "100.1.1.1", validWGKey(t, 1))
require.NoError(t, err)
ep := srv.addCalls[0].cfg.Endpoint
require.NotNil(t, ep)
require.Len(t, ep.IP, 16, "IPv4 endpoint must be normalized to 16-byte v4-mapped form")
require.True(t, ep.IP.To4() != nil, "Endpoint must still be detected as IPv4")
}
func TestAddPeer_LowerLocalPubkey_LeavesEndpointNil(t *testing.T) {
srv := &mockServer{}
m := newTestManager(0x00, srv) // local spk smaller
remotePubKey := make([]byte, 32)
remotePubKey[0] = 0xFF
err := m.addPeer(remotePubKey, "rp:5000", "100.1.1.1", validWGKey(t, 2))
require.NoError(t, err)
require.Nil(t, srv.addCalls[0].cfg.Endpoint, "responder side must NOT set Endpoint")
}
func TestAddPeer_PresharedKeyPropagated(t *testing.T) {
srv := &mockServer{}
psk := &wgtypes.Key{0x42}
m := newTestManager(0xFF, srv)
m.preSharedKey = (*[32]byte)(psk)
err := m.addPeer(make([]byte, 32), "rp:5000", "100.1.1.1", validWGKey(t, 3))
require.NoError(t, err)
require.Equal(t, [32]byte(*psk), [32]byte(srv.addCalls[0].cfg.PresharedKey))
}
func TestAddPeer_InvalidRosenpassAddr_ReturnsError(t *testing.T) {
srv := &mockServer{}
m := newTestManager(0xFF, srv) // initiator path → parses rosenpassAddr
err := m.addPeer(make([]byte, 32), "not-a-host-port", "100.1.1.1", validWGKey(t, 1))
require.Error(t, err)
require.Empty(t, srv.addCalls, "server.AddPeer must not run when address parse fails")
}
func TestAddPeer_InvalidWireGuardPubKey_ReturnsError(t *testing.T) {
srv := &mockServer{}
m := newTestManager(0xFF, srv)
err := m.addPeer(make([]byte, 32), "rp:5000", "100.1.1.1", "not-a-valid-key")
require.Error(t, err)
}
func TestAddPeer_ServerError_Propagates(t *testing.T) {
srv := &mockServer{addErr: errors.New("boom")}
m := newTestManager(0xFF, srv)
err := m.addPeer(make([]byte, 32), "rp:5000", "100.1.1.1", validWGKey(t, 1))
require.Error(t, err)
}
// Regression guard for issue #4341 (Android crash). If Run() has not completed
// before OnConnected fires, m.rpWgHandler or m.server may be nil. Without the
// nil guards, m.rpWgHandler.AddPeer panics on nil receiver.
func TestAddPeer_NilHandler_ReturnsErrorNoCrash(t *testing.T) {
srv := &mockServer{}
m := newTestManager(0xFF, srv)
m.rpWgHandler = nil // simulate Run() not yet completed
err := m.addPeer(make([]byte, 32), "rp:5000", "100.1.1.1", validWGKey(t, 1))
require.Error(t, err)
require.Contains(t, err.Error(), "wg handler not initialized")
}
func TestAddPeer_NilServer_ReturnsErrorNoCrash(t *testing.T) {
m := newTestManager(0xFF, nil)
m.server = nil // simulate Run() not yet completed
err := m.addPeer(make([]byte, 32), "rp:5000", "100.1.1.1", validWGKey(t, 1))
require.Error(t, err)
require.Contains(t, err.Error(), "server not initialized")
}
// NewManager must pre-initialize rpWgHandler so the nil-receiver crash from
// issue #4341 cannot occur in the window between NewManager and Run().
func TestNewManager_PreInitializesHandler(t *testing.T) {
psk := wgtypes.Key{}
m, err := NewManager(&psk, "wt0")
require.NoError(t, err)
require.NotNil(t, m.rpWgHandler, "rpWgHandler must be initialized in NewManager")
}
func TestAddPeer_RecordsPeerID(t *testing.T) {
srv := &mockServer{}
m := newTestManager(0xFF, srv)
wgKey := validWGKey(t, 5)
err := m.addPeer(make([]byte, 32), "rp:5000", "100.1.1.1", wgKey)
require.NoError(t, err)
require.Contains(t, m.rpPeerIDs, wgKey)
}
// --- OnConnected / OnDisconnected ------------------------------------------
func TestOnConnected_NilRemotePubKey_NoAddPeer(t *testing.T) {
srv := &mockServer{}
m := newTestManager(0xFF, srv)
m.OnConnected(validWGKey(t, 1), nil, "100.1.1.1", "rp:5000")
require.Empty(t, srv.addCalls, "nil remote rosenpass pubkey must skip AddPeer")
require.Empty(t, m.rpPeerIDs)
}
func TestOnConnected_ValidPubKey_CallsAddPeer(t *testing.T) {
srv := &mockServer{}
m := newTestManager(0xFF, srv)
wgKey := validWGKey(t, 1)
m.OnConnected(wgKey, make([]byte, 32), "100.1.1.1", "rp:5000")
require.Len(t, srv.addCalls, 1)
require.Contains(t, m.rpPeerIDs, wgKey)
}
func TestOnDisconnected_UnknownPeer_NoOp(t *testing.T) {
srv := &mockServer{}
m := newTestManager(0xFF, srv)
m.OnDisconnected(validWGKey(t, 99))
require.Empty(t, srv.removed, "unknown peer key must not call RemovePeer")
}
func TestOnDisconnected_KnownPeer_CallsRemoveAndForgets(t *testing.T) {
srv := &mockServer{}
m := newTestManager(0xFF, srv)
wgKey := validWGKey(t, 1)
require.NoError(t, m.addPeer(make([]byte, 32), "rp:5000", "100.1.1.1", wgKey))
require.Contains(t, m.rpPeerIDs, wgKey)
m.OnDisconnected(wgKey)
require.Len(t, srv.removed, 1)
require.NotContains(t, m.rpPeerIDs, wgKey, "peer must be forgotten after disconnect")
}
// --- IsPresharedKeyInitialized ---------------------------------------------
func TestIsPresharedKeyInitialized_UnknownPeer_ReturnsFalse(t *testing.T) {
srv := &mockServer{}
m := newTestManager(0xFF, srv)
require.False(t, m.IsPresharedKeyInitialized(validWGKey(t, 1)))
}
func TestIsPresharedKeyInitialized_AddedButNotHandshaken_ReturnsFalse(t *testing.T) {
srv := &mockServer{}
m := newTestManager(0xFF, srv)
wgKey := validWGKey(t, 2)
require.NoError(t, m.addPeer(make([]byte, 32), "rp:5000", "100.1.1.1", wgKey))
require.False(t, m.IsPresharedKeyInitialized(wgKey))
}
// --- NetbirdHandler.outputKey ----------------------------------------------
func TestHandler_OutputKey_FirstCallUsesUpdateOnlyFalse(t *testing.T) {
h := NewNetbirdHandler()
iface := &mockIface{}
h.SetInterface(iface)
pid := rp.PeerID{0x01}
wgKey := wgtypes.Key{0xAA}
h.AddPeer(pid, "wt0", rp.Key(wgKey))
psk := rp.Key{0xBB}
h.HandshakeCompleted(pid, psk)
require.Len(t, iface.calls, 1)
require.False(t, iface.calls[0].updateOnly, "first PSK rotation must use updateOnly=false")
require.Equal(t, wgKey.String(), iface.calls[0].peerKey)
}
func TestHandler_OutputKey_SubsequentCallsUseUpdateOnlyTrue(t *testing.T) {
h := NewNetbirdHandler()
iface := &mockIface{}
h.SetInterface(iface)
pid := rp.PeerID{0x02}
h.AddPeer(pid, "wt0", rp.Key(wgtypes.Key{0xCC}))
h.HandshakeCompleted(pid, rp.Key{0x01}) // first
h.HandshakeCompleted(pid, rp.Key{0x02}) // second
require.Len(t, iface.calls, 2)
require.False(t, iface.calls[0].updateOnly)
require.True(t, iface.calls[1].updateOnly, "subsequent rotations must use updateOnly=true")
}
func TestHandler_OutputKey_NilInterface_NoCrashNoCall(t *testing.T) {
h := NewNetbirdHandler()
// no SetInterface — iface remains nil
pid := rp.PeerID{0x03}
h.AddPeer(pid, "wt0", rp.Key(wgtypes.Key{}))
// Must not panic.
h.HandshakeCompleted(pid, rp.Key{})
}
func TestHandler_OutputKey_UnknownPeer_NoCall(t *testing.T) {
h := NewNetbirdHandler()
iface := &mockIface{}
h.SetInterface(iface)
h.HandshakeCompleted(rp.PeerID{0xFF}, rp.Key{})
require.Empty(t, iface.calls, "unknown peer id must not trigger SetPresharedKey")
}
func TestHandler_RemovePeer_ClearsInitializedState(t *testing.T) {
h := NewNetbirdHandler()
iface := &mockIface{}
h.SetInterface(iface)
pid := rp.PeerID{0x04}
h.AddPeer(pid, "wt0", rp.Key(wgtypes.Key{0xDD}))
h.HandshakeCompleted(pid, rp.Key{0x01})
require.True(t, h.IsPeerInitialized(pid))
h.RemovePeer(pid)
require.False(t, h.IsPeerInitialized(pid), "RemovePeer must clear initialized flag")
}
func TestHandler_SetInterfaceAfterAddPeer_StillReceivesKey(t *testing.T) {
h := NewNetbirdHandler()
pid := rp.PeerID{0x05}
wgKey := wgtypes.Key{0xEE}
h.AddPeer(pid, "wt0", rp.Key(wgKey))
iface := &mockIface{}
h.SetInterface(iface) // set after AddPeer
h.HandshakeCompleted(pid, rp.Key{0x42})
require.Len(t, iface.calls, 1)
require.Equal(t, wgKey.String(), iface.calls[0].peerKey)
}

View File

@@ -0,0 +1,42 @@
package rosenpass
import (
"fmt"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
// DeterministicSeedKey derives a 32-byte WireGuard preshared key from a pair
// of peer public keys. Both peers, given the same key pair, produce the same
// output regardless of which side runs the function: the inputs are ordered
// lexicographically before concatenation.
//
// NetBird uses this value as the initial Rosenpass-side preshared key when no
// explicit account-level PSK is configured, so both peers converge on the same
// PSK before the first post-quantum handshake completes.
//
// The resulting key MUST NOT be treated as quantum-safe: it is deterministic
// from public keys and exists only to seed WireGuard until Rosenpass rotates
// in a real post-quantum PSK.
func DeterministicSeedKey(localKey, remoteKey string) (*wgtypes.Key, error) {
lk := []byte(localKey)
rk := []byte(remoteKey)
if len(lk) < 16 || len(rk) < 16 {
return nil, fmt.Errorf("rosenpass: peer keys must be at least 16 bytes (got local=%d, remote=%d)", len(lk), len(rk))
}
var keyInput []byte
if localKey > remoteKey {
keyInput = append(keyInput, lk[:16]...)
keyInput = append(keyInput, rk[:16]...)
} else {
keyInput = append(keyInput, rk[:16]...)
keyInput = append(keyInput, lk[:16]...)
}
key, err := wgtypes.NewKey(keyInput)
if err != nil {
return nil, fmt.Errorf("rosenpass: deterministic seed key: %w", err)
}
return &key, nil
}

View File

@@ -0,0 +1,44 @@
package rosenpass
import (
"strings"
"testing"
"github.com/stretchr/testify/require"
)
func TestDeterministicSeedKey_SameForBothSides(t *testing.T) {
// Peer A and peer B must derive the same PSK regardless of which side
// computes it: the function orders inputs internally.
a := strings.Repeat("a", 32)
b := strings.Repeat("b", 32)
keyAB, err := DeterministicSeedKey(a, b)
require.NoError(t, err)
keyBA, err := DeterministicSeedKey(b, a)
require.NoError(t, err)
require.Equal(t, keyAB.String(), keyBA.String(), "swapping arguments must yield identical key")
}
func TestDeterministicSeedKey_ChangesWithKeys(t *testing.T) {
a := strings.Repeat("a", 32)
b := strings.Repeat("b", 32)
c := strings.Repeat("c", 32)
keyAB, err := DeterministicSeedKey(a, b)
require.NoError(t, err)
keyAC, err := DeterministicSeedKey(a, c)
require.NoError(t, err)
require.NotEqual(t, keyAB.String(), keyAC.String(), "different peer pair must yield different key")
}
func TestDeterministicSeedKey_TooShortKey_ReturnsError(t *testing.T) {
short := "short" // < 16 bytes
long := strings.Repeat("x", 32)
_, err := DeterministicSeedKey(short, long)
require.Error(t, err)
_, err = DeterministicSeedKey(long, short)
require.Error(t, err)
}

View File

@@ -96,17 +96,19 @@ func (m *Manager) Stop(ctx context.Context) error {
} }
m.mu.Lock() m.mu.Lock()
defer m.mu.Unlock() cancel := m.cancel
done := m.done
m.mu.Unlock()
if m.cancel == nil { if cancel == nil {
return nil return nil
} }
m.cancel() cancel()
select { select {
case <-ctx.Done(): case <-ctx.Done():
return ctx.Err() return ctx.Err()
case <-m.done: case <-done:
} }
return nil return nil

View File

@@ -53,6 +53,9 @@ type NameServerGroup struct {
ID string `gorm:"primaryKey"` ID string `gorm:"primaryKey"`
// AccountID is a reference to Account that this object belongs // AccountID is a reference to Account that this object belongs
AccountID string `gorm:"index"` AccountID string `gorm:"index"`
// AccountSeqID is a per-account monotonically increasing identifier used as the
// compact wire id when sending NetworkMap components to capable peers.
AccountSeqID uint32 `json:"-" gorm:"index:idx_nameserver_groups_account_seq_id;not null;default:0"`
// Name group name // Name group name
Name string Name string
// Description group description // Description group description

10
go.mod
View File

@@ -3,7 +3,7 @@ module github.com/netbirdio/netbird
go 1.25.5 go 1.25.5
require ( require (
cunicu.li/go-rosenpass v0.4.0 cunicu.li/go-rosenpass v0.5.42
github.com/cenkalti/backoff/v4 v4.3.0 github.com/cenkalti/backoff/v4 v4.3.0
github.com/cloudflare/circl v1.3.3 // indirect github.com/cloudflare/circl v1.3.3 // indirect
github.com/golang/protobuf v1.5.4 github.com/golang/protobuf v1.5.4
@@ -19,8 +19,8 @@ require (
github.com/vishvananda/netlink v1.3.1 github.com/vishvananda/netlink v1.3.1
golang.org/x/crypto v0.50.0 golang.org/x/crypto v0.50.0
golang.org/x/sys v0.43.0 golang.org/x/sys v0.43.0
golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1 golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10
golang.zx2c4.com/wireguard/windows v0.5.3 golang.zx2c4.com/wireguard/windows v0.5.3
google.golang.org/grpc v1.80.0 google.golang.org/grpc v1.80.0
google.golang.org/protobuf v1.36.11 google.golang.org/protobuf v1.36.11
@@ -38,7 +38,7 @@ require (
github.com/aws/aws-sdk-go-v2/service/s3 v1.87.3 github.com/aws/aws-sdk-go-v2/service/s3 v1.87.3
github.com/c-robinson/iplib v1.0.3 github.com/c-robinson/iplib v1.0.3
github.com/caddyserver/certmagic v0.21.3 github.com/caddyserver/certmagic v0.21.3
github.com/cilium/ebpf v0.15.0 github.com/cilium/ebpf v0.19.0
github.com/coder/websocket v1.8.14 github.com/coder/websocket v1.8.14
github.com/coreos/go-iptables v0.7.0 github.com/coreos/go-iptables v0.7.0
github.com/coreos/go-oidc/v3 v3.18.0 github.com/coreos/go-oidc/v3 v3.18.0
@@ -60,7 +60,7 @@ require (
github.com/google/go-cmp v0.7.0 github.com/google/go-cmp v0.7.0
github.com/google/gopacket v1.1.19 github.com/google/gopacket v1.1.19
github.com/google/nftables v0.3.0 github.com/google/nftables v0.3.0
github.com/gopacket/gopacket v1.1.1 github.com/gopacket/gopacket v1.4.0
github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.0.2-0.20240212192251-757544f21357 github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.0.2-0.20240212192251-757544f21357
github.com/hashicorp/go-multierror v1.1.1 github.com/hashicorp/go-multierror v1.1.1
github.com/hashicorp/go-secure-stdlib/base62 v0.1.2 github.com/hashicorp/go-secure-stdlib/base62 v0.1.2

22
go.sum
View File

@@ -7,8 +7,8 @@ cloud.google.com/go/compute/metadata v0.9.0 h1:pDUj4QMoPejqq20dK0Pg2N4yG9zIkYGdB
cloud.google.com/go/compute/metadata v0.9.0/go.mod h1:E0bWwX5wTnLPedCKqk3pJmVgCBSM6qQI1yTBdEb3C10= cloud.google.com/go/compute/metadata v0.9.0/go.mod h1:E0bWwX5wTnLPedCKqk3pJmVgCBSM6qQI1yTBdEb3C10=
codeberg.org/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6 h1:b8xUw3004wk+3ipBhu0VU4RtUJsegMIiqjxSK4++lzA= codeberg.org/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6 h1:b8xUw3004wk+3ipBhu0VU4RtUJsegMIiqjxSK4++lzA=
codeberg.org/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6/go.mod h1:+CauBF6R70Jqcyl8N2hC8pAXYbWkGIezuSbuGLtRhnw= codeberg.org/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6/go.mod h1:+CauBF6R70Jqcyl8N2hC8pAXYbWkGIezuSbuGLtRhnw=
cunicu.li/go-rosenpass v0.4.0 h1:LtPtBgFWY/9emfgC4glKLEqS0MJTylzV6+ChRhiZERw= cunicu.li/go-rosenpass v0.5.42 h1:fRDsGwCxd7DhDgZI1Pxeo8GtNyq8BESZJ7w2/BGGJtU=
cunicu.li/go-rosenpass v0.4.0/go.mod h1:MPbjH9nxV4l3vEagKVdFNwHOketqgS5/To1VYJplf/M= cunicu.li/go-rosenpass v0.5.42/go.mod h1:YRBeyKOe/gWpSX2kpDUec5p9t0XOLsshTguId5gTGVg=
dario.cat/mergo v1.0.1 h1:Ra4+bf83h2ztPIQYNP99R6m+Y7KfnARDfID+a+vLl4s= dario.cat/mergo v1.0.1 h1:Ra4+bf83h2ztPIQYNP99R6m+Y7KfnARDfID+a+vLl4s=
dario.cat/mergo v1.0.1/go.mod h1:uNxQE+84aUszobStD9th8a29P2fMDhsBdgRYvZOxGmk= dario.cat/mergo v1.0.1/go.mod h1:uNxQE+84aUszobStD9th8a29P2fMDhsBdgRYvZOxGmk=
filippo.io/edwards25519 v1.1.1 h1:YpjwWWlNmGIDyXOn8zLzqiD+9TyIlPhGFG96P39uBpw= filippo.io/edwards25519 v1.1.1 h1:YpjwWWlNmGIDyXOn8zLzqiD+9TyIlPhGFG96P39uBpw=
@@ -111,8 +111,8 @@ github.com/cenkalti/backoff/v5 v5.0.3 h1:ZN+IMa753KfX5hd8vVaMixjnqRZ3y8CuJKRKj1x
github.com/cenkalti/backoff/v5 v5.0.3/go.mod h1:rkhZdG3JZukswDf7f0cwqPNk4K0sa+F97BxZthm/crw= github.com/cenkalti/backoff/v5 v5.0.3/go.mod h1:rkhZdG3JZukswDf7f0cwqPNk4K0sa+F97BxZthm/crw=
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/cilium/ebpf v0.15.0 h1:7NxJhNiBT3NG8pZJ3c+yfrVdHY8ScgKD27sScgjLMMk= github.com/cilium/ebpf v0.19.0 h1:Ro/rE64RmFBeA9FGjcTc+KmCeY6jXmryu6FfnzPRIao=
github.com/cilium/ebpf v0.15.0/go.mod h1:DHp1WyrLeiBh19Cf/tfiSMhqheEiK8fXFZ4No0P1Hso= github.com/cilium/ebpf v0.19.0/go.mod h1:fLCgMo3l8tZmAdM3B2XqdFzXBpwkcSTroaVqN08OWVY=
github.com/coder/websocket v1.8.14 h1:9L0p0iKiNOibykf283eHkKUHHrpG7f65OE3BhhO7v9g= github.com/coder/websocket v1.8.14 h1:9L0p0iKiNOibykf283eHkKUHHrpG7f65OE3BhhO7v9g=
github.com/coder/websocket v1.8.14/go.mod h1:NX3SzP+inril6yawo5CQXx8+fk145lPDC6pumgx0mVg= github.com/coder/websocket v1.8.14/go.mod h1:NX3SzP+inril6yawo5CQXx8+fk145lPDC6pumgx0mVg=
github.com/containerd/log v0.1.0 h1:TCJt7ioM2cr/tfR8GPbGf9/VRAX8D2B4PjzCpfX540I= github.com/containerd/log v0.1.0 h1:TCJt7ioM2cr/tfR8GPbGf9/VRAX8D2B4PjzCpfX540I=
@@ -225,8 +225,8 @@ github.com/go-openapi/validate v0.24.0 h1:LdfDKwNbpB6Vn40xhTdNZAnfLECL81w+VX3Bum
github.com/go-openapi/validate v0.24.0/go.mod h1:iyeX1sEufmv3nPbBdX3ieNviWnOZaJ1+zquzJEf2BAQ= github.com/go-openapi/validate v0.24.0/go.mod h1:iyeX1sEufmv3nPbBdX3ieNviWnOZaJ1+zquzJEf2BAQ=
github.com/go-playground/locales v0.12.1/go.mod h1:IUMDtCfWo/w/mtMfIE/IG2K+Ey3ygWanZIBtBW0W2TM= github.com/go-playground/locales v0.12.1/go.mod h1:IUMDtCfWo/w/mtMfIE/IG2K+Ey3ygWanZIBtBW0W2TM=
github.com/go-playground/universal-translator v0.16.0/go.mod h1:1AnU7NaIRDWWzGEKwgtJRd2xk99HeFyHw3yid4rvQIY= github.com/go-playground/universal-translator v0.16.0/go.mod h1:1AnU7NaIRDWWzGEKwgtJRd2xk99HeFyHw3yid4rvQIY=
github.com/go-quicktest/qt v1.101.0 h1:O1K29Txy5P2OK0dGo59b7b0LR6wKfIhttaAhHUyn7eI= github.com/go-quicktest/qt v1.101.1-0.20240301121107-c6c8733fa1e6 h1:teYtXy9B7y5lHTp8V9KPxpYRAVA7dozigQcMiBust1s=
github.com/go-quicktest/qt v1.101.0/go.mod h1:14Bz/f7NwaXPtdYEgzsx46kqSxVwTbzVZsDC26tQJow= github.com/go-quicktest/qt v1.101.1-0.20240301121107-c6c8733fa1e6/go.mod h1:p4lGIVX+8Wa6ZPNDvqcxq36XpUDLh42FLetFU7odllI=
github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI= github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI=
github.com/go-sql-driver/mysql v1.9.3 h1:U/N249h2WzJ3Ukj8SowVFjdtZKfu9vlLZxjPXV1aweo= github.com/go-sql-driver/mysql v1.9.3 h1:U/N249h2WzJ3Ukj8SowVFjdtZKfu9vlLZxjPXV1aweo=
github.com/go-sql-driver/mysql v1.9.3/go.mod h1:qn46aNg1333BRMNU69Lq93t8du/dwxI64Gl8i5p1WMU= github.com/go-sql-driver/mysql v1.9.3/go.mod h1:qn46aNg1333BRMNU69Lq93t8du/dwxI64Gl8i5p1WMU=
@@ -307,8 +307,8 @@ github.com/googleapis/enterprise-certificate-proxy v0.3.14 h1:yh8ncqsbUY4shRD5dA
github.com/googleapis/enterprise-certificate-proxy v0.3.14/go.mod h1:vqVt9yG9480NtzREnTlmGSBmFrA+bzb0yl0TxoBQXOg= github.com/googleapis/enterprise-certificate-proxy v0.3.14/go.mod h1:vqVt9yG9480NtzREnTlmGSBmFrA+bzb0yl0TxoBQXOg=
github.com/googleapis/gax-go/v2 v2.21.0 h1:h45NjjzEO3faG9Lg/cFrBh2PgegVVgzqKzuZl/wMbiI= github.com/googleapis/gax-go/v2 v2.21.0 h1:h45NjjzEO3faG9Lg/cFrBh2PgegVVgzqKzuZl/wMbiI=
github.com/googleapis/gax-go/v2 v2.21.0/go.mod h1:But/NJU6TnZsrLai/xBAQLLz+Hc7fHZJt/hsCz3Fih4= github.com/googleapis/gax-go/v2 v2.21.0/go.mod h1:But/NJU6TnZsrLai/xBAQLLz+Hc7fHZJt/hsCz3Fih4=
github.com/gopacket/gopacket v1.1.1 h1:zbx9F9d6A7sWNkFKrvMBZTfGgxFoY4NgUudFVVHMfcw= github.com/gopacket/gopacket v1.4.0 h1:cr1OlFpzksCkZHNO0eLjaSSOrMQnpPXg0j6qHIY3y2U=
github.com/gopacket/gopacket v1.1.1/go.mod h1:HavMeONEl7W9036of9LbSWoonqhH7HA1+ZRO+rMIvFs= github.com/gopacket/gopacket v1.4.0/go.mod h1:EpvsxINeehp5qj4YMKMLf2/dekdhKn2IIAO/ZOifS7o=
github.com/gorilla/handlers v1.5.2 h1:cLTUSsNkgcwhgRqvCNmdbRWG0A3N4F+M2nWKdScwyEE= github.com/gorilla/handlers v1.5.2 h1:cLTUSsNkgcwhgRqvCNmdbRWG0A3N4F+M2nWKdScwyEE=
github.com/gorilla/handlers v1.5.2/go.mod h1:dX+xVpaxdSw+q0Qek8SSsl3dfMk3jNddUkMzo0GtH0w= github.com/gorilla/handlers v1.5.2/go.mod h1:dX+xVpaxdSw+q0Qek8SSsl3dfMk3jNddUkMzo0GtH0w=
github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY= github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY=
@@ -390,6 +390,8 @@ github.com/jonboulle/clockwork v0.5.0 h1:Hyh9A8u51kptdkR+cqRpT1EebBwTn1oK9YfGYbd
github.com/jonboulle/clockwork v0.5.0/go.mod h1:3mZlmanh0g2NDKO5TWZVJAfofYk64M7XN3SzBPjZF60= github.com/jonboulle/clockwork v0.5.0/go.mod h1:3mZlmanh0g2NDKO5TWZVJAfofYk64M7XN3SzBPjZF60=
github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY= github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY=
github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y=
github.com/jsimonetti/rtnetlink/v2 v2.0.1 h1:xda7qaHDSVOsADNouv7ukSuicKZO7GgVUCXxpaIEIlM=
github.com/jsimonetti/rtnetlink/v2 v2.0.1/go.mod h1:7MoNYNbb3UaDHtF8udiJo/RH6VsTKP1pqKLUTVCvToE=
github.com/json-iterator/go v1.1.7/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= github.com/json-iterator/go v1.1.7/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
github.com/jsummers/gobmp v0.0.0-20230614200233-a9de23ed2e25 h1:YLvr1eE6cdCqjOe972w/cYF+FjW34v27+9Vo5106B4M= github.com/jsummers/gobmp v0.0.0-20230614200233-a9de23ed2e25 h1:YLvr1eE6cdCqjOe972w/cYF+FjW34v27+9Vo5106B4M=
github.com/jsummers/gobmp v0.0.0-20230614200233-a9de23ed2e25/go.mod h1:kLgvv7o6UM+0QSf0QjAse3wReFDsb9qbZJdfexWlrQw= github.com/jsummers/gobmp v0.0.0-20230614200233-a9de23ed2e25/go.mod h1:kLgvv7o6UM+0QSf0QjAse3wReFDsb9qbZJdfexWlrQw=
@@ -900,8 +902,8 @@ golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8T
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg= golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg=
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI= golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 h1:CawjfCvYQH2OU3/TnxLx97WDSUDRABfT18pCOYwc2GE= golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10 h1:3GDAcqdIg1ozBNLgPy4SLT84nfcBjr6rhGtXYtrkWLU=
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6/go.mod h1:3rxYc4HtVcSG9gVaTs2GEBdehh+sYPOwKtyUWEOTb80= golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10/go.mod h1:T97yPqesLiNrOYxkwmhMI0ZIlJDm+p0PMR8eRVeR5tQ=
golang.zx2c4.com/wireguard/windows v0.5.3 h1:On6j2Rpn3OEMXqBq00QEDC7bWSZrPIHKIus8eIuExIE= golang.zx2c4.com/wireguard/windows v0.5.3 h1:On6j2Rpn3OEMXqBq00QEDC7bWSZrPIHKIus8eIuExIE=
golang.zx2c4.com/wireguard/windows v0.5.3/go.mod h1:9TEe8TJmtwyQebdFwAkEWOPr3prrtqm+REGFifP60hI= golang.zx2c4.com/wireguard/windows v0.5.3/go.mod h1:9TEe8TJmtwyQebdFwAkEWOPr3prrtqm+REGFifP60hI=
gonum.org/v1/gonum v0.17.0 h1:VbpOemQlsSMrYmn7T2OUvQ4dqxQXU+ouZFQsZOx50z4= gonum.org/v1/gonum v0.17.0 h1:VbpOemQlsSMrYmn7T2OUvQ4dqxQXU+ouZFQsZOx50z4=

View File

@@ -308,7 +308,7 @@ func (s *Storage) OpenStorage(logger *slog.Logger) (storage.Storage, error) {
if file == "" { if file == "" {
return nil, fmt.Errorf("sqlite3 storage requires 'file' config") return nil, fmt.Errorf("sqlite3 storage requires 'file' config")
} }
return (&sql.SQLite3{File: file}).Open(logger) return newSQLite3(file).Open(logger)
case "postgres": case "postgres":
dsn, _ := s.Config["dsn"].(string) dsn, _ := s.Config["dsn"].(string)
if dsn == "" { if dsn == "" {

View File

@@ -20,7 +20,6 @@ import (
"github.com/dexidp/dex/server" "github.com/dexidp/dex/server"
"github.com/dexidp/dex/server/signer" "github.com/dexidp/dex/server/signer"
"github.com/dexidp/dex/storage" "github.com/dexidp/dex/storage"
"github.com/dexidp/dex/storage/sql"
jose "github.com/go-jose/go-jose/v4" jose "github.com/go-jose/go-jose/v4"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus"
@@ -77,7 +76,7 @@ func NewProvider(ctx context.Context, config *Config) (*Provider, error) {
// Initialize SQLite storage // Initialize SQLite storage
dbPath := filepath.Join(config.DataDir, "oidc.db") dbPath := filepath.Join(config.DataDir, "oidc.db")
sqliteConfig := &sql.SQLite3{File: dbPath} sqliteConfig := newSQLite3(dbPath)
stor, err := sqliteConfig.Open(logger) stor, err := sqliteConfig.Open(logger)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to open storage: %w", err) return nil, fmt.Errorf("failed to open storage: %w", err)

15
idp/dex/sqlite_cgo.go Normal file
View File

@@ -0,0 +1,15 @@
//go:build cgo
package dex
import (
sql "github.com/dexidp/dex/storage/sql"
)
// newSQLite3 builds the dex SQLite3 config. CGO builds use the upstream
// struct that takes a File path. Non-CGO builds get an empty stub whose
// Open() returns the dex "SQLite not available" error — correct behaviour
// for binaries that can't link sqlite3 (e.g. cross-compiled ARM targets).
func newSQLite3(file string) *sql.SQLite3 {
return &sql.SQLite3{File: file}
}

15
idp/dex/sqlite_nocgo.go Normal file
View File

@@ -0,0 +1,15 @@
//go:build !cgo
package dex
import (
sql "github.com/dexidp/dex/storage/sql"
)
// newSQLite3 for non-CGO builds. The dex SQLite3 stub has no fields and its
// Open() returns an error documenting the missing CGO support — correct
// behaviour for cross-compiled artefacts that never actually run the
// embedded IdP. The `file` argument is ignored.
func newSQLite3(_ string) *sql.SQLite3 {
return &sql.SQLite3{}
}

View File

@@ -55,6 +55,12 @@ type Controller struct {
proxyController port_forwarding.Controller proxyController port_forwarding.Controller
integratedPeerValidator integrated_validator.IntegratedValidator integratedPeerValidator integrated_validator.IntegratedValidator
// componentsDisabled, when true, forces the controller to emit legacy
// proto.NetworkMap to every peer regardless of capability. Set once at
// construction and never written after — readers race-free without a
// mutex.
componentsDisabled bool
} }
type bufferUpdate struct { type bufferUpdate struct {
@@ -81,12 +87,27 @@ func NewController(ctx context.Context, store store.Store, metrics telemetry.App
settingsManager: settingsManager, settingsManager: settingsManager,
dnsDomain: dnsDomain, dnsDomain: dnsDomain,
config: config, config: config,
componentsDisabled: parseBoolEnv("NB_NETWORK_MAP_COMPONENTS_DISABLE"),
proxyController: proxyController, proxyController: proxyController,
EphemeralPeersManager: ephemeralPeersManager, EphemeralPeersManager: ephemeralPeersManager,
} }
} }
// PeerNeedsComponents reports whether the gRPC layer should emit the
// component-based wire format for this peer.
func (c *Controller) PeerNeedsComponents(p *nbpeer.Peer) bool {
return p != nil && p.SupportsComponentNetworkMap() && !c.componentsDisabled
}
// parseBoolEnv reads an env var via strconv.ParseBool so callers accept the
// usual "1/t/T/TRUE/true/True" set instead of being strict about a single
// literal.
func parseBoolEnv(key string) bool {
v, _ := strconv.ParseBool(os.Getenv(key))
return v
}
func (c *Controller) OnPeerConnected(ctx context.Context, accountID string, peerID string) (chan *network_map.UpdateMessage, error) { func (c *Controller) OnPeerConnected(ctx context.Context, accountID string, peerID string) (chan *network_map.UpdateMessage, error) {
peer, err := c.repo.GetPeerByID(ctx, accountID, peerID) peer, err := c.repo.GetPeerByID(ctx, accountID, peerID)
if err != nil { if err != nil {
@@ -112,7 +133,7 @@ func (c *Controller) CountStreams() int {
return c.peersUpdateManager.CountStreams() return c.peersUpdateManager.CountStreams()
} }
func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID string) error { func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) error {
log.WithContext(ctx).Tracef("updating peers for account %s from %s", accountID, util.GetCallerName()) log.WithContext(ctx).Tracef("updating peers for account %s from %s", accountID, util.GetCallerName())
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID) account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil { if err != nil {
@@ -175,6 +196,10 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
continue continue
} }
if c.accountManagerMetrics != nil {
c.accountManagerMetrics.CountNmapTriggered(string(reason.Resource), string(reason.Operation))
}
wg.Add(1) wg.Add(1)
semaphore <- struct{}{} semaphore <- struct{}{}
go func(p *nbpeer.Peer) { go func(p *nbpeer.Peer) {
@@ -192,18 +217,26 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
c.metrics.CountCalcPostureChecksDuration(time.Since(start)) c.metrics.CountCalcPostureChecksDuration(time.Since(start))
start = time.Now() start = time.Now()
remotePeerNetworkMap := account.GetPeerNetworkMapFromComponents(ctx, p.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs) result := account.GetPeerNetworkMapResult(ctx, p.ID, c.componentsDisabled, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
c.metrics.CountCalcPeerNetworkMapDuration(time.Since(start)) c.metrics.CountCalcPeerNetworkMapDuration(time.Since(start))
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID] proxyNetworkMap := proxyNetworkMaps[p.ID]
if ok { if result.NetworkMap != nil && proxyNetworkMap != nil {
remotePeerNetworkMap.Merge(proxyNetworkMap) result.NetworkMap.Merge(proxyNetworkMap)
} }
peerGroups := account.GetPeerGroups(p.ID) peerGroups := account.GetPeerGroups(p.ID)
start = time.Now() start = time.Now()
update := grpc.ToSyncResponse(ctx, nil, c.config.HttpConfig, c.config.DeviceAuthorizationFlow, p, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSetting, maps.Keys(peerGroups), dnsFwdPort) var update *proto.SyncResponse
if result.IsComponents() {
// proxyNetworkMap rides the envelope as a ProxyPatch sidecar;
// the client merges it into Calculate()'s output the same
// way the legacy server did via NetworkMap.Merge.
update = grpc.ToComponentSyncResponse(ctx, nil, c.config.HttpConfig, c.config.DeviceAuthorizationFlow, p, nil, nil, result.Components, proxyNetworkMap, dnsDomain, postureChecks, account.Settings, extraSetting, maps.Keys(peerGroups), dnsFwdPort)
} else {
update = grpc.ToSyncResponse(ctx, nil, c.config.HttpConfig, c.config.DeviceAuthorizationFlow, p, nil, nil, result.NetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSetting, maps.Keys(peerGroups), dnsFwdPort)
}
c.metrics.CountToSyncResponseDuration(time.Since(start)) c.metrics.CountToSyncResponseDuration(time.Since(start))
c.peersUpdateManager.SendUpdate(ctx, p.ID, &network_map.UpdateMessage{ c.peersUpdateManager.SendUpdate(ctx, p.ID, &network_map.UpdateMessage{
@@ -242,14 +275,14 @@ func (c *Controller) bufferSendUpdateAccountPeers(ctx context.Context, accountID
go func() { go func() {
defer b.mu.Unlock() defer b.mu.Unlock()
_ = c.sendUpdateAccountPeers(ctx, accountID) _ = c.sendUpdateAccountPeers(ctx, accountID, reason)
if !b.update.Load() { if !b.update.Load() {
return return
} }
b.update.Store(false) b.update.Store(false)
if b.next == nil { if b.next == nil {
b.next = time.AfterFunc(time.Duration(c.updateAccountPeersBufferInterval.Load()), func() { b.next = time.AfterFunc(time.Duration(c.updateAccountPeersBufferInterval.Load()), func() {
_ = c.sendUpdateAccountPeers(ctx, accountID) _ = c.sendUpdateAccountPeers(ctx, accountID, reason)
}) })
return return
} }
@@ -265,7 +298,7 @@ func (c *Controller) UpdateAccountPeers(ctx context.Context, accountID string, r
if c.accountManagerMetrics != nil { if c.accountManagerMetrics != nil {
c.accountManagerMetrics.CountUpdateAccountPeersTriggered(string(reason.Resource), string(reason.Operation)) c.accountManagerMetrics.CountUpdateAccountPeersTriggered(string(reason.Resource), string(reason.Operation))
} }
return c.sendUpdateAccountPeers(ctx, accountID) return c.sendUpdateAccountPeers(ctx, accountID, reason)
} }
func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, peerId string) error { func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, peerId string) error {
@@ -314,11 +347,11 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe
return err return err
} }
remotePeerNetworkMap := account.GetPeerNetworkMapFromComponents(ctx, peerId, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs) result := account.GetPeerNetworkMapResult(ctx, peerId, c.componentsDisabled, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID] proxyNetworkMap := proxyNetworkMaps[peer.ID]
if ok { if result.NetworkMap != nil && proxyNetworkMap != nil {
remotePeerNetworkMap.Merge(proxyNetworkMap) result.NetworkMap.Merge(proxyNetworkMap)
} }
extraSettings, err := c.settingsManager.GetExtraSettings(ctx, peer.AccountID) extraSettings, err := c.settingsManager.GetExtraSettings(ctx, peer.AccountID)
@@ -329,7 +362,12 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe
peerGroups := account.GetPeerGroups(peerId) peerGroups := account.GetPeerGroups(peerId)
dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), network_map.DnsForwarderPortMinVersion) dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), network_map.DnsForwarderPortMinVersion)
update := grpc.ToSyncResponse(ctx, nil, c.config.HttpConfig, c.config.DeviceAuthorizationFlow, peer, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSettings, maps.Keys(peerGroups), dnsFwdPort) var update *proto.SyncResponse
if result.IsComponents() {
update = grpc.ToComponentSyncResponse(ctx, nil, c.config.HttpConfig, c.config.DeviceAuthorizationFlow, peer, nil, nil, result.Components, proxyNetworkMap, dnsDomain, postureChecks, account.Settings, extraSettings, maps.Keys(peerGroups), dnsFwdPort)
} else {
update = grpc.ToSyncResponse(ctx, nil, c.config.HttpConfig, c.config.DeviceAuthorizationFlow, peer, nil, nil, result.NetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSettings, maps.Keys(peerGroups), dnsFwdPort)
}
c.peersUpdateManager.SendUpdate(ctx, peer.ID, &network_map.UpdateMessage{ c.peersUpdateManager.SendUpdate(ctx, peer.ID, &network_map.UpdateMessage{
Update: update, Update: update,
MessageType: network_map.MessageTypeNetworkMap, MessageType: network_map.MessageTypeNetworkMap,
@@ -359,14 +397,14 @@ func (c *Controller) BufferUpdateAccountPeers(ctx context.Context, accountID str
go func() { go func() {
defer b.mu.Unlock() defer b.mu.Unlock()
_ = c.sendUpdateAccountPeers(ctx, accountID) _ = c.sendUpdateAccountPeers(ctx, accountID, reason)
if !b.update.Load() { if !b.update.Load() {
return return
} }
b.update.Store(false) b.update.Store(false)
if b.next == nil { if b.next == nil {
b.next = time.AfterFunc(time.Duration(c.updateAccountPeersBufferInterval.Load()), func() { b.next = time.AfterFunc(time.Duration(c.updateAccountPeersBufferInterval.Load()), func() {
_ = c.sendUpdateAccountPeers(ctx, accountID) _ = c.sendUpdateAccountPeers(ctx, accountID, reason)
}) })
return return
} }
@@ -376,6 +414,66 @@ func (c *Controller) BufferUpdateAccountPeers(ctx context.Context, accountID str
return nil return nil
} }
// GetValidatedPeerWithComponents is the components-format counterpart of
// GetValidatedPeerWithMap. It returns raw NetworkMapComponents for capable
// peers along with the proxy NetworkMap fragment (BYOP / port-forwarding
// data the legacy server folds in via NetworkMap.Merge). The gRPC layer
// encodes both into the wire envelope. Callers must gate on capability
// themselves before dispatching here — this method does NOT branch on it.
func (c *Controller) GetValidatedPeerWithComponents(ctx context.Context, isRequiresApproval bool, accountID string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMapComponents, *types.NetworkMap, []*posture.Checks, int64, error) {
if isRequiresApproval {
network, err := c.repo.GetAccountNetwork(ctx, accountID)
if err != nil {
return nil, nil, nil, nil, 0, err
}
return peer, &types.NetworkMapComponents{Network: network.Copy()}, nil, nil, 0, nil
}
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return nil, nil, nil, nil, 0, err
}
account.InjectProxyPolicies(ctx)
approvedPeersMap, err := c.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
if err != nil {
return nil, nil, nil, nil, 0, err
}
postureChecks, err := c.getPeerPostureChecks(account, peer.ID)
if err != nil {
return nil, nil, nil, nil, 0, err
}
accountZones, err := c.repo.GetAccountZones(ctx, account.Id)
if err != nil {
return nil, nil, nil, nil, 0, err
}
// Fetch the proxy network map fragment for this peer alongside the
// components — same single-account-load path the streaming controller
// uses, so initial-sync delivers BYOP/forwarding patches synchronously
// instead of waiting for the next streaming push.
proxyNetworkMaps, err := c.proxyController.GetProxyNetworkMaps(ctx, account.Id, peer.ID, account.Peers)
if err != nil {
log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err)
return nil, nil, nil, nil, 0, err
}
dnsDomain := c.GetDNSDomain(account.Settings)
peersCustomZone := account.GetPeersCustomZone(ctx, dnsDomain)
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
groupIDToUserIDs := account.GetActiveGroupUsers()
components := account.GetPeerNetworkMapComponents(ctx, peer.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, groupIDToUserIDs)
dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), network_map.DnsForwarderPortMinVersion)
return peer, components, proxyNetworkMaps[peer.ID], postureChecks, dnsFwdPort, nil
}
func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) { func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
if isRequiresApproval { if isRequiresApproval {
network, err := c.repo.GetAccountNetwork(ctx, accountID) network, err := c.repo.GetAccountNetwork(ctx, accountID)

View File

@@ -22,6 +22,10 @@ type Controller interface {
UpdateAccountPeer(ctx context.Context, accountId string, peerId string) error UpdateAccountPeer(ctx context.Context, accountId string, peerId string) error
BufferUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) error BufferUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) error
GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, p *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, p *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
GetValidatedPeerWithComponents(ctx context.Context, isRequiresApproval bool, accountID string, p *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMapComponents, *types.NetworkMap, []*posture.Checks, int64, error)
// PeerNeedsComponents combines the peer's advertised capability with the
// kill-switch flag — the only public predicate gRPC layers should ask.
PeerNeedsComponents(p *nbpeer.Peer) bool
GetDNSDomain(settings *types.Settings) string GetDNSDomain(settings *types.Settings) string
StartWarmup(context.Context) StartWarmup(context.Context)
GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error) GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error)

View File

@@ -130,6 +130,39 @@ func (mr *MockControllerMockRecorder) GetValidatedPeerWithMap(ctx, isRequiresApp
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetValidatedPeerWithMap", reflect.TypeOf((*MockController)(nil).GetValidatedPeerWithMap), ctx, isRequiresApproval, accountID, p) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetValidatedPeerWithMap", reflect.TypeOf((*MockController)(nil).GetValidatedPeerWithMap), ctx, isRequiresApproval, accountID, p)
} }
// GetValidatedPeerWithComponents mocks base method.
func (m *MockController) GetValidatedPeerWithComponents(ctx context.Context, isRequiresApproval bool, accountID string, p *peer.Peer) (*peer.Peer, *types.NetworkMapComponents, *types.NetworkMap, []*posture.Checks, int64, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetValidatedPeerWithComponents", ctx, isRequiresApproval, accountID, p)
ret0, _ := ret[0].(*peer.Peer)
ret1, _ := ret[1].(*types.NetworkMapComponents)
ret2, _ := ret[2].(*types.NetworkMap)
ret3, _ := ret[3].([]*posture.Checks)
ret4, _ := ret[4].(int64)
ret5, _ := ret[5].(error)
return ret0, ret1, ret2, ret3, ret4, ret5
}
// GetValidatedPeerWithComponents indicates an expected call of GetValidatedPeerWithComponents.
func (mr *MockControllerMockRecorder) GetValidatedPeerWithComponents(ctx, isRequiresApproval, accountID, p any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetValidatedPeerWithComponents", reflect.TypeOf((*MockController)(nil).GetValidatedPeerWithComponents), ctx, isRequiresApproval, accountID, p)
}
// PeerNeedsComponents mocks base method.
func (m *MockController) PeerNeedsComponents(p *peer.Peer) bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "PeerNeedsComponents", p)
ret0, _ := ret[0].(bool)
return ret0
}
// PeerNeedsComponents indicates an expected call of PeerNeedsComponents.
func (mr *MockControllerMockRecorder) PeerNeedsComponents(p any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PeerNeedsComponents", reflect.TypeOf((*MockController)(nil).PeerNeedsComponents), p)
}
// OnPeerConnected mocks base method. // OnPeerConnected mocks base method.
func (m *MockController) OnPeerConnected(ctx context.Context, accountID, peerID string) (chan *UpdateMessage, error) { func (m *MockController) OnPeerConnected(ctx context.Context, accountID, peerID string) (chan *UpdateMessage, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()

View File

@@ -51,7 +51,7 @@ func (p *PeersUpdateManager) SendUpdate(ctx context.Context, peerID string, upda
found = true found = true
select { select {
case channel <- update: case channel <- update:
log.WithContext(ctx).Debugf("update was sent to channel for peer %s", peerID) log.WithContext(ctx).Tracef("update was sent to channel for peer %s", peerID)
default: default:
dropped = true dropped = true
log.WithContext(ctx).Warnf("channel for peer %s is %d full or closed", peerID, len(channel)) log.WithContext(ctx).Warnf("channel for peer %s is %d full or closed", peerID, len(channel))

View File

@@ -0,0 +1,813 @@
package grpc
import (
"encoding/base64"
"strconv"
nbdns "github.com/netbirdio/netbird/dns"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/types"
nbroute "github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/networkmap"
"github.com/netbirdio/netbird/shared/management/proto"
)
// wgKeyRawLen is the raw byte length of a WireGuard public key.
const wgKeyRawLen = 32
// ComponentsEnvelopeInput bundles the data the component-format encoder needs.
// The envelope is fully self-contained — every field needed by the client's
// local Calculate() comes from the components struct itself. The only
// externally-supplied data is the receiving peer's PeerConfig (which is
// computed alongside the components in the network_map controller and reused
// from the legacy proto path) and the dns_domain string.
type ComponentsEnvelopeInput struct {
Components *types.NetworkMapComponents
PeerConfig *proto.PeerConfig
DNSDomain string
DNSForwarderPort int64
// UserIDClaim is the OIDC claim name the client should embed in
// SshAuth.UserIDClaim when reconstructing the NetworkMap. Empty value
// is OK — client treats empty as "no SshAuth to build".
UserIDClaim string
// ProxyPatch carries pre-expanded NetworkMap fragments injected by
// external controllers (BYOP/port-forwarding). Nil when no proxy data
// is present; encoder skips the field in that case.
ProxyPatch *proto.ProxyPatch
}
// EncodeNetworkMapEnvelope converts NetworkMapComponents into the component
// wire envelope. The encoder is intentionally non-deterministic: it iterates
// Go maps in their native (random) order. Indexes inside the envelope
// (peer_indexes, source_group_ids, agent_version_idx, router_peer_indexes)
// are self-consistent within a single encode, so the decoder reconstructs
// the same typed objects regardless of emit order. Tests that need to
// compare envelopes do so semantically via proto round-trip + canonicalize,
// not byte-equal.
//
// Callers must NOT concatenate or merge envelopes from different encodes —
// index spaces are local to a single envelope.
func EncodeNetworkMapEnvelope(in ComponentsEnvelopeInput) *proto.NetworkMapEnvelope {
c := in.Components
// Graceful degrade when components is nil — matches the legacy path's
// behaviour for missing/unvalidated peers (return a NetworkMap with only
// Network populated). The receiver gets an envelope it can decode
// without crashing; AccountSettings stays non-nil so client-side
// dereferences are safe.
if c == nil {
// Match legacy missing-peer minimum: a NetworkMap with only Network
// populated. The receiver gets enough to bootstrap (Network
// identifier, dns_domain, account_settings) and nothing else.
return &proto.NetworkMapEnvelope{
Payload: &proto.NetworkMapEnvelope_Full{
Full: &proto.NetworkMapComponentsFull{
PeerConfig: in.PeerConfig,
DnsDomain: in.DNSDomain,
DnsForwarderPort: in.DNSForwarderPort,
UserIdClaim: in.UserIDClaim,
AccountSettings: &proto.AccountSettingsCompact{},
ProxyPatch: in.ProxyPatch,
},
},
}
}
// Phase 1: build dedup tables. Every routing peer (in c.RouterPeers) and
// every regular peer (in c.Peers) must be indexed before any encoder
// looks up indexes via e.peerOrder — otherwise routes / routers_map for
// peers that exist only in c.RouterPeers would silently lose their
// peer_index reference.
enc := newComponentEncoder(c)
enc.indexAllPeers()
routerIdxs := enc.indexRouterPeers(c.RouterPeers)
// Phase 2: gather every policy that any consumer references (peer-pair
// policies + resource-only policies) so encodeResourcePoliciesMap can
// translate every *Policy pointer to a wire index.
allPolicies := unionPolicies(c.Policies, c.ResourcePoliciesMap)
policies, policyToIdxs := enc.encodePolicies(allPolicies)
// Phase 3: emit. Order of struct field expressions no longer matters:
// every encoder either reads from the dedup tables or works on
// independent input.
full := &proto.NetworkMapComponentsFull{
Serial: networkSerial(c.Network),
PeerConfig: in.PeerConfig,
Network: toAccountNetwork(c.Network),
AccountSettings: toAccountSettingsCompact(c.AccountSettings),
DnsForwarderPort: in.DNSForwarderPort,
UserIdClaim: in.UserIDClaim,
ProxyPatch: in.ProxyPatch,
DnsSettings: enc.encodeDNSSettings(c.DNSSettings),
DnsDomain: in.DNSDomain,
CustomZoneDomain: c.CustomZoneDomain,
AgentVersions: enc.agentVersions,
Peers: enc.peers,
RouterPeerIndexes: routerIdxs,
Policies: policies,
Groups: enc.encodeGroups(),
Routes: enc.encodeRoutes(c.Routes),
NameserverGroups: enc.encodeNameServerGroups(c.NameServerGroups),
AllDnsRecords: encodeSimpleRecords(c.AllDNSRecords),
AccountZones: encodeCustomZones(c.AccountZones),
NetworkResources: enc.encodeNetworkResources(c.NetworkResources),
RoutersMap: enc.encodeRoutersMap(c.RoutersMap),
ResourcePoliciesMap: enc.encodeResourcePoliciesMap(c.ResourcePoliciesMap, policyToIdxs),
GroupIdToUserIds: enc.encodeGroupIDToUserIDs(c.GroupIDToUserIDs),
AllowedUserIds: stringSetToSlice(c.AllowedUserIDs),
PostureFailedPeers: enc.encodePostureFailedPeers(c.PostureFailedPeers),
}
return &proto.NetworkMapEnvelope{
Payload: &proto.NetworkMapEnvelope_Full{Full: full},
}
}
// networkSerial returns c.Network.CurrentSerial() with a nil guard. The
// production path always populates c.Network, but the encoder is exported
// and a hand-built components struct may omit it.
func networkSerial(n *types.Network) uint64 {
if n == nil {
return 0
}
return n.CurrentSerial()
}
type componentEncoder struct {
components *types.NetworkMapComponents
peerOrder map[string]uint32
peers []*proto.PeerCompact
agentVersionOrder map[string]uint32
agentVersions []string
}
func newComponentEncoder(c *types.NetworkMapComponents) *componentEncoder {
return &componentEncoder{
components: c,
peerOrder: make(map[string]uint32, len(c.Peers)),
peers: make([]*proto.PeerCompact, 0, len(c.Peers)),
agentVersionOrder: make(map[string]uint32),
}
}
func (e *componentEncoder) indexAllPeers() {
for _, p := range e.components.Peers {
if p == nil {
continue
}
e.appendPeer(p)
}
}
func (e *componentEncoder) appendPeer(p *nbpeer.Peer) uint32 {
if idx, ok := e.peerOrder[p.ID]; ok {
return idx
}
idx := uint32(len(e.peers))
e.peerOrder[p.ID] = idx
e.peers = append(e.peers, toPeerCompact(p, e.agentVersionIndex(p.Meta.WtVersion)))
return idx
}
func (e *componentEncoder) agentVersionIndex(v string) uint32 {
if idx, ok := e.agentVersionOrder[v]; ok {
return idx
}
// Lazy-initialise the table with "" at index 0 so the empty string
// stays interchangeable with proto3's default uint32=0 — peers without
// a WtVersion don't force the table to materialise.
if v == "" {
idx := uint32(len(e.agentVersions))
if idx == 0 {
e.agentVersions = append(e.agentVersions, "")
}
e.agentVersionOrder[""] = idx
return idx
}
if len(e.agentVersions) == 0 {
e.agentVersions = append(e.agentVersions, "")
e.agentVersionOrder[""] = 0
}
idx := uint32(len(e.agentVersions))
e.agentVersionOrder[v] = idx
e.agentVersions = append(e.agentVersions, v)
return idx
}
// indexRouterPeers ensures every router peer is in the peer dedup table
// (c.RouterPeers may contain peers not in c.Peers when validation rules drop
// them) and returns their wire indexes for the RouterPeerIndexes field. Must
// run before any encoder that resolves peer ids via e.peerOrder.
func (e *componentEncoder) indexRouterPeers(routers map[string]*nbpeer.Peer) []uint32 {
if len(routers) == 0 {
return nil
}
out := make([]uint32, 0, len(routers))
for _, p := range routers {
if p == nil {
continue
}
out = append(out, e.appendPeer(p))
}
return out
}
func (e *componentEncoder) encodeGroups() []*proto.GroupCompact {
if len(e.components.Groups) == 0 {
return nil
}
out := make([]*proto.GroupCompact, 0, len(e.components.Groups))
for _, g := range e.components.Groups {
if !g.HasSeqID() {
continue
}
peerIdxs := make([]uint32, 0, len(g.Peers))
for _, peerID := range g.Peers {
if idx, ok := e.peerOrder[peerID]; ok {
peerIdxs = append(peerIdxs, idx)
}
}
out = append(out, &proto.GroupCompact{
Id: g.AccountSeqID,
Name: g.Name,
PeerIndexes: peerIdxs,
})
}
return out
}
// encodePolicies flattens Policy{Rules} → []PolicyCompact. Returns the wire
// list and a map from policy pointer to the indexes of its emitted rules in
// that list — used by encodeResourcePoliciesMap to translate
// ResourcePoliciesMap[resourceID][]*Policy into wire-side indexes.
func (e *componentEncoder) encodePolicies(policies []*types.Policy) ([]*proto.PolicyCompact, map[*types.Policy][]uint32) {
if len(policies) == 0 {
return nil, nil
}
out := make([]*proto.PolicyCompact, 0, len(policies))
idxByPolicy := make(map[*types.Policy][]uint32, len(policies))
for _, pol := range policies {
if !pol.HasSeqID() || !pol.Enabled {
continue
}
for _, r := range pol.Rules {
if r == nil || !r.Enabled {
continue
}
idxByPolicy[pol] = append(idxByPolicy[pol], uint32(len(out)))
out = append(out, e.encodePolicyRule(pol, r))
}
}
return out, idxByPolicy
}
// encodePolicyRule maps a single PolicyRule under pol to a PolicyCompact entry.
func (e *componentEncoder) encodePolicyRule(pol *types.Policy, r *types.PolicyRule) *proto.PolicyCompact {
return &proto.PolicyCompact{
Id: pol.AccountSeqID,
Action: networkmap.GetProtoAction(string(r.Action)),
Protocol: networkmap.GetProtoProtocol(string(r.Protocol)),
Bidirectional: r.Bidirectional,
Ports: portsToUint32(r.Ports),
PortRanges: portRangesToProto(r.PortRanges),
SourceGroupIds: e.groupSeqIDs(r.Sources),
DestinationGroupIds: e.groupSeqIDs(r.Destinations),
AuthorizedUser: r.AuthorizedUser,
AuthorizedGroups: e.encodeAuthorizedGroups(r.AuthorizedGroups),
SourceResource: e.resourceToProto(r.SourceResource),
DestinationResource: e.resourceToProto(r.DestinationResource),
SourcePostureCheckSeqIds: e.postureCheckSeqs(pol.SourcePostureChecks),
}
}
// groupSeqIDs maps the xid group IDs in src to their per-account seq ids,
// dropping any group that has no seq id assigned.
func (e *componentEncoder) groupSeqIDs(src []string) []uint32 {
if len(src) == 0 {
return nil
}
out := make([]uint32, 0, len(src))
for _, gid := range src {
if seq, ok := e.groupSeq(gid); ok {
out = append(out, seq)
}
}
return out
}
// unionPolicies merges c.Policies with every policy referenced by
// c.ResourcePoliciesMap, deduplicating by pointer identity. Resource-only
// policies (relevant to a NetworkResource but not to peer-pair traffic)
// only live in ResourcePoliciesMap; without this union step they'd be lost
// from the wire and the client's resource-policy lookup would come back
// empty.
func unionPolicies(policies []*types.Policy, resourcePolicies map[string][]*types.Policy) []*types.Policy {
// Fast path: non-router peers have no resource-only policies, so the
// "union" is identical to `policies`. Skip the dedup map allocation.
if len(resourcePolicies) == 0 {
return policies
}
seen := make(map[*types.Policy]struct{}, len(policies))
out := make([]*types.Policy, 0, len(policies))
for _, p := range policies {
if p == nil {
continue
}
if _, ok := seen[p]; ok {
continue
}
seen[p] = struct{}{}
out = append(out, p)
}
for _, list := range resourcePolicies {
for _, p := range list {
if p == nil {
continue
}
if _, ok := seen[p]; ok {
continue
}
seen[p] = struct{}{}
out = append(out, p)
}
}
return out
}
// encodeAuthorizedGroups translates rule.AuthorizedGroups (map keyed by
// group xid → local-user names) to the wire form (map keyed by group
// account_seq_id → UserNameList). Groups without a seq id are dropped —
// matches how source/destination group references handle the same case.
func (e *componentEncoder) encodeAuthorizedGroups(m map[string][]string) map[uint32]*proto.UserNameList {
if len(m) == 0 {
return nil
}
out := make(map[uint32]*proto.UserNameList, len(m))
for groupID, names := range m {
seq, ok := e.groupSeq(groupID)
if !ok {
continue
}
out[seq] = &proto.UserNameList{Names: append([]string(nil), names...)}
}
return out
}
func (e *componentEncoder) groupSeq(groupID string) (uint32, bool) {
g, ok := e.components.Groups[groupID]
if !ok || !g.HasSeqID() {
return 0, false
}
return g.AccountSeqID, true
}
// resourceToProto translates types.Resource for the wire. For peer-typed
// resources the peer id is converted to a peer index into the envelope's
// peers array. For other resource types only the type string is shipped
// today (Calculate's resource-typed rule path consults SourceResource only
// for "peer" — other types fall through to group-based lookup).
func (e *componentEncoder) resourceToProto(r types.Resource) *proto.ResourceCompact {
if r.ID == "" && r.Type == "" {
return nil
}
out := &proto.ResourceCompact{Type: string(r.Type)}
if r.Type == types.ResourceTypePeer && r.ID != "" {
if idx, ok := e.peerOrder[r.ID]; ok {
out.PeerIndexSet = true
out.PeerIndex = idx
}
}
return out
}
// postureCheckSeqs translates a slice of posture-check xids to their
// per-account integer ids using the NetworkMapComponents.PostureCheckXIDToSeq
// lookup. Unresolvable xids are silently dropped — matches how group/peer
// references handle the same case.
func (e *componentEncoder) postureCheckSeqs(xids []string) []uint32 {
if len(xids) == 0 || len(e.components.PostureCheckXIDToSeq) == 0 {
return nil
}
out := make([]uint32, 0, len(xids))
for _, xid := range xids {
if seq, ok := e.components.PostureCheckXIDToSeq[xid]; ok {
out = append(out, seq)
}
}
return out
}
// networkSeq translates a Network xid to its per-account integer id using
// the NetworkMapComponents.NetworkXIDToSeq lookup. Returns (0,false) when
// the xid isn't known — callers decide whether to skip the parent record.
func (e *componentEncoder) networkSeq(xid string) (uint32, bool) {
if xid == "" {
return 0, false
}
seq, ok := e.components.NetworkXIDToSeq[xid]
if !ok || seq == 0 {
return 0, false
}
return seq, true
}
func (e *componentEncoder) encodeDNSSettings(s *types.DNSSettings) *proto.DNSSettingsCompact {
if s == nil || len(s.DisabledManagementGroups) == 0 {
return nil
}
out := &proto.DNSSettingsCompact{
DisabledManagementGroupIds: make([]uint32, 0, len(s.DisabledManagementGroups)),
}
for _, gid := range s.DisabledManagementGroups {
if seq, ok := e.groupSeq(gid); ok {
out.DisabledManagementGroupIds = append(out.DisabledManagementGroupIds, seq)
}
}
return out
}
func (e *componentEncoder) encodeRoutes(routes []*nbroute.Route) []*proto.RouteRaw {
if len(routes) == 0 {
return nil
}
out := make([]*proto.RouteRaw, 0, len(routes))
for _, r := range routes {
if r == nil {
continue
}
rr := &proto.RouteRaw{
Id: r.AccountSeqID,
NetId: string(r.NetID),
Description: r.Description,
KeepRoute: r.KeepRoute,
NetworkType: int32(r.NetworkType),
Masquerade: r.Masquerade,
Metric: int32(r.Metric),
Enabled: r.Enabled,
SkipAutoApply: r.SkipAutoApply,
Domains: r.Domains.ToPunycodeList(),
GroupIds: e.groupIDsToSeq(r.Groups),
AccessControlGroupIds: e.groupIDsToSeq(r.AccessControlGroups),
PeerGroupIds: e.groupIDsToSeq(r.PeerGroups),
}
if r.Network.IsValid() {
rr.NetworkCidr = r.Network.String()
}
if r.Peer != "" {
if idx, ok := e.peerOrder[r.Peer]; ok {
rr.PeerIndexSet = true
rr.PeerIndex = idx
}
}
out = append(out, rr)
}
return out
}
func (e *componentEncoder) groupIDsToSeq(groupIDs []string) []uint32 {
if len(groupIDs) == 0 {
return nil
}
out := make([]uint32, 0, len(groupIDs))
for _, gid := range groupIDs {
if seq, ok := e.groupSeq(gid); ok {
out = append(out, seq)
}
}
return out
}
func (e *componentEncoder) encodeNameServerGroups(nsgs []*nbdns.NameServerGroup) []*proto.NameServerGroupRaw {
if len(nsgs) == 0 {
return nil
}
out := make([]*proto.NameServerGroupRaw, 0, len(nsgs))
for _, nsg := range nsgs {
if nsg == nil {
continue
}
entry := &proto.NameServerGroupRaw{
Id: nsg.AccountSeqID,
Name: nsg.Name,
Description: nsg.Description,
Nameservers: encodeNameServers(nsg.NameServers),
GroupIds: e.groupIDsToSeq(nsg.Groups),
Primary: nsg.Primary,
Domains: nsg.Domains,
Enabled: nsg.Enabled,
SearchDomainsEnabled: nsg.SearchDomainsEnabled,
}
out = append(out, entry)
}
return out
}
func encodeNameServers(servers []nbdns.NameServer) []*proto.NameServer {
if len(servers) == 0 {
return nil
}
out := make([]*proto.NameServer, 0, len(servers))
for _, s := range servers {
out = append(out, &proto.NameServer{
IP: s.IP.String(),
NSType: int64(s.NSType),
Port: int64(s.Port),
})
}
return out
}
func encodeSimpleRecords(records []nbdns.SimpleRecord) []*proto.SimpleRecord {
if len(records) == 0 {
return nil
}
out := make([]*proto.SimpleRecord, 0, len(records))
for _, r := range records {
out = append(out, &proto.SimpleRecord{
Name: r.Name,
Type: int64(r.Type),
Class: r.Class,
TTL: int64(r.TTL),
RData: r.RData,
})
}
return out
}
func encodeCustomZones(zones []nbdns.CustomZone) []*proto.CustomZone {
if len(zones) == 0 {
return nil
}
out := make([]*proto.CustomZone, 0, len(zones))
for _, z := range zones {
out = append(out, &proto.CustomZone{
Domain: z.Domain,
Records: encodeSimpleRecords(z.Records),
SearchDomainDisabled: z.SearchDomainDisabled,
NonAuthoritative: z.NonAuthoritative,
})
}
return out
}
func (e *componentEncoder) encodeNetworkResources(resources []*resourceTypes.NetworkResource) []*proto.NetworkResourceRaw {
if len(resources) == 0 {
return nil
}
out := make([]*proto.NetworkResourceRaw, 0, len(resources))
for _, r := range resources {
if r == nil {
continue
}
entry := &proto.NetworkResourceRaw{
Id: r.AccountSeqID,
Name: r.Name,
Description: r.Description,
Type: string(r.Type),
Address: r.Address,
DomainValue: r.Domain,
Enabled: r.Enabled,
}
if seq, ok := e.networkSeq(r.NetworkID); ok {
entry.NetworkSeq = seq
}
if r.Prefix.IsValid() {
entry.PrefixCidr = r.Prefix.String()
}
out = append(out, entry)
}
return out
}
func (e *componentEncoder) encodeRoutersMap(routersMap map[string]map[string]*routerTypes.NetworkRouter) map[uint32]*proto.NetworkRouterList {
if len(routersMap) == 0 {
return nil
}
out := make(map[uint32]*proto.NetworkRouterList, len(routersMap))
for networkXID, routers := range routersMap {
if len(routers) == 0 {
continue
}
netSeq, ok := e.networkSeq(networkXID)
if !ok {
continue
}
entries := make([]*proto.NetworkRouterEntry, 0, len(routers))
for peerID, r := range routers {
if r == nil {
continue
}
entry := &proto.NetworkRouterEntry{
Id: r.AccountSeqID,
PeerGroupIds: e.groupIDsToSeq(r.PeerGroups),
Masquerade: r.Masquerade,
Metric: int32(r.Metric),
Enabled: r.Enabled,
}
if idx, ok := e.peerOrder[peerID]; ok {
entry.PeerIndexSet = true
entry.PeerIndex = idx
}
entries = append(entries, entry)
}
out[netSeq] = &proto.NetworkRouterList{Entries: entries}
}
return out
}
func (e *componentEncoder) encodeResourcePoliciesMap(rpm map[string][]*types.Policy, policyToIdxs map[*types.Policy][]uint32) map[uint32]*proto.PolicyIndexes {
if len(rpm) == 0 {
return nil
}
// resourceXIDToSeq is local to one encode — built from components.NetworkResources
// (small slice). Network resources without seq id are dropped, matching how
// other components-without-seq are silently filtered.
resourceXIDToSeq := make(map[string]uint32, len(e.components.NetworkResources))
for _, r := range e.components.NetworkResources {
if r != nil && r.AccountSeqID != 0 {
resourceXIDToSeq[r.ID] = r.AccountSeqID
}
}
out := make(map[uint32]*proto.PolicyIndexes, len(rpm))
for resourceXID, policies := range rpm {
seq, ok := resourceXIDToSeq[resourceXID]
if !ok {
continue
}
idxs := make([]uint32, 0, len(policies)*2)
for _, pol := range policies {
idxs = append(idxs, policyToIdxs[pol]...)
}
if len(idxs) == 0 {
continue
}
out[seq] = &proto.PolicyIndexes{Indexes: idxs}
}
return out
}
func (e *componentEncoder) encodeGroupIDToUserIDs(m map[string][]string) map[uint32]*proto.UserIDList {
if len(m) == 0 {
return nil
}
out := make(map[uint32]*proto.UserIDList, len(m))
for groupID, userIDs := range m {
seq, ok := e.groupSeq(groupID)
if !ok || len(userIDs) == 0 {
continue
}
out[seq] = &proto.UserIDList{UserIds: userIDs}
}
return out
}
func stringSetToSlice(s map[string]struct{}) []string {
if len(s) == 0 {
return nil
}
out := make([]string, 0, len(s))
for k := range s {
out = append(out, k)
}
return out
}
func (e *componentEncoder) encodePostureFailedPeers(m map[string]map[string]struct{}) map[uint32]*proto.PeerIndexSet {
if len(m) == 0 {
return nil
}
out := make(map[uint32]*proto.PeerIndexSet, len(m))
for checkXID, failedPeerIDs := range m {
seq, ok := e.components.PostureCheckXIDToSeq[checkXID]
if !ok || seq == 0 {
continue
}
idxs := make([]uint32, 0, len(failedPeerIDs))
for peerID := range failedPeerIDs {
if idx, ok := e.peerOrder[peerID]; ok {
idxs = append(idxs, idx)
}
}
if len(idxs) == 0 {
continue
}
out[seq] = &proto.PeerIndexSet{PeerIndexes: idxs}
}
return out
}
// toAccountSettingsCompact always returns a non-nil message — the client
// dereferences it unconditionally during Calculate(), so a nil here would
// crash the receiver. A missing types.AccountSettingsInfo on the server
// (which shouldn't happen in production but the encoder is exported)
// degrades to login_expiration_enabled = false, which makes
// LoginExpired() return false for every peer.
func toAccountSettingsCompact(s *types.AccountSettingsInfo) *proto.AccountSettingsCompact {
if s == nil {
return &proto.AccountSettingsCompact{}
}
return &proto.AccountSettingsCompact{
PeerLoginExpirationEnabled: s.PeerLoginExpirationEnabled,
PeerLoginExpirationNs: int64(s.PeerLoginExpiration),
}
}
func toAccountNetwork(n *types.Network) *proto.AccountNetwork {
if n == nil {
return nil
}
out := &proto.AccountNetwork{
Identifier: n.Identifier,
NetCidr: n.Net.String(),
Dns: n.Dns,
Serial: n.CurrentSerial(),
}
if len(n.NetV6.IP) > 0 {
out.NetV6Cidr = n.NetV6.String()
}
return out
}
func toPeerCompact(p *nbpeer.Peer, agentVersionIdx uint32) *proto.PeerCompact {
pc := &proto.PeerCompact{
WgPubKey: decodeWgKey(p.Key),
SshPubKey: []byte(p.SSHKey),
DnsLabel: p.DNSLabel,
AgentVersionIdx: agentVersionIdx,
AddedWithSsoLogin: p.UserID != "",
LoginExpirationEnabled: p.LoginExpirationEnabled,
SshEnabled: p.SSHEnabled,
SupportsIpv6: p.SupportsIPv6(),
SupportsSourcePrefixes: p.SupportsSourcePrefixes(),
ServerSshAllowed: p.Meta.Flags.ServerSSHAllowed,
}
if p.LastLogin != nil {
pc.LastLoginUnixNano = p.LastLogin.UnixNano()
}
switch {
case !p.IP.IsValid():
// leave Ip nil
case p.IP.Is4() || p.IP.Is4In6():
ip := p.IP.Unmap().As4()
pc.Ip = ip[:]
default:
ip := p.IP.As16()
pc.Ip = ip[:]
}
if p.IPv6.IsValid() {
ip := p.IPv6.As16()
pc.Ipv6 = ip[:]
}
return pc
}
// decodeWgKey returns the raw 32 bytes of a base64-encoded WireGuard public
// key, or nil for an empty / malformed key.
func decodeWgKey(s string) []byte {
if s == "" {
return nil
}
out := make([]byte, wgKeyRawLen)
n, err := base64.StdEncoding.Decode(out, []byte(s))
if err != nil || n != wgKeyRawLen {
return nil
}
return out
}
func portsToUint32(ports []string) []uint32 {
if len(ports) == 0 {
return nil
}
out := make([]uint32, 0, len(ports))
for _, p := range ports {
v, err := strconv.ParseUint(p, 10, 16)
if err != nil {
continue
}
out = append(out, uint32(v))
}
return out
}
func portRangesToProto(ranges []types.RulePortRange) []*proto.PortInfo_Range {
if len(ranges) == 0 {
return nil
}
out := make([]*proto.PortInfo_Range, 0, len(ranges))
for _, r := range ranges {
out = append(out, &proto.PortInfo_Range{
Start: uint32(r.Start),
End: uint32(r.End),
})
}
return out
}

View File

@@ -0,0 +1,879 @@
package grpc
import (
"bytes"
"cmp"
"net"
"net/netip"
"slices"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
goproto "google.golang.org/protobuf/proto"
nbdns "github.com/netbirdio/netbird/dns"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/types"
nbroute "github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/proto"
)
const testWgKeyA = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopq="
const testWgKeyB = "BBCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopq="
const testWgKeyC = "CBCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopq="
// canonicalize rewrites a NetworkMapComponentsFull in place into a canonical
// form: peers reordered by wg_pub_key, with the rest of the message rewritten
// to reference the new peer indexes. Groups, policies, and router indexes are
// also sorted. After canonicalize, two envelopes built from the same logical
// input compare byte-equal via proto.Equal.
//
// This lives on the test side — the encoder itself emits in map-iteration
// order. Test-side normalization is the contract for "two encodes are
// equivalent".
func canonicalize(full *proto.NetworkMapComponentsFull) {
if full == nil {
return
}
// Canonicalize agent_versions first: sort the slice and rewrite each
// peer's AgentVersionIdx accordingly. The empty placeholder stays at
// index 0 by convention.
avRemap := make(map[uint32]uint32, len(full.AgentVersions))
if len(full.AgentVersions) > 0 {
// Pair version → original index, sort, rebuild.
type avEntry struct {
version string
oldIdx uint32
}
entries := make([]avEntry, len(full.AgentVersions))
for i, v := range full.AgentVersions {
entries[i] = avEntry{version: v, oldIdx: uint32(i)}
}
// Empty stays at 0; sort the rest by string. Tiebreaker on oldIdx
// keeps the canonicalize output stable when two entries compare
// equal (the encoder dedups, but defending against future inputs).
slices.SortFunc(entries, func(a, b avEntry) int {
if a.version == "" && b.version != "" {
return -1
}
if b.version == "" && a.version != "" {
return 1
}
if c := cmp.Compare(a.version, b.version); c != 0 {
return c
}
return cmp.Compare(a.oldIdx, b.oldIdx)
})
newVersions := make([]string, len(entries))
for newIdx, e := range entries {
avRemap[e.oldIdx] = uint32(newIdx)
newVersions[newIdx] = e.version
}
full.AgentVersions = newVersions
}
for _, p := range full.Peers {
if newIdx, ok := avRemap[p.AgentVersionIdx]; ok {
p.AgentVersionIdx = newIdx
}
}
type peerEntry struct {
peer *proto.PeerCompact
oldIdx uint32
}
entries := make([]peerEntry, len(full.Peers))
for i, p := range full.Peers {
entries[i] = peerEntry{peer: p, oldIdx: uint32(i)}
}
// DnsLabel is unique per peer; it tiebreaks on equal WgPubKey (e.g. both
// nil from malformed keys, or both empty for placeholders).
slices.SortFunc(entries, func(a, b peerEntry) int {
if c := bytes.Compare(a.peer.WgPubKey, b.peer.WgPubKey); c != 0 {
return c
}
return cmp.Compare(a.peer.DnsLabel, b.peer.DnsLabel)
})
remap := make(map[uint32]uint32, len(entries))
newPeers := make([]*proto.PeerCompact, len(entries))
for newIdx, e := range entries {
remap[e.oldIdx] = uint32(newIdx)
newPeers[newIdx] = e.peer
}
full.Peers = newPeers
full.RouterPeerIndexes = remapAndSort(full.RouterPeerIndexes, remap)
for _, g := range full.Groups {
g.PeerIndexes = remapAndSort(g.PeerIndexes, remap)
}
slices.SortFunc(full.Groups, func(a, b *proto.GroupCompact) int { return cmp.Compare(a.Id, b.Id) })
for _, r := range full.Routes {
if r.PeerIndexSet {
if newIdx, ok := remap[r.PeerIndex]; ok {
r.PeerIndex = newIdx
}
}
slices.Sort(r.GroupIds)
slices.Sort(r.AccessControlGroupIds)
slices.Sort(r.PeerGroupIds)
}
slices.SortFunc(full.Routes, func(a, b *proto.RouteRaw) int { return cmp.Compare(a.Id, b.Id) })
for _, list := range full.RoutersMap {
for _, entry := range list.Entries {
if entry.PeerIndexSet {
if newIdx, ok := remap[entry.PeerIndex]; ok {
entry.PeerIndex = newIdx
}
}
slices.Sort(entry.PeerGroupIds)
}
slices.SortFunc(list.Entries, func(a, b *proto.NetworkRouterEntry) int { return cmp.Compare(a.Id, b.Id) })
}
for _, set := range full.PostureFailedPeers {
set.PeerIndexes = remapAndSort(set.PeerIndexes, remap)
}
for _, p := range full.Policies {
slices.Sort(p.SourceGroupIds)
slices.Sort(p.DestinationGroupIds)
}
// Sort policies by (Id, source_group_ids, destination_group_ids) so that
// multiple PolicyCompact entries sharing the same Id (one per rule, when
// a Policy has multiple rules) still get a deterministic order. After
// sorting we remap indexes in ResourcePoliciesMap.
policyOldOrder := make(map[*proto.PolicyCompact]uint32, len(full.Policies))
for i, p := range full.Policies {
policyOldOrder[p] = uint32(i)
}
slices.SortFunc(full.Policies, func(a, b *proto.PolicyCompact) int {
if c := cmp.Compare(a.Id, b.Id); c != 0 {
return c
}
if c := slices.Compare(a.SourceGroupIds, b.SourceGroupIds); c != 0 {
return c
}
return slices.Compare(a.DestinationGroupIds, b.DestinationGroupIds)
})
policyRemap := make(map[uint32]uint32, len(full.Policies))
for newIdx, p := range full.Policies {
policyRemap[policyOldOrder[p]] = uint32(newIdx)
}
for _, idxs := range full.ResourcePoliciesMap {
idxs.Indexes = remapAndSort(idxs.Indexes, policyRemap)
}
for _, list := range full.GroupIdToUserIds {
slices.Sort(list.UserIds)
}
slices.Sort(full.AllowedUserIds)
}
func remapAndSort(idxs []uint32, remap map[uint32]uint32) []uint32 {
out := make([]uint32, 0, len(idxs))
for _, i := range idxs {
if newIdx, ok := remap[i]; ok {
out = append(out, newIdx)
}
}
slices.Sort(out)
return out
}
// envelopesEquivalent decodes both envelopes, canonicalizes them, and reports
// whether they're proto.Equal. Use instead of byte-comparing marshaled output:
// the encoder is intentionally non-deterministic.
func envelopesEquivalent(a, b *proto.NetworkMapEnvelope) bool {
canonicalize(a.GetFull())
canonicalize(b.GetFull())
return goproto.Equal(a, b)
}
func newTestComponents() *types.NetworkMapComponents {
peerA := &nbpeer.Peer{
ID: "peer-a",
Key: testWgKeyA,
IP: netip.AddrFrom4([4]byte{100, 64, 0, 1}),
DNSLabel: "peera",
SSHKey: "ssh-a",
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()},
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.40.0"},
}
peerB := &nbpeer.Peer{
ID: "peer-b",
Key: testWgKeyB,
IP: netip.AddrFrom4([4]byte{100, 64, 0, 2}),
IPv6: netip.AddrFrom16([16]byte{0xfd, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2}),
DNSLabel: "peerb",
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.25.0"},
}
peerC := &nbpeer.Peer{
ID: "peer-c",
Key: testWgKeyC,
IP: netip.AddrFrom4([4]byte{100, 64, 0, 3}),
DNSLabel: "peerc",
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.40.0"},
}
return &types.NetworkMapComponents{
PeerID: "peer-a",
Network: &types.Network{
Identifier: "net-test",
Net: net.IPNet{IP: net.IP{100, 64, 0, 0}, Mask: net.CIDRMask(10, 32)},
Serial: 7,
},
AccountSettings: &types.AccountSettingsInfo{
PeerLoginExpirationEnabled: true,
PeerLoginExpiration: 2 * time.Hour,
},
Peers: map[string]*nbpeer.Peer{
"peer-a": peerA,
"peer-b": peerB,
"peer-c": peerC,
},
Groups: map[string]*types.Group{
"group-src": {ID: "group-src", AccountSeqID: 1, Name: "Src", Peers: []string{"peer-a"}},
"group-dst": {ID: "group-dst", AccountSeqID: 2, Name: "Dst", Peers: []string{"peer-b", "peer-c"}},
},
Policies: []*types.Policy{
{
ID: "pol-1",
AccountSeqID: 10,
Enabled: true,
Rules: []*types.PolicyRule{{
ID: "rule-1", Enabled: true, Action: types.PolicyTrafficActionAccept,
Protocol: types.PolicyRuleProtocolTCP, Bidirectional: true,
Ports: []string{"22", "80"},
PortRanges: []types.RulePortRange{{Start: 8000, End: 8100}},
Sources: []string{"group-src"},
Destinations: []string{"group-dst"},
}},
},
},
RouterPeers: map[string]*nbpeer.Peer{"peer-c": peerC},
}
}
func TestEncodeNetworkMapEnvelope_Basic(t *testing.T) {
c := newTestComponents()
env := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{
Components: c,
DNSDomain: "netbird.cloud",
})
require.NotNil(t, env)
full := env.GetFull()
require.NotNil(t, full, "envelope must contain Full payload")
assert.EqualValues(t, 7, full.Serial)
assert.Equal(t, "netbird.cloud", full.DnsDomain)
require.NotNil(t, full.Network)
assert.Equal(t, "net-test", full.Network.Identifier)
assert.Equal(t, "100.64.0.0/10", full.Network.NetCidr)
require.NotNil(t, full.AccountSettings)
assert.True(t, full.AccountSettings.PeerLoginExpirationEnabled)
assert.EqualValues(t, (2 * time.Hour).Nanoseconds(), full.AccountSettings.PeerLoginExpirationNs)
require.Len(t, full.Peers, 3)
byLabel := map[string]*proto.PeerCompact{}
for _, p := range full.Peers {
assert.Len(t, p.WgPubKey, 32, "wg key must be raw 32 bytes")
assert.Len(t, p.Ip, 4, "ipv4 must be raw 4 bytes")
byLabel[p.DnsLabel] = p
}
assert.Len(t, byLabel["peerb"].Ipv6, 16, "peer-b has ipv6 → 16 bytes")
}
func TestEncodeNetworkMapEnvelope_RepeatEncodesEquivalent(t *testing.T) {
c := newTestComponents()
expected := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c})
// Hammer it 100 times — Go map iteration is randomized per call, so each
// run produces different wire bytes, but the canonicalized form must
// match.
for i := 0; i < 100; i++ {
got := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c})
require.True(t, envelopesEquivalent(expected, got),
"encode #%d must be semantically equivalent to first encode", i)
}
}
func TestEncodeNetworkMapEnvelope_ConcurrentEncodesEquivalent(t *testing.T) {
c := newTestComponents()
expected := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c})
const goroutines = 50
var wg sync.WaitGroup
wg.Add(goroutines)
results := make([]*proto.NetworkMapEnvelope, goroutines)
for i := 0; i < goroutines; i++ {
i := i
go func() {
defer wg.Done()
results[i] = EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c})
}()
}
wg.Wait()
for i, got := range results {
require.NotNil(t, got, "goroutine %d returned nil", i)
require.True(t, envelopesEquivalent(expected, got),
"goroutine %d produced inequivalent envelope", i)
}
}
func TestEncodeNetworkMapEnvelope_GroupsByAccountSeqID(t *testing.T) {
c := newTestComponents()
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
require.Len(t, full.Groups, 2)
groupByID := map[uint32]*proto.GroupCompact{}
for _, g := range full.Groups {
groupByID[g.Id] = g
}
require.Contains(t, groupByID, uint32(1))
require.Contains(t, groupByID, uint32(2))
assert.Equal(t, "Src", groupByID[1].Name)
assert.Equal(t, "Dst", groupByID[2].Name)
assert.Len(t, groupByID[1].PeerIndexes, 1)
assert.Len(t, groupByID[2].PeerIndexes, 2)
}
func TestEncodeNetworkMapEnvelope_PolicyExpansion(t *testing.T) {
c := newTestComponents()
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
require.Len(t, full.Policies, 1)
pc := full.Policies[0]
assert.EqualValues(t, 10, pc.Id)
assert.Equal(t, proto.RuleAction_ACCEPT, pc.Action)
assert.Equal(t, proto.RuleProtocol_TCP, pc.Protocol)
assert.True(t, pc.Bidirectional)
assert.Equal(t, []uint32{22, 80}, pc.Ports)
require.Len(t, pc.PortRanges, 1)
assert.EqualValues(t, 8000, pc.PortRanges[0].Start)
assert.EqualValues(t, 8100, pc.PortRanges[0].End)
assert.Equal(t, []uint32{1}, pc.SourceGroupIds)
assert.Equal(t, []uint32{2}, pc.DestinationGroupIds)
}
func TestEncodeNetworkMapEnvelope_RouterIndexes(t *testing.T) {
c := newTestComponents()
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
require.Len(t, full.RouterPeerIndexes, 1)
idx := full.RouterPeerIndexes[0]
require.Less(t, int(idx), len(full.Peers))
assert.Equal(t, "peerc", full.Peers[idx].DnsLabel)
}
func TestEncodeNetworkMapEnvelope_AgentVersionDedup(t *testing.T) {
c := newTestComponents()
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
require.Len(t, full.AgentVersions, 3, "empty placeholder + 2 distinct versions")
assert.Equal(t, "", full.AgentVersions[0], "index 0 reserved for empty version")
assert.ElementsMatch(t, []string{"0.40.0", "0.25.0"}, full.AgentVersions[1:],
"two distinct versions, order depends on map iteration")
idxByLabel := map[string]uint32{}
for _, p := range full.Peers {
idxByLabel[p.DnsLabel] = p.AgentVersionIdx
}
assert.Equal(t, idxByLabel["peera"], idxByLabel["peerc"], "peers with the same agent version share an index")
assert.NotEqual(t, idxByLabel["peera"], idxByLabel["peerb"])
}
func TestEncodeNetworkMapEnvelope_DisabledPolicySkipped(t *testing.T) {
c := newTestComponents()
c.Policies[0].Enabled = false
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
assert.Empty(t, full.Policies)
}
func TestEncodeNetworkMapEnvelope_GroupZeroSeqIDSkipped(t *testing.T) {
c := newTestComponents()
c.Groups["group-src"].AccountSeqID = 0
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
require.Len(t, full.Groups, 1, "groups with AccountSeqID=0 are not yet persisted and must be skipped")
assert.EqualValues(t, 2, full.Groups[0].Id)
require.Len(t, full.Policies, 1)
pc := full.Policies[0]
assert.Empty(t, pc.SourceGroupIds, "rule references a group that was filtered out → no group id on wire")
assert.Equal(t, []uint32{2}, pc.DestinationGroupIds)
}
func TestEncodeNetworkMapEnvelope_TwoPeersSameMalformedKey(t *testing.T) {
// Both peers have nil WgPubKey after decode; canonicalize must still
// produce a stable order using DnsLabel as a tiebreaker, so 100 encodes
// canonicalize identically.
c := newTestComponents()
c.Peers["peer-a"].Key = "garbage-a-!!!"
c.Peers["peer-b"].Key = "garbage-b-!!!"
expected := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c})
for i := 0; i < 100; i++ {
got := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c})
require.True(t, envelopesEquivalent(expected, got),
"encode #%d with two same-key peers must canonicalize equivalently", i)
}
}
func TestEncodeNetworkMapEnvelope_MalformedWgKey(t *testing.T) {
c := newTestComponents()
c.Peers["peer-a"].Key = "not-base64-!!!"
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
require.Len(t, full.Peers, 3)
var byLabel = map[string]*proto.PeerCompact{}
for _, p := range full.Peers {
byLabel[p.DnsLabel] = p
}
assert.Nil(t, byLabel["peera"].WgPubKey, "peer with malformed key encodes nil WgPubKey")
assert.Len(t, byLabel["peerb"].WgPubKey, 32, "other peers retain their key")
}
func TestEncodeNetworkMapEnvelope_IPv6OnlyPeer(t *testing.T) {
c := newTestComponents()
v6Only := &nbpeer.Peer{
ID: "peer-v6",
Key: testWgKeyA,
IPv6: netip.AddrFrom16([16]byte{0xfd, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 9}),
DNSLabel: "peerv6",
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.40.0"},
}
c.Peers["peer-v6"] = v6Only
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
var found *proto.PeerCompact
for _, p := range full.Peers {
if p.DnsLabel == "peerv6" {
found = p
}
}
require.NotNil(t, found, "ipv6-only peer must be present")
assert.Empty(t, found.Ip, "no IPv4 address → empty Ip")
assert.Len(t, found.Ipv6, 16)
}
func TestEncodeNetworkMapEnvelope_PeerWithoutIP(t *testing.T) {
c := newTestComponents()
c.Peers["peer-noip"] = &nbpeer.Peer{
ID: "peer-noip",
Key: testWgKeyA,
DNSLabel: "peernoip",
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.40.0"},
}
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
var found *proto.PeerCompact
for _, p := range full.Peers {
if p.DnsLabel == "peernoip" {
found = p
}
}
require.NotNil(t, found)
assert.Empty(t, found.Ip)
assert.Empty(t, found.Ipv6)
}
func TestEncodeNetworkMapEnvelope_EmptyInput(t *testing.T) {
c := &types.NetworkMapComponents{
Network: &types.Network{Identifier: "x", Net: net.IPNet{IP: net.IP{100, 64, 0, 0}, Mask: net.CIDRMask(10, 32)}},
}
env := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c})
full := env.GetFull()
require.NotNil(t, full)
assert.Empty(t, full.Peers)
assert.Empty(t, full.Groups)
assert.Empty(t, full.Policies)
assert.Empty(t, full.RouterPeerIndexes)
require.NotNil(t, full.AccountSettings, "AccountSettingsCompact must always be emitted (client dereferences it unconditionally)")
}
func TestEncodeNetworkMapEnvelope_PeerLoginExpirationFields(t *testing.T) {
c := newTestComponents()
now := time.Date(2024, 1, 2, 3, 4, 5, 0, time.UTC)
c.Peers["peer-a"].UserID = "user-1"
c.Peers["peer-a"].LoginExpirationEnabled = true
c.Peers["peer-a"].LastLogin = &now
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
var pa *proto.PeerCompact
for _, p := range full.Peers {
if p.DnsLabel == "peera" {
pa = p
}
}
require.NotNil(t, pa)
assert.True(t, pa.AddedWithSsoLogin)
assert.True(t, pa.LoginExpirationEnabled)
assert.Equal(t, now.UnixNano(), pa.LastLoginUnixNano)
// peer-b has no UserID and no LastLogin → all fields zero-value.
var pb *proto.PeerCompact
for _, p := range full.Peers {
if p.DnsLabel == "peerb" {
pb = p
}
}
require.NotNil(t, pb)
assert.False(t, pb.AddedWithSsoLogin)
assert.False(t, pb.LoginExpirationEnabled)
assert.Zero(t, pb.LastLoginUnixNano)
}
func TestEncodeNetworkMapEnvelope_RoutesRoundTrip(t *testing.T) {
c := newTestComponents()
c.Routes = []*nbroute.Route{
{
ID: "route-peer",
AccountSeqID: 100,
NetID: "net-A",
Description: "via peer-c",
Network: netip.MustParsePrefix("10.0.0.0/16"),
Peer: "peer-c", // peer ID, not WG key
Groups: []string{"group-src"},
AccessControlGroups: []string{"group-dst"},
Enabled: true,
},
{
ID: "route-peergroup",
AccountSeqID: 101,
NetID: "net-B",
Network: netip.MustParsePrefix("10.1.0.0/16"),
PeerGroups: []string{"group-src", "group-dst"},
Enabled: true,
},
{
ID: "route-no-seq",
AccountSeqID: 0, // unset — should still ship (no group seq filter on routes)
Network: netip.MustParsePrefix("10.2.0.0/16"),
Enabled: true,
},
}
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
require.Len(t, full.Routes, 3)
byNetID := map[string]*proto.RouteRaw{}
for _, r := range full.Routes {
byNetID[r.NetId] = r
}
r1 := byNetID["net-A"]
require.NotNil(t, r1)
assert.True(t, r1.PeerIndexSet, "route with peer must set peer_index_set")
require.Less(t, int(r1.PeerIndex), len(full.Peers))
assert.Equal(t, "peerc", full.Peers[r1.PeerIndex].DnsLabel)
assert.Equal(t, []uint32{1}, r1.GroupIds, "group-src has AccountSeqID 1")
assert.Equal(t, []uint32{2}, r1.AccessControlGroupIds, "group-dst has AccountSeqID 2")
assert.Empty(t, r1.PeerGroupIds)
r2 := byNetID["net-B"]
require.NotNil(t, r2)
assert.False(t, r2.PeerIndexSet, "route with peer_groups must NOT set peer_index_set")
assert.ElementsMatch(t, []uint32{1, 2}, r2.PeerGroupIds)
}
func TestEncodeNetworkMapEnvelope_RouteWithMissingPeerLeavesIndexUnset(t *testing.T) {
c := newTestComponents()
c.Routes = []*nbroute.Route{{
ID: "route-x",
AccountSeqID: 100,
Peer: "peer-not-in-components",
Network: netip.MustParsePrefix("10.0.0.0/16"),
Enabled: true,
}}
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
require.Len(t, full.Routes, 1)
assert.False(t, full.Routes[0].PeerIndexSet,
"missing peer reference must not pretend to point at peer index 0")
}
func TestEncodeNetworkMapEnvelope_ResourceOnlyPolicyShippedAndIndexed(t *testing.T) {
c := newTestComponents()
// Policy that exists ONLY in ResourcePoliciesMap, not in c.Policies. This
// is the I1 case — without unionPolicies the encoder would silently
// drop it from the wire.
resourceOnlyPolicy := &types.Policy{
ID: "pol-resource", AccountSeqID: 99, Enabled: true,
Rules: []*types.PolicyRule{{
ID: "rule-r", Enabled: true, Action: types.PolicyTrafficActionAccept,
Protocol: types.PolicyRuleProtocolTCP,
Sources: []string{"group-src"},
Destinations: []string{"group-dst"},
}},
}
c.ResourcePoliciesMap = map[string][]*types.Policy{
"resource-x": {c.Policies[0], resourceOnlyPolicy}, // shared + resource-only
}
// Resource must appear in components.NetworkResources with a seq id —
// encoder uses that to translate the xid map key to uint32.
c.NetworkResources = []*resourceTypes.NetworkResource{
{ID: "resource-x", AccountSeqID: 77, Name: "res-x", Enabled: true},
}
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
require.Len(t, full.Policies, 2, "encoded policies must include both peer-traffic and resource-only")
policyByID := map[uint32]*proto.PolicyCompact{}
policyIdxByID := map[uint32]uint32{}
for i, p := range full.Policies {
policyByID[p.Id] = p
policyIdxByID[p.Id] = uint32(i)
}
require.Contains(t, policyByID, uint32(10), "original peer-traffic policy id 10")
require.Contains(t, policyByID, uint32(99), "resource-only policy id 99")
require.Contains(t, full.ResourcePoliciesMap, uint32(77))
idxs := full.ResourcePoliciesMap[77].Indexes
require.Len(t, idxs, 2)
assert.ElementsMatch(t, []uint32{policyIdxByID[10], policyIdxByID[99]}, idxs,
"resource policies map must reference both wire policy indexes")
}
func TestEncodeNetworkMapEnvelope_NameServerGroups(t *testing.T) {
c := newTestComponents()
c.NameServerGroups = []*nbdns.NameServerGroup{{
ID: "nsg-1", AccountSeqID: 50, Name: "Main", Description: "primary",
NameServers: []nbdns.NameServer{{
IP: netip.MustParseAddr("8.8.8.8"), NSType: nbdns.UDPNameServerType, Port: 53,
}},
Groups: []string{"group-src", "group-not-persisted"},
Primary: true, Enabled: true,
Domains: []string{"corp.example"},
}}
c.Groups["group-not-persisted"] = &types.Group{ID: "group-not-persisted", AccountSeqID: 0, Peers: []string{}}
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
require.Len(t, full.NameserverGroups, 1)
nsg := full.NameserverGroups[0]
assert.EqualValues(t, 50, nsg.Id)
assert.Equal(t, "Main", nsg.Name)
assert.True(t, nsg.Primary)
require.Len(t, nsg.Nameservers, 1)
assert.Equal(t, "8.8.8.8", nsg.Nameservers[0].IP)
assert.Equal(t, []uint32{1}, nsg.GroupIds, "group-not-persisted is filtered out (AccountSeqID=0)")
}
func TestEncodeNetworkMapEnvelope_PostureFailedPeers(t *testing.T) {
c := newTestComponents()
c.PostureCheckXIDToSeq = map[string]uint32{"check-1": 33}
c.PostureFailedPeers = map[string]map[string]struct{}{
"check-1": {
"peer-a": {},
"peer-b": {},
"peer-not-in-account": {},
},
}
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
require.Contains(t, full.PostureFailedPeers, uint32(33))
idxs := full.PostureFailedPeers[33].PeerIndexes
assert.Len(t, idxs, 2, "missing peer is silently dropped (filterPostureFailedPeers guarantees presence in real data)")
}
func TestEncodeNetworkMapEnvelope_RoutersMap(t *testing.T) {
c := newTestComponents()
c.NetworkXIDToSeq = map[string]uint32{"net-1": 5}
c.RoutersMap = map[string]map[string]*routerTypes.NetworkRouter{
"net-1": {
"peer-c": {
ID: "router-1", AccountSeqID: 200,
Peer: "peer-c", Masquerade: true, Metric: 10, Enabled: true,
},
},
}
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
require.Contains(t, full.RoutersMap, uint32(5))
entries := full.RoutersMap[5].Entries
require.Len(t, entries, 1)
e := entries[0]
assert.EqualValues(t, 200, e.Id)
assert.True(t, e.PeerIndexSet)
require.Less(t, int(e.PeerIndex), len(full.Peers))
assert.Equal(t, "peerc", full.Peers[e.PeerIndex].DnsLabel)
assert.True(t, e.Masquerade)
assert.EqualValues(t, 10, e.Metric)
assert.True(t, e.Enabled)
}
func TestEncodeNetworkMapEnvelope_RouterPeerNotInComponentsPeers(t *testing.T) {
// Router peer in c.RouterPeers but NOT in c.Peers (validation may have
// filtered it). indexRouterPeers runs before encodeRoutersMap, so the
// peer_index reference must still resolve.
c := newTestComponents()
delete(c.Peers, "peer-c")
routerPeer := &nbpeer.Peer{
ID: "peer-c", Key: testWgKeyC, IP: netip.AddrFrom4([4]byte{100, 64, 0, 3}),
DNSLabel: "peerc", Meta: nbpeer.PeerSystemMeta{WtVersion: "0.40.0"},
}
c.RouterPeers = map[string]*nbpeer.Peer{"peer-c": routerPeer}
c.NetworkXIDToSeq = map[string]uint32{"net-1": 5}
c.RoutersMap = map[string]map[string]*routerTypes.NetworkRouter{
"net-1": {"peer-c": {ID: "r-1", AccountSeqID: 1, Peer: "peer-c", Enabled: true}},
}
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
require.Contains(t, full.RoutersMap, uint32(5))
require.Len(t, full.RoutersMap[5].Entries, 1)
e := full.RoutersMap[5].Entries[0]
assert.True(t, e.PeerIndexSet, "router peer must be indexed even when not in c.Peers")
}
func TestEncodeNetworkMapEnvelope_DNSSettingsFiltersUnpersistedGroups(t *testing.T) {
c := newTestComponents()
c.DNSSettings = &types.DNSSettings{
DisabledManagementGroups: []string{"group-src", "group-missing", "group-no-seq"},
}
c.Groups["group-no-seq"] = &types.Group{ID: "group-no-seq", AccountSeqID: 0}
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
require.NotNil(t, full.DnsSettings)
assert.Equal(t, []uint32{1}, full.DnsSettings.DisabledManagementGroupIds,
"only group-src (AccountSeqID=1) survives — missing and unpersisted are dropped")
}
func TestEncodeNetworkMapEnvelope_GroupIDToUserIDs(t *testing.T) {
c := newTestComponents()
c.GroupIDToUserIDs = map[string][]string{
"group-src": {"user-1", "user-2"},
"group-no-seq": {"user-3"}, // group not persisted → drop
"group-missing": {"user-4"}, // group not in components → drop
}
c.Groups["group-no-seq"] = &types.Group{ID: "group-no-seq", AccountSeqID: 0}
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
require.Len(t, full.GroupIdToUserIds, 1, "only persisted+present groups survive")
require.Contains(t, full.GroupIdToUserIds, uint32(1))
assert.ElementsMatch(t, []string{"user-1", "user-2"}, full.GroupIdToUserIds[1].UserIds)
}
func TestToProxyPatch_EmptyInputReturnsNil(t *testing.T) {
assert.Nil(t, toProxyPatch(nil, "netbird.cloud", false, false))
assert.Nil(t, toProxyPatch(&types.NetworkMap{}, "netbird.cloud", false, false),
"empty NetworkMap (no peers, rules, routes etc) → nil patch so proto3 omits the field")
}
func TestToProxyPatch_PopulatesAllFields(t *testing.T) {
nm := &types.NetworkMap{
Peers: []*nbpeer.Peer{{
ID: "ext-peer", Key: testWgKeyA, IP: netip.AddrFrom4([4]byte{100, 64, 0, 9}),
DNSLabel: "extpeer", Meta: nbpeer.PeerSystemMeta{WtVersion: "0.40.0"},
}},
FirewallRules: []*types.FirewallRule{{
PeerIP: "100.64.0.9", Action: "accept", Direction: 0, Protocol: "tcp",
}},
}
patch := toProxyPatch(nm, "netbird.cloud", false, false)
require.NotNil(t, patch)
assert.Len(t, patch.Peers, 1)
assert.Len(t, patch.FirewallRules, 1)
}
// TestEncodeNetworkMapEnvelope_ProxyPatchPropagated covers the ProxyPatch
// pass-through in both encoder branches (normal path + nil-Components
// graceful-degrade). Guards against a regression that drops `ProxyPatch:`
// from one of the envelope struct literals.
func TestEncodeNetworkMapEnvelope_ProxyPatchPropagated(t *testing.T) {
patch := &proto.ProxyPatch{
ForwardingRules: []*proto.ForwardingRule{{
Protocol: proto.RuleProtocol_TCP,
DestinationPort: &proto.PortInfo{PortSelection: &proto.PortInfo_Port{Port: 80}},
TranslatedAddress: net.IPv4(10, 0, 0, 1).To4(),
TranslatedPort: &proto.PortInfo{PortSelection: &proto.PortInfo_Port{Port: 8080}},
}},
}
t.Run("normal_path", func(t *testing.T) {
c := newTestComponents()
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{
Components: c,
ProxyPatch: patch,
}).GetFull()
require.NotNil(t, full.ProxyPatch, "ProxyPatch must propagate through the normal encode path")
assert.Len(t, full.ProxyPatch.ForwardingRules, 1)
})
t.Run("nil_components_graceful_degrade", func(t *testing.T) {
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{
Components: nil,
ProxyPatch: patch,
}).GetFull()
require.NotNil(t, full.ProxyPatch, "ProxyPatch must propagate through the nil-Components branch too")
assert.Len(t, full.ProxyPatch.ForwardingRules, 1)
})
}
func TestEncodeNetworkMapEnvelope_NilComponentsGracefulDegrade(t *testing.T) {
// nil Components → minimal envelope, no crash. Matches the legacy
// behaviour for missing/unvalidated peers.
env := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{
Components: nil,
DNSDomain: "netbird.cloud",
})
require.NotNil(t, env)
full := env.GetFull()
require.NotNil(t, full)
require.NotNil(t, full.AccountSettings, "AccountSettings must always be non-nil")
assert.Equal(t, "netbird.cloud", full.DnsDomain)
assert.Empty(t, full.Peers)
assert.Empty(t, full.Policies)
}
func TestEncodeNetworkMapEnvelope_AccountSettingsAlwaysEmitted(t *testing.T) {
c := &types.NetworkMapComponents{
Network: &types.Network{Identifier: "x", Net: net.IPNet{IP: net.IP{100, 64, 0, 0}, Mask: net.CIDRMask(10, 32)}},
// AccountSettings deliberately nil
}
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
require.NotNil(t, full.AccountSettings, "client dereferences AccountSettings unconditionally during Calculate(); a nil here would crash the receiver")
assert.False(t, full.AccountSettings.PeerLoginExpirationEnabled)
assert.Zero(t, full.AccountSettings.PeerLoginExpirationNs)
}

View File

@@ -0,0 +1,192 @@
package grpc
import (
"context"
integrationsConfig "github.com/netbirdio/management-integrations/integrations/config"
"github.com/netbirdio/netbird/client/ssh/auth"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/networkmap"
"github.com/netbirdio/netbird/shared/management/proto"
)
// ToComponentSyncResponse builds a SyncResponse carrying the compact
// NetworkMapEnvelope for capability-aware peers. The legacy proto.NetworkMap
// field is intentionally left empty — capable peers ignore it and the
// envelope alone is the authoritative wire shape.
//
// PeerConfig is computed once server-side using the receiving peer's own
// account-level network metadata. EnableSSH inside PeerConfig is left at
// peer.SSHEnabled (the peer's local setting); account-policy-driven SSH is
// computed by the client from the envelope's GroupIDToUserIDs / AllowedUserIDs
// inside Calculate(), so the SshConfig.SshEnabled bit may flip true on the
// client even though the server-side PeerConfig reports false.
func ToComponentSyncResponse(
ctx context.Context,
config *nbconfig.Config,
httpConfig *nbconfig.HttpServerConfig,
deviceFlowConfig *nbconfig.DeviceAuthorizationFlow,
peer *nbpeer.Peer,
turnCredentials *Token,
relayCredentials *Token,
components *types.NetworkMapComponents,
proxyPatch *types.NetworkMap,
dnsName string,
checks []*posture.Checks,
settings *types.Settings,
extraSettings *types.ExtraSettings,
peerGroups []string,
dnsFwdPort int64,
) *proto.SyncResponse {
network := networkOrZero(components)
enableSSH := computeSSHEnabledForPeer(components, peer)
peerConfig := toPeerConfig(peer, network, dnsName, settings, httpConfig, deviceFlowConfig, enableSSH)
includeIPv6 := peer.SupportsIPv6() && peer.IPv6.IsValid()
useSourcePrefixes := peer.SupportsSourcePrefixes()
userIDClaim := auth.DefaultUserIDClaim
if httpConfig != nil && httpConfig.AuthUserIDClaim != "" {
userIDClaim = httpConfig.AuthUserIDClaim
}
envelope := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{
Components: components,
PeerConfig: peerConfig,
DNSDomain: dnsName,
DNSForwarderPort: dnsFwdPort,
UserIDClaim: userIDClaim,
ProxyPatch: toProxyPatch(proxyPatch, dnsName, includeIPv6, useSourcePrefixes),
})
resp := &proto.SyncResponse{
PeerConfig: peerConfig,
NetworkMapEnvelope: envelope,
Checks: toProtocolChecks(ctx, checks),
}
nbConfig := toNetbirdConfig(config, turnCredentials, relayCredentials, extraSettings)
resp.NetbirdConfig = integrationsConfig.ExtendNetBirdConfig(peer.ID, peerGroups, nbConfig, extraSettings)
return resp
}
// networkOrZero returns components.Network or a zero Network — toPeerConfig
// dereferences network.Net which would panic on nil.
func networkOrZero(c *types.NetworkMapComponents) *types.Network {
if c == nil || c.Network == nil {
return &types.Network{}
}
return c.Network
}
// toProxyPatch converts a proxy-injected *types.NetworkMap into the wire
// patch the components envelope ships alongside. Returns nil when there are
// no fragments to merge — proto3 omits a nil message field, so the receiver
// sees no patch and skips the merge step entirely.
//
// We reuse the legacy proto-conversion helpers (toProtocolRoutes,
// toProtocolFirewallRules, toProtocolRoutesFirewallRules,
// appendRemotePeerConfig, ForwardingRule.ToProto) because the proxy
// delivers fragments pre-expanded — there's no raw component shape to
// derive them from. Components purity isn't violated: proxy data isn't
// policy-graph-derived, it's externally injected post-Calculate, so the
// client merges it on top of its locally-computed NetworkMap.
func toProxyPatch(nm *types.NetworkMap, dnsName string, includeIPv6, useSourcePrefixes bool) *proto.ProxyPatch {
if nm == nil {
return nil
}
if len(nm.Peers) == 0 && len(nm.OfflinePeers) == 0 && len(nm.FirewallRules) == 0 &&
len(nm.Routes) == 0 && len(nm.RoutesFirewallRules) == 0 && len(nm.ForwardingRules) == 0 {
return nil
}
patch := &proto.ProxyPatch{
Peers: networkmap.AppendRemotePeerConfig(nil, nm.Peers, dnsName, includeIPv6),
OfflinePeers: networkmap.AppendRemotePeerConfig(nil, nm.OfflinePeers, dnsName, includeIPv6),
FirewallRules: networkmap.ToProtocolFirewallRules(nm.FirewallRules, includeIPv6, useSourcePrefixes),
Routes: networkmap.ToProtocolRoutes(nm.Routes),
RouteFirewallRules: networkmap.ToProtocolRoutesFirewallRules(nm.RoutesFirewallRules),
}
if len(nm.ForwardingRules) > 0 {
patch.ForwardingRules = make([]*proto.ForwardingRule, 0, len(nm.ForwardingRules))
for _, r := range nm.ForwardingRules {
patch.ForwardingRules = append(patch.ForwardingRules, r.ToProto())
}
}
return patch
}
// computeSSHEnabledForPeer mirrors the SSH-server-activation bit that
// Calculate() folds into NetworkMap.EnableSSH. Components-format peers
// receive a freshly-computed PeerConfig.SshConfig.SshEnabled at sync time;
// without this helper the field would be incorrectly false for any peer
// that's the destination of an SSH-enabling policy without having
// peer.SSHEnabled set locally.
//
// Mirrors the two activation paths Calculate() uses:
// 1. Explicit: rule.Protocol == NetbirdSSH and peer is in the rule's
// destinations.
// 2. Legacy implicit: rule covers TCP/22 or TCP/22022 (or ALL), peer is in
// destinations, AND the peer has SSHEnabled set locally — this is the
// "allow-all/TCP-22 implies SSH activation for SSH-capable peers" path.
//
// The full SSH AuthorizedUsers map is still produced by the client when it
// runs Calculate() over the envelope.
func computeSSHEnabledForPeer(c *types.NetworkMapComponents, peer *nbpeer.Peer) bool {
if c == nil || peer == nil {
return false
}
// Mirror Calculate's `getAllPeersFromGroups` invariant: target peer must
// exist in c.Peers, otherwise no rule applies to it.
if _, ok := c.Peers[peer.ID]; !ok {
return false
}
for _, policy := range c.Policies {
if policy == nil || !policy.Enabled {
continue
}
for _, rule := range policy.Rules {
if ruleEnablesSSHForPeer(c, rule, peer) {
return true
}
}
}
return false
}
// ruleEnablesSSHForPeer returns true when rule is active, targets peer, and
// either explicitly authorises SSH or covers the legacy TCP/22 path while the
// peer itself has SSH enabled locally.
func ruleEnablesSSHForPeer(c *types.NetworkMapComponents, rule *types.PolicyRule, peer *nbpeer.Peer) bool {
if rule == nil || !rule.Enabled {
return false
}
if !peerInDestinations(c, rule, peer.ID) {
return false
}
if rule.Protocol == types.PolicyRuleProtocolNetbirdSSH {
return true
}
return peer.SSHEnabled && types.PolicyRuleImpliesLegacySSH(rule)
}
// peerInDestinations reports whether peerID is in any of rule.Destinations'
// groups (or matches DestinationResource if it's a peer-typed resource —
// for non-peer types Calculate falls through to group lookup, so we mirror
// that exactly to avoid silent divergence).
func peerInDestinations(c *types.NetworkMapComponents, rule *types.PolicyRule, peerID string) bool {
if rule.DestinationResource.Type == types.ResourceTypePeer && rule.DestinationResource.ID != "" {
return rule.DestinationResource.ID == peerID
}
for _, groupID := range rule.Destinations {
if c.IsPeerInGroup(peerID, groupID) {
return true
}
}
return false
}

View File

@@ -0,0 +1,184 @@
package grpc
import (
"testing"
"github.com/stretchr/testify/assert"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/types"
)
// TestComputeSSHEnabledForPeer covers both Calculate-mirroring branches:
// explicit NetbirdSSH protocol, and the legacy implicit case where a
// TCP/22 (or 22022 / ALL / port-range-covering-22) rule activates SSH when
// the destination peer has SSHEnabled=true locally.
func TestComputeSSHEnabledForPeer(t *testing.T) {
const targetPeerID = "target"
const targetGroupID = "g_dst"
mkComponents := func(rule *types.PolicyRule, sshEnabled bool) (*types.NetworkMapComponents, *nbpeer.Peer) {
peer := &nbpeer.Peer{ID: targetPeerID, SSHEnabled: sshEnabled}
group := &types.Group{ID: targetGroupID, Name: "dst", Peers: []string{targetPeerID}}
return &types.NetworkMapComponents{
Peers: map[string]*nbpeer.Peer{targetPeerID: peer},
Groups: map[string]*types.Group{targetGroupID: group},
Policies: []*types.Policy{{
ID: "p",
Enabled: true,
Rules: []*types.PolicyRule{rule},
}},
}, peer
}
cases := []struct {
name string
peerSSH bool
rule types.PolicyRule
wantEnabled bool
}{
{
name: "explicit-netbird-ssh-activates-regardless-of-peer-ssh",
peerSSH: false,
rule: types.PolicyRule{
Enabled: true, Protocol: types.PolicyRuleProtocolNetbirdSSH,
Destinations: []string{targetGroupID},
},
wantEnabled: true,
},
{
name: "implicit-tcp-22-with-peer-ssh",
peerSSH: true,
rule: types.PolicyRule{
Enabled: true, Protocol: types.PolicyRuleProtocolTCP, Ports: []string{"22"},
Destinations: []string{targetGroupID},
},
wantEnabled: true,
},
{
name: "implicit-tcp-22-without-peer-ssh-disabled",
peerSSH: false,
rule: types.PolicyRule{
Enabled: true, Protocol: types.PolicyRuleProtocolTCP, Ports: []string{"22"},
Destinations: []string{targetGroupID},
},
wantEnabled: false,
},
{
name: "implicit-tcp-22022-with-peer-ssh",
peerSSH: true,
rule: types.PolicyRule{
Enabled: true, Protocol: types.PolicyRuleProtocolTCP, Ports: []string{"22022"},
Destinations: []string{targetGroupID},
},
wantEnabled: true,
},
{
name: "implicit-all-protocol-with-peer-ssh",
peerSSH: true,
rule: types.PolicyRule{
Enabled: true, Protocol: types.PolicyRuleProtocolALL,
Destinations: []string{targetGroupID},
},
wantEnabled: true,
},
{
name: "implicit-port-range-covers-22",
peerSSH: true,
rule: types.PolicyRule{
Enabled: true,
Protocol: types.PolicyRuleProtocolTCP,
PortRanges: []types.RulePortRange{{Start: 20, End: 30}},
Destinations: []string{targetGroupID},
},
wantEnabled: true,
},
{
name: "tcp-80-no-ssh",
peerSSH: true,
rule: types.PolicyRule{
Enabled: true, Protocol: types.PolicyRuleProtocolTCP, Ports: []string{"80"},
Destinations: []string{targetGroupID},
},
wantEnabled: false,
},
{
name: "disabled-rule-skipped",
peerSSH: true,
rule: types.PolicyRule{
Enabled: false, Protocol: types.PolicyRuleProtocolNetbirdSSH,
Destinations: []string{targetGroupID},
},
wantEnabled: false,
},
{
name: "peer-not-in-destinations",
peerSSH: true,
rule: types.PolicyRule{
Enabled: true, Protocol: types.PolicyRuleProtocolNetbirdSSH,
Destinations: []string{"g_other"}, // target not in this group
},
wantEnabled: false,
},
{
name: "peer-typed-destination-resource-matches",
peerSSH: false,
rule: types.PolicyRule{
Enabled: true,
Protocol: types.PolicyRuleProtocolNetbirdSSH,
DestinationResource: types.Resource{ID: targetPeerID, Type: types.ResourceTypePeer},
},
wantEnabled: true,
},
{
name: "non-peer-destination-resource-falls-through-to-groups",
peerSSH: false,
rule: types.PolicyRule{
Enabled: true,
Protocol: types.PolicyRuleProtocolNetbirdSSH,
DestinationResource: types.Resource{ID: targetPeerID, Type: "host"}, // wrong type
Destinations: []string{targetGroupID}, // saved by group fallback
},
wantEnabled: true,
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
c, peer := mkComponents(&tc.rule, tc.peerSSH)
got := computeSSHEnabledForPeer(c, peer)
assert.Equal(t, tc.wantEnabled, got)
})
}
}
// TestComputeSSHEnabledForPeer_TargetMissingFromComponents covers the
// belt-and-suspenders presence guard mirroring Calculate's
// getAllPeersFromGroups invariant.
func TestComputeSSHEnabledForPeer_TargetMissingFromComponents(t *testing.T) {
peer := &nbpeer.Peer{ID: "missing", SSHEnabled: true}
c := &types.NetworkMapComponents{
Peers: map[string]*nbpeer.Peer{}, // target peer NOT present
Groups: map[string]*types.Group{
"g": {ID: "g", Peers: []string{"missing"}},
},
Policies: []*types.Policy{{
ID: "p", Enabled: true,
Rules: []*types.PolicyRule{{
Enabled: true, Protocol: types.PolicyRuleProtocolNetbirdSSH,
Destinations: []string{"g"},
}},
}},
}
assert.False(t, computeSSHEnabledForPeer(c, peer),
"missing target peer must short-circuit to false, not consult policies")
}
// TestComputeSSHEnabledForPeer_NilInputs guards the cheap nil-checks at
// function entry — Calculate doesn't accept nil either, but the helper is
// exported indirectly via ToComponentSyncResponse and may receive nil
// components on graceful-degrade paths.
func TestComputeSSHEnabledForPeer_NilInputs(t *testing.T) {
assert.False(t, computeSSHEnabledForPeer(nil, &nbpeer.Peer{ID: "x"}))
assert.False(t, computeSSHEnabledForPeer(&types.NetworkMapComponents{}, nil))
}

View File

@@ -6,24 +6,22 @@ import (
"net/netip" "net/netip"
"net/url" "net/url"
"strings" "strings"
"time"
log "github.com/sirupsen/logrus" "google.golang.org/protobuf/types/known/timestamppb"
goproto "google.golang.org/protobuf/proto"
integrationsConfig "github.com/netbirdio/management-integrations/integrations/config" integrationsConfig "github.com/netbirdio/management-integrations/integrations/config"
"github.com/netbirdio/netbird/client/ssh/auth" "github.com/netbirdio/netbird/client/ssh/auth"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache" "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config" nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
nbroute "github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/shared/management/networkmap"
"github.com/netbirdio/netbird/shared/management/proto" "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/shared/netiputil" "github.com/netbirdio/netbird/shared/netiputil"
"github.com/netbirdio/netbird/shared/sshauth"
) )
func toNetbirdConfig(config *nbconfig.Config, turnCredentials *Token, relayToken *Token, extraSettings *types.ExtraSettings) *proto.NetbirdConfig { func toNetbirdConfig(config *nbconfig.Config, turnCredentials *Token, relayToken *Token, extraSettings *types.ExtraSettings) *proto.NetbirdConfig {
@@ -138,8 +136,8 @@ func ToSyncResponse(ctx context.Context, config *nbconfig.Config, httpConfig *nb
PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings, httpConfig, deviceFlowConfig, networkMap.EnableSSH), PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings, httpConfig, deviceFlowConfig, networkMap.EnableSSH),
NetworkMap: &proto.NetworkMap{ NetworkMap: &proto.NetworkMap{
Serial: networkMap.Network.CurrentSerial(), Serial: networkMap.Network.CurrentSerial(),
Routes: toProtocolRoutes(networkMap.Routes), Routes: networkmap.ToProtocolRoutes(networkMap.Routes),
DNSConfig: toProtocolDNSConfig(networkMap.DNSConfig, dnsCache, dnsFwdPort), DNSConfig: networkmap.ToProtocolDNSConfig(networkMap.DNSConfig, dnsCache, dnsFwdPort),
PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings, httpConfig, deviceFlowConfig, networkMap.EnableSSH), PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings, httpConfig, deviceFlowConfig, networkMap.EnableSSH),
}, },
Checks: toProtocolChecks(ctx, checks), Checks: toProtocolChecks(ctx, checks),
@@ -152,19 +150,19 @@ func ToSyncResponse(ctx context.Context, config *nbconfig.Config, httpConfig *nb
response.NetworkMap.PeerConfig = response.PeerConfig response.NetworkMap.PeerConfig = response.PeerConfig
remotePeers := make([]*proto.RemotePeerConfig, 0, len(networkMap.Peers)+len(networkMap.OfflinePeers)) remotePeers := make([]*proto.RemotePeerConfig, 0, len(networkMap.Peers)+len(networkMap.OfflinePeers))
remotePeers = appendRemotePeerConfig(remotePeers, networkMap.Peers, dnsName, includeIPv6) remotePeers = networkmap.AppendRemotePeerConfig(remotePeers, networkMap.Peers, dnsName, includeIPv6)
response.RemotePeers = remotePeers response.RemotePeers = remotePeers
response.NetworkMap.RemotePeers = remotePeers response.NetworkMap.RemotePeers = remotePeers
response.RemotePeersIsEmpty = len(remotePeers) == 0 response.RemotePeersIsEmpty = len(remotePeers) == 0
response.NetworkMap.RemotePeersIsEmpty = response.RemotePeersIsEmpty response.NetworkMap.RemotePeersIsEmpty = response.RemotePeersIsEmpty
response.NetworkMap.OfflinePeers = appendRemotePeerConfig(nil, networkMap.OfflinePeers, dnsName, includeIPv6) response.NetworkMap.OfflinePeers = networkmap.AppendRemotePeerConfig(nil, networkMap.OfflinePeers, dnsName, includeIPv6)
firewallRules := toProtocolFirewallRules(networkMap.FirewallRules, includeIPv6, useSourcePrefixes) firewallRules := networkmap.ToProtocolFirewallRules(networkMap.FirewallRules, includeIPv6, useSourcePrefixes)
response.NetworkMap.FirewallRules = firewallRules response.NetworkMap.FirewallRules = firewallRules
response.NetworkMap.FirewallRulesIsEmpty = len(firewallRules) == 0 response.NetworkMap.FirewallRulesIsEmpty = len(firewallRules) == 0
routesFirewallRules := toProtocolRoutesFirewallRules(networkMap.RoutesFirewallRules) routesFirewallRules := networkmap.ToProtocolRoutesFirewallRules(networkMap.RoutesFirewallRules)
response.NetworkMap.RoutesFirewallRules = routesFirewallRules response.NetworkMap.RoutesFirewallRules = routesFirewallRules
response.NetworkMap.RoutesFirewallRulesIsEmpty = len(routesFirewallRules) == 0 response.NetworkMap.RoutesFirewallRulesIsEmpty = len(routesFirewallRules) == 0
@@ -177,7 +175,7 @@ func ToSyncResponse(ctx context.Context, config *nbconfig.Config, httpConfig *nb
} }
if networkMap.AuthorizedUsers != nil { if networkMap.AuthorizedUsers != nil {
hashedUsers, machineUsers := buildAuthorizedUsersProto(ctx, networkMap.AuthorizedUsers) hashedUsers, machineUsers := networkmap.BuildAuthorizedUsersProto(ctx, networkMap.AuthorizedUsers)
userIDClaim := auth.DefaultUserIDClaim userIDClaim := auth.DefaultUserIDClaim
if httpConfig != nil && httpConfig.AuthUserIDClaim != "" { if httpConfig != nil && httpConfig.AuthUserIDClaim != "" {
userIDClaim = httpConfig.AuthUserIDClaim userIDClaim = httpConfig.AuthUserIDClaim
@@ -185,79 +183,36 @@ func ToSyncResponse(ctx context.Context, config *nbconfig.Config, httpConfig *nb
response.NetworkMap.SshAuth = &proto.SSHAuth{AuthorizedUsers: hashedUsers, MachineUsers: machineUsers, UserIDClaim: userIDClaim} response.NetworkMap.SshAuth = &proto.SSHAuth{AuthorizedUsers: hashedUsers, MachineUsers: machineUsers, UserIDClaim: userIDClaim}
} }
// settings == nil → field stays nil → "no info in this snapshot", client
// preserves the deadline it already had. settings non-nil → emit either a
// valid deadline or the explicit-zero "disabled" sentinel via
// encodeSessionExpiresAt.
if settings != nil {
response.SessionExpiresAt = encodeSessionExpiresAt(
peer.SessionExpiresAt(settings.PeerLoginExpirationEnabled, settings.PeerLoginExpiration),
)
}
return response return response
} }
func buildAuthorizedUsersProto(ctx context.Context, authorizedUsers map[string]map[string]struct{}) ([][]byte, map[string]*proto.MachineUserIndexes) { // encodeSessionExpiresAt encodes a server-side deadline into the 3-state wire
userIDToIndex := make(map[string]uint32) // representation used on LoginResponse, SyncResponse and
var hashedUsers [][]byte // ExtendAuthSessionResponse. See the proto comments on those messages.
machineUsers := make(map[string]*proto.MachineUserIndexes, len(authorizedUsers)) //
// - deadline.IsZero() → returns &Timestamp{} (seconds=0, nanos=0): the
for machineUser, users := range authorizedUsers { // "expiry disabled or peer is not SSO-tracked" sentinel; the client clears
indexes := make([]uint32, 0, len(users)) // its anchor.
for userID := range users { // - deadline non-zero → returns timestamppb.New(deadline): the new absolute
idx, exists := userIDToIndex[userID] // UTC deadline.
if !exists { //
hash, err := sshauth.HashUserID(userID) // Returning nil ("no info, preserve client's anchor") is the caller's job —
if err != nil { // only meaningful on Sync builds where settings were not resolved.
log.WithContext(ctx).Errorf("failed to hash user id %s: %v", userID, err) func encodeSessionExpiresAt(deadline time.Time) *timestamppb.Timestamp {
continue if deadline.IsZero() {
} return &timestamppb.Timestamp{}
idx = uint32(len(hashedUsers))
userIDToIndex[userID] = idx
hashedUsers = append(hashedUsers, hash[:])
}
indexes = append(indexes, idx)
}
machineUsers[machineUser] = &proto.MachineUserIndexes{Indexes: indexes}
} }
return timestamppb.New(deadline)
return hashedUsers, machineUsers
}
func appendRemotePeerConfig(dst []*proto.RemotePeerConfig, peers []*nbpeer.Peer, dnsName string, includeIPv6 bool) []*proto.RemotePeerConfig {
for _, rPeer := range peers {
allowedIPs := []string{rPeer.IP.String() + "/32"}
if includeIPv6 && rPeer.IPv6.IsValid() {
allowedIPs = append(allowedIPs, rPeer.IPv6.String()+"/128")
}
dst = append(dst, &proto.RemotePeerConfig{
WgPubKey: rPeer.Key,
AllowedIps: allowedIPs,
SshConfig: &proto.SSHConfig{SshPubKey: []byte(rPeer.SSHKey)},
Fqdn: rPeer.FQDN(dnsName),
AgentVersion: rPeer.Meta.WtVersion,
})
}
return dst
}
// toProtocolDNSConfig converts nbdns.Config to proto.DNSConfig using the cache
func toProtocolDNSConfig(update nbdns.Config, cache *cache.DNSConfigCache, forwardPort int64) *proto.DNSConfig {
protoUpdate := &proto.DNSConfig{
ServiceEnable: update.ServiceEnable,
CustomZones: make([]*proto.CustomZone, 0, len(update.CustomZones)),
NameServerGroups: make([]*proto.NameServerGroup, 0, len(update.NameServerGroups)),
ForwarderPort: forwardPort,
}
for _, zone := range update.CustomZones {
protoZone := convertToProtoCustomZone(zone)
protoUpdate.CustomZones = append(protoUpdate.CustomZones, protoZone)
}
for _, nsGroup := range update.NameServerGroups {
cacheKey := nsGroup.ID
if cachedGroup, exists := cache.GetNameServerGroup(cacheKey); exists {
protoUpdate.NameServerGroups = append(protoUpdate.NameServerGroups, cachedGroup)
} else {
protoGroup := convertToProtoNameServerGroup(nsGroup)
cache.SetNameServerGroup(cacheKey, protoGroup)
protoUpdate.NameServerGroups = append(protoUpdate.NameServerGroups, protoGroup)
}
}
return protoUpdate
} }
func ToResponseProto(configProto nbconfig.Protocol) proto.HostConfig_Protocol { func ToResponseProto(configProto nbconfig.Protocol) proto.HostConfig_Protocol {
@@ -277,204 +232,6 @@ func ToResponseProto(configProto nbconfig.Protocol) proto.HostConfig_Protocol {
} }
} }
func toProtocolRoutes(routes []*nbroute.Route) []*proto.Route {
protoRoutes := make([]*proto.Route, 0, len(routes))
for _, r := range routes {
protoRoutes = append(protoRoutes, toProtocolRoute(r))
}
return protoRoutes
}
func toProtocolRoute(route *nbroute.Route) *proto.Route {
return &proto.Route{
ID: string(route.ID),
NetID: string(route.NetID),
Network: route.Network.String(),
Domains: route.Domains.ToPunycodeList(),
NetworkType: int64(route.NetworkType),
Peer: route.Peer,
Metric: int64(route.Metric),
Masquerade: route.Masquerade,
KeepRoute: route.KeepRoute,
SkipAutoApply: route.SkipAutoApply,
}
}
// toProtocolFirewallRules converts the firewall rules to the protocol firewall rules.
// When useSourcePrefixes is true, the compact SourcePrefixes field is populated
// alongside the deprecated PeerIP for forward compatibility.
// Wildcard rules ("0.0.0.0") are expanded into separate v4 and v6 SourcePrefixes
// when includeIPv6 is true.
func toProtocolFirewallRules(rules []*types.FirewallRule, includeIPv6, useSourcePrefixes bool) []*proto.FirewallRule {
result := make([]*proto.FirewallRule, 0, len(rules))
for i := range rules {
rule := rules[i]
fwRule := &proto.FirewallRule{
PolicyID: []byte(rule.PolicyID),
PeerIP: rule.PeerIP, //nolint:staticcheck // populated for backward compatibility
Direction: getProtoDirection(rule.Direction),
Action: getProtoAction(rule.Action),
Protocol: getProtoProtocol(rule.Protocol),
Port: rule.Port,
}
if useSourcePrefixes && rule.PeerIP != "" {
result = append(result, populateSourcePrefixes(fwRule, rule, includeIPv6)...)
}
if shouldUsePortRange(fwRule) {
fwRule.PortInfo = rule.PortRange.ToProto()
}
result = append(result, fwRule)
}
return result
}
// populateSourcePrefixes sets SourcePrefixes on fwRule and returns any
// additional rules needed (e.g. a v6 wildcard clone when the peer IP is unspecified).
func populateSourcePrefixes(fwRule *proto.FirewallRule, rule *types.FirewallRule, includeIPv6 bool) []*proto.FirewallRule {
addr, err := netip.ParseAddr(rule.PeerIP)
if err != nil {
return nil
}
if !addr.IsUnspecified() {
fwRule.SourcePrefixes = [][]byte{netiputil.EncodeAddr(addr.Unmap())}
return nil
}
// IPv4Unspecified/0 is always valid, error is impossible.
v4Wildcard, _ := netiputil.EncodePrefix(netip.PrefixFrom(netip.IPv4Unspecified(), 0))
fwRule.SourcePrefixes = [][]byte{v4Wildcard}
if !includeIPv6 {
return nil
}
v6Rule := goproto.Clone(fwRule).(*proto.FirewallRule)
v6Rule.PeerIP = "::" //nolint:staticcheck // populated for backward compatibility
// IPv6Unspecified/0 is always valid, error is impossible.
v6Wildcard, _ := netiputil.EncodePrefix(netip.PrefixFrom(netip.IPv6Unspecified(), 0))
v6Rule.SourcePrefixes = [][]byte{v6Wildcard}
if shouldUsePortRange(v6Rule) {
v6Rule.PortInfo = rule.PortRange.ToProto()
}
return []*proto.FirewallRule{v6Rule}
}
// getProtoDirection converts the direction to proto.RuleDirection.
func getProtoDirection(direction int) proto.RuleDirection {
if direction == types.FirewallRuleDirectionOUT {
return proto.RuleDirection_OUT
}
return proto.RuleDirection_IN
}
func toProtocolRoutesFirewallRules(rules []*types.RouteFirewallRule) []*proto.RouteFirewallRule {
result := make([]*proto.RouteFirewallRule, len(rules))
for i := range rules {
rule := rules[i]
result[i] = &proto.RouteFirewallRule{
SourceRanges: rule.SourceRanges,
Action: getProtoAction(rule.Action),
Destination: rule.Destination,
Protocol: getProtoProtocol(rule.Protocol),
PortInfo: getProtoPortInfo(rule),
IsDynamic: rule.IsDynamic,
Domains: rule.Domains.ToPunycodeList(),
PolicyID: []byte(rule.PolicyID),
RouteID: string(rule.RouteID),
}
}
return result
}
// getProtoAction converts the action to proto.RuleAction.
func getProtoAction(action string) proto.RuleAction {
if action == string(types.PolicyTrafficActionDrop) {
return proto.RuleAction_DROP
}
return proto.RuleAction_ACCEPT
}
// getProtoProtocol converts the protocol to proto.RuleProtocol.
func getProtoProtocol(protocol string) proto.RuleProtocol {
switch types.PolicyRuleProtocolType(protocol) {
case types.PolicyRuleProtocolALL:
return proto.RuleProtocol_ALL
case types.PolicyRuleProtocolTCP:
return proto.RuleProtocol_TCP
case types.PolicyRuleProtocolUDP:
return proto.RuleProtocol_UDP
case types.PolicyRuleProtocolICMP:
return proto.RuleProtocol_ICMP
default:
return proto.RuleProtocol_UNKNOWN
}
}
// getProtoPortInfo converts the port info to proto.PortInfo.
func getProtoPortInfo(rule *types.RouteFirewallRule) *proto.PortInfo {
var portInfo proto.PortInfo
if rule.Port != 0 {
portInfo.PortSelection = &proto.PortInfo_Port{Port: uint32(rule.Port)}
} else if portRange := rule.PortRange; portRange.Start != 0 && portRange.End != 0 {
portInfo.PortSelection = &proto.PortInfo_Range_{
Range: &proto.PortInfo_Range{
Start: uint32(portRange.Start),
End: uint32(portRange.End),
},
}
}
return &portInfo
}
func shouldUsePortRange(rule *proto.FirewallRule) bool {
return rule.Port == "" && (rule.Protocol == proto.RuleProtocol_UDP || rule.Protocol == proto.RuleProtocol_TCP)
}
// Helper function to convert nbdns.CustomZone to proto.CustomZone
func convertToProtoCustomZone(zone nbdns.CustomZone) *proto.CustomZone {
protoZone := &proto.CustomZone{
Domain: zone.Domain,
Records: make([]*proto.SimpleRecord, 0, len(zone.Records)),
SearchDomainDisabled: zone.SearchDomainDisabled,
NonAuthoritative: zone.NonAuthoritative,
}
for _, record := range zone.Records {
protoZone.Records = append(protoZone.Records, &proto.SimpleRecord{
Name: record.Name,
Type: int64(record.Type),
Class: record.Class,
TTL: int64(record.TTL),
RData: record.RData,
})
}
return protoZone
}
// Helper function to convert nbdns.NameServerGroup to proto.NameServerGroup
func convertToProtoNameServerGroup(nsGroup *nbdns.NameServerGroup) *proto.NameServerGroup {
protoGroup := &proto.NameServerGroup{
Primary: nsGroup.Primary,
Domains: nsGroup.Domains,
SearchDomainsEnabled: nsGroup.SearchDomainsEnabled,
NameServers: make([]*proto.NameServer, 0, len(nsGroup.NameServers)),
}
for _, ns := range nsGroup.NameServers {
protoGroup.NameServers = append(protoGroup.NameServers, &proto.NameServer{
IP: ns.IP.String(),
Port: int64(ns.Port),
NSType: int64(ns.NSType),
})
}
return protoGroup
}
// buildJWTConfig constructs JWT configuration for SSH servers from management server config // buildJWTConfig constructs JWT configuration for SSH servers from management server config
func buildJWTConfig(config *nbconfig.HttpServerConfig, deviceFlowConfig *nbconfig.DeviceAuthorizationFlow) *proto.JWTConfig { func buildJWTConfig(config *nbconfig.HttpServerConfig, deviceFlowConfig *nbconfig.DeviceAuthorizationFlow) *proto.JWTConfig {
if config == nil || config.AuthAudience == "" { if config == nil || config.AuthAudience == "" {

View File

@@ -5,6 +5,7 @@ import (
"net/netip" "net/netip"
"reflect" "reflect"
"testing" "testing"
"time"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@@ -12,6 +13,7 @@ import (
"github.com/netbirdio/netbird/management/internals/controllers/network_map" "github.com/netbirdio/netbird/management/internals/controllers/network_map"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache" "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config" nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/shared/management/networkmap"
) )
func TestToProtocolDNSConfigWithCache(t *testing.T) { func TestToProtocolDNSConfigWithCache(t *testing.T) {
@@ -61,13 +63,13 @@ func TestToProtocolDNSConfigWithCache(t *testing.T) {
} }
// First run with config1 // First run with config1
result1 := toProtocolDNSConfig(config1, &cache, int64(network_map.DnsForwarderPort)) result1 := networkmap.ToProtocolDNSConfig(config1, &cache, int64(network_map.DnsForwarderPort))
// Second run with config2 // Second run with config2
result2 := toProtocolDNSConfig(config2, &cache, int64(network_map.DnsForwarderPort)) result2 := networkmap.ToProtocolDNSConfig(config2, &cache, int64(network_map.DnsForwarderPort))
// Third run with config1 again // Third run with config1 again
result3 := toProtocolDNSConfig(config1, &cache, int64(network_map.DnsForwarderPort)) result3 := networkmap.ToProtocolDNSConfig(config1, &cache, int64(network_map.DnsForwarderPort))
// Verify that result1 and result3 are identical // Verify that result1 and result3 are identical
if !reflect.DeepEqual(result1, result3) { if !reflect.DeepEqual(result1, result3) {
@@ -99,7 +101,7 @@ func BenchmarkToProtocolDNSConfig(b *testing.B) {
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
toProtocolDNSConfig(testData, cache, int64(network_map.DnsForwarderPort)) networkmap.ToProtocolDNSConfig(testData, cache, int64(network_map.DnsForwarderPort))
} }
}) })
@@ -107,7 +109,7 @@ func BenchmarkToProtocolDNSConfig(b *testing.B) {
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
cache := &cache.DNSConfigCache{} cache := &cache.DNSConfigCache{}
toProtocolDNSConfig(testData, cache, int64(network_map.DnsForwarderPort)) networkmap.ToProtocolDNSConfig(testData, cache, int64(network_map.DnsForwarderPort))
} }
}) })
} }
@@ -200,3 +202,29 @@ func TestBuildJWTConfig_Audiences(t *testing.T) {
}) })
} }
} }
// TestEncodeSessionExpiresAt pins the wire encoding the client's
// applySessionDeadline depends on:
//
// - zero deadline → &Timestamp{} (seconds=0, nanos=0): the explicit
// "expiry disabled or peer is not SSO-tracked" sentinel.
// - non-zero → timestamppb.New(deadline): the absolute UTC deadline.
//
// The third state (nil pointer = "no info in this snapshot") is the caller's
// responsibility on the Sync path when settings could not be resolved; the
// helper itself never returns nil.
func TestEncodeSessionExpiresAt(t *testing.T) {
t.Run("zero deadline encodes as explicit-zero sentinel", func(t *testing.T) {
got := encodeSessionExpiresAt(time.Time{})
assert.NotNil(t, got, "must not return nil; nil means 'no info', not 'disabled'")
assert.Equal(t, int64(0), got.GetSeconds())
assert.Equal(t, int32(0), got.GetNanos())
})
t.Run("non-zero deadline round-trips", func(t *testing.T) {
deadline := time.Date(2030, 1, 2, 3, 4, 5, 0, time.UTC)
got := encodeSessionExpiresAt(deadline)
assert.NotNil(t, got)
assert.True(t, got.AsTime().Equal(deadline))
})
}

View File

@@ -437,7 +437,7 @@ func (s *Server) handleUpdates(ctx context.Context, accountID string, peerKey wg
return nil return nil
} }
log.WithContext(ctx).Debugf("received an update for peer %s", peerKey.String()) log.WithContext(ctx).Tracef("received an update for peer %s", peerKey.String())
if debouncer.ProcessUpdate(update) { if debouncer.ProcessUpdate(update) {
// Send immediately (first update or after quiet period) // Send immediately (first update or after quiet period)
if err := s.sendUpdate(ctx, accountID, peerKey, peer, update, srv, streamStartTime); err != nil { if err := s.sendUpdate(ctx, accountID, peerKey, peer, update, srv, streamStartTime); err != nil {
@@ -492,7 +492,7 @@ func (s *Server) sendUpdate(ctx context.Context, accountID string, peerKey wgtyp
s.cancelPeerRoutines(ctx, accountID, peer, streamStartTime) s.cancelPeerRoutines(ctx, accountID, peer, streamStartTime)
return status.Errorf(codes.Internal, "failed sending update message") return status.Errorf(codes.Internal, "failed sending update message")
} }
log.WithContext(ctx).Debugf("sent an update to peer %s", peerKey.String()) log.WithContext(ctx).Tracef("sent an update to peer %s", peerKey.String())
return nil return nil
} }
@@ -821,6 +821,80 @@ func (s *Server) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto
}, nil }, nil
} }
// ExtendAuthSession refreshes the peer's SSO session expiry deadline using a
// fresh JWT. The same JWT validation pipeline as Login is used. The tunnel
// stays up; no network map sync is performed. The new deadline is returned
// in ExtendAuthSessionResponse.SessionExpiresAt.
func (s *Server) ExtendAuthSession(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
extendReq := &proto.ExtendAuthSessionRequest{}
peerKey, err := s.parseRequest(ctx, req, extendReq)
if err != nil {
return nil, err
}
//nolint
ctx = context.WithValue(ctx, nbContext.PeerIDKey, peerKey.String())
if accountID, accErr := s.accountManager.GetAccountIDForPeerKey(ctx, peerKey.String()); accErr == nil {
//nolint
ctx = context.WithValue(ctx, nbContext.AccountIDKey, accountID)
}
jwt := extendReq.GetJwtToken()
if jwt == "" {
return nil, status.Errorf(codes.InvalidArgument, "jwt token is required")
}
var userID string
const attempts = 3
for i := 0; i < attempts; i++ {
userID, err = s.validateToken(ctx, peerKey.String(), jwt)
if err == nil {
break
}
if i == attempts-1 {
break
}
log.WithContext(ctx).Warnf("failed validating JWT token while extending session for peer %s: %v. Retrying (idP cache).", peerKey.String(), err)
select {
case <-time.After(200 * time.Millisecond):
case <-ctx.Done():
return nil, ctx.Err()
}
}
if err != nil {
return nil, err
}
if userID == "" {
return nil, status.Errorf(codes.Unauthenticated, "jwt token did not yield a user id")
}
deadline, err := s.accountManager.ExtendPeerSession(ctx, peerKey.String(), userID)
if err != nil {
log.WithContext(ctx).Warnf("failed extending session for peer %s: %v", peerKey.String(), err)
return nil, mapError(ctx, err)
}
// Success path normally returns a non-zero deadline. A defensive zero
// would still encode as the explicit "disabled" sentinel rather than nil,
// so the client clears any stale anchor instead of preserving it.
resp := &proto.ExtendAuthSessionResponse{
SessionExpiresAt: encodeSessionExpiresAt(deadline),
}
wgKey, err := s.secretsManager.GetWGKey()
if err != nil {
return nil, status.Errorf(codes.Internal, "failed processing request")
}
encrypted, err := encryption.EncryptMessage(peerKey, wgKey, resp)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed encrypting response")
}
return &proto.EncryptedMessage{
WgPubKey: wgKey.PublicKey().String(),
Body: encrypted,
}, nil
}
func (s *Server) prepareLoginResponse(ctx context.Context, peer *nbpeer.Peer, netMap *types.NetworkMap, postureChecks []*posture.Checks) (*proto.LoginResponse, error) { func (s *Server) prepareLoginResponse(ctx context.Context, peer *nbpeer.Peer, netMap *types.NetworkMap, postureChecks []*posture.Checks) (*proto.LoginResponse, error) {
var relayToken *Token var relayToken *Token
var err error var err error
@@ -844,6 +918,12 @@ func (s *Server) prepareLoginResponse(ctx context.Context, peer *nbpeer.Peer, ne
Checks: toProtocolChecks(ctx, postureChecks), Checks: toProtocolChecks(ctx, postureChecks),
} }
// settings is always non-nil here, so we never emit nil — encoder returns
// either a valid deadline or the explicit-zero "disabled" sentinel.
loginResp.SessionExpiresAt = encodeSessionExpiresAt(
peer.SessionExpiresAt(settings.PeerLoginExpirationEnabled, settings.PeerLoginExpiration),
)
return loginResp, nil return loginResp, nil
} }
@@ -932,7 +1012,31 @@ func (s *Server) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, peer
return status.Errorf(codes.Internal, "failed to get peer groups %s", err) return status.Errorf(codes.Internal, "failed to get peer groups %s", err)
} }
plainResp := ToSyncResponse(ctx, s.config, s.config.HttpConfig, s.config.DeviceAuthorizationFlow, peer, turnToken, relayToken, networkMap, s.networkMapController.GetDNSDomain(settings), postureChecks, nil, settings, settings.Extra, peerGroups, dnsFwdPort) dnsName := s.networkMapController.GetDNSDomain(settings)
var plainResp *proto.SyncResponse
if s.networkMapController.PeerNeedsComponents(peer) {
// Capable peer: discard the legacy NetworkMap that SyncAndMarkPeer
// computed and recompute the raw components instead. This wastes one
// Calculate() call per initial-sync — the component-based wire
// format is what the peer actually consumes. The streaming path
// (network_map.Controller.UpdateAccountPeers) skips this duplication
// because it dispatches by capability before computing.
//
// TODO: refactor SyncPeer / SyncAndMarkPeer / their mocks + manager
// interfaces to return PeerNetworkMapResult so the initial-sync path
// stops doing duplicate work. Deferred until the client-side
// decoder lands and there's a real deployment of capability=3 peers
// worth optimizing for.
_, components, proxyPatch, _, _, err := s.networkMapController.GetValidatedPeerWithComponents(ctx, false, peer.AccountID, peer)
if err != nil {
log.WithContext(ctx).Errorf("failed to build components for peer %s on initial sync: %v", peer.ID, err)
return status.Errorf(codes.Internal, "failed to build initial sync envelope")
}
plainResp = ToComponentSyncResponse(ctx, s.config, s.config.HttpConfig, s.config.DeviceAuthorizationFlow, peer, turnToken, relayToken, components, proxyPatch, dnsName, postureChecks, settings, settings.Extra, peerGroups, dnsFwdPort)
} else {
plainResp = ToSyncResponse(ctx, s.config, s.config.HttpConfig, s.config.DeviceAuthorizationFlow, peer, turnToken, relayToken, networkMap, dnsName, postureChecks, nil, settings, settings.Extra, peerGroups, dnsFwdPort)
}
key, err := s.secretsManager.GetWGKey() key, err := s.secretsManager.GetWGKey()
if err != nil { if err != nil {

View File

@@ -355,7 +355,17 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
oldSettings.LazyConnectionEnabled != newSettings.LazyConnectionEnabled || oldSettings.LazyConnectionEnabled != newSettings.LazyConnectionEnabled ||
oldSettings.DNSDomain != newSettings.DNSDomain || oldSettings.DNSDomain != newSettings.DNSDomain ||
oldSettings.AutoUpdateVersion != newSettings.AutoUpdateVersion || oldSettings.AutoUpdateVersion != newSettings.AutoUpdateVersion ||
oldSettings.AutoUpdateAlways != newSettings.AutoUpdateAlways { oldSettings.AutoUpdateAlways != newSettings.AutoUpdateAlways ||
oldSettings.PeerLoginExpirationEnabled != newSettings.PeerLoginExpirationEnabled ||
oldSettings.PeerLoginExpiration != newSettings.PeerLoginExpiration {
// Session deadline is derived from LastLogin + PeerLoginExpiration
// on every Login/Sync response. Without a fan-out push, connected
// peers keep the deadline they received at login time and only see
// the new value after the next unrelated NetworkMap change. Add
// these two fields to the trigger list so admin-side expiry tweaks
// (e.g. shortening from 24h to 1h) reach every connected peer
// within seconds, which is what the proactive-warning feature
// relies on (see client/internal/auth/sessionwatch).
updateAccountPeers = true updateAccountPeers = true
} }
@@ -1621,6 +1631,14 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth
return nil return nil
} }
for _, g := range newGroupsToCreate {
seq, err := transaction.AllocateAccountSeqID(ctx, userAuth.AccountId, types.AccountSeqEntityGroup)
if err != nil {
return fmt.Errorf("error allocating group seq id: %w", err)
}
g.AccountSeqID = seq
}
if err = transaction.CreateGroups(ctx, userAuth.AccountId, newGroupsToCreate); err != nil { if err = transaction.CreateGroups(ctx, userAuth.AccountId, newGroupsToCreate); err != nil {
return fmt.Errorf("error saving groups: %w", err) return fmt.Errorf("error saving groups: %w", err)
} }

View File

@@ -109,6 +109,7 @@ type Manager interface {
UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Settings, error) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Settings, error)
UpdateAccountOnboarding(ctx context.Context, accountID, userID string, newOnboarding *types.AccountOnboarding) (*types.AccountOnboarding, error) UpdateAccountOnboarding(ctx context.Context, accountID, userID string, newOnboarding *types.AccountOnboarding) (*types.AccountOnboarding, error)
LoginPeer(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) // used by peer gRPC API LoginPeer(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) // used by peer gRPC API
ExtendPeerSession(ctx context.Context, peerPubKey, userID string) (time.Time, error) // used by peer gRPC API for ExtendAuthSession
SyncPeer(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) // used by peer gRPC API SyncPeer(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) // used by peer gRPC API
GetExternalCacheManager() ExternalCacheManager GetExternalCacheManager() ExternalCacheManager
GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error)

View File

@@ -1304,6 +1304,21 @@ func (mr *MockManagerMockRecorder) LoginPeer(ctx, login interface{}) *gomock.Cal
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoginPeer", reflect.TypeOf((*MockManager)(nil).LoginPeer), ctx, login) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoginPeer", reflect.TypeOf((*MockManager)(nil).LoginPeer), ctx, login)
} }
// ExtendPeerSession mocks base method.
func (m *MockManager) ExtendPeerSession(ctx context.Context, peerPubKey, userID string) (time.Time, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ExtendPeerSession", ctx, peerPubKey, userID)
ret0, _ := ret[0].(time.Time)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// ExtendPeerSession indicates an expected call of ExtendPeerSession.
func (mr *MockManagerMockRecorder) ExtendPeerSession(ctx, peerPubKey, userID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ExtendPeerSession", reflect.TypeOf((*MockManager)(nil).ExtendPeerSession), ctx, peerPubKey, userID)
}
// MarkPeerConnected mocks base method. // MarkPeerConnected mocks base method.
func (m *MockManager) MarkPeerConnected(ctx context.Context, peerKey string, realIP net.IP, accountID string, sessionStartedAt int64) error { func (m *MockManager) MarkPeerConnected(ctx context.Context, peerKey string, realIP net.IP, accountID string, sessionStartedAt int64) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()

View File

@@ -3036,6 +3036,16 @@ func TestAccount_SetJWTGroups(t *testing.T) {
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user2") user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user2")
assert.NoError(t, err, "unable to get user") assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 1, "new group should be added") assert.Len(t, user.AutoGroups, 1, "new group should be added")
var newJWTGroup *types.Group
for _, g := range groups {
if g.Name == "group3" {
newJWTGroup = g
break
}
}
require.NotNil(t, newJWTGroup, "JIT-created JWT group not found")
assert.NotZero(t, newJWTGroup.AccountSeqID, "JIT-created JWT group must have a non-zero AccountSeqID")
}) })
t.Run("remove all JWT groups when list is empty", func(t *testing.T) { t.Run("remove all JWT groups when list is empty", func(t *testing.T) {

View File

@@ -240,6 +240,10 @@ const (
AccountLocalMfaEnabled Activity = 123 AccountLocalMfaEnabled Activity = 123
// AccountLocalMfaDisabled indicates that a user disabled TOTP MFA for local users // AccountLocalMfaDisabled indicates that a user disabled TOTP MFA for local users
AccountLocalMfaDisabled Activity = 124 AccountLocalMfaDisabled Activity = 124
// UserExtendedPeerSession indicates that a user refreshed their peer's
// SSO session deadline via ExtendAuthSession without re-establishing the
// tunnel. Distinct from UserLoggedInPeer (full interactive login).
UserExtendedPeerSession Activity = 125
AccountDeleted Activity = 99999 AccountDeleted Activity = 99999
) )
@@ -394,6 +398,8 @@ var activityMap = map[Activity]Code{
AccountLocalMfaEnabled: {"Account local MFA enabled", "account.setting.local.mfa.enable"}, AccountLocalMfaEnabled: {"Account local MFA enabled", "account.setting.local.mfa.enable"},
AccountLocalMfaDisabled: {"Account local MFA disabled", "account.setting.local.mfa.disable"}, AccountLocalMfaDisabled: {"Account local MFA disabled", "account.setting.local.mfa.disable"},
UserExtendedPeerSession: {"User extended peer session", "user.peer.session.extend"},
DomainAdded: {"Domain added", "domain.add"}, DomainAdded: {"Domain added", "domain.add"},
DomainDeleted: {"Domain deleted", "domain.delete"}, DomainDeleted: {"Domain deleted", "domain.delete"},
DomainValidated: {"Domain validated", "domain.validate"}, DomainValidated: {"Domain validated", "domain.validate"},

View File

@@ -96,6 +96,12 @@ func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, use
return err return err
} }
seq, err := transaction.AllocateAccountSeqID(ctx, accountID, types.AccountSeqEntityGroup)
if err != nil {
return status.Errorf(status.Internal, "failed to allocate group seq id: %v", err)
}
newGroup.AccountSeqID = seq
if err := transaction.CreateGroup(ctx, newGroup); err != nil { if err := transaction.CreateGroup(ctx, newGroup); err != nil {
return status.Errorf(status.Internal, "failed to create group: %v", err) return status.Errorf(status.Internal, "failed to create group: %v", err)
} }
@@ -170,6 +176,8 @@ func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, use
return err return err
} }
newGroup.AccountSeqID = oldGroup.AccountSeqID
if err = transaction.UpdateGroup(ctx, newGroup); err != nil { if err = transaction.UpdateGroup(ctx, newGroup); err != nil {
return err return err
} }
@@ -221,6 +229,12 @@ func (am *DefaultAccountManager) CreateGroups(ctx context.Context, accountID, us
newGroup.AccountID = accountID newGroup.AccountID = accountID
seq, err := transaction.AllocateAccountSeqID(ctx, accountID, types.AccountSeqEntityGroup)
if err != nil {
return err
}
newGroup.AccountSeqID = seq
if err = transaction.CreateGroup(ctx, newGroup); err != nil { if err = transaction.CreateGroup(ctx, newGroup); err != nil {
return err return err
} }
@@ -320,6 +334,12 @@ func (am *DefaultAccountManager) updateSingleGroup(ctx context.Context, accountI
newGroup.AccountID = accountID newGroup.AccountID = accountID
oldGroup, err := transaction.GetGroupByID(ctx, store.LockingStrengthNone, accountID, newGroup.ID)
if err != nil {
return err
}
newGroup.AccountSeqID = oldGroup.AccountSeqID
if err := transaction.UpdateGroup(ctx, newGroup); err != nil { if err := transaction.UpdateGroup(ctx, newGroup); err != nil {
return err return err
} }

View File

@@ -0,0 +1,156 @@
package migration
import (
"context"
"fmt"
log "github.com/sirupsen/logrus"
"gorm.io/gorm"
"github.com/netbirdio/netbird/management/server/types"
)
// BackfillAccountSeqIDs assigns a deterministic per-account sequential id to all
// rows of `model` whose account_seq_id is zero, then seeds account_seq_counters
// with the next free id per account. Idempotent: safe to re-run; both steps
// no-op once everything is consistent.
//
// Implemented as two table-wide SQL statements with window functions, one
// transaction. Backfilling 246k rows across 154k accounts on Postgres takes
// well under a second instead of the per-account-loop ~2 minutes.
//
// orderColumn is the column to use when assigning the deterministic ordering
// (typically the primary-key string id).
func BackfillAccountSeqIDs[T any](
ctx context.Context,
db *gorm.DB,
entity types.AccountSeqEntity,
orderColumn string,
) error {
var model T
if !db.Migrator().HasTable(&model) {
log.WithContext(ctx).Debugf("backfill seq id: table for %T missing, skip", model)
return nil
}
stmt := &gorm.Statement{DB: db}
if err := stmt.Parse(&model); err != nil {
return fmt.Errorf("parse model: %w", err)
}
table := quoteIdent(db, stmt.Schema.Table)
orderCol := quoteIdent(db, orderColumn)
return db.Transaction(func(tx *gorm.DB) error {
var pending int64
if err := tx.Raw(
fmt.Sprintf("SELECT count(*) FROM %s WHERE account_seq_id IS NULL OR account_seq_id = 0", table),
).Scan(&pending).Error; err != nil {
return fmt.Errorf("count pending on %s: %w", table, err)
}
if pending > 0 {
log.WithContext(ctx).Infof("backfill seq id: %s — %d rows pending", table, pending)
if err := backfillRankSQL(tx, table, orderCol); err != nil {
return fmt.Errorf("rank %s: %w", table, err)
}
}
if err := seedCountersSQL(tx, table, entity); err != nil {
return fmt.Errorf("seed counters for %s: %w", entity, err)
}
return nil
})
}
func quoteIdent(db *gorm.DB, name string) string {
switch db.Dialector.Name() {
case "mysql":
return "`" + name + "`"
case "postgres":
return `"` + name + `"`
default:
return name
}
}
func backfillRankSQL(db *gorm.DB, table, orderCol string) error {
dialect := db.Dialector.Name()
var sql string
switch dialect {
case "postgres", "sqlite":
sql = fmt.Sprintf(`
WITH max_seq AS (
SELECT account_id, COALESCE(MAX(account_seq_id), 0) AS max_seq
FROM %s
GROUP BY account_id
),
ranked AS (
SELECT p.id,
m.max_seq + ROW_NUMBER() OVER (PARTITION BY p.account_id ORDER BY p.%s) AS new_seq
FROM %s p
JOIN max_seq m ON p.account_id = m.account_id
WHERE p.account_seq_id IS NULL OR p.account_seq_id = 0
)
UPDATE %s SET account_seq_id = ranked.new_seq
FROM ranked
WHERE %s.id = ranked.id
`, table, orderCol, table, table, table)
case "mysql":
sql = fmt.Sprintf(`
UPDATE %s p
JOIN (
SELECT account_id, COALESCE(MAX(account_seq_id), 0) AS max_seq
FROM %s
GROUP BY account_id
) m ON p.account_id = m.account_id
JOIN (
SELECT id, ROW_NUMBER() OVER (PARTITION BY account_id ORDER BY %s) AS rn
FROM %s
WHERE account_seq_id IS NULL OR account_seq_id = 0
) r ON p.id = r.id
SET p.account_seq_id = m.max_seq + r.rn
`, table, table, orderCol, table)
default:
return fmt.Errorf("unsupported dialect: %s", dialect)
}
return db.Exec(sql).Error
}
func seedCountersSQL(db *gorm.DB, table string, entity types.AccountSeqEntity) error {
dialect := db.Dialector.Name()
var sql string
switch dialect {
case "postgres":
sql = fmt.Sprintf(`
INSERT INTO account_seq_counters (account_id, entity, next_id)
SELECT account_id, ?, MAX(account_seq_id) + 1
FROM %s
WHERE account_seq_id IS NOT NULL AND account_seq_id > 0
GROUP BY account_id
ON CONFLICT (account_id, entity) DO UPDATE
SET next_id = GREATEST(account_seq_counters.next_id, EXCLUDED.next_id)
`, table)
case "sqlite":
sql = fmt.Sprintf(`
INSERT INTO account_seq_counters (account_id, entity, next_id)
SELECT account_id, ?, MAX(account_seq_id) + 1
FROM %s
WHERE account_seq_id IS NOT NULL AND account_seq_id > 0
GROUP BY account_id
ON CONFLICT (account_id, entity) DO UPDATE
SET next_id = max(account_seq_counters.next_id, excluded.next_id)
`, table)
case "mysql":
sql = fmt.Sprintf(`
INSERT INTO account_seq_counters (account_id, entity, next_id)
SELECT account_id, ?, MAX(account_seq_id) + 1
FROM %s
WHERE account_seq_id IS NOT NULL AND account_seq_id > 0
GROUP BY account_id
ON DUPLICATE KEY UPDATE next_id = GREATEST(next_id, VALUES(next_id))
`, table)
default:
return fmt.Errorf("unsupported dialect: %s", dialect)
}
return db.Exec(sql, string(entity)).Error
}

View File

@@ -98,6 +98,7 @@ type MockAccountManager struct {
GetPeerFunc func(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) GetPeerFunc func(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error)
UpdateAccountSettingsFunc func(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Settings, error) UpdateAccountSettingsFunc func(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Settings, error)
LoginPeerFunc func(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) LoginPeerFunc func(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
ExtendPeerSessionFunc func(ctx context.Context, peerPubKey, userID string) (time.Time, error)
SyncPeerFunc func(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) SyncPeerFunc func(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
InviteUserFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserEmail string) error InviteUserFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserEmail string) error
ApproveUserFunc func(ctx context.Context, accountID, initiatorUserID, targetUserID string) (*types.UserInfo, error) ApproveUserFunc func(ctx context.Context, accountID, initiatorUserID, targetUserID string) (*types.UserInfo, error)
@@ -860,6 +861,14 @@ func (am *MockAccountManager) LoginPeer(ctx context.Context, login types.PeerLog
return nil, nil, nil, status.Errorf(codes.Unimplemented, "method LoginPeer is not implemented") return nil, nil, nil, status.Errorf(codes.Unimplemented, "method LoginPeer is not implemented")
} }
// ExtendPeerSession mocks ExtendPeerSession of the AccountManager interface
func (am *MockAccountManager) ExtendPeerSession(ctx context.Context, peerPubKey, userID string) (time.Time, error) {
if am.ExtendPeerSessionFunc != nil {
return am.ExtendPeerSessionFunc(ctx, peerPubKey, userID)
}
return time.Time{}, status.Errorf(codes.Unimplemented, "method ExtendPeerSession is not implemented")
}
// SyncPeer mocks SyncPeer of the AccountManager interface // SyncPeer mocks SyncPeer of the AccountManager interface
func (am *MockAccountManager) SyncPeer(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) { func (am *MockAccountManager) SyncPeer(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
if am.SyncPeerFunc != nil { if am.SyncPeerFunc != nil {

View File

@@ -69,6 +69,12 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco
return err return err
} }
seq, err := transaction.AllocateAccountSeqID(ctx, accountID, types.AccountSeqEntityNameserverGroup)
if err != nil {
return err
}
newNSGroup.AccountSeqID = seq
if err = transaction.SaveNameServerGroup(ctx, newNSGroup); err != nil { if err = transaction.SaveNameServerGroup(ctx, newNSGroup); err != nil {
return err return err
} }
@@ -120,6 +126,8 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun
return err return err
} }
nsGroupToSave.AccountSeqID = oldNSGroup.AccountSeqID
if err = transaction.SaveNameServerGroup(ctx, nsGroupToSave); err != nil { if err = transaction.SaveNameServerGroup(ctx, nsGroupToSave); err != nil {
return err return err
} }

View File

@@ -71,9 +71,20 @@ func (m *managerImpl) CreateNetwork(ctx context.Context, userID string, network
network.ID = xid.New().String() network.ID = xid.New().String()
err = m.store.SaveNetwork(ctx, network) err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
seq, err := transaction.AllocateAccountSeqID(ctx, network.AccountID, serverTypes.AccountSeqEntityNetwork)
if err != nil {
return fmt.Errorf("failed to allocate network seq id: %w", err)
}
network.AccountSeqID = seq
if err := transaction.SaveNetwork(ctx, network); err != nil {
return fmt.Errorf("failed to save network: %w", err)
}
return nil
})
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to save network: %w", err) return nil, err
} }
m.accountManager.StoreEvent(ctx, userID, network.ID, network.AccountID, activity.NetworkCreated, network.EventMeta()) m.accountManager.StoreEvent(ctx, userID, network.ID, network.AccountID, activity.NetworkCreated, network.EventMeta())
@@ -102,14 +113,25 @@ func (m *managerImpl) UpdateNetwork(ctx context.Context, userID string, network
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
_, err = m.store.GetNetworkByID(ctx, store.LockingStrengthUpdate, network.AccountID, network.ID) err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
existing, err := transaction.GetNetworkByID(ctx, store.LockingStrengthUpdate, network.AccountID, network.ID)
if err != nil {
return fmt.Errorf("failed to get network: %w", err)
}
network.AccountSeqID = existing.AccountSeqID
if err := transaction.SaveNetwork(ctx, network); err != nil {
return fmt.Errorf("failed to save network: %w", err)
}
return nil
})
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get network: %w", err) return nil, err
} }
m.accountManager.StoreEvent(ctx, userID, network.ID, network.AccountID, activity.NetworkUpdated, network.EventMeta()) m.accountManager.StoreEvent(ctx, userID, network.ID, network.AccountID, activity.NetworkUpdated, network.EventMeta())
return network, m.store.SaveNetwork(ctx, network) return network, nil
} }
func (m *managerImpl) DeleteNetwork(ctx context.Context, accountID, userID, networkID string) error { func (m *managerImpl) DeleteNetwork(ctx context.Context, accountID, userID, networkID string) error {

View File

@@ -255,3 +255,73 @@ func Test_UpdateNetworkFailsWithPermissionDenied(t *testing.T) {
require.Error(t, err) require.Error(t, err)
require.Nil(t, updatedNetwork) require.Nil(t, updatedNetwork)
} }
// Test_CreateNetworkAllocatesSeqID verifies that CreateNetwork sets a
// non-zero AccountSeqID on the persisted network (allocated through the
// account_seq_counters table).
func Test_CreateNetworkAllocatesSeqID(t *testing.T) {
ctx := context.Background()
const accountID = "testAccountId"
const userID = "testAdminId"
s, cleanUp, err := store.NewTestStoreFromSQL(ctx, "../testdata/networks.sql", t.TempDir())
require.NoError(t, err)
t.Cleanup(cleanUp)
am := mock_server.MockAccountManager{}
permissionsManager := permissions.NewManager(s)
groupsManager := groups.NewManagerMock()
routerManager := routers.NewManagerMock()
resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am, nil)
manager := NewManager(s, permissionsManager, resourcesManager, routerManager, &am)
created, err := manager.CreateNetwork(ctx, userID, &types.Network{
AccountID: accountID,
Name: "seq-allocation-test",
})
require.NoError(t, err)
require.NotZero(t, created.AccountSeqID, "CreateNetwork must allocate a non-zero AccountSeqID")
}
// Test_UpdateNetworkPreservesSeqID verifies UpdateNetwork does not reset
// AccountSeqID even when the caller passes a zero value (the shape REST
// handlers produce because the field is `json:"-"`).
func Test_UpdateNetworkPreservesSeqID(t *testing.T) {
ctx := context.Background()
const accountID = "testAccountId"
const userID = "testAdminId"
s, cleanUp, err := store.NewTestStoreFromSQL(ctx, "../testdata/networks.sql", t.TempDir())
require.NoError(t, err)
t.Cleanup(cleanUp)
am := mock_server.MockAccountManager{}
permissionsManager := permissions.NewManager(s)
groupsManager := groups.NewManagerMock()
routerManager := routers.NewManagerMock()
resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am, nil)
manager := NewManager(s, permissionsManager, resourcesManager, routerManager, &am)
created, err := manager.CreateNetwork(ctx, userID, &types.Network{
AccountID: accountID,
Name: "seq-preserve-original",
})
require.NoError(t, err)
originalSeq := created.AccountSeqID
require.NotZero(t, originalSeq)
update := &types.Network{
AccountID: accountID,
ID: created.ID,
Name: "seq-preserve-renamed",
}
require.Zero(t, update.AccountSeqID, "incoming struct must mirror an HTTP handler shape")
_, err = manager.UpdateNetwork(ctx, userID, update)
require.NoError(t, err)
got, err := manager.GetNetwork(ctx, accountID, userID, created.ID)
require.NoError(t, err)
require.Equal(t, originalSeq, got.AccountSeqID, "AccountSeqID must survive UpdateNetwork")
require.Equal(t, "seq-preserve-renamed", got.Name)
}

View File

@@ -125,6 +125,12 @@ func (m *managerImpl) CreateResource(ctx context.Context, userID string, resourc
return fmt.Errorf("failed to get network: %w", err) return fmt.Errorf("failed to get network: %w", err)
} }
seq, err := transaction.AllocateAccountSeqID(ctx, resource.AccountID, nbtypes.AccountSeqEntityNetworkResource)
if err != nil {
return fmt.Errorf("failed to allocate network resource seq id: %w", err)
}
resource.AccountSeqID = seq
err = transaction.SaveNetworkResource(ctx, resource) err = transaction.SaveNetworkResource(ctx, resource)
if err != nil { if err != nil {
return fmt.Errorf("failed to save network resource: %w", err) return fmt.Errorf("failed to save network resource: %w", err)
@@ -231,6 +237,7 @@ func (m *managerImpl) UpdateResource(ctx context.Context, userID string, resourc
if err != nil { if err != nil {
return fmt.Errorf("failed to get network resource: %w", err) return fmt.Errorf("failed to get network resource: %w", err)
} }
resource.AccountSeqID = oldResource.AccountSeqID
err = transaction.SaveNetworkResource(ctx, resource) err = transaction.SaveNetworkResource(ctx, resource)
if err != nil { if err != nil {

View File

@@ -32,6 +32,9 @@ type NetworkResource struct {
ID string `gorm:"primaryKey"` ID string `gorm:"primaryKey"`
NetworkID string `gorm:"index"` NetworkID string `gorm:"index"`
AccountID string `gorm:"index"` AccountID string `gorm:"index"`
// AccountSeqID is a per-account monotonically increasing identifier used as the
// compact wire id when sending NetworkMap components to capable peers.
AccountSeqID uint32 `json:"-" gorm:"index:idx_network_resources_account_seq_id;not null;default:0"`
Name string Name string
Description string Description string
Type NetworkResourceType Type NetworkResourceType
@@ -93,17 +96,18 @@ func (n *NetworkResource) FromAPIRequest(req *api.NetworkResourceRequest) {
func (n *NetworkResource) Copy() *NetworkResource { func (n *NetworkResource) Copy() *NetworkResource {
return &NetworkResource{ return &NetworkResource{
ID: n.ID, ID: n.ID,
AccountID: n.AccountID, AccountID: n.AccountID,
NetworkID: n.NetworkID, NetworkID: n.NetworkID,
Name: n.Name, AccountSeqID: n.AccountSeqID,
Description: n.Description, Name: n.Name,
Type: n.Type, Description: n.Description,
Address: n.Address, Type: n.Type,
Domain: n.Domain, Address: n.Address,
Prefix: n.Prefix, Domain: n.Domain,
GroupIDs: n.GroupIDs, Prefix: n.Prefix,
Enabled: n.Enabled, GroupIDs: n.GroupIDs,
Enabled: n.Enabled,
} }
} }

View File

@@ -102,6 +102,12 @@ func (m *managerImpl) CreateRouter(ctx context.Context, userID string, router *t
router.ID = xid.New().String() router.ID = xid.New().String()
seq, err := transaction.AllocateAccountSeqID(ctx, router.AccountID, serverTypes.AccountSeqEntityNetworkRouter)
if err != nil {
return fmt.Errorf("failed to allocate network router seq id: %w", err)
}
router.AccountSeqID = seq
err = transaction.CreateNetworkRouter(ctx, router) err = transaction.CreateNetworkRouter(ctx, router)
if err != nil { if err != nil {
return fmt.Errorf("failed to create network router: %w", err) return fmt.Errorf("failed to create network router: %w", err)
@@ -175,6 +181,14 @@ func (m *managerImpl) UpdateRouter(ctx context.Context, userID string, router *t
return status.NewRouterNotPartOfNetworkError(router.ID, router.NetworkID) return status.NewRouterNotPartOfNetworkError(router.ID, router.NetworkID)
} }
// Preserve AccountSeqID from the existing router so the upstream
// UpdateNetworkRouter (which does Updates(router) with Select("*"))
// doesn't clobber it with the request's zero value. Main's split of
// Save → Create/Update removed the PUT-as-upsert path that the
// earlier branch carried, so callers must always pass a real router
// id — UpdateNetworkRouter returns NotFound otherwise.
router.AccountSeqID = existing.AccountSeqID
err = transaction.UpdateNetworkRouter(ctx, router) err = transaction.UpdateNetworkRouter(ctx, router)
if err != nil { if err != nil {
return fmt.Errorf("failed to update network router: %w", err) return fmt.Errorf("failed to update network router: %w", err)

View File

@@ -13,6 +13,9 @@ type NetworkRouter struct {
ID string `gorm:"primaryKey"` ID string `gorm:"primaryKey"`
NetworkID string `gorm:"index"` NetworkID string `gorm:"index"`
AccountID string `gorm:"index"` AccountID string `gorm:"index"`
// AccountSeqID is a per-account monotonically increasing identifier used as the
// compact wire id when sending NetworkMap components to capable peers.
AccountSeqID uint32 `json:"-" gorm:"index:idx_network_routers_account_seq_id;not null;default:0"`
Peer string Peer string
PeerGroups []string `gorm:"serializer:json"` PeerGroups []string `gorm:"serializer:json"`
Masquerade bool Masquerade bool
@@ -78,14 +81,15 @@ func (n *NetworkRouter) FromAPIRequest(req *api.NetworkRouterRequest) {
func (n *NetworkRouter) Copy() *NetworkRouter { func (n *NetworkRouter) Copy() *NetworkRouter {
return &NetworkRouter{ return &NetworkRouter{
ID: n.ID, ID: n.ID,
NetworkID: n.NetworkID, NetworkID: n.NetworkID,
AccountID: n.AccountID, AccountID: n.AccountID,
Peer: n.Peer, AccountSeqID: n.AccountSeqID,
PeerGroups: n.PeerGroups, Peer: n.Peer,
Masquerade: n.Masquerade, PeerGroups: n.PeerGroups,
Metric: n.Metric, Masquerade: n.Masquerade,
Enabled: n.Enabled, Metric: n.Metric,
Enabled: n.Enabled,
} }
} }

View File

@@ -7,12 +7,24 @@ import (
) )
type Network struct { type Network struct {
ID string `gorm:"primaryKey"` ID string `gorm:"primaryKey"`
AccountID string `gorm:"index"` AccountID string `gorm:"index"`
// AccountSeqID is a per-account monotonically increasing identifier used as the
// compact wire id when sending NetworkMap components to capable peers.
AccountSeqID uint32 `json:"-" gorm:"index:idx_networks_account_seq_id;not null;default:0"`
Name string Name string
Description string Description string
} }
// HasSeqID reports whether the network has been persisted long enough to have
// a per-account sequence id allocated. Wire encoders that key off AccountSeqID
// must skip networks that return false here.
func (n *Network) HasSeqID() bool {
return n != nil && n.AccountSeqID != 0
}
func NewNetwork(accountId, name, description string) *Network { func NewNetwork(accountId, name, description string) *Network {
return &Network{ return &Network{
ID: xid.New().String(), ID: xid.New().String(),
@@ -41,13 +53,14 @@ func (n *Network) FromAPIRequest(req *api.NetworkRequest) {
} }
} }
// Copy returns a copy of a posture checks. // Copy returns a copy of a network.
func (n *Network) Copy() *Network { func (n *Network) Copy() *Network {
return &Network{ return &Network{
ID: n.ID, ID: n.ID,
AccountID: n.AccountID, AccountID: n.AccountID,
Name: n.Name, AccountSeqID: n.AccountSeqID,
Description: n.Description, Name: n.Name,
Description: n.Description,
} }
} }

View File

@@ -1151,6 +1151,79 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
return p, nmap, pc, err return p, nmap, pc, err
} }
// ExtendPeerSession refreshes the peer's SSO session deadline by updating
// LastLogin after a successful JWT validation. The tunnel is untouched: no
// network map sync, no peer reconnect.
//
// Preconditions enforced here:
// - userID must be present (caller validated the JWT and extracted the user ID).
// - The peer must exist and be SSO-registered (AddedWithSSOLogin) with
// LoginExpirationEnabled.
// - Account-level PeerLoginExpirationEnabled must be true.
// - The JWT user must match peer.UserID (mirrors LoginPeer at peer.go ~1028).
//
// Returns the new absolute UTC deadline.
func (am *DefaultAccountManager) ExtendPeerSession(ctx context.Context, peerPubKey, userID string) (time.Time, error) {
if userID == "" {
return time.Time{}, status.Errorf(status.PermissionDenied, "session extend requires a JWT")
}
accountID, err := am.Store.GetAccountIDByPeerPubKey(ctx, peerPubKey)
if err != nil {
return time.Time{}, err
}
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return time.Time{}, err
}
if !settings.PeerLoginExpirationEnabled {
return time.Time{}, status.Errorf(status.PreconditionFailed, "peer login expiration is disabled for the account")
}
var refreshed *nbpeer.Peer
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
peer, err := transaction.GetPeerByPeerPubKey(ctx, store.LockingStrengthUpdate, peerPubKey)
if err != nil {
return err
}
if !peer.AddedWithSSOLogin() || !peer.LoginExpirationEnabled {
return status.Errorf(status.PreconditionFailed, "peer is not eligible for session extension")
}
if peer.UserID != userID {
log.WithContext(ctx).Warnf("user mismatch when extending session for peer %s: peer user %s, jwt user %s", peer.ID, peer.UserID, userID)
return status.NewPeerLoginMismatchError()
}
peer = peer.UpdateLastLogin()
if err := transaction.SavePeer(ctx, accountID, peer); err != nil {
return err
}
if err := transaction.SaveUserLastLogin(ctx, accountID, userID, peer.GetLastLogin()); err != nil {
log.WithContext(ctx).Debugf("failed to update user last login during session extend: %v", err)
}
am.StoreEvent(ctx, userID, peer.ID, accountID, activity.UserExtendedPeerSession, peer.EventMeta(am.networkMapController.GetDNSDomain(settings)))
refreshed = peer
return nil
})
if err != nil {
return time.Time{}, err
}
// Reschedule the per-account expiration job. schedulePeerLoginExpiration
// is a no-op when a job is already running, but the running job will pick
// up the new LastLogin on its next tick. Calling it here is harmless and
// guarantees a job is in flight even if a prior one ended right before
// the extend.
am.schedulePeerLoginExpiration(ctx, accountID)
return refreshed.SessionExpiresAt(settings.PeerLoginExpirationEnabled, settings.PeerLoginExpiration), nil
}
// getPeerPostureChecks returns the posture checks for the peer. // getPeerPostureChecks returns the posture checks for the peer.
func getPeerPostureChecks(ctx context.Context, transaction store.Store, accountID, peerID string) ([]*posture.Checks, error) { func getPeerPostureChecks(ctx context.Context, transaction store.Store, accountID, peerID string) ([]*posture.Checks, error) {
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID) policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)

View File

@@ -13,8 +13,9 @@ import (
// Peer capability constants mirror the proto enum values. // Peer capability constants mirror the proto enum values.
const ( const (
PeerCapabilitySourcePrefixes int32 = 1 PeerCapabilitySourcePrefixes int32 = 1
PeerCapabilityIPv6Overlay int32 = 2 PeerCapabilityIPv6Overlay int32 = 2
PeerCapabilityComponentNetworkMap int32 = 3
) )
// Peer represents a machine connected to the network. // Peer represents a machine connected to the network.
@@ -247,6 +248,14 @@ func (p *Peer) SupportsSourcePrefixes() bool {
return p.HasCapability(PeerCapabilitySourcePrefixes) return p.HasCapability(PeerCapabilitySourcePrefixes)
} }
// SupportsComponentNetworkMap reports whether the peer assembles its
// NetworkMap from server-shipped components instead of consuming a fully
// expanded NetworkMap. Determines whether the network_map controller skips
// Calculate() server-side and emits the components envelope.
func (p *Peer) SupportsComponentNetworkMap() bool {
return p.HasCapability(PeerCapabilityComponentNetworkMap)
}
func capabilitiesEqual(a, b []int32) bool { func capabilitiesEqual(a, b []int32) bool {
if len(a) != len(b) { if len(a) != len(b) {
return false return false
@@ -367,6 +376,22 @@ func (p *Peer) LoginExpired(expiresIn time.Duration) (bool, time.Duration) {
return timeLeft <= 0, timeLeft return timeLeft <= 0, timeLeft
} }
// SessionExpiresAt returns the absolute UTC instant at which the peer's SSO
// session expires, derived from LastLogin and the account-level
// PeerLoginExpiration setting. Returns the zero value when login expiration
// does not apply (peer not SSO-registered, peer-level toggle off, or account
// expiry disabled). Callers should treat the zero value as "no deadline".
func (p *Peer) SessionExpiresAt(accountExpirationEnabled bool, expiresIn time.Duration) time.Time {
if !accountExpirationEnabled || !p.AddedWithSSOLogin() || !p.LoginExpirationEnabled {
return time.Time{}
}
last := p.GetLastLogin()
if last.IsZero() {
return time.Time{}
}
return last.Add(expiresIn).UTC()
}
// FQDN returns peers FQDN combined of the peer's DNS label and the system's DNS domain // FQDN returns peers FQDN combined of the peer's DNS label and the system's DNS domain
func (p *Peer) FQDN(dnsDomain string) string { func (p *Peer) FQDN(dnsDomain string) string {
if dnsDomain == "" { if dnsDomain == "" {

View File

@@ -69,6 +69,8 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
return err return err
} }
policy.AccountSeqID = existingPolicy.AccountSeqID
if err = transaction.SavePolicy(ctx, policy); err != nil { if err = transaction.SavePolicy(ctx, policy); err != nil {
return err return err
} }
@@ -78,6 +80,12 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
return err return err
} }
seq, err := transaction.AllocateAccountSeqID(ctx, accountID, types.AccountSeqEntityPolicy)
if err != nil {
return err
}
policy.AccountSeqID = seq
if err = transaction.CreatePolicy(ctx, policy); err != nil { if err = transaction.CreatePolicy(ctx, policy); err != nil {
return err return err
} }

View File

@@ -47,10 +47,21 @@ type Checks struct {
// AccountID is a reference to the Account that this object belongs // AccountID is a reference to the Account that this object belongs
AccountID string `json:"-" gorm:"index"` AccountID string `json:"-" gorm:"index"`
// AccountSeqID is a per-account monotonically increasing identifier used as the
// compact wire id when sending NetworkMap components to capable peers.
AccountSeqID uint32 `json:"-" gorm:"index:idx_posture_checks_account_seq_id;not null;default:0"`
// Checks is a set of objects that perform the actual checks // Checks is a set of objects that perform the actual checks
Checks ChecksDefinition `gorm:"serializer:json"` Checks ChecksDefinition `gorm:"serializer:json"`
} }
// HasSeqID reports whether the posture check has been persisted long enough
// to have a per-account sequence id allocated. Wire encoders that key off
// AccountSeqID must skip checks that return false here.
func (pc *Checks) HasSeqID() bool {
return pc != nil && pc.AccountSeqID != 0
}
// ChecksDefinition contains definition of actual check // ChecksDefinition contains definition of actual check
type ChecksDefinition struct { type ChecksDefinition struct {
NBVersionCheck *NBVersionCheck `json:",omitempty"` NBVersionCheck *NBVersionCheck `json:",omitempty"`
@@ -121,11 +132,12 @@ func (*Checks) TableName() string {
// Copy returns a copy of a posture checks. // Copy returns a copy of a posture checks.
func (pc *Checks) Copy() *Checks { func (pc *Checks) Copy() *Checks {
checks := &Checks{ checks := &Checks{
ID: pc.ID, ID: pc.ID,
Name: pc.Name, Name: pc.Name,
Description: pc.Description, Description: pc.Description,
AccountID: pc.AccountID, AccountID: pc.AccountID,
Checks: pc.Checks.Copy(), AccountSeqID: pc.AccountSeqID,
Checks: pc.Checks.Copy(),
} }
return checks return checks
} }

View File

@@ -6,7 +6,6 @@ import (
"strings" "strings"
"github.com/hashicorp/go-version" "github.com/hashicorp/go-version"
log "github.com/sirupsen/logrus"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
) )
@@ -33,9 +32,6 @@ func (n *NBVersionCheck) Check(ctx context.Context, peer nbpeer.Peer) (bool, err
return true, nil return true, nil
} }
log.WithContext(ctx).Debugf("peer %s NB version %s is older than minimum allowed version %s",
peer.ID, peer.Meta.WtVersion, n.MinVersion)
return false, nil return false, nil
} }

View File

@@ -100,8 +100,6 @@ func checkMinVersion(ctx context.Context, peerGoOS, peerVersion string, check *M
return true, nil return true, nil
} }
log.WithContext(ctx).Debugf("peer %s OS version %s is older than minimum allowed version %s", peerGoOS, peerVersion, check.MinVersion)
return false, nil return false, nil
} }
@@ -125,7 +123,5 @@ func checkMinKernelVersion(ctx context.Context, peerGoOS, peerVersion string, ch
return true, nil return true, nil
} }
log.WithContext(ctx).Debugf("peer %s kernel version %s is older than minimum allowed version %s", peerGoOS, peerVersion, check.MinKernelVersion)
return false, nil return false, nil
} }

View File

@@ -51,12 +51,24 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI
} }
if isUpdate { if isUpdate {
existing, err := transaction.GetPostureChecksByID(ctx, store.LockingStrengthNone, accountID, postureChecks.ID)
if err != nil {
return err
}
postureChecks.AccountSeqID = existing.AccountSeqID
updateAccountPeers, err = arePostureCheckChangesAffectPeers(ctx, transaction, accountID, postureChecks.ID) updateAccountPeers, err = arePostureCheckChangesAffectPeers(ctx, transaction, accountID, postureChecks.ID)
if err != nil { if err != nil {
return err return err
} }
action = activity.PostureCheckUpdated action = activity.PostureCheckUpdated
} else {
seq, err := transaction.AllocateAccountSeqID(ctx, accountID, types.AccountSeqEntityPostureCheck)
if err != nil {
return err
}
postureChecks.AccountSeqID = seq
} }
postureChecks.AccountID = accountID postureChecks.AccountID = accountID

View File

@@ -563,3 +563,61 @@ func TestArePostureCheckChangesAffectPeers(t *testing.T) {
assert.False(t, result) assert.False(t, result)
}) })
} }
// TestSavePostureChecks_AllocatesSeqIDOnCreate verifies that the create path
// (no incoming ID) allocates a non-zero AccountSeqID via the
// account_seq_counters table.
func TestSavePostureChecks_AllocatesSeqIDOnCreate(t *testing.T) {
am, _, err := createManager(t)
require.NoError(t, err)
account, err := initTestPostureChecksAccount(am)
require.NoError(t, err)
created, err := am.SavePostureChecks(context.Background(), account.Id, adminUserID, &posture.Checks{
Name: "seq-allocation-test",
Checks: posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.26.0"},
},
}, true)
require.NoError(t, err)
require.NotZero(t, created.AccountSeqID, "SavePostureChecks on create must allocate a non-zero AccountSeqID")
}
// TestSavePostureChecks_PreservesSeqIDOnUpdate verifies the update path does
// not reset AccountSeqID even when the caller passes a zero value (REST
// handler shape, because the field is `json:"-"`).
func TestSavePostureChecks_PreservesSeqIDOnUpdate(t *testing.T) {
am, _, err := createManager(t)
require.NoError(t, err)
account, err := initTestPostureChecksAccount(am)
require.NoError(t, err)
created, err := am.SavePostureChecks(context.Background(), account.Id, adminUserID, &posture.Checks{
Name: "seq-preserve-original",
Checks: posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.26.0"},
},
}, true)
require.NoError(t, err)
originalSeq := created.AccountSeqID
require.NotZero(t, originalSeq)
update := &posture.Checks{
ID: created.ID,
Name: "seq-preserve-renamed",
Checks: posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.27.0"},
},
}
require.Zero(t, update.AccountSeqID, "incoming struct must mirror an HTTP handler shape")
_, err = am.SavePostureChecks(context.Background(), account.Id, adminUserID, update, false)
require.NoError(t, err)
got, err := am.GetPostureChecks(context.Background(), account.Id, created.ID, adminUserID)
require.NoError(t, err)
require.Equal(t, originalSeq, got.AccountSeqID, "AccountSeqID must survive SavePostureChecks update")
require.Equal(t, "seq-preserve-renamed", got.Name)
}

View File

@@ -178,6 +178,12 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
return err return err
} }
seq, err := transaction.AllocateAccountSeqID(ctx, accountID, types.AccountSeqEntityRoute)
if err != nil {
return err
}
newRoute.AccountSeqID = seq
if err = transaction.SaveRoute(ctx, newRoute); err != nil { if err = transaction.SaveRoute(ctx, newRoute); err != nil {
return err return err
} }
@@ -231,6 +237,7 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
return err return err
} }
routeToSave.AccountID = accountID routeToSave.AccountID = accountID
routeToSave.AccountSeqID = oldRoute.AccountSeqID
if err = transaction.SaveRoute(ctx, routeToSave); err != nil { if err = transaction.SaveRoute(ctx, routeToSave); err != nil {
return err return err

View File

@@ -0,0 +1,506 @@
package store
import (
"context"
"errors"
"net/netip"
"testing"
"github.com/stretchr/testify/require"
nbdns "github.com/netbirdio/netbird/dns"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/route"
)
var errRollback = errors.New("intentional rollback")
func TestAllocateAccountSeqID_SequentialPerAccount(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
ctx := context.Background()
const accA = "acc-a"
const accB = "acc-b"
require.NoError(t, store.ExecuteInTransaction(ctx, func(tx Store) error {
got, err := tx.AllocateAccountSeqID(ctx, accA, types.AccountSeqEntityPolicy)
require.NoError(t, err)
require.Equal(t, uint32(1), got)
got, err = tx.AllocateAccountSeqID(ctx, accA, types.AccountSeqEntityPolicy)
require.NoError(t, err)
require.Equal(t, uint32(2), got)
got, err = tx.AllocateAccountSeqID(ctx, accB, types.AccountSeqEntityPolicy)
require.NoError(t, err)
require.Equal(t, uint32(1), got, "different account starts from 1")
got, err = tx.AllocateAccountSeqID(ctx, accA, types.AccountSeqEntityGroup)
require.NoError(t, err)
require.Equal(t, uint32(1), got, "different entity starts from 1")
return nil
}))
require.NoError(t, store.ExecuteInTransaction(ctx, func(tx Store) error {
got, err := tx.AllocateAccountSeqID(ctx, accA, types.AccountSeqEntityPolicy)
require.NoError(t, err)
require.Equal(t, uint32(3), got, "counter persists across transactions")
return nil
}))
}
func TestPolicyBackfill_AssignsSeqIDsToExistingPolicies(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
ctx := context.Background()
const accountID = "bf1c8084-ba50-4ce7-9439-34653001fc3b"
policies, err := store.GetAccountPolicies(ctx, LockingStrengthNone, accountID)
require.NoError(t, err)
require.NotEmpty(t, policies, "test fixture must have policies")
seen := make(map[uint32]bool)
for _, p := range policies {
require.NotZero(t, p.AccountSeqID, "policy %s must have a non-zero AccountSeqID after migration", p.ID)
require.False(t, seen[p.AccountSeqID], "duplicate AccountSeqID %d in account %s", p.AccountSeqID, accountID)
seen[p.AccountSeqID] = true
}
}
func TestPolicyUpdate_PreservesSeqID(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
ctx := context.Background()
const accountID = "bf1c8084-ba50-4ce7-9439-34653001fc3b"
const policyID = "cs1tnh0hhcjnqoiuebf0"
original, err := store.GetPolicyByID(ctx, LockingStrengthNone, accountID, policyID)
require.NoError(t, err)
originalSeq := original.AccountSeqID
require.NotZero(t, originalSeq, "fixture must have non-zero AccountSeqID after backfill")
updated := &types.Policy{
ID: policyID,
AccountID: accountID,
Name: "renamed",
Enabled: false,
Rules: original.Rules,
}
require.Zero(t, updated.AccountSeqID, "incoming struct should have zero AccountSeqID like an HTTP handler would")
require.NoError(t, store.SavePolicy(ctx, updated))
got, err := store.GetPolicyByID(ctx, LockingStrengthNone, accountID, policyID)
require.NoError(t, err)
require.Equal(t, originalSeq, got.AccountSeqID, "AccountSeqID must not be reset by update path")
require.Equal(t, "renamed", got.Name)
}
func TestGroupUpdate_PreservesSeqID(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
ctx := context.Background()
const accountID = "bf1c8084-ba50-4ce7-9439-34653001fc3b"
groups, err := store.GetAccountGroups(ctx, LockingStrengthNone, accountID)
require.NoError(t, err)
require.NotEmpty(t, groups)
original := groups[0]
originalSeq := original.AccountSeqID
require.NotZero(t, originalSeq)
updated := &types.Group{
ID: original.ID,
AccountID: accountID,
Name: "renamed",
Issued: original.Issued,
}
require.Zero(t, updated.AccountSeqID)
require.NoError(t, store.UpdateGroup(ctx, updated))
got, err := store.GetGroupByID(ctx, LockingStrengthNone, accountID, original.ID)
require.NoError(t, err)
require.Equal(t, originalSeq, got.AccountSeqID, "AccountSeqID must not be reset by UpdateGroup")
require.Equal(t, "renamed", got.Name)
}
func TestSaveAccount_AllocatesSeqIDsForDefaultGroupAndPolicy(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
ctx := context.Background()
const accountID = "save-account-seqid-test"
account := &types.Account{
Id: accountID,
CreatedBy: "user1",
Domain: "example.test",
DNSSettings: types.DNSSettings{},
Settings: &types.Settings{},
Network: &types.Network{
Identifier: "net-test",
},
Users: map[string]*types.User{
"user1": {Id: "user1", AccountID: accountID, Role: types.UserRoleOwner},
},
}
require.NoError(t, account.AddAllGroup(false), "AddAllGroup should populate default Group + Policy")
require.Len(t, account.Groups, 1, "default 'All' group must be present")
require.Len(t, account.Policies, 1, "default policy must be present")
for _, g := range account.Groups {
require.Zero(t, g.AccountSeqID, "default group must start with seq=0")
}
require.Zero(t, account.Policies[0].AccountSeqID, "default policy must start with seq=0")
require.NoError(t, store.SaveAccount(ctx, account))
groups, err := store.GetAccountGroups(ctx, LockingStrengthNone, accountID)
require.NoError(t, err)
require.Len(t, groups, 1)
require.NotZerof(t, groups[0].AccountSeqID, "default group must have seq>0 after SaveAccount")
policies, err := store.GetAccountPolicies(ctx, LockingStrengthNone, accountID)
require.NoError(t, err)
require.Len(t, policies, 1)
require.NotZerof(t, policies[0].AccountSeqID, "default policy must have seq>0 after SaveAccount")
require.ErrorIs(t, store.ExecuteInTransaction(ctx, func(tx Store) error {
next, err := tx.AllocateAccountSeqID(ctx, accountID, types.AccountSeqEntityGroup)
require.NoError(t, err)
require.Equal(t, groups[0].AccountSeqID+1, next, "next group seq must be max+1")
next, err = tx.AllocateAccountSeqID(ctx, accountID, types.AccountSeqEntityPolicy)
require.NoError(t, err)
require.Equal(t, policies[0].AccountSeqID+1, next, "next policy seq must be max+1")
return errRollback
}), errRollback)
}
func TestSaveAccount_PreservesExistingSeqIDs(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
ctx := context.Background()
const accountID = "bf1c8084-ba50-4ce7-9439-34653001fc3b"
account, err := store.GetAccount(ctx, accountID)
require.NoError(t, err)
groupSeqs := make(map[string]uint32)
policySeqs := make(map[string]uint32)
routeSeqs := make(map[route.ID]uint32)
nsgSeqs := make(map[string]uint32)
resourceSeqs := make(map[string]uint32)
routerSeqs := make(map[string]uint32)
networkSeqs := make(map[string]uint32)
for _, g := range account.Groups {
require.NotZero(t, g.AccountSeqID, "fixture group must have seq>0 after backfill")
groupSeqs[g.ID] = g.AccountSeqID
}
for _, p := range account.Policies {
require.NotZero(t, p.AccountSeqID, "fixture policy must have seq>0")
policySeqs[p.ID] = p.AccountSeqID
}
for _, r := range account.Routes {
require.NotZero(t, r.AccountSeqID, "fixture route must have seq>0")
routeSeqs[r.ID] = r.AccountSeqID
}
for _, n := range account.NameServerGroups {
require.NotZero(t, n.AccountSeqID, "fixture name_server_group must have seq>0")
nsgSeqs[n.ID] = n.AccountSeqID
}
for _, nr := range account.NetworkResources {
require.NotZero(t, nr.AccountSeqID, "fixture network_resource must have seq>0")
resourceSeqs[nr.ID] = nr.AccountSeqID
}
for _, nr := range account.NetworkRouters {
require.NotZero(t, nr.AccountSeqID, "fixture network_router must have seq>0")
routerSeqs[nr.ID] = nr.AccountSeqID
}
for _, n := range account.Networks {
require.NotZero(t, n.AccountSeqID, "fixture network must have seq>0 after backfill")
networkSeqs[n.ID] = n.AccountSeqID
}
require.NoError(t, store.SaveAccount(ctx, account))
after, err := store.GetAccount(ctx, accountID)
require.NoError(t, err)
for _, g := range after.Groups {
require.Equal(t, groupSeqs[g.ID], g.AccountSeqID, "group %s seq must be preserved on re-save", g.ID)
}
for _, p := range after.Policies {
require.Equal(t, policySeqs[p.ID], p.AccountSeqID, "policy %s seq must be preserved", p.ID)
}
for _, r := range after.Routes {
require.Equal(t, routeSeqs[r.ID], r.AccountSeqID, "route %s seq must be preserved (slice-of-value addressability)", r.ID)
}
for _, n := range after.NameServerGroups {
require.Equal(t, nsgSeqs[n.ID], n.AccountSeqID, "name_server_group %s seq must be preserved (slice-of-value addressability)", n.ID)
}
for _, nr := range after.NetworkResources {
require.Equal(t, resourceSeqs[nr.ID], nr.AccountSeqID, "network_resource %s seq must be preserved", nr.ID)
}
for _, nr := range after.NetworkRouters {
require.Equal(t, routerSeqs[nr.ID], nr.AccountSeqID, "network_router %s seq must be preserved", nr.ID)
}
for _, n := range after.Networks {
require.Equal(t, networkSeqs[n.ID], n.AccountSeqID, "network %s seq must be preserved", n.ID)
}
}
func TestSaveAccount_AllocatesSeqIDsForAllEntityTypes(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
ctx := context.Background()
const accountID = "save-account-all-entities"
addr, err := netip.ParseAddr("8.8.8.8")
require.NoError(t, err)
account := &types.Account{
Id: accountID,
CreatedBy: "user1",
Domain: "example.test",
Settings: &types.Settings{},
Network: &types.Network{Identifier: "net-test"},
Users: map[string]*types.User{
"user1": {Id: "user1", AccountID: accountID, Role: types.UserRoleOwner},
},
Groups: map[string]*types.Group{
"g1": {ID: "g1", AccountID: accountID, Name: "g1", Issued: types.GroupIssuedAPI},
},
Policies: []*types.Policy{
{ID: "p1", AccountID: accountID, Name: "p1", Enabled: true,
Rules: []*types.PolicyRule{{ID: "r1", PolicyID: "p1", Enabled: true}}},
},
Routes: map[route.ID]*route.Route{
"rt1": {ID: "rt1", AccountID: accountID, NetID: "net1", Peer: "peer1"},
},
NameServerGroups: map[string]*nbdns.NameServerGroup{
"nsg1": {ID: "nsg1", AccountID: accountID, Name: "nsg1", Enabled: true,
NameServers: []nbdns.NameServer{{IP: addr, NSType: nbdns.UDPNameServerType, Port: 53}}},
},
NetworkResources: []*resourceTypes.NetworkResource{
{ID: "nr1", AccountID: accountID, NetworkID: "net1", Name: "res1", Enabled: true},
},
NetworkRouters: []*routerTypes.NetworkRouter{
{ID: "nrt1", AccountID: accountID, NetworkID: "net1", Peer: "peer1", Enabled: true},
},
Networks: []*networkTypes.Network{
{ID: "n1", AccountID: accountID, Name: "n1"},
},
PostureChecks: []*posture.Checks{
{ID: "pc1", AccountID: accountID, Name: "pc1",
Checks: posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.26.0"},
}},
},
}
require.NoError(t, store.SaveAccount(ctx, account))
after, err := store.GetAccount(ctx, accountID)
require.NoError(t, err)
require.Len(t, after.Groups, 1)
require.Len(t, after.Policies, 1)
require.Len(t, after.Routes, 1)
require.Len(t, after.NameServerGroups, 1)
require.Len(t, after.NetworkResources, 1)
require.Len(t, after.NetworkRouters, 1)
require.Len(t, after.Networks, 1)
require.Len(t, after.PostureChecks, 1)
for _, g := range after.Groups {
require.NotZero(t, g.AccountSeqID, "group seq must be allocated")
}
for _, p := range after.Policies {
require.NotZero(t, p.AccountSeqID, "policy seq must be allocated")
}
for _, r := range after.Routes {
require.NotZero(t, r.AccountSeqID, "route seq must be allocated (slice-of-value addressability)")
}
for _, n := range after.NameServerGroups {
require.NotZero(t, n.AccountSeqID, "name_server_group seq must be allocated (slice-of-value addressability)")
}
for _, nr := range after.NetworkResources {
require.NotZero(t, nr.AccountSeqID, "network_resource seq must be allocated")
}
for _, nr := range after.NetworkRouters {
require.NotZero(t, nr.AccountSeqID, "network_router seq must be allocated")
}
for _, n := range after.Networks {
require.NotZero(t, n.AccountSeqID, "network seq must be allocated")
}
for _, pc := range after.PostureChecks {
require.NotZero(t, pc.AccountSeqID, "posture_check seq must be allocated")
}
require.NoError(t, store.SaveAccount(ctx, after))
final, err := store.GetAccount(ctx, accountID)
require.NoError(t, err)
for _, r := range final.Routes {
require.Equal(t, after.Routes[r.ID].AccountSeqID, r.AccountSeqID, "route seq preserved on re-save")
}
for _, n := range final.NameServerGroups {
require.Equal(t, after.NameServerGroups[n.ID].AccountSeqID, n.AccountSeqID, "name_server_group seq preserved on re-save")
}
afterByID := map[string]uint32{}
for _, n := range after.Networks {
afterByID[n.ID] = n.AccountSeqID
}
for _, n := range final.Networks {
require.Equal(t, afterByID[n.ID], n.AccountSeqID, "network seq preserved on re-save")
}
afterPCByID := map[string]uint32{}
for _, pc := range after.PostureChecks {
afterPCByID[pc.ID] = pc.AccountSeqID
}
for _, pc := range final.PostureChecks {
require.Equal(t, afterPCByID[pc.ID], pc.AccountSeqID, "posture_check seq preserved on re-save")
}
}
func TestAllocateAccountSeqID_ConcurrentSameAccountEntity(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
ctx := context.Background()
const accountID = "concurrent-test"
const entity = types.AccountSeqEntityPolicy
const goroutines = 32
type result struct {
seq uint32
err error
}
results := make(chan result, goroutines)
start := make(chan struct{})
for i := 0; i < goroutines; i++ {
go func() {
<-start
var allocated uint32
err := store.ExecuteInTransaction(ctx, func(tx Store) error {
seq, err := tx.AllocateAccountSeqID(ctx, accountID, entity)
allocated = seq
return err
})
results <- result{seq: allocated, err: err}
}()
}
close(start)
seen := make(map[uint32]int, goroutines)
for i := 0; i < goroutines; i++ {
r := <-results
require.NoError(t, r.err, "concurrent allocate must not fail")
require.NotZero(t, r.seq, "allocated seq must be non-zero")
seen[r.seq]++
}
require.Lenf(t, seen, goroutines, "every concurrent allocation must yield a unique id; got duplicates in %v", seen)
for i := uint32(1); i <= goroutines; i++ {
require.Equalf(t, 1, seen[i], "id %d must appear exactly once across concurrent allocations", i)
}
}
func TestStoreCreateGroups_AllocatedSeqIDIsNotClobbered(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
ctx := context.Background()
const accountID = "bf1c8084-ba50-4ce7-9439-34653001fc3b"
groups := []*types.Group{
{ID: "seq-test-g1", AccountID: accountID, Name: "g1", Issued: "jwt", AccountSeqID: 7777},
{ID: "seq-test-g2", AccountID: accountID, Name: "g2", Issued: "jwt", AccountSeqID: 7778},
}
require.NoError(t, store.CreateGroups(ctx, accountID, groups))
for _, want := range groups {
got, err := store.GetGroupByID(ctx, LockingStrengthNone, accountID, want.ID)
require.NoError(t, err)
require.Equal(t, want.AccountSeqID, got.AccountSeqID, "seq id from caller must be persisted on insert")
}
groups[0].Name = "g1-renamed"
groups[0].AccountSeqID = 0
require.NoError(t, store.CreateGroups(ctx, accountID, groups[:1]))
got, err := store.GetGroupByID(ctx, LockingStrengthNone, accountID, "seq-test-g1")
require.NoError(t, err)
require.Equal(t, "g1-renamed", got.Name, "upsert path still updates other columns")
require.Equal(t, uint32(7777), got.AccountSeqID, "upsert path must NOT overwrite account_seq_id")
}
func TestPolicyCreate_AllocatesSeqID(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
ctx := context.Background()
const accountID = "bf1c8084-ba50-4ce7-9439-34653001fc3b"
existing, err := store.GetAccountPolicies(ctx, LockingStrengthNone, accountID)
require.NoError(t, err)
maxSeq := uint32(0)
for _, p := range existing {
if p.AccountSeqID > maxSeq {
maxSeq = p.AccountSeqID
}
}
require.NoError(t, store.ExecuteInTransaction(ctx, func(tx Store) error {
seq, err := tx.AllocateAccountSeqID(ctx, accountID, types.AccountSeqEntityPolicy)
if err != nil {
return err
}
require.Equal(t, maxSeq+1, seq, "next id should be max+1 after backfill")
newPolicy := &types.Policy{
ID: "bench-new-policy",
AccountID: accountID,
AccountSeqID: seq,
Enabled: true,
Rules: []*types.PolicyRule{{
ID: "bench-new-policy-rule",
PolicyID: "bench-new-policy",
Enabled: true,
Action: types.PolicyTrafficActionAccept,
Sources: []string{"groupA"},
Destinations: []string{"groupC"},
Bidirectional: true,
}},
}
return tx.CreatePolicy(ctx, newPolicy)
}))
created, err := store.GetPolicyByID(ctx, LockingStrengthNone, accountID, "bench-new-policy")
require.NoError(t, err)
require.Equal(t, maxSeq+1, created.AccountSeqID)
}

View File

@@ -137,6 +137,7 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, met
&networkTypes.Network{}, &routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{}, &types.AccountOnboarding{}, &networkTypes.Network{}, &routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{}, &types.AccountOnboarding{},
&types.Job{}, &zones.Zone{}, &records.Record{}, &types.UserInviteRecord{}, &rpservice.Service{}, &rpservice.Target{}, &domain.Domain{}, &types.Job{}, &zones.Zone{}, &records.Record{}, &types.UserInviteRecord{}, &rpservice.Service{}, &rpservice.Target{}, &domain.Domain{},
&accesslogs.AccessLogEntry{}, &proxy.Proxy{}, &accesslogs.AccessLogEntry{}, &proxy.Proxy{},
&types.AccountSeqCounter{},
) )
if err != nil { if err != nil {
return nil, fmt.Errorf("auto migratePreAuto: %w", err) return nil, fmt.Errorf("auto migratePreAuto: %w", err)
@@ -307,6 +308,10 @@ func (s *SqlStore) SaveAccount(ctx context.Context, account *types.Account) erro
return result.Error return result.Error
} }
if err := s.assignAccountSeqIDs(ctx, tx, account); err != nil {
return fmt.Errorf("assign seq ids: %w", err)
}
result = tx. result = tx.
Session(&gorm.Session{FullSaveAssociations: true}). Session(&gorm.Session{FullSaveAssociations: true}).
Clauses(clause.OnConflict{UpdateAll: true}). Clauses(clause.OnConflict{UpdateAll: true}).
@@ -658,6 +663,22 @@ func (s *SqlStore) SaveUser(ctx context.Context, user *types.User) error {
} }
// CreateGroups creates the given list of groups to the database. // CreateGroups creates the given list of groups to the database.
// groupUpsertColumns is the explicit allowlist of columns that get updated when
// CreateGroups / UpdateGroups hit a PK conflict. account_seq_id is intentionally
// omitted so a caller passing an entity with the zero value (e.g. an HTTP
// handler-built struct) cannot reset the persisted seq id during an upsert.
// Keep this in sync with the Group schema in management/server/types/group.go.
func groupUpsertColumns() clause.Set {
return clause.AssignmentColumns([]string{
"account_id",
"name",
"issued",
"integration_ref_id",
"integration_ref_integration_type",
"resources",
})
}
func (s *SqlStore) CreateGroups(ctx context.Context, accountID string, groups []*types.Group) error { func (s *SqlStore) CreateGroups(ctx context.Context, accountID string, groups []*types.Group) error {
if len(groups) == 0 { if len(groups) == 0 {
return nil return nil
@@ -667,8 +688,9 @@ func (s *SqlStore) CreateGroups(ctx context.Context, accountID string, groups []
result := tx. result := tx.
Clauses( Clauses(
clause.OnConflict{ clause.OnConflict{
Columns: []clause.Column{{Name: "id"}},
Where: clause.Where{Exprs: []clause.Expression{clause.Eq{Column: "groups.account_id", Value: accountID}}}, Where: clause.Where{Exprs: []clause.Expression{clause.Eq{Column: "groups.account_id", Value: accountID}}},
UpdateAll: true, DoUpdates: groupUpsertColumns(),
}, },
). ).
Omit(clause.Associations). Omit(clause.Associations).
@@ -692,8 +714,9 @@ func (s *SqlStore) UpdateGroups(ctx context.Context, accountID string, groups []
result := tx. result := tx.
Clauses( Clauses(
clause.OnConflict{ clause.OnConflict{
Columns: []clause.Column{{Name: "id"}},
Where: clause.Where{Exprs: []clause.Expression{clause.Eq{Column: "groups.account_id", Value: accountID}}}, Where: clause.Where{Exprs: []clause.Expression{clause.Eq{Column: "groups.account_id", Value: accountID}}},
UpdateAll: true, DoUpdates: groupUpsertColumns(),
}, },
). ).
Omit(clause.Associations). Omit(clause.Associations).
@@ -2027,7 +2050,7 @@ func (s *SqlStore) getUsers(ctx context.Context, accountID string) ([]types.User
} }
func (s *SqlStore) getGroups(ctx context.Context, accountID string) ([]*types.Group, error) { func (s *SqlStore) getGroups(ctx context.Context, accountID string) ([]*types.Group, error) {
const query = `SELECT id, account_id, name, issued, resources, integration_ref_id, integration_ref_integration_type FROM groups WHERE account_id = $1` const query = `SELECT id, account_id, account_seq_id, name, issued, resources, integration_ref_id, integration_ref_integration_type FROM groups WHERE account_id = $1`
rows, err := s.pool.Query(ctx, query, accountID) rows, err := s.pool.Query(ctx, query, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -2037,7 +2060,7 @@ func (s *SqlStore) getGroups(ctx context.Context, accountID string) ([]*types.Gr
var resources []byte var resources []byte
var refID sql.NullInt64 var refID sql.NullInt64
var refType sql.NullString var refType sql.NullString
err := row.Scan(&g.ID, &g.AccountID, &g.Name, &g.Issued, &resources, &refID, &refType) err := row.Scan(&g.ID, &g.AccountID, &g.AccountSeqID, &g.Name, &g.Issued, &resources, &refID, &refType)
if err == nil { if err == nil {
if refID.Valid { if refID.Valid {
g.IntegrationReference.ID = int(refID.Int64) g.IntegrationReference.ID = int(refID.Int64)
@@ -2062,7 +2085,7 @@ func (s *SqlStore) getGroups(ctx context.Context, accountID string) ([]*types.Gr
} }
func (s *SqlStore) getPolicies(ctx context.Context, accountID string) ([]*types.Policy, error) { func (s *SqlStore) getPolicies(ctx context.Context, accountID string) ([]*types.Policy, error) {
const query = `SELECT id, account_id, name, description, enabled, source_posture_checks FROM policies WHERE account_id = $1` const query = `SELECT id, account_id, account_seq_id, name, description, enabled, source_posture_checks FROM policies WHERE account_id = $1`
rows, err := s.pool.Query(ctx, query, accountID) rows, err := s.pool.Query(ctx, query, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -2071,7 +2094,7 @@ func (s *SqlStore) getPolicies(ctx context.Context, accountID string) ([]*types.
var p types.Policy var p types.Policy
var checks []byte var checks []byte
var enabled sql.NullBool var enabled sql.NullBool
err := row.Scan(&p.ID, &p.AccountID, &p.Name, &p.Description, &enabled, &checks) err := row.Scan(&p.ID, &p.AccountID, &p.AccountSeqID, &p.Name, &p.Description, &enabled, &checks)
if err == nil { if err == nil {
if enabled.Valid { if enabled.Valid {
p.Enabled = enabled.Bool p.Enabled = enabled.Bool
@@ -2089,7 +2112,7 @@ func (s *SqlStore) getPolicies(ctx context.Context, accountID string) ([]*types.
} }
func (s *SqlStore) getRoutes(ctx context.Context, accountID string) ([]route.Route, error) { func (s *SqlStore) getRoutes(ctx context.Context, accountID string) ([]route.Route, error) {
const query = `SELECT id, account_id, network, domains, keep_route, net_id, description, peer, peer_groups, network_type, masquerade, metric, enabled, groups, access_control_groups, skip_auto_apply FROM routes WHERE account_id = $1` const query = `SELECT id, account_id, account_seq_id, network, domains, keep_route, net_id, description, peer, peer_groups, network_type, masquerade, metric, enabled, groups, access_control_groups, skip_auto_apply FROM routes WHERE account_id = $1`
rows, err := s.pool.Query(ctx, query, accountID) rows, err := s.pool.Query(ctx, query, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -2099,7 +2122,7 @@ func (s *SqlStore) getRoutes(ctx context.Context, accountID string) ([]route.Rou
var network, domains, peerGroups, groups, accessGroups []byte var network, domains, peerGroups, groups, accessGroups []byte
var keepRoute, masquerade, enabled, skipAutoApply sql.NullBool var keepRoute, masquerade, enabled, skipAutoApply sql.NullBool
var metric sql.NullInt64 var metric sql.NullInt64
err := row.Scan(&r.ID, &r.AccountID, &network, &domains, &keepRoute, &r.NetID, &r.Description, &r.Peer, &peerGroups, &r.NetworkType, &masquerade, &metric, &enabled, &groups, &accessGroups, &skipAutoApply) err := row.Scan(&r.ID, &r.AccountID, &r.AccountSeqID, &network, &domains, &keepRoute, &r.NetID, &r.Description, &r.Peer, &peerGroups, &r.NetworkType, &masquerade, &metric, &enabled, &groups, &accessGroups, &skipAutoApply)
if err == nil { if err == nil {
if keepRoute.Valid { if keepRoute.Valid {
r.KeepRoute = keepRoute.Bool r.KeepRoute = keepRoute.Bool
@@ -2141,7 +2164,7 @@ func (s *SqlStore) getRoutes(ctx context.Context, accountID string) ([]route.Rou
} }
func (s *SqlStore) getNameServerGroups(ctx context.Context, accountID string) ([]nbdns.NameServerGroup, error) { func (s *SqlStore) getNameServerGroups(ctx context.Context, accountID string) ([]nbdns.NameServerGroup, error) {
const query = `SELECT id, account_id, name, description, name_servers, groups, "primary", domains, enabled, search_domains_enabled FROM name_server_groups WHERE account_id = $1` const query = `SELECT id, account_id, account_seq_id, name, description, name_servers, groups, "primary", domains, enabled, search_domains_enabled FROM name_server_groups WHERE account_id = $1`
rows, err := s.pool.Query(ctx, query, accountID) rows, err := s.pool.Query(ctx, query, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -2150,7 +2173,7 @@ func (s *SqlStore) getNameServerGroups(ctx context.Context, accountID string) ([
var n nbdns.NameServerGroup var n nbdns.NameServerGroup
var ns, groups, domains []byte var ns, groups, domains []byte
var primary, enabled, searchDomainsEnabled sql.NullBool var primary, enabled, searchDomainsEnabled sql.NullBool
err := row.Scan(&n.ID, &n.AccountID, &n.Name, &n.Description, &ns, &groups, &primary, &domains, &enabled, &searchDomainsEnabled) err := row.Scan(&n.ID, &n.AccountID, &n.AccountSeqID, &n.Name, &n.Description, &ns, &groups, &primary, &domains, &enabled, &searchDomainsEnabled)
if err == nil { if err == nil {
if primary.Valid { if primary.Valid {
n.Primary = primary.Bool n.Primary = primary.Bool
@@ -2186,7 +2209,7 @@ func (s *SqlStore) getNameServerGroups(ctx context.Context, accountID string) ([
} }
func (s *SqlStore) getPostureChecks(ctx context.Context, accountID string) ([]*posture.Checks, error) { func (s *SqlStore) getPostureChecks(ctx context.Context, accountID string) ([]*posture.Checks, error) {
const query = `SELECT id, account_id, name, description, checks FROM posture_checks WHERE account_id = $1` const query = `SELECT id, account_id, account_seq_id, name, description, checks FROM posture_checks WHERE account_id = $1`
rows, err := s.pool.Query(ctx, query, accountID) rows, err := s.pool.Query(ctx, query, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -2194,7 +2217,7 @@ func (s *SqlStore) getPostureChecks(ctx context.Context, accountID string) ([]*p
checks, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (*posture.Checks, error) { checks, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (*posture.Checks, error) {
var c posture.Checks var c posture.Checks
var checksDef []byte var checksDef []byte
err := row.Scan(&c.ID, &c.AccountID, &c.Name, &c.Description, &checksDef) err := row.Scan(&c.ID, &c.AccountID, &c.AccountSeqID, &c.Name, &c.Description, &checksDef)
if err == nil && checksDef != nil { if err == nil && checksDef != nil {
_ = json.Unmarshal(checksDef, &c.Checks) _ = json.Unmarshal(checksDef, &c.Checks)
} }
@@ -2374,7 +2397,7 @@ func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*rpserv
} }
func (s *SqlStore) getNetworks(ctx context.Context, accountID string) ([]*networkTypes.Network, error) { func (s *SqlStore) getNetworks(ctx context.Context, accountID string) ([]*networkTypes.Network, error) {
const query = `SELECT id, account_id, name, description FROM networks WHERE account_id = $1` const query = `SELECT id, account_id, account_seq_id, name, description FROM networks WHERE account_id = $1`
rows, err := s.pool.Query(ctx, query, accountID) rows, err := s.pool.Query(ctx, query, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -2391,7 +2414,7 @@ func (s *SqlStore) getNetworks(ctx context.Context, accountID string) ([]*networ
} }
func (s *SqlStore) getNetworkRouters(ctx context.Context, accountID string) ([]*routerTypes.NetworkRouter, error) { func (s *SqlStore) getNetworkRouters(ctx context.Context, accountID string) ([]*routerTypes.NetworkRouter, error) {
const query = `SELECT id, network_id, account_id, peer, peer_groups, masquerade, metric, enabled FROM network_routers WHERE account_id = $1` const query = `SELECT id, network_id, account_id, account_seq_id, peer, peer_groups, masquerade, metric, enabled FROM network_routers WHERE account_id = $1`
rows, err := s.pool.Query(ctx, query, accountID) rows, err := s.pool.Query(ctx, query, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -2401,7 +2424,7 @@ func (s *SqlStore) getNetworkRouters(ctx context.Context, accountID string) ([]*
var peerGroups []byte var peerGroups []byte
var masquerade, enabled sql.NullBool var masquerade, enabled sql.NullBool
var metric sql.NullInt64 var metric sql.NullInt64
err := row.Scan(&r.ID, &r.NetworkID, &r.AccountID, &r.Peer, &peerGroups, &masquerade, &metric, &enabled) err := row.Scan(&r.ID, &r.NetworkID, &r.AccountID, &r.AccountSeqID, &r.Peer, &peerGroups, &masquerade, &metric, &enabled)
if err == nil { if err == nil {
if masquerade.Valid { if masquerade.Valid {
r.Masquerade = masquerade.Bool r.Masquerade = masquerade.Bool
@@ -2429,7 +2452,7 @@ func (s *SqlStore) getNetworkRouters(ctx context.Context, accountID string) ([]*
} }
func (s *SqlStore) getNetworkResources(ctx context.Context, accountID string) ([]*resourceTypes.NetworkResource, error) { func (s *SqlStore) getNetworkResources(ctx context.Context, accountID string) ([]*resourceTypes.NetworkResource, error) {
const query = `SELECT id, network_id, account_id, name, description, type, domain, prefix, enabled FROM network_resources WHERE account_id = $1` const query = `SELECT id, network_id, account_id, account_seq_id, name, description, type, domain, prefix, enabled FROM network_resources WHERE account_id = $1`
rows, err := s.pool.Query(ctx, query, accountID) rows, err := s.pool.Query(ctx, query, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -2438,7 +2461,7 @@ func (s *SqlStore) getNetworkResources(ctx context.Context, accountID string) ([
var r resourceTypes.NetworkResource var r resourceTypes.NetworkResource
var prefix []byte var prefix []byte
var enabled sql.NullBool var enabled sql.NullBool
err := row.Scan(&r.ID, &r.NetworkID, &r.AccountID, &r.Name, &r.Description, &r.Type, &r.Domain, &prefix, &enabled) err := row.Scan(&r.ID, &r.NetworkID, &r.AccountID, &r.AccountSeqID, &r.Name, &r.Description, &r.Type, &r.Domain, &prefix, &enabled)
if err == nil { if err == nil {
if enabled.Valid { if enabled.Valid {
r.Enabled = enabled.Bool r.Enabled = enabled.Bool
@@ -3611,6 +3634,262 @@ func (s *SqlStore) withTx(tx *gorm.DB) Store {
} }
} }
// AllocateAccountSeqID returns the next per-account integer id for the given
// component kind. Must be called inside ExecuteInTransaction so the increment
// is serialized with the component insert.
func (s *SqlStore) AllocateAccountSeqID(ctx context.Context, accountID string, entity types.AccountSeqEntity) (uint32, error) {
return allocateAccountSeqID(ctx, s.db, s.storeEngine, accountID, entity)
}
func allocateAccountSeqID(_ context.Context, db *gorm.DB, engine types.Engine, accountID string, entity types.AccountSeqEntity) (uint32, error) {
switch engine {
case types.PostgresStoreEngine, types.SqliteStoreEngine:
return allocateAccountSeqIDReturning(db, accountID, entity)
case types.MysqlStoreEngine:
return allocateAccountSeqIDMysql(db, accountID, entity)
default:
return 0, fmt.Errorf("unsupported store engine for account_seq allocator: %v", engine)
}
}
// allocateAccountSeqIDReturning runs a single atomic INSERT ... ON CONFLICT
// DO UPDATE ... RETURNING that gives us the allocated id without a separate
// SELECT FOR UPDATE. Two concurrent allocations for the same (account, entity)
// produce two distinct ids: one wins the INSERT, the other wins the UPDATE
// branch and returns next_id+1.
func allocateAccountSeqIDReturning(db *gorm.DB, accountID string, entity types.AccountSeqEntity) (uint32, error) {
const sqlStr = `
INSERT INTO account_seq_counters (account_id, entity, next_id)
VALUES (?, ?, 2)
ON CONFLICT (account_id, entity) DO UPDATE
SET next_id = account_seq_counters.next_id + 1
RETURNING (next_id - 1)
`
var allocated uint32
if err := db.Raw(sqlStr, accountID, string(entity)).Scan(&allocated).Error; err != nil {
return 0, fmt.Errorf("upsert account seq counter: %w", err)
}
if allocated == 0 {
return 0, fmt.Errorf("upsert account seq counter returned 0")
}
return allocated, nil
}
// allocateAccountSeqIDMysql is the MySQL equivalent of allocateAccountSeqIDReturning.
// MySQL has no RETURNING on ON DUPLICATE KEY UPDATE, so we use the LAST_INSERT_ID
// trick: passing an expression to LAST_INSERT_ID(expr) both sets the session value
// and returns it from the INSERT. The INSERT's value uses LAST_INSERT_ID(2) so the
// no-conflict path also surfaces the new next_id, keeping the read-back uniform.
// LAST_INSERT_ID is per-connection; GORM transactions pin a single connection,
// so the follow-up SELECT sees the same value.
func allocateAccountSeqIDMysql(db *gorm.DB, accountID string, entity types.AccountSeqEntity) (uint32, error) {
const upsertSQL = `
INSERT INTO account_seq_counters (account_id, entity, next_id)
VALUES (?, ?, LAST_INSERT_ID(2))
ON DUPLICATE KEY UPDATE next_id = LAST_INSERT_ID(next_id + 1)
`
if err := db.Exec(upsertSQL, accountID, string(entity)).Error; err != nil {
return 0, fmt.Errorf("upsert account seq counter: %w", err)
}
var newNext uint64
if err := db.Raw("SELECT LAST_INSERT_ID()").Scan(&newNext).Error; err != nil {
return 0, fmt.Errorf("get last insert id: %w", err)
}
if newNext == 0 {
return 0, fmt.Errorf("LAST_INSERT_ID returned 0; account_seq_counters misconfigured")
}
return uint32(newNext - 1), nil
}
// assignAccountSeqIDs allocates a per-account integer id for any component on
// the in-memory account whose AccountSeqID is zero. Called from SaveAccount so
// the canonical "save the whole account" path produces the same persisted seq
// ids that the manager-level Create paths produce. Update flows that go
// through SaveAccount preserve existing non-zero values; for those, the
// per-entity counter is bumped so subsequent AllocateAccountSeqID calls don't
// hand out a colliding id.
func (s *SqlStore) assignAccountSeqIDs(ctx context.Context, tx *gorm.DB, account *types.Account) error {
maxByEntity := make(map[types.AccountSeqEntity]uint32, 8)
bump := func(entity types.AccountSeqEntity, seq uint32) {
if seq > maxByEntity[entity] {
maxByEntity[entity] = seq
}
}
for i := range account.GroupsG {
g := account.GroupsG[i]
if g == nil {
continue
}
if g.AccountSeqID != 0 {
bump(types.AccountSeqEntityGroup, g.AccountSeqID)
continue
}
seq, err := allocateAccountSeqID(ctx, tx, s.storeEngine, account.Id, types.AccountSeqEntityGroup)
if err != nil {
return err
}
g.AccountSeqID = seq
// Defensive: generateAccountSQLTypes currently aliases the same
// *Group pointer into GroupsG and Groups[id] (so this is a no-op
// today), but mirror the seq anyway so any future divergence in
// how the two collections are populated doesn't silently leave
// the canonical map view stale.
if original, ok := account.Groups[g.ID]; ok && original != nil && original != g {
original.AccountSeqID = seq
}
}
for _, p := range account.Policies {
if p == nil {
continue
}
if p.AccountSeqID != 0 {
bump(types.AccountSeqEntityPolicy, p.AccountSeqID)
continue
}
seq, err := allocateAccountSeqID(ctx, tx, s.storeEngine, account.Id, types.AccountSeqEntityPolicy)
if err != nil {
return err
}
p.AccountSeqID = seq
}
for i := range account.RoutesG {
r := &account.RoutesG[i]
if r.AccountSeqID != 0 {
bump(types.AccountSeqEntityRoute, r.AccountSeqID)
continue
}
seq, err := allocateAccountSeqID(ctx, tx, s.storeEngine, account.Id, types.AccountSeqEntityRoute)
if err != nil {
return err
}
r.AccountSeqID = seq
// Mirror the new seq onto the canonical map view so callers that
// hold the same in-memory account post-Save read a consistent
// AccountSeqID — without this, components/encoder code would see
// 0 for routes saved this transaction until the account is reloaded.
if original, ok := account.Routes[r.ID]; ok && original != nil {
original.AccountSeqID = seq
}
}
for i := range account.NameServerGroupsG {
ng := &account.NameServerGroupsG[i]
if ng.AccountSeqID != 0 {
bump(types.AccountSeqEntityNameserverGroup, ng.AccountSeqID)
continue
}
seq, err := allocateAccountSeqID(ctx, tx, s.storeEngine, account.Id, types.AccountSeqEntityNameserverGroup)
if err != nil {
return err
}
ng.AccountSeqID = seq
if original, ok := account.NameServerGroups[ng.ID]; ok && original != nil {
original.AccountSeqID = seq
}
}
for _, nr := range account.NetworkResources {
if nr == nil {
continue
}
if nr.AccountSeqID != 0 {
bump(types.AccountSeqEntityNetworkResource, nr.AccountSeqID)
continue
}
seq, err := allocateAccountSeqID(ctx, tx, s.storeEngine, account.Id, types.AccountSeqEntityNetworkResource)
if err != nil {
return err
}
nr.AccountSeqID = seq
}
for _, nr := range account.NetworkRouters {
if nr == nil {
continue
}
if nr.AccountSeqID != 0 {
bump(types.AccountSeqEntityNetworkRouter, nr.AccountSeqID)
continue
}
seq, err := allocateAccountSeqID(ctx, tx, s.storeEngine, account.Id, types.AccountSeqEntityNetworkRouter)
if err != nil {
return err
}
nr.AccountSeqID = seq
}
for _, n := range account.Networks {
if n == nil {
continue
}
if n.AccountSeqID != 0 {
bump(types.AccountSeqEntityNetwork, n.AccountSeqID)
continue
}
seq, err := allocateAccountSeqID(ctx, tx, s.storeEngine, account.Id, types.AccountSeqEntityNetwork)
if err != nil {
return err
}
n.AccountSeqID = seq
}
for _, pc := range account.PostureChecks {
if pc == nil {
continue
}
if pc.AccountSeqID != 0 {
bump(types.AccountSeqEntityPostureCheck, pc.AccountSeqID)
continue
}
seq, err := allocateAccountSeqID(ctx, tx, s.storeEngine, account.Id, types.AccountSeqEntityPostureCheck)
if err != nil {
return err
}
pc.AccountSeqID = seq
}
for entity, maxSeq := range maxByEntity {
if err := ensureAccountSeqCounter(tx, s.storeEngine, account.Id, entity, maxSeq+1); err != nil {
return fmt.Errorf("seed counter for %s: %w", entity, err)
}
}
return nil
}
// ensureAccountSeqCounter raises the per-account counter for entity to at
// least target. Used when SaveAccount persists components that already carry
// AccountSeqIDs (e.g. test bulk-load from sqlite to postgres, or migrations
// running before component data lands) so that the next AllocateAccountSeqID
// call returns a fresh id beyond what was just written.
func ensureAccountSeqCounter(db *gorm.DB, engine types.Engine, accountID string, entity types.AccountSeqEntity, target uint32) error {
switch engine {
case types.PostgresStoreEngine, types.SqliteStoreEngine:
const sqlStr = `
INSERT INTO account_seq_counters (account_id, entity, next_id)
VALUES (?, ?, ?)
ON CONFLICT (account_id, entity) DO UPDATE
SET next_id = GREATEST(account_seq_counters.next_id, EXCLUDED.next_id)
`
// sqlite's UPSERT understands max() but the migration uses GREATEST
// for postgres and max() for sqlite. We collapse to dialect-specific
// statements only when needed.
if engine == types.SqliteStoreEngine {
const sqliteSQL = `
INSERT INTO account_seq_counters (account_id, entity, next_id)
VALUES (?, ?, ?)
ON CONFLICT (account_id, entity) DO UPDATE
SET next_id = max(account_seq_counters.next_id, excluded.next_id)
`
return db.Exec(sqliteSQL, accountID, string(entity), target).Error
}
return db.Exec(sqlStr, accountID, string(entity), target).Error
case types.MysqlStoreEngine:
const sqlStr = `
INSERT INTO account_seq_counters (account_id, entity, next_id)
VALUES (?, ?, ?)
ON DUPLICATE KEY UPDATE next_id = GREATEST(next_id, VALUES(next_id))
`
return db.Exec(sqlStr, accountID, string(entity), target).Error
default:
return fmt.Errorf("unsupported store engine for account_seq counter: %v", engine)
}
}
// transaction wraps a GORM transaction with MySQL-specific FK checks handling // transaction wraps a GORM transaction with MySQL-specific FK checks handling
// Use this instead of db.Transaction() directly to avoid deadlocks on MySQL/Aurora // Use this instead of db.Transaction() directly to avoid deadlocks on MySQL/Aurora
func (s *SqlStore) transaction(fn func(*gorm.DB) error) error { func (s *SqlStore) transaction(fn func(*gorm.DB) error) error {
@@ -3800,7 +4079,7 @@ func (s *SqlStore) UpdateGroup(ctx context.Context, group *types.Group) error {
return status.Errorf(status.InvalidArgument, "group is nil") return status.Errorf(status.InvalidArgument, "group is nil")
} }
if err := s.db.Omit(clause.Associations).Save(group).Error; err != nil { if err := s.db.Omit(clause.Associations, "account_seq_id").Save(group).Error; err != nil {
log.WithContext(ctx).Errorf("failed to save group to store: %v", err) log.WithContext(ctx).Errorf("failed to save group to store: %v", err)
return status.Errorf(status.Internal, "failed to save group to store") return status.Errorf(status.Internal, "failed to save group to store")
} }
@@ -3888,7 +4167,7 @@ func (s *SqlStore) CreatePolicy(ctx context.Context, policy *types.Policy) error
// SavePolicy saves a policy to the database. // SavePolicy saves a policy to the database.
func (s *SqlStore) SavePolicy(ctx context.Context, policy *types.Policy) error { func (s *SqlStore) SavePolicy(ctx context.Context, policy *types.Policy) error {
result := s.db.Session(&gorm.Session{FullSaveAssociations: true}).Save(policy) result := s.db.Session(&gorm.Session{FullSaveAssociations: true}).Omit("account_seq_id").Save(policy)
if err := result.Error; err != nil { if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to save policy to the store: %s", err) log.WithContext(ctx).Errorf("failed to save policy to the store: %s", err)
return status.Errorf(status.Internal, "failed to save policy to store") return status.Errorf(status.Internal, "failed to save policy to store")

View File

@@ -220,6 +220,11 @@ type Store interface {
GetStoreEngine() types.Engine GetStoreEngine() types.Engine
ExecuteInTransaction(ctx context.Context, f func(store Store) error) error ExecuteInTransaction(ctx context.Context, f func(store Store) error) error
// AllocateAccountSeqID returns the next per-account integer id for the given
// component kind. Must run inside a transaction so the increment is serialized
// with the component insert.
AllocateAccountSeqID(ctx context.Context, accountID string, entity types.AccountSeqEntity) (uint32, error)
GetAccountNetworks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*networkTypes.Network, error) GetAccountNetworks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*networkTypes.Network, error)
GetNetworkByID(ctx context.Context, lockStrength LockingStrength, accountID, networkID string) (*networkTypes.Network, error) GetNetworkByID(ctx context.Context, lockStrength LockingStrength, accountID, networkID string) (*networkTypes.Network, error)
SaveNetwork(ctx context.Context, network *networkTypes.Network) error SaveNetwork(ctx context.Context, network *networkTypes.Network) error
@@ -556,6 +561,30 @@ func getMigrationsPostAuto(ctx context.Context) []migrationFunc {
func(db *gorm.DB) error { func(db *gorm.DB) error {
return migration.DropIndex[proxy.Proxy](ctx, db, "idx_proxy_account_id_unique") return migration.DropIndex[proxy.Proxy](ctx, db, "idx_proxy_account_id_unique")
}, },
func(db *gorm.DB) error {
return migration.BackfillAccountSeqIDs[types.Policy](ctx, db, types.AccountSeqEntityPolicy, "id")
},
func(db *gorm.DB) error {
return migration.BackfillAccountSeqIDs[types.Group](ctx, db, types.AccountSeqEntityGroup, "id")
},
func(db *gorm.DB) error {
return migration.BackfillAccountSeqIDs[route.Route](ctx, db, types.AccountSeqEntityRoute, "id")
},
func(db *gorm.DB) error {
return migration.BackfillAccountSeqIDs[resourceTypes.NetworkResource](ctx, db, types.AccountSeqEntityNetworkResource, "id")
},
func(db *gorm.DB) error {
return migration.BackfillAccountSeqIDs[routerTypes.NetworkRouter](ctx, db, types.AccountSeqEntityNetworkRouter, "id")
},
func(db *gorm.DB) error {
return migration.BackfillAccountSeqIDs[dns.NameServerGroup](ctx, db, types.AccountSeqEntityNameserverGroup, "id")
},
func(db *gorm.DB) error {
return migration.BackfillAccountSeqIDs[networkTypes.Network](ctx, db, types.AccountSeqEntityNetwork, "id")
},
func(db *gorm.DB) error {
return migration.BackfillAccountSeqIDs[posture.Checks](ctx, db, types.AccountSeqEntityPostureCheck, "id")
},
} }
} }

View File

@@ -774,6 +774,21 @@ func (mr *MockStoreMockRecorder) EphemeralServiceExists(ctx, lockStrength, accou
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "EphemeralServiceExists", reflect.TypeOf((*MockStore)(nil).EphemeralServiceExists), ctx, lockStrength, accountID, peerID, domain) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "EphemeralServiceExists", reflect.TypeOf((*MockStore)(nil).EphemeralServiceExists), ctx, lockStrength, accountID, peerID, domain)
} }
// AllocateAccountSeqID mocks base method.
func (m *MockStore) AllocateAccountSeqID(ctx context.Context, accountID string, entity types2.AccountSeqEntity) (uint32, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AllocateAccountSeqID", ctx, accountID, entity)
ret0, _ := ret[0].(uint32)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// AllocateAccountSeqID indicates an expected call of AllocateAccountSeqID.
func (mr *MockStoreMockRecorder) AllocateAccountSeqID(ctx, accountID, entity interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AllocateAccountSeqID", reflect.TypeOf((*MockStore)(nil).AllocateAccountSeqID), ctx, accountID, entity)
}
// ExecuteInTransaction mocks base method. // ExecuteInTransaction mocks base method.
func (m *MockStore) ExecuteInTransaction(ctx context.Context, f func(Store) error) error { func (m *MockStore) ExecuteInTransaction(ctx context.Context, f func(Store) error) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()

View File

@@ -13,6 +13,7 @@ type AccountManagerMetrics struct {
ctx context.Context ctx context.Context
updateAccountPeersDurationMs metric.Float64Histogram updateAccountPeersDurationMs metric.Float64Histogram
updateAccountPeersCounter metric.Int64Counter updateAccountPeersCounter metric.Int64Counter
nmapCounter metric.Int64Counter
getPeerNetworkMapDurationMs metric.Float64Histogram getPeerNetworkMapDurationMs metric.Float64Histogram
networkMapObjectCount metric.Int64Histogram networkMapObjectCount metric.Int64Histogram
peerMetaUpdateCount metric.Int64Counter peerMetaUpdateCount metric.Int64Counter
@@ -59,6 +60,13 @@ func NewAccountManagerMetrics(ctx context.Context, meter metric.Meter) (*Account
return nil, err return nil, err
} }
nmapCounter, err := meter.Int64Counter("management.network.map.counter",
metric.WithUnit("1"),
metric.WithDescription("Number of network maps computed, labeled by resource and operation trigger"))
if err != nil {
return nil, err
}
peerMetaUpdateCount, err := meter.Int64Counter("management.account.peer.meta.update.counter", peerMetaUpdateCount, err := meter.Int64Counter("management.account.peer.meta.update.counter",
metric.WithUnit("1"), metric.WithUnit("1"),
metric.WithDescription("Number of updates with new meta data from the peers")) metric.WithDescription("Number of updates with new meta data from the peers"))
@@ -93,6 +101,7 @@ func NewAccountManagerMetrics(ctx context.Context, meter metric.Meter) (*Account
peerMetaUpdateCount: peerMetaUpdateCount, peerMetaUpdateCount: peerMetaUpdateCount,
peerStatusUpdateCounter: peerStatusUpdateCounter, peerStatusUpdateCounter: peerStatusUpdateCounter,
peerStatusUpdateDurationMs: peerStatusUpdateDurationMs, peerStatusUpdateDurationMs: peerStatusUpdateDurationMs,
nmapCounter: nmapCounter,
}, nil }, nil
} }
@@ -145,6 +154,16 @@ func (metrics *AccountManagerMetrics) CountUpdateAccountPeersTriggered(resource,
) )
} }
// CountNmapTriggered increments the counter for calculated network maps with resource and operation labels.
func (metrics *AccountManagerMetrics) CountNmapTriggered(resource, operation string) {
metrics.nmapCounter.Add(metrics.ctx, 1,
metric.WithAttributes(
attribute.String("resource", resource),
attribute.String("operation", operation),
),
)
}
// CountPeerMetUpdate counts the number of peer meta updates // CountPeerMetUpdate counts the number of peer meta updates
func (metrics *AccountManagerMetrics) CountPeerMetUpdate() { func (metrics *AccountManagerMetrics) CountPeerMetUpdate() {
metrics.peerMetaUpdateCount.Add(metrics.ctx, 1) metrics.peerMetaUpdateCount.Add(metrics.ctx, 1)

View File

@@ -42,26 +42,8 @@ const (
PrivateCategory = "private" PrivateCategory = "private"
UnknownCategory = "unknown" UnknownCategory = "unknown"
// firewallRuleMinPortRangesVer defines the minimum peer version that supports port range rules.
firewallRuleMinPortRangesVer = "0.48.0"
// firewallRuleMinNativeSSHVer defines the minimum peer version that supports native SSH features in the firewall rules.
firewallRuleMinNativeSSHVer = "0.60.0"
// nativeSSHPortString defines the default port number as a string used for native SSH connections; this port is used by clients when hijacking ssh connections.
nativeSSHPortString = "22022"
nativeSSHPortNumber = 22022
// defaultSSHPortString defines the standard SSH port number as a string, commonly used for default SSH connections.
defaultSSHPortString = "22"
defaultSSHPortNumber = 22
) )
type supportedFeatures struct {
nativeSSH bool
portRanges bool
}
type LookupMap map[string]struct{}
// AccountMeta is a struct that contains a stripped down version of the Account object. // AccountMeta is a struct that contains a stripped down version of the Account object.
// It doesn't carry any peers, groups, policies, or routes, etc. Just some metadata (e.g. ID, created by, created at, etc). // It doesn't carry any peers, groups, policies, or routes, etc. Just some metadata (e.g. ID, created by, created at, etc).
type AccountMeta struct { type AccountMeta struct {
@@ -1037,7 +1019,7 @@ func (a *Account) GetPeerConnectionResources(ctx context.Context, peer *nbpeer.P
default: default:
authorizedUsers[auth.Wildcard] = a.getAllowedUserIDs() authorizedUsers[auth.Wildcard] = a.getAllowedUserIDs()
} }
} else if peerInDestinations && policyRuleImpliesLegacySSH(rule) && peer.SSHEnabled { } else if peerInDestinations && PolicyRuleImpliesLegacySSH(rule) && peer.SSHEnabled {
sshEnabled = true sshEnabled = true
authorizedUsers[auth.Wildcard] = a.getAllowedUserIDs() authorizedUsers[auth.Wildcard] = a.getAllowedUserIDs()
} }
@@ -1103,15 +1085,15 @@ func (a *Account) connResourcesGenerator(ctx context.Context, targetPeer *nbpeer
if len(rule.Ports) == 0 && len(rule.PortRanges) == 0 { if len(rule.Ports) == 0 && len(rule.PortRanges) == 0 {
rules = append(rules, &fr) rules = append(rules, &fr)
} else { } else {
rules = append(rules, expandPortsAndRanges(fr, rule, targetPeer)...) rules = append(rules, ExpandPortsAndRanges(fr, rule, targetPeer)...)
} }
rules = appendIPv6FirewallRule(rules, rulesExists, peer, targetPeer, rule, firewallRuleContext{ rules = AppendIPv6FirewallRule(rules, rulesExists, peer, targetPeer, rule, FirewallRuleContext{
direction: direction, Direction: direction,
dirStr: strconv.Itoa(direction), DirStr: strconv.Itoa(direction),
protocolStr: string(protocol), ProtocolStr: string(protocol),
actionStr: string(rule.Action), ActionStr: string(rule.Action),
portsJoined: strings.Join(rule.Ports, ","), PortsJoined: strings.Join(rule.Ports, ","),
}) })
} }
}, func() ([]*nbpeer.Peer, []*FirewallRule) { }, func() ([]*nbpeer.Peer, []*FirewallRule) {
@@ -1119,28 +1101,6 @@ func (a *Account) connResourcesGenerator(ctx context.Context, targetPeer *nbpeer
} }
} }
func policyRuleImpliesLegacySSH(rule *PolicyRule) bool {
return rule.Protocol == PolicyRuleProtocolALL || (rule.Protocol == PolicyRuleProtocolTCP && (portsIncludesSSH(rule.Ports) || portRangeIncludesSSH(rule.PortRanges)))
}
func portRangeIncludesSSH(portRanges []RulePortRange) bool {
for _, pr := range portRanges {
if (pr.Start <= defaultSSHPortNumber && pr.End >= defaultSSHPortNumber) || (pr.Start <= nativeSSHPortNumber && pr.End >= nativeSSHPortNumber) {
return true
}
}
return false
}
func portsIncludesSSH(ports []string) bool {
for _, port := range ports {
if port == defaultSSHPortString || port == nativeSSHPortString {
return true
}
}
return false
}
// getAllPeersFromGroups for given peer ID and list of groups // getAllPeersFromGroups for given peer ID and list of groups
// //
// Returns a list of peers from specified groups that pass specified posture checks // Returns a list of peers from specified groups that pass specified posture checks
@@ -1240,7 +1200,7 @@ func (a *Account) getRouteFirewallRules(ctx context.Context, peerID string, poli
} }
rulePeers := a.getRulePeers(rule, policy.SourcePostureChecks, peerID, distributionPeers, validatedPeersMap) rulePeers := a.getRulePeers(rule, policy.SourcePostureChecks, peerID, distributionPeers, validatedPeersMap)
rules := generateRouteFirewallRules(ctx, route, rule, rulePeers, FirewallRuleDirectionIN, includeIPv6) rules := GenerateRouteFirewallRules(ctx, route, rule, rulePeers, FirewallRuleDirectionIN, includeIPv6)
fwRules = append(fwRules, rules...) fwRules = append(fwRules, rules...)
} }
} }
@@ -1733,95 +1693,6 @@ func (a *Account) createProxyPolicy(svc *service.Service, target *service.Target
} }
} }
// expandPortsAndRanges expands Ports and PortRanges of a rule into individual firewall rules
func expandPortsAndRanges(base FirewallRule, rule *PolicyRule, peer *nbpeer.Peer) []*FirewallRule {
features := peerSupportedFirewallFeatures(peer.Meta.WtVersion)
var expanded []*FirewallRule
for _, port := range rule.Ports {
fr := base
fr.Port = port
expanded = append(expanded, &fr)
}
for _, portRange := range rule.PortRanges {
// prefer PolicyRule.Ports
if len(rule.Ports) > 0 {
break
}
fr := base
if features.portRanges {
fr.PortRange = portRange
} else {
// Peer doesn't support port ranges, only allow single-port ranges
if portRange.Start != portRange.End {
continue
}
fr.Port = strconv.FormatUint(uint64(portRange.Start), 10)
}
expanded = append(expanded, &fr)
}
if shouldCheckRulesForNativeSSH(features.nativeSSH, rule, peer) || rule.Protocol == PolicyRuleProtocolNetbirdSSH {
expanded = addNativeSSHRule(base, expanded)
}
return expanded
}
// addNativeSSHRule adds a native SSH rule (port 22022) to the expanded rules if the base rule has port 22 configured.
func addNativeSSHRule(base FirewallRule, expanded []*FirewallRule) []*FirewallRule {
shouldAdd := false
for _, fr := range expanded {
if isPortInRule(nativeSSHPortString, 22022, fr) {
return expanded
}
if isPortInRule(defaultSSHPortString, 22, fr) {
shouldAdd = true
}
}
if !shouldAdd {
return expanded
}
fr := base
fr.Port = nativeSSHPortString
return append(expanded, &fr)
}
func isPortInRule(portString string, portInt uint16, rule *FirewallRule) bool {
return rule.Port == portString || (rule.PortRange.Start <= portInt && portInt <= rule.PortRange.End)
}
// shouldCheckRulesForNativeSSH determines whether specific policy rules should be checked for native SSH support.
// While users can add the nativeSSHPortString, we look for cases when they used port 22 and based on SSH enabled
// in both management and client, we indicate to add the native port.
func shouldCheckRulesForNativeSSH(supportsNative bool, rule *PolicyRule, peer *nbpeer.Peer) bool {
return supportsNative && peer.SSHEnabled && peer.Meta.Flags.ServerSSHAllowed && rule.Protocol == PolicyRuleProtocolTCP
}
// peerSupportedFirewallFeatures checks if the peer version supports port ranges.
func peerSupportedFirewallFeatures(peerVer string) supportedFeatures {
if strings.Contains(peerVer, "dev") {
return supportedFeatures{true, true}
}
var features supportedFeatures
meetMinVer, err := posture.MeetsMinVersion(firewallRuleMinNativeSSHVer, peerVer)
features.nativeSSH = err == nil && meetMinVer
if features.nativeSSH {
features.portRanges = true
} else {
meetMinVer, err = posture.MeetsMinVersion(firewallRuleMinPortRangesVer, peerVer)
features.portRanges = err == nil && meetMinVer
}
return features
}
// filterZoneRecordsForPeers filters DNS records to only include peers to connect. // filterZoneRecordsForPeers filters DNS records to only include peers to connect.
// AAAA records are excluded when the requesting peer lacks IPv6 capability. // AAAA records are excluded when the requesting peer lacks IPv6 capability.

View File

@@ -16,6 +16,49 @@ import (
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
) )
// GetPeerNetworkMapResult dispatches to either the legacy-NetworkMap path or
// the components path based on the peer's capability and the kill switch.
// Capable peers (PeerCapabilityComponentNetworkMap) get the raw components
// shape — the server skips Calculate() entirely for them, saving CPU
// proportional to the number of capable peers in the account. Legacy peers
// (or any peer when componentsDisabled is true) get the fully-expanded
// NetworkMap as before.
func (a *Account) GetPeerNetworkMapResult(
ctx context.Context,
peerID string,
componentsDisabled bool,
peersCustomZone nbdns.CustomZone,
accountZones []*zones.Zone,
validatedPeersMap map[string]struct{},
resourcePolicies map[string][]*Policy,
routers map[string]map[string]*routerTypes.NetworkRouter,
metrics *telemetry.AccountManagerMetrics,
groupIDToUserIDs map[string][]string,
) PeerNetworkMapResult {
peer := a.Peers[peerID]
if !componentsDisabled && peer != nil && peer.SupportsComponentNetworkMap() {
components := a.GetPeerNetworkMapComponents(
ctx, peerID, peersCustomZone, accountZones, validatedPeersMap, resourcePolicies, routers, groupIDToUserIDs,
)
// Mirror legacy graceful-degrade: GetPeerNetworkMapFromComponents
// returns &NetworkMap{Network: a.Network.Copy()} when components is
// nil. Match that floor so the receiving client always sees the
// account Network identifier, not a fully-empty envelope.
if components == nil {
components = &NetworkMapComponents{
PeerID: peerID,
Network: a.Network.Copy(),
}
}
return PeerNetworkMapResult{Components: components}
}
return PeerNetworkMapResult{
NetworkMap: a.GetPeerNetworkMapFromComponents(
ctx, peerID, peersCustomZone, accountZones, validatedPeersMap, resourcePolicies, routers, metrics, groupIDToUserIDs,
),
}
}
func (a *Account) GetPeerNetworkMapFromComponents( func (a *Account) GetPeerNetworkMapFromComponents(
ctx context.Context, ctx context.Context,
peerID string, peerID string,
@@ -82,15 +125,27 @@ func (a *Account) GetPeerNetworkMapComponents(
} }
components := &NetworkMapComponents{ components := &NetworkMapComponents{
PeerID: peerID, PeerID: peerID,
Network: a.Network.Copy(), Network: a.Network.Copy(),
NameServerGroups: make([]*nbdns.NameServerGroup, 0), NameServerGroups: make([]*nbdns.NameServerGroup, 0),
CustomZoneDomain: peersCustomZone.Domain, CustomZoneDomain: peersCustomZone.Domain,
ResourcePoliciesMap: make(map[string][]*Policy), ResourcePoliciesMap: make(map[string][]*Policy),
RoutersMap: make(map[string]map[string]*routerTypes.NetworkRouter), RoutersMap: make(map[string]map[string]*routerTypes.NetworkRouter),
NetworkResources: make([]*resourceTypes.NetworkResource, 0), NetworkResources: make([]*resourceTypes.NetworkResource, 0),
PostureFailedPeers: make(map[string]map[string]struct{}, len(a.PostureChecks)), PostureFailedPeers: make(map[string]map[string]struct{}, len(a.PostureChecks)),
RouterPeers: make(map[string]*nbpeer.Peer), RouterPeers: make(map[string]*nbpeer.Peer),
NetworkXIDToSeq: make(map[string]uint32, len(a.Networks)),
PostureCheckXIDToSeq: make(map[string]uint32, len(a.PostureChecks)),
}
for _, n := range a.Networks {
if n != nil && n.HasSeqID() {
components.NetworkXIDToSeq[n.ID] = n.AccountSeqID
}
}
for _, pc := range a.PostureChecks {
if pc != nil && pc.HasSeqID() {
components.PostureCheckXIDToSeq[pc.ID] = pc.AccountSeqID
}
} }
components.AccountSettings = &AccountSettingsInfo{ components.AccountSettings = &AccountSettingsInfo{
@@ -209,21 +264,26 @@ func (a *Account) GetPeerNetworkMapComponents(
components.ResourcePoliciesMap[resource.ID] = policies components.ResourcePoliciesMap[resource.ID] = policies
} }
components.RoutersMap[resource.NetworkID] = networkRoutingPeers // Only expose router peers and the per-network routers_map when this
for peerIDKey := range networkRoutingPeers { // target peer actually has access to the resource (either as a router
if p := a.Peers[peerIDKey]; p != nil { // itself or via a policy that includes it as a source). Without this
if _, exists := components.RouterPeers[peerIDKey]; !exists { // gate, every peer's envelope was leaking router peers of every
components.RouterPeers[peerIDKey] = p // network in the account — accounts with many tenants/networks
} // shipped tens of unrelated peers in `peers[]` and `routers_map`.
if _, exists := components.Peers[peerIDKey]; !exists { if addSourcePeers {
if _, validated := validatedPeersMap[peerIDKey]; validated { components.RoutersMap[resource.NetworkID] = networkRoutingPeers
components.Peers[peerIDKey] = p for peerIDKey := range networkRoutingPeers {
if p := a.Peers[peerIDKey]; p != nil {
if _, exists := components.RouterPeers[peerIDKey]; !exists {
components.RouterPeers[peerIDKey] = p
}
if _, exists := components.Peers[peerIDKey]; !exists {
if _, validated := validatedPeersMap[peerIDKey]; validated {
components.Peers[peerIDKey] = p
}
} }
} }
} }
}
if addSourcePeers {
components.NetworkResources = append(components.NetworkResources, resource) components.NetworkResources = append(components.NetworkResources, resource)
} }
} }
@@ -254,18 +314,44 @@ func (a *Account) getPeersGroupsPoliciesRoutes(
relevantPeerIDs[peerID] = a.GetPeer(peerID) relevantPeerIDs[peerID] = a.GetPeer(peerID)
peerGroupSet := make(map[string]struct{}, 8)
for groupID, group := range a.Groups { for groupID, group := range a.Groups {
if slices.Contains(group.Peers, peerID) { if slices.Contains(group.Peers, peerID) {
relevantGroupIDs[groupID] = a.GetGroup(groupID) relevantGroupIDs[groupID] = a.GetGroup(groupID)
peerGroupSet[groupID] = struct{}{}
} }
} }
routeAccessControlGroups := make(map[string]struct{}) routeAccessControlGroups := make(map[string]struct{})
for _, r := range a.Routes { for _, r := range a.Routes {
for _, groupID := range r.Groups { if r == nil {
continue
}
relevant := r.Peer == peerID
if !relevant {
for _, groupID := range r.PeerGroups {
if _, ok := peerGroupSet[groupID]; ok {
relevant = true
break
}
}
}
if !relevant && r.Enabled {
for _, groupID := range r.Groups {
if _, ok := peerGroupSet[groupID]; ok {
relevant = true
break
}
}
}
if !relevant {
continue
}
for _, groupID := range r.PeerGroups {
relevantGroupIDs[groupID] = a.GetGroup(groupID) relevantGroupIDs[groupID] = a.GetGroup(groupID)
} }
for _, groupID := range r.PeerGroups { for _, groupID := range r.Groups {
relevantGroupIDs[groupID] = a.GetGroup(groupID) relevantGroupIDs[groupID] = a.GetGroup(groupID)
} }
if r.Enabled { if r.Enabled {
@@ -274,6 +360,44 @@ func (a *Account) getPeersGroupsPoliciesRoutes(
routeAccessControlGroups[groupID] = struct{}{} routeAccessControlGroups[groupID] = struct{}{}
} }
} }
// Include route advertisers in relevantPeerIDs. The envelope
// encoder writes route.peer_index by looking up r.Peer in the
// shipped peers list; if the advertiser is policy-isolated from
// the target peer (no rule edge between them), it would otherwise
// be omitted and the decoder would fail to resolve r.Peer, leaving
// the client without a WG tunnel target for this route. Legacy
// NetworkMap.Routes shipped the WG public key inline, so the
// equivalence path doesn't surface this — but the dependency is
// real once a client actually tries to use the route.
// Gate by validatedPeersMap so non-validated advertisers stay out
// (matches the network-resource router behaviour at the bottom of
// this loop, and the legacy invariant that only validated peers
// reach a client's view).
if r.Peer != "" {
if _, ok := validatedPeersMap[r.Peer]; ok {
if p := a.GetPeer(r.Peer); p != nil {
relevantPeerIDs[r.Peer] = p
}
}
}
for _, groupID := range r.PeerGroups {
g := a.GetGroup(groupID)
if g == nil {
continue
}
for _, pid := range g.Peers {
if _, exists := relevantPeerIDs[pid]; exists {
continue
}
if _, ok := validatedPeersMap[pid]; !ok {
continue
}
if p := a.GetPeer(pid); p != nil {
relevantPeerIDs[pid] = p
}
}
}
relevantRoutes = append(relevantRoutes, r) relevantRoutes = append(relevantRoutes, r)
} }
@@ -353,7 +477,7 @@ func (a *Account) getPeersGroupsPoliciesRoutes(
default: default:
sshReqs.needAllowedUserIDs = true sshReqs.needAllowedUserIDs = true
} }
} else if policyRuleImpliesLegacySSH(rule) && peerSSHEnabled { } else if PolicyRuleImpliesLegacySSH(rule) && peerSSHEnabled {
sshReqs.needAllowedUserIDs = true sshReqs.needAllowedUserIDs = true
} }
} }
@@ -486,6 +610,13 @@ func (a *Account) getPostureValidPeersSaveFailed(inputPeers []string, postureChe
return dest return dest
} }
// filterGroupPeers trims each group's Peers slice to only those peers that
// also appear in `peers`. Groups whose filtered list is empty are NOT
// deleted from the map — they're kept so the components wire encoder can
// still resolve seq references from routes/policies/access-control groups
// that name them. Calculate() tolerates groups with empty Peers (the inner
// loops simply iterate zero times), so retaining them is behaviourally a
// no-op for the legacy path that consumes the same NetworkMapComponents.
func filterGroupPeers(groups *map[string]*Group, peers map[string]*nbpeer.Peer) { func filterGroupPeers(groups *map[string]*Group, peers map[string]*nbpeer.Peer) {
for groupID, groupInfo := range *groups { for groupID, groupInfo := range *groups {
filteredPeers := make([]string, 0, len(groupInfo.Peers)) filteredPeers := make([]string, 0, len(groupInfo.Peers))
@@ -495,9 +626,7 @@ func filterGroupPeers(groups *map[string]*Group, peers map[string]*nbpeer.Peer)
} }
} }
if len(filteredPeers) == 0 { if len(filteredPeers) != len(groupInfo.Peers) {
delete(*groups, groupID)
} else if len(filteredPeers) != len(groupInfo.Peers) {
ng := groupInfo.Copy() ng := groupInfo.Copy()
ng.Peers = filteredPeers ng.Peers = filteredPeers
(*groups)[groupID] = ng (*groups)[groupID] = ng

View File

@@ -0,0 +1,29 @@
package types
// AccountSeqEntity identifies the kind of component that uses a per-account sequence.
type AccountSeqEntity string
const (
AccountSeqEntityPolicy AccountSeqEntity = "policy"
AccountSeqEntityGroup AccountSeqEntity = "group"
AccountSeqEntityRoute AccountSeqEntity = "route"
AccountSeqEntityNetworkResource AccountSeqEntity = "network_resource"
AccountSeqEntityNetworkRouter AccountSeqEntity = "network_router"
AccountSeqEntityNameserverGroup AccountSeqEntity = "nameserver_group"
AccountSeqEntityNetwork AccountSeqEntity = "network"
AccountSeqEntityPostureCheck AccountSeqEntity = "posture_check"
)
// AccountSeqCounter tracks the next per-account integer id for a given component
// kind. Reads/writes go through the store inside the same transaction as the
// component insert so two concurrent inserts cannot collide on the same id.
type AccountSeqCounter struct {
AccountID string `gorm:"primaryKey;size:255"`
Entity string `gorm:"primaryKey;size:32"`
NextID uint32 `gorm:"not null;default:1"`
}
// TableName overrides the GORM-derived table name.
func (AccountSeqCounter) TableName() string {
return "account_seq_counters"
}

View File

@@ -700,7 +700,7 @@ func Test_ExpandPortsAndRanges_SSHRuleExpansion(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
result := expandPortsAndRanges(tt.base, tt.rule, tt.peer) result := ExpandPortsAndRanges(tt.base, tt.rule, tt.peer)
var ports []string var ports []string
for _, fr := range result { for _, fr := range result {

View File

@@ -0,0 +1,142 @@
package types
import (
"context"
"math/rand"
"net"
"net/netip"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
nbroute "github.com/netbirdio/netbird/route"
sharedtypes "github.com/netbirdio/netbird/shared/management/types"
)
// Type aliases for types relocated to shared/management/types so that the
// client-side compute path can depend on them
type DNSSettings = sharedtypes.DNSSettings
type FirewallRule = sharedtypes.FirewallRule
type Group = sharedtypes.Group
type GroupPeer = sharedtypes.GroupPeer
type Network = sharedtypes.Network
type NetworkMap = sharedtypes.NetworkMap
type ForwardingRule = sharedtypes.ForwardingRule
type Policy = sharedtypes.Policy
type PolicyUpdateOperation = sharedtypes.PolicyUpdateOperation
type PolicyRule = sharedtypes.PolicyRule
type PolicyUpdateOperationType = sharedtypes.PolicyUpdateOperationType
type PolicyTrafficActionType = sharedtypes.PolicyTrafficActionType
type PolicyRuleProtocolType = sharedtypes.PolicyRuleProtocolType
type PolicyRuleDirection = sharedtypes.PolicyRuleDirection
type RulePortRange = sharedtypes.RulePortRange
type Resource = sharedtypes.Resource
type ResourceType = sharedtypes.ResourceType
type RouteFirewallRule = sharedtypes.RouteFirewallRule
type NetworkMapComponents = sharedtypes.NetworkMapComponents
type AccountSettingsInfo = sharedtypes.AccountSettingsInfo
type GroupCompact = sharedtypes.GroupCompact
type NetworkMapComponentsCompact = sharedtypes.NetworkMapComponentsCompact
type LookupMap = sharedtypes.LookupMap
type FirewallRuleContext = sharedtypes.FirewallRuleContext
const (
GroupIssuedAPI = sharedtypes.GroupIssuedAPI
GroupIssuedJWT = sharedtypes.GroupIssuedJWT
GroupIssuedIntegration = sharedtypes.GroupIssuedIntegration
GroupAllName = sharedtypes.GroupAllName
)
// Function forwarders preserve types.X(...) call sites that previously
// resolved to package-local funcs. Plain forwarders (not var aliases) keep
// the symbol immutable and allow the inliner to flatten the call.
func PolicyRuleImpliesLegacySSH(rule *PolicyRule) bool {
return sharedtypes.PolicyRuleImpliesLegacySSH(rule)
}
func ExpandPortsAndRanges(base FirewallRule, rule *PolicyRule, peer *nbpeer.Peer) []*FirewallRule {
return sharedtypes.ExpandPortsAndRanges(base, rule, peer)
}
func AppendIPv6FirewallRule(rules []*FirewallRule, rulesExists map[string]struct{}, peer, targetPeer *nbpeer.Peer, rule *PolicyRule, rc FirewallRuleContext) []*FirewallRule {
return sharedtypes.AppendIPv6FirewallRule(rules, rulesExists, peer, targetPeer, rule, rc)
}
func CalculateNetworkMapFromComponents(ctx context.Context, components *NetworkMapComponents) *NetworkMap {
return sharedtypes.CalculateNetworkMapFromComponents(ctx, components)
}
func GenerateRouteFirewallRules(ctx context.Context, route *nbroute.Route, rule *PolicyRule, groupPeers []*nbpeer.Peer, direction int, includeIPv6 bool) []*RouteFirewallRule {
return sharedtypes.GenerateRouteFirewallRules(ctx, route, rule, groupPeers, direction, includeIPv6)
}
func AllocateIPv6Subnet(r *rand.Rand) net.IPNet {
return sharedtypes.AllocateIPv6Subnet(r)
}
func NewNetwork() *Network {
return sharedtypes.NewNetwork()
}
func AllocatePeerIP(prefix netip.Prefix, takenIps []netip.Addr) (netip.Addr, error) {
return sharedtypes.AllocatePeerIP(prefix, takenIps)
}
func AllocateRandomPeerIP(prefix netip.Prefix) (netip.Addr, error) {
return sharedtypes.AllocateRandomPeerIP(prefix)
}
func AllocateRandomPeerIPv6(prefix netip.Prefix) (netip.Addr, error) {
return sharedtypes.AllocateRandomPeerIPv6(prefix)
}
func ParseRuleString(rule string) (PolicyRuleProtocolType, RulePortRange, error) {
return sharedtypes.ParseRuleString(rule)
}
const (
FirewallRuleDirectionIN = sharedtypes.FirewallRuleDirectionIN
FirewallRuleDirectionOUT = sharedtypes.FirewallRuleDirectionOUT
)
const (
ResourceTypePeer = sharedtypes.ResourceTypePeer
ResourceTypeDomain = sharedtypes.ResourceTypeDomain
ResourceTypeHost = sharedtypes.ResourceTypeHost
ResourceTypeSubnet = sharedtypes.ResourceTypeSubnet
)
const (
PolicyTrafficActionAccept = sharedtypes.PolicyTrafficActionAccept
PolicyTrafficActionDrop = sharedtypes.PolicyTrafficActionDrop
)
const (
PolicyRuleProtocolALL = sharedtypes.PolicyRuleProtocolALL
PolicyRuleProtocolTCP = sharedtypes.PolicyRuleProtocolTCP
PolicyRuleProtocolUDP = sharedtypes.PolicyRuleProtocolUDP
PolicyRuleProtocolICMP = sharedtypes.PolicyRuleProtocolICMP
PolicyRuleProtocolNetbirdSSH = sharedtypes.PolicyRuleProtocolNetbirdSSH
)
const (
PolicyRuleFlowDirect = sharedtypes.PolicyRuleFlowDirect
PolicyRuleFlowBidirect = sharedtypes.PolicyRuleFlowBidirect
)
const (
DefaultRuleName = sharedtypes.DefaultRuleName
DefaultRuleDescription = sharedtypes.DefaultRuleDescription
DefaultPolicyName = sharedtypes.DefaultPolicyName
DefaultPolicyDescription = sharedtypes.DefaultPolicyDescription
)

View File

@@ -0,0 +1,180 @@
package types_test
import (
"context"
"crypto/rand"
"encoding/base64"
"fmt"
"testing"
goproto "google.golang.org/protobuf/proto"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache"
mgmtgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/server/types"
)
// wireBenchScales — trimmed scale set for wire-size measurements. Encoding
// and marshalling are linear, so the largest extremes don't add signal.
var wireBenchScales = []benchmarkScale{
{"100peers_5groups", 100, 5},
{"500peers_20groups", 500, 20},
{"1000peers_50groups", 1000, 50},
{"5000peers_100groups", 5000, 100},
}
// populateAccountSeqIDs assigns deterministic AccountSeqIDs to every group and
// policy in the account so that the component encoder can reference them. The
// scalableTestAccount fixture builds entities by struct literal and skips this
// step, but production paths populate the IDs via the store layer.
func populateAccountSeqIDs(account *types.Account) {
var nextGroupSeq uint32 = 1
for _, g := range account.Groups {
g.AccountSeqID = nextGroupSeq
nextGroupSeq++
}
var nextPolicySeq uint32 = 1
for _, p := range account.Policies {
p.AccountSeqID = nextPolicySeq
nextPolicySeq++
}
}
// assignValidWgKeys overwrites every peer's Key with a valid base64-encoded
// 32-byte string. The default scalableTestAccount uses unparsable strings
// like "key-peer-0", which makes the components encoder emit a nil WgPubKey
// and the legacy encoder ship 10-char placeholders — both shrink the wire
// size in unrealistic ways. Production peers always have valid 44-char base64
// keys, so any benchmark/breakdown that wants honest numbers must call this.
func assignValidWgKeys(account *types.Account) {
for _, p := range account.Peers {
var raw [32]byte
_, _ = rand.Read(raw[:])
p.Key = base64.StdEncoding.EncodeToString(raw[:])
}
}
// BenchmarkNetworkMapWireEncode reports per-call ns and the marshaled wire
// size for both encoding paths. Run with:
//
// go test -run=^$ -bench=BenchmarkNetworkMapWireEncode -benchmem ./management/server/types/
func BenchmarkNetworkMapWireEncode(b *testing.B) {
skipCIBenchmark(b)
for _, scale := range wireBenchScales {
account, validatedPeers := scalableTestAccount(scale.peers, scale.groups)
populateAccountSeqIDs(account)
assignValidWgKeys(account)
ctx := context.Background()
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
groupIDToUserIDs := account.GetActiveGroupUsers()
peerID := "peer-0"
peer := account.Peers[peerID]
networkMap := account.GetPeerNetworkMapFromComponents(ctx, peerID, nbdns.CustomZone{}, nil, validatedPeers, resourcePolicies, routers, nil, groupIDToUserIDs)
components := account.GetPeerNetworkMapComponents(ctx, peerID, nbdns.CustomZone{}, nil, validatedPeers, resourcePolicies, routers, groupIDToUserIDs)
dnsCache := &cache.DNSConfigCache{}
settings := &types.Settings{}
// Pre-encode once so the size metric is identical for every run inside
// the same scale; the b.Loop call only re-runs encode + Marshal.
legacyResp := mgmtgrpc.ToSyncResponse(ctx, nil, nil, nil, peer, nil, nil, networkMap, "netbird.cloud", nil, dnsCache, settings, nil, nil, 0)
legacyBytes, err := goproto.Marshal(legacyResp.NetworkMap)
if err != nil {
b.Fatalf("marshal legacy networkmap: %v", err)
}
envelopeInput := mgmtgrpc.ComponentsEnvelopeInput{
Components: components,
PeerConfig: legacyResp.NetworkMap.PeerConfig,
DNSDomain: "netbird.cloud",
}
envelope := mgmtgrpc.EncodeNetworkMapEnvelope(envelopeInput)
envelopeBytes, err := goproto.Marshal(envelope)
if err != nil {
b.Fatalf("marshal envelope: %v", err)
}
b.Run(fmt.Sprintf("legacy/%s", scale.name), func(b *testing.B) {
b.ReportAllocs()
b.ReportMetric(float64(len(legacyBytes)), "bytes/msg")
b.ResetTimer()
for range b.N {
resp := mgmtgrpc.ToSyncResponse(ctx, nil, nil, nil, peer, nil, nil, networkMap, "netbird.cloud", nil, dnsCache, settings, nil, nil, 0)
if _, err := goproto.Marshal(resp.NetworkMap); err != nil {
b.Fatal(err)
}
}
})
b.Run(fmt.Sprintf("components/%s", scale.name), func(b *testing.B) {
b.ReportAllocs()
b.ReportMetric(float64(len(envelopeBytes)), "bytes/msg")
b.ResetTimer()
for range b.N {
env := mgmtgrpc.EncodeNetworkMapEnvelope(envelopeInput)
if _, err := goproto.Marshal(env); err != nil {
b.Fatal(err)
}
}
})
}
}
// BenchmarkNetworkMapWireSize is a fast snapshot of the wire size by scale
// without a tight encode loop. Run with -bench to see one ns/op + bytes per
// scale (treat the timing as informational; the sample is one Marshal per
// scale, not the full b.N loop).
func BenchmarkNetworkMapWireSize(b *testing.B) {
skipCIBenchmark(b)
for _, scale := range wireBenchScales {
account, validatedPeers := scalableTestAccount(scale.peers, scale.groups)
populateAccountSeqIDs(account)
assignValidWgKeys(account)
ctx := context.Background()
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
groupIDToUserIDs := account.GetActiveGroupUsers()
peerID := "peer-0"
peer := account.Peers[peerID]
networkMap := account.GetPeerNetworkMapFromComponents(ctx, peerID, nbdns.CustomZone{}, nil, validatedPeers, resourcePolicies, routers, nil, groupIDToUserIDs)
components := account.GetPeerNetworkMapComponents(ctx, peerID, nbdns.CustomZone{}, nil, validatedPeers, resourcePolicies, routers, groupIDToUserIDs)
dnsCache := &cache.DNSConfigCache{}
settings := &types.Settings{}
legacyResp := mgmtgrpc.ToSyncResponse(ctx, nil, nil, nil, peer, nil, nil, networkMap, "netbird.cloud", nil, dnsCache, settings, nil, nil, 0)
legacyBytes, err := goproto.Marshal(legacyResp.NetworkMap)
if err != nil {
b.Fatalf("marshal legacy networkmap: %v", err)
}
env := mgmtgrpc.EncodeNetworkMapEnvelope(mgmtgrpc.ComponentsEnvelopeInput{
Components: components,
PeerConfig: legacyResp.NetworkMap.PeerConfig,
DNSDomain: "netbird.cloud",
})
envBytes, err := goproto.Marshal(env)
if err != nil {
b.Fatalf("marshal envelope: %v", err)
}
b.Run(fmt.Sprintf("size/%s", scale.name), func(b *testing.B) {
b.ReportMetric(float64(len(legacyBytes)), "legacy_bytes")
b.ReportMetric(float64(len(envBytes)), "components_bytes")
ratio := float64(len(envBytes)) / float64(len(legacyBytes))
b.ReportMetric(ratio, "components/legacy")
for range b.N {
}
})
}
}

View File

@@ -0,0 +1,150 @@
package types_test
import (
"context"
"fmt"
"os"
"testing"
goproto "google.golang.org/protobuf/proto"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache"
mgmtgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/proto"
)
// TestNetworkMapWireBreakdown is a one-shot diagnostic: it computes the wire
// size attributable to each top-level field of both the legacy NetworkMap and
// the components NetworkMapEnvelope at the 5000-peer scale, so the migration
// docs can attribute the size reduction to each optimization. Runs only on
// demand via -run TestNetworkMapWireBreakdown.
func TestNetworkMapWireBreakdown(t *testing.T) {
if testing.Short() {
t.Skip("size diagnostic, skipped with -short")
}
if os.Getenv("NB_RUN_WIRE_BREAKDOWN") != "1" {
t.Skip("set NB_RUN_WIRE_BREAKDOWN=1 to run wire breakdown diagnostic")
}
const peerCount, groupCount = 5000, 100
account, validatedPeers := scalableTestAccount(peerCount, groupCount)
populateAccountSeqIDs(account)
assignValidWgKeys(account)
ctx := context.Background()
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
groupIDToUserIDs := account.GetActiveGroupUsers()
peerID := "peer-0"
peer := account.Peers[peerID]
networkMap := account.GetPeerNetworkMapFromComponents(ctx, peerID, nbdns.CustomZone{}, nil, validatedPeers, resourcePolicies, routers, nil, groupIDToUserIDs)
components := account.GetPeerNetworkMapComponents(ctx, peerID, nbdns.CustomZone{}, nil, validatedPeers, resourcePolicies, routers, groupIDToUserIDs)
dnsCache := &cache.DNSConfigCache{}
settings := &types.Settings{}
legacyResp := mgmtgrpc.ToSyncResponse(ctx, nil, nil, nil, peer, nil, nil, networkMap, "netbird.cloud", nil, dnsCache, settings, nil, nil, 0)
legacyTotal := mustMarshalSize(t, legacyResp.NetworkMap)
envelope := mgmtgrpc.EncodeNetworkMapEnvelope(mgmtgrpc.ComponentsEnvelopeInput{
Components: components,
PeerConfig: legacyResp.NetworkMap.PeerConfig,
DNSDomain: "netbird.cloud",
})
componentsTotal := mustMarshalSize(t, envelope)
t.Logf("\n=== LEGACY NetworkMap (%d peers, %d groups) ===", peerCount, groupCount)
t.Logf(" Total: %d bytes\n", legacyTotal)
legacyBreakdown := []struct {
name string
nm *proto.NetworkMap
}{
{"RemotePeers", &proto.NetworkMap{RemotePeers: legacyResp.NetworkMap.RemotePeers}},
{"OfflinePeers", &proto.NetworkMap{OfflinePeers: legacyResp.NetworkMap.OfflinePeers}},
{"FirewallRules", &proto.NetworkMap{FirewallRules: legacyResp.NetworkMap.FirewallRules}},
{"Routes", &proto.NetworkMap{Routes: legacyResp.NetworkMap.Routes}},
{"RoutesFirewallRules", &proto.NetworkMap{RoutesFirewallRules: legacyResp.NetworkMap.RoutesFirewallRules}},
{"DNSConfig", &proto.NetworkMap{DNSConfig: legacyResp.NetworkMap.DNSConfig}},
{"PeerConfig", &proto.NetworkMap{PeerConfig: legacyResp.NetworkMap.PeerConfig}},
{"SshAuth", &proto.NetworkMap{SshAuth: legacyResp.NetworkMap.SshAuth}},
}
for _, e := range legacyBreakdown {
size := mustMarshalSize(t, e.nm)
t.Logf(" %-22s %8d bytes %5.1f%%", e.name, size, pct(size, legacyTotal))
}
full := envelope.GetFull()
if full == nil {
t.Fatalf("expected full network map envelope payload, got nil")
}
t.Logf("\n=== COMPONENTS NetworkMapEnvelope (%d peers, %d groups) ===", peerCount, groupCount)
t.Logf(" Total: %d bytes (%.1f%% of legacy)\n", componentsTotal, pct(componentsTotal, legacyTotal))
componentsBreakdown := []struct {
name string
nm *proto.NetworkMapComponentsFull
}{
{"Peers", &proto.NetworkMapComponentsFull{Peers: full.Peers}},
{"Policies", &proto.NetworkMapComponentsFull{Policies: full.Policies}},
{"Groups", &proto.NetworkMapComponentsFull{Groups: full.Groups}},
{"Routes (raw)", &proto.NetworkMapComponentsFull{Routes: full.Routes}},
{"NameServerGroups", &proto.NetworkMapComponentsFull{NameserverGroups: full.NameserverGroups}},
{"AllDNSRecords", &proto.NetworkMapComponentsFull{AllDnsRecords: full.AllDnsRecords}},
{"AccountZones", &proto.NetworkMapComponentsFull{AccountZones: full.AccountZones}},
{"NetworkResources", &proto.NetworkMapComponentsFull{NetworkResources: full.NetworkResources}},
{"RoutersMap", &proto.NetworkMapComponentsFull{RoutersMap: full.RoutersMap}},
{"ResourcePoliciesMap", &proto.NetworkMapComponentsFull{ResourcePoliciesMap: full.ResourcePoliciesMap}},
{"GroupIDToUserIDs", &proto.NetworkMapComponentsFull{GroupIdToUserIds: full.GroupIdToUserIds}},
{"AllowedUserIDs", &proto.NetworkMapComponentsFull{AllowedUserIds: full.AllowedUserIds}},
{"PostureFailedPeers", &proto.NetworkMapComponentsFull{PostureFailedPeers: full.PostureFailedPeers}},
{"DNSSettings", &proto.NetworkMapComponentsFull{DnsSettings: full.DnsSettings}},
{"PeerConfig", &proto.NetworkMapComponentsFull{PeerConfig: full.PeerConfig}},
{"AgentVersions", &proto.NetworkMapComponentsFull{AgentVersions: full.AgentVersions}},
}
for _, e := range componentsBreakdown {
size := mustMarshalSize(t, e.nm)
t.Logf(" %-22s %8d bytes %5.1f%%", e.name, size, pct(size, componentsTotal))
}
t.Logf("\n=== Per-PeerCompact average ===")
if len(full.Peers) > 0 {
t.Logf(" PeerCompact avg: %d bytes/peer", mustMarshalSize(t, &proto.NetworkMapComponentsFull{Peers: full.Peers})/len(full.Peers))
}
if len(legacyResp.NetworkMap.RemotePeers) > 0 {
t.Logf(" RemotePeer avg: %d bytes/peer",
mustMarshalSize(t, &proto.NetworkMap{RemotePeers: legacyResp.NetworkMap.RemotePeers})/len(legacyResp.NetworkMap.RemotePeers))
}
t.Logf("\n=== FirewallRule expansion footprint ===")
t.Logf(" legacy FirewallRules count: %d", len(legacyResp.NetworkMap.FirewallRules))
t.Logf(" components Policies count: %d", len(full.Policies))
t.Logf(" components Groups count: %d", len(full.Groups))
totalGroupPeerIdxs := 0
for _, g := range full.Groups {
totalGroupPeerIdxs += len(g.PeerIndexes)
}
t.Logf(" components peer-index refs across all groups: %d", totalGroupPeerIdxs)
}
func mustMarshalSize(t *testing.T, m goproto.Message) int {
b, err := goproto.Marshal(m)
if err != nil {
t.Fatalf("marshal: %v", err)
}
return len(b)
}
func pct(part, total int) float64 {
if total == 0 {
return 0
}
return 100 * float64(part) / float64(total)
}
// Stops fmt being unused if the breakdown loop above is later commented out.
var _ = fmt.Sprintf

View File

@@ -0,0 +1,25 @@
package types
// PeerNetworkMapResult is what the network_map controller produces for a
// single peer. Exactly one of NetworkMap or Components is populated depending
// on the peer's capability:
//
// - Components-capable peers (PeerCapabilityComponentNetworkMap) get
// Components: the raw types.NetworkMapComponents the client decodes and
// runs Calculate() on locally. NetworkMap stays nil — the server skips
// the expansion entirely.
// - Legacy peers (or any peer when the kill switch is set) get NetworkMap:
// the fully-expanded view the legacy gRPC path consumes.
//
// The gRPC layer (ToSyncResponseForPeer) dispatches by which field is
// non-nil; callers must not rely on both being set.
type PeerNetworkMapResult struct {
NetworkMap *NetworkMap
Components *NetworkMapComponents
}
// IsComponents reports whether the result carries the components shape.
// Use this in preference to direct nil checks on the fields.
func (r PeerNetworkMapResult) IsComponents() bool {
return r.Components != nil
}

View File

@@ -0,0 +1,104 @@
package types_test
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
nbdns "github.com/netbirdio/netbird/dns"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/types"
)
// helper: marks the given peer as components-capable.
func markCapable(p *nbpeer.Peer) {
p.Meta.Capabilities = append(p.Meta.Capabilities, nbpeer.PeerCapabilityComponentNetworkMap)
}
func TestGetPeerNetworkMapResult_CapablePeerGetsComponents(t *testing.T) {
account, validatedPeers := scalableTestAccount(10, 2)
markCapable(account.Peers["peer-0"])
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
groupIDToUserIDs := account.GetActiveGroupUsers()
result := account.GetPeerNetworkMapResult(
context.Background(),
"peer-0",
false, // componentsDisabled
nbdns.CustomZone{},
nil,
validatedPeers,
resourcePolicies,
routers,
nil,
groupIDToUserIDs,
)
require.True(t, result.IsComponents(), "capable peer must get the components shape")
assert.Nil(t, result.NetworkMap)
require.NotNil(t, result.Components)
assert.Equal(t, "peer-0", result.Components.PeerID)
}
func TestGetPeerNetworkMapResult_LegacyPeerGetsNetworkMap(t *testing.T) {
account, validatedPeers := scalableTestAccount(10, 2)
// peer-0 left without the component capability
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
groupIDToUserIDs := account.GetActiveGroupUsers()
result := account.GetPeerNetworkMapResult(
context.Background(),
"peer-0",
false,
nbdns.CustomZone{},
nil,
validatedPeers,
resourcePolicies,
routers,
nil,
groupIDToUserIDs,
)
assert.False(t, result.IsComponents())
assert.Nil(t, result.Components)
require.NotNil(t, result.NetworkMap, "legacy peer must get a NetworkMap")
}
func TestGetPeerNetworkMapResult_KillSwitchOverridesCapability(t *testing.T) {
// Capable peer + componentsDisabled=true → falls back to legacy.
account, validatedPeers := scalableTestAccount(10, 2)
markCapable(account.Peers["peer-0"])
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
groupIDToUserIDs := account.GetActiveGroupUsers()
result := account.GetPeerNetworkMapResult(
context.Background(),
"peer-0",
true, // componentsDisabled = true (kill switch)
nbdns.CustomZone{},
nil,
validatedPeers,
resourcePolicies,
routers,
nil,
groupIDToUserIDs,
)
assert.False(t, result.IsComponents(), "kill switch must force legacy NetworkMap path")
assert.Nil(t, result.Components)
require.NotNil(t, result.NetworkMap)
}
func TestPeerNetworkMapResult_IsComponents(t *testing.T) {
assert.True(t, types.PeerNetworkMapResult{Components: &types.NetworkMapComponents{}}.IsComponents())
assert.False(t, types.PeerNetworkMapResult{NetworkMap: &types.NetworkMap{}}.IsComponents())
assert.False(t, types.PeerNetworkMapResult{}.IsComponents())
}

View File

@@ -95,6 +95,9 @@ type Route struct {
ID ID `gorm:"primaryKey"` ID ID `gorm:"primaryKey"`
// AccountID is a reference to Account that this object belongs // AccountID is a reference to Account that this object belongs
AccountID string `gorm:"index"` AccountID string `gorm:"index"`
// AccountSeqID is a per-account monotonically increasing identifier used as the
// compact wire id when sending NetworkMap components to capable peers.
AccountSeqID uint32 `json:"-" gorm:"index:idx_routes_account_seq_id;not null;default:0"`
// Network and Domains are mutually exclusive // Network and Domains are mutually exclusive
Network netip.Prefix `gorm:"serializer:json"` Network netip.Prefix `gorm:"serializer:json"`
Domains domain.List `gorm:"serializer:json"` Domains domain.List `gorm:"serializer:json"`
@@ -128,6 +131,7 @@ func (r *Route) Copy() *Route {
route := &Route{ route := &Route{
ID: r.ID, ID: r.ID,
AccountID: r.AccountID, AccountID: r.AccountID,
AccountSeqID: r.AccountSeqID,
Description: r.Description, Description: r.Description,
NetID: r.NetID, NetID: r.NetID,
Network: r.Network, Network: r.Network,

View File

@@ -16,6 +16,10 @@ type Client interface {
Job(ctx context.Context, msgHandler func(msg *proto.JobRequest) *proto.JobResponse) error Job(ctx context.Context, msgHandler func(msg *proto.JobRequest) *proto.JobResponse) error
Register(setupKey string, jwtToken string, sysInfo *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) Register(setupKey string, jwtToken string, sysInfo *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error)
Login(sysInfo *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) Login(sysInfo *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error)
// ExtendAuthSession refreshes the peer's SSO session deadline using a fresh JWT.
// Returns the new absolute deadline; zero time when the server reports the peer
// is not eligible for session extension.
ExtendAuthSession(sysInfo *system.Info, jwtToken string) (*proto.ExtendAuthSessionResponse, error)
GetDeviceAuthorizationFlow() (*proto.DeviceAuthorizationFlow, error) GetDeviceAuthorizationFlow() (*proto.DeviceAuthorizationFlow, error)
GetPKCEAuthorizationFlow() (*proto.PKCEAuthorizationFlow, error) GetPKCEAuthorizationFlow() (*proto.PKCEAuthorizationFlow, error)
GetNetworkMap(sysInfo *system.Info) (*proto.NetworkMap, error) GetNetworkMap(sysInfo *system.Info) (*proto.NetworkMap, error)

View File

@@ -316,27 +316,74 @@ func TestClient_Sync(t *testing.T) {
select { select {
case resp := <-ch: case resp := <-ch:
if resp.GetPeerConfig() == nil { if resp.GetPeerConfig() == nil && resp.GetNetworkMap().GetPeerConfig() == nil {
t.Error("expecting non nil PeerConfig got nil") t.Error("expecting non nil PeerConfig got nil")
} }
if resp.GetNetbirdConfig() == nil { if resp.GetNetbirdConfig() == nil {
t.Error("expecting non nil NetbirdConfig got nil") t.Error("expecting non nil NetbirdConfig got nil")
} }
if len(resp.GetRemotePeers()) != 1 { // Component-capable clients receive a NetworkMapEnvelope; the
t.Errorf("expecting RemotePeers size %d got %d", 1, len(resp.GetRemotePeers())) // remote-peers list is encoded inside it. Decode it and check the
// envelope's peers slice. Legacy peers populate the top-level
// RemotePeers; both shapes must surface exactly one remote peer.
remotePeerKeys := remotePeerKeysFromSync(resp, testKey.PublicKey().String())
if len(remotePeerKeys) != 1 {
t.Errorf("expecting RemotePeers size %d got %d", 1, len(remotePeerKeys))
return return
} }
if resp.GetRemotePeersIsEmpty() == true { if resp.GetNetworkMap() != nil && resp.GetRemotePeersIsEmpty() {
t.Error("expecting RemotePeers property to be false, got true") t.Error("expecting RemotePeers property to be false, got true")
} }
if resp.GetRemotePeers()[0].GetWgPubKey() != remoteKey.PublicKey().String() { if remotePeerKeys[0] != remoteKey.PublicKey().String() {
t.Errorf("expecting RemotePeer public key %s got %s", remoteKey.PublicKey().String(), resp.GetRemotePeers()[0].GetWgPubKey()) t.Errorf("expecting RemotePeer public key %s got %s", remoteKey.PublicKey().String(), remotePeerKeys[0])
} }
case <-time.After(3 * time.Second): case <-time.After(3 * time.Second):
t.Error("timeout waiting for test to finish") t.Error("timeout waiting for test to finish")
} }
} }
// remotePeerKeysFromSync extracts the remote-peer WG keys from either the
// legacy NetworkMap.RemotePeers list or the components NetworkMapEnvelope's
// inner peers slice (filtering out the local receiving peer identified by
// localKey, since the envelope's peers list is index-addressed and includes
// the local peer alongside remotes).
func remotePeerKeysFromSync(resp *mgmtProto.SyncResponse, localKey string) []string {
if rp := resp.GetRemotePeers(); len(rp) > 0 {
out := make([]string, 0, len(rp))
for _, p := range rp {
out = append(out, p.GetWgPubKey())
}
return out
}
env := resp.GetNetworkMapEnvelope().GetFull()
if env == nil {
return nil
}
out := make([]string, 0, len(env.GetPeers()))
for _, p := range env.GetPeers() {
key := wgKeyFromBytes(p.GetWgPubKey())
if key == "" || key == localKey {
continue
}
out = append(out, key)
}
return out
}
// wgKeyFromBytes mirrors the client-side decoder: the envelope ships raw 32
// bytes; reconstruct the standard base64 key the test compares against.
func wgKeyFromBytes(raw []byte) string {
if len(raw) == 0 {
return ""
}
var k wgtypes.Key
if len(raw) != len(k) {
return ""
}
copy(k[:], raw)
return k.String()
}
func Test_SystemMetaDataFromClient(t *testing.T) { func Test_SystemMetaDataFromClient(t *testing.T) {
s, lis, mgmtMockServer, serverKey := startMockManagement(t) s, lis, mgmtMockServer, serverKey := startMockManagement(t)
defer s.GracefulStop() defer s.GracefulStop()

View File

@@ -607,6 +607,61 @@ func (c *GrpcClient) Login(sysInfo *system.Info, pubSSHKey []byte, dnsLabels dom
return c.login(&proto.LoginRequest{Meta: infoToMetaData(sysInfo), PeerKeys: keys, DnsLabels: dnsLabels.ToPunycodeList()}) return c.login(&proto.LoginRequest{Meta: infoToMetaData(sysInfo), PeerKeys: keys, DnsLabels: dnsLabels.ToPunycodeList()})
} }
// ExtendAuthSession refreshes the peer's SSO session deadline on the management
// server using a freshly issued JWT. The tunnel is untouched: no network map
// sync, no peer reconnect. Returns the new absolute UTC deadline (zero time
// when the server reports the field empty).
func (c *GrpcClient) ExtendAuthSession(sysInfo *system.Info, jwtToken string) (*proto.ExtendAuthSessionResponse, error) {
if !c.ready() {
return nil, errors.New(errMsgNoMgmtConnection)
}
serverKey, err := c.getServerPublicKey()
if err != nil {
return nil, err
}
reqBody, err := encryption.EncryptMessage(*serverKey, c.key, &proto.ExtendAuthSessionRequest{
JwtToken: jwtToken,
Meta: infoToMetaData(sysInfo),
})
if err != nil {
log.Errorf("failed to encrypt extend auth session message: %s", err)
return nil, err
}
var resp *proto.EncryptedMessage
operation := func() error {
mgmCtx, cancel := context.WithTimeout(context.Background(), ConnectTimeout)
defer cancel()
var err error
resp, err = c.realClient.ExtendAuthSession(mgmCtx, &proto.EncryptedMessage{
WgPubKey: c.key.PublicKey().String(),
Body: reqBody,
})
if err != nil {
if s, ok := gstatus.FromError(err); ok && s.Code() == codes.Canceled {
return err
}
return backoff.Permanent(err)
}
return nil
}
if err := backoff.Retry(operation, nbgrpc.Backoff(c.ctx)); err != nil {
log.Errorf("failed to extend auth session on Management Service: %v", err)
return nil, err
}
out := &proto.ExtendAuthSessionResponse{}
if err := encryption.DecryptMessage(*serverKey, c.key, resp.Body, out); err != nil {
log.Errorf("failed to decrypt extend auth session response: %s", err)
return nil, err
}
return out, nil
}
// GetDeviceAuthorizationFlow returns a device authorization flow information. // GetDeviceAuthorizationFlow returns a device authorization flow information.
// It also takes care of encrypting and decrypting messages. // It also takes care of encrypting and decrypting messages.
func (c *GrpcClient) GetDeviceAuthorizationFlow() (*proto.DeviceAuthorizationFlow, error) { func (c *GrpcClient) GetDeviceAuthorizationFlow() (*proto.DeviceAuthorizationFlow, error) {
@@ -950,6 +1005,10 @@ func infoToMetaData(info *system.Info) *proto.PeerSystemMeta {
func peerCapabilities(info system.Info) []proto.PeerCapability { func peerCapabilities(info system.Info) []proto.PeerCapability {
caps := []proto.PeerCapability{ caps := []proto.PeerCapability{
proto.PeerCapability_PeerCapabilitySourcePrefixes, proto.PeerCapability_PeerCapabilitySourcePrefixes,
// PeerCapabilityComponentNetworkMap signals that this client can
// decode the components-format SyncResponse.NetworkMapEnvelope and
// run Calculate() locally.
proto.PeerCapability_PeerCapabilityComponentNetworkMap,
} }
if !info.DisableIPv6 { if !info.DisableIPv6 {
caps = append(caps, proto.PeerCapability_PeerCapabilityIPv6Overlay) caps = append(caps, proto.PeerCapability_PeerCapabilityIPv6Overlay)

View File

@@ -14,6 +14,7 @@ type MockClient struct {
SyncFunc func(ctx context.Context, sysInfo *system.Info, msgHandler func(msg *proto.SyncResponse) error) error SyncFunc func(ctx context.Context, sysInfo *system.Info, msgHandler func(msg *proto.SyncResponse) error) error
RegisterFunc func(setupKey string, jwtToken string, info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) RegisterFunc func(setupKey string, jwtToken string, info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error)
LoginFunc func(info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) LoginFunc func(info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error)
ExtendAuthSessionFunc func(info *system.Info, jwtToken string) (*proto.ExtendAuthSessionResponse, error)
GetDeviceAuthorizationFlowFunc func() (*proto.DeviceAuthorizationFlow, error) GetDeviceAuthorizationFlowFunc func() (*proto.DeviceAuthorizationFlow, error)
GetPKCEAuthorizationFlowFunc func() (*proto.PKCEAuthorizationFlow, error) GetPKCEAuthorizationFlowFunc func() (*proto.PKCEAuthorizationFlow, error)
GetServerURLFunc func() string GetServerURLFunc func() string
@@ -65,6 +66,13 @@ func (m *MockClient) Login(info *system.Info, sshKey []byte, dnsLabels domain.Li
return m.LoginFunc(info, sshKey, dnsLabels) return m.LoginFunc(info, sshKey, dnsLabels)
} }
func (m *MockClient) ExtendAuthSession(info *system.Info, jwtToken string) (*proto.ExtendAuthSessionResponse, error) {
if m.ExtendAuthSessionFunc == nil {
return nil, nil
}
return m.ExtendAuthSessionFunc(info, jwtToken)
}
func (m *MockClient) GetDeviceAuthorizationFlow() (*proto.DeviceAuthorizationFlow, error) { func (m *MockClient) GetDeviceAuthorizationFlow() (*proto.DeviceAuthorizationFlow, error) {
if m.GetDeviceAuthorizationFlowFunc == nil { if m.GetDeviceAuthorizationFlowFunc == nil {
return nil, nil return nil, nil

View File

@@ -0,0 +1,610 @@
package networkmap
import (
"encoding/base64"
"fmt"
"net"
"net/netip"
"strconv"
"time"
log "github.com/sirupsen/logrus"
nbdns "github.com/netbirdio/netbird/dns"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
nbroute "github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/shared/management/types"
)
// DecodeEnvelope converts a NetworkMapEnvelope into a NetworkMapComponents
// the client can run Calculate() over. Every ID-reference on the wire is a
// uint32 (peer index or account_seq_id) — no xid strings travel. The decoder
// synthesises consistent string IDs from the uint32s so the reconstructed
// components struct round-trips through Calculate exactly the way the
// server-side typed components would.
//
// ID scheme on the client side:
//
// Peers base64(wg_pub_key) // stable across snapshots
// Groups "g_<account_seq_id>"
// Policies "pol_<account_seq_id>" // 1 rule per policy
// Routes "r_<account_seq_id>"
// Network resources "nres_<account_seq_id>"
// Posture checks "pc_<account_seq_id>"
// Networks "net_<account_seq_id>"
// Nameserver groups "nsg_<account_seq_id>"
func DecodeEnvelope(env *proto.NetworkMapEnvelope) (*types.NetworkMapComponents, error) {
if env == nil {
return nil, fmt.Errorf("nil envelope")
}
full := env.GetFull()
if full == nil {
return nil, fmt.Errorf("envelope has no Full payload")
}
c := &types.NetworkMapComponents{
PeerID: "", // engine fills its own peer id from PeerConfig
Network: decodeAccountNetwork(full.Network),
AccountSettings: decodeAccountSettings(full.AccountSettings),
CustomZoneDomain: full.CustomZoneDomain,
Peers: make(map[string]*nbpeer.Peer, len(full.Peers)),
Groups: make(map[string]*types.Group, len(full.Groups)),
Policies: make([]*types.Policy, 0, len(full.Policies)),
Routes: make([]*nbroute.Route, 0, len(full.Routes)),
NameServerGroups: make([]*nbdns.NameServerGroup, 0, len(full.NameserverGroups)),
AllDNSRecords: decodeSimpleRecords(full.AllDnsRecords),
AccountZones: decodeCustomZones(full.AccountZones),
ResourcePoliciesMap: make(map[string][]*types.Policy),
RoutersMap: make(map[string]map[string]*routerTypes.NetworkRouter),
NetworkResources: make([]*resourceTypes.NetworkResource, 0, len(full.NetworkResources)),
RouterPeers: make(map[string]*nbpeer.Peer),
AllowedUserIDs: stringSliceToSet(full.AllowedUserIds),
PostureFailedPeers: make(map[string]map[string]struct{}, len(full.PostureFailedPeers)),
GroupIDToUserIDs: make(map[string][]string, len(full.GroupIdToUserIds)),
}
if full.DnsSettings != nil {
c.DNSSettings = &types.DNSSettings{
DisabledManagementGroups: groupIDsFromSeqs(full.DnsSettings.DisabledManagementGroupIds),
}
} else {
c.DNSSettings = &types.DNSSettings{}
}
// Phase 1: peers. The envelope's peers slice is index-addressed on the
// wire; we re-key by the peer's WireGuard public key (base64) so the
// in-memory components struct uses a stable identifier across
// snapshots. peerIDByIndex lets downstream phases resolve wire indexes
// back to that key. A peer with a missing or malformed wg_pub_key is
// skipped (and its index keeps "" so any cross-reference falls into the
// same missing-peer branch downstream) — matches legacy behaviour, which
// degrades gracefully rather than aborting the whole sync on a single
// bad row.
peerIDByIndex := make([]string, len(full.Peers))
for idx, pc := range full.Peers {
if pc == nil {
log.Warnf("envelope: peers[%d] is nil, skipping", idx)
continue
}
if len(pc.WgPubKey) != 32 {
log.Warnf("envelope: peers[%d] wg_pub_key length %d (want 32), skipping", idx, len(pc.WgPubKey))
continue
}
peerID := base64.StdEncoding.EncodeToString(pc.WgPubKey)
peer := decodePeerCompact(pc, peerID, full.AgentVersions)
c.Peers[peerID] = peer
peerIDByIndex[idx] = peerID
}
// Phase 2: groups. AccountSeqID becomes both the synthesized string ID
// and the GroupCompact.id wire value.
for i, gc := range full.Groups {
if gc == nil {
return nil, fmt.Errorf("invalid envelope: groups[%d] is nil", i)
}
groupID := synthGroupID(gc.Id)
peerIDs := make([]string, 0, len(gc.PeerIndexes))
for _, idx := range gc.PeerIndexes {
if int(idx) < len(peerIDByIndex) {
peerIDs = append(peerIDs, peerIDByIndex[idx])
}
}
c.Groups[groupID] = &types.Group{
ID: groupID,
AccountSeqID: gc.Id,
Name: gc.Name,
Peers: peerIDs,
}
}
// Phase 3: policies (PolicyCompact = one rule per entry; current data
// model is 1 rule per policy). Policy.ID is synthesized from the
// per-account seq id; proto.FirewallRule.PolicyID downstream carries
// the same synth string (no xid on the wire).
for i, pc := range full.Policies {
if pc == nil {
return nil, fmt.Errorf("invalid envelope: policies[%d] is nil", i)
}
policyID := synthPolicyID(pc.Id)
c.Policies = append(c.Policies, decodePolicyCompact(pc, policyID, peerIDByIndex))
}
// Phase 4: routes.
for i, rr := range full.Routes {
if rr == nil {
return nil, fmt.Errorf("invalid envelope: routes[%d] is nil", i)
}
c.Routes = append(c.Routes, decodeRouteRaw(rr, peerIDByIndex))
}
// Phase 5: NSGs.
for i, nsg := range full.NameserverGroups {
if nsg == nil {
return nil, fmt.Errorf("invalid envelope: nameserver_groups[%d] is nil", i)
}
c.NameServerGroups = append(c.NameServerGroups, decodeNameServerGroupRaw(nsg))
}
// Phase 6: network resources.
for i, nr := range full.NetworkResources {
if nr == nil {
return nil, fmt.Errorf("invalid envelope: network_resources[%d] is nil", i)
}
c.NetworkResources = append(c.NetworkResources, decodeNetworkResource(nr))
}
// Phase 7: routers_map (outer key = network seq id, inner key = peer-id
// reconstructed from peer_index). Synthesized network id is "net_<seq>".
for networkSeq, list := range full.RoutersMap {
networkID := synthNetworkID(networkSeq)
inner := make(map[string]*routerTypes.NetworkRouter, len(list.Entries))
for _, entry := range list.Entries {
if !entry.PeerIndexSet {
continue
}
if int(entry.PeerIndex) >= len(peerIDByIndex) {
continue
}
peerID := peerIDByIndex[entry.PeerIndex]
inner[peerID] = &routerTypes.NetworkRouter{
ID: "",
NetworkID: networkID,
AccountSeqID: entry.Id,
Peer: peerID,
PeerGroups: groupIDsFromSeqs(entry.PeerGroupIds),
Masquerade: entry.Masquerade,
Metric: int(entry.Metric),
Enabled: entry.Enabled,
}
}
if len(inner) > 0 {
c.RoutersMap[networkID] = inner
}
}
// Phase 8: resource_policies_map (resource seq id → list of *types.Policy
// pointers from the decoded policies slice). Resource ID is synthesized
// the same way as in decodeNetworkResource.
for resourceSeq, idxs := range full.ResourcePoliciesMap {
if len(idxs.Indexes) == 0 {
continue
}
resourceID := synthNetworkResourceID(resourceSeq)
policies := make([]*types.Policy, 0, len(idxs.Indexes))
for _, i := range idxs.Indexes {
if int(i) < len(c.Policies) {
policies = append(policies, c.Policies[i])
}
}
if len(policies) > 0 {
c.ResourcePoliciesMap[resourceID] = policies
}
}
// Phase 9: group_id_to_user_ids — wire keys are seq ids, synth to strings.
for groupSeq, list := range full.GroupIdToUserIds {
c.GroupIDToUserIDs[synthGroupID(groupSeq)] = append([]string(nil), list.UserIds...)
}
// Phase 10: posture_failed_peers — wire keys are posture-check seq ids,
// values are peer indexes that need to be turned into peer ids. PolicyRule
// SourcePostureChecks (also synth ids) reference the same key space.
for checkSeq, set := range full.PostureFailedPeers {
checkID := synthPostureCheckID(checkSeq)
failed := make(map[string]struct{}, len(set.PeerIndexes))
for _, idx := range set.PeerIndexes {
if int(idx) < len(peerIDByIndex) {
failed[peerIDByIndex[idx]] = struct{}{}
}
}
if len(failed) > 0 {
c.PostureFailedPeers[checkID] = failed
}
}
// Phase 11: router_peer_indexes — peers that act as routers. They're
// already in c.Peers (router peers are appended to the global peers
// list by the encoder); RouterPeers is the subset.
for _, idx := range full.RouterPeerIndexes {
if int(idx) < len(peerIDByIndex) {
peerID := peerIDByIndex[idx]
c.RouterPeers[peerID] = c.Peers[peerID]
}
}
return c, nil
}
func decodeAccountNetwork(an *proto.AccountNetwork) *types.Network {
if an == nil {
return nil
}
n := &types.Network{
Identifier: an.Identifier,
Dns: an.Dns,
Serial: an.Serial,
}
if an.NetCidr != "" {
if _, ipnet, err := net.ParseCIDR(an.NetCidr); err == nil && ipnet != nil {
n.Net = *ipnet
}
}
if an.NetV6Cidr != "" {
if _, ipnet, err := net.ParseCIDR(an.NetV6Cidr); err == nil && ipnet != nil {
n.NetV6 = *ipnet
}
}
return n
}
func decodeAccountSettings(as *proto.AccountSettingsCompact) *types.AccountSettingsInfo {
if as == nil {
return &types.AccountSettingsInfo{}
}
return &types.AccountSettingsInfo{
PeerLoginExpirationEnabled: as.PeerLoginExpirationEnabled,
PeerLoginExpiration: time.Duration(as.PeerLoginExpirationNs),
}
}
func decodePeerCompact(pc *proto.PeerCompact, peerID string, agentVersions []string) *nbpeer.Peer {
var caps []int32
if pc.SupportsSourcePrefixes {
caps = append(caps, nbpeer.PeerCapabilitySourcePrefixes)
}
if pc.SupportsIpv6 {
caps = append(caps, nbpeer.PeerCapabilityIPv6Overlay)
}
peer := &nbpeer.Peer{
ID: peerID,
Key: peerID,
SSHKey: string(pc.SshPubKey),
SSHEnabled: pc.SshEnabled,
DNSLabel: pc.DnsLabel,
LoginExpirationEnabled: pc.LoginExpirationEnabled,
Meta: nbpeer.PeerSystemMeta{
WtVersion: lookupAgentVersion(agentVersions, pc.AgentVersionIdx),
Capabilities: caps,
Flags: nbpeer.Flags{
ServerSSHAllowed: pc.ServerSshAllowed,
},
},
}
if pc.AddedWithSsoLogin {
// Set a non-empty UserID so (*Peer).AddedWithSSOLogin() returns true.
// The original UserID isn't on the wire; the value is intentionally
// visibly synthetic so any future consumer that mistakes UserID for a
// real account user xid won't silently match (or worse, write the
// sentinel into a downstream record).
peer.UserID = "<env-sso>"
}
if pc.LastLoginUnixNano != 0 {
t := time.Unix(0, pc.LastLoginUnixNano)
peer.LastLogin = &t
}
switch len(pc.Ip) {
case 4:
peer.IP = netip.AddrFrom4([4]byte{pc.Ip[0], pc.Ip[1], pc.Ip[2], pc.Ip[3]})
case 16:
var a [16]byte
copy(a[:], pc.Ip)
peer.IP = netip.AddrFrom16(a)
}
if len(pc.Ipv6) == 16 {
var a [16]byte
copy(a[:], pc.Ipv6)
peer.IPv6 = netip.AddrFrom16(a)
}
return peer
}
func decodePolicyCompact(pc *proto.PolicyCompact, policyID string, peerIDByIndex []string) *types.Policy {
rule := &types.PolicyRule{
ID: policyID, // 1 rule per policy → reuse synthesized id
PolicyID: policyID,
Enabled: true,
Action: actionFromProto(pc.Action),
Protocol: protocolFromProto(pc.Protocol),
Bidirectional: pc.Bidirectional,
Ports: uint32SliceToStrings(pc.Ports),
PortRanges: portRangesFromProto(pc.PortRanges),
Sources: groupIDsFromSeqs(pc.SourceGroupIds),
Destinations: groupIDsFromSeqs(pc.DestinationGroupIds),
AuthorizedUser: pc.AuthorizedUser,
AuthorizedGroups: authorizedGroupsFromProto(pc.AuthorizedGroups),
SourceResource: resourceFromProto(pc.SourceResource, peerIDByIndex),
DestinationResource: resourceFromProto(pc.DestinationResource, peerIDByIndex),
}
return &types.Policy{
ID: policyID,
AccountSeqID: pc.Id,
Enabled: true,
Rules: []*types.PolicyRule{rule},
SourcePostureChecks: postureCheckIDsFromSeqs(pc.SourcePostureCheckSeqIds),
}
}
// resourceFromProto rebuilds types.Resource. For peer-typed resources the
// peer reference is reconstructed from the envelope's peer index — wire
// format ships no xid for peers, so we use the synthesized peer id.
func resourceFromProto(r *proto.ResourceCompact, peerIDByIndex []string) types.Resource {
if r == nil {
return types.Resource{}
}
out := types.Resource{Type: types.ResourceType(r.Type)}
if r.PeerIndexSet && int(r.PeerIndex) < len(peerIDByIndex) {
out.ID = peerIDByIndex[r.PeerIndex]
}
return out
}
// postureCheckIDsFromSeqs synths posture-check ids from per-account seq ids.
// Mirrors groupIDsFromSeqs.
func postureCheckIDsFromSeqs(seqs []uint32) []string {
if len(seqs) == 0 {
return nil
}
out := make([]string, len(seqs))
for i, s := range seqs {
out[i] = synthPostureCheckID(s)
}
return out
}
// authorizedGroupsFromProto inverts encodeAuthorizedGroups: the wire form
// keys by group account_seq_id, the typed PolicyRule field keys by group
// xid string. We rebuild using the same synthetic scheme the rest of the
// decoder uses ("g<seq>").
func authorizedGroupsFromProto(m map[uint32]*proto.UserNameList) map[string][]string {
if len(m) == 0 {
return nil
}
out := make(map[string][]string, len(m))
for seq, list := range m {
if list == nil {
continue
}
out[synthGroupID(seq)] = append([]string(nil), list.Names...)
}
return out
}
func decodeRouteRaw(rr *proto.RouteRaw, peerIDByIndex []string) *nbroute.Route {
r := &nbroute.Route{
ID: nbroute.ID(synthRouteID(rr.Id)),
AccountSeqID: rr.Id,
NetID: nbroute.NetID(rr.NetId),
Description: rr.Description,
Domains: domainsFromPunycode(rr.Domains),
KeepRoute: rr.KeepRoute,
NetworkType: nbroute.NetworkType(rr.NetworkType),
Masquerade: rr.Masquerade,
Metric: int(rr.Metric),
Enabled: rr.Enabled,
Groups: groupIDsFromSeqs(rr.GroupIds),
AccessControlGroups: groupIDsFromSeqs(rr.AccessControlGroupIds),
PeerGroups: groupIDsFromSeqs(rr.PeerGroupIds),
SkipAutoApply: rr.SkipAutoApply,
}
if rr.NetworkCidr != "" {
if p, err := netip.ParsePrefix(rr.NetworkCidr); err == nil {
r.Network = p
}
}
if rr.PeerIndexSet && int(rr.PeerIndex) < len(peerIDByIndex) {
r.Peer = peerIDByIndex[rr.PeerIndex]
}
return r
}
func decodeNameServerGroupRaw(nsg *proto.NameServerGroupRaw) *nbdns.NameServerGroup {
out := &nbdns.NameServerGroup{
ID: synthNameServerGroupID(nsg.Id),
AccountSeqID: nsg.Id,
Name: nsg.Name,
Description: nsg.Description,
Groups: groupIDsFromSeqs(nsg.GroupIds),
Primary: nsg.Primary,
Domains: nsg.Domains,
Enabled: nsg.Enabled,
SearchDomainsEnabled: nsg.SearchDomainsEnabled,
NameServers: make([]nbdns.NameServer, 0, len(nsg.Nameservers)),
}
for _, ns := range nsg.Nameservers {
if addr, err := netip.ParseAddr(ns.IP); err == nil {
out.NameServers = append(out.NameServers, nbdns.NameServer{
IP: addr,
NSType: nbdns.NameServerType(ns.NSType),
Port: int(ns.Port),
})
}
}
return out
}
func decodeNetworkResource(nr *proto.NetworkResourceRaw) *resourceTypes.NetworkResource {
out := &resourceTypes.NetworkResource{
ID: synthNetworkResourceID(nr.Id),
AccountSeqID: nr.Id,
NetworkID: synthNetworkID(nr.NetworkSeq),
Name: nr.Name,
Description: nr.Description,
Type: resourceTypes.NetworkResourceType(nr.Type),
Address: nr.Address,
Domain: nr.DomainValue,
Enabled: nr.Enabled,
}
if nr.PrefixCidr != "" {
if p, err := netip.ParsePrefix(nr.PrefixCidr); err == nil {
out.Prefix = p
}
}
return out
}
func decodeSimpleRecords(records []*proto.SimpleRecord) []nbdns.SimpleRecord {
out := make([]nbdns.SimpleRecord, 0, len(records))
for _, r := range records {
out = append(out, nbdns.SimpleRecord{
Name: r.Name,
Type: int(r.Type),
Class: r.Class,
TTL: int(r.TTL),
RData: r.RData,
})
}
return out
}
func decodeCustomZones(zones []*proto.CustomZone) []nbdns.CustomZone {
out := make([]nbdns.CustomZone, 0, len(zones))
for _, z := range zones {
out = append(out, nbdns.CustomZone{
Domain: z.Domain,
Records: decodeSimpleRecords(z.Records),
SearchDomainDisabled: z.SearchDomainDisabled,
NonAuthoritative: z.NonAuthoritative,
})
}
return out
}
// Synthetic ID generators — deterministic given the same wire input.
// Underscore-separated ("p_<n>", "pol_<n>", ...) so they're visually
// distinct in operator logs. fmt.Sprintf would dominate the decode hot path
// on large accounts (a 10k-peer envelope produces ~50k synth calls); the
// strconv.AppendUint builder keeps it allocation-light.
func synthID(prefix string, n uint32) string {
buf := make([]byte, 0, len(prefix)+10)
buf = append(buf, prefix...)
buf = strconv.AppendUint(buf, uint64(n), 10)
return string(buf)
}
func synthGroupID(seq uint32) string { return synthID("g_", seq) }
func synthPolicyID(seq uint32) string { return synthID("pol_", seq) }
func synthRouteID(seq uint32) string { return synthID("r_", seq) }
func synthNetworkResourceID(seq uint32) string { return synthID("nres_", seq) }
func synthPostureCheckID(seq uint32) string { return synthID("pc_", seq) }
func synthNetworkID(seq uint32) string { return synthID("net_", seq) }
func synthNameServerGroupID(seq uint32) string { return synthID("nsg_", seq) }
func groupIDsFromSeqs(seqs []uint32) []string {
if len(seqs) == 0 {
return nil
}
out := make([]string, len(seqs))
for i, s := range seqs {
out[i] = synthGroupID(s)
}
return out
}
func uint32SliceToStrings(ports []uint32) []string {
if len(ports) == 0 {
return nil
}
out := make([]string, len(ports))
for i, p := range ports {
out[i] = strconv.FormatUint(uint64(p), 10)
}
return out
}
func portRangesFromProto(ranges []*proto.PortInfo_Range) []types.RulePortRange {
if len(ranges) == 0 {
return nil
}
out := make([]types.RulePortRange, 0, len(ranges))
for _, r := range ranges {
if r == nil || r.Start > 65535 || r.End > 65535 {
continue
}
out = append(out, types.RulePortRange{
Start: uint16(r.Start),
End: uint16(r.End),
})
}
return out
}
func actionFromProto(a proto.RuleAction) types.PolicyTrafficActionType {
if a == proto.RuleAction_DROP {
return types.PolicyTrafficActionDrop
}
return types.PolicyTrafficActionAccept
}
func protocolFromProto(p proto.RuleProtocol) types.PolicyRuleProtocolType {
switch p {
case proto.RuleProtocol_TCP:
return types.PolicyRuleProtocolTCP
case proto.RuleProtocol_UDP:
return types.PolicyRuleProtocolUDP
case proto.RuleProtocol_ICMP:
return types.PolicyRuleProtocolICMP
case proto.RuleProtocol_ALL:
return types.PolicyRuleProtocolALL
case proto.RuleProtocol_NETBIRD_SSH:
return types.PolicyRuleProtocolNetbirdSSH
default:
return types.PolicyRuleProtocolALL
}
}
func lookupAgentVersion(table []string, idx uint32) string {
if int(idx) < len(table) {
return table[idx]
}
return ""
}
func stringSliceToSet(s []string) map[string]struct{} {
if len(s) == 0 {
return nil
}
out := make(map[string]struct{}, len(s))
for _, v := range s {
out[v] = struct{}{}
}
return out
}
// domainsFromPunycode is a thin wrapper that converts a punycode list back to
// the domain.List type the route.Route struct expects. It accepts the
// punycode strings as-is (no extra decoding) — symmetric with
// route.Domains.ToPunycodeList() used in the encoder.
func domainsFromPunycode(punycoded []string) domain.List {
if len(punycoded) == 0 {
return nil
}
out := make(domain.List, 0, len(punycoded))
for _, d := range punycoded {
out = append(out, domain.Domain(d))
}
return out
}

View File

@@ -0,0 +1,323 @@
// Package networkmap contains the shared NetworkMap helpers that both the
// management server and the client agent need.
//
// The proto-conversion helpers (types.NetworkMap → proto.NetworkMap) live
// here so the client can run the same conversion locally after deriving its
// NetworkMap from a NetworkMapEnvelope, without taking a dependency on the
// server-side conversion package (which pulls in cloud integrations and is
// otherwise an unwanted internal import on the client).
//
// The helpers are pure functions over inputs — no caches, no IO, no logging
// beyond a context-aware error log when an individual user-id hash fails.
package networkmap
import (
"context"
log "github.com/sirupsen/logrus"
goproto "google.golang.org/protobuf/proto"
nbdns "github.com/netbirdio/netbird/dns"
"net/netip"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/shared/management/types"
nbroute "github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/shared/netiputil"
"github.com/netbirdio/netbird/shared/sshauth"
)
// ToProtocolRoutes converts a slice of typed routes to their proto form.
func ToProtocolRoutes(routes []*nbroute.Route) []*proto.Route {
protoRoutes := make([]*proto.Route, 0, len(routes))
for _, r := range routes {
protoRoutes = append(protoRoutes, ToProtocolRoute(r))
}
return protoRoutes
}
// ToProtocolRoute converts one typed route to its proto form.
func ToProtocolRoute(route *nbroute.Route) *proto.Route {
return &proto.Route{
ID: string(route.ID),
NetID: string(route.NetID),
Network: route.Network.String(),
Domains: route.Domains.ToPunycodeList(),
NetworkType: int64(route.NetworkType),
Peer: route.Peer,
Metric: int64(route.Metric),
Masquerade: route.Masquerade,
KeepRoute: route.KeepRoute,
SkipAutoApply: route.SkipAutoApply,
}
}
// ToProtocolFirewallRules converts the firewall rules to the protocol form.
// When useSourcePrefixes is true, the compact SourcePrefixes field is
// populated alongside the deprecated PeerIP for forward compatibility.
// Wildcard rules ("0.0.0.0") are expanded into separate v4/v6 SourcePrefixes
// when includeIPv6 is true.
func ToProtocolFirewallRules(rules []*types.FirewallRule, includeIPv6, useSourcePrefixes bool) []*proto.FirewallRule {
result := make([]*proto.FirewallRule, 0, len(rules))
for i := range rules {
rule := rules[i]
fwRule := &proto.FirewallRule{
PolicyID: []byte(rule.PolicyID),
PeerIP: rule.PeerIP, //nolint:staticcheck // populated for backward compatibility
Direction: GetProtoDirection(rule.Direction),
Action: GetProtoAction(rule.Action),
Protocol: GetProtoProtocol(rule.Protocol),
Port: rule.Port,
}
if useSourcePrefixes && rule.PeerIP != "" {
result = append(result, populateSourcePrefixes(fwRule, rule, includeIPv6)...)
}
if ShouldUsePortRange(fwRule) {
fwRule.PortInfo = rule.PortRange.ToProto()
}
result = append(result, fwRule)
}
return result
}
// populateSourcePrefixes sets SourcePrefixes on fwRule and returns any
// additional rules needed (e.g. a v6 wildcard clone when the peer IP is
// unspecified).
func populateSourcePrefixes(fwRule *proto.FirewallRule, rule *types.FirewallRule, includeIPv6 bool) []*proto.FirewallRule {
addr, err := netip.ParseAddr(rule.PeerIP)
if err != nil {
return nil
}
if !addr.IsUnspecified() {
fwRule.SourcePrefixes = [][]byte{netiputil.EncodeAddr(addr.Unmap())}
return nil
}
v4Wildcard, _ := netiputil.EncodePrefix(netip.PrefixFrom(netip.IPv4Unspecified(), 0))
fwRule.SourcePrefixes = [][]byte{v4Wildcard}
if !includeIPv6 {
return nil
}
v6Rule := goproto.Clone(fwRule).(*proto.FirewallRule)
v6Rule.PeerIP = "::" //nolint:staticcheck // populated for backward compatibility
v6Wildcard, _ := netiputil.EncodePrefix(netip.PrefixFrom(netip.IPv6Unspecified(), 0))
v6Rule.SourcePrefixes = [][]byte{v6Wildcard}
if ShouldUsePortRange(v6Rule) {
v6Rule.PortInfo = rule.PortRange.ToProto()
}
return []*proto.FirewallRule{v6Rule}
}
// GetProtoDirection converts the direction to proto.RuleDirection.
func GetProtoDirection(direction int) proto.RuleDirection {
if direction == types.FirewallRuleDirectionOUT {
return proto.RuleDirection_OUT
}
return proto.RuleDirection_IN
}
// GetProtoAction converts the action to proto.RuleAction.
func GetProtoAction(action string) proto.RuleAction {
if action == string(types.PolicyTrafficActionDrop) {
return proto.RuleAction_DROP
}
return proto.RuleAction_ACCEPT
}
// GetProtoProtocol converts the protocol to proto.RuleProtocol.
func GetProtoProtocol(protocol string) proto.RuleProtocol {
switch types.PolicyRuleProtocolType(protocol) {
case types.PolicyRuleProtocolALL:
return proto.RuleProtocol_ALL
case types.PolicyRuleProtocolTCP:
return proto.RuleProtocol_TCP
case types.PolicyRuleProtocolUDP:
return proto.RuleProtocol_UDP
case types.PolicyRuleProtocolICMP:
return proto.RuleProtocol_ICMP
case types.PolicyRuleProtocolNetbirdSSH:
return proto.RuleProtocol_NETBIRD_SSH
default:
return proto.RuleProtocol_UNKNOWN
}
}
// GetProtoPortInfo converts route-firewall-rule port info to proto.PortInfo.
func GetProtoPortInfo(rule *types.RouteFirewallRule) *proto.PortInfo {
var portInfo proto.PortInfo
if rule.Port != 0 {
portInfo.PortSelection = &proto.PortInfo_Port{Port: uint32(rule.Port)}
} else if portRange := rule.PortRange; portRange.Start != 0 && portRange.End != 0 {
portInfo.PortSelection = &proto.PortInfo_Range_{
Range: &proto.PortInfo_Range{
Start: uint32(portRange.Start),
End: uint32(portRange.End),
},
}
}
return &portInfo
}
// ShouldUsePortRange reports whether the firewall rule should use a port
// range rather than a single port (TCP/UDP without a single port).
func ShouldUsePortRange(rule *proto.FirewallRule) bool {
return rule.Port == "" && (rule.Protocol == proto.RuleProtocol_UDP || rule.Protocol == proto.RuleProtocol_TCP)
}
// ToProtocolRoutesFirewallRules converts a slice of typed route-firewall
// rules to proto.
func ToProtocolRoutesFirewallRules(rules []*types.RouteFirewallRule) []*proto.RouteFirewallRule {
result := make([]*proto.RouteFirewallRule, len(rules))
for i := range rules {
rule := rules[i]
result[i] = &proto.RouteFirewallRule{
SourceRanges: rule.SourceRanges,
Action: GetProtoAction(rule.Action),
Destination: rule.Destination,
Protocol: GetProtoProtocol(rule.Protocol),
PortInfo: GetProtoPortInfo(rule),
IsDynamic: rule.IsDynamic,
Domains: rule.Domains.ToPunycodeList(),
PolicyID: []byte(rule.PolicyID),
RouteID: string(rule.RouteID),
}
}
return result
}
// ConvertToProtoCustomZone converts an nbdns.CustomZone to its proto form.
func ConvertToProtoCustomZone(zone nbdns.CustomZone) *proto.CustomZone {
protoZone := &proto.CustomZone{
Domain: zone.Domain,
Records: make([]*proto.SimpleRecord, 0, len(zone.Records)),
SearchDomainDisabled: zone.SearchDomainDisabled,
NonAuthoritative: zone.NonAuthoritative,
}
for _, record := range zone.Records {
protoZone.Records = append(protoZone.Records, &proto.SimpleRecord{
Name: record.Name,
Type: int64(record.Type),
Class: record.Class,
TTL: int64(record.TTL),
RData: record.RData,
})
}
return protoZone
}
// ConvertToProtoNameServerGroup converts a NameServerGroup to its proto form.
func ConvertToProtoNameServerGroup(nsGroup *nbdns.NameServerGroup) *proto.NameServerGroup {
protoGroup := &proto.NameServerGroup{
Primary: nsGroup.Primary,
Domains: nsGroup.Domains,
SearchDomainsEnabled: nsGroup.SearchDomainsEnabled,
NameServers: make([]*proto.NameServer, 0, len(nsGroup.NameServers)),
}
for _, ns := range nsGroup.NameServers {
protoGroup.NameServers = append(protoGroup.NameServers, &proto.NameServer{
IP: ns.IP.String(),
Port: int64(ns.Port),
NSType: int64(ns.NSType),
})
}
return protoGroup
}
// DNSConfigCache is the cache contract for amortising NameServerGroup
// proto-conversion across peers in the same account. Server uses a concrete
// implementation; client passes nil (no cross-peer caching needed when
// rebuilding a single NetworkMap from an envelope).
type DNSConfigCache interface {
GetNameServerGroup(key string) (*proto.NameServerGroup, bool)
SetNameServerGroup(key string, value *proto.NameServerGroup)
}
// ToProtocolDNSConfig converts nbdns.Config to proto.DNSConfig. If cache is
// non-nil, NameServerGroup proto values are cached by NSG.ID across calls —
// the server amortises this across peers, the client passes nil.
func ToProtocolDNSConfig(update nbdns.Config, cache DNSConfigCache, forwardPort int64) *proto.DNSConfig {
protoUpdate := &proto.DNSConfig{
ServiceEnable: update.ServiceEnable,
CustomZones: make([]*proto.CustomZone, 0, len(update.CustomZones)),
NameServerGroups: make([]*proto.NameServerGroup, 0, len(update.NameServerGroups)),
ForwarderPort: forwardPort,
}
for _, zone := range update.CustomZones {
protoUpdate.CustomZones = append(protoUpdate.CustomZones, ConvertToProtoCustomZone(zone))
}
for _, nsGroup := range update.NameServerGroups {
if cache != nil {
if cachedGroup, exists := cache.GetNameServerGroup(nsGroup.ID); exists {
protoUpdate.NameServerGroups = append(protoUpdate.NameServerGroups, cachedGroup)
continue
}
}
protoGroup := ConvertToProtoNameServerGroup(nsGroup)
if cache != nil {
cache.SetNameServerGroup(nsGroup.ID, protoGroup)
}
protoUpdate.NameServerGroups = append(protoUpdate.NameServerGroups, protoGroup)
}
return protoUpdate
}
// AppendRemotePeerConfig appends typed peers as proto.RemotePeerConfig
// entries to dst and returns the result.
func AppendRemotePeerConfig(dst []*proto.RemotePeerConfig, peers []*nbpeer.Peer, dnsName string, includeIPv6 bool) []*proto.RemotePeerConfig {
for _, rPeer := range peers {
allowedIPs := []string{rPeer.IP.String() + "/32"}
if includeIPv6 && rPeer.IPv6.IsValid() {
allowedIPs = append(allowedIPs, rPeer.IPv6.String()+"/128")
}
dst = append(dst, &proto.RemotePeerConfig{
WgPubKey: rPeer.Key,
AllowedIps: allowedIPs,
SshConfig: &proto.SSHConfig{SshPubKey: []byte(rPeer.SSHKey)},
Fqdn: rPeer.FQDN(dnsName),
AgentVersion: rPeer.Meta.WtVersion,
})
}
return dst
}
// BuildAuthorizedUsersProto deduplicates user-IDs into a hashed list and
// builds per-machine-user index maps. Returns (hashedUsers, machineUsers).
// Errors from individual hash failures are logged via the provided context;
// they leave the offending user out of the result but don't abort the build.
func BuildAuthorizedUsersProto(ctx context.Context, authorizedUsers map[string]map[string]struct{}) ([][]byte, map[string]*proto.MachineUserIndexes) {
userIDToIndex := make(map[string]uint32)
var hashedUsers [][]byte
machineUsers := make(map[string]*proto.MachineUserIndexes, len(authorizedUsers))
for machineUser, users := range authorizedUsers {
indexes := make([]uint32, 0, len(users))
for userID := range users {
idx, exists := userIDToIndex[userID]
if !exists {
hash, err := sshauth.HashUserID(userID)
if err != nil {
log.WithContext(ctx).WithError(err).Error("failed to hash user id")
continue
}
idx = uint32(len(hashedUsers))
userIDToIndex[userID] = idx
hashedUsers = append(hashedUsers, hash[:])
}
indexes = append(indexes, idx)
}
machineUsers[machineUser] = &proto.MachineUserIndexes{Indexes: indexes}
}
return hashedUsers, machineUsers
}

View File

@@ -0,0 +1,190 @@
package networkmap
import (
"context"
"encoding/base64"
"fmt"
"github.com/netbirdio/netbird/shared/management/types"
"github.com/netbirdio/netbird/shared/management/proto"
)
// EnvelopeResult is what the client engine consumes after receiving a
// component-format NetworkMap. Both fields are populated:
//
// - NetworkMap is the *proto.NetworkMap shape the engine reads today via
// update.GetNetworkMap() — built from the envelope's components by
// running Calculate() locally + converting back through the shared
// proto helpers + merging the optional ProxyPatch.
// - Components is the *types.NetworkMapComponents the engine retains so
// future incremental delta updates have a base to apply changes
// against. The client keeps it under its sync lock.
type EnvelopeResult struct {
NetworkMap *proto.NetworkMap
Components *types.NetworkMapComponents
}
// EnvelopeToNetworkMap is the full client-side pipeline: decode the
// component envelope back to a typed NetworkMapComponents, run Calculate()
// locally to produce the typed NetworkMap, convert it to the wire form the
// engine consumes, and fold in any ProxyPatch the server attached.
//
// localPeerKey is the receiving peer's WG pub key (used to derive
// includeIPv6 / useSourcePrefixes from the receiving peer's own record in
// the components struct, mirroring legacy ToSyncResponse behaviour).
//
// dnsName is the account's DNS domain ("netbird.cloud" etc.); used when
// rebuilding the per-peer FQDNs that proto.RemotePeerConfig carries.
func EnvelopeToNetworkMap(ctx context.Context, env *proto.NetworkMapEnvelope, localPeerKey, dnsName string) (*EnvelopeResult, error) {
components, err := DecodeEnvelope(env)
if err != nil {
return nil, fmt.Errorf("decode envelope: %w", err)
}
// Find the receiving peer in the decoded components by WG key.
// c.Peers is keyed by canonical base64 of the raw 32-byte pub key
// (decoder re-encodes the bytes off the wire). The caller may pass a
// non-canonical encoding (some persisted production keys carry
// non-zero trailing padding bits that survived a legacy import), so
// round-trip through raw bytes once to canonicalize before lookup.
canonicalKey := canonicalizeWgKey(localPeerKey)
localPeer := components.Peers[canonicalKey]
if localPeer == nil {
return nil, fmt.Errorf("receiving peer (wg_key prefix %q) not found among %d decoded peers — components have no PeerID, Calculate would return empty", trimKey(localPeerKey), len(components.Peers))
}
components.PeerID = canonicalKey
includeIPv6 := localPeer.SupportsIPv6() && localPeer.IPv6.IsValid()
useSourcePrefixes := localPeer.SupportsSourcePrefixes()
typedNM := components.Calculate(ctx)
full := env.GetFull()
dnsFwdPort := int64(0)
if full != nil {
dnsFwdPort = full.DnsForwarderPort
}
protoNM := &proto.NetworkMap{
Serial: typedNM.Network.CurrentSerial(),
}
if full != nil {
protoNM.PeerConfig = full.PeerConfig
}
protoNM.Routes = ToProtocolRoutes(typedNM.Routes)
protoNM.DNSConfig = ToProtocolDNSConfig(typedNM.DNSConfig, nil, dnsFwdPort)
remotePeers := AppendRemotePeerConfig(nil, typedNM.Peers, dnsName, includeIPv6)
protoNM.RemotePeers = remotePeers
protoNM.RemotePeersIsEmpty = len(remotePeers) == 0
protoNM.OfflinePeers = AppendRemotePeerConfig(nil, typedNM.OfflinePeers, dnsName, includeIPv6)
firewallRules := ToProtocolFirewallRules(typedNM.FirewallRules, includeIPv6, useSourcePrefixes)
protoNM.FirewallRules = firewallRules
protoNM.FirewallRulesIsEmpty = len(firewallRules) == 0
routesFirewallRules := ToProtocolRoutesFirewallRules(typedNM.RoutesFirewallRules)
protoNM.RoutesFirewallRules = routesFirewallRules
protoNM.RoutesFirewallRulesIsEmpty = len(routesFirewallRules) == 0
if typedNM.AuthorizedUsers != nil {
hashedUsers, machineUsers := BuildAuthorizedUsersProto(ctx, typedNM.AuthorizedUsers)
userIDClaim := ""
if full != nil {
userIDClaim = full.UserIdClaim
}
protoNM.SshAuth = &proto.SSHAuth{
AuthorizedUsers: hashedUsers,
MachineUsers: machineUsers,
UserIDClaim: userIDClaim,
}
}
if typedNM.ForwardingRules != nil {
forwardingRules := make([]*proto.ForwardingRule, 0, len(typedNM.ForwardingRules))
for _, rule := range typedNM.ForwardingRules {
forwardingRules = append(forwardingRules, rule.ToProto())
}
protoNM.ForwardingRules = forwardingRules
}
// Merge the proxy patch the server attached. Mirrors the legacy
// NetworkMap.Merge step that the server runs after Calculate().
if full != nil && full.ProxyPatch != nil {
mergeProxyPatch(protoNM, full.ProxyPatch)
}
return &EnvelopeResult{
NetworkMap: protoNM,
Components: components,
}, nil
}
// mergeProxyPatch folds a ProxyPatch's pre-expanded fragments into the
// proto.NetworkMap that Calculate() produced. Mirrors types.NetworkMap.Merge
// — same six collections, deduplicated where the legacy merge dedupes.
func mergeProxyPatch(nm *proto.NetworkMap, patch *proto.ProxyPatch) {
nm.RemotePeers = appendUniquePeers(nm.RemotePeers, patch.Peers)
nm.OfflinePeers = appendUniquePeers(nm.OfflinePeers, patch.OfflinePeers)
nm.FirewallRules = append(nm.FirewallRules, patch.FirewallRules...)
nm.Routes = append(nm.Routes, patch.Routes...)
nm.RoutesFirewallRules = append(nm.RoutesFirewallRules, patch.RouteFirewallRules...)
nm.ForwardingRules = append(nm.ForwardingRules, patch.ForwardingRules...)
if len(nm.RemotePeers) > 0 {
nm.RemotePeersIsEmpty = false
}
if len(nm.FirewallRules) > 0 {
nm.FirewallRulesIsEmpty = false
}
if len(nm.RoutesFirewallRules) > 0 {
nm.RoutesFirewallRulesIsEmpty = false
}
}
// appendUniquePeers dedupes by WgPubKey — mirrors legacy
// mergeUniquePeersByID's intent (legacy keyed off Peer.ID; in proto form the
// closest stable identifier is WgPubKey).
func appendUniquePeers(dst, extra []*proto.RemotePeerConfig) []*proto.RemotePeerConfig {
if len(extra) == 0 {
return dst
}
seen := make(map[string]struct{}, len(dst))
for _, p := range dst {
if p == nil {
continue
}
seen[p.WgPubKey] = struct{}{}
}
for _, p := range extra {
if p == nil {
continue
}
if _, ok := seen[p.WgPubKey]; ok {
continue
}
seen[p.WgPubKey] = struct{}{}
dst = append(dst, p)
}
return dst
}
func trimKey(s string) string {
if len(s) > 12 {
return s[:12]
}
return s
}
// canonicalizeWgKey normalises a base64-encoded WireGuard public key so it
// matches the canonical encoding emitted by the envelope decoder. Returns
// the input unchanged when it does not decode to 32 raw bytes (caller will
// hit a miss in the peer map and surface the error).
func canonicalizeWgKey(s string) string {
raw, err := base64.StdEncoding.DecodeString(s)
if err != nil || len(raw) != 32 {
return s
}
return base64.StdEncoding.EncodeToString(raw)
}

View File

@@ -0,0 +1,201 @@
package networkmap_test
import (
"context"
"crypto/rand"
"encoding/base64"
"net"
"net/netip"
"testing"
"github.com/stretchr/testify/require"
goproto "google.golang.org/protobuf/proto"
mgmtgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/types"
nbnetworkmap "github.com/netbirdio/netbird/shared/management/networkmap"
"github.com/netbirdio/netbird/shared/management/proto"
)
// TestEnvelopeToNetworkMap_RoundTrip exercises the full client-side pipeline:
// build a small components struct, encode an envelope, marshal/unmarshal the
// wire bytes, decode back via EnvelopeToNetworkMap, and verify the result is
// non-empty and consistent.
func TestEnvelopeToNetworkMap_RoundTrip(t *testing.T) {
c, localPeerKey := buildSmokeComponents(t)
envelope := mgmtgrpc.EncodeNetworkMapEnvelope(mgmtgrpc.ComponentsEnvelopeInput{
Components: c,
DNSDomain: "netbird.cloud",
})
wire, err := goproto.Marshal(envelope)
require.NoError(t, err, "marshal envelope")
var decoded proto.NetworkMapEnvelope
require.NoError(t, goproto.Unmarshal(wire, &decoded), "unmarshal envelope")
result, err := nbnetworkmap.EnvelopeToNetworkMap(context.Background(), &decoded, localPeerKey, "netbird.cloud")
require.NoError(t, err, "EnvelopeToNetworkMap")
require.NotNil(t, result)
require.NotNil(t, result.NetworkMap, "decoded NetworkMap must be non-nil")
require.NotNil(t, result.Components, "Components must be retained for future delta updates")
require.NotNil(t, result.Components.AccountSettings)
require.NotEmpty(t, result.NetworkMap.RemotePeers, "two-peer allow policy should produce one remote peer")
require.NotEmpty(t, result.NetworkMap.FirewallRules, "two-peer allow policy should produce firewall rules")
}
// TestCalculate_FirewallRuleProtocol_NeverNetbirdSSH guards against the
// scenario where a rule with Protocol=NetbirdSSH leaks the enum value into
// proto.FirewallRule.Protocol. Calculate() must rewrite NetbirdSSH → TCP
// before forming firewall rules. Without that rewrite, agents fall into
// UNKNOWN-protocol handling, which on some platforms downgrades to
// allow-all — a real security regression.
func TestCalculate_FirewallRuleProtocol_NeverNetbirdSSH(t *testing.T) {
c, localPeerKey := buildSmokeComponents(t)
// Replace the smoke policy with a NetbirdSSH-protocol allow.
c.Policies = []*types.Policy{{
ID: "pol-ssh", AccountSeqID: 2, Enabled: true,
Rules: []*types.PolicyRule{{
ID: "rule-ssh",
Enabled: true,
Action: types.PolicyTrafficActionAccept,
Protocol: types.PolicyRuleProtocolNetbirdSSH,
Bidirectional: true,
Sources: []string{"group-all"},
Destinations: []string{"group-all"},
}},
}}
envelope := mgmtgrpc.EncodeNetworkMapEnvelope(mgmtgrpc.ComponentsEnvelopeInput{
Components: c,
DNSDomain: "netbird.cloud",
})
wire, err := goproto.Marshal(envelope)
require.NoError(t, err)
var decoded proto.NetworkMapEnvelope
require.NoError(t, goproto.Unmarshal(wire, &decoded))
result, err := nbnetworkmap.EnvelopeToNetworkMap(context.Background(), &decoded, localPeerKey, "netbird.cloud")
require.NoError(t, err)
require.NotEmpty(t, result.NetworkMap.FirewallRules, "ssh policy should produce firewall rules")
for i, fr := range result.NetworkMap.FirewallRules {
require.NotEqualf(t, proto.RuleProtocol_NETBIRD_SSH, fr.Protocol,
"FirewallRules[%d].Protocol must be the rewritten TCP, not NETBIRD_SSH", i)
}
}
func TestEnvelopeToNetworkMap_NilEnvelope(t *testing.T) {
_, err := nbnetworkmap.EnvelopeToNetworkMap(context.Background(), nil, "key", "netbird.cloud")
require.Error(t, err, "nil envelope must produce an error rather than panic")
}
func TestEnvelopeToNetworkMap_FullPayloadMissing(t *testing.T) {
env := &proto.NetworkMapEnvelope{}
_, err := nbnetworkmap.EnvelopeToNetworkMap(context.Background(), env, "key", "netbird.cloud")
require.Error(t, err, "envelope with no Full payload must produce an error")
}
// TestDecodeEnvelope_MalformedWgKeyPeerSkipped feeds an envelope where one
// peer has a wg_pub_key that is not 32 bytes long. The decoder must skip
// that peer (keeping the rest of the snapshot usable) instead of aborting
// the whole sync — mirrors legacy behaviour that tolerates an occasional
// bad row.
func TestDecodeEnvelope_MalformedWgKeyPeerSkipped(t *testing.T) {
c, localPeerKey := buildSmokeComponents(t)
envelope := mgmtgrpc.EncodeNetworkMapEnvelope(mgmtgrpc.ComponentsEnvelopeInput{
Components: c,
DNSDomain: "netbird.cloud",
})
require.NotNil(t, envelope.GetFull())
full := envelope.GetFull()
require.Len(t, full.Peers, 2, "smoke fixture should have two peers")
// Truncate the second peer's wg_pub_key so it fails the length gate.
full.Peers[1].WgPubKey = full.Peers[1].WgPubKey[:31]
wire, err := goproto.Marshal(envelope)
require.NoError(t, err, "marshal envelope")
var decoded proto.NetworkMapEnvelope
require.NoError(t, goproto.Unmarshal(wire, &decoded), "unmarshal envelope")
result, err := nbnetworkmap.EnvelopeToNetworkMap(context.Background(), &decoded, localPeerKey, "netbird.cloud")
require.NoError(t, err, "EnvelopeToNetworkMap must tolerate one bad peer key")
require.NotNil(t, result)
require.NotNil(t, result.Components)
require.Len(t, result.Components.Peers, 1, "the well-formed peer survives, the malformed one is dropped")
}
// buildSmokeComponents returns a minimal NetworkMapComponents (2 peers, 1
// group, 1 allow policy) plus the receiving peer's WG public key. Sufficient
// to validate the encode → marshal → decode → Calculate pipeline produces
// non-empty output.
func buildSmokeComponents(t *testing.T) (*types.NetworkMapComponents, string) {
t.Helper()
peerAKey := randomWgKey(t)
peerBKey := randomWgKey(t)
peerA := &nbpeer.Peer{
ID: "peer-A",
Key: peerAKey,
IP: netip.AddrFrom4([4]byte{100, 64, 0, 1}),
DNSLabel: "peerA",
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.40.0"},
}
peerB := &nbpeer.Peer{
ID: "peer-B",
Key: peerBKey,
IP: netip.AddrFrom4([4]byte{100, 64, 0, 2}),
DNSLabel: "peerB",
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.40.0"},
}
group := &types.Group{
ID: "group-all", AccountSeqID: 1, Name: "All",
Peers: []string{"peer-A", "peer-B"},
}
policy := &types.Policy{
ID: "pol-allow", AccountSeqID: 1, Enabled: true,
Rules: []*types.PolicyRule{{
ID: "rule-allow",
Enabled: true,
Action: types.PolicyTrafficActionAccept,
Protocol: types.PolicyRuleProtocolALL,
Bidirectional: true,
Sources: []string{"group-all"},
Destinations: []string{"group-all"},
}},
}
c := &types.NetworkMapComponents{
PeerID: "peer-A",
Network: &types.Network{
Identifier: "net-smoke",
Net: net.IPNet{IP: net.IP{100, 64, 0, 0}, Mask: net.CIDRMask(10, 32)},
Serial: 1,
},
AccountSettings: &types.AccountSettingsInfo{},
DNSSettings: &types.DNSSettings{},
Peers: map[string]*nbpeer.Peer{
"peer-A": peerA,
"peer-B": peerB,
},
Groups: map[string]*types.Group{
"group-all": group,
},
Policies: []*types.Policy{policy},
}
return c, peerAKey
}
func randomWgKey(t *testing.T) string {
t.Helper()
var raw [32]byte
_, err := rand.Read(raw[:])
require.NoError(t, err)
return base64.StdEncoding.EncodeToString(raw[:])
}

File diff suppressed because it is too large Load Diff

View File

@@ -52,6 +52,14 @@ service ManagementService {
// Executes a job on a target peer (e.g., debug bundle) // Executes a job on a target peer (e.g., debug bundle)
rpc Job(stream EncryptedMessage) returns (stream EncryptedMessage) {} rpc Job(stream EncryptedMessage) returns (stream EncryptedMessage) {}
// ExtendAuthSession refreshes the peer's session expiry deadline using a fresh JWT.
// Same JWT validation pipeline as Login (including jwt.UserID == peer.UserID check),
// but does not redo the network-map sync. Only valid for SSO-registered peers where
// login expiration is enabled. The tunnel remains up.
// EncryptedMessage of the request has a body of ExtendAuthSessionRequest.
// EncryptedMessage of the response has a body of ExtendAuthSessionResponse.
rpc ExtendAuthSession(EncryptedMessage) returns (EncryptedMessage) {}
// CreateExpose creates a temporary reverse proxy service for a peer // CreateExpose creates a temporary reverse proxy service for a peer
rpc CreateExpose(EncryptedMessage) returns (EncryptedMessage) {} rpc CreateExpose(EncryptedMessage) returns (EncryptedMessage) {}
@@ -133,6 +141,21 @@ message SyncResponse {
// Posture checks to be evaluated by client // Posture checks to be evaluated by client
repeated Checks Checks = 6; repeated Checks Checks = 6;
// 3-state session deadline. Carried on every Sync snapshot so admin-side
// changes propagate live without a client reconnect.
// field unset (nil) → snapshot carries no info; client keeps the
// deadline it already had
// set, seconds=0 nanos=0 → explicit "expiry disabled" or peer is not
// SSO-registered; client clears its anchor
// set, valid timestamp → new absolute UTC deadline
google.protobuf.Timestamp sessionExpiresAt = 7;
// NetworkMapEnvelope carries the component-based wire format for peers that
// advertise PeerCapabilityComponentNetworkMap. When set, NetworkMap (field 5)
// is left empty: management ships components and the client runs Calculate()
// locally instead of receiving an expanded NetworkMap.
NetworkMapEnvelope NetworkMapEnvelope = 8;
} }
message SyncMetaRequest { message SyncMetaRequest {
@@ -212,6 +235,8 @@ enum PeerCapability {
PeerCapabilitySourcePrefixes = 1; PeerCapabilitySourcePrefixes = 1;
// Client handles IPv6 overlay addresses and firewall rules. // Client handles IPv6 overlay addresses and firewall rules.
PeerCapabilityIPv6Overlay = 2; PeerCapabilityIPv6Overlay = 2;
// Client receives NetworkMap as components and assembles it locally.
PeerCapabilityComponentNetworkMap = 3;
} }
// PeerSystemMeta is machine meta data like OS and version. // PeerSystemMeta is machine meta data like OS and version.
@@ -244,6 +269,31 @@ message LoginResponse {
PeerConfig peerConfig = 2; PeerConfig peerConfig = 2;
// Posture checks to be evaluated by client // Posture checks to be evaluated by client
repeated Checks Checks = 3; repeated Checks Checks = 3;
// 3-state session deadline; same encoding as SyncResponse.sessionExpiresAt.
// field unset (nil) → no info; client keeps any deadline it had
// set, seconds=0 nanos=0 → explicit "expiry disabled" / non-SSO peer
// set, valid timestamp → new absolute UTC deadline
google.protobuf.Timestamp sessionExpiresAt = 4;
}
// ExtendAuthSessionRequest carries a fresh JWT to refresh the peer's session deadline.
// The encrypted body of an EncryptedMessage with this payload is sent to the
// ExtendAuthSession RPC.
message ExtendAuthSessionRequest {
// SSO token (must be a fresh, valid JWT for the peer's owning user)
string jwtToken = 1;
// Meta data of the peer (used for IdP user info refresh consistent with Login)
PeerSystemMeta meta = 2;
}
// ExtendAuthSessionResponse contains the refreshed session deadline.
message ExtendAuthSessionResponse {
// 3-state session deadline; same encoding as SyncResponse.sessionExpiresAt.
// In practice ExtendAuthSession only succeeds for SSO peers with expiry
// enabled, so this carries a valid timestamp on the success path. The
// 3-state encoding is documented here for symmetry with Login/Sync.
google.protobuf.Timestamp sessionExpiresAt = 1;
} }
message ServerKeyResponse { message ServerKeyResponse {
@@ -569,6 +619,13 @@ enum RuleProtocol {
UDP = 3; UDP = 3;
ICMP = 4; ICMP = 4;
CUSTOM = 5; CUSTOM = 5;
// NETBIRD_SSH (types.PolicyRuleProtocolType "netbird-ssh") is the marker
// policy rule that drives SSH-server activation in Calculate(). The legacy
// proto.FirewallRule path doesn't ship this value (Calculate already
// expands SSH rules into TCP/22 before encoding), but the components path
// ships RAW policies — the client must see this protocol to derive
// AuthorizedUsers locally.
NETBIRD_SSH = 6;
} }
enum RuleDirection { enum RuleDirection {
@@ -709,3 +766,462 @@ message StopExposeRequest {
} }
message StopExposeResponse {} message StopExposeResponse {}
// =====================================================================
// Component-based NetworkMap wire format (PeerCapabilityComponentNetworkMap).
//
// Peers that advertise this capability receive NetworkMap building blocks
// (peers + groups + policies + routes + dns + ssh + forwarding) and run the
// expansion (Calculate) locally instead of receiving a fully-expanded
// NetworkMap from the server.
// =====================================================================
// NetworkMapEnvelope wraps either a full snapshot or a delta. Only Full is
// emitted today; Delta is reserved for the incremental-update work.
message NetworkMapEnvelope {
oneof payload {
NetworkMapComponentsFull full = 1;
NetworkMapComponentsDelta delta = 2;
}
}
// NetworkMapComponentsFull is the full per-peer component snapshot. The
// client decodes it into a types.NetworkMapComponents and runs Calculate()
// locally to produce the same NetworkMap the legacy server path would have
// produced. Every field carries RAW component data — no server-side
// expansion (firewall rules, DNS config, SSH auth, route firewall rules,
// forwarding rules) is shipped; the client computes those itself.
message NetworkMapComponentsFull {
uint64 serial = 1;
// Peer config for the receiving peer (legacy proto.PeerConfig kept as-is —
// it carries the receiving peer's own overlay address, FQDN, SSH config).
PeerConfig peer_config = 2;
// Account-level network metadata (id, IPv4/IPv6 overlay subnets, DNS,
// serial). Mirrors types.Network.
AccountNetwork network = 3;
// Account-level settings the client needs for its local Calculate().
AccountSettingsCompact account_settings = 4;
// Account DNS settings (mirrors types.DNSSettings).
DNSSettingsCompact dns_settings = 5;
// Domain shared across all peers in this account, e.g. "netbird.cloud".
// Each peer's FQDN is dns_label + "." + dns_domain.
string dns_domain = 6;
// Custom-zone domain for this peer's view (c.CustomZoneDomain). Empty when
// the peer has no custom zone records.
string custom_zone_domain = 7;
// Deduplicated agent versions; PeerCompact.agent_version_idx indexes here.
// Empty string at index 0 if any peer has no version.
repeated string agent_versions = 8;
// All peers (deduplicated). The client splits peers into online / offline
// locally using account_settings.peer_login_expiration on receive.
repeated PeerCompact peers = 9;
// Indexes into peers for the subset that may act as routers.
repeated uint32 router_peer_indexes = 10;
// Policies that affect the receiving peer.
repeated PolicyCompact policies = 11;
// Groups in unspecified order — clients key off id (account_seq_id).
repeated GroupCompact groups = 12;
// Routes relevant to this peer, raw shape (mirrors []*route.Route).
repeated RouteRaw routes = 13;
// Nameserver groups (mirrors []*nbdns.NameServerGroup).
repeated NameServerGroupRaw nameserver_groups = 14;
// All DNS records the client needs to assemble its custom zone. Reuses
// the existing SimpleRecord wire shape.
repeated SimpleRecord all_dns_records = 15;
// Custom zones (typically the peer's own zone). Reuses the existing
// CustomZone wire shape.
repeated CustomZone account_zones = 16;
// Network resources (mirrors []*resourceTypes.NetworkResource).
repeated NetworkResourceRaw network_resources = 17;
// Routers per network. Outer key: network account_seq_id. Each entry is
// the set of routers backing that network for this peer's view.
//
// INCOMPATIBLE WIRE CHANGE: the map key changed from string (network xid)
// to uint32 (account_seq_id). Field 18 was reused without a `reserved`
// entry because capability=3 has never been released — every cap=3
// producer and consumer carries the same regenerated descriptor. Do NOT
// reuse this pattern for any further wire change once cap=3 ships.
map<uint32, NetworkRouterList> routers_map = 18;
// For each NetworkResource account_seq_id, the indexes into policies[]
// that apply to it.
//
// INCOMPATIBLE WIRE CHANGE: see routers_map note above.
map<uint32, PolicyIndexes> resource_policies_map = 19;
// Group-id (account_seq_id) → user ids authorized for SSH on members.
map<uint32, UserIDList> group_id_to_user_ids = 20;
// Account-level allowed user ids (used by Calculate() when assembling SSH
// authorized users for the receiving peer).
repeated string allowed_user_ids = 21;
// Per posture-check account_seq_id, the set of peer indexes that failed
// the check. Server-side evaluation result; clients do not re-evaluate.
//
// INCOMPATIBLE WIRE CHANGE: see routers_map note above.
map<uint32, PeerIndexSet> posture_failed_peers = 22;
// Account-level DNS forwarder port (mirrors the legacy
// proto.DNSConfig.ForwarderPort). Computed by the controller from peer
// versions; clients fold it into their Calculate() DNS output.
int64 dns_forwarder_port = 23;
// Pre-expanded NetworkMap fragments injected post-Calculate by external
// controllers (BYOP / port-forwarding proxies). The receiving client
// merges these into its locally-computed NetworkMap the same way the
// legacy server does via NetworkMap.Merge — so downstream consumers see
// a unified merged result regardless of source.
ProxyPatch proxy_patch = 24;
// SSH UserIDClaim — server-side HttpServerConfig.AuthUserIDClaim, or
// "sub" by default. Populated in proto.SSHAuth.UserIDClaim when the
// client rebuilds the NetworkMap from this envelope. Empty when the
// account has no AuthorizedUsers (and thus no SshAuth to populate).
string user_id_claim = 25;
// Reserved for future component additions (incremental_serial, parent_seq,
// etc.) without forcing a renumber.
reserved 26 to 50;
}
// ProxyPatch carries NetworkMap fragments that don't fit the component-graph
// model — they're pre-expanded by external controllers (BYOP /
// port-forwarding proxies) and injected post-Calculate. Fields use the
// legacy wire types because the proxy delivers them pre-formed; there is
// no raw component shape to convert from. Empty when no proxy is active.
message ProxyPatch {
repeated RemotePeerConfig peers = 1;
repeated RemotePeerConfig offline_peers = 2;
repeated FirewallRule firewall_rules = 3;
repeated Route routes = 4;
repeated RouteFirewallRule route_firewall_rules = 5;
repeated ForwardingRule forwarding_rules = 6;
}
// AccountSettingsCompact carries the account-level settings the client needs
// to evaluate locally. Mirrors the subset of types.AccountSettingsInfo that
// Calculate() actually reads — login-expiration (used to filter expired
// peers). Inactivity expiration is purely server-side bookkeeping and is not
// shipped.
message AccountSettingsCompact {
bool peer_login_expiration_enabled = 1;
// Login expiration window. Unit is nanoseconds (matches time.Duration).
int64 peer_login_expiration_ns = 2;
}
// AccountNetwork is the account-level overlay metadata. Mirrors types.Network
// so the client can populate NetworkMap.Network without a server round-trip.
message AccountNetwork {
string identifier = 1;
// IPv4 overlay subnet in CIDR form (e.g. "100.64.0.0/16").
string net_cidr = 2;
// IPv6 ULA overlay subnet in CIDR form (e.g. "fd00:4e42::/64"). Empty when
// the account has no IPv6 overlay yet.
string net_v6_cidr = 3;
string dns = 4;
uint64 serial = 5;
}
// NetworkMapComponentsDelta is reserved for the incremental update
// protocol. Field numbers 1100 are pre-allocated to keep room for the
// planned event types without needing a renumber.
message NetworkMapComponentsDelta {
reserved 1 to 100;
}
// PeerCompact is the wire-shape of a remote peer used by the component
// format. It carries every field of types.Peer that the client's local
// Calculate() reads — including the trio needed to evaluate
// LoginExpired() (added_with_sso_login + login_expiration_enabled +
// last_login_unix_nano). Fields the client does not consume (Status,
// CreatedAt, etc.) are not shipped.
message PeerCompact {
// Raw 32-byte WireGuard public key (no base64 wrapping).
bytes wg_pub_key = 1;
// Raw 4-byte IPv4 overlay address. Always a /32 host route, so no prefix
// byte is needed.
bytes ip = 2;
// Raw 16-byte IPv6 overlay address; always a /128 host route. Empty when
// the peer has no IPv6 overlay address.
bytes ipv6 = 3;
// Raw SSH public key bytes (or empty).
bytes ssh_pub_key = 4;
// DNS label without the account's domain suffix. Full FQDN is
// dns_label + "." + NetworkMapComponentsFull.dns_domain.
string dns_label = 5;
// Index into NetworkMapComponentsFull.agent_versions.
uint32 agent_version_idx = 6;
// True iff the peer was added via SSO login (i.e., types.Peer.UserID is
// non-empty). Combined with login_expiration_enabled and
// last_login_unix_nano this lets the client reproduce
// (*Peer).LoginExpired() locally.
bool added_with_sso_login = 7;
// True when the peer's login can expire — mirrors
// types.Peer.LoginExpirationEnabled.
bool login_expiration_enabled = 8;
// Unix-nanosecond timestamp of the peer's last login. 0 when the peer has
// never logged in (server stores nil; client treats 0 as "epoch", which
// makes a fresh peer immediately expired iff login_expiration_enabled is
// true — the same semantics as types.Peer.GetLastLogin).
int64 last_login_unix_nano = 9;
// True when the peer has an SSH server enabled locally. Used by the
// legacy SSH path in Calculate() (`policyRuleImpliesLegacySSH`): a rule
// with protocol ALL/TCP-with-SSH-ports activates SSH for the receiving
// peer when this bit is set, even without an explicit NetbirdSSH rule.
bool ssh_enabled = 10;
reserved 11; // was: id (string xid)
// Mirror of types.Peer.SupportsIPv6() — !Meta.Flags.DisableIPv6 &&
// HasCapability(PeerCapabilityIPv6Overlay). Used by the local peer's
// Calculate() when deciding whether to emit IPv6 firewall rules
// (appendIPv6FirewallRule) against this peer's IPv6 address.
bool supports_ipv6 = 12;
// Mirror of types.Peer.SupportsSourcePrefixes() —
// HasCapability(PeerCapabilitySourcePrefixes). Determines whether the
// local peer's Calculate() emits SourcePrefixes alongside legacy PeerIP
// fields in proto.FirewallRule.
bool supports_source_prefixes = 13;
// Mirror of types.Peer.Meta.Flags.ServerSSHAllowed. Read by Calculate()
// when expanding TCP port-22 firewall rules — the native SSH companion
// (port 22022) is only added when this flag is set and the peer agent
// version supports it.
bool server_ssh_allowed = 14;
}
// PolicyCompact is the compact form of a policy rule. Group references use
// the per-account integer ids from account_seq_counters; the client resolves
// them against NetworkMapComponentsFull.groups. Direction is derived per-peer
// on the client (ingress when the peer is in destination_group_ids, egress
// when in source_group_ids; both when bidirectional).
message PolicyCompact {
// Per-account integer id (matches policies.account_seq_id). Used as a
// stable reference for ResourcePoliciesMap.indexes and future delta
// updates.
uint32 id = 1;
RuleAction action = 2;
RuleProtocol protocol = 3;
bool bidirectional = 4;
// Single ports referenced by the rule.
repeated uint32 ports = 5;
// Port ranges (start..end) referenced by the rule.
repeated PortInfo.Range port_ranges = 6;
// Group ids (account_seq_id) of source / destination groups.
repeated uint32 source_group_ids = 7;
repeated uint32 destination_group_ids = 8;
reserved 9; // was: xid (string)
// SSH authorization fields. PolicyRule.AuthorizedGroups maps the rule's
// applicable group ids (account_seq_id) to a list of local-user names —
// when a peer in one of those groups is the SSH destination, the named
// local users gain access. AuthorizedUser is the single-user form
// (legacy: rule scopes SSH to one specific user id).
//
// Both fields are only consumed by Calculate() when the rule's protocol
// is NetbirdSSH (or the legacy implicit-SSH heuristic).
map<uint32, UserNameList> authorized_groups = 10;
string authorized_user = 11;
// Resource-typed rule sources/destinations. When a rule targets a specific
// peer (rather than groups), Calculate() reads SourceResource /
// DestinationResource — without these the rule's connection resources
// can't be produced on the client. ResourceCompact's peer_index refers to
// NetworkMapComponentsFull.peers; type is the raw ResourceType string
// ("peer", "host", "subnet", "domain"). Only "peer" is meaningful for
// Calculate's resource-typed rule path today.
ResourceCompact source_resource = 12;
ResourceCompact destination_resource = 13;
// Posture-check seq ids gating this policy's source peers. Calculate()
// reads them when filtering rule peers (peers that fail any listed check
// are dropped from sourcePeers). Match keys in
// NetworkMapComponentsFull.posture_failed_peers.
repeated uint32 source_posture_check_seq_ids = 15;
reserved 14; // was: source_posture_check_ids (repeated string xid)
}
// ResourceCompact mirrors types.Resource. Used by PolicyCompact to carry
// rule.SourceResource / rule.DestinationResource when the rule targets a
// specific resource (typically a peer) rather than groups.
// peer_index_set tells whether peer_index is valid (proto3 uint32 cannot
// disambiguate "0" from "unset"); set only when type == "peer".
message ResourceCompact {
string type = 1;
bool peer_index_set = 2;
uint32 peer_index = 3;
reserved 4; // future: host/subnet/domain references when needed
}
// UserNameList is a list of local-user names — used as the value type in
// PolicyCompact.authorized_groups.
message UserNameList {
repeated string names = 1;
}
// GroupCompact is the wire-shape of a group: per-account integer id, optional
// name, and indexes into NetworkMapComponentsFull.peers identifying members.
message GroupCompact {
// Per-account integer id (matches groups.account_seq_id). Used by
// PolicyCompact.source_group_ids / destination_group_ids.
uint32 id = 1;
// Group name; only sent when non-empty (clients use it for diagnostics).
string name = 2;
// Indexes into NetworkMapComponentsFull.peers.
repeated uint32 peer_indexes = 3;
}
// DNSSettingsCompact mirrors types.DNSSettings.
message DNSSettingsCompact {
// Group ids (account_seq_id) whose DNS management is disabled.
repeated uint32 disabled_management_group_ids = 1;
}
// RouteRaw mirrors *route.Route (the domain type), trimmed to fields that
// types.NetworkMapComponents.Calculate() reads. Group references are
// account_seq_ids; the routing peer (when set) is referenced by index into
// NetworkMapComponentsFull.peers.
message RouteRaw {
// Per-account integer id (matches routes.account_seq_id).
uint32 id = 1;
string net_id = 2;
string description = 3;
// Either network_cidr (e.g. "10.0.0.0/16") or domains is set, not both.
string network_cidr = 4;
repeated string domains = 5;
bool keep_route = 6;
// Routing peer reference: peer_index_set tells whether peer_index is valid
// (proto3 uint32 cannot disambiguate "0" from "unset"). Mutually exclusive
// with peer_group_ids.
//
// peer_index decodes back to types.Peer.ID (the peer's xid string), NOT
// to its WireGuard public key. This matches the server-side data flow:
// c.Routes carry route.Peer = peer.ID, and getRoutingPeerRoutes mutates
// it to peer.Key only after the route has been admitted to the network
// map. Decoders MUST set Route.Peer = peer.ID; the legacy Calculate()
// path will substitute the WG key downstream.
bool peer_index_set = 7;
uint32 peer_index = 8;
repeated uint32 peer_group_ids = 9;
int32 network_type = 10;
bool masquerade = 11;
int32 metric = 12;
bool enabled = 13;
repeated uint32 group_ids = 14;
repeated uint32 access_control_group_ids = 15;
bool skip_auto_apply = 16;
reserved 17; // was: xid (string)
}
// NameServerGroupRaw mirrors *nbdns.NameServerGroup. Distinct from the
// legacy NameServerGroup (which is the wire-trimmed shape consumed by
// proto.DNSConfig and lacks the Name/Description/Groups/Enabled fields).
message NameServerGroupRaw {
uint32 id = 1; // nameserver_groups.account_seq_id
string name = 2;
string description = 3;
// Reuses the legacy NameServer wire shape (IP as string).
repeated NameServer nameservers = 4;
// Group ids (account_seq_id) the NSG distributes nameservers to.
repeated uint32 group_ids = 5;
bool primary = 6;
repeated string domains = 7;
bool enabled = 8;
bool search_domains_enabled = 9;
}
// NetworkResourceRaw mirrors *resourceTypes.NetworkResource.
//
// INCOMPATIBLE WIRE CHANGE: field 2 changed from `string network_id` (xid)
// to `uint32 network_seq` without a `reserved` entry. Safe only because
// capability=3 has never been released — every cap=3 producer and consumer
// carries the same regenerated descriptor. Do NOT reuse this pattern once
// cap=3 ships.
message NetworkResourceRaw {
uint32 id = 1; // network_resources.account_seq_id
uint32 network_seq = 2; // networks.account_seq_id (replaces xid)
string name = 3;
string description = 4;
// Resource type: "host" / "subnet" / "domain".
string type = 5;
string address = 6;
string domain_value = 7; // resource.Domain
string prefix_cidr = 8;
bool enabled = 9;
reserved 10; // was: xid (string)
}
// NetworkRouterList carries the routers backing one network.
message NetworkRouterList {
// Routers in this network, keyed by peer_index (the routing peer).
repeated NetworkRouterEntry entries = 1;
}
// NetworkRouterEntry mirrors a single *routerTypes.NetworkRouter; the routing
// peer is referenced by index into NetworkMapComponentsFull.peers.
message NetworkRouterEntry {
uint32 id = 1; // network_routers.account_seq_id
uint32 peer_index = 2;
bool peer_index_set = 3;
repeated uint32 peer_group_ids = 4;
bool masquerade = 5;
int32 metric = 6;
bool enabled = 7;
}
// PolicyIndexes is a list of indexes into NetworkMapComponentsFull.policies.
message PolicyIndexes {
repeated uint32 indexes = 1;
}
// UserIDList is a list of user ids — used as the value type in
// NetworkMapComponentsFull.group_id_to_user_ids.
message UserIDList {
repeated string user_ids = 1;
}
// PeerIndexSet is a set of peer indexes — used as the value type in
// NetworkMapComponentsFull.posture_failed_peers.
message PeerIndexSet {
repeated uint32 peer_indexes = 1;
}

View File

@@ -52,6 +52,13 @@ type ManagementServiceClient interface {
Logout(ctx context.Context, in *EncryptedMessage, opts ...grpc.CallOption) (*Empty, error) Logout(ctx context.Context, in *EncryptedMessage, opts ...grpc.CallOption) (*Empty, error)
// Executes a job on a target peer (e.g., debug bundle) // Executes a job on a target peer (e.g., debug bundle)
Job(ctx context.Context, opts ...grpc.CallOption) (ManagementService_JobClient, error) Job(ctx context.Context, opts ...grpc.CallOption) (ManagementService_JobClient, error)
// ExtendAuthSession refreshes the peer's session expiry deadline using a fresh JWT.
// Same JWT validation pipeline as Login (including jwt.UserID == peer.UserID check),
// but does not redo the network-map sync. Only valid for SSO-registered peers where
// login expiration is enabled. The tunnel remains up.
// EncryptedMessage of the request has a body of ExtendAuthSessionRequest.
// EncryptedMessage of the response has a body of ExtendAuthSessionResponse.
ExtendAuthSession(ctx context.Context, in *EncryptedMessage, opts ...grpc.CallOption) (*EncryptedMessage, error)
// CreateExpose creates a temporary reverse proxy service for a peer // CreateExpose creates a temporary reverse proxy service for a peer
CreateExpose(ctx context.Context, in *EncryptedMessage, opts ...grpc.CallOption) (*EncryptedMessage, error) CreateExpose(ctx context.Context, in *EncryptedMessage, opts ...grpc.CallOption) (*EncryptedMessage, error)
// RenewExpose extends the TTL of an active expose session // RenewExpose extends the TTL of an active expose session
@@ -194,6 +201,15 @@ func (x *managementServiceJobClient) Recv() (*EncryptedMessage, error) {
return m, nil return m, nil
} }
func (c *managementServiceClient) ExtendAuthSession(ctx context.Context, in *EncryptedMessage, opts ...grpc.CallOption) (*EncryptedMessage, error) {
out := new(EncryptedMessage)
err := c.cc.Invoke(ctx, "/management.ManagementService/ExtendAuthSession", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *managementServiceClient) CreateExpose(ctx context.Context, in *EncryptedMessage, opts ...grpc.CallOption) (*EncryptedMessage, error) { func (c *managementServiceClient) CreateExpose(ctx context.Context, in *EncryptedMessage, opts ...grpc.CallOption) (*EncryptedMessage, error) {
out := new(EncryptedMessage) out := new(EncryptedMessage)
err := c.cc.Invoke(ctx, "/management.ManagementService/CreateExpose", in, out, opts...) err := c.cc.Invoke(ctx, "/management.ManagementService/CreateExpose", in, out, opts...)
@@ -259,6 +275,13 @@ type ManagementServiceServer interface {
Logout(context.Context, *EncryptedMessage) (*Empty, error) Logout(context.Context, *EncryptedMessage) (*Empty, error)
// Executes a job on a target peer (e.g., debug bundle) // Executes a job on a target peer (e.g., debug bundle)
Job(ManagementService_JobServer) error Job(ManagementService_JobServer) error
// ExtendAuthSession refreshes the peer's session expiry deadline using a fresh JWT.
// Same JWT validation pipeline as Login (including jwt.UserID == peer.UserID check),
// but does not redo the network-map sync. Only valid for SSO-registered peers where
// login expiration is enabled. The tunnel remains up.
// EncryptedMessage of the request has a body of ExtendAuthSessionRequest.
// EncryptedMessage of the response has a body of ExtendAuthSessionResponse.
ExtendAuthSession(context.Context, *EncryptedMessage) (*EncryptedMessage, error)
// CreateExpose creates a temporary reverse proxy service for a peer // CreateExpose creates a temporary reverse proxy service for a peer
CreateExpose(context.Context, *EncryptedMessage) (*EncryptedMessage, error) CreateExpose(context.Context, *EncryptedMessage) (*EncryptedMessage, error)
// RenewExpose extends the TTL of an active expose session // RenewExpose extends the TTL of an active expose session
@@ -299,6 +322,9 @@ func (UnimplementedManagementServiceServer) Logout(context.Context, *EncryptedMe
func (UnimplementedManagementServiceServer) Job(ManagementService_JobServer) error { func (UnimplementedManagementServiceServer) Job(ManagementService_JobServer) error {
return status.Errorf(codes.Unimplemented, "method Job not implemented") return status.Errorf(codes.Unimplemented, "method Job not implemented")
} }
func (UnimplementedManagementServiceServer) ExtendAuthSession(context.Context, *EncryptedMessage) (*EncryptedMessage, error) {
return nil, status.Errorf(codes.Unimplemented, "method ExtendAuthSession not implemented")
}
func (UnimplementedManagementServiceServer) CreateExpose(context.Context, *EncryptedMessage) (*EncryptedMessage, error) { func (UnimplementedManagementServiceServer) CreateExpose(context.Context, *EncryptedMessage) (*EncryptedMessage, error) {
return nil, status.Errorf(codes.Unimplemented, "method CreateExpose not implemented") return nil, status.Errorf(codes.Unimplemented, "method CreateExpose not implemented")
} }
@@ -494,6 +520,24 @@ func (x *managementServiceJobServer) Recv() (*EncryptedMessage, error) {
return m, nil return m, nil
} }
func _ManagementService_ExtendAuthSession_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(EncryptedMessage)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(ManagementServiceServer).ExtendAuthSession(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/management.ManagementService/ExtendAuthSession",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(ManagementServiceServer).ExtendAuthSession(ctx, req.(*EncryptedMessage))
}
return interceptor(ctx, in, info, handler)
}
func _ManagementService_CreateExpose_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { func _ManagementService_CreateExpose_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(EncryptedMessage) in := new(EncryptedMessage)
if err := dec(in); err != nil { if err := dec(in); err != nil {
@@ -583,6 +627,10 @@ var ManagementService_ServiceDesc = grpc.ServiceDesc{
MethodName: "Logout", MethodName: "Logout",
Handler: _ManagementService_Logout_Handler, Handler: _ManagementService_Logout_Handler,
}, },
{
MethodName: "ExtendAuthSession",
Handler: _ManagementService_ExtendAuthSession_Handler,
},
{ {
MethodName: "CreateExpose", MethodName: "CreateExpose",
Handler: _ManagementService_CreateExpose_Handler, Handler: _ManagementService_CreateExpose_Handler,

View File

@@ -0,0 +1,131 @@
package types
import (
"strconv"
"strings"
"github.com/netbirdio/netbird/management/server/posture"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
)
const (
firewallRuleMinPortRangesVer = "0.48.0"
firewallRuleMinNativeSSHVer = "0.60.0"
nativeSSHPortString = "22022"
nativeSSHPortNumber = 22022
defaultSSHPortString = "22"
defaultSSHPortNumber = 22
)
type supportedFeatures struct {
nativeSSH bool
portRanges bool
}
type LookupMap map[string]struct{}
func PolicyRuleImpliesLegacySSH(rule *PolicyRule) bool {
return rule.Protocol == PolicyRuleProtocolALL || (rule.Protocol == PolicyRuleProtocolTCP && (portsIncludesSSH(rule.Ports) || portRangeIncludesSSH(rule.PortRanges)))
}
func portRangeIncludesSSH(portRanges []RulePortRange) bool {
for _, pr := range portRanges {
if (pr.Start <= defaultSSHPortNumber && pr.End >= defaultSSHPortNumber) || (pr.Start <= nativeSSHPortNumber && pr.End >= nativeSSHPortNumber) {
return true
}
}
return false
}
func portsIncludesSSH(ports []string) bool {
for _, port := range ports {
if port == defaultSSHPortString || port == nativeSSHPortString {
return true
}
}
return false
}
// ExpandPortsAndRanges expands Ports and PortRanges of a rule into individual firewall rules.
func ExpandPortsAndRanges(base FirewallRule, rule *PolicyRule, peer *nbpeer.Peer) []*FirewallRule {
features := peerSupportedFirewallFeatures(peer.Meta.WtVersion)
var expanded []*FirewallRule
for _, port := range rule.Ports {
fr := base
fr.Port = port
expanded = append(expanded, &fr)
}
for _, portRange := range rule.PortRanges {
if len(rule.Ports) > 0 {
break
}
fr := base
if features.portRanges {
fr.PortRange = portRange
} else {
if portRange.Start != portRange.End {
continue
}
fr.Port = strconv.FormatUint(uint64(portRange.Start), 10)
}
expanded = append(expanded, &fr)
}
if shouldCheckRulesForNativeSSH(features.nativeSSH, rule, peer) || rule.Protocol == PolicyRuleProtocolNetbirdSSH {
expanded = addNativeSSHRule(base, expanded)
}
return expanded
}
func addNativeSSHRule(base FirewallRule, expanded []*FirewallRule) []*FirewallRule {
shouldAdd := false
for _, fr := range expanded {
if isPortInRule(nativeSSHPortString, 22022, fr) {
return expanded
}
if isPortInRule(defaultSSHPortString, 22, fr) {
shouldAdd = true
}
}
if !shouldAdd {
return expanded
}
fr := base
fr.Port = nativeSSHPortString
return append(expanded, &fr)
}
func isPortInRule(portString string, portInt uint16, rule *FirewallRule) bool {
return rule.Port == portString || (rule.PortRange.Start <= portInt && portInt <= rule.PortRange.End)
}
func shouldCheckRulesForNativeSSH(supportsNative bool, rule *PolicyRule, peer *nbpeer.Peer) bool {
return supportsNative && peer.SSHEnabled && peer.Meta.Flags.ServerSSHAllowed && rule.Protocol == PolicyRuleProtocolTCP
}
func peerSupportedFirewallFeatures(peerVer string) supportedFeatures {
if strings.Contains(peerVer, "dev") {
return supportedFeatures{true, true}
}
var features supportedFeatures
meetMinVer, err := posture.MeetsMinVersion(firewallRuleMinNativeSSHVer, peerVer)
features.nativeSSH = err == nil && meetMinVer
if features.nativeSSH {
features.portRanges = true
} else {
meetMinVer, err = posture.MeetsMinVersion(firewallRuleMinPortRangesVer, peerVer)
features.portRanges = err == nil && meetMinVer
}
return features
}

View File

@@ -47,11 +47,11 @@ func (r *FirewallRule) Equal(other *FirewallRule) bool {
return reflect.DeepEqual(r, other) return reflect.DeepEqual(r, other)
} }
// generateRouteFirewallRules generates a list of firewall rules for a given route. // GenerateRouteFirewallRules generates a list of firewall rules for a given route.
// For static routes, source ranges match the destination family (v4 or v6). // For static routes, source ranges match the destination family (v4 or v6).
// For dynamic routes (domain-based), separate v4 and v6 rules are generated // For dynamic routes (domain-based), separate v4 and v6 rules are generated
// so the routing peer's forwarding chain allows both address families. // so the routing peer's forwarding chain allows both address families.
func generateRouteFirewallRules(ctx context.Context, route *nbroute.Route, rule *PolicyRule, groupPeers []*nbpeer.Peer, direction int, includeIPv6 bool) []*RouteFirewallRule { func GenerateRouteFirewallRules(ctx context.Context, route *nbroute.Route, rule *PolicyRule, groupPeers []*nbpeer.Peer, direction int, includeIPv6 bool) []*RouteFirewallRule {
rulesExists := make(map[string]struct{}) rulesExists := make(map[string]struct{})
rules := make([]*RouteFirewallRule, 0) rules := make([]*RouteFirewallRule, 0)

View File

@@ -57,7 +57,7 @@ func TestGenerateRouteFirewallRules_V4Route(t *testing.T) {
Protocol: PolicyRuleProtocolALL, Protocol: PolicyRuleProtocolALL,
} }
rules := generateRouteFirewallRules(context.Background(), r, rule, peers, FirewallRuleDirectionIN, true) rules := GenerateRouteFirewallRules(context.Background(), r, rule, peers, FirewallRuleDirectionIN, true)
require.Len(t, rules, 1) require.Len(t, rules, 1)
assert.Equal(t, []string{"100.64.0.1/32", "100.64.0.2/32"}, rules[0].SourceRanges, "v4 route should only have v4 sources") assert.Equal(t, []string{"100.64.0.1/32", "100.64.0.2/32"}, rules[0].SourceRanges, "v4 route should only have v4 sources")
@@ -86,7 +86,7 @@ func TestGenerateRouteFirewallRules_V6Route(t *testing.T) {
Protocol: PolicyRuleProtocolALL, Protocol: PolicyRuleProtocolALL,
} }
rules := generateRouteFirewallRules(context.Background(), r, rule, peers, FirewallRuleDirectionIN, true) rules := GenerateRouteFirewallRules(context.Background(), r, rule, peers, FirewallRuleDirectionIN, true)
require.Len(t, rules, 1) require.Len(t, rules, 1)
assert.Equal(t, []string{"fd00::1/128"}, rules[0].SourceRanges, "v6 route should only have v6 sources") assert.Equal(t, []string{"fd00::1/128"}, rules[0].SourceRanges, "v6 route should only have v6 sources")
@@ -115,7 +115,7 @@ func TestGenerateRouteFirewallRules_DynamicRoute_DualStack(t *testing.T) {
Protocol: PolicyRuleProtocolALL, Protocol: PolicyRuleProtocolALL,
} }
rules := generateRouteFirewallRules(context.Background(), r, rule, peers, FirewallRuleDirectionIN, true) rules := GenerateRouteFirewallRules(context.Background(), r, rule, peers, FirewallRuleDirectionIN, true)
require.Len(t, rules, 2, "dynamic route should produce both v4 and v6 rules") require.Len(t, rules, 2, "dynamic route should produce both v4 and v6 rules")
assert.Equal(t, []string{"100.64.0.1/32", "100.64.0.2/32"}, rules[0].SourceRanges) assert.Equal(t, []string{"100.64.0.1/32", "100.64.0.2/32"}, rules[0].SourceRanges)
@@ -143,7 +143,7 @@ func TestGenerateRouteFirewallRules_DynamicRoute_NoV6Peers(t *testing.T) {
Protocol: PolicyRuleProtocolALL, Protocol: PolicyRuleProtocolALL,
} }
rules := generateRouteFirewallRules(context.Background(), r, rule, peers, FirewallRuleDirectionIN, true) rules := GenerateRouteFirewallRules(context.Background(), r, rule, peers, FirewallRuleDirectionIN, true)
require.Len(t, rules, 1, "no v6 peers means only v4 rule") require.Len(t, rules, 1, "no v6 peers means only v4 rule")
assert.Equal(t, []string{"100.64.0.1/32", "100.64.0.2/32"}, rules[0].SourceRanges) assert.Equal(t, []string{"100.64.0.1/32", "100.64.0.2/32"}, rules[0].SourceRanges)
@@ -173,7 +173,7 @@ func TestGenerateRouteFirewallRules_IncludeIPv6False(t *testing.T) {
Protocol: PolicyRuleProtocolALL, Protocol: PolicyRuleProtocolALL,
} }
rules := generateRouteFirewallRules(context.Background(), r, rule, peers, FirewallRuleDirectionIN, false) rules := GenerateRouteFirewallRules(context.Background(), r, rule, peers, FirewallRuleDirectionIN, false)
assert.Empty(t, rules, "v6 route should produce no rules when includeIPv6 is false") assert.Empty(t, rules, "v6 route should produce no rules when includeIPv6 is false")
}) })
@@ -190,7 +190,7 @@ func TestGenerateRouteFirewallRules_IncludeIPv6False(t *testing.T) {
Protocol: PolicyRuleProtocolALL, Protocol: PolicyRuleProtocolALL,
} }
rules := generateRouteFirewallRules(context.Background(), r, rule, peers, FirewallRuleDirectionIN, false) rules := GenerateRouteFirewallRules(context.Background(), r, rule, peers, FirewallRuleDirectionIN, false)
require.Len(t, rules, 1, "dynamic route with includeIPv6=false should produce only v4 rule") require.Len(t, rules, 1, "dynamic route with includeIPv6=false should produce only v4 rule")
assert.Equal(t, []string{"100.64.0.1/32", "100.64.0.2/32"}, rules[0].SourceRanges) assert.Equal(t, []string{"100.64.0.1/32", "100.64.0.2/32"}, rules[0].SourceRanges)
}) })

View File

@@ -19,6 +19,10 @@ type Group struct {
// AccountID is a reference to Account that this object belongs // AccountID is a reference to Account that this object belongs
AccountID string `json:"-" gorm:"index"` AccountID string `json:"-" gorm:"index"`
// AccountSeqID is a per-account monotonically increasing identifier used as the
// compact wire id when sending NetworkMap components to capable peers.
AccountSeqID uint32 `json:"-" gorm:"index:idx_groups_account_seq_id;not null;default:0"`
// Name visible in the UI // Name visible in the UI
Name string Name string
@@ -41,6 +45,14 @@ type GroupPeer struct {
PeerID string `gorm:"primaryKey"` PeerID string `gorm:"primaryKey"`
} }
// HasSeqID reports whether the group has been persisted long enough to have a
// per-account sequence id allocated. Wire encoders that key off AccountSeqID
// must skip groups that return false here — otherwise multiple unpersisted
// groups would collide on id 0.
func (g *Group) HasSeqID() bool {
return g != nil && g.AccountSeqID != 0
}
func (g *Group) LoadGroupPeers() { func (g *Group) LoadGroupPeers() {
g.Peers = make([]string, len(g.GroupPeers)) g.Peers = make([]string, len(g.GroupPeers))
for i, peer := range g.GroupPeers { for i, peer := range g.GroupPeers {
@@ -74,6 +86,7 @@ func (g *Group) Copy() *Group {
group := &Group{ group := &Group{
ID: g.ID, ID: g.ID,
AccountID: g.AccountID, AccountID: g.AccountID,
AccountSeqID: g.AccountSeqID,
Name: g.Name, Name: g.Name,
Issued: g.Issued, Issued: g.Issued,
Peers: make([]string, len(g.Peers)), Peers: make([]string, len(g.Peers)),

View File

@@ -42,6 +42,17 @@ type NetworkMapComponents struct {
PostureFailedPeers map[string]map[string]struct{} PostureFailedPeers map[string]map[string]struct{}
RouterPeers map[string]*nbpeer.Peer RouterPeers map[string]*nbpeer.Peer
// NetworkXIDToSeq maps Network.ID (xid) → AccountSeqID. Populated by the
// account-side component builder; consumed by the envelope encoder to
// translate RoutersMap keys and NetworkResource.NetworkID references
// to compact uint32 ids. Legacy Calculate() doesn't consult it.
NetworkXIDToSeq map[string]uint32
// PostureCheckXIDToSeq maps posture.Checks.ID (xid) → AccountSeqID.
// Same role as NetworkXIDToSeq, used for PostureFailedPeers keys and
// policy SourcePostureChecks references.
PostureCheckXIDToSeq map[string]uint32
} }
type AccountSettingsInfo struct { type AccountSettingsInfo struct {
@@ -252,7 +263,7 @@ func (c *NetworkMapComponents) getPeerConnectionResources(targetPeerID string) (
default: default:
authorizedUsers[auth.Wildcard] = c.getAllowedUserIDs() authorizedUsers[auth.Wildcard] = c.getAllowedUserIDs()
} }
} else if peerInDestinations && policyRuleImpliesLegacySSH(rule) && targetPeer.SSHEnabled { } else if peerInDestinations && PolicyRuleImpliesLegacySSH(rule) && targetPeer.SSHEnabled {
sshEnabled = true sshEnabled = true
authorizedUsers[auth.Wildcard] = c.getAllowedUserIDs() authorizedUsers[auth.Wildcard] = c.getAllowedUserIDs()
} }
@@ -319,15 +330,15 @@ func (c *NetworkMapComponents) connResourcesGenerator(targetPeer *nbpeer.Peer) (
if len(rule.Ports) == 0 && len(rule.PortRanges) == 0 { if len(rule.Ports) == 0 && len(rule.PortRanges) == 0 {
rules = append(rules, &fr) rules = append(rules, &fr)
} else { } else {
rules = append(rules, expandPortsAndRanges(fr, rule, targetPeer)...) rules = append(rules, ExpandPortsAndRanges(fr, rule, targetPeer)...)
} }
rules = appendIPv6FirewallRule(rules, rulesExists, peer, targetPeer, rule, firewallRuleContext{ rules = AppendIPv6FirewallRule(rules, rulesExists, peer, targetPeer, rule, FirewallRuleContext{
direction: direction, Direction: direction,
dirStr: dirStr, DirStr: dirStr,
protocolStr: protocolStr, ProtocolStr: protocolStr,
actionStr: actionStr, ActionStr: actionStr,
portsJoined: portsJoined, PortsJoined: portsJoined,
}) })
} }
}, func() ([]*nbpeer.Peer, []*FirewallRule) { }, func() ([]*nbpeer.Peer, []*FirewallRule) {
@@ -680,7 +691,7 @@ func (c *NetworkMapComponents) getRouteFirewallRules(ctx context.Context, peerID
} }
rulePeers := c.getRulePeers(rule, policy.SourcePostureChecks, peerID, distributionPeers) rulePeers := c.getRulePeers(rule, policy.SourcePostureChecks, peerID, distributionPeers)
rules := generateRouteFirewallRules(ctx, route, rule, rulePeers, FirewallRuleDirectionIN, includeIPv6) rules := GenerateRouteFirewallRules(ctx, route, rule, rulePeers, FirewallRuleDirectionIN, includeIPv6)
fwRules = append(fwRules, rules...) fwRules = append(fwRules, rules...)
} }
} }
@@ -949,21 +960,21 @@ func (c *NetworkMapComponents) addNetworksRoutingPeers(
return peersToConnect return peersToConnect
} }
type firewallRuleContext struct { type FirewallRuleContext struct {
direction int Direction int
dirStr string DirStr string
protocolStr string ProtocolStr string
actionStr string ActionStr string
portsJoined string PortsJoined string
} }
func appendIPv6FirewallRule(rules []*FirewallRule, rulesExists map[string]struct{}, peer, targetPeer *nbpeer.Peer, rule *PolicyRule, rc firewallRuleContext) []*FirewallRule { func AppendIPv6FirewallRule(rules []*FirewallRule, rulesExists map[string]struct{}, peer, targetPeer *nbpeer.Peer, rule *PolicyRule, rc FirewallRuleContext) []*FirewallRule {
if !peer.IPv6.IsValid() || !targetPeer.SupportsIPv6() || !targetPeer.IPv6.IsValid() { if !peer.IPv6.IsValid() || !targetPeer.SupportsIPv6() || !targetPeer.IPv6.IsValid() {
return rules return rules
} }
v6IP := peer.IPv6.String() v6IP := peer.IPv6.String()
v6RuleID := rule.ID + v6IP + rc.dirStr + rc.protocolStr + rc.actionStr + rc.portsJoined v6RuleID := rule.ID + v6IP + rc.DirStr + rc.ProtocolStr + rc.ActionStr + rc.PortsJoined
if _, ok := rulesExists[v6RuleID]; ok { if _, ok := rulesExists[v6RuleID]; ok {
return rules return rules
} }
@@ -972,12 +983,12 @@ func appendIPv6FirewallRule(rules []*FirewallRule, rulesExists map[string]struct
v6fr := FirewallRule{ v6fr := FirewallRule{
PolicyID: rule.ID, PolicyID: rule.ID,
PeerIP: v6IP, PeerIP: v6IP,
Direction: rc.direction, Direction: rc.Direction,
Action: rc.actionStr, Action: rc.ActionStr,
Protocol: rc.protocolStr, Protocol: rc.ProtocolStr,
} }
if len(rule.Ports) == 0 && len(rule.PortRanges) == 0 { if len(rule.Ports) == 0 && len(rule.PortRanges) == 0 {
return append(rules, &v6fr) return append(rules, &v6fr)
} }
return append(rules, expandPortsAndRanges(v6fr, rule, targetPeer)...) return append(rules, ExpandPortsAndRanges(v6fr, rule, targetPeer)...)
} }

View File

@@ -59,6 +59,10 @@ type Policy struct {
// AccountID is a reference to Account that this object belongs // AccountID is a reference to Account that this object belongs
AccountID string `json:"-" gorm:"index"` AccountID string `json:"-" gorm:"index"`
// AccountSeqID is a per-account monotonically increasing identifier used as the
// compact wire id when sending NetworkMap components to capable peers.
AccountSeqID uint32 `json:"-" gorm:"index:idx_policies_account_seq_id;not null;default:0"`
// Name of the Policy // Name of the Policy
Name string Name string
@@ -75,11 +79,19 @@ type Policy struct {
SourcePostureChecks []string `gorm:"serializer:json"` SourcePostureChecks []string `gorm:"serializer:json"`
} }
// HasSeqID reports whether the policy has been persisted long enough to have
// a per-account sequence id allocated. Wire encoders that key off
// AccountSeqID must skip policies that return false here.
func (p *Policy) HasSeqID() bool {
return p != nil && p.AccountSeqID != 0
}
// Copy returns a copy of the policy. // Copy returns a copy of the policy.
func (p *Policy) Copy() *Policy { func (p *Policy) Copy() *Policy {
c := &Policy{ c := &Policy{
ID: p.ID, ID: p.ID,
AccountID: p.AccountID, AccountID: p.AccountID,
AccountSeqID: p.AccountSeqID,
Name: p.Name, Name: p.Name,
Description: p.Description, Description: p.Description,
Enabled: p.Enabled, Enabled: p.Enabled,