Compare commits

...

13 Commits

Author SHA1 Message Date
crn4
13e41e432c idp dex fix 2026-05-21 15:21:28 +02:00
crn4
efa6a3f502 missed file 2026-05-20 12:41:05 +02:00
crn4
5fbcdeceac more comments 2026-05-19 21:41:08 +02:00
crn4
3a1bbeba90 review comments 2026-05-19 20:27:50 +02:00
crn4
728057ef15 missed files for client side and shared files 2026-05-19 14:46:23 +02:00
crn4
582cd70086 client side and components on shared folder 2026-05-19 14:46:09 +02:00
crn4
9bbbafaf69 int id for networks and posture checks migration 2026-05-19 14:45:40 +02:00
crn4
672b057aa0 fix Group.Copy losing AccountSeqID 2026-05-19 14:43:59 +02:00
crn4
b9a0186200 fix routes filtering in account componnents 2026-05-19 14:43:49 +02:00
crn4
9083bdb977 capabilities conditioning 2026-05-19 14:43:38 +02:00
crn4
b194af48b8 wire size benches fix 2026-05-19 14:43:28 +02:00
crn4
4543780ef0 grpc components encoding with optimisations 2026-05-19 14:43:17 +02:00
crn4
2de0283971 init int inds migration 2026-05-19 14:42:55 +02:00
56 changed files with 9006 additions and 571 deletions

View File

@@ -61,9 +61,11 @@ import (
cProto "github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/client/system"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/route"
mgm "github.com/netbirdio/netbird/shared/management/client"
"github.com/netbirdio/netbird/shared/management/domain"
nbnetworkmap "github.com/netbirdio/netbird/shared/management/networkmap"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/shared/netiputil"
auth "github.com/netbirdio/netbird/shared/relay/auth/hmac"
@@ -202,6 +204,13 @@ type Engine struct {
// networkSerial is the latest CurrentSerial (state ID) of the network sent by the Management service
networkSerial uint64
// latestComponents is the most-recent NetworkMapComponents decoded from
// a NetworkMapEnvelope (capability=3 peers only). Held alongside the
// NetworkMap that Calculate() produced from it so Step 3 incremental
// updates have a base to apply changes against. nil for legacy-format
// peers. Guarded by syncMsgMux.
latestComponents *types.NetworkMapComponents
networkMonitor *networkmonitor.NetworkMonitor
sshServer sshServer
@@ -865,8 +874,12 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
return e.ctx.Err()
}
if update.NetworkMap != nil && update.NetworkMap.PeerConfig != nil {
e.handleAutoUpdateVersion(update.NetworkMap.PeerConfig.AutoUpdate)
// Envelope sync responses carry PeerConfig at the top level; legacy
// NetworkMap syncs carry it under NetworkMap.PeerConfig.
if pc := update.GetPeerConfig(); pc != nil {
e.handleAutoUpdateVersion(pc.GetAutoUpdate())
} else if nm := update.GetNetworkMap(); nm != nil && nm.GetPeerConfig() != nil {
e.handleAutoUpdateVersion(nm.GetPeerConfig().GetAutoUpdate())
}
if update.GetNetbirdConfig() != nil {
@@ -907,11 +920,45 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
return err
}
nm := update.GetNetworkMap()
var (
nm *mgmProto.NetworkMap
components *types.NetworkMapComponents
)
if envelope := update.GetNetworkMapEnvelope(); envelope != nil {
// Components-format peer: decode the envelope back to typed
// components, run Calculate() locally, and convert to the wire
// NetworkMap shape the rest of the engine consumes. Components are
// retained so future incremental updates (Step 3) can apply deltas
// instead of doing a full reconstruction.
localKey := e.config.WgPrivateKey.PublicKey().String()
dnsName := ""
if pc := update.GetPeerConfig(); pc != nil {
// PeerConfig.Fqdn = "<dns_label>.<dns_domain>" — extract the
// shared domain by stripping the peer's own label prefix. Falls
// back to empty if the FQDN doesn't have the expected shape.
dnsName = extractDNSDomainFromFQDN(pc.GetFqdn())
}
result, err := nbnetworkmap.EnvelopeToNetworkMap(e.ctx, envelope, localKey, dnsName)
if err != nil {
return fmt.Errorf("decode network map envelope: %w", err)
}
nm = result.NetworkMap
components = result.Components
} else {
nm = update.GetNetworkMap()
}
if nm == nil {
return nil
}
// Only retain the components view when the server sent the envelope
// path. A legacy proto.NetworkMap means components == nil; writing it
// here would clobber a previously-cached snapshot, breaking the Step 3
// incremental-delta base on a future envelope sync.
if components != nil {
e.latestComponents = components
}
// Persist sync response under the dedicated lock (syncRespMux), not under syncMsgMux.
// Read the storage-enabled flag under the syncRespMux too.
e.syncRespMux.RLock()
@@ -937,6 +984,19 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
return nil
}
// extractDNSDomainFromFQDN returns the trailing dotted domain part of the
// receiving peer's FQDN — the same value the management server fills as
// dnsName when it builds the legacy NetworkMap. "peer42.netbird.cloud" →
// "netbird.cloud". An empty string is returned for unrecognized formats.
func extractDNSDomainFromFQDN(fqdn string) string {
for i := 0; i < len(fqdn); i++ {
if fqdn[i] == '.' && i+1 < len(fqdn) {
return fqdn[i+1:]
}
}
return ""
}
func (e *Engine) handleRelayUpdate(update *mgmProto.RelayConfig) error {
if update != nil {
// when we receive token we expect valid address list too

View File

@@ -53,6 +53,9 @@ type NameServerGroup struct {
ID string `gorm:"primaryKey"`
// AccountID is a reference to Account that this object belongs
AccountID string `gorm:"index"`
// AccountSeqID is a per-account monotonically increasing identifier used as the
// compact wire id when sending NetworkMap components to capable peers.
AccountSeqID uint32 `json:"-" gorm:"index:idx_nameserver_groups_account_seq_id;not null;default:0"`
// Name group name
Name string
// Description group description

View File

@@ -308,7 +308,7 @@ func (s *Storage) OpenStorage(logger *slog.Logger) (storage.Storage, error) {
if file == "" {
return nil, fmt.Errorf("sqlite3 storage requires 'file' config")
}
return (&sql.SQLite3{File: file}).Open(logger)
return newSQLite3(file).Open(logger)
case "postgres":
dsn, _ := s.Config["dsn"].(string)
if dsn == "" {

View File

@@ -20,7 +20,6 @@ import (
"github.com/dexidp/dex/server"
"github.com/dexidp/dex/server/signer"
"github.com/dexidp/dex/storage"
"github.com/dexidp/dex/storage/sql"
jose "github.com/go-jose/go-jose/v4"
"github.com/google/uuid"
"github.com/prometheus/client_golang/prometheus"
@@ -77,7 +76,7 @@ func NewProvider(ctx context.Context, config *Config) (*Provider, error) {
// Initialize SQLite storage
dbPath := filepath.Join(config.DataDir, "oidc.db")
sqliteConfig := &sql.SQLite3{File: dbPath}
sqliteConfig := newSQLite3(dbPath)
stor, err := sqliteConfig.Open(logger)
if err != nil {
return nil, fmt.Errorf("failed to open storage: %w", err)

View File

@@ -55,6 +55,15 @@ type Controller struct {
proxyController port_forwarding.Controller
integratedPeerValidator integrated_validator.IntegratedValidator
// componentsDisabled is the kill switch for the component-based wire
// format. When true the controller emits legacy proto.NetworkMap to every
// peer regardless of capability — used to roll back instantly via a
// management restart from a bad components encoder.
//
// Set once in NewController from NB_NETWORK_MAP_COMPONENTS_DISABLE and
// never written after — readers race-free without a mutex.
componentsDisabled bool
}
type bufferUpdate struct {
@@ -81,12 +90,30 @@ func NewController(ctx context.Context, store store.Store, metrics telemetry.App
settingsManager: settingsManager,
dnsDomain: dnsDomain,
config: config,
componentsDisabled: parseBoolEnv("NB_NETWORK_MAP_COMPONENTS_DISABLE"),
proxyController: proxyController,
EphemeralPeersManager: ephemeralPeersManager,
}
}
// PeerNeedsComponents reports whether the gRPC layer should emit the
// component-based wire format for this peer. Combines the peer's advertised
// capability with the controller-level kill switch — callers ask exactly
// this question, so encapsulating it removes accidental double-checks.
func (c *Controller) PeerNeedsComponents(p *nbpeer.Peer) bool {
return p != nil && p.SupportsComponentNetworkMap() && !c.componentsDisabled
}
// parseBoolEnv reads an env var via strconv.ParseBool so callers accept the
// usual "1/t/T/TRUE/true/True" set instead of being strict about a single
// literal — matches the convention used elsewhere in the codebase
// (e.g. event.go's NB_TRAFFIC_EVENT_*) and reduces operator surprises.
func parseBoolEnv(key string) bool {
v, _ := strconv.ParseBool(os.Getenv(key))
return v
}
func (c *Controller) OnPeerConnected(ctx context.Context, accountID string, peerID string) (chan *network_map.UpdateMessage, error) {
peer, err := c.repo.GetPeerByID(ctx, accountID, peerID)
if err != nil {
@@ -192,18 +219,26 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
c.metrics.CountCalcPostureChecksDuration(time.Since(start))
start = time.Now()
remotePeerNetworkMap := account.GetPeerNetworkMapFromComponents(ctx, p.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
result := account.GetPeerNetworkMapResult(ctx, p.ID, c.componentsDisabled, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
c.metrics.CountCalcPeerNetworkMapDuration(time.Since(start))
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
if ok {
remotePeerNetworkMap.Merge(proxyNetworkMap)
proxyNetworkMap := proxyNetworkMaps[p.ID]
if result.NetworkMap != nil && proxyNetworkMap != nil {
result.NetworkMap.Merge(proxyNetworkMap)
}
peerGroups := account.GetPeerGroups(p.ID)
start = time.Now()
update := grpc.ToSyncResponse(ctx, nil, c.config.HttpConfig, c.config.DeviceAuthorizationFlow, p, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSetting, maps.Keys(peerGroups), dnsFwdPort)
var update *proto.SyncResponse
if result.IsComponents() {
// proxyNetworkMap rides the envelope as a ProxyPatch sidecar;
// the client merges it into Calculate()'s output the same
// way the legacy server did via NetworkMap.Merge.
update = grpc.ToComponentSyncResponse(ctx, nil, c.config.HttpConfig, c.config.DeviceAuthorizationFlow, p, nil, nil, result.Components, proxyNetworkMap, dnsDomain, postureChecks, account.Settings, extraSetting, maps.Keys(peerGroups), dnsFwdPort)
} else {
update = grpc.ToSyncResponse(ctx, nil, c.config.HttpConfig, c.config.DeviceAuthorizationFlow, p, nil, nil, result.NetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSetting, maps.Keys(peerGroups), dnsFwdPort)
}
c.metrics.CountToSyncResponseDuration(time.Since(start))
c.peersUpdateManager.SendUpdate(ctx, p.ID, &network_map.UpdateMessage{
@@ -314,11 +349,11 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe
return err
}
remotePeerNetworkMap := account.GetPeerNetworkMapFromComponents(ctx, peerId, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
result := account.GetPeerNetworkMapResult(ctx, peerId, c.componentsDisabled, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
if ok {
remotePeerNetworkMap.Merge(proxyNetworkMap)
proxyNetworkMap := proxyNetworkMaps[peer.ID]
if result.NetworkMap != nil && proxyNetworkMap != nil {
result.NetworkMap.Merge(proxyNetworkMap)
}
extraSettings, err := c.settingsManager.GetExtraSettings(ctx, peer.AccountID)
@@ -329,7 +364,12 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe
peerGroups := account.GetPeerGroups(peerId)
dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), network_map.DnsForwarderPortMinVersion)
update := grpc.ToSyncResponse(ctx, nil, c.config.HttpConfig, c.config.DeviceAuthorizationFlow, peer, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSettings, maps.Keys(peerGroups), dnsFwdPort)
var update *proto.SyncResponse
if result.IsComponents() {
update = grpc.ToComponentSyncResponse(ctx, nil, c.config.HttpConfig, c.config.DeviceAuthorizationFlow, peer, nil, nil, result.Components, proxyNetworkMap, dnsDomain, postureChecks, account.Settings, extraSettings, maps.Keys(peerGroups), dnsFwdPort)
} else {
update = grpc.ToSyncResponse(ctx, nil, c.config.HttpConfig, c.config.DeviceAuthorizationFlow, peer, nil, nil, result.NetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSettings, maps.Keys(peerGroups), dnsFwdPort)
}
c.peersUpdateManager.SendUpdate(ctx, peer.ID, &network_map.UpdateMessage{
Update: update,
MessageType: network_map.MessageTypeNetworkMap,
@@ -376,6 +416,67 @@ func (c *Controller) BufferUpdateAccountPeers(ctx context.Context, accountID str
return nil
}
// GetValidatedPeerWithComponents is the components-format counterpart of
// GetValidatedPeerWithMap. It returns raw NetworkMapComponents for capable
// peers along with the proxy NetworkMap fragment (BYOP / port-forwarding
// data the legacy server folds in via NetworkMap.Merge). The gRPC layer
// encodes both into the wire envelope. The caller is responsible for
// checking peer capability + componentsDisabled before dispatching here —
// this method does NOT branch on capability itself.
func (c *Controller) GetValidatedPeerWithComponents(ctx context.Context, isRequiresApproval bool, accountID string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMapComponents, *types.NetworkMap, []*posture.Checks, int64, error) {
if isRequiresApproval {
network, err := c.repo.GetAccountNetwork(ctx, accountID)
if err != nil {
return nil, nil, nil, nil, 0, err
}
return peer, &types.NetworkMapComponents{Network: network.Copy()}, nil, nil, 0, nil
}
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return nil, nil, nil, nil, 0, err
}
account.InjectProxyPolicies(ctx)
approvedPeersMap, err := c.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
if err != nil {
return nil, nil, nil, nil, 0, err
}
postureChecks, err := c.getPeerPostureChecks(account, peer.ID)
if err != nil {
return nil, nil, nil, nil, 0, err
}
accountZones, err := c.repo.GetAccountZones(ctx, account.Id)
if err != nil {
return nil, nil, nil, nil, 0, err
}
// Fetch the proxy network map fragment for this peer alongside the
// components — same single-account-load path the streaming controller
// uses, so initial-sync delivers BYOP/forwarding patches synchronously
// instead of waiting for the next streaming push.
proxyNetworkMaps, err := c.proxyController.GetProxyNetworkMaps(ctx, account.Id, peer.ID, account.Peers)
if err != nil {
log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err)
return nil, nil, nil, nil, 0, err
}
dnsDomain := c.GetDNSDomain(account.Settings)
peersCustomZone := account.GetPeersCustomZone(ctx, dnsDomain)
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
groupIDToUserIDs := account.GetActiveGroupUsers()
components := account.GetPeerNetworkMapComponents(ctx, peer.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, groupIDToUserIDs)
dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), network_map.DnsForwarderPortMinVersion)
return peer, components, proxyNetworkMaps[peer.ID], postureChecks, dnsFwdPort, nil
}
func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
if isRequiresApproval {
network, err := c.repo.GetAccountNetwork(ctx, accountID)

View File

@@ -22,6 +22,10 @@ type Controller interface {
UpdateAccountPeer(ctx context.Context, accountId string, peerId string) error
BufferUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) error
GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, p *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
GetValidatedPeerWithComponents(ctx context.Context, isRequiresApproval bool, accountID string, p *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMapComponents, *types.NetworkMap, []*posture.Checks, int64, error)
// PeerNeedsComponents combines the peer's advertised capability with the
// kill-switch flag — the only public predicate gRPC layers should ask.
PeerNeedsComponents(p *nbpeer.Peer) bool
GetDNSDomain(settings *types.Settings) string
StartWarmup(context.Context)
GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error)

View File

@@ -130,6 +130,39 @@ func (mr *MockControllerMockRecorder) GetValidatedPeerWithMap(ctx, isRequiresApp
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetValidatedPeerWithMap", reflect.TypeOf((*MockController)(nil).GetValidatedPeerWithMap), ctx, isRequiresApproval, accountID, p)
}
// GetValidatedPeerWithComponents mocks base method.
func (m *MockController) GetValidatedPeerWithComponents(ctx context.Context, isRequiresApproval bool, accountID string, p *peer.Peer) (*peer.Peer, *types.NetworkMapComponents, *types.NetworkMap, []*posture.Checks, int64, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetValidatedPeerWithComponents", ctx, isRequiresApproval, accountID, p)
ret0, _ := ret[0].(*peer.Peer)
ret1, _ := ret[1].(*types.NetworkMapComponents)
ret2, _ := ret[2].(*types.NetworkMap)
ret3, _ := ret[3].([]*posture.Checks)
ret4, _ := ret[4].(int64)
ret5, _ := ret[5].(error)
return ret0, ret1, ret2, ret3, ret4, ret5
}
// GetValidatedPeerWithComponents indicates an expected call of GetValidatedPeerWithComponents.
func (mr *MockControllerMockRecorder) GetValidatedPeerWithComponents(ctx, isRequiresApproval, accountID, p any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetValidatedPeerWithComponents", reflect.TypeOf((*MockController)(nil).GetValidatedPeerWithComponents), ctx, isRequiresApproval, accountID, p)
}
// PeerNeedsComponents mocks base method.
func (m *MockController) PeerNeedsComponents(p *peer.Peer) bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "PeerNeedsComponents", p)
ret0, _ := ret[0].(bool)
return ret0
}
// PeerNeedsComponents indicates an expected call of PeerNeedsComponents.
func (mr *MockControllerMockRecorder) PeerNeedsComponents(p any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PeerNeedsComponents", reflect.TypeOf((*MockController)(nil).PeerNeedsComponents), p)
}
// OnPeerConnected mocks base method.
func (m *MockController) OnPeerConnected(ctx context.Context, accountID, peerID string) (chan *UpdateMessage, error) {
m.ctrl.T.Helper()

View File

@@ -0,0 +1,815 @@
package grpc
import (
"encoding/base64"
"strconv"
nbdns "github.com/netbirdio/netbird/dns"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/types"
nbroute "github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/networkmap"
"github.com/netbirdio/netbird/shared/management/proto"
)
// wgKeyRawLen is the raw byte length of a WireGuard public key.
const wgKeyRawLen = 32
// ComponentsEnvelopeInput bundles the data the component-format encoder needs.
// In Step 2 the envelope is fully self-contained — every field needed by the
// client's local Calculate() comes from the components struct itself. The
// only externally-supplied data is the receiving peer's PeerConfig (which is
// computed alongside the components in the network_map controller and reused
// from the legacy proto path) and the dns_domain string.
type ComponentsEnvelopeInput struct {
Components *types.NetworkMapComponents
PeerConfig *proto.PeerConfig
DNSDomain string
DNSForwarderPort int64
// UserIDClaim is the OIDC claim name the client should embed in
// SshAuth.UserIDClaim when reconstructing the NetworkMap. Empty value
// is OK — client treats empty as "no SshAuth to build".
UserIDClaim string
// ProxyPatch carries pre-expanded NetworkMap fragments injected by
// external controllers (BYOP/port-forwarding). Nil when no proxy data
// is present; encoder skips the field in that case.
ProxyPatch *proto.ProxyPatch
}
// EncodeNetworkMapEnvelope converts NetworkMapComponents into the component
// wire envelope. The encoder is intentionally non-deterministic: it iterates
// Go maps in their native (random) order. Indexes inside the envelope
// (peer_indexes, source_group_ids, agent_version_idx, router_peer_indexes)
// are self-consistent within a single encode, so the decoder reconstructs
// the same typed objects regardless of emit order. Tests that need to
// compare envelopes do so semantically via proto round-trip + canonicalize,
// not byte-equal.
//
// Callers must NOT concatenate or merge envelopes from different encodes —
// index spaces are local to a single envelope. Delta sync (Step 3+) will
// use a different shape for the same reason.
func EncodeNetworkMapEnvelope(in ComponentsEnvelopeInput) *proto.NetworkMapEnvelope {
c := in.Components
// Graceful degrade when components is nil — matches the legacy path's
// account_components.go:43 behaviour for missing/unvalidated peers
// (return a NetworkMap with only Network populated). The receiver gets
// an envelope it can decode without crashing; AccountSettings stays
// non-nil so client-side dereferences are safe.
if c == nil {
// Match legacy missing-peer minimum: a NetworkMap with only Network
// populated (account_components.go:43). The receiver gets enough to
// bootstrap (Network identifier, dns_domain, account_settings) and
// nothing else.
return &proto.NetworkMapEnvelope{
Payload: &proto.NetworkMapEnvelope_Full{
Full: &proto.NetworkMapComponentsFull{
PeerConfig: in.PeerConfig,
DnsDomain: in.DNSDomain,
DnsForwarderPort: in.DNSForwarderPort,
UserIdClaim: in.UserIDClaim,
AccountSettings: &proto.AccountSettingsCompact{},
ProxyPatch: in.ProxyPatch,
},
},
}
}
// Phase 1: build dedup tables. Every routing peer (in c.RouterPeers) and
// every regular peer (in c.Peers) must be indexed before any encoder
// looks up indexes via e.peerOrder — otherwise routes / routers_map for
// peers that exist only in c.RouterPeers would silently lose their
// peer_index reference.
enc := newComponentEncoder(c)
enc.indexAllPeers()
routerIdxs := enc.indexRouterPeers(c.RouterPeers)
// Phase 2: gather every policy that any consumer references (peer-pair
// policies + resource-only policies) so encodeResourcePoliciesMap can
// translate every *Policy pointer to a wire index.
allPolicies := unionPolicies(c.Policies, c.ResourcePoliciesMap)
policies, policyToIdxs := enc.encodePolicies(allPolicies)
// Phase 3: emit. Order of struct field expressions no longer matters:
// every encoder either reads from the dedup tables or works on
// independent input.
full := &proto.NetworkMapComponentsFull{
Serial: networkSerial(c.Network),
PeerConfig: in.PeerConfig,
Network: toAccountNetwork(c.Network),
AccountSettings: toAccountSettingsCompact(c.AccountSettings),
DnsForwarderPort: in.DNSForwarderPort,
UserIdClaim: in.UserIDClaim,
ProxyPatch: in.ProxyPatch,
DnsSettings: enc.encodeDNSSettings(c.DNSSettings),
DnsDomain: in.DNSDomain,
CustomZoneDomain: c.CustomZoneDomain,
AgentVersions: enc.agentVersions,
Peers: enc.peers,
RouterPeerIndexes: routerIdxs,
Policies: policies,
Groups: enc.encodeGroups(),
Routes: enc.encodeRoutes(c.Routes),
NameserverGroups: enc.encodeNameServerGroups(c.NameServerGroups),
AllDnsRecords: encodeSimpleRecords(c.AllDNSRecords),
AccountZones: encodeCustomZones(c.AccountZones),
NetworkResources: enc.encodeNetworkResources(c.NetworkResources),
RoutersMap: enc.encodeRoutersMap(c.RoutersMap),
ResourcePoliciesMap: enc.encodeResourcePoliciesMap(c.ResourcePoliciesMap, policyToIdxs),
GroupIdToUserIds: enc.encodeGroupIDToUserIDs(c.GroupIDToUserIDs),
AllowedUserIds: stringSetToSlice(c.AllowedUserIDs),
PostureFailedPeers: enc.encodePostureFailedPeers(c.PostureFailedPeers),
}
return &proto.NetworkMapEnvelope{
Payload: &proto.NetworkMapEnvelope_Full{Full: full},
}
}
// networkSerial returns c.Network.CurrentSerial() with a nil guard. The
// production path always populates c.Network (account_components.go:86), but
// the encoder is exported and a hand-built components struct may omit it.
func networkSerial(n *types.Network) uint64 {
if n == nil {
return 0
}
return n.CurrentSerial()
}
type componentEncoder struct {
components *types.NetworkMapComponents
peerOrder map[string]uint32
peers []*proto.PeerCompact
agentVersionOrder map[string]uint32
agentVersions []string
}
func newComponentEncoder(c *types.NetworkMapComponents) *componentEncoder {
return &componentEncoder{
components: c,
peerOrder: make(map[string]uint32, len(c.Peers)),
peers: make([]*proto.PeerCompact, 0, len(c.Peers)),
agentVersionOrder: make(map[string]uint32),
}
}
func (e *componentEncoder) indexAllPeers() {
for _, p := range e.components.Peers {
if p == nil {
continue
}
e.appendPeer(p)
}
}
func (e *componentEncoder) appendPeer(p *nbpeer.Peer) uint32 {
if idx, ok := e.peerOrder[p.ID]; ok {
return idx
}
idx := uint32(len(e.peers))
e.peerOrder[p.ID] = idx
e.peers = append(e.peers, toPeerCompact(p, e.agentVersionIndex(p.Meta.WtVersion)))
return idx
}
func (e *componentEncoder) agentVersionIndex(v string) uint32 {
if idx, ok := e.agentVersionOrder[v]; ok {
return idx
}
// Lazy-initialise the table with "" at index 0 so the empty string
// stays interchangeable with proto3's default uint32=0 — peers without
// a WtVersion don't force the table to materialise.
if v == "" {
idx := uint32(len(e.agentVersions))
if idx == 0 {
e.agentVersions = append(e.agentVersions, "")
}
e.agentVersionOrder[""] = idx
return idx
}
if len(e.agentVersions) == 0 {
e.agentVersions = append(e.agentVersions, "")
e.agentVersionOrder[""] = 0
}
idx := uint32(len(e.agentVersions))
e.agentVersionOrder[v] = idx
e.agentVersions = append(e.agentVersions, v)
return idx
}
// indexRouterPeers ensures every router peer is in the peer dedup table
// (c.RouterPeers may contain peers not in c.Peers when validation rules drop
// them) and returns their wire indexes for the RouterPeerIndexes field. Must
// run before any encoder that resolves peer ids via e.peerOrder.
func (e *componentEncoder) indexRouterPeers(routers map[string]*nbpeer.Peer) []uint32 {
if len(routers) == 0 {
return nil
}
out := make([]uint32, 0, len(routers))
for _, p := range routers {
if p == nil {
continue
}
out = append(out, e.appendPeer(p))
}
return out
}
func (e *componentEncoder) encodeGroups() []*proto.GroupCompact {
if len(e.components.Groups) == 0 {
return nil
}
out := make([]*proto.GroupCompact, 0, len(e.components.Groups))
for _, g := range e.components.Groups {
if !g.HasSeqID() {
continue
}
peerIdxs := make([]uint32, 0, len(g.Peers))
for _, peerID := range g.Peers {
if idx, ok := e.peerOrder[peerID]; ok {
peerIdxs = append(peerIdxs, idx)
}
}
out = append(out, &proto.GroupCompact{
Id: g.AccountSeqID,
Name: g.Name,
PeerIndexes: peerIdxs,
})
}
return out
}
// encodePolicies flattens Policy{Rules} → []PolicyCompact. Returns the wire
// list and a map from policy pointer to the indexes of its emitted rules in
// that list — used by encodeResourcePoliciesMap to translate
// ResourcePoliciesMap[resourceID][]*Policy into wire-side indexes.
func (e *componentEncoder) encodePolicies(policies []*types.Policy) ([]*proto.PolicyCompact, map[*types.Policy][]uint32) {
if len(policies) == 0 {
return nil, nil
}
out := make([]*proto.PolicyCompact, 0, len(policies))
idxByPolicy := make(map[*types.Policy][]uint32, len(policies))
for _, pol := range policies {
if !pol.HasSeqID() || !pol.Enabled {
continue
}
for _, r := range pol.Rules {
if r == nil || !r.Enabled {
continue
}
idxByPolicy[pol] = append(idxByPolicy[pol], uint32(len(out)))
out = append(out, e.encodePolicyRule(pol, r))
}
}
return out, idxByPolicy
}
// encodePolicyRule maps a single PolicyRule under pol to a PolicyCompact entry.
func (e *componentEncoder) encodePolicyRule(pol *types.Policy, r *types.PolicyRule) *proto.PolicyCompact {
return &proto.PolicyCompact{
Id: pol.AccountSeqID,
Action: networkmap.GetProtoAction(string(r.Action)),
Protocol: networkmap.GetProtoProtocol(string(r.Protocol)),
Bidirectional: r.Bidirectional,
Ports: portsToUint32(r.Ports),
PortRanges: portRangesToProto(r.PortRanges),
SourceGroupIds: e.groupSeqIDs(r.Sources),
DestinationGroupIds: e.groupSeqIDs(r.Destinations),
AuthorizedUser: r.AuthorizedUser,
AuthorizedGroups: e.encodeAuthorizedGroups(r.AuthorizedGroups),
SourceResource: e.resourceToProto(r.SourceResource),
DestinationResource: e.resourceToProto(r.DestinationResource),
SourcePostureCheckSeqIds: e.postureCheckSeqs(pol.SourcePostureChecks),
}
}
// groupSeqIDs maps the xid group IDs in src to their per-account seq ids,
// dropping any group that has no seq id assigned.
func (e *componentEncoder) groupSeqIDs(src []string) []uint32 {
if len(src) == 0 {
return nil
}
out := make([]uint32, 0, len(src))
for _, gid := range src {
if seq, ok := e.groupSeq(gid); ok {
out = append(out, seq)
}
}
return out
}
// unionPolicies merges c.Policies with every policy referenced by
// c.ResourcePoliciesMap, deduplicating by pointer identity. Resource-only
// policies (relevant to a NetworkResource but not to peer-pair traffic)
// only live in ResourcePoliciesMap; without this union step they'd be lost
// from the wire and the client's resource-policy lookup would come back
// empty.
func unionPolicies(policies []*types.Policy, resourcePolicies map[string][]*types.Policy) []*types.Policy {
// Fast path: non-router peers have no resource-only policies, so the
// "union" is identical to `policies`. Skip the dedup map allocation.
if len(resourcePolicies) == 0 {
return policies
}
seen := make(map[*types.Policy]struct{}, len(policies))
out := make([]*types.Policy, 0, len(policies))
for _, p := range policies {
if p == nil {
continue
}
if _, ok := seen[p]; ok {
continue
}
seen[p] = struct{}{}
out = append(out, p)
}
for _, list := range resourcePolicies {
for _, p := range list {
if p == nil {
continue
}
if _, ok := seen[p]; ok {
continue
}
seen[p] = struct{}{}
out = append(out, p)
}
}
return out
}
// encodeAuthorizedGroups translates rule.AuthorizedGroups (map keyed by
// group xid → local-user names) to the wire form (map keyed by group
// account_seq_id → UserNameList). Groups without a seq id are dropped —
// matches how source/destination group references handle the same case.
func (e *componentEncoder) encodeAuthorizedGroups(m map[string][]string) map[uint32]*proto.UserNameList {
if len(m) == 0 {
return nil
}
out := make(map[uint32]*proto.UserNameList, len(m))
for groupID, names := range m {
seq, ok := e.groupSeq(groupID)
if !ok {
continue
}
out[seq] = &proto.UserNameList{Names: append([]string(nil), names...)}
}
return out
}
func (e *componentEncoder) groupSeq(groupID string) (uint32, bool) {
g, ok := e.components.Groups[groupID]
if !ok || !g.HasSeqID() {
return 0, false
}
return g.AccountSeqID, true
}
// resourceToProto translates types.Resource for the wire. For peer-typed
// resources the peer id is converted to a peer index into the envelope's
// peers array. For other resource types only the type string is shipped
// today (Calculate's resource-typed rule path consults SourceResource only
// for "peer" — other types fall through to group-based lookup).
func (e *componentEncoder) resourceToProto(r types.Resource) *proto.ResourceCompact {
if r.ID == "" && r.Type == "" {
return nil
}
out := &proto.ResourceCompact{Type: string(r.Type)}
if r.Type == types.ResourceTypePeer && r.ID != "" {
if idx, ok := e.peerOrder[r.ID]; ok {
out.PeerIndexSet = true
out.PeerIndex = idx
}
}
return out
}
// postureCheckSeqs translates a slice of posture-check xids to their
// per-account integer ids using the NetworkMapComponents.PostureCheckXIDToSeq
// lookup. Unresolvable xids are silently dropped — matches how group/peer
// references handle the same case.
func (e *componentEncoder) postureCheckSeqs(xids []string) []uint32 {
if len(xids) == 0 || len(e.components.PostureCheckXIDToSeq) == 0 {
return nil
}
out := make([]uint32, 0, len(xids))
for _, xid := range xids {
if seq, ok := e.components.PostureCheckXIDToSeq[xid]; ok {
out = append(out, seq)
}
}
return out
}
// networkSeq translates a Network xid to its per-account integer id using
// the NetworkMapComponents.NetworkXIDToSeq lookup. Returns (0,false) when
// the xid isn't known — callers decide whether to skip the parent record.
func (e *componentEncoder) networkSeq(xid string) (uint32, bool) {
if xid == "" {
return 0, false
}
seq, ok := e.components.NetworkXIDToSeq[xid]
if !ok || seq == 0 {
return 0, false
}
return seq, true
}
func (e *componentEncoder) encodeDNSSettings(s *types.DNSSettings) *proto.DNSSettingsCompact {
if s == nil || len(s.DisabledManagementGroups) == 0 {
return nil
}
out := &proto.DNSSettingsCompact{
DisabledManagementGroupIds: make([]uint32, 0, len(s.DisabledManagementGroups)),
}
for _, gid := range s.DisabledManagementGroups {
if seq, ok := e.groupSeq(gid); ok {
out.DisabledManagementGroupIds = append(out.DisabledManagementGroupIds, seq)
}
}
return out
}
func (e *componentEncoder) encodeRoutes(routes []*nbroute.Route) []*proto.RouteRaw {
if len(routes) == 0 {
return nil
}
out := make([]*proto.RouteRaw, 0, len(routes))
for _, r := range routes {
if r == nil {
continue
}
rr := &proto.RouteRaw{
Id: r.AccountSeqID,
NetId: string(r.NetID),
Description: r.Description,
KeepRoute: r.KeepRoute,
NetworkType: int32(r.NetworkType),
Masquerade: r.Masquerade,
Metric: int32(r.Metric),
Enabled: r.Enabled,
SkipAutoApply: r.SkipAutoApply,
Domains: r.Domains.ToPunycodeList(),
GroupIds: e.groupIDsToSeq(r.Groups),
AccessControlGroupIds: e.groupIDsToSeq(r.AccessControlGroups),
PeerGroupIds: e.groupIDsToSeq(r.PeerGroups),
}
if r.Network.IsValid() {
rr.NetworkCidr = r.Network.String()
}
if r.Peer != "" {
if idx, ok := e.peerOrder[r.Peer]; ok {
rr.PeerIndexSet = true
rr.PeerIndex = idx
}
}
out = append(out, rr)
}
return out
}
func (e *componentEncoder) groupIDsToSeq(groupIDs []string) []uint32 {
if len(groupIDs) == 0 {
return nil
}
out := make([]uint32, 0, len(groupIDs))
for _, gid := range groupIDs {
if seq, ok := e.groupSeq(gid); ok {
out = append(out, seq)
}
}
return out
}
func (e *componentEncoder) encodeNameServerGroups(nsgs []*nbdns.NameServerGroup) []*proto.NameServerGroupRaw {
if len(nsgs) == 0 {
return nil
}
out := make([]*proto.NameServerGroupRaw, 0, len(nsgs))
for _, nsg := range nsgs {
if nsg == nil {
continue
}
entry := &proto.NameServerGroupRaw{
Id: nsg.AccountSeqID,
Name: nsg.Name,
Description: nsg.Description,
Nameservers: encodeNameServers(nsg.NameServers),
GroupIds: e.groupIDsToSeq(nsg.Groups),
Primary: nsg.Primary,
Domains: nsg.Domains,
Enabled: nsg.Enabled,
SearchDomainsEnabled: nsg.SearchDomainsEnabled,
}
out = append(out, entry)
}
return out
}
func encodeNameServers(servers []nbdns.NameServer) []*proto.NameServer {
if len(servers) == 0 {
return nil
}
out := make([]*proto.NameServer, 0, len(servers))
for _, s := range servers {
out = append(out, &proto.NameServer{
IP: s.IP.String(),
NSType: int64(s.NSType),
Port: int64(s.Port),
})
}
return out
}
func encodeSimpleRecords(records []nbdns.SimpleRecord) []*proto.SimpleRecord {
if len(records) == 0 {
return nil
}
out := make([]*proto.SimpleRecord, 0, len(records))
for _, r := range records {
out = append(out, &proto.SimpleRecord{
Name: r.Name,
Type: int64(r.Type),
Class: r.Class,
TTL: int64(r.TTL),
RData: r.RData,
})
}
return out
}
func encodeCustomZones(zones []nbdns.CustomZone) []*proto.CustomZone {
if len(zones) == 0 {
return nil
}
out := make([]*proto.CustomZone, 0, len(zones))
for _, z := range zones {
out = append(out, &proto.CustomZone{
Domain: z.Domain,
Records: encodeSimpleRecords(z.Records),
SearchDomainDisabled: z.SearchDomainDisabled,
NonAuthoritative: z.NonAuthoritative,
})
}
return out
}
func (e *componentEncoder) encodeNetworkResources(resources []*resourceTypes.NetworkResource) []*proto.NetworkResourceRaw {
if len(resources) == 0 {
return nil
}
out := make([]*proto.NetworkResourceRaw, 0, len(resources))
for _, r := range resources {
if r == nil {
continue
}
entry := &proto.NetworkResourceRaw{
Id: r.AccountSeqID,
Name: r.Name,
Description: r.Description,
Type: string(r.Type),
Address: r.Address,
DomainValue: r.Domain,
Enabled: r.Enabled,
}
if seq, ok := e.networkSeq(r.NetworkID); ok {
entry.NetworkSeq = seq
}
if r.Prefix.IsValid() {
entry.PrefixCidr = r.Prefix.String()
}
out = append(out, entry)
}
return out
}
func (e *componentEncoder) encodeRoutersMap(routersMap map[string]map[string]*routerTypes.NetworkRouter) map[uint32]*proto.NetworkRouterList {
if len(routersMap) == 0 {
return nil
}
out := make(map[uint32]*proto.NetworkRouterList, len(routersMap))
for networkXID, routers := range routersMap {
if len(routers) == 0 {
continue
}
netSeq, ok := e.networkSeq(networkXID)
if !ok {
continue
}
entries := make([]*proto.NetworkRouterEntry, 0, len(routers))
for peerID, r := range routers {
if r == nil {
continue
}
entry := &proto.NetworkRouterEntry{
Id: r.AccountSeqID,
PeerGroupIds: e.groupIDsToSeq(r.PeerGroups),
Masquerade: r.Masquerade,
Metric: int32(r.Metric),
Enabled: r.Enabled,
}
if idx, ok := e.peerOrder[peerID]; ok {
entry.PeerIndexSet = true
entry.PeerIndex = idx
}
entries = append(entries, entry)
}
out[netSeq] = &proto.NetworkRouterList{Entries: entries}
}
return out
}
func (e *componentEncoder) encodeResourcePoliciesMap(rpm map[string][]*types.Policy, policyToIdxs map[*types.Policy][]uint32) map[uint32]*proto.PolicyIndexes {
if len(rpm) == 0 {
return nil
}
// resourceXIDToSeq is local to one encode — built from components.NetworkResources
// (small slice). Network resources without seq id are dropped, matching how
// other components-without-seq are silently filtered.
resourceXIDToSeq := make(map[string]uint32, len(e.components.NetworkResources))
for _, r := range e.components.NetworkResources {
if r != nil && r.AccountSeqID != 0 {
resourceXIDToSeq[r.ID] = r.AccountSeqID
}
}
out := make(map[uint32]*proto.PolicyIndexes, len(rpm))
for resourceXID, policies := range rpm {
seq, ok := resourceXIDToSeq[resourceXID]
if !ok {
continue
}
idxs := make([]uint32, 0, len(policies)*2)
for _, pol := range policies {
idxs = append(idxs, policyToIdxs[pol]...)
}
if len(idxs) == 0 {
continue
}
out[seq] = &proto.PolicyIndexes{Indexes: idxs}
}
return out
}
func (e *componentEncoder) encodeGroupIDToUserIDs(m map[string][]string) map[uint32]*proto.UserIDList {
if len(m) == 0 {
return nil
}
out := make(map[uint32]*proto.UserIDList, len(m))
for groupID, userIDs := range m {
seq, ok := e.groupSeq(groupID)
if !ok || len(userIDs) == 0 {
continue
}
out[seq] = &proto.UserIDList{UserIds: userIDs}
}
return out
}
func stringSetToSlice(s map[string]struct{}) []string {
if len(s) == 0 {
return nil
}
out := make([]string, 0, len(s))
for k := range s {
out = append(out, k)
}
return out
}
func (e *componentEncoder) encodePostureFailedPeers(m map[string]map[string]struct{}) map[uint32]*proto.PeerIndexSet {
if len(m) == 0 {
return nil
}
out := make(map[uint32]*proto.PeerIndexSet, len(m))
for checkXID, failedPeerIDs := range m {
seq, ok := e.components.PostureCheckXIDToSeq[checkXID]
if !ok || seq == 0 {
continue
}
idxs := make([]uint32, 0, len(failedPeerIDs))
for peerID := range failedPeerIDs {
if idx, ok := e.peerOrder[peerID]; ok {
idxs = append(idxs, idx)
}
}
if len(idxs) == 0 {
continue
}
out[seq] = &proto.PeerIndexSet{PeerIndexes: idxs}
}
return out
}
// toAccountSettingsCompact always returns a non-nil message — the client
// dereferences it unconditionally during Calculate(), so a nil here would
// crash the receiver. A missing types.AccountSettingsInfo on the server
// (which shouldn't happen in production but the encoder is exported)
// degrades to login_expiration_enabled = false, which makes
// LoginExpired() return false for every peer.
func toAccountSettingsCompact(s *types.AccountSettingsInfo) *proto.AccountSettingsCompact {
if s == nil {
return &proto.AccountSettingsCompact{}
}
return &proto.AccountSettingsCompact{
PeerLoginExpirationEnabled: s.PeerLoginExpirationEnabled,
PeerLoginExpirationNs: int64(s.PeerLoginExpiration),
}
}
func toAccountNetwork(n *types.Network) *proto.AccountNetwork {
if n == nil {
return nil
}
out := &proto.AccountNetwork{
Identifier: n.Identifier,
NetCidr: n.Net.String(),
Dns: n.Dns,
Serial: n.CurrentSerial(),
}
if len(n.NetV6.IP) > 0 {
out.NetV6Cidr = n.NetV6.String()
}
return out
}
func toPeerCompact(p *nbpeer.Peer, agentVersionIdx uint32) *proto.PeerCompact {
pc := &proto.PeerCompact{
WgPubKey: decodeWgKey(p.Key),
SshPubKey: []byte(p.SSHKey),
DnsLabel: p.DNSLabel,
AgentVersionIdx: agentVersionIdx,
AddedWithSsoLogin: p.UserID != "",
LoginExpirationEnabled: p.LoginExpirationEnabled,
SshEnabled: p.SSHEnabled,
SupportsIpv6: p.SupportsIPv6(),
SupportsSourcePrefixes: p.SupportsSourcePrefixes(),
ServerSshAllowed: p.Meta.Flags.ServerSSHAllowed,
}
if p.LastLogin != nil {
pc.LastLoginUnixNano = p.LastLogin.UnixNano()
}
switch {
case !p.IP.IsValid():
// leave Ip nil
case p.IP.Is4() || p.IP.Is4In6():
ip := p.IP.Unmap().As4()
pc.Ip = ip[:]
default:
ip := p.IP.As16()
pc.Ip = ip[:]
}
if p.IPv6.IsValid() {
ip := p.IPv6.As16()
pc.Ipv6 = ip[:]
}
return pc
}
// decodeWgKey returns the raw 32 bytes of a base64-encoded WireGuard public
// key, or nil for an empty / malformed key.
func decodeWgKey(s string) []byte {
if s == "" {
return nil
}
out := make([]byte, wgKeyRawLen)
n, err := base64.StdEncoding.Decode(out, []byte(s))
if err != nil || n != wgKeyRawLen {
return nil
}
return out
}
func portsToUint32(ports []string) []uint32 {
if len(ports) == 0 {
return nil
}
out := make([]uint32, 0, len(ports))
for _, p := range ports {
v, err := strconv.ParseUint(p, 10, 16)
if err != nil {
continue
}
out = append(out, uint32(v))
}
return out
}
func portRangesToProto(ranges []types.RulePortRange) []*proto.PortInfo_Range {
if len(ranges) == 0 {
return nil
}
out := make([]*proto.PortInfo_Range, 0, len(ranges))
for _, r := range ranges {
out = append(out, &proto.PortInfo_Range{
Start: uint32(r.Start),
End: uint32(r.End),
})
}
return out
}

View File

@@ -0,0 +1,879 @@
package grpc
import (
"bytes"
"cmp"
"net"
"net/netip"
"slices"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
goproto "google.golang.org/protobuf/proto"
nbdns "github.com/netbirdio/netbird/dns"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/types"
nbroute "github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/proto"
)
const testWgKeyA = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopq="
const testWgKeyB = "BBCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopq="
const testWgKeyC = "CBCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopq="
// canonicalize rewrites a NetworkMapComponentsFull in place into a canonical
// form: peers reordered by wg_pub_key, with the rest of the message rewritten
// to reference the new peer indexes. Groups, policies, and router indexes are
// also sorted. After canonicalize, two envelopes built from the same logical
// input compare byte-equal via proto.Equal.
//
// This lives on the test side — the encoder itself emits in map-iteration
// order. Test-side normalization is the contract for "two encodes are
// equivalent".
func canonicalize(full *proto.NetworkMapComponentsFull) {
if full == nil {
return
}
// Canonicalize agent_versions first: sort the slice and rewrite each
// peer's AgentVersionIdx accordingly. The empty placeholder stays at
// index 0 by convention.
avRemap := make(map[uint32]uint32, len(full.AgentVersions))
if len(full.AgentVersions) > 0 {
// Pair version → original index, sort, rebuild.
type avEntry struct {
version string
oldIdx uint32
}
entries := make([]avEntry, len(full.AgentVersions))
for i, v := range full.AgentVersions {
entries[i] = avEntry{version: v, oldIdx: uint32(i)}
}
// Empty stays at 0; sort the rest by string. Tiebreaker on oldIdx
// keeps the canonicalize output stable when two entries compare
// equal (the encoder dedups, but defending against future inputs).
slices.SortFunc(entries, func(a, b avEntry) int {
if a.version == "" && b.version != "" {
return -1
}
if b.version == "" && a.version != "" {
return 1
}
if c := cmp.Compare(a.version, b.version); c != 0 {
return c
}
return cmp.Compare(a.oldIdx, b.oldIdx)
})
newVersions := make([]string, len(entries))
for newIdx, e := range entries {
avRemap[e.oldIdx] = uint32(newIdx)
newVersions[newIdx] = e.version
}
full.AgentVersions = newVersions
}
for _, p := range full.Peers {
if newIdx, ok := avRemap[p.AgentVersionIdx]; ok {
p.AgentVersionIdx = newIdx
}
}
type peerEntry struct {
peer *proto.PeerCompact
oldIdx uint32
}
entries := make([]peerEntry, len(full.Peers))
for i, p := range full.Peers {
entries[i] = peerEntry{peer: p, oldIdx: uint32(i)}
}
// DnsLabel is unique per peer; it tiebreaks on equal WgPubKey (e.g. both
// nil from malformed keys, or both empty for placeholders).
slices.SortFunc(entries, func(a, b peerEntry) int {
if c := bytes.Compare(a.peer.WgPubKey, b.peer.WgPubKey); c != 0 {
return c
}
return cmp.Compare(a.peer.DnsLabel, b.peer.DnsLabel)
})
remap := make(map[uint32]uint32, len(entries))
newPeers := make([]*proto.PeerCompact, len(entries))
for newIdx, e := range entries {
remap[e.oldIdx] = uint32(newIdx)
newPeers[newIdx] = e.peer
}
full.Peers = newPeers
full.RouterPeerIndexes = remapAndSort(full.RouterPeerIndexes, remap)
for _, g := range full.Groups {
g.PeerIndexes = remapAndSort(g.PeerIndexes, remap)
}
slices.SortFunc(full.Groups, func(a, b *proto.GroupCompact) int { return cmp.Compare(a.Id, b.Id) })
for _, r := range full.Routes {
if r.PeerIndexSet {
if newIdx, ok := remap[r.PeerIndex]; ok {
r.PeerIndex = newIdx
}
}
slices.Sort(r.GroupIds)
slices.Sort(r.AccessControlGroupIds)
slices.Sort(r.PeerGroupIds)
}
slices.SortFunc(full.Routes, func(a, b *proto.RouteRaw) int { return cmp.Compare(a.Id, b.Id) })
for _, list := range full.RoutersMap {
for _, entry := range list.Entries {
if entry.PeerIndexSet {
if newIdx, ok := remap[entry.PeerIndex]; ok {
entry.PeerIndex = newIdx
}
}
slices.Sort(entry.PeerGroupIds)
}
slices.SortFunc(list.Entries, func(a, b *proto.NetworkRouterEntry) int { return cmp.Compare(a.Id, b.Id) })
}
for _, set := range full.PostureFailedPeers {
set.PeerIndexes = remapAndSort(set.PeerIndexes, remap)
}
for _, p := range full.Policies {
slices.Sort(p.SourceGroupIds)
slices.Sort(p.DestinationGroupIds)
}
// Sort policies by (Id, source_group_ids, destination_group_ids) so that
// multiple PolicyCompact entries sharing the same Id (one per rule, when
// a Policy has multiple rules) still get a deterministic order. After
// sorting we remap indexes in ResourcePoliciesMap.
policyOldOrder := make(map[*proto.PolicyCompact]uint32, len(full.Policies))
for i, p := range full.Policies {
policyOldOrder[p] = uint32(i)
}
slices.SortFunc(full.Policies, func(a, b *proto.PolicyCompact) int {
if c := cmp.Compare(a.Id, b.Id); c != 0 {
return c
}
if c := slices.Compare(a.SourceGroupIds, b.SourceGroupIds); c != 0 {
return c
}
return slices.Compare(a.DestinationGroupIds, b.DestinationGroupIds)
})
policyRemap := make(map[uint32]uint32, len(full.Policies))
for newIdx, p := range full.Policies {
policyRemap[policyOldOrder[p]] = uint32(newIdx)
}
for _, idxs := range full.ResourcePoliciesMap {
idxs.Indexes = remapAndSort(idxs.Indexes, policyRemap)
}
for _, list := range full.GroupIdToUserIds {
slices.Sort(list.UserIds)
}
slices.Sort(full.AllowedUserIds)
}
func remapAndSort(idxs []uint32, remap map[uint32]uint32) []uint32 {
out := make([]uint32, 0, len(idxs))
for _, i := range idxs {
if newIdx, ok := remap[i]; ok {
out = append(out, newIdx)
}
}
slices.Sort(out)
return out
}
// envelopesEquivalent decodes both envelopes, canonicalizes them, and reports
// whether they're proto.Equal. Use instead of byte-comparing marshaled output:
// the encoder is intentionally non-deterministic.
func envelopesEquivalent(a, b *proto.NetworkMapEnvelope) bool {
canonicalize(a.GetFull())
canonicalize(b.GetFull())
return goproto.Equal(a, b)
}
func newTestComponents() *types.NetworkMapComponents {
peerA := &nbpeer.Peer{
ID: "peer-a",
Key: testWgKeyA,
IP: netip.AddrFrom4([4]byte{100, 64, 0, 1}),
DNSLabel: "peera",
SSHKey: "ssh-a",
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()},
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.40.0"},
}
peerB := &nbpeer.Peer{
ID: "peer-b",
Key: testWgKeyB,
IP: netip.AddrFrom4([4]byte{100, 64, 0, 2}),
IPv6: netip.AddrFrom16([16]byte{0xfd, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2}),
DNSLabel: "peerb",
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.25.0"},
}
peerC := &nbpeer.Peer{
ID: "peer-c",
Key: testWgKeyC,
IP: netip.AddrFrom4([4]byte{100, 64, 0, 3}),
DNSLabel: "peerc",
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.40.0"},
}
return &types.NetworkMapComponents{
PeerID: "peer-a",
Network: &types.Network{
Identifier: "net-test",
Net: net.IPNet{IP: net.IP{100, 64, 0, 0}, Mask: net.CIDRMask(10, 32)},
Serial: 7,
},
AccountSettings: &types.AccountSettingsInfo{
PeerLoginExpirationEnabled: true,
PeerLoginExpiration: 2 * time.Hour,
},
Peers: map[string]*nbpeer.Peer{
"peer-a": peerA,
"peer-b": peerB,
"peer-c": peerC,
},
Groups: map[string]*types.Group{
"group-src": {ID: "group-src", AccountSeqID: 1, Name: "Src", Peers: []string{"peer-a"}},
"group-dst": {ID: "group-dst", AccountSeqID: 2, Name: "Dst", Peers: []string{"peer-b", "peer-c"}},
},
Policies: []*types.Policy{
{
ID: "pol-1",
AccountSeqID: 10,
Enabled: true,
Rules: []*types.PolicyRule{{
ID: "rule-1", Enabled: true, Action: types.PolicyTrafficActionAccept,
Protocol: types.PolicyRuleProtocolTCP, Bidirectional: true,
Ports: []string{"22", "80"},
PortRanges: []types.RulePortRange{{Start: 8000, End: 8100}},
Sources: []string{"group-src"},
Destinations: []string{"group-dst"},
}},
},
},
RouterPeers: map[string]*nbpeer.Peer{"peer-c": peerC},
}
}
func TestEncodeNetworkMapEnvelope_Basic(t *testing.T) {
c := newTestComponents()
env := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{
Components: c,
DNSDomain: "netbird.cloud",
})
require.NotNil(t, env)
full := env.GetFull()
require.NotNil(t, full, "envelope must contain Full payload")
assert.EqualValues(t, 7, full.Serial)
assert.Equal(t, "netbird.cloud", full.DnsDomain)
require.NotNil(t, full.Network)
assert.Equal(t, "net-test", full.Network.Identifier)
assert.Equal(t, "100.64.0.0/10", full.Network.NetCidr)
require.NotNil(t, full.AccountSettings)
assert.True(t, full.AccountSettings.PeerLoginExpirationEnabled)
assert.EqualValues(t, (2 * time.Hour).Nanoseconds(), full.AccountSettings.PeerLoginExpirationNs)
require.Len(t, full.Peers, 3)
byLabel := map[string]*proto.PeerCompact{}
for _, p := range full.Peers {
assert.Len(t, p.WgPubKey, 32, "wg key must be raw 32 bytes")
assert.Len(t, p.Ip, 4, "ipv4 must be raw 4 bytes")
byLabel[p.DnsLabel] = p
}
assert.Len(t, byLabel["peerb"].Ipv6, 16, "peer-b has ipv6 → 16 bytes")
}
func TestEncodeNetworkMapEnvelope_RepeatEncodesEquivalent(t *testing.T) {
c := newTestComponents()
expected := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c})
// Hammer it 100 times — Go map iteration is randomized per call, so each
// run produces different wire bytes, but the canonicalized form must
// match.
for i := 0; i < 100; i++ {
got := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c})
require.True(t, envelopesEquivalent(expected, got),
"encode #%d must be semantically equivalent to first encode", i)
}
}
func TestEncodeNetworkMapEnvelope_ConcurrentEncodesEquivalent(t *testing.T) {
c := newTestComponents()
expected := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c})
const goroutines = 50
var wg sync.WaitGroup
wg.Add(goroutines)
results := make([]*proto.NetworkMapEnvelope, goroutines)
for i := 0; i < goroutines; i++ {
i := i
go func() {
defer wg.Done()
results[i] = EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c})
}()
}
wg.Wait()
for i, got := range results {
require.NotNil(t, got, "goroutine %d returned nil", i)
require.True(t, envelopesEquivalent(expected, got),
"goroutine %d produced inequivalent envelope", i)
}
}
func TestEncodeNetworkMapEnvelope_GroupsByAccountSeqID(t *testing.T) {
c := newTestComponents()
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
require.Len(t, full.Groups, 2)
groupByID := map[uint32]*proto.GroupCompact{}
for _, g := range full.Groups {
groupByID[g.Id] = g
}
require.Contains(t, groupByID, uint32(1))
require.Contains(t, groupByID, uint32(2))
assert.Equal(t, "Src", groupByID[1].Name)
assert.Equal(t, "Dst", groupByID[2].Name)
assert.Len(t, groupByID[1].PeerIndexes, 1)
assert.Len(t, groupByID[2].PeerIndexes, 2)
}
func TestEncodeNetworkMapEnvelope_PolicyExpansion(t *testing.T) {
c := newTestComponents()
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
require.Len(t, full.Policies, 1)
pc := full.Policies[0]
assert.EqualValues(t, 10, pc.Id)
assert.Equal(t, proto.RuleAction_ACCEPT, pc.Action)
assert.Equal(t, proto.RuleProtocol_TCP, pc.Protocol)
assert.True(t, pc.Bidirectional)
assert.Equal(t, []uint32{22, 80}, pc.Ports)
require.Len(t, pc.PortRanges, 1)
assert.EqualValues(t, 8000, pc.PortRanges[0].Start)
assert.EqualValues(t, 8100, pc.PortRanges[0].End)
assert.Equal(t, []uint32{1}, pc.SourceGroupIds)
assert.Equal(t, []uint32{2}, pc.DestinationGroupIds)
}
func TestEncodeNetworkMapEnvelope_RouterIndexes(t *testing.T) {
c := newTestComponents()
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
require.Len(t, full.RouterPeerIndexes, 1)
idx := full.RouterPeerIndexes[0]
require.Less(t, int(idx), len(full.Peers))
assert.Equal(t, "peerc", full.Peers[idx].DnsLabel)
}
func TestEncodeNetworkMapEnvelope_AgentVersionDedup(t *testing.T) {
c := newTestComponents()
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
require.Len(t, full.AgentVersions, 3, "empty placeholder + 2 distinct versions")
assert.Equal(t, "", full.AgentVersions[0], "index 0 reserved for empty version")
assert.ElementsMatch(t, []string{"0.40.0", "0.25.0"}, full.AgentVersions[1:],
"two distinct versions, order depends on map iteration")
idxByLabel := map[string]uint32{}
for _, p := range full.Peers {
idxByLabel[p.DnsLabel] = p.AgentVersionIdx
}
assert.Equal(t, idxByLabel["peera"], idxByLabel["peerc"], "peers with the same agent version share an index")
assert.NotEqual(t, idxByLabel["peera"], idxByLabel["peerb"])
}
func TestEncodeNetworkMapEnvelope_DisabledPolicySkipped(t *testing.T) {
c := newTestComponents()
c.Policies[0].Enabled = false
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
assert.Empty(t, full.Policies)
}
func TestEncodeNetworkMapEnvelope_GroupZeroSeqIDSkipped(t *testing.T) {
c := newTestComponents()
c.Groups["group-src"].AccountSeqID = 0
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
require.Len(t, full.Groups, 1, "groups with AccountSeqID=0 are not yet persisted and must be skipped")
assert.EqualValues(t, 2, full.Groups[0].Id)
require.Len(t, full.Policies, 1)
pc := full.Policies[0]
assert.Empty(t, pc.SourceGroupIds, "rule references a group that was filtered out → no group id on wire")
assert.Equal(t, []uint32{2}, pc.DestinationGroupIds)
}
func TestEncodeNetworkMapEnvelope_TwoPeersSameMalformedKey(t *testing.T) {
// Both peers have nil WgPubKey after decode; canonicalize must still
// produce a stable order using DnsLabel as a tiebreaker, so 100 encodes
// canonicalize identically.
c := newTestComponents()
c.Peers["peer-a"].Key = "garbage-a-!!!"
c.Peers["peer-b"].Key = "garbage-b-!!!"
expected := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c})
for i := 0; i < 100; i++ {
got := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c})
require.True(t, envelopesEquivalent(expected, got),
"encode #%d with two same-key peers must canonicalize equivalently", i)
}
}
func TestEncodeNetworkMapEnvelope_MalformedWgKey(t *testing.T) {
c := newTestComponents()
c.Peers["peer-a"].Key = "not-base64-!!!"
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
require.Len(t, full.Peers, 3)
var byLabel = map[string]*proto.PeerCompact{}
for _, p := range full.Peers {
byLabel[p.DnsLabel] = p
}
assert.Nil(t, byLabel["peera"].WgPubKey, "peer with malformed key encodes nil WgPubKey")
assert.Len(t, byLabel["peerb"].WgPubKey, 32, "other peers retain their key")
}
func TestEncodeNetworkMapEnvelope_IPv6OnlyPeer(t *testing.T) {
c := newTestComponents()
v6Only := &nbpeer.Peer{
ID: "peer-v6",
Key: testWgKeyA,
IPv6: netip.AddrFrom16([16]byte{0xfd, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 9}),
DNSLabel: "peerv6",
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.40.0"},
}
c.Peers["peer-v6"] = v6Only
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
var found *proto.PeerCompact
for _, p := range full.Peers {
if p.DnsLabel == "peerv6" {
found = p
}
}
require.NotNil(t, found, "ipv6-only peer must be present")
assert.Empty(t, found.Ip, "no IPv4 address → empty Ip")
assert.Len(t, found.Ipv6, 16)
}
func TestEncodeNetworkMapEnvelope_PeerWithoutIP(t *testing.T) {
c := newTestComponents()
c.Peers["peer-noip"] = &nbpeer.Peer{
ID: "peer-noip",
Key: testWgKeyA,
DNSLabel: "peernoip",
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.40.0"},
}
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
var found *proto.PeerCompact
for _, p := range full.Peers {
if p.DnsLabel == "peernoip" {
found = p
}
}
require.NotNil(t, found)
assert.Empty(t, found.Ip)
assert.Empty(t, found.Ipv6)
}
func TestEncodeNetworkMapEnvelope_EmptyInput(t *testing.T) {
c := &types.NetworkMapComponents{
Network: &types.Network{Identifier: "x", Net: net.IPNet{IP: net.IP{100, 64, 0, 0}, Mask: net.CIDRMask(10, 32)}},
}
env := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c})
full := env.GetFull()
require.NotNil(t, full)
assert.Empty(t, full.Peers)
assert.Empty(t, full.Groups)
assert.Empty(t, full.Policies)
assert.Empty(t, full.RouterPeerIndexes)
require.NotNil(t, full.AccountSettings, "AccountSettingsCompact must always be emitted (client dereferences it unconditionally)")
}
func TestEncodeNetworkMapEnvelope_PeerLoginExpirationFields(t *testing.T) {
c := newTestComponents()
now := time.Date(2024, 1, 2, 3, 4, 5, 0, time.UTC)
c.Peers["peer-a"].UserID = "user-1"
c.Peers["peer-a"].LoginExpirationEnabled = true
c.Peers["peer-a"].LastLogin = &now
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
var pa *proto.PeerCompact
for _, p := range full.Peers {
if p.DnsLabel == "peera" {
pa = p
}
}
require.NotNil(t, pa)
assert.True(t, pa.AddedWithSsoLogin)
assert.True(t, pa.LoginExpirationEnabled)
assert.Equal(t, now.UnixNano(), pa.LastLoginUnixNano)
// peer-b has no UserID and no LastLogin → all fields zero-value.
var pb *proto.PeerCompact
for _, p := range full.Peers {
if p.DnsLabel == "peerb" {
pb = p
}
}
require.NotNil(t, pb)
assert.False(t, pb.AddedWithSsoLogin)
assert.False(t, pb.LoginExpirationEnabled)
assert.Zero(t, pb.LastLoginUnixNano)
}
func TestEncodeNetworkMapEnvelope_RoutesRoundTrip(t *testing.T) {
c := newTestComponents()
c.Routes = []*nbroute.Route{
{
ID: "route-peer",
AccountSeqID: 100,
NetID: "net-A",
Description: "via peer-c",
Network: netip.MustParsePrefix("10.0.0.0/16"),
Peer: "peer-c", // peer ID, not WG key
Groups: []string{"group-src"},
AccessControlGroups: []string{"group-dst"},
Enabled: true,
},
{
ID: "route-peergroup",
AccountSeqID: 101,
NetID: "net-B",
Network: netip.MustParsePrefix("10.1.0.0/16"),
PeerGroups: []string{"group-src", "group-dst"},
Enabled: true,
},
{
ID: "route-no-seq",
AccountSeqID: 0, // unset — should still ship (no group seq filter on routes)
Network: netip.MustParsePrefix("10.2.0.0/16"),
Enabled: true,
},
}
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
require.Len(t, full.Routes, 3)
byNetID := map[string]*proto.RouteRaw{}
for _, r := range full.Routes {
byNetID[r.NetId] = r
}
r1 := byNetID["net-A"]
require.NotNil(t, r1)
assert.True(t, r1.PeerIndexSet, "route with peer must set peer_index_set")
require.Less(t, int(r1.PeerIndex), len(full.Peers))
assert.Equal(t, "peerc", full.Peers[r1.PeerIndex].DnsLabel)
assert.Equal(t, []uint32{1}, r1.GroupIds, "group-src has AccountSeqID 1")
assert.Equal(t, []uint32{2}, r1.AccessControlGroupIds, "group-dst has AccountSeqID 2")
assert.Empty(t, r1.PeerGroupIds)
r2 := byNetID["net-B"]
require.NotNil(t, r2)
assert.False(t, r2.PeerIndexSet, "route with peer_groups must NOT set peer_index_set")
assert.ElementsMatch(t, []uint32{1, 2}, r2.PeerGroupIds)
}
func TestEncodeNetworkMapEnvelope_RouteWithMissingPeerLeavesIndexUnset(t *testing.T) {
c := newTestComponents()
c.Routes = []*nbroute.Route{{
ID: "route-x",
AccountSeqID: 100,
Peer: "peer-not-in-components",
Network: netip.MustParsePrefix("10.0.0.0/16"),
Enabled: true,
}}
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
require.Len(t, full.Routes, 1)
assert.False(t, full.Routes[0].PeerIndexSet,
"missing peer reference must not pretend to point at peer index 0")
}
func TestEncodeNetworkMapEnvelope_ResourceOnlyPolicyShippedAndIndexed(t *testing.T) {
c := newTestComponents()
// Policy that exists ONLY in ResourcePoliciesMap, not in c.Policies. This
// is the I1 case — without unionPolicies the encoder would silently
// drop it from the wire.
resourceOnlyPolicy := &types.Policy{
ID: "pol-resource", AccountSeqID: 99, Enabled: true,
Rules: []*types.PolicyRule{{
ID: "rule-r", Enabled: true, Action: types.PolicyTrafficActionAccept,
Protocol: types.PolicyRuleProtocolTCP,
Sources: []string{"group-src"},
Destinations: []string{"group-dst"},
}},
}
c.ResourcePoliciesMap = map[string][]*types.Policy{
"resource-x": {c.Policies[0], resourceOnlyPolicy}, // shared + resource-only
}
// Resource must appear in components.NetworkResources with a seq id —
// encoder uses that to translate the xid map key to uint32.
c.NetworkResources = []*resourceTypes.NetworkResource{
{ID: "resource-x", AccountSeqID: 77, Name: "res-x", Enabled: true},
}
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
require.Len(t, full.Policies, 2, "encoded policies must include both peer-traffic and resource-only")
policyByID := map[uint32]*proto.PolicyCompact{}
policyIdxByID := map[uint32]uint32{}
for i, p := range full.Policies {
policyByID[p.Id] = p
policyIdxByID[p.Id] = uint32(i)
}
require.Contains(t, policyByID, uint32(10), "original peer-traffic policy id 10")
require.Contains(t, policyByID, uint32(99), "resource-only policy id 99")
require.Contains(t, full.ResourcePoliciesMap, uint32(77))
idxs := full.ResourcePoliciesMap[77].Indexes
require.Len(t, idxs, 2)
assert.ElementsMatch(t, []uint32{policyIdxByID[10], policyIdxByID[99]}, idxs,
"resource policies map must reference both wire policy indexes")
}
func TestEncodeNetworkMapEnvelope_NameServerGroups(t *testing.T) {
c := newTestComponents()
c.NameServerGroups = []*nbdns.NameServerGroup{{
ID: "nsg-1", AccountSeqID: 50, Name: "Main", Description: "primary",
NameServers: []nbdns.NameServer{{
IP: netip.MustParseAddr("8.8.8.8"), NSType: nbdns.UDPNameServerType, Port: 53,
}},
Groups: []string{"group-src", "group-not-persisted"},
Primary: true, Enabled: true,
Domains: []string{"corp.example"},
}}
c.Groups["group-not-persisted"] = &types.Group{ID: "group-not-persisted", AccountSeqID: 0, Peers: []string{}}
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
require.Len(t, full.NameserverGroups, 1)
nsg := full.NameserverGroups[0]
assert.EqualValues(t, 50, nsg.Id)
assert.Equal(t, "Main", nsg.Name)
assert.True(t, nsg.Primary)
require.Len(t, nsg.Nameservers, 1)
assert.Equal(t, "8.8.8.8", nsg.Nameservers[0].IP)
assert.Equal(t, []uint32{1}, nsg.GroupIds, "group-not-persisted is filtered out (AccountSeqID=0)")
}
func TestEncodeNetworkMapEnvelope_PostureFailedPeers(t *testing.T) {
c := newTestComponents()
c.PostureCheckXIDToSeq = map[string]uint32{"check-1": 33}
c.PostureFailedPeers = map[string]map[string]struct{}{
"check-1": {
"peer-a": {},
"peer-b": {},
"peer-not-in-account": {},
},
}
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
require.Contains(t, full.PostureFailedPeers, uint32(33))
idxs := full.PostureFailedPeers[33].PeerIndexes
assert.Len(t, idxs, 2, "missing peer is silently dropped (filterPostureFailedPeers guarantees presence in real data)")
}
func TestEncodeNetworkMapEnvelope_RoutersMap(t *testing.T) {
c := newTestComponents()
c.NetworkXIDToSeq = map[string]uint32{"net-1": 5}
c.RoutersMap = map[string]map[string]*routerTypes.NetworkRouter{
"net-1": {
"peer-c": {
ID: "router-1", AccountSeqID: 200,
Peer: "peer-c", Masquerade: true, Metric: 10, Enabled: true,
},
},
}
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
require.Contains(t, full.RoutersMap, uint32(5))
entries := full.RoutersMap[5].Entries
require.Len(t, entries, 1)
e := entries[0]
assert.EqualValues(t, 200, e.Id)
assert.True(t, e.PeerIndexSet)
require.Less(t, int(e.PeerIndex), len(full.Peers))
assert.Equal(t, "peerc", full.Peers[e.PeerIndex].DnsLabel)
assert.True(t, e.Masquerade)
assert.EqualValues(t, 10, e.Metric)
assert.True(t, e.Enabled)
}
func TestEncodeNetworkMapEnvelope_RouterPeerNotInComponentsPeers(t *testing.T) {
// Router peer in c.RouterPeers but NOT in c.Peers (validation may have
// filtered it). indexRouterPeers runs before encodeRoutersMap, so the
// peer_index reference must still resolve.
c := newTestComponents()
delete(c.Peers, "peer-c")
routerPeer := &nbpeer.Peer{
ID: "peer-c", Key: testWgKeyC, IP: netip.AddrFrom4([4]byte{100, 64, 0, 3}),
DNSLabel: "peerc", Meta: nbpeer.PeerSystemMeta{WtVersion: "0.40.0"},
}
c.RouterPeers = map[string]*nbpeer.Peer{"peer-c": routerPeer}
c.NetworkXIDToSeq = map[string]uint32{"net-1": 5}
c.RoutersMap = map[string]map[string]*routerTypes.NetworkRouter{
"net-1": {"peer-c": {ID: "r-1", AccountSeqID: 1, Peer: "peer-c", Enabled: true}},
}
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
require.Contains(t, full.RoutersMap, uint32(5))
require.Len(t, full.RoutersMap[5].Entries, 1)
e := full.RoutersMap[5].Entries[0]
assert.True(t, e.PeerIndexSet, "router peer must be indexed even when not in c.Peers")
}
func TestEncodeNetworkMapEnvelope_DNSSettingsFiltersUnpersistedGroups(t *testing.T) {
c := newTestComponents()
c.DNSSettings = &types.DNSSettings{
DisabledManagementGroups: []string{"group-src", "group-missing", "group-no-seq"},
}
c.Groups["group-no-seq"] = &types.Group{ID: "group-no-seq", AccountSeqID: 0}
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
require.NotNil(t, full.DnsSettings)
assert.Equal(t, []uint32{1}, full.DnsSettings.DisabledManagementGroupIds,
"only group-src (AccountSeqID=1) survives — missing and unpersisted are dropped")
}
func TestEncodeNetworkMapEnvelope_GroupIDToUserIDs(t *testing.T) {
c := newTestComponents()
c.GroupIDToUserIDs = map[string][]string{
"group-src": {"user-1", "user-2"},
"group-no-seq": {"user-3"}, // group not persisted → drop
"group-missing": {"user-4"}, // group not in components → drop
}
c.Groups["group-no-seq"] = &types.Group{ID: "group-no-seq", AccountSeqID: 0}
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
require.Len(t, full.GroupIdToUserIds, 1, "only persisted+present groups survive")
require.Contains(t, full.GroupIdToUserIds, uint32(1))
assert.ElementsMatch(t, []string{"user-1", "user-2"}, full.GroupIdToUserIds[1].UserIds)
}
func TestToProxyPatch_EmptyInputReturnsNil(t *testing.T) {
assert.Nil(t, toProxyPatch(nil, "netbird.cloud", false, false))
assert.Nil(t, toProxyPatch(&types.NetworkMap{}, "netbird.cloud", false, false),
"empty NetworkMap (no peers, rules, routes etc) → nil patch so proto3 omits the field")
}
func TestToProxyPatch_PopulatesAllFields(t *testing.T) {
nm := &types.NetworkMap{
Peers: []*nbpeer.Peer{{
ID: "ext-peer", Key: testWgKeyA, IP: netip.AddrFrom4([4]byte{100, 64, 0, 9}),
DNSLabel: "extpeer", Meta: nbpeer.PeerSystemMeta{WtVersion: "0.40.0"},
}},
FirewallRules: []*types.FirewallRule{{
PeerIP: "100.64.0.9", Action: "accept", Direction: 0, Protocol: "tcp",
}},
}
patch := toProxyPatch(nm, "netbird.cloud", false, false)
require.NotNil(t, patch)
assert.Len(t, patch.Peers, 1)
assert.Len(t, patch.FirewallRules, 1)
}
// TestEncodeNetworkMapEnvelope_ProxyPatchPropagated covers the ProxyPatch
// pass-through in both encoder branches (normal path + nil-Components
// graceful-degrade). Without this test a regression that drops `ProxyPatch:`
// from one of the struct literals in components_encoder.go would slip past CI.
func TestEncodeNetworkMapEnvelope_ProxyPatchPropagated(t *testing.T) {
patch := &proto.ProxyPatch{
ForwardingRules: []*proto.ForwardingRule{{
Protocol: proto.RuleProtocol_TCP,
DestinationPort: &proto.PortInfo{PortSelection: &proto.PortInfo_Port{Port: 80}},
TranslatedAddress: net.IPv4(10, 0, 0, 1).To4(),
TranslatedPort: &proto.PortInfo{PortSelection: &proto.PortInfo_Port{Port: 8080}},
}},
}
t.Run("normal_path", func(t *testing.T) {
c := newTestComponents()
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{
Components: c,
ProxyPatch: patch,
}).GetFull()
require.NotNil(t, full.ProxyPatch, "ProxyPatch must propagate through the normal encode path")
assert.Len(t, full.ProxyPatch.ForwardingRules, 1)
})
t.Run("nil_components_graceful_degrade", func(t *testing.T) {
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{
Components: nil,
ProxyPatch: patch,
}).GetFull()
require.NotNil(t, full.ProxyPatch, "ProxyPatch must propagate through the nil-Components branch too")
assert.Len(t, full.ProxyPatch.ForwardingRules, 1)
})
}
func TestEncodeNetworkMapEnvelope_NilComponentsGracefulDegrade(t *testing.T) {
// nil Components → minimal envelope, no crash. Matches the legacy
// account_components.go:43 behaviour for missing/unvalidated peers.
env := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{
Components: nil,
DNSDomain: "netbird.cloud",
})
require.NotNil(t, env)
full := env.GetFull()
require.NotNil(t, full)
require.NotNil(t, full.AccountSettings, "AccountSettings must always be non-nil")
assert.Equal(t, "netbird.cloud", full.DnsDomain)
assert.Empty(t, full.Peers)
assert.Empty(t, full.Policies)
}
func TestEncodeNetworkMapEnvelope_AccountSettingsAlwaysEmitted(t *testing.T) {
c := &types.NetworkMapComponents{
Network: &types.Network{Identifier: "x", Net: net.IPNet{IP: net.IP{100, 64, 0, 0}, Mask: net.CIDRMask(10, 32)}},
// AccountSettings deliberately nil
}
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
require.NotNil(t, full.AccountSettings, "client dereferences AccountSettings unconditionally during Calculate(); a nil here would crash the receiver")
assert.False(t, full.AccountSettings.PeerLoginExpirationEnabled)
assert.Zero(t, full.AccountSettings.PeerLoginExpirationNs)
}

View File

@@ -0,0 +1,193 @@
package grpc
import (
"context"
integrationsConfig "github.com/netbirdio/management-integrations/integrations/config"
"github.com/netbirdio/netbird/client/ssh/auth"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/networkmap"
"github.com/netbirdio/netbird/shared/management/proto"
)
// ToComponentSyncResponse builds a SyncResponse carrying the compact
// NetworkMapEnvelope for capability-aware peers. The legacy proto.NetworkMap
// field is intentionally left empty — capable peers ignore it and the
// envelope alone is the authoritative wire shape.
//
// PeerConfig is computed once server-side using the receiving peer's own
// account-level network metadata. EnableSSH inside PeerConfig is left at
// peer.SSHEnabled (the peer's local setting); account-policy-driven SSH is
// computed by the client from the envelope's GroupIDToUserIDs / AllowedUserIDs
// inside Calculate(), so the SshConfig.SshEnabled bit may flip true on the
// client even though the server-side PeerConfig reports false.
func ToComponentSyncResponse(
ctx context.Context,
config *nbconfig.Config,
httpConfig *nbconfig.HttpServerConfig,
deviceFlowConfig *nbconfig.DeviceAuthorizationFlow,
peer *nbpeer.Peer,
turnCredentials *Token,
relayCredentials *Token,
components *types.NetworkMapComponents,
proxyPatch *types.NetworkMap,
dnsName string,
checks []*posture.Checks,
settings *types.Settings,
extraSettings *types.ExtraSettings,
peerGroups []string,
dnsFwdPort int64,
) *proto.SyncResponse {
network := networkOrZero(components)
enableSSH := computeSSHEnabledForPeer(components, peer)
peerConfig := toPeerConfig(peer, network, dnsName, settings, httpConfig, deviceFlowConfig, enableSSH)
includeIPv6 := peer.SupportsIPv6() && peer.IPv6.IsValid()
useSourcePrefixes := peer.SupportsSourcePrefixes()
userIDClaim := auth.DefaultUserIDClaim
if httpConfig != nil && httpConfig.AuthUserIDClaim != "" {
userIDClaim = httpConfig.AuthUserIDClaim
}
envelope := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{
Components: components,
PeerConfig: peerConfig,
DNSDomain: dnsName,
DNSForwarderPort: dnsFwdPort,
UserIDClaim: userIDClaim,
ProxyPatch: toProxyPatch(proxyPatch, dnsName, includeIPv6, useSourcePrefixes),
})
resp := &proto.SyncResponse{
PeerConfig: peerConfig,
NetworkMapEnvelope: envelope,
Checks: toProtocolChecks(ctx, checks),
}
nbConfig := toNetbirdConfig(config, turnCredentials, relayCredentials, extraSettings)
resp.NetbirdConfig = integrationsConfig.ExtendNetBirdConfig(peer.ID, peerGroups, nbConfig, extraSettings)
return resp
}
// networkOrZero returns components.Network or a zero Network — toPeerConfig
// dereferences network.Net which would panic on nil.
func networkOrZero(c *types.NetworkMapComponents) *types.Network {
if c == nil || c.Network == nil {
return &types.Network{}
}
return c.Network
}
// toProxyPatch converts a proxy-injected *types.NetworkMap into the wire
// patch the components envelope ships alongside. Returns nil when there are
// no fragments to merge — proto3 omits a nil message field, so the receiver
// sees no patch and skips the merge step entirely.
//
// We reuse the legacy proto-conversion helpers (toProtocolRoutes,
// toProtocolFirewallRules, toProtocolRoutesFirewallRules,
// appendRemotePeerConfig, ForwardingRule.ToProto) because the proxy
// delivers fragments pre-expanded — there's no raw component shape to
// derive them from. Components purity isn't violated: proxy data isn't
// policy-graph-derived, it's externally injected post-Calculate, so the
// client merges it on top of its locally-computed NetworkMap.
func toProxyPatch(nm *types.NetworkMap, dnsName string, includeIPv6, useSourcePrefixes bool) *proto.ProxyPatch {
if nm == nil {
return nil
}
if len(nm.Peers) == 0 && len(nm.OfflinePeers) == 0 && len(nm.FirewallRules) == 0 &&
len(nm.Routes) == 0 && len(nm.RoutesFirewallRules) == 0 && len(nm.ForwardingRules) == 0 {
return nil
}
patch := &proto.ProxyPatch{
Peers: networkmap.AppendRemotePeerConfig(nil, nm.Peers, dnsName, includeIPv6),
OfflinePeers: networkmap.AppendRemotePeerConfig(nil, nm.OfflinePeers, dnsName, includeIPv6),
FirewallRules: networkmap.ToProtocolFirewallRules(nm.FirewallRules, includeIPv6, useSourcePrefixes),
Routes: networkmap.ToProtocolRoutes(nm.Routes),
RouteFirewallRules: networkmap.ToProtocolRoutesFirewallRules(nm.RoutesFirewallRules),
}
if len(nm.ForwardingRules) > 0 {
patch.ForwardingRules = make([]*proto.ForwardingRule, 0, len(nm.ForwardingRules))
for _, r := range nm.ForwardingRules {
patch.ForwardingRules = append(patch.ForwardingRules, r.ToProto())
}
}
return patch
}
// computeSSHEnabledForPeer mirrors the SSH-server-activation bit that
// Calculate() folds into NetworkMap.EnableSSH. Components-format peers
// receive a freshly-computed PeerConfig.SshConfig.SshEnabled at sync time;
// without this helper the field would be incorrectly false for any peer
// that's the destination of an SSH-enabling policy without having
// peer.SSHEnabled set locally.
//
// Mirrors the two activation paths in Calculate() (`networkmap_components.go`
// `getPeerConnectionResources`):
// 1. Explicit: rule.Protocol == NetbirdSSH and peer is in the rule's
// destinations.
// 2. Legacy implicit: rule covers TCP/22 or TCP/22022 (or ALL), peer is in
// destinations, AND the peer has SSHEnabled set locally — this is the
// "allow-all/TCP-22 implies SSH activation for SSH-capable peers" path.
//
// The full SSH AuthorizedUsers map is still produced by the client when it
// runs Calculate() over the envelope.
func computeSSHEnabledForPeer(c *types.NetworkMapComponents, peer *nbpeer.Peer) bool {
if c == nil || peer == nil {
return false
}
// Mirror Calculate's `getAllPeersFromGroups` invariant: target peer must
// exist in c.Peers, otherwise no rule applies to it.
if _, ok := c.Peers[peer.ID]; !ok {
return false
}
for _, policy := range c.Policies {
if policy == nil || !policy.Enabled {
continue
}
for _, rule := range policy.Rules {
if ruleEnablesSSHForPeer(c, rule, peer) {
return true
}
}
}
return false
}
// ruleEnablesSSHForPeer returns true when rule is active, targets peer, and
// either explicitly authorises SSH or covers the legacy TCP/22 path while the
// peer itself has SSH enabled locally.
func ruleEnablesSSHForPeer(c *types.NetworkMapComponents, rule *types.PolicyRule, peer *nbpeer.Peer) bool {
if rule == nil || !rule.Enabled {
return false
}
if !peerInDestinations(c, rule, peer.ID) {
return false
}
if rule.Protocol == types.PolicyRuleProtocolNetbirdSSH {
return true
}
return peer.SSHEnabled && types.PolicyRuleImpliesLegacySSH(rule)
}
// peerInDestinations reports whether peerID is in any of rule.Destinations'
// groups (or matches DestinationResource if it's a peer-typed resource —
// for non-peer types Calculate falls through to group lookup, so we mirror
// that exactly to avoid silent divergence).
func peerInDestinations(c *types.NetworkMapComponents, rule *types.PolicyRule, peerID string) bool {
if rule.DestinationResource.Type == types.ResourceTypePeer && rule.DestinationResource.ID != "" {
return rule.DestinationResource.ID == peerID
}
for _, groupID := range rule.Destinations {
if c.IsPeerInGroup(peerID, groupID) {
return true
}
}
return false
}

View File

@@ -0,0 +1,186 @@
package grpc
import (
"testing"
"github.com/stretchr/testify/assert"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/types"
)
// TestComputeSSHEnabledForPeer covers both Calculate-mirroring branches:
// explicit NetbirdSSH protocol, and the legacy implicit case where a
// TCP/22 (or 22022 / ALL / port-range-covering-22) rule activates SSH when
// the destination peer has SSHEnabled=true locally. Belt-and-suspenders for
// the B1 fix that the prod-DB equivalence test alone wouldn't have caught
// if no account had this combination.
func TestComputeSSHEnabledForPeer(t *testing.T) {
const targetPeerID = "target"
const targetGroupID = "g_dst"
mkComponents := func(rule *types.PolicyRule, sshEnabled bool) (*types.NetworkMapComponents, *nbpeer.Peer) {
peer := &nbpeer.Peer{ID: targetPeerID, SSHEnabled: sshEnabled}
group := &types.Group{ID: targetGroupID, Name: "dst", Peers: []string{targetPeerID}}
return &types.NetworkMapComponents{
Peers: map[string]*nbpeer.Peer{targetPeerID: peer},
Groups: map[string]*types.Group{targetGroupID: group},
Policies: []*types.Policy{{
ID: "p",
Enabled: true,
Rules: []*types.PolicyRule{rule},
}},
}, peer
}
cases := []struct {
name string
peerSSH bool
rule types.PolicyRule
wantEnabled bool
}{
{
name: "explicit-netbird-ssh-activates-regardless-of-peer-ssh",
peerSSH: false,
rule: types.PolicyRule{
Enabled: true, Protocol: types.PolicyRuleProtocolNetbirdSSH,
Destinations: []string{targetGroupID},
},
wantEnabled: true,
},
{
name: "implicit-tcp-22-with-peer-ssh",
peerSSH: true,
rule: types.PolicyRule{
Enabled: true, Protocol: types.PolicyRuleProtocolTCP, Ports: []string{"22"},
Destinations: []string{targetGroupID},
},
wantEnabled: true,
},
{
name: "implicit-tcp-22-without-peer-ssh-disabled",
peerSSH: false,
rule: types.PolicyRule{
Enabled: true, Protocol: types.PolicyRuleProtocolTCP, Ports: []string{"22"},
Destinations: []string{targetGroupID},
},
wantEnabled: false,
},
{
name: "implicit-tcp-22022-with-peer-ssh",
peerSSH: true,
rule: types.PolicyRule{
Enabled: true, Protocol: types.PolicyRuleProtocolTCP, Ports: []string{"22022"},
Destinations: []string{targetGroupID},
},
wantEnabled: true,
},
{
name: "implicit-all-protocol-with-peer-ssh",
peerSSH: true,
rule: types.PolicyRule{
Enabled: true, Protocol: types.PolicyRuleProtocolALL,
Destinations: []string{targetGroupID},
},
wantEnabled: true,
},
{
name: "implicit-port-range-covers-22",
peerSSH: true,
rule: types.PolicyRule{
Enabled: true,
Protocol: types.PolicyRuleProtocolTCP,
PortRanges: []types.RulePortRange{{Start: 20, End: 30}},
Destinations: []string{targetGroupID},
},
wantEnabled: true,
},
{
name: "tcp-80-no-ssh",
peerSSH: true,
rule: types.PolicyRule{
Enabled: true, Protocol: types.PolicyRuleProtocolTCP, Ports: []string{"80"},
Destinations: []string{targetGroupID},
},
wantEnabled: false,
},
{
name: "disabled-rule-skipped",
peerSSH: true,
rule: types.PolicyRule{
Enabled: false, Protocol: types.PolicyRuleProtocolNetbirdSSH,
Destinations: []string{targetGroupID},
},
wantEnabled: false,
},
{
name: "peer-not-in-destinations",
peerSSH: true,
rule: types.PolicyRule{
Enabled: true, Protocol: types.PolicyRuleProtocolNetbirdSSH,
Destinations: []string{"g_other"}, // target not in this group
},
wantEnabled: false,
},
{
name: "peer-typed-destination-resource-matches",
peerSSH: false,
rule: types.PolicyRule{
Enabled: true,
Protocol: types.PolicyRuleProtocolNetbirdSSH,
DestinationResource: types.Resource{ID: targetPeerID, Type: types.ResourceTypePeer},
},
wantEnabled: true,
},
{
name: "non-peer-destination-resource-falls-through-to-groups",
peerSSH: false,
rule: types.PolicyRule{
Enabled: true,
Protocol: types.PolicyRuleProtocolNetbirdSSH,
DestinationResource: types.Resource{ID: targetPeerID, Type: "host"}, // wrong type
Destinations: []string{targetGroupID}, // saved by group fallback
},
wantEnabled: true,
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
c, peer := mkComponents(&tc.rule, tc.peerSSH)
got := computeSSHEnabledForPeer(c, peer)
assert.Equal(t, tc.wantEnabled, got)
})
}
}
// TestComputeSSHEnabledForPeer_TargetMissingFromComponents covers the
// belt-and-suspenders presence guard mirroring Calculate's
// getAllPeersFromGroups invariant.
func TestComputeSSHEnabledForPeer_TargetMissingFromComponents(t *testing.T) {
peer := &nbpeer.Peer{ID: "missing", SSHEnabled: true}
c := &types.NetworkMapComponents{
Peers: map[string]*nbpeer.Peer{}, // target peer NOT present
Groups: map[string]*types.Group{
"g": {ID: "g", Peers: []string{"missing"}},
},
Policies: []*types.Policy{{
ID: "p", Enabled: true,
Rules: []*types.PolicyRule{{
Enabled: true, Protocol: types.PolicyRuleProtocolNetbirdSSH,
Destinations: []string{"g"},
}},
}},
}
assert.False(t, computeSSHEnabledForPeer(c, peer),
"missing target peer must short-circuit to false, not consult policies")
}
// TestComputeSSHEnabledForPeer_NilInputs guards the cheap nil-checks at
// function entry — Calculate doesn't accept nil either, but the helper is
// exported indirectly via ToComponentSyncResponse and may receive nil
// components on graceful-degrade paths.
func TestComputeSSHEnabledForPeer_NilInputs(t *testing.T) {
assert.False(t, computeSSHEnabledForPeer(nil, &nbpeer.Peer{ID: "x"}))
assert.False(t, computeSSHEnabledForPeer(&types.NetworkMapComponents{}, nil))
}

View File

@@ -7,23 +7,18 @@ import (
"net/url"
"strings"
log "github.com/sirupsen/logrus"
goproto "google.golang.org/protobuf/proto"
integrationsConfig "github.com/netbirdio/management-integrations/integrations/config"
"github.com/netbirdio/netbird/client/ssh/auth"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/types"
nbroute "github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/networkmap"
"github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/shared/netiputil"
"github.com/netbirdio/netbird/shared/sshauth"
)
func toNetbirdConfig(config *nbconfig.Config, turnCredentials *Token, relayToken *Token, extraSettings *types.ExtraSettings) *proto.NetbirdConfig {
@@ -138,8 +133,8 @@ func ToSyncResponse(ctx context.Context, config *nbconfig.Config, httpConfig *nb
PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings, httpConfig, deviceFlowConfig, networkMap.EnableSSH),
NetworkMap: &proto.NetworkMap{
Serial: networkMap.Network.CurrentSerial(),
Routes: toProtocolRoutes(networkMap.Routes),
DNSConfig: toProtocolDNSConfig(networkMap.DNSConfig, dnsCache, dnsFwdPort),
Routes: networkmap.ToProtocolRoutes(networkMap.Routes),
DNSConfig: networkmap.ToProtocolDNSConfig(networkMap.DNSConfig, dnsCache, dnsFwdPort),
PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings, httpConfig, deviceFlowConfig, networkMap.EnableSSH),
},
Checks: toProtocolChecks(ctx, checks),
@@ -152,19 +147,19 @@ func ToSyncResponse(ctx context.Context, config *nbconfig.Config, httpConfig *nb
response.NetworkMap.PeerConfig = response.PeerConfig
remotePeers := make([]*proto.RemotePeerConfig, 0, len(networkMap.Peers)+len(networkMap.OfflinePeers))
remotePeers = appendRemotePeerConfig(remotePeers, networkMap.Peers, dnsName, includeIPv6)
remotePeers = networkmap.AppendRemotePeerConfig(remotePeers, networkMap.Peers, dnsName, includeIPv6)
response.RemotePeers = remotePeers
response.NetworkMap.RemotePeers = remotePeers
response.RemotePeersIsEmpty = len(remotePeers) == 0
response.NetworkMap.RemotePeersIsEmpty = response.RemotePeersIsEmpty
response.NetworkMap.OfflinePeers = appendRemotePeerConfig(nil, networkMap.OfflinePeers, dnsName, includeIPv6)
response.NetworkMap.OfflinePeers = networkmap.AppendRemotePeerConfig(nil, networkMap.OfflinePeers, dnsName, includeIPv6)
firewallRules := toProtocolFirewallRules(networkMap.FirewallRules, includeIPv6, useSourcePrefixes)
firewallRules := networkmap.ToProtocolFirewallRules(networkMap.FirewallRules, includeIPv6, useSourcePrefixes)
response.NetworkMap.FirewallRules = firewallRules
response.NetworkMap.FirewallRulesIsEmpty = len(firewallRules) == 0
routesFirewallRules := toProtocolRoutesFirewallRules(networkMap.RoutesFirewallRules)
routesFirewallRules := networkmap.ToProtocolRoutesFirewallRules(networkMap.RoutesFirewallRules)
response.NetworkMap.RoutesFirewallRules = routesFirewallRules
response.NetworkMap.RoutesFirewallRulesIsEmpty = len(routesFirewallRules) == 0
@@ -177,7 +172,7 @@ func ToSyncResponse(ctx context.Context, config *nbconfig.Config, httpConfig *nb
}
if networkMap.AuthorizedUsers != nil {
hashedUsers, machineUsers := buildAuthorizedUsersProto(ctx, networkMap.AuthorizedUsers)
hashedUsers, machineUsers := networkmap.BuildAuthorizedUsersProto(ctx, networkMap.AuthorizedUsers)
userIDClaim := auth.DefaultUserIDClaim
if httpConfig != nil && httpConfig.AuthUserIDClaim != "" {
userIDClaim = httpConfig.AuthUserIDClaim
@@ -188,78 +183,6 @@ func ToSyncResponse(ctx context.Context, config *nbconfig.Config, httpConfig *nb
return response
}
func buildAuthorizedUsersProto(ctx context.Context, authorizedUsers map[string]map[string]struct{}) ([][]byte, map[string]*proto.MachineUserIndexes) {
userIDToIndex := make(map[string]uint32)
var hashedUsers [][]byte
machineUsers := make(map[string]*proto.MachineUserIndexes, len(authorizedUsers))
for machineUser, users := range authorizedUsers {
indexes := make([]uint32, 0, len(users))
for userID := range users {
idx, exists := userIDToIndex[userID]
if !exists {
hash, err := sshauth.HashUserID(userID)
if err != nil {
log.WithContext(ctx).Errorf("failed to hash user id %s: %v", userID, err)
continue
}
idx = uint32(len(hashedUsers))
userIDToIndex[userID] = idx
hashedUsers = append(hashedUsers, hash[:])
}
indexes = append(indexes, idx)
}
machineUsers[machineUser] = &proto.MachineUserIndexes{Indexes: indexes}
}
return hashedUsers, machineUsers
}
func appendRemotePeerConfig(dst []*proto.RemotePeerConfig, peers []*nbpeer.Peer, dnsName string, includeIPv6 bool) []*proto.RemotePeerConfig {
for _, rPeer := range peers {
allowedIPs := []string{rPeer.IP.String() + "/32"}
if includeIPv6 && rPeer.IPv6.IsValid() {
allowedIPs = append(allowedIPs, rPeer.IPv6.String()+"/128")
}
dst = append(dst, &proto.RemotePeerConfig{
WgPubKey: rPeer.Key,
AllowedIps: allowedIPs,
SshConfig: &proto.SSHConfig{SshPubKey: []byte(rPeer.SSHKey)},
Fqdn: rPeer.FQDN(dnsName),
AgentVersion: rPeer.Meta.WtVersion,
})
}
return dst
}
// toProtocolDNSConfig converts nbdns.Config to proto.DNSConfig using the cache
func toProtocolDNSConfig(update nbdns.Config, cache *cache.DNSConfigCache, forwardPort int64) *proto.DNSConfig {
protoUpdate := &proto.DNSConfig{
ServiceEnable: update.ServiceEnable,
CustomZones: make([]*proto.CustomZone, 0, len(update.CustomZones)),
NameServerGroups: make([]*proto.NameServerGroup, 0, len(update.NameServerGroups)),
ForwarderPort: forwardPort,
}
for _, zone := range update.CustomZones {
protoZone := convertToProtoCustomZone(zone)
protoUpdate.CustomZones = append(protoUpdate.CustomZones, protoZone)
}
for _, nsGroup := range update.NameServerGroups {
cacheKey := nsGroup.ID
if cachedGroup, exists := cache.GetNameServerGroup(cacheKey); exists {
protoUpdate.NameServerGroups = append(protoUpdate.NameServerGroups, cachedGroup)
} else {
protoGroup := convertToProtoNameServerGroup(nsGroup)
cache.SetNameServerGroup(cacheKey, protoGroup)
protoUpdate.NameServerGroups = append(protoUpdate.NameServerGroups, protoGroup)
}
}
return protoUpdate
}
func ToResponseProto(configProto nbconfig.Protocol) proto.HostConfig_Protocol {
switch configProto {
case nbconfig.UDP:
@@ -277,204 +200,6 @@ func ToResponseProto(configProto nbconfig.Protocol) proto.HostConfig_Protocol {
}
}
func toProtocolRoutes(routes []*nbroute.Route) []*proto.Route {
protoRoutes := make([]*proto.Route, 0, len(routes))
for _, r := range routes {
protoRoutes = append(protoRoutes, toProtocolRoute(r))
}
return protoRoutes
}
func toProtocolRoute(route *nbroute.Route) *proto.Route {
return &proto.Route{
ID: string(route.ID),
NetID: string(route.NetID),
Network: route.Network.String(),
Domains: route.Domains.ToPunycodeList(),
NetworkType: int64(route.NetworkType),
Peer: route.Peer,
Metric: int64(route.Metric),
Masquerade: route.Masquerade,
KeepRoute: route.KeepRoute,
SkipAutoApply: route.SkipAutoApply,
}
}
// toProtocolFirewallRules converts the firewall rules to the protocol firewall rules.
// When useSourcePrefixes is true, the compact SourcePrefixes field is populated
// alongside the deprecated PeerIP for forward compatibility.
// Wildcard rules ("0.0.0.0") are expanded into separate v4 and v6 SourcePrefixes
// when includeIPv6 is true.
func toProtocolFirewallRules(rules []*types.FirewallRule, includeIPv6, useSourcePrefixes bool) []*proto.FirewallRule {
result := make([]*proto.FirewallRule, 0, len(rules))
for i := range rules {
rule := rules[i]
fwRule := &proto.FirewallRule{
PolicyID: []byte(rule.PolicyID),
PeerIP: rule.PeerIP, //nolint:staticcheck // populated for backward compatibility
Direction: getProtoDirection(rule.Direction),
Action: getProtoAction(rule.Action),
Protocol: getProtoProtocol(rule.Protocol),
Port: rule.Port,
}
if useSourcePrefixes && rule.PeerIP != "" {
result = append(result, populateSourcePrefixes(fwRule, rule, includeIPv6)...)
}
if shouldUsePortRange(fwRule) {
fwRule.PortInfo = rule.PortRange.ToProto()
}
result = append(result, fwRule)
}
return result
}
// populateSourcePrefixes sets SourcePrefixes on fwRule and returns any
// additional rules needed (e.g. a v6 wildcard clone when the peer IP is unspecified).
func populateSourcePrefixes(fwRule *proto.FirewallRule, rule *types.FirewallRule, includeIPv6 bool) []*proto.FirewallRule {
addr, err := netip.ParseAddr(rule.PeerIP)
if err != nil {
return nil
}
if !addr.IsUnspecified() {
fwRule.SourcePrefixes = [][]byte{netiputil.EncodeAddr(addr.Unmap())}
return nil
}
// IPv4Unspecified/0 is always valid, error is impossible.
v4Wildcard, _ := netiputil.EncodePrefix(netip.PrefixFrom(netip.IPv4Unspecified(), 0))
fwRule.SourcePrefixes = [][]byte{v4Wildcard}
if !includeIPv6 {
return nil
}
v6Rule := goproto.Clone(fwRule).(*proto.FirewallRule)
v6Rule.PeerIP = "::" //nolint:staticcheck // populated for backward compatibility
// IPv6Unspecified/0 is always valid, error is impossible.
v6Wildcard, _ := netiputil.EncodePrefix(netip.PrefixFrom(netip.IPv6Unspecified(), 0))
v6Rule.SourcePrefixes = [][]byte{v6Wildcard}
if shouldUsePortRange(v6Rule) {
v6Rule.PortInfo = rule.PortRange.ToProto()
}
return []*proto.FirewallRule{v6Rule}
}
// getProtoDirection converts the direction to proto.RuleDirection.
func getProtoDirection(direction int) proto.RuleDirection {
if direction == types.FirewallRuleDirectionOUT {
return proto.RuleDirection_OUT
}
return proto.RuleDirection_IN
}
func toProtocolRoutesFirewallRules(rules []*types.RouteFirewallRule) []*proto.RouteFirewallRule {
result := make([]*proto.RouteFirewallRule, len(rules))
for i := range rules {
rule := rules[i]
result[i] = &proto.RouteFirewallRule{
SourceRanges: rule.SourceRanges,
Action: getProtoAction(rule.Action),
Destination: rule.Destination,
Protocol: getProtoProtocol(rule.Protocol),
PortInfo: getProtoPortInfo(rule),
IsDynamic: rule.IsDynamic,
Domains: rule.Domains.ToPunycodeList(),
PolicyID: []byte(rule.PolicyID),
RouteID: string(rule.RouteID),
}
}
return result
}
// getProtoAction converts the action to proto.RuleAction.
func getProtoAction(action string) proto.RuleAction {
if action == string(types.PolicyTrafficActionDrop) {
return proto.RuleAction_DROP
}
return proto.RuleAction_ACCEPT
}
// getProtoProtocol converts the protocol to proto.RuleProtocol.
func getProtoProtocol(protocol string) proto.RuleProtocol {
switch types.PolicyRuleProtocolType(protocol) {
case types.PolicyRuleProtocolALL:
return proto.RuleProtocol_ALL
case types.PolicyRuleProtocolTCP:
return proto.RuleProtocol_TCP
case types.PolicyRuleProtocolUDP:
return proto.RuleProtocol_UDP
case types.PolicyRuleProtocolICMP:
return proto.RuleProtocol_ICMP
default:
return proto.RuleProtocol_UNKNOWN
}
}
// getProtoPortInfo converts the port info to proto.PortInfo.
func getProtoPortInfo(rule *types.RouteFirewallRule) *proto.PortInfo {
var portInfo proto.PortInfo
if rule.Port != 0 {
portInfo.PortSelection = &proto.PortInfo_Port{Port: uint32(rule.Port)}
} else if portRange := rule.PortRange; portRange.Start != 0 && portRange.End != 0 {
portInfo.PortSelection = &proto.PortInfo_Range_{
Range: &proto.PortInfo_Range{
Start: uint32(portRange.Start),
End: uint32(portRange.End),
},
}
}
return &portInfo
}
func shouldUsePortRange(rule *proto.FirewallRule) bool {
return rule.Port == "" && (rule.Protocol == proto.RuleProtocol_UDP || rule.Protocol == proto.RuleProtocol_TCP)
}
// Helper function to convert nbdns.CustomZone to proto.CustomZone
func convertToProtoCustomZone(zone nbdns.CustomZone) *proto.CustomZone {
protoZone := &proto.CustomZone{
Domain: zone.Domain,
Records: make([]*proto.SimpleRecord, 0, len(zone.Records)),
SearchDomainDisabled: zone.SearchDomainDisabled,
NonAuthoritative: zone.NonAuthoritative,
}
for _, record := range zone.Records {
protoZone.Records = append(protoZone.Records, &proto.SimpleRecord{
Name: record.Name,
Type: int64(record.Type),
Class: record.Class,
TTL: int64(record.TTL),
RData: record.RData,
})
}
return protoZone
}
// Helper function to convert nbdns.NameServerGroup to proto.NameServerGroup
func convertToProtoNameServerGroup(nsGroup *nbdns.NameServerGroup) *proto.NameServerGroup {
protoGroup := &proto.NameServerGroup{
Primary: nsGroup.Primary,
Domains: nsGroup.Domains,
SearchDomainsEnabled: nsGroup.SearchDomainsEnabled,
NameServers: make([]*proto.NameServer, 0, len(nsGroup.NameServers)),
}
for _, ns := range nsGroup.NameServers {
protoGroup.NameServers = append(protoGroup.NameServers, &proto.NameServer{
IP: ns.IP.String(),
Port: int64(ns.Port),
NSType: int64(ns.NSType),
})
}
return protoGroup
}
// buildJWTConfig constructs JWT configuration for SSH servers from management server config
func buildJWTConfig(config *nbconfig.HttpServerConfig, deviceFlowConfig *nbconfig.DeviceAuthorizationFlow) *proto.JWTConfig {
if config == nil || config.AuthAudience == "" {

View File

@@ -12,6 +12,7 @@ import (
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/shared/management/networkmap"
)
func TestToProtocolDNSConfigWithCache(t *testing.T) {
@@ -61,13 +62,13 @@ func TestToProtocolDNSConfigWithCache(t *testing.T) {
}
// First run with config1
result1 := toProtocolDNSConfig(config1, &cache, int64(network_map.DnsForwarderPort))
result1 := networkmap.ToProtocolDNSConfig(config1, &cache, int64(network_map.DnsForwarderPort))
// Second run with config2
result2 := toProtocolDNSConfig(config2, &cache, int64(network_map.DnsForwarderPort))
result2 := networkmap.ToProtocolDNSConfig(config2, &cache, int64(network_map.DnsForwarderPort))
// Third run with config1 again
result3 := toProtocolDNSConfig(config1, &cache, int64(network_map.DnsForwarderPort))
result3 := networkmap.ToProtocolDNSConfig(config1, &cache, int64(network_map.DnsForwarderPort))
// Verify that result1 and result3 are identical
if !reflect.DeepEqual(result1, result3) {
@@ -99,7 +100,7 @@ func BenchmarkToProtocolDNSConfig(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
toProtocolDNSConfig(testData, cache, int64(network_map.DnsForwarderPort))
networkmap.ToProtocolDNSConfig(testData, cache, int64(network_map.DnsForwarderPort))
}
})
@@ -107,7 +108,7 @@ func BenchmarkToProtocolDNSConfig(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
cache := &cache.DNSConfigCache{}
toProtocolDNSConfig(testData, cache, int64(network_map.DnsForwarderPort))
networkmap.ToProtocolDNSConfig(testData, cache, int64(network_map.DnsForwarderPort))
}
})
}

View File

@@ -932,7 +932,31 @@ func (s *Server) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, peer
return status.Errorf(codes.Internal, "failed to get peer groups %s", err)
}
plainResp := ToSyncResponse(ctx, s.config, s.config.HttpConfig, s.config.DeviceAuthorizationFlow, peer, turnToken, relayToken, networkMap, s.networkMapController.GetDNSDomain(settings), postureChecks, nil, settings, settings.Extra, peerGroups, dnsFwdPort)
dnsName := s.networkMapController.GetDNSDomain(settings)
var plainResp *proto.SyncResponse
if s.networkMapController.PeerNeedsComponents(peer) {
// Capable peer: discard the legacy NetworkMap that SyncAndMarkPeer
// computed and recompute the raw components instead. This wastes one
// Calculate() call per initial-sync — the component-based wire
// format is what the peer actually consumes. The streaming path
// (network_map.Controller.UpdateAccountPeers) skips this duplication
// because it dispatches by capability before computing.
//
// TODO(step-4-sync): refactor SyncPeer / SyncAndMarkPeer / their
// mocks + manager interfaces to return PeerNetworkMapResult so the
// initial-sync path stops doing duplicate work. ~13 files of churn,
// deferred until the client-side decoder lands and there's a real
// deployment of capability=3 peers worth optimizing for.
_, components, proxyPatch, _, _, err := s.networkMapController.GetValidatedPeerWithComponents(ctx, false, peer.AccountID, peer)
if err != nil {
log.WithContext(ctx).Errorf("failed to build components for peer %s on initial sync: %v", peer.ID, err)
return status.Errorf(codes.Internal, "failed to build initial sync envelope")
}
plainResp = ToComponentSyncResponse(ctx, s.config, s.config.HttpConfig, s.config.DeviceAuthorizationFlow, peer, turnToken, relayToken, components, proxyPatch, dnsName, postureChecks, settings, settings.Extra, peerGroups, dnsFwdPort)
} else {
plainResp = ToSyncResponse(ctx, s.config, s.config.HttpConfig, s.config.DeviceAuthorizationFlow, peer, turnToken, relayToken, networkMap, dnsName, postureChecks, nil, settings, settings.Extra, peerGroups, dnsFwdPort)
}
key, err := s.secretsManager.GetWGKey()
if err != nil {

View File

@@ -1621,6 +1621,14 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth
return nil
}
for _, g := range newGroupsToCreate {
seq, err := transaction.AllocateAccountSeqID(ctx, userAuth.AccountId, types.AccountSeqEntityGroup)
if err != nil {
return fmt.Errorf("error allocating group seq id: %w", err)
}
g.AccountSeqID = seq
}
if err = transaction.CreateGroups(ctx, userAuth.AccountId, newGroupsToCreate); err != nil {
return fmt.Errorf("error saving groups: %w", err)
}

View File

@@ -3036,6 +3036,16 @@ func TestAccount_SetJWTGroups(t *testing.T) {
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user2")
assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 1, "new group should be added")
var newJWTGroup *types.Group
for _, g := range groups {
if g.Name == "group3" {
newJWTGroup = g
break
}
}
require.NotNil(t, newJWTGroup, "JIT-created JWT group not found")
assert.NotZero(t, newJWTGroup.AccountSeqID, "JIT-created JWT group must have a non-zero AccountSeqID")
})
t.Run("remove all JWT groups when list is empty", func(t *testing.T) {

View File

@@ -96,6 +96,12 @@ func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, use
return err
}
seq, err := transaction.AllocateAccountSeqID(ctx, accountID, types.AccountSeqEntityGroup)
if err != nil {
return status.Errorf(status.Internal, "failed to allocate group seq id: %v", err)
}
newGroup.AccountSeqID = seq
if err := transaction.CreateGroup(ctx, newGroup); err != nil {
return status.Errorf(status.Internal, "failed to create group: %v", err)
}
@@ -170,6 +176,8 @@ func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, use
return err
}
newGroup.AccountSeqID = oldGroup.AccountSeqID
if err = transaction.UpdateGroup(ctx, newGroup); err != nil {
return err
}
@@ -221,6 +229,12 @@ func (am *DefaultAccountManager) CreateGroups(ctx context.Context, accountID, us
newGroup.AccountID = accountID
seq, err := transaction.AllocateAccountSeqID(ctx, accountID, types.AccountSeqEntityGroup)
if err != nil {
return err
}
newGroup.AccountSeqID = seq
if err = transaction.CreateGroup(ctx, newGroup); err != nil {
return err
}
@@ -320,6 +334,12 @@ func (am *DefaultAccountManager) updateSingleGroup(ctx context.Context, accountI
newGroup.AccountID = accountID
oldGroup, err := transaction.GetGroupByID(ctx, store.LockingStrengthNone, accountID, newGroup.ID)
if err != nil {
return err
}
newGroup.AccountSeqID = oldGroup.AccountSeqID
if err := transaction.UpdateGroup(ctx, newGroup); err != nil {
return err
}

View File

@@ -0,0 +1,156 @@
package migration
import (
"context"
"fmt"
log "github.com/sirupsen/logrus"
"gorm.io/gorm"
"github.com/netbirdio/netbird/management/server/types"
)
// BackfillAccountSeqIDs assigns a deterministic per-account sequential id to all
// rows of `model` whose account_seq_id is zero, then seeds account_seq_counters
// with the next free id per account. Idempotent: safe to re-run; both steps
// no-op once everything is consistent.
//
// Implemented as two table-wide SQL statements with window functions, one
// transaction. Backfilling 246k rows across 154k accounts on Postgres takes
// well under a second instead of the per-account-loop ~2 minutes.
//
// orderColumn is the column to use when assigning the deterministic ordering
// (typically the primary-key string id).
func BackfillAccountSeqIDs[T any](
ctx context.Context,
db *gorm.DB,
entity types.AccountSeqEntity,
orderColumn string,
) error {
var model T
if !db.Migrator().HasTable(&model) {
log.WithContext(ctx).Debugf("backfill seq id: table for %T missing, skip", model)
return nil
}
stmt := &gorm.Statement{DB: db}
if err := stmt.Parse(&model); err != nil {
return fmt.Errorf("parse model: %w", err)
}
table := quoteIdent(db, stmt.Schema.Table)
orderCol := quoteIdent(db, orderColumn)
return db.Transaction(func(tx *gorm.DB) error {
var pending int64
if err := tx.Raw(
fmt.Sprintf("SELECT count(*) FROM %s WHERE account_seq_id IS NULL OR account_seq_id = 0", table),
).Scan(&pending).Error; err != nil {
return fmt.Errorf("count pending on %s: %w", table, err)
}
if pending > 0 {
log.WithContext(ctx).Infof("backfill seq id: %s — %d rows pending", table, pending)
if err := backfillRankSQL(tx, table, orderCol); err != nil {
return fmt.Errorf("rank %s: %w", table, err)
}
}
if err := seedCountersSQL(tx, table, entity); err != nil {
return fmt.Errorf("seed counters for %s: %w", entity, err)
}
return nil
})
}
func quoteIdent(db *gorm.DB, name string) string {
switch db.Dialector.Name() {
case "mysql":
return "`" + name + "`"
case "postgres":
return `"` + name + `"`
default:
return name
}
}
func backfillRankSQL(db *gorm.DB, table, orderCol string) error {
dialect := db.Dialector.Name()
var sql string
switch dialect {
case "postgres", "sqlite":
sql = fmt.Sprintf(`
WITH max_seq AS (
SELECT account_id, COALESCE(MAX(account_seq_id), 0) AS max_seq
FROM %s
GROUP BY account_id
),
ranked AS (
SELECT p.id,
m.max_seq + ROW_NUMBER() OVER (PARTITION BY p.account_id ORDER BY p.%s) AS new_seq
FROM %s p
JOIN max_seq m ON p.account_id = m.account_id
WHERE p.account_seq_id IS NULL OR p.account_seq_id = 0
)
UPDATE %s SET account_seq_id = ranked.new_seq
FROM ranked
WHERE %s.id = ranked.id
`, table, orderCol, table, table, table)
case "mysql":
sql = fmt.Sprintf(`
UPDATE %s p
JOIN (
SELECT account_id, COALESCE(MAX(account_seq_id), 0) AS max_seq
FROM %s
GROUP BY account_id
) m ON p.account_id = m.account_id
JOIN (
SELECT id, ROW_NUMBER() OVER (PARTITION BY account_id ORDER BY %s) AS rn
FROM %s
WHERE account_seq_id IS NULL OR account_seq_id = 0
) r ON p.id = r.id
SET p.account_seq_id = m.max_seq + r.rn
`, table, table, orderCol, table)
default:
return fmt.Errorf("unsupported dialect: %s", dialect)
}
return db.Exec(sql).Error
}
func seedCountersSQL(db *gorm.DB, table string, entity types.AccountSeqEntity) error {
dialect := db.Dialector.Name()
var sql string
switch dialect {
case "postgres":
sql = fmt.Sprintf(`
INSERT INTO account_seq_counters (account_id, entity, next_id)
SELECT account_id, ?, MAX(account_seq_id) + 1
FROM %s
WHERE account_seq_id IS NOT NULL AND account_seq_id > 0
GROUP BY account_id
ON CONFLICT (account_id, entity) DO UPDATE
SET next_id = GREATEST(account_seq_counters.next_id, EXCLUDED.next_id)
`, table)
case "sqlite":
sql = fmt.Sprintf(`
INSERT INTO account_seq_counters (account_id, entity, next_id)
SELECT account_id, ?, MAX(account_seq_id) + 1
FROM %s
WHERE account_seq_id IS NOT NULL AND account_seq_id > 0
GROUP BY account_id
ON CONFLICT (account_id, entity) DO UPDATE
SET next_id = max(account_seq_counters.next_id, excluded.next_id)
`, table)
case "mysql":
sql = fmt.Sprintf(`
INSERT INTO account_seq_counters (account_id, entity, next_id)
SELECT account_id, ?, MAX(account_seq_id) + 1
FROM %s
WHERE account_seq_id IS NOT NULL AND account_seq_id > 0
GROUP BY account_id
ON DUPLICATE KEY UPDATE next_id = GREATEST(next_id, VALUES(next_id))
`, table)
default:
return fmt.Errorf("unsupported dialect: %s", dialect)
}
return db.Exec(sql, string(entity)).Error
}

View File

@@ -69,6 +69,12 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco
return err
}
seq, err := transaction.AllocateAccountSeqID(ctx, accountID, types.AccountSeqEntityNameserverGroup)
if err != nil {
return err
}
newNSGroup.AccountSeqID = seq
if err = transaction.SaveNameServerGroup(ctx, newNSGroup); err != nil {
return err
}
@@ -120,6 +126,8 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun
return err
}
nsGroupToSave.AccountSeqID = oldNSGroup.AccountSeqID
if err = transaction.SaveNameServerGroup(ctx, nsGroupToSave); err != nil {
return err
}

View File

@@ -71,9 +71,20 @@ func (m *managerImpl) CreateNetwork(ctx context.Context, userID string, network
network.ID = xid.New().String()
err = m.store.SaveNetwork(ctx, network)
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
seq, err := transaction.AllocateAccountSeqID(ctx, network.AccountID, serverTypes.AccountSeqEntityNetwork)
if err != nil {
return fmt.Errorf("failed to allocate network seq id: %w", err)
}
network.AccountSeqID = seq
if err := transaction.SaveNetwork(ctx, network); err != nil {
return fmt.Errorf("failed to save network: %w", err)
}
return nil
})
if err != nil {
return nil, fmt.Errorf("failed to save network: %w", err)
return nil, err
}
m.accountManager.StoreEvent(ctx, userID, network.ID, network.AccountID, activity.NetworkCreated, network.EventMeta())
@@ -102,14 +113,25 @@ func (m *managerImpl) UpdateNetwork(ctx context.Context, userID string, network
return nil, status.NewPermissionDeniedError()
}
_, err = m.store.GetNetworkByID(ctx, store.LockingStrengthUpdate, network.AccountID, network.ID)
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
existing, err := transaction.GetNetworkByID(ctx, store.LockingStrengthUpdate, network.AccountID, network.ID)
if err != nil {
return fmt.Errorf("failed to get network: %w", err)
}
network.AccountSeqID = existing.AccountSeqID
if err := transaction.SaveNetwork(ctx, network); err != nil {
return fmt.Errorf("failed to save network: %w", err)
}
return nil
})
if err != nil {
return nil, fmt.Errorf("failed to get network: %w", err)
return nil, err
}
m.accountManager.StoreEvent(ctx, userID, network.ID, network.AccountID, activity.NetworkUpdated, network.EventMeta())
return network, m.store.SaveNetwork(ctx, network)
return network, nil
}
func (m *managerImpl) DeleteNetwork(ctx context.Context, accountID, userID, networkID string) error {

View File

@@ -252,3 +252,73 @@ func Test_UpdateNetworkFailsWithPermissionDenied(t *testing.T) {
require.Error(t, err)
require.Nil(t, updatedNetwork)
}
// Test_CreateNetworkAllocatesSeqID verifies that CreateNetwork sets a
// non-zero AccountSeqID on the persisted network (allocated through the
// account_seq_counters table).
func Test_CreateNetworkAllocatesSeqID(t *testing.T) {
ctx := context.Background()
const accountID = "testAccountId"
const userID = "testAdminId"
s, cleanUp, err := store.NewTestStoreFromSQL(ctx, "../testdata/networks.sql", t.TempDir())
require.NoError(t, err)
t.Cleanup(cleanUp)
am := mock_server.MockAccountManager{}
permissionsManager := permissions.NewManager(s)
groupsManager := groups.NewManagerMock()
routerManager := routers.NewManagerMock()
resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am, nil)
manager := NewManager(s, permissionsManager, resourcesManager, routerManager, &am)
created, err := manager.CreateNetwork(ctx, userID, &types.Network{
AccountID: accountID,
Name: "seq-allocation-test",
})
require.NoError(t, err)
require.NotZero(t, created.AccountSeqID, "CreateNetwork must allocate a non-zero AccountSeqID")
}
// Test_UpdateNetworkPreservesSeqID verifies UpdateNetwork does not reset
// AccountSeqID even when the caller passes a zero value (the shape REST
// handlers produce because the field is `json:"-"`).
func Test_UpdateNetworkPreservesSeqID(t *testing.T) {
ctx := context.Background()
const accountID = "testAccountId"
const userID = "testAdminId"
s, cleanUp, err := store.NewTestStoreFromSQL(ctx, "../testdata/networks.sql", t.TempDir())
require.NoError(t, err)
t.Cleanup(cleanUp)
am := mock_server.MockAccountManager{}
permissionsManager := permissions.NewManager(s)
groupsManager := groups.NewManagerMock()
routerManager := routers.NewManagerMock()
resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am, nil)
manager := NewManager(s, permissionsManager, resourcesManager, routerManager, &am)
created, err := manager.CreateNetwork(ctx, userID, &types.Network{
AccountID: accountID,
Name: "seq-preserve-original",
})
require.NoError(t, err)
originalSeq := created.AccountSeqID
require.NotZero(t, originalSeq)
update := &types.Network{
AccountID: accountID,
ID: created.ID,
Name: "seq-preserve-renamed",
}
require.Zero(t, update.AccountSeqID, "incoming struct must mirror an HTTP handler shape")
_, err = manager.UpdateNetwork(ctx, userID, update)
require.NoError(t, err)
got, err := manager.GetNetwork(ctx, accountID, userID, created.ID)
require.NoError(t, err)
require.Equal(t, originalSeq, got.AccountSeqID, "AccountSeqID must survive UpdateNetwork")
require.Equal(t, "seq-preserve-renamed", got.Name)
}

View File

@@ -125,6 +125,12 @@ func (m *managerImpl) CreateResource(ctx context.Context, userID string, resourc
return fmt.Errorf("failed to get network: %w", err)
}
seq, err := transaction.AllocateAccountSeqID(ctx, resource.AccountID, nbtypes.AccountSeqEntityNetworkResource)
if err != nil {
return fmt.Errorf("failed to allocate network resource seq id: %w", err)
}
resource.AccountSeqID = seq
err = transaction.SaveNetworkResource(ctx, resource)
if err != nil {
return fmt.Errorf("failed to save network resource: %w", err)
@@ -231,6 +237,7 @@ func (m *managerImpl) UpdateResource(ctx context.Context, userID string, resourc
if err != nil {
return fmt.Errorf("failed to get network resource: %w", err)
}
resource.AccountSeqID = oldResource.AccountSeqID
err = transaction.SaveNetworkResource(ctx, resource)
if err != nil {

View File

@@ -32,6 +32,9 @@ type NetworkResource struct {
ID string `gorm:"primaryKey"`
NetworkID string `gorm:"index"`
AccountID string `gorm:"index"`
// AccountSeqID is a per-account monotonically increasing identifier used as the
// compact wire id when sending NetworkMap components to capable peers.
AccountSeqID uint32 `json:"-" gorm:"index:idx_network_resources_account_seq_id;not null;default:0"`
Name string
Description string
Type NetworkResourceType
@@ -93,17 +96,18 @@ func (n *NetworkResource) FromAPIRequest(req *api.NetworkResourceRequest) {
func (n *NetworkResource) Copy() *NetworkResource {
return &NetworkResource{
ID: n.ID,
AccountID: n.AccountID,
NetworkID: n.NetworkID,
Name: n.Name,
Description: n.Description,
Type: n.Type,
Address: n.Address,
Domain: n.Domain,
Prefix: n.Prefix,
GroupIDs: n.GroupIDs,
Enabled: n.Enabled,
ID: n.ID,
AccountID: n.AccountID,
NetworkID: n.NetworkID,
AccountSeqID: n.AccountSeqID,
Name: n.Name,
Description: n.Description,
Type: n.Type,
Address: n.Address,
Domain: n.Domain,
Prefix: n.Prefix,
GroupIDs: n.GroupIDs,
Enabled: n.Enabled,
}
}

View File

@@ -102,6 +102,12 @@ func (m *managerImpl) CreateRouter(ctx context.Context, userID string, router *t
router.ID = xid.New().String()
seq, err := transaction.AllocateAccountSeqID(ctx, router.AccountID, serverTypes.AccountSeqEntityNetworkRouter)
if err != nil {
return fmt.Errorf("failed to allocate network router seq id: %w", err)
}
router.AccountSeqID = seq
err = transaction.SaveNetworkRouter(ctx, router)
if err != nil {
return fmt.Errorf("failed to create network router: %w", err)
@@ -166,6 +172,22 @@ func (m *managerImpl) UpdateRouter(ctx context.Context, userID string, router *t
return status.NewRouterNotPartOfNetworkError(router.ID, router.NetworkID)
}
oldRouter, err := transaction.GetNetworkRouterByID(ctx, store.LockingStrengthNone, router.AccountID, router.ID)
if err == nil {
router.AccountSeqID = oldRouter.AccountSeqID
} else if e, ok := status.FromError(err); ok && e.Type() == status.NotFound {
// PUT-as-upsert: caller may target a brand-new router id (used by
// the dashboard's "save" flow). Allocate a fresh account_seq_id so
// the upsert behaves the same as Create().
seq, allocErr := transaction.AllocateAccountSeqID(ctx, router.AccountID, serverTypes.AccountSeqEntityNetworkRouter)
if allocErr != nil {
return fmt.Errorf("failed to allocate network router seq id: %w", allocErr)
}
router.AccountSeqID = seq
} else {
return fmt.Errorf("failed to get existing network router: %w", err)
}
err = transaction.SaveNetworkRouter(ctx, router)
if err != nil {
return fmt.Errorf("failed to update network router: %w", err)

View File

@@ -195,6 +195,7 @@ func Test_UpdateRouterSuccessfully(t *testing.T) {
if err != nil {
require.NoError(t, err)
}
router.ID = "testRouterId"
s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir())
if err != nil {

View File

@@ -13,6 +13,9 @@ type NetworkRouter struct {
ID string `gorm:"primaryKey"`
NetworkID string `gorm:"index"`
AccountID string `gorm:"index"`
// AccountSeqID is a per-account monotonically increasing identifier used as the
// compact wire id when sending NetworkMap components to capable peers.
AccountSeqID uint32 `json:"-" gorm:"index:idx_network_routers_account_seq_id;not null;default:0"`
Peer string
PeerGroups []string `gorm:"serializer:json"`
Masquerade bool
@@ -78,14 +81,15 @@ func (n *NetworkRouter) FromAPIRequest(req *api.NetworkRouterRequest) {
func (n *NetworkRouter) Copy() *NetworkRouter {
return &NetworkRouter{
ID: n.ID,
NetworkID: n.NetworkID,
AccountID: n.AccountID,
Peer: n.Peer,
PeerGroups: n.PeerGroups,
Masquerade: n.Masquerade,
Metric: n.Metric,
Enabled: n.Enabled,
ID: n.ID,
NetworkID: n.NetworkID,
AccountID: n.AccountID,
AccountSeqID: n.AccountSeqID,
Peer: n.Peer,
PeerGroups: n.PeerGroups,
Masquerade: n.Masquerade,
Metric: n.Metric,
Enabled: n.Enabled,
}
}

View File

@@ -7,12 +7,24 @@ import (
)
type Network struct {
ID string `gorm:"primaryKey"`
AccountID string `gorm:"index"`
ID string `gorm:"primaryKey"`
AccountID string `gorm:"index"`
// AccountSeqID is a per-account monotonically increasing identifier used as the
// compact wire id when sending NetworkMap components to capable peers.
AccountSeqID uint32 `json:"-" gorm:"index:idx_networks_account_seq_id;not null;default:0"`
Name string
Description string
}
// HasSeqID reports whether the network has been persisted long enough to have
// a per-account sequence id allocated. Wire encoders that key off AccountSeqID
// must skip networks that return false here.
func (n *Network) HasSeqID() bool {
return n != nil && n.AccountSeqID != 0
}
func NewNetwork(accountId, name, description string) *Network {
return &Network{
ID: xid.New().String(),
@@ -41,13 +53,14 @@ func (n *Network) FromAPIRequest(req *api.NetworkRequest) {
}
}
// Copy returns a copy of a posture checks.
// Copy returns a copy of a network.
func (n *Network) Copy() *Network {
return &Network{
ID: n.ID,
AccountID: n.AccountID,
Name: n.Name,
Description: n.Description,
ID: n.ID,
AccountID: n.AccountID,
AccountSeqID: n.AccountSeqID,
Name: n.Name,
Description: n.Description,
}
}

View File

@@ -13,8 +13,9 @@ import (
// Peer capability constants mirror the proto enum values.
const (
PeerCapabilitySourcePrefixes int32 = 1
PeerCapabilityIPv6Overlay int32 = 2
PeerCapabilitySourcePrefixes int32 = 1
PeerCapabilityIPv6Overlay int32 = 2
PeerCapabilityComponentNetworkMap int32 = 3
)
// Peer represents a machine connected to the network.
@@ -247,6 +248,14 @@ func (p *Peer) SupportsSourcePrefixes() bool {
return p.HasCapability(PeerCapabilitySourcePrefixes)
}
// SupportsComponentNetworkMap reports whether the peer assembles its
// NetworkMap from server-shipped components instead of consuming a fully
// expanded NetworkMap. Determines whether the network_map controller skips
// Calculate() server-side and emits the components envelope.
func (p *Peer) SupportsComponentNetworkMap() bool {
return p.HasCapability(PeerCapabilityComponentNetworkMap)
}
func capabilitiesEqual(a, b []int32) bool {
if len(a) != len(b) {
return false

View File

@@ -69,6 +69,8 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
return err
}
policy.AccountSeqID = existingPolicy.AccountSeqID
if err = transaction.SavePolicy(ctx, policy); err != nil {
return err
}
@@ -78,6 +80,12 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
return err
}
seq, err := transaction.AllocateAccountSeqID(ctx, accountID, types.AccountSeqEntityPolicy)
if err != nil {
return err
}
policy.AccountSeqID = seq
if err = transaction.CreatePolicy(ctx, policy); err != nil {
return err
}

View File

@@ -47,10 +47,21 @@ type Checks struct {
// AccountID is a reference to the Account that this object belongs
AccountID string `json:"-" gorm:"index"`
// AccountSeqID is a per-account monotonically increasing identifier used as the
// compact wire id when sending NetworkMap components to capable peers.
AccountSeqID uint32 `json:"-" gorm:"index:idx_posture_checks_account_seq_id;not null;default:0"`
// Checks is a set of objects that perform the actual checks
Checks ChecksDefinition `gorm:"serializer:json"`
}
// HasSeqID reports whether the posture check has been persisted long enough
// to have a per-account sequence id allocated. Wire encoders that key off
// AccountSeqID must skip checks that return false here.
func (pc *Checks) HasSeqID() bool {
return pc != nil && pc.AccountSeqID != 0
}
// ChecksDefinition contains definition of actual check
type ChecksDefinition struct {
NBVersionCheck *NBVersionCheck `json:",omitempty"`
@@ -121,11 +132,12 @@ func (*Checks) TableName() string {
// Copy returns a copy of a posture checks.
func (pc *Checks) Copy() *Checks {
checks := &Checks{
ID: pc.ID,
Name: pc.Name,
Description: pc.Description,
AccountID: pc.AccountID,
Checks: pc.Checks.Copy(),
ID: pc.ID,
Name: pc.Name,
Description: pc.Description,
AccountID: pc.AccountID,
AccountSeqID: pc.AccountSeqID,
Checks: pc.Checks.Copy(),
}
return checks
}

View File

@@ -51,12 +51,24 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI
}
if isUpdate {
existing, err := transaction.GetPostureChecksByID(ctx, store.LockingStrengthNone, accountID, postureChecks.ID)
if err != nil {
return err
}
postureChecks.AccountSeqID = existing.AccountSeqID
updateAccountPeers, err = arePostureCheckChangesAffectPeers(ctx, transaction, accountID, postureChecks.ID)
if err != nil {
return err
}
action = activity.PostureCheckUpdated
} else {
seq, err := transaction.AllocateAccountSeqID(ctx, accountID, types.AccountSeqEntityPostureCheck)
if err != nil {
return err
}
postureChecks.AccountSeqID = seq
}
postureChecks.AccountID = accountID

View File

@@ -563,3 +563,61 @@ func TestArePostureCheckChangesAffectPeers(t *testing.T) {
assert.False(t, result)
})
}
// TestSavePostureChecks_AllocatesSeqIDOnCreate verifies that the create path
// (no incoming ID) allocates a non-zero AccountSeqID via the
// account_seq_counters table.
func TestSavePostureChecks_AllocatesSeqIDOnCreate(t *testing.T) {
am, _, err := createManager(t)
require.NoError(t, err)
account, err := initTestPostureChecksAccount(am)
require.NoError(t, err)
created, err := am.SavePostureChecks(context.Background(), account.Id, adminUserID, &posture.Checks{
Name: "seq-allocation-test",
Checks: posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.26.0"},
},
}, true)
require.NoError(t, err)
require.NotZero(t, created.AccountSeqID, "SavePostureChecks on create must allocate a non-zero AccountSeqID")
}
// TestSavePostureChecks_PreservesSeqIDOnUpdate verifies the update path does
// not reset AccountSeqID even when the caller passes a zero value (REST
// handler shape, because the field is `json:"-"`).
func TestSavePostureChecks_PreservesSeqIDOnUpdate(t *testing.T) {
am, _, err := createManager(t)
require.NoError(t, err)
account, err := initTestPostureChecksAccount(am)
require.NoError(t, err)
created, err := am.SavePostureChecks(context.Background(), account.Id, adminUserID, &posture.Checks{
Name: "seq-preserve-original",
Checks: posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.26.0"},
},
}, true)
require.NoError(t, err)
originalSeq := created.AccountSeqID
require.NotZero(t, originalSeq)
update := &posture.Checks{
ID: created.ID,
Name: "seq-preserve-renamed",
Checks: posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.27.0"},
},
}
require.Zero(t, update.AccountSeqID, "incoming struct must mirror an HTTP handler shape")
_, err = am.SavePostureChecks(context.Background(), account.Id, adminUserID, update, false)
require.NoError(t, err)
got, err := am.GetPostureChecks(context.Background(), account.Id, created.ID, adminUserID)
require.NoError(t, err)
require.Equal(t, originalSeq, got.AccountSeqID, "AccountSeqID must survive SavePostureChecks update")
require.Equal(t, "seq-preserve-renamed", got.Name)
}

View File

@@ -178,6 +178,12 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
return err
}
seq, err := transaction.AllocateAccountSeqID(ctx, accountID, types.AccountSeqEntityRoute)
if err != nil {
return err
}
newRoute.AccountSeqID = seq
if err = transaction.SaveRoute(ctx, newRoute); err != nil {
return err
}
@@ -231,6 +237,7 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
return err
}
routeToSave.AccountID = accountID
routeToSave.AccountSeqID = oldRoute.AccountSeqID
if err = transaction.SaveRoute(ctx, routeToSave); err != nil {
return err

View File

@@ -0,0 +1,506 @@
package store
import (
"context"
"errors"
"net/netip"
"testing"
"github.com/stretchr/testify/require"
nbdns "github.com/netbirdio/netbird/dns"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/route"
)
var errRollback = errors.New("intentional rollback")
func TestAllocateAccountSeqID_SequentialPerAccount(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
ctx := context.Background()
const accA = "acc-a"
const accB = "acc-b"
require.NoError(t, store.ExecuteInTransaction(ctx, func(tx Store) error {
got, err := tx.AllocateAccountSeqID(ctx, accA, types.AccountSeqEntityPolicy)
require.NoError(t, err)
require.Equal(t, uint32(1), got)
got, err = tx.AllocateAccountSeqID(ctx, accA, types.AccountSeqEntityPolicy)
require.NoError(t, err)
require.Equal(t, uint32(2), got)
got, err = tx.AllocateAccountSeqID(ctx, accB, types.AccountSeqEntityPolicy)
require.NoError(t, err)
require.Equal(t, uint32(1), got, "different account starts from 1")
got, err = tx.AllocateAccountSeqID(ctx, accA, types.AccountSeqEntityGroup)
require.NoError(t, err)
require.Equal(t, uint32(1), got, "different entity starts from 1")
return nil
}))
require.NoError(t, store.ExecuteInTransaction(ctx, func(tx Store) error {
got, err := tx.AllocateAccountSeqID(ctx, accA, types.AccountSeqEntityPolicy)
require.NoError(t, err)
require.Equal(t, uint32(3), got, "counter persists across transactions")
return nil
}))
}
func TestPolicyBackfill_AssignsSeqIDsToExistingPolicies(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
ctx := context.Background()
const accountID = "bf1c8084-ba50-4ce7-9439-34653001fc3b"
policies, err := store.GetAccountPolicies(ctx, LockingStrengthNone, accountID)
require.NoError(t, err)
require.NotEmpty(t, policies, "test fixture must have policies")
seen := make(map[uint32]bool)
for _, p := range policies {
require.NotZero(t, p.AccountSeqID, "policy %s must have a non-zero AccountSeqID after migration", p.ID)
require.False(t, seen[p.AccountSeqID], "duplicate AccountSeqID %d in account %s", p.AccountSeqID, accountID)
seen[p.AccountSeqID] = true
}
}
func TestPolicyUpdate_PreservesSeqID(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
ctx := context.Background()
const accountID = "bf1c8084-ba50-4ce7-9439-34653001fc3b"
const policyID = "cs1tnh0hhcjnqoiuebf0"
original, err := store.GetPolicyByID(ctx, LockingStrengthNone, accountID, policyID)
require.NoError(t, err)
originalSeq := original.AccountSeqID
require.NotZero(t, originalSeq, "fixture must have non-zero AccountSeqID after backfill")
updated := &types.Policy{
ID: policyID,
AccountID: accountID,
Name: "renamed",
Enabled: false,
Rules: original.Rules,
}
require.Zero(t, updated.AccountSeqID, "incoming struct should have zero AccountSeqID like an HTTP handler would")
require.NoError(t, store.SavePolicy(ctx, updated))
got, err := store.GetPolicyByID(ctx, LockingStrengthNone, accountID, policyID)
require.NoError(t, err)
require.Equal(t, originalSeq, got.AccountSeqID, "AccountSeqID must not be reset by update path")
require.Equal(t, "renamed", got.Name)
}
func TestGroupUpdate_PreservesSeqID(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
ctx := context.Background()
const accountID = "bf1c8084-ba50-4ce7-9439-34653001fc3b"
groups, err := store.GetAccountGroups(ctx, LockingStrengthNone, accountID)
require.NoError(t, err)
require.NotEmpty(t, groups)
original := groups[0]
originalSeq := original.AccountSeqID
require.NotZero(t, originalSeq)
updated := &types.Group{
ID: original.ID,
AccountID: accountID,
Name: "renamed",
Issued: original.Issued,
}
require.Zero(t, updated.AccountSeqID)
require.NoError(t, store.UpdateGroup(ctx, updated))
got, err := store.GetGroupByID(ctx, LockingStrengthNone, accountID, original.ID)
require.NoError(t, err)
require.Equal(t, originalSeq, got.AccountSeqID, "AccountSeqID must not be reset by UpdateGroup")
require.Equal(t, "renamed", got.Name)
}
func TestSaveAccount_AllocatesSeqIDsForDefaultGroupAndPolicy(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
ctx := context.Background()
const accountID = "save-account-seqid-test"
account := &types.Account{
Id: accountID,
CreatedBy: "user1",
Domain: "example.test",
DNSSettings: types.DNSSettings{},
Settings: &types.Settings{},
Network: &types.Network{
Identifier: "net-test",
},
Users: map[string]*types.User{
"user1": {Id: "user1", AccountID: accountID, Role: types.UserRoleOwner},
},
}
require.NoError(t, account.AddAllGroup(false), "AddAllGroup should populate default Group + Policy")
require.Len(t, account.Groups, 1, "default 'All' group must be present")
require.Len(t, account.Policies, 1, "default policy must be present")
for _, g := range account.Groups {
require.Zero(t, g.AccountSeqID, "default group must start with seq=0")
}
require.Zero(t, account.Policies[0].AccountSeqID, "default policy must start with seq=0")
require.NoError(t, store.SaveAccount(ctx, account))
groups, err := store.GetAccountGroups(ctx, LockingStrengthNone, accountID)
require.NoError(t, err)
require.Len(t, groups, 1)
require.NotZerof(t, groups[0].AccountSeqID, "default group must have seq>0 after SaveAccount")
policies, err := store.GetAccountPolicies(ctx, LockingStrengthNone, accountID)
require.NoError(t, err)
require.Len(t, policies, 1)
require.NotZerof(t, policies[0].AccountSeqID, "default policy must have seq>0 after SaveAccount")
require.ErrorIs(t, store.ExecuteInTransaction(ctx, func(tx Store) error {
next, err := tx.AllocateAccountSeqID(ctx, accountID, types.AccountSeqEntityGroup)
require.NoError(t, err)
require.Equal(t, groups[0].AccountSeqID+1, next, "next group seq must be max+1")
next, err = tx.AllocateAccountSeqID(ctx, accountID, types.AccountSeqEntityPolicy)
require.NoError(t, err)
require.Equal(t, policies[0].AccountSeqID+1, next, "next policy seq must be max+1")
return errRollback
}), errRollback)
}
func TestSaveAccount_PreservesExistingSeqIDs(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
ctx := context.Background()
const accountID = "bf1c8084-ba50-4ce7-9439-34653001fc3b"
account, err := store.GetAccount(ctx, accountID)
require.NoError(t, err)
groupSeqs := make(map[string]uint32)
policySeqs := make(map[string]uint32)
routeSeqs := make(map[route.ID]uint32)
nsgSeqs := make(map[string]uint32)
resourceSeqs := make(map[string]uint32)
routerSeqs := make(map[string]uint32)
networkSeqs := make(map[string]uint32)
for _, g := range account.Groups {
require.NotZero(t, g.AccountSeqID, "fixture group must have seq>0 after backfill")
groupSeqs[g.ID] = g.AccountSeqID
}
for _, p := range account.Policies {
require.NotZero(t, p.AccountSeqID, "fixture policy must have seq>0")
policySeqs[p.ID] = p.AccountSeqID
}
for _, r := range account.Routes {
require.NotZero(t, r.AccountSeqID, "fixture route must have seq>0")
routeSeqs[r.ID] = r.AccountSeqID
}
for _, n := range account.NameServerGroups {
require.NotZero(t, n.AccountSeqID, "fixture name_server_group must have seq>0")
nsgSeqs[n.ID] = n.AccountSeqID
}
for _, nr := range account.NetworkResources {
require.NotZero(t, nr.AccountSeqID, "fixture network_resource must have seq>0")
resourceSeqs[nr.ID] = nr.AccountSeqID
}
for _, nr := range account.NetworkRouters {
require.NotZero(t, nr.AccountSeqID, "fixture network_router must have seq>0")
routerSeqs[nr.ID] = nr.AccountSeqID
}
for _, n := range account.Networks {
require.NotZero(t, n.AccountSeqID, "fixture network must have seq>0 after backfill")
networkSeqs[n.ID] = n.AccountSeqID
}
require.NoError(t, store.SaveAccount(ctx, account))
after, err := store.GetAccount(ctx, accountID)
require.NoError(t, err)
for _, g := range after.Groups {
require.Equal(t, groupSeqs[g.ID], g.AccountSeqID, "group %s seq must be preserved on re-save", g.ID)
}
for _, p := range after.Policies {
require.Equal(t, policySeqs[p.ID], p.AccountSeqID, "policy %s seq must be preserved", p.ID)
}
for _, r := range after.Routes {
require.Equal(t, routeSeqs[r.ID], r.AccountSeqID, "route %s seq must be preserved (slice-of-value addressability)", r.ID)
}
for _, n := range after.NameServerGroups {
require.Equal(t, nsgSeqs[n.ID], n.AccountSeqID, "name_server_group %s seq must be preserved (slice-of-value addressability)", n.ID)
}
for _, nr := range after.NetworkResources {
require.Equal(t, resourceSeqs[nr.ID], nr.AccountSeqID, "network_resource %s seq must be preserved", nr.ID)
}
for _, nr := range after.NetworkRouters {
require.Equal(t, routerSeqs[nr.ID], nr.AccountSeqID, "network_router %s seq must be preserved", nr.ID)
}
for _, n := range after.Networks {
require.Equal(t, networkSeqs[n.ID], n.AccountSeqID, "network %s seq must be preserved", n.ID)
}
}
func TestSaveAccount_AllocatesSeqIDsForAllEntityTypes(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
ctx := context.Background()
const accountID = "save-account-all-entities"
addr, err := netip.ParseAddr("8.8.8.8")
require.NoError(t, err)
account := &types.Account{
Id: accountID,
CreatedBy: "user1",
Domain: "example.test",
Settings: &types.Settings{},
Network: &types.Network{Identifier: "net-test"},
Users: map[string]*types.User{
"user1": {Id: "user1", AccountID: accountID, Role: types.UserRoleOwner},
},
Groups: map[string]*types.Group{
"g1": {ID: "g1", AccountID: accountID, Name: "g1", Issued: types.GroupIssuedAPI},
},
Policies: []*types.Policy{
{ID: "p1", AccountID: accountID, Name: "p1", Enabled: true,
Rules: []*types.PolicyRule{{ID: "r1", PolicyID: "p1", Enabled: true}}},
},
Routes: map[route.ID]*route.Route{
"rt1": {ID: "rt1", AccountID: accountID, NetID: "net1", Peer: "peer1"},
},
NameServerGroups: map[string]*nbdns.NameServerGroup{
"nsg1": {ID: "nsg1", AccountID: accountID, Name: "nsg1", Enabled: true,
NameServers: []nbdns.NameServer{{IP: addr, NSType: nbdns.UDPNameServerType, Port: 53}}},
},
NetworkResources: []*resourceTypes.NetworkResource{
{ID: "nr1", AccountID: accountID, NetworkID: "net1", Name: "res1", Enabled: true},
},
NetworkRouters: []*routerTypes.NetworkRouter{
{ID: "nrt1", AccountID: accountID, NetworkID: "net1", Peer: "peer1", Enabled: true},
},
Networks: []*networkTypes.Network{
{ID: "n1", AccountID: accountID, Name: "n1"},
},
PostureChecks: []*posture.Checks{
{ID: "pc1", AccountID: accountID, Name: "pc1",
Checks: posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.26.0"},
}},
},
}
require.NoError(t, store.SaveAccount(ctx, account))
after, err := store.GetAccount(ctx, accountID)
require.NoError(t, err)
require.Len(t, after.Groups, 1)
require.Len(t, after.Policies, 1)
require.Len(t, after.Routes, 1)
require.Len(t, after.NameServerGroups, 1)
require.Len(t, after.NetworkResources, 1)
require.Len(t, after.NetworkRouters, 1)
require.Len(t, after.Networks, 1)
require.Len(t, after.PostureChecks, 1)
for _, g := range after.Groups {
require.NotZero(t, g.AccountSeqID, "group seq must be allocated")
}
for _, p := range after.Policies {
require.NotZero(t, p.AccountSeqID, "policy seq must be allocated")
}
for _, r := range after.Routes {
require.NotZero(t, r.AccountSeqID, "route seq must be allocated (slice-of-value addressability)")
}
for _, n := range after.NameServerGroups {
require.NotZero(t, n.AccountSeqID, "name_server_group seq must be allocated (slice-of-value addressability)")
}
for _, nr := range after.NetworkResources {
require.NotZero(t, nr.AccountSeqID, "network_resource seq must be allocated")
}
for _, nr := range after.NetworkRouters {
require.NotZero(t, nr.AccountSeqID, "network_router seq must be allocated")
}
for _, n := range after.Networks {
require.NotZero(t, n.AccountSeqID, "network seq must be allocated")
}
for _, pc := range after.PostureChecks {
require.NotZero(t, pc.AccountSeqID, "posture_check seq must be allocated")
}
require.NoError(t, store.SaveAccount(ctx, after))
final, err := store.GetAccount(ctx, accountID)
require.NoError(t, err)
for _, r := range final.Routes {
require.Equal(t, after.Routes[r.ID].AccountSeqID, r.AccountSeqID, "route seq preserved on re-save")
}
for _, n := range final.NameServerGroups {
require.Equal(t, after.NameServerGroups[n.ID].AccountSeqID, n.AccountSeqID, "name_server_group seq preserved on re-save")
}
afterByID := map[string]uint32{}
for _, n := range after.Networks {
afterByID[n.ID] = n.AccountSeqID
}
for _, n := range final.Networks {
require.Equal(t, afterByID[n.ID], n.AccountSeqID, "network seq preserved on re-save")
}
afterPCByID := map[string]uint32{}
for _, pc := range after.PostureChecks {
afterPCByID[pc.ID] = pc.AccountSeqID
}
for _, pc := range final.PostureChecks {
require.Equal(t, afterPCByID[pc.ID], pc.AccountSeqID, "posture_check seq preserved on re-save")
}
}
func TestAllocateAccountSeqID_ConcurrentSameAccountEntity(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
ctx := context.Background()
const accountID = "concurrent-test"
const entity = types.AccountSeqEntityPolicy
const goroutines = 32
type result struct {
seq uint32
err error
}
results := make(chan result, goroutines)
start := make(chan struct{})
for i := 0; i < goroutines; i++ {
go func() {
<-start
var allocated uint32
err := store.ExecuteInTransaction(ctx, func(tx Store) error {
seq, err := tx.AllocateAccountSeqID(ctx, accountID, entity)
allocated = seq
return err
})
results <- result{seq: allocated, err: err}
}()
}
close(start)
seen := make(map[uint32]int, goroutines)
for i := 0; i < goroutines; i++ {
r := <-results
require.NoError(t, r.err, "concurrent allocate must not fail")
require.NotZero(t, r.seq, "allocated seq must be non-zero")
seen[r.seq]++
}
require.Lenf(t, seen, goroutines, "every concurrent allocation must yield a unique id; got duplicates in %v", seen)
for i := uint32(1); i <= goroutines; i++ {
require.Equalf(t, 1, seen[i], "id %d must appear exactly once across concurrent allocations", i)
}
}
func TestStoreCreateGroups_AllocatedSeqIDIsNotClobbered(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
ctx := context.Background()
const accountID = "bf1c8084-ba50-4ce7-9439-34653001fc3b"
groups := []*types.Group{
{ID: "seq-test-g1", AccountID: accountID, Name: "g1", Issued: "jwt", AccountSeqID: 7777},
{ID: "seq-test-g2", AccountID: accountID, Name: "g2", Issued: "jwt", AccountSeqID: 7778},
}
require.NoError(t, store.CreateGroups(ctx, accountID, groups))
for _, want := range groups {
got, err := store.GetGroupByID(ctx, LockingStrengthNone, accountID, want.ID)
require.NoError(t, err)
require.Equal(t, want.AccountSeqID, got.AccountSeqID, "seq id from caller must be persisted on insert")
}
groups[0].Name = "g1-renamed"
groups[0].AccountSeqID = 0
require.NoError(t, store.CreateGroups(ctx, accountID, groups[:1]))
got, err := store.GetGroupByID(ctx, LockingStrengthNone, accountID, "seq-test-g1")
require.NoError(t, err)
require.Equal(t, "g1-renamed", got.Name, "upsert path still updates other columns")
require.Equal(t, uint32(7777), got.AccountSeqID, "upsert path must NOT overwrite account_seq_id")
}
func TestPolicyCreate_AllocatesSeqID(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
ctx := context.Background()
const accountID = "bf1c8084-ba50-4ce7-9439-34653001fc3b"
existing, err := store.GetAccountPolicies(ctx, LockingStrengthNone, accountID)
require.NoError(t, err)
maxSeq := uint32(0)
for _, p := range existing {
if p.AccountSeqID > maxSeq {
maxSeq = p.AccountSeqID
}
}
require.NoError(t, store.ExecuteInTransaction(ctx, func(tx Store) error {
seq, err := tx.AllocateAccountSeqID(ctx, accountID, types.AccountSeqEntityPolicy)
if err != nil {
return err
}
require.Equal(t, maxSeq+1, seq, "next id should be max+1 after backfill")
newPolicy := &types.Policy{
ID: "bench-new-policy",
AccountID: accountID,
AccountSeqID: seq,
Enabled: true,
Rules: []*types.PolicyRule{{
ID: "bench-new-policy-rule",
PolicyID: "bench-new-policy",
Enabled: true,
Action: types.PolicyTrafficActionAccept,
Sources: []string{"groupA"},
Destinations: []string{"groupC"},
Bidirectional: true,
}},
}
return tx.CreatePolicy(ctx, newPolicy)
}))
created, err := store.GetPolicyByID(ctx, LockingStrengthNone, accountID, "bench-new-policy")
require.NoError(t, err)
require.Equal(t, maxSeq+1, created.AccountSeqID)
}

View File

@@ -137,6 +137,7 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, met
&networkTypes.Network{}, &routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{}, &types.AccountOnboarding{},
&types.Job{}, &zones.Zone{}, &records.Record{}, &types.UserInviteRecord{}, &rpservice.Service{}, &rpservice.Target{}, &domain.Domain{},
&accesslogs.AccessLogEntry{}, &proxy.Proxy{},
&types.AccountSeqCounter{},
)
if err != nil {
return nil, fmt.Errorf("auto migratePreAuto: %w", err)
@@ -307,6 +308,10 @@ func (s *SqlStore) SaveAccount(ctx context.Context, account *types.Account) erro
return result.Error
}
if err := s.assignAccountSeqIDs(ctx, tx, account); err != nil {
return fmt.Errorf("assign seq ids: %w", err)
}
result = tx.
Session(&gorm.Session{FullSaveAssociations: true}).
Clauses(clause.OnConflict{UpdateAll: true}).
@@ -658,6 +663,22 @@ func (s *SqlStore) SaveUser(ctx context.Context, user *types.User) error {
}
// CreateGroups creates the given list of groups to the database.
// groupUpsertColumns is the explicit allowlist of columns that get updated when
// CreateGroups / UpdateGroups hit a PK conflict. account_seq_id is intentionally
// omitted so a caller passing an entity with the zero value (e.g. an HTTP
// handler-built struct) cannot reset the persisted seq id during an upsert.
// Keep this in sync with the Group schema in management/server/types/group.go.
func groupUpsertColumns() clause.Set {
return clause.AssignmentColumns([]string{
"account_id",
"name",
"issued",
"integration_ref_id",
"integration_ref_integration_type",
"resources",
})
}
func (s *SqlStore) CreateGroups(ctx context.Context, accountID string, groups []*types.Group) error {
if len(groups) == 0 {
return nil
@@ -667,8 +688,9 @@ func (s *SqlStore) CreateGroups(ctx context.Context, accountID string, groups []
result := tx.
Clauses(
clause.OnConflict{
Columns: []clause.Column{{Name: "id"}},
Where: clause.Where{Exprs: []clause.Expression{clause.Eq{Column: "groups.account_id", Value: accountID}}},
UpdateAll: true,
DoUpdates: groupUpsertColumns(),
},
).
Omit(clause.Associations).
@@ -692,8 +714,9 @@ func (s *SqlStore) UpdateGroups(ctx context.Context, accountID string, groups []
result := tx.
Clauses(
clause.OnConflict{
Columns: []clause.Column{{Name: "id"}},
Where: clause.Where{Exprs: []clause.Expression{clause.Eq{Column: "groups.account_id", Value: accountID}}},
UpdateAll: true,
DoUpdates: groupUpsertColumns(),
},
).
Omit(clause.Associations).
@@ -1995,7 +2018,7 @@ func (s *SqlStore) getUsers(ctx context.Context, accountID string) ([]types.User
}
func (s *SqlStore) getGroups(ctx context.Context, accountID string) ([]*types.Group, error) {
const query = `SELECT id, account_id, name, issued, resources, integration_ref_id, integration_ref_integration_type FROM groups WHERE account_id = $1`
const query = `SELECT id, account_id, account_seq_id, name, issued, resources, integration_ref_id, integration_ref_integration_type FROM groups WHERE account_id = $1`
rows, err := s.pool.Query(ctx, query, accountID)
if err != nil {
return nil, err
@@ -2005,7 +2028,7 @@ func (s *SqlStore) getGroups(ctx context.Context, accountID string) ([]*types.Gr
var resources []byte
var refID sql.NullInt64
var refType sql.NullString
err := row.Scan(&g.ID, &g.AccountID, &g.Name, &g.Issued, &resources, &refID, &refType)
err := row.Scan(&g.ID, &g.AccountID, &g.AccountSeqID, &g.Name, &g.Issued, &resources, &refID, &refType)
if err == nil {
if refID.Valid {
g.IntegrationReference.ID = int(refID.Int64)
@@ -2030,7 +2053,7 @@ func (s *SqlStore) getGroups(ctx context.Context, accountID string) ([]*types.Gr
}
func (s *SqlStore) getPolicies(ctx context.Context, accountID string) ([]*types.Policy, error) {
const query = `SELECT id, account_id, name, description, enabled, source_posture_checks FROM policies WHERE account_id = $1`
const query = `SELECT id, account_id, account_seq_id, name, description, enabled, source_posture_checks FROM policies WHERE account_id = $1`
rows, err := s.pool.Query(ctx, query, accountID)
if err != nil {
return nil, err
@@ -2039,7 +2062,7 @@ func (s *SqlStore) getPolicies(ctx context.Context, accountID string) ([]*types.
var p types.Policy
var checks []byte
var enabled sql.NullBool
err := row.Scan(&p.ID, &p.AccountID, &p.Name, &p.Description, &enabled, &checks)
err := row.Scan(&p.ID, &p.AccountID, &p.AccountSeqID, &p.Name, &p.Description, &enabled, &checks)
if err == nil {
if enabled.Valid {
p.Enabled = enabled.Bool
@@ -2057,7 +2080,7 @@ func (s *SqlStore) getPolicies(ctx context.Context, accountID string) ([]*types.
}
func (s *SqlStore) getRoutes(ctx context.Context, accountID string) ([]route.Route, error) {
const query = `SELECT id, account_id, network, domains, keep_route, net_id, description, peer, peer_groups, network_type, masquerade, metric, enabled, groups, access_control_groups, skip_auto_apply FROM routes WHERE account_id = $1`
const query = `SELECT id, account_id, account_seq_id, network, domains, keep_route, net_id, description, peer, peer_groups, network_type, masquerade, metric, enabled, groups, access_control_groups, skip_auto_apply FROM routes WHERE account_id = $1`
rows, err := s.pool.Query(ctx, query, accountID)
if err != nil {
return nil, err
@@ -2067,7 +2090,7 @@ func (s *SqlStore) getRoutes(ctx context.Context, accountID string) ([]route.Rou
var network, domains, peerGroups, groups, accessGroups []byte
var keepRoute, masquerade, enabled, skipAutoApply sql.NullBool
var metric sql.NullInt64
err := row.Scan(&r.ID, &r.AccountID, &network, &domains, &keepRoute, &r.NetID, &r.Description, &r.Peer, &peerGroups, &r.NetworkType, &masquerade, &metric, &enabled, &groups, &accessGroups, &skipAutoApply)
err := row.Scan(&r.ID, &r.AccountID, &r.AccountSeqID, &network, &domains, &keepRoute, &r.NetID, &r.Description, &r.Peer, &peerGroups, &r.NetworkType, &masquerade, &metric, &enabled, &groups, &accessGroups, &skipAutoApply)
if err == nil {
if keepRoute.Valid {
r.KeepRoute = keepRoute.Bool
@@ -2109,7 +2132,7 @@ func (s *SqlStore) getRoutes(ctx context.Context, accountID string) ([]route.Rou
}
func (s *SqlStore) getNameServerGroups(ctx context.Context, accountID string) ([]nbdns.NameServerGroup, error) {
const query = `SELECT id, account_id, name, description, name_servers, groups, "primary", domains, enabled, search_domains_enabled FROM name_server_groups WHERE account_id = $1`
const query = `SELECT id, account_id, account_seq_id, name, description, name_servers, groups, "primary", domains, enabled, search_domains_enabled FROM name_server_groups WHERE account_id = $1`
rows, err := s.pool.Query(ctx, query, accountID)
if err != nil {
return nil, err
@@ -2118,7 +2141,7 @@ func (s *SqlStore) getNameServerGroups(ctx context.Context, accountID string) ([
var n nbdns.NameServerGroup
var ns, groups, domains []byte
var primary, enabled, searchDomainsEnabled sql.NullBool
err := row.Scan(&n.ID, &n.AccountID, &n.Name, &n.Description, &ns, &groups, &primary, &domains, &enabled, &searchDomainsEnabled)
err := row.Scan(&n.ID, &n.AccountID, &n.AccountSeqID, &n.Name, &n.Description, &ns, &groups, &primary, &domains, &enabled, &searchDomainsEnabled)
if err == nil {
if primary.Valid {
n.Primary = primary.Bool
@@ -2154,7 +2177,7 @@ func (s *SqlStore) getNameServerGroups(ctx context.Context, accountID string) ([
}
func (s *SqlStore) getPostureChecks(ctx context.Context, accountID string) ([]*posture.Checks, error) {
const query = `SELECT id, account_id, name, description, checks FROM posture_checks WHERE account_id = $1`
const query = `SELECT id, account_id, account_seq_id, name, description, checks FROM posture_checks WHERE account_id = $1`
rows, err := s.pool.Query(ctx, query, accountID)
if err != nil {
return nil, err
@@ -2162,7 +2185,7 @@ func (s *SqlStore) getPostureChecks(ctx context.Context, accountID string) ([]*p
checks, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (*posture.Checks, error) {
var c posture.Checks
var checksDef []byte
err := row.Scan(&c.ID, &c.AccountID, &c.Name, &c.Description, &checksDef)
err := row.Scan(&c.ID, &c.AccountID, &c.AccountSeqID, &c.Name, &c.Description, &checksDef)
if err == nil && checksDef != nil {
_ = json.Unmarshal(checksDef, &c.Checks)
}
@@ -2328,7 +2351,7 @@ func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*rpserv
}
func (s *SqlStore) getNetworks(ctx context.Context, accountID string) ([]*networkTypes.Network, error) {
const query = `SELECT id, account_id, name, description FROM networks WHERE account_id = $1`
const query = `SELECT id, account_id, account_seq_id, name, description FROM networks WHERE account_id = $1`
rows, err := s.pool.Query(ctx, query, accountID)
if err != nil {
return nil, err
@@ -2345,7 +2368,7 @@ func (s *SqlStore) getNetworks(ctx context.Context, accountID string) ([]*networ
}
func (s *SqlStore) getNetworkRouters(ctx context.Context, accountID string) ([]*routerTypes.NetworkRouter, error) {
const query = `SELECT id, network_id, account_id, peer, peer_groups, masquerade, metric, enabled FROM network_routers WHERE account_id = $1`
const query = `SELECT id, network_id, account_id, account_seq_id, peer, peer_groups, masquerade, metric, enabled FROM network_routers WHERE account_id = $1`
rows, err := s.pool.Query(ctx, query, accountID)
if err != nil {
return nil, err
@@ -2355,7 +2378,7 @@ func (s *SqlStore) getNetworkRouters(ctx context.Context, accountID string) ([]*
var peerGroups []byte
var masquerade, enabled sql.NullBool
var metric sql.NullInt64
err := row.Scan(&r.ID, &r.NetworkID, &r.AccountID, &r.Peer, &peerGroups, &masquerade, &metric, &enabled)
err := row.Scan(&r.ID, &r.NetworkID, &r.AccountID, &r.AccountSeqID, &r.Peer, &peerGroups, &masquerade, &metric, &enabled)
if err == nil {
if masquerade.Valid {
r.Masquerade = masquerade.Bool
@@ -2383,7 +2406,7 @@ func (s *SqlStore) getNetworkRouters(ctx context.Context, accountID string) ([]*
}
func (s *SqlStore) getNetworkResources(ctx context.Context, accountID string) ([]*resourceTypes.NetworkResource, error) {
const query = `SELECT id, network_id, account_id, name, description, type, domain, prefix, enabled FROM network_resources WHERE account_id = $1`
const query = `SELECT id, network_id, account_id, account_seq_id, name, description, type, domain, prefix, enabled FROM network_resources WHERE account_id = $1`
rows, err := s.pool.Query(ctx, query, accountID)
if err != nil {
return nil, err
@@ -2392,7 +2415,7 @@ func (s *SqlStore) getNetworkResources(ctx context.Context, accountID string) ([
var r resourceTypes.NetworkResource
var prefix []byte
var enabled sql.NullBool
err := row.Scan(&r.ID, &r.NetworkID, &r.AccountID, &r.Name, &r.Description, &r.Type, &r.Domain, &prefix, &enabled)
err := row.Scan(&r.ID, &r.NetworkID, &r.AccountID, &r.AccountSeqID, &r.Name, &r.Description, &r.Type, &r.Domain, &prefix, &enabled)
if err == nil {
if enabled.Valid {
r.Enabled = enabled.Bool
@@ -3565,6 +3588,262 @@ func (s *SqlStore) withTx(tx *gorm.DB) Store {
}
}
// AllocateAccountSeqID returns the next per-account integer id for the given
// component kind. Must be called inside ExecuteInTransaction so the increment
// is serialized with the component insert.
func (s *SqlStore) AllocateAccountSeqID(ctx context.Context, accountID string, entity types.AccountSeqEntity) (uint32, error) {
return allocateAccountSeqID(ctx, s.db, s.storeEngine, accountID, entity)
}
func allocateAccountSeqID(_ context.Context, db *gorm.DB, engine types.Engine, accountID string, entity types.AccountSeqEntity) (uint32, error) {
switch engine {
case types.PostgresStoreEngine, types.SqliteStoreEngine:
return allocateAccountSeqIDReturning(db, accountID, entity)
case types.MysqlStoreEngine:
return allocateAccountSeqIDMysql(db, accountID, entity)
default:
return 0, fmt.Errorf("unsupported store engine for account_seq allocator: %v", engine)
}
}
// allocateAccountSeqIDReturning runs a single atomic INSERT ... ON CONFLICT
// DO UPDATE ... RETURNING that gives us the allocated id without a separate
// SELECT FOR UPDATE. Two concurrent allocations for the same (account, entity)
// produce two distinct ids: one wins the INSERT, the other wins the UPDATE
// branch and returns next_id+1.
func allocateAccountSeqIDReturning(db *gorm.DB, accountID string, entity types.AccountSeqEntity) (uint32, error) {
const sqlStr = `
INSERT INTO account_seq_counters (account_id, entity, next_id)
VALUES (?, ?, 2)
ON CONFLICT (account_id, entity) DO UPDATE
SET next_id = account_seq_counters.next_id + 1
RETURNING (next_id - 1)
`
var allocated uint32
if err := db.Raw(sqlStr, accountID, string(entity)).Scan(&allocated).Error; err != nil {
return 0, fmt.Errorf("upsert account seq counter: %w", err)
}
if allocated == 0 {
return 0, fmt.Errorf("upsert account seq counter returned 0")
}
return allocated, nil
}
// allocateAccountSeqIDMysql is the MySQL equivalent of allocateAccountSeqIDReturning.
// MySQL has no RETURNING on ON DUPLICATE KEY UPDATE, so we use the LAST_INSERT_ID
// trick: passing an expression to LAST_INSERT_ID(expr) both sets the session value
// and returns it from the INSERT. The INSERT's value uses LAST_INSERT_ID(2) so the
// no-conflict path also surfaces the new next_id, keeping the read-back uniform.
// LAST_INSERT_ID is per-connection; GORM transactions pin a single connection,
// so the follow-up SELECT sees the same value.
func allocateAccountSeqIDMysql(db *gorm.DB, accountID string, entity types.AccountSeqEntity) (uint32, error) {
const upsertSQL = `
INSERT INTO account_seq_counters (account_id, entity, next_id)
VALUES (?, ?, LAST_INSERT_ID(2))
ON DUPLICATE KEY UPDATE next_id = LAST_INSERT_ID(next_id + 1)
`
if err := db.Exec(upsertSQL, accountID, string(entity)).Error; err != nil {
return 0, fmt.Errorf("upsert account seq counter: %w", err)
}
var newNext uint64
if err := db.Raw("SELECT LAST_INSERT_ID()").Scan(&newNext).Error; err != nil {
return 0, fmt.Errorf("get last insert id: %w", err)
}
if newNext == 0 {
return 0, fmt.Errorf("LAST_INSERT_ID returned 0; account_seq_counters misconfigured")
}
return uint32(newNext - 1), nil
}
// assignAccountSeqIDs allocates a per-account integer id for any component on
// the in-memory account whose AccountSeqID is zero. Called from SaveAccount so
// the canonical "save the whole account" path produces the same persisted seq
// ids that the manager-level Create paths produce. Update flows that go
// through SaveAccount preserve existing non-zero values; for those, the
// per-entity counter is bumped so subsequent AllocateAccountSeqID calls don't
// hand out a colliding id.
func (s *SqlStore) assignAccountSeqIDs(ctx context.Context, tx *gorm.DB, account *types.Account) error {
maxByEntity := make(map[types.AccountSeqEntity]uint32, 8)
bump := func(entity types.AccountSeqEntity, seq uint32) {
if seq > maxByEntity[entity] {
maxByEntity[entity] = seq
}
}
for i := range account.GroupsG {
g := account.GroupsG[i]
if g == nil {
continue
}
if g.AccountSeqID != 0 {
bump(types.AccountSeqEntityGroup, g.AccountSeqID)
continue
}
seq, err := allocateAccountSeqID(ctx, tx, s.storeEngine, account.Id, types.AccountSeqEntityGroup)
if err != nil {
return err
}
g.AccountSeqID = seq
// Defensive: generateAccountSQLTypes currently aliases the same
// *Group pointer into GroupsG and Groups[id] (so this is a no-op
// today), but mirror the seq anyway so any future divergence in
// how the two collections are populated doesn't silently leave
// the canonical map view stale.
if original, ok := account.Groups[g.ID]; ok && original != nil && original != g {
original.AccountSeqID = seq
}
}
for _, p := range account.Policies {
if p == nil {
continue
}
if p.AccountSeqID != 0 {
bump(types.AccountSeqEntityPolicy, p.AccountSeqID)
continue
}
seq, err := allocateAccountSeqID(ctx, tx, s.storeEngine, account.Id, types.AccountSeqEntityPolicy)
if err != nil {
return err
}
p.AccountSeqID = seq
}
for i := range account.RoutesG {
r := &account.RoutesG[i]
if r.AccountSeqID != 0 {
bump(types.AccountSeqEntityRoute, r.AccountSeqID)
continue
}
seq, err := allocateAccountSeqID(ctx, tx, s.storeEngine, account.Id, types.AccountSeqEntityRoute)
if err != nil {
return err
}
r.AccountSeqID = seq
// Mirror the new seq onto the canonical map view so callers that
// hold the same in-memory account post-Save read a consistent
// AccountSeqID — without this, components/encoder code would see
// 0 for routes saved this transaction until the account is reloaded.
if original, ok := account.Routes[r.ID]; ok && original != nil {
original.AccountSeqID = seq
}
}
for i := range account.NameServerGroupsG {
ng := &account.NameServerGroupsG[i]
if ng.AccountSeqID != 0 {
bump(types.AccountSeqEntityNameserverGroup, ng.AccountSeqID)
continue
}
seq, err := allocateAccountSeqID(ctx, tx, s.storeEngine, account.Id, types.AccountSeqEntityNameserverGroup)
if err != nil {
return err
}
ng.AccountSeqID = seq
if original, ok := account.NameServerGroups[ng.ID]; ok && original != nil {
original.AccountSeqID = seq
}
}
for _, nr := range account.NetworkResources {
if nr == nil {
continue
}
if nr.AccountSeqID != 0 {
bump(types.AccountSeqEntityNetworkResource, nr.AccountSeqID)
continue
}
seq, err := allocateAccountSeqID(ctx, tx, s.storeEngine, account.Id, types.AccountSeqEntityNetworkResource)
if err != nil {
return err
}
nr.AccountSeqID = seq
}
for _, nr := range account.NetworkRouters {
if nr == nil {
continue
}
if nr.AccountSeqID != 0 {
bump(types.AccountSeqEntityNetworkRouter, nr.AccountSeqID)
continue
}
seq, err := allocateAccountSeqID(ctx, tx, s.storeEngine, account.Id, types.AccountSeqEntityNetworkRouter)
if err != nil {
return err
}
nr.AccountSeqID = seq
}
for _, n := range account.Networks {
if n == nil {
continue
}
if n.AccountSeqID != 0 {
bump(types.AccountSeqEntityNetwork, n.AccountSeqID)
continue
}
seq, err := allocateAccountSeqID(ctx, tx, s.storeEngine, account.Id, types.AccountSeqEntityNetwork)
if err != nil {
return err
}
n.AccountSeqID = seq
}
for _, pc := range account.PostureChecks {
if pc == nil {
continue
}
if pc.AccountSeqID != 0 {
bump(types.AccountSeqEntityPostureCheck, pc.AccountSeqID)
continue
}
seq, err := allocateAccountSeqID(ctx, tx, s.storeEngine, account.Id, types.AccountSeqEntityPostureCheck)
if err != nil {
return err
}
pc.AccountSeqID = seq
}
for entity, maxSeq := range maxByEntity {
if err := ensureAccountSeqCounter(tx, s.storeEngine, account.Id, entity, maxSeq+1); err != nil {
return fmt.Errorf("seed counter for %s: %w", entity, err)
}
}
return nil
}
// ensureAccountSeqCounter raises the per-account counter for entity to at
// least target. Used when SaveAccount persists components that already carry
// AccountSeqIDs (e.g. test bulk-load from sqlite to postgres, or migrations
// running before component data lands) so that the next AllocateAccountSeqID
// call returns a fresh id beyond what was just written.
func ensureAccountSeqCounter(db *gorm.DB, engine types.Engine, accountID string, entity types.AccountSeqEntity, target uint32) error {
switch engine {
case types.PostgresStoreEngine, types.SqliteStoreEngine:
const sqlStr = `
INSERT INTO account_seq_counters (account_id, entity, next_id)
VALUES (?, ?, ?)
ON CONFLICT (account_id, entity) DO UPDATE
SET next_id = GREATEST(account_seq_counters.next_id, EXCLUDED.next_id)
`
// sqlite's UPSERT understands max() but the migration uses GREATEST
// for postgres and max() for sqlite. We collapse to dialect-specific
// statements only when needed.
if engine == types.SqliteStoreEngine {
const sqliteSQL = `
INSERT INTO account_seq_counters (account_id, entity, next_id)
VALUES (?, ?, ?)
ON CONFLICT (account_id, entity) DO UPDATE
SET next_id = max(account_seq_counters.next_id, excluded.next_id)
`
return db.Exec(sqliteSQL, accountID, string(entity), target).Error
}
return db.Exec(sqlStr, accountID, string(entity), target).Error
case types.MysqlStoreEngine:
const sqlStr = `
INSERT INTO account_seq_counters (account_id, entity, next_id)
VALUES (?, ?, ?)
ON DUPLICATE KEY UPDATE next_id = GREATEST(next_id, VALUES(next_id))
`
return db.Exec(sqlStr, accountID, string(entity), target).Error
default:
return fmt.Errorf("unsupported store engine for account_seq counter: %v", engine)
}
}
// transaction wraps a GORM transaction with MySQL-specific FK checks handling
// Use this instead of db.Transaction() directly to avoid deadlocks on MySQL/Aurora
func (s *SqlStore) transaction(fn func(*gorm.DB) error) error {
@@ -3754,7 +4033,7 @@ func (s *SqlStore) UpdateGroup(ctx context.Context, group *types.Group) error {
return status.Errorf(status.InvalidArgument, "group is nil")
}
if err := s.db.Omit(clause.Associations).Save(group).Error; err != nil {
if err := s.db.Omit(clause.Associations, "account_seq_id").Save(group).Error; err != nil {
log.WithContext(ctx).Errorf("failed to save group to store: %v", err)
return status.Errorf(status.Internal, "failed to save group to store")
}
@@ -3842,7 +4121,7 @@ func (s *SqlStore) CreatePolicy(ctx context.Context, policy *types.Policy) error
// SavePolicy saves a policy to the database.
func (s *SqlStore) SavePolicy(ctx context.Context, policy *types.Policy) error {
result := s.db.Session(&gorm.Session{FullSaveAssociations: true}).Save(policy)
result := s.db.Session(&gorm.Session{FullSaveAssociations: true}).Omit("account_seq_id").Save(policy)
if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to save policy to the store: %s", err)
return status.Errorf(status.Internal, "failed to save policy to store")

View File

@@ -220,6 +220,11 @@ type Store interface {
GetStoreEngine() types.Engine
ExecuteInTransaction(ctx context.Context, f func(store Store) error) error
// AllocateAccountSeqID returns the next per-account integer id for the given
// component kind. Must run inside a transaction so the increment is serialized
// with the component insert.
AllocateAccountSeqID(ctx context.Context, accountID string, entity types.AccountSeqEntity) (uint32, error)
GetAccountNetworks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*networkTypes.Network, error)
GetNetworkByID(ctx context.Context, lockStrength LockingStrength, accountID, networkID string) (*networkTypes.Network, error)
SaveNetwork(ctx context.Context, network *networkTypes.Network) error
@@ -522,6 +527,30 @@ func getMigrationsPostAuto(ctx context.Context) []migrationFunc {
func(db *gorm.DB) error {
return migration.DropIndex[proxy.Proxy](ctx, db, "idx_proxy_account_id_unique")
},
func(db *gorm.DB) error {
return migration.BackfillAccountSeqIDs[types.Policy](ctx, db, types.AccountSeqEntityPolicy, "id")
},
func(db *gorm.DB) error {
return migration.BackfillAccountSeqIDs[types.Group](ctx, db, types.AccountSeqEntityGroup, "id")
},
func(db *gorm.DB) error {
return migration.BackfillAccountSeqIDs[route.Route](ctx, db, types.AccountSeqEntityRoute, "id")
},
func(db *gorm.DB) error {
return migration.BackfillAccountSeqIDs[resourceTypes.NetworkResource](ctx, db, types.AccountSeqEntityNetworkResource, "id")
},
func(db *gorm.DB) error {
return migration.BackfillAccountSeqIDs[routerTypes.NetworkRouter](ctx, db, types.AccountSeqEntityNetworkRouter, "id")
},
func(db *gorm.DB) error {
return migration.BackfillAccountSeqIDs[dns.NameServerGroup](ctx, db, types.AccountSeqEntityNameserverGroup, "id")
},
func(db *gorm.DB) error {
return migration.BackfillAccountSeqIDs[networkTypes.Network](ctx, db, types.AccountSeqEntityNetwork, "id")
},
func(db *gorm.DB) error {
return migration.BackfillAccountSeqIDs[posture.Checks](ctx, db, types.AccountSeqEntityPostureCheck, "id")
},
}
}

View File

@@ -746,6 +746,21 @@ func (mr *MockStoreMockRecorder) EphemeralServiceExists(ctx, lockStrength, accou
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "EphemeralServiceExists", reflect.TypeOf((*MockStore)(nil).EphemeralServiceExists), ctx, lockStrength, accountID, peerID, domain)
}
// AllocateAccountSeqID mocks base method.
func (m *MockStore) AllocateAccountSeqID(ctx context.Context, accountID string, entity types2.AccountSeqEntity) (uint32, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AllocateAccountSeqID", ctx, accountID, entity)
ret0, _ := ret[0].(uint32)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// AllocateAccountSeqID indicates an expected call of AllocateAccountSeqID.
func (mr *MockStoreMockRecorder) AllocateAccountSeqID(ctx, accountID, entity interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AllocateAccountSeqID", reflect.TypeOf((*MockStore)(nil).AllocateAccountSeqID), ctx, accountID, entity)
}
// ExecuteInTransaction mocks base method.
func (m *MockStore) ExecuteInTransaction(ctx context.Context, f func(Store) error) error {
m.ctrl.T.Helper()

View File

@@ -1006,6 +1006,15 @@ func (a *Account) connResourcesGenerator(ctx context.Context, targetPeer *nbpeer
}
}
// PolicyRuleImpliesLegacySSH reports whether the rule (without an explicit
// NetbirdSSH protocol) implicitly authorises SSH because it permits TCP/22 or
// TCP/22022 — either by ALL-protocol coverage or by an explicit port/port-range
// containing one of those. Exposed for ToComponentSyncResponse so the
// envelope-format response mirrors the legacy SshConfig.SshEnabled bit.
func PolicyRuleImpliesLegacySSH(rule *PolicyRule) bool {
return policyRuleImpliesLegacySSH(rule)
}
func policyRuleImpliesLegacySSH(rule *PolicyRule) bool {
return rule.Protocol == PolicyRuleProtocolALL || (rule.Protocol == PolicyRuleProtocolTCP && (portsIncludesSSH(rule.Ports) || portRangeIncludesSSH(rule.PortRanges)))
}

View File

@@ -16,6 +16,49 @@ import (
"github.com/netbirdio/netbird/route"
)
// GetPeerNetworkMapResult dispatches to either the legacy-NetworkMap path or
// the components path based on the peer's capability and the kill switch.
// Capable peers (PeerCapabilityComponentNetworkMap) get the raw components
// shape — the server skips Calculate() entirely for them, saving CPU
// proportional to the number of capable peers in the account. Legacy peers
// (or any peer when componentsDisabled is true) get the fully-expanded
// NetworkMap as before.
func (a *Account) GetPeerNetworkMapResult(
ctx context.Context,
peerID string,
componentsDisabled bool,
peersCustomZone nbdns.CustomZone,
accountZones []*zones.Zone,
validatedPeersMap map[string]struct{},
resourcePolicies map[string][]*Policy,
routers map[string]map[string]*routerTypes.NetworkRouter,
metrics *telemetry.AccountManagerMetrics,
groupIDToUserIDs map[string][]string,
) PeerNetworkMapResult {
peer := a.Peers[peerID]
if !componentsDisabled && peer != nil && peer.SupportsComponentNetworkMap() {
components := a.GetPeerNetworkMapComponents(
ctx, peerID, peersCustomZone, accountZones, validatedPeersMap, resourcePolicies, routers, groupIDToUserIDs,
)
// Mirror legacy graceful-degrade: GetPeerNetworkMapFromComponents
// returns &NetworkMap{Network: a.Network.Copy()} when components is
// nil. Match that floor so the receiving client always sees the
// account Network identifier, not a fully-empty envelope.
if components == nil {
components = &NetworkMapComponents{
PeerID: peerID,
Network: a.Network.Copy(),
}
}
return PeerNetworkMapResult{Components: components}
}
return PeerNetworkMapResult{
NetworkMap: a.GetPeerNetworkMapFromComponents(
ctx, peerID, peersCustomZone, accountZones, validatedPeersMap, resourcePolicies, routers, metrics, groupIDToUserIDs,
),
}
}
func (a *Account) GetPeerNetworkMapFromComponents(
ctx context.Context,
peerID string,
@@ -82,15 +125,27 @@ func (a *Account) GetPeerNetworkMapComponents(
}
components := &NetworkMapComponents{
PeerID: peerID,
Network: a.Network.Copy(),
NameServerGroups: make([]*nbdns.NameServerGroup, 0),
CustomZoneDomain: peersCustomZone.Domain,
ResourcePoliciesMap: make(map[string][]*Policy),
RoutersMap: make(map[string]map[string]*routerTypes.NetworkRouter),
NetworkResources: make([]*resourceTypes.NetworkResource, 0),
PostureFailedPeers: make(map[string]map[string]struct{}, len(a.PostureChecks)),
RouterPeers: make(map[string]*nbpeer.Peer),
PeerID: peerID,
Network: a.Network.Copy(),
NameServerGroups: make([]*nbdns.NameServerGroup, 0),
CustomZoneDomain: peersCustomZone.Domain,
ResourcePoliciesMap: make(map[string][]*Policy),
RoutersMap: make(map[string]map[string]*routerTypes.NetworkRouter),
NetworkResources: make([]*resourceTypes.NetworkResource, 0),
PostureFailedPeers: make(map[string]map[string]struct{}, len(a.PostureChecks)),
RouterPeers: make(map[string]*nbpeer.Peer),
NetworkXIDToSeq: make(map[string]uint32, len(a.Networks)),
PostureCheckXIDToSeq: make(map[string]uint32, len(a.PostureChecks)),
}
for _, n := range a.Networks {
if n != nil && n.HasSeqID() {
components.NetworkXIDToSeq[n.ID] = n.AccountSeqID
}
}
for _, pc := range a.PostureChecks {
if pc != nil && pc.HasSeqID() {
components.PostureCheckXIDToSeq[pc.ID] = pc.AccountSeqID
}
}
components.AccountSettings = &AccountSettingsInfo{
@@ -253,18 +308,44 @@ func (a *Account) getPeersGroupsPoliciesRoutes(
relevantPeerIDs[peerID] = a.GetPeer(peerID)
peerGroupSet := make(map[string]struct{}, 8)
for groupID, group := range a.Groups {
if slices.Contains(group.Peers, peerID) {
relevantGroupIDs[groupID] = a.GetGroup(groupID)
peerGroupSet[groupID] = struct{}{}
}
}
routeAccessControlGroups := make(map[string]struct{})
for _, r := range a.Routes {
for _, groupID := range r.Groups {
if r == nil {
continue
}
relevant := r.Peer == peerID
if !relevant {
for _, groupID := range r.PeerGroups {
if _, ok := peerGroupSet[groupID]; ok {
relevant = true
break
}
}
}
if !relevant && r.Enabled {
for _, groupID := range r.Groups {
if _, ok := peerGroupSet[groupID]; ok {
relevant = true
break
}
}
}
if !relevant {
continue
}
for _, groupID := range r.PeerGroups {
relevantGroupIDs[groupID] = a.GetGroup(groupID)
}
for _, groupID := range r.PeerGroups {
for _, groupID := range r.Groups {
relevantGroupIDs[groupID] = a.GetGroup(groupID)
}
if r.Enabled {
@@ -485,6 +566,13 @@ func (a *Account) getPostureValidPeersSaveFailed(inputPeers []string, postureChe
return dest
}
// filterGroupPeers trims each group's Peers slice to only those peers that
// also appear in `peers`. Groups whose filtered list is empty are NOT
// deleted from the map — they're kept so the components wire encoder can
// still resolve seq references from routes/policies/access-control groups
// that name them. Calculate() tolerates groups with empty Peers (the inner
// loops simply iterate zero times), so retaining them is behaviourally a
// no-op for the legacy path that consumes the same NetworkMapComponents.
func filterGroupPeers(groups *map[string]*Group, peers map[string]*nbpeer.Peer) {
for groupID, groupInfo := range *groups {
filteredPeers := make([]string, 0, len(groupInfo.Peers))
@@ -494,9 +582,7 @@ func filterGroupPeers(groups *map[string]*Group, peers map[string]*nbpeer.Peer)
}
}
if len(filteredPeers) == 0 {
delete(*groups, groupID)
} else if len(filteredPeers) != len(groupInfo.Peers) {
if len(filteredPeers) != len(groupInfo.Peers) {
ng := groupInfo.Copy()
ng.Peers = filteredPeers
(*groups)[groupID] = ng

View File

@@ -0,0 +1,29 @@
package types
// AccountSeqEntity identifies the kind of component that uses a per-account sequence.
type AccountSeqEntity string
const (
AccountSeqEntityPolicy AccountSeqEntity = "policy"
AccountSeqEntityGroup AccountSeqEntity = "group"
AccountSeqEntityRoute AccountSeqEntity = "route"
AccountSeqEntityNetworkResource AccountSeqEntity = "network_resource"
AccountSeqEntityNetworkRouter AccountSeqEntity = "network_router"
AccountSeqEntityNameserverGroup AccountSeqEntity = "nameserver_group"
AccountSeqEntityNetwork AccountSeqEntity = "network"
AccountSeqEntityPostureCheck AccountSeqEntity = "posture_check"
)
// AccountSeqCounter tracks the next per-account integer id for a given component
// kind. Reads/writes go through the store inside the same transaction as the
// component insert so two concurrent inserts cannot collide on the same id.
type AccountSeqCounter struct {
AccountID string `gorm:"primaryKey;size:255"`
Entity string `gorm:"primaryKey;size:32"`
NextID uint32 `gorm:"not null;default:1"`
}
// TableName overrides the GORM-derived table name.
func (AccountSeqCounter) TableName() string {
return "account_seq_counters"
}

View File

@@ -19,6 +19,10 @@ type Group struct {
// AccountID is a reference to Account that this object belongs
AccountID string `json:"-" gorm:"index"`
// AccountSeqID is a per-account monotonically increasing identifier used as the
// compact wire id when sending NetworkMap components to capable peers.
AccountSeqID uint32 `json:"-" gorm:"index:idx_groups_account_seq_id;not null;default:0"`
// Name visible in the UI
Name string
@@ -41,6 +45,14 @@ type GroupPeer struct {
PeerID string `gorm:"primaryKey"`
}
// HasSeqID reports whether the group has been persisted long enough to have a
// per-account sequence id allocated. Wire encoders that key off AccountSeqID
// must skip groups that return false here — otherwise multiple unpersisted
// groups would collide on id 0.
func (g *Group) HasSeqID() bool {
return g != nil && g.AccountSeqID != 0
}
func (g *Group) LoadGroupPeers() {
g.Peers = make([]string, len(g.GroupPeers))
for i, peer := range g.GroupPeers {
@@ -74,6 +86,7 @@ func (g *Group) Copy() *Group {
group := &Group{
ID: g.ID,
AccountID: g.AccountID,
AccountSeqID: g.AccountSeqID,
Name: g.Name,
Issued: g.Issued,
Peers: make([]string, len(g.Peers)),

View File

@@ -42,6 +42,17 @@ type NetworkMapComponents struct {
PostureFailedPeers map[string]map[string]struct{}
RouterPeers map[string]*nbpeer.Peer
// NetworkXIDToSeq maps Network.ID (xid) → AccountSeqID. Populated by the
// account-side component builder; consumed by the envelope encoder to
// translate RoutersMap keys and NetworkResource.NetworkID references
// to compact uint32 ids. Legacy Calculate() doesn't consult it.
NetworkXIDToSeq map[string]uint32
// PostureCheckXIDToSeq maps posture.Checks.ID (xid) → AccountSeqID.
// Same role as NetworkXIDToSeq, used for PostureFailedPeers keys and
// policy SourcePostureChecks references.
PostureCheckXIDToSeq map[string]uint32
}
type AccountSettingsInfo struct {

View File

@@ -0,0 +1,181 @@
package types_test
import (
"context"
"crypto/rand"
"encoding/base64"
"fmt"
"testing"
goproto "google.golang.org/protobuf/proto"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache"
mgmtgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/server/types"
)
// wireBenchScales mirrors the scales used by networkmap_benchmark_test.go but
// trimmed: encoding+marshal are linear, so we don't need the 30k peer extreme
// to see the trend.
var wireBenchScales = []benchmarkScale{
{"100peers_5groups", 100, 5},
{"500peers_20groups", 500, 20},
{"1000peers_50groups", 1000, 50},
{"5000peers_100groups", 5000, 100},
}
// populateAccountSeqIDs assigns deterministic AccountSeqIDs to every group and
// policy in the account so that the component encoder can reference them. The
// scalableTestAccount fixture builds entities by struct literal and skips this
// step, but production paths populate the IDs via the store layer.
func populateAccountSeqIDs(account *types.Account) {
var nextGroupSeq uint32 = 1
for _, g := range account.Groups {
g.AccountSeqID = nextGroupSeq
nextGroupSeq++
}
var nextPolicySeq uint32 = 1
for _, p := range account.Policies {
p.AccountSeqID = nextPolicySeq
nextPolicySeq++
}
}
// assignValidWgKeys overwrites every peer's Key with a valid base64-encoded
// 32-byte string. The default scalableTestAccount uses unparsable strings
// like "key-peer-0", which makes the components encoder emit a nil WgPubKey
// and the legacy encoder ship 10-char placeholders — both shrink the wire
// size in unrealistic ways. Production peers always have valid 44-char base64
// keys, so any benchmark/breakdown that wants honest numbers must call this.
func assignValidWgKeys(account *types.Account) {
for _, p := range account.Peers {
var raw [32]byte
_, _ = rand.Read(raw[:])
p.Key = base64.StdEncoding.EncodeToString(raw[:])
}
}
// BenchmarkNetworkMapWireEncode reports per-call ns and the marshaled wire
// size for both encoding paths. Run with:
//
// go test -run=^$ -bench=BenchmarkNetworkMapWireEncode -benchmem ./management/server/types/
func BenchmarkNetworkMapWireEncode(b *testing.B) {
skipCIBenchmark(b)
for _, scale := range wireBenchScales {
account, validatedPeers := scalableTestAccount(scale.peers, scale.groups)
populateAccountSeqIDs(account)
assignValidWgKeys(account)
ctx := context.Background()
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
groupIDToUserIDs := account.GetActiveGroupUsers()
peerID := "peer-0"
peer := account.Peers[peerID]
networkMap := account.GetPeerNetworkMapFromComponents(ctx, peerID, nbdns.CustomZone{}, nil, validatedPeers, resourcePolicies, routers, nil, groupIDToUserIDs)
components := account.GetPeerNetworkMapComponents(ctx, peerID, nbdns.CustomZone{}, nil, validatedPeers, resourcePolicies, routers, groupIDToUserIDs)
dnsCache := &cache.DNSConfigCache{}
settings := &types.Settings{}
// Pre-encode once so the size metric is identical for every run inside
// the same scale; the b.Loop call only re-runs encode + Marshal.
legacyResp := mgmtgrpc.ToSyncResponse(ctx, nil, nil, nil, peer, nil, nil, networkMap, "netbird.cloud", nil, dnsCache, settings, nil, nil, 0)
legacyBytes, err := goproto.Marshal(legacyResp.NetworkMap)
if err != nil {
b.Fatalf("marshal legacy networkmap: %v", err)
}
envelopeInput := mgmtgrpc.ComponentsEnvelopeInput{
Components: components,
PeerConfig: legacyResp.NetworkMap.PeerConfig,
DNSDomain: "netbird.cloud",
}
envelope := mgmtgrpc.EncodeNetworkMapEnvelope(envelopeInput)
envelopeBytes, err := goproto.Marshal(envelope)
if err != nil {
b.Fatalf("marshal envelope: %v", err)
}
b.Run(fmt.Sprintf("legacy/%s", scale.name), func(b *testing.B) {
b.ReportAllocs()
b.ReportMetric(float64(len(legacyBytes)), "bytes/msg")
b.ResetTimer()
for range b.N {
resp := mgmtgrpc.ToSyncResponse(ctx, nil, nil, nil, peer, nil, nil, networkMap, "netbird.cloud", nil, dnsCache, settings, nil, nil, 0)
if _, err := goproto.Marshal(resp.NetworkMap); err != nil {
b.Fatal(err)
}
}
})
b.Run(fmt.Sprintf("components/%s", scale.name), func(b *testing.B) {
b.ReportAllocs()
b.ReportMetric(float64(len(envelopeBytes)), "bytes/msg")
b.ResetTimer()
for range b.N {
env := mgmtgrpc.EncodeNetworkMapEnvelope(envelopeInput)
if _, err := goproto.Marshal(env); err != nil {
b.Fatal(err)
}
}
})
}
}
// BenchmarkNetworkMapWireSize is a fast snapshot of the wire size by scale
// without a tight encode loop. Run with -bench to see one ns/op + bytes per
// scale (treat the timing as informational; the sample is one Marshal per
// scale, not the full b.N loop).
func BenchmarkNetworkMapWireSize(b *testing.B) {
skipCIBenchmark(b)
for _, scale := range wireBenchScales {
account, validatedPeers := scalableTestAccount(scale.peers, scale.groups)
populateAccountSeqIDs(account)
assignValidWgKeys(account)
ctx := context.Background()
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
groupIDToUserIDs := account.GetActiveGroupUsers()
peerID := "peer-0"
peer := account.Peers[peerID]
networkMap := account.GetPeerNetworkMapFromComponents(ctx, peerID, nbdns.CustomZone{}, nil, validatedPeers, resourcePolicies, routers, nil, groupIDToUserIDs)
components := account.GetPeerNetworkMapComponents(ctx, peerID, nbdns.CustomZone{}, nil, validatedPeers, resourcePolicies, routers, groupIDToUserIDs)
dnsCache := &cache.DNSConfigCache{}
settings := &types.Settings{}
legacyResp := mgmtgrpc.ToSyncResponse(ctx, nil, nil, nil, peer, nil, nil, networkMap, "netbird.cloud", nil, dnsCache, settings, nil, nil, 0)
legacyBytes, err := goproto.Marshal(legacyResp.NetworkMap)
if err != nil {
b.Fatalf("marshal legacy networkmap: %v", err)
}
env := mgmtgrpc.EncodeNetworkMapEnvelope(mgmtgrpc.ComponentsEnvelopeInput{
Components: components,
PeerConfig: legacyResp.NetworkMap.PeerConfig,
DNSDomain: "netbird.cloud",
})
envBytes, err := goproto.Marshal(env)
if err != nil {
b.Fatalf("marshal envelope: %v", err)
}
b.Run(fmt.Sprintf("size/%s", scale.name), func(b *testing.B) {
b.ReportMetric(float64(len(legacyBytes)), "legacy_bytes")
b.ReportMetric(float64(len(envBytes)), "components_bytes")
ratio := float64(len(envBytes)) / float64(len(legacyBytes))
b.ReportMetric(ratio, "components/legacy")
for range b.N {
}
})
}
}

View File

@@ -0,0 +1,150 @@
package types_test
import (
"context"
"fmt"
"os"
"testing"
goproto "google.golang.org/protobuf/proto"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache"
mgmtgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/proto"
)
// TestNetworkMapWireBreakdown is a one-shot diagnostic: it computes the wire
// size attributable to each top-level field of both the legacy NetworkMap and
// the components NetworkMapEnvelope at the 5000-peer scale, so the migration
// docs can attribute the size reduction to each optimization. Runs only on
// demand via -run TestNetworkMapWireBreakdown.
func TestNetworkMapWireBreakdown(t *testing.T) {
if testing.Short() {
t.Skip("size diagnostic, skipped with -short")
}
if os.Getenv("NB_RUN_WIRE_BREAKDOWN") != "1" {
t.Skip("set NB_RUN_WIRE_BREAKDOWN=1 to run wire breakdown diagnostic")
}
const peerCount, groupCount = 5000, 100
account, validatedPeers := scalableTestAccount(peerCount, groupCount)
populateAccountSeqIDs(account)
assignValidWgKeys(account)
ctx := context.Background()
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
groupIDToUserIDs := account.GetActiveGroupUsers()
peerID := "peer-0"
peer := account.Peers[peerID]
networkMap := account.GetPeerNetworkMapFromComponents(ctx, peerID, nbdns.CustomZone{}, nil, validatedPeers, resourcePolicies, routers, nil, groupIDToUserIDs)
components := account.GetPeerNetworkMapComponents(ctx, peerID, nbdns.CustomZone{}, nil, validatedPeers, resourcePolicies, routers, groupIDToUserIDs)
dnsCache := &cache.DNSConfigCache{}
settings := &types.Settings{}
legacyResp := mgmtgrpc.ToSyncResponse(ctx, nil, nil, nil, peer, nil, nil, networkMap, "netbird.cloud", nil, dnsCache, settings, nil, nil, 0)
legacyTotal := mustMarshalSize(t, legacyResp.NetworkMap)
envelope := mgmtgrpc.EncodeNetworkMapEnvelope(mgmtgrpc.ComponentsEnvelopeInput{
Components: components,
PeerConfig: legacyResp.NetworkMap.PeerConfig,
DNSDomain: "netbird.cloud",
})
componentsTotal := mustMarshalSize(t, envelope)
t.Logf("\n=== LEGACY NetworkMap (%d peers, %d groups) ===", peerCount, groupCount)
t.Logf(" Total: %d bytes\n", legacyTotal)
legacyBreakdown := []struct {
name string
nm *proto.NetworkMap
}{
{"RemotePeers", &proto.NetworkMap{RemotePeers: legacyResp.NetworkMap.RemotePeers}},
{"OfflinePeers", &proto.NetworkMap{OfflinePeers: legacyResp.NetworkMap.OfflinePeers}},
{"FirewallRules", &proto.NetworkMap{FirewallRules: legacyResp.NetworkMap.FirewallRules}},
{"Routes", &proto.NetworkMap{Routes: legacyResp.NetworkMap.Routes}},
{"RoutesFirewallRules", &proto.NetworkMap{RoutesFirewallRules: legacyResp.NetworkMap.RoutesFirewallRules}},
{"DNSConfig", &proto.NetworkMap{DNSConfig: legacyResp.NetworkMap.DNSConfig}},
{"PeerConfig", &proto.NetworkMap{PeerConfig: legacyResp.NetworkMap.PeerConfig}},
{"SshAuth", &proto.NetworkMap{SshAuth: legacyResp.NetworkMap.SshAuth}},
}
for _, e := range legacyBreakdown {
size := mustMarshalSize(t, e.nm)
t.Logf(" %-22s %8d bytes %5.1f%%", e.name, size, pct(size, legacyTotal))
}
full := envelope.GetFull()
if full == nil {
t.Fatalf("expected full network map envelope payload, got nil")
}
t.Logf("\n=== COMPONENTS NetworkMapEnvelope (%d peers, %d groups) ===", peerCount, groupCount)
t.Logf(" Total: %d bytes (%.1f%% of legacy)\n", componentsTotal, pct(componentsTotal, legacyTotal))
componentsBreakdown := []struct {
name string
nm *proto.NetworkMapComponentsFull
}{
{"Peers", &proto.NetworkMapComponentsFull{Peers: full.Peers}},
{"Policies", &proto.NetworkMapComponentsFull{Policies: full.Policies}},
{"Groups", &proto.NetworkMapComponentsFull{Groups: full.Groups}},
{"Routes (raw)", &proto.NetworkMapComponentsFull{Routes: full.Routes}},
{"NameServerGroups", &proto.NetworkMapComponentsFull{NameserverGroups: full.NameserverGroups}},
{"AllDNSRecords", &proto.NetworkMapComponentsFull{AllDnsRecords: full.AllDnsRecords}},
{"AccountZones", &proto.NetworkMapComponentsFull{AccountZones: full.AccountZones}},
{"NetworkResources", &proto.NetworkMapComponentsFull{NetworkResources: full.NetworkResources}},
{"RoutersMap", &proto.NetworkMapComponentsFull{RoutersMap: full.RoutersMap}},
{"ResourcePoliciesMap", &proto.NetworkMapComponentsFull{ResourcePoliciesMap: full.ResourcePoliciesMap}},
{"GroupIDToUserIDs", &proto.NetworkMapComponentsFull{GroupIdToUserIds: full.GroupIdToUserIds}},
{"AllowedUserIDs", &proto.NetworkMapComponentsFull{AllowedUserIds: full.AllowedUserIds}},
{"PostureFailedPeers", &proto.NetworkMapComponentsFull{PostureFailedPeers: full.PostureFailedPeers}},
{"DNSSettings", &proto.NetworkMapComponentsFull{DnsSettings: full.DnsSettings}},
{"PeerConfig", &proto.NetworkMapComponentsFull{PeerConfig: full.PeerConfig}},
{"AgentVersions", &proto.NetworkMapComponentsFull{AgentVersions: full.AgentVersions}},
}
for _, e := range componentsBreakdown {
size := mustMarshalSize(t, e.nm)
t.Logf(" %-22s %8d bytes %5.1f%%", e.name, size, pct(size, componentsTotal))
}
t.Logf("\n=== Per-PeerCompact average ===")
if len(full.Peers) > 0 {
t.Logf(" PeerCompact avg: %d bytes/peer", mustMarshalSize(t, &proto.NetworkMapComponentsFull{Peers: full.Peers})/len(full.Peers))
}
if len(legacyResp.NetworkMap.RemotePeers) > 0 {
t.Logf(" RemotePeer avg: %d bytes/peer",
mustMarshalSize(t, &proto.NetworkMap{RemotePeers: legacyResp.NetworkMap.RemotePeers})/len(legacyResp.NetworkMap.RemotePeers))
}
t.Logf("\n=== FirewallRule expansion footprint ===")
t.Logf(" legacy FirewallRules count: %d", len(legacyResp.NetworkMap.FirewallRules))
t.Logf(" components Policies count: %d", len(full.Policies))
t.Logf(" components Groups count: %d", len(full.Groups))
totalGroupPeerIdxs := 0
for _, g := range full.Groups {
totalGroupPeerIdxs += len(g.PeerIndexes)
}
t.Logf(" components peer-index refs across all groups: %d", totalGroupPeerIdxs)
}
func mustMarshalSize(t *testing.T, m goproto.Message) int {
b, err := goproto.Marshal(m)
if err != nil {
t.Fatalf("marshal: %v", err)
}
return len(b)
}
func pct(part, total int) float64 {
if total == 0 {
return 0
}
return 100 * float64(part) / float64(total)
}
// Stops fmt being unused if the breakdown loop above is later commented out.
var _ = fmt.Sprintf

View File

@@ -0,0 +1,25 @@
package types
// PeerNetworkMapResult is what the network_map controller produces for a
// single peer. Exactly one of NetworkMap or Components is populated depending
// on the peer's capability:
//
// - Components-capable peers (PeerCapabilityComponentNetworkMap) get
// Components: the raw types.NetworkMapComponents the client decodes and
// runs Calculate() on locally. NetworkMap stays nil — the server skips
// the expansion entirely.
// - Legacy peers (or any peer when the kill switch is set) get NetworkMap:
// the fully-expanded view the legacy gRPC path consumes.
//
// The gRPC layer (ToSyncResponseForPeer) dispatches by which field is
// non-nil; callers must not rely on both being set.
type PeerNetworkMapResult struct {
NetworkMap *NetworkMap
Components *NetworkMapComponents
}
// IsComponents reports whether the result carries the components shape.
// Use this in preference to direct nil checks on the fields.
func (r PeerNetworkMapResult) IsComponents() bool {
return r.Components != nil
}

View File

@@ -0,0 +1,104 @@
package types_test
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
nbdns "github.com/netbirdio/netbird/dns"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/types"
)
// helper: marks the given peer as components-capable.
func markCapable(p *nbpeer.Peer) {
p.Meta.Capabilities = append(p.Meta.Capabilities, nbpeer.PeerCapabilityComponentNetworkMap)
}
func TestGetPeerNetworkMapResult_CapablePeerGetsComponents(t *testing.T) {
account, validatedPeers := scalableTestAccount(10, 2)
markCapable(account.Peers["peer-0"])
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
groupIDToUserIDs := account.GetActiveGroupUsers()
result := account.GetPeerNetworkMapResult(
context.Background(),
"peer-0",
false, // componentsDisabled
nbdns.CustomZone{},
nil,
validatedPeers,
resourcePolicies,
routers,
nil,
groupIDToUserIDs,
)
require.True(t, result.IsComponents(), "capable peer must get the components shape")
assert.Nil(t, result.NetworkMap)
require.NotNil(t, result.Components)
assert.Equal(t, "peer-0", result.Components.PeerID)
}
func TestGetPeerNetworkMapResult_LegacyPeerGetsNetworkMap(t *testing.T) {
account, validatedPeers := scalableTestAccount(10, 2)
// peer-0 left without the component capability
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
groupIDToUserIDs := account.GetActiveGroupUsers()
result := account.GetPeerNetworkMapResult(
context.Background(),
"peer-0",
false,
nbdns.CustomZone{},
nil,
validatedPeers,
resourcePolicies,
routers,
nil,
groupIDToUserIDs,
)
assert.False(t, result.IsComponents())
assert.Nil(t, result.Components)
require.NotNil(t, result.NetworkMap, "legacy peer must get a NetworkMap")
}
func TestGetPeerNetworkMapResult_KillSwitchOverridesCapability(t *testing.T) {
// Capable peer + componentsDisabled=true → falls back to legacy.
account, validatedPeers := scalableTestAccount(10, 2)
markCapable(account.Peers["peer-0"])
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
groupIDToUserIDs := account.GetActiveGroupUsers()
result := account.GetPeerNetworkMapResult(
context.Background(),
"peer-0",
true, // componentsDisabled = true (kill switch)
nbdns.CustomZone{},
nil,
validatedPeers,
resourcePolicies,
routers,
nil,
groupIDToUserIDs,
)
assert.False(t, result.IsComponents(), "kill switch must force legacy NetworkMap path")
assert.Nil(t, result.Components)
require.NotNil(t, result.NetworkMap)
}
func TestPeerNetworkMapResult_IsComponents(t *testing.T) {
assert.True(t, types.PeerNetworkMapResult{Components: &types.NetworkMapComponents{}}.IsComponents())
assert.False(t, types.PeerNetworkMapResult{NetworkMap: &types.NetworkMap{}}.IsComponents())
assert.False(t, types.PeerNetworkMapResult{}.IsComponents())
}

View File

@@ -59,6 +59,10 @@ type Policy struct {
// AccountID is a reference to Account that this object belongs
AccountID string `json:"-" gorm:"index"`
// AccountSeqID is a per-account monotonically increasing identifier used as the
// compact wire id when sending NetworkMap components to capable peers.
AccountSeqID uint32 `json:"-" gorm:"index:idx_policies_account_seq_id;not null;default:0"`
// Name of the Policy
Name string
@@ -75,11 +79,19 @@ type Policy struct {
SourcePostureChecks []string `gorm:"serializer:json"`
}
// HasSeqID reports whether the policy has been persisted long enough to have
// a per-account sequence id allocated. Wire encoders that key off
// AccountSeqID must skip policies that return false here.
func (p *Policy) HasSeqID() bool {
return p != nil && p.AccountSeqID != 0
}
// Copy returns a copy of the policy.
func (p *Policy) Copy() *Policy {
c := &Policy{
ID: p.ID,
AccountID: p.AccountID,
AccountSeqID: p.AccountSeqID,
Name: p.Name,
Description: p.Description,
Enabled: p.Enabled,

View File

@@ -95,6 +95,9 @@ type Route struct {
ID ID `gorm:"primaryKey"`
// AccountID is a reference to Account that this object belongs
AccountID string `gorm:"index"`
// AccountSeqID is a per-account monotonically increasing identifier used as the
// compact wire id when sending NetworkMap components to capable peers.
AccountSeqID uint32 `json:"-" gorm:"index:idx_routes_account_seq_id;not null;default:0"`
// Network and Domains are mutually exclusive
Network netip.Prefix `gorm:"serializer:json"`
Domains domain.List `gorm:"serializer:json"`
@@ -128,6 +131,7 @@ func (r *Route) Copy() *Route {
route := &Route{
ID: r.ID,
AccountID: r.AccountID,
AccountSeqID: r.AccountSeqID,
Description: r.Description,
NetID: r.NetID,
Network: r.Network,

View File

@@ -316,27 +316,74 @@ func TestClient_Sync(t *testing.T) {
select {
case resp := <-ch:
if resp.GetPeerConfig() == nil {
if resp.GetPeerConfig() == nil && resp.GetNetworkMap().GetPeerConfig() == nil {
t.Error("expecting non nil PeerConfig got nil")
}
if resp.GetNetbirdConfig() == nil {
t.Error("expecting non nil NetbirdConfig got nil")
}
if len(resp.GetRemotePeers()) != 1 {
t.Errorf("expecting RemotePeers size %d got %d", 1, len(resp.GetRemotePeers()))
// Component-capable clients receive a NetworkMapEnvelope; the
// remote-peers list is encoded inside it. Decode it and check the
// envelope's peers slice. Legacy peers populate the top-level
// RemotePeers; both shapes must surface exactly one remote peer.
remotePeerKeys := remotePeerKeysFromSync(resp, testKey.PublicKey().String())
if len(remotePeerKeys) != 1 {
t.Errorf("expecting RemotePeers size %d got %d", 1, len(remotePeerKeys))
return
}
if resp.GetRemotePeersIsEmpty() == true {
if resp.GetNetworkMap() != nil && resp.GetRemotePeersIsEmpty() {
t.Error("expecting RemotePeers property to be false, got true")
}
if resp.GetRemotePeers()[0].GetWgPubKey() != remoteKey.PublicKey().String() {
t.Errorf("expecting RemotePeer public key %s got %s", remoteKey.PublicKey().String(), resp.GetRemotePeers()[0].GetWgPubKey())
if remotePeerKeys[0] != remoteKey.PublicKey().String() {
t.Errorf("expecting RemotePeer public key %s got %s", remoteKey.PublicKey().String(), remotePeerKeys[0])
}
case <-time.After(3 * time.Second):
t.Error("timeout waiting for test to finish")
}
}
// remotePeerKeysFromSync extracts the remote-peer WG keys from either the
// legacy NetworkMap.RemotePeers list or the components NetworkMapEnvelope's
// inner peers slice (filtering out the local receiving peer identified by
// localKey, since the envelope's peers list is index-addressed and includes
// the local peer alongside remotes).
func remotePeerKeysFromSync(resp *mgmtProto.SyncResponse, localKey string) []string {
if rp := resp.GetRemotePeers(); len(rp) > 0 {
out := make([]string, 0, len(rp))
for _, p := range rp {
out = append(out, p.GetWgPubKey())
}
return out
}
env := resp.GetNetworkMapEnvelope().GetFull()
if env == nil {
return nil
}
out := make([]string, 0, len(env.GetPeers()))
for _, p := range env.GetPeers() {
key := wgKeyFromBytes(p.GetWgPubKey())
if key == "" || key == localKey {
continue
}
out = append(out, key)
}
return out
}
// wgKeyFromBytes mirrors the client-side decoder: the envelope ships raw 32
// bytes; reconstruct the standard base64 key the test compares against.
func wgKeyFromBytes(raw []byte) string {
if len(raw) == 0 {
return ""
}
var k wgtypes.Key
if len(raw) != len(k) {
return ""
}
copy(k[:], raw)
return k.String()
}
func Test_SystemMetaDataFromClient(t *testing.T) {
s, lis, mgmtMockServer, serverKey := startMockManagement(t)
defer s.GracefulStop()

View File

@@ -950,6 +950,13 @@ func infoToMetaData(info *system.Info) *proto.PeerSystemMeta {
func peerCapabilities(info system.Info) []proto.PeerCapability {
caps := []proto.PeerCapability{
proto.PeerCapability_PeerCapabilitySourcePrefixes,
// PeerCapabilityComponentNetworkMap signals that this client can
// decode the components-format SyncResponse.NetworkMapEnvelope and
// run Calculate() locally. Always advertised by Step-4-capable
// builds — there's no opt-out flag because the server-side kill
// switch (NB_NETWORK_MAP_COMPONENTS_DISABLE) covers emergency
// rollback and the client decoder is built in.
proto.PeerCapability_PeerCapabilityComponentNetworkMap,
}
if !info.DisableIPv6 {
caps = append(caps, proto.PeerCapability_PeerCapabilityIPv6Overlay)

View File

@@ -0,0 +1,607 @@
package networkmap
import (
"encoding/base64"
"fmt"
"net"
"net/netip"
"strconv"
"time"
nbdns "github.com/netbirdio/netbird/dns"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/types"
nbroute "github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/shared/management/proto"
)
// DecodeEnvelope converts a NetworkMapEnvelope into a NetworkMapComponents
// the client can run Calculate() over. Every ID-reference on the wire is a
// uint32 (peer index or account_seq_id) — no xid strings travel. The decoder
// synthesises consistent string IDs from the uint32s so the reconstructed
// components struct round-trips through Calculate exactly the way the
// server-side typed components would.
//
// Synthetic ID scheme (underscore-separated, visually distinct from the xid
// format Calculate would put in log lines under the legacy path):
//
// Peers "p_<wire_index>" // envelope.peers is index-addressed
// Groups "g_<account_seq_id>"
// Policies "pol_<account_seq_id>" // 1 rule per policy
// Routes "r_<account_seq_id>"
// Network resources "nres_<account_seq_id>"
// Posture checks "pc_<account_seq_id>"
// Networks "net_<account_seq_id>"
// Nameserver groups "nsg_<account_seq_id>"
func DecodeEnvelope(env *proto.NetworkMapEnvelope) (*types.NetworkMapComponents, error) {
if env == nil {
return nil, fmt.Errorf("nil envelope")
}
full := env.GetFull()
if full == nil {
return nil, fmt.Errorf("envelope has no Full payload")
}
c := &types.NetworkMapComponents{
PeerID: "", // engine fills its own peer id from PeerConfig
Network: decodeAccountNetwork(full.Network),
AccountSettings: decodeAccountSettings(full.AccountSettings),
CustomZoneDomain: full.CustomZoneDomain,
Peers: make(map[string]*nbpeer.Peer, len(full.Peers)),
Groups: make(map[string]*types.Group, len(full.Groups)),
Policies: make([]*types.Policy, 0, len(full.Policies)),
Routes: make([]*nbroute.Route, 0, len(full.Routes)),
NameServerGroups: make([]*nbdns.NameServerGroup, 0, len(full.NameserverGroups)),
AllDNSRecords: decodeSimpleRecords(full.AllDnsRecords),
AccountZones: decodeCustomZones(full.AccountZones),
ResourcePoliciesMap: make(map[string][]*types.Policy),
RoutersMap: make(map[string]map[string]*routerTypes.NetworkRouter),
NetworkResources: make([]*resourceTypes.NetworkResource, 0, len(full.NetworkResources)),
RouterPeers: make(map[string]*nbpeer.Peer),
AllowedUserIDs: stringSliceToSet(full.AllowedUserIds),
PostureFailedPeers: make(map[string]map[string]struct{}, len(full.PostureFailedPeers)),
GroupIDToUserIDs: make(map[string][]string, len(full.GroupIdToUserIds)),
}
if full.DnsSettings != nil {
c.DNSSettings = &types.DNSSettings{
DisabledManagementGroups: groupIDsFromSeqs(full.DnsSettings.DisabledManagementGroupIds),
}
} else {
c.DNSSettings = &types.DNSSettings{}
}
// Phase 1: peers. The envelope's peers slice is index-addressed; we
// build a peerOrder lookup for downstream references. Peer.ID is
// synthesized from the peer's wire index — wire format ships no xid
// for peers (and never has).
peerIDByIndex := make([]string, len(full.Peers))
for idx, pc := range full.Peers {
if pc == nil {
return nil, fmt.Errorf("invalid envelope: peers[%d] is nil", idx)
}
peerID := synthPeerID(uint32(idx))
peer := decodePeerCompact(pc, peerID, full.AgentVersions)
c.Peers[peerID] = peer
peerIDByIndex[idx] = peerID
}
// Phase 2: groups. AccountSeqID becomes both the synthesized string ID
// and the GroupCompact.id wire value.
for i, gc := range full.Groups {
if gc == nil {
return nil, fmt.Errorf("invalid envelope: groups[%d] is nil", i)
}
groupID := synthGroupID(gc.Id)
peerIDs := make([]string, 0, len(gc.PeerIndexes))
for _, idx := range gc.PeerIndexes {
if int(idx) < len(peerIDByIndex) {
peerIDs = append(peerIDs, peerIDByIndex[idx])
}
}
c.Groups[groupID] = &types.Group{
ID: groupID,
AccountSeqID: gc.Id,
Name: gc.Name,
Peers: peerIDs,
}
}
// Phase 3: policies (PolicyCompact = one rule per entry; current data
// model is 1 rule per policy). Policy.ID is synthesized from the
// per-account seq id; proto.FirewallRule.PolicyID downstream carries
// the same synth string (no xid on the wire).
for i, pc := range full.Policies {
if pc == nil {
return nil, fmt.Errorf("invalid envelope: policies[%d] is nil", i)
}
policyID := synthPolicyID(pc.Id)
c.Policies = append(c.Policies, decodePolicyCompact(pc, policyID, peerIDByIndex))
}
// Phase 4: routes.
for i, rr := range full.Routes {
if rr == nil {
return nil, fmt.Errorf("invalid envelope: routes[%d] is nil", i)
}
c.Routes = append(c.Routes, decodeRouteRaw(rr, peerIDByIndex))
}
// Phase 5: NSGs.
for i, nsg := range full.NameserverGroups {
if nsg == nil {
return nil, fmt.Errorf("invalid envelope: nameserver_groups[%d] is nil", i)
}
c.NameServerGroups = append(c.NameServerGroups, decodeNameServerGroupRaw(nsg))
}
// Phase 6: network resources.
for i, nr := range full.NetworkResources {
if nr == nil {
return nil, fmt.Errorf("invalid envelope: network_resources[%d] is nil", i)
}
c.NetworkResources = append(c.NetworkResources, decodeNetworkResource(nr))
}
// Phase 7: routers_map (outer key = network seq id, inner key = peer-id
// reconstructed from peer_index). Synthesized network id is "net_<seq>".
for networkSeq, list := range full.RoutersMap {
networkID := synthNetworkID(networkSeq)
inner := make(map[string]*routerTypes.NetworkRouter, len(list.Entries))
for _, entry := range list.Entries {
if !entry.PeerIndexSet {
continue
}
if int(entry.PeerIndex) >= len(peerIDByIndex) {
continue
}
peerID := peerIDByIndex[entry.PeerIndex]
inner[peerID] = &routerTypes.NetworkRouter{
ID: "",
NetworkID: networkID,
AccountSeqID: entry.Id,
Peer: peerID,
PeerGroups: groupIDsFromSeqs(entry.PeerGroupIds),
Masquerade: entry.Masquerade,
Metric: int(entry.Metric),
Enabled: entry.Enabled,
}
}
if len(inner) > 0 {
c.RoutersMap[networkID] = inner
}
}
// Phase 8: resource_policies_map (resource seq id → list of *types.Policy
// pointers from the decoded policies slice). Resource ID is synthesized
// the same way as in decodeNetworkResource.
for resourceSeq, idxs := range full.ResourcePoliciesMap {
if len(idxs.Indexes) == 0 {
continue
}
resourceID := synthNetworkResourceID(resourceSeq)
policies := make([]*types.Policy, 0, len(idxs.Indexes))
for _, i := range idxs.Indexes {
if int(i) < len(c.Policies) {
policies = append(policies, c.Policies[i])
}
}
if len(policies) > 0 {
c.ResourcePoliciesMap[resourceID] = policies
}
}
// Phase 9: group_id_to_user_ids — wire keys are seq ids, synth to strings.
for groupSeq, list := range full.GroupIdToUserIds {
c.GroupIDToUserIDs[synthGroupID(groupSeq)] = append([]string(nil), list.UserIds...)
}
// Phase 10: posture_failed_peers — wire keys are posture-check seq ids,
// values are peer indexes that need to be turned into peer ids. PolicyRule
// SourcePostureChecks (also synth ids) reference the same key space.
for checkSeq, set := range full.PostureFailedPeers {
checkID := synthPostureCheckID(checkSeq)
failed := make(map[string]struct{}, len(set.PeerIndexes))
for _, idx := range set.PeerIndexes {
if int(idx) < len(peerIDByIndex) {
failed[peerIDByIndex[idx]] = struct{}{}
}
}
if len(failed) > 0 {
c.PostureFailedPeers[checkID] = failed
}
}
// Phase 11: router_peer_indexes — peers that act as routers. They're
// already in c.Peers (router peers are appended to the global peers
// list by the encoder); RouterPeers is the subset.
for _, idx := range full.RouterPeerIndexes {
if int(idx) < len(peerIDByIndex) {
peerID := peerIDByIndex[idx]
c.RouterPeers[peerID] = c.Peers[peerID]
}
}
return c, nil
}
func decodeAccountNetwork(an *proto.AccountNetwork) *types.Network {
if an == nil {
return nil
}
n := &types.Network{
Identifier: an.Identifier,
Dns: an.Dns,
Serial: an.Serial,
}
if an.NetCidr != "" {
if _, ipnet, err := net.ParseCIDR(an.NetCidr); err == nil && ipnet != nil {
n.Net = *ipnet
}
}
if an.NetV6Cidr != "" {
if _, ipnet, err := net.ParseCIDR(an.NetV6Cidr); err == nil && ipnet != nil {
n.NetV6 = *ipnet
}
}
return n
}
func decodeAccountSettings(as *proto.AccountSettingsCompact) *types.AccountSettingsInfo {
if as == nil {
return &types.AccountSettingsInfo{}
}
return &types.AccountSettingsInfo{
PeerLoginExpirationEnabled: as.PeerLoginExpirationEnabled,
PeerLoginExpiration: time.Duration(as.PeerLoginExpirationNs),
}
}
func decodePeerCompact(pc *proto.PeerCompact, peerID string, agentVersions []string) *nbpeer.Peer {
var caps []int32
if pc.SupportsSourcePrefixes {
caps = append(caps, nbpeer.PeerCapabilitySourcePrefixes)
}
if pc.SupportsIpv6 {
caps = append(caps, nbpeer.PeerCapabilityIPv6Overlay)
}
peer := &nbpeer.Peer{
ID: peerID,
Key: encodeWgKeyBase64(pc.WgPubKey),
SSHKey: string(pc.SshPubKey),
SSHEnabled: pc.SshEnabled,
DNSLabel: pc.DnsLabel,
LoginExpirationEnabled: pc.LoginExpirationEnabled,
Meta: nbpeer.PeerSystemMeta{
WtVersion: lookupAgentVersion(agentVersions, pc.AgentVersionIdx),
Capabilities: caps,
Flags: nbpeer.Flags{
ServerSSHAllowed: pc.ServerSshAllowed,
},
},
}
if pc.AddedWithSsoLogin {
// Set a non-empty UserID so (*Peer).AddedWithSSOLogin() returns true.
// The original UserID isn't on the wire; the value is intentionally
// visibly synthetic so any future consumer that mistakes UserID for a
// real account user xid won't silently match (or worse, write the
// sentinel into a downstream record).
peer.UserID = "<env-sso>"
}
if pc.LastLoginUnixNano != 0 {
t := time.Unix(0, pc.LastLoginUnixNano)
peer.LastLogin = &t
}
switch len(pc.Ip) {
case 4:
peer.IP = netip.AddrFrom4([4]byte{pc.Ip[0], pc.Ip[1], pc.Ip[2], pc.Ip[3]})
case 16:
var a [16]byte
copy(a[:], pc.Ip)
peer.IP = netip.AddrFrom16(a)
}
if len(pc.Ipv6) == 16 {
var a [16]byte
copy(a[:], pc.Ipv6)
peer.IPv6 = netip.AddrFrom16(a)
}
return peer
}
func decodePolicyCompact(pc *proto.PolicyCompact, policyID string, peerIDByIndex []string) *types.Policy {
rule := &types.PolicyRule{
ID: policyID, // 1 rule per policy → reuse synthesized id
PolicyID: policyID,
Enabled: true,
Action: actionFromProto(pc.Action),
Protocol: protocolFromProto(pc.Protocol),
Bidirectional: pc.Bidirectional,
Ports: uint32SliceToStrings(pc.Ports),
PortRanges: portRangesFromProto(pc.PortRanges),
Sources: groupIDsFromSeqs(pc.SourceGroupIds),
Destinations: groupIDsFromSeqs(pc.DestinationGroupIds),
AuthorizedUser: pc.AuthorizedUser,
AuthorizedGroups: authorizedGroupsFromProto(pc.AuthorizedGroups),
SourceResource: resourceFromProto(pc.SourceResource, peerIDByIndex),
DestinationResource: resourceFromProto(pc.DestinationResource, peerIDByIndex),
}
return &types.Policy{
ID: policyID,
AccountSeqID: pc.Id,
Enabled: true,
Rules: []*types.PolicyRule{rule},
SourcePostureChecks: postureCheckIDsFromSeqs(pc.SourcePostureCheckSeqIds),
}
}
// resourceFromProto rebuilds types.Resource. For peer-typed resources the
// peer reference is reconstructed from the envelope's peer index — wire
// format ships no xid for peers, so we use the synthesized peer id.
func resourceFromProto(r *proto.ResourceCompact, peerIDByIndex []string) types.Resource {
if r == nil {
return types.Resource{}
}
out := types.Resource{Type: types.ResourceType(r.Type)}
if r.PeerIndexSet && int(r.PeerIndex) < len(peerIDByIndex) {
out.ID = peerIDByIndex[r.PeerIndex]
}
return out
}
// postureCheckIDsFromSeqs synths posture-check ids from per-account seq ids.
// Mirrors groupIDsFromSeqs.
func postureCheckIDsFromSeqs(seqs []uint32) []string {
if len(seqs) == 0 {
return nil
}
out := make([]string, len(seqs))
for i, s := range seqs {
out[i] = synthPostureCheckID(s)
}
return out
}
// authorizedGroupsFromProto inverts encodeAuthorizedGroups: the wire form
// keys by group account_seq_id, the typed PolicyRule field keys by group
// xid string. We rebuild using the same synthetic scheme the rest of the
// decoder uses ("g<seq>").
func authorizedGroupsFromProto(m map[uint32]*proto.UserNameList) map[string][]string {
if len(m) == 0 {
return nil
}
out := make(map[string][]string, len(m))
for seq, list := range m {
if list == nil {
continue
}
out[synthGroupID(seq)] = append([]string(nil), list.Names...)
}
return out
}
func decodeRouteRaw(rr *proto.RouteRaw, peerIDByIndex []string) *nbroute.Route {
r := &nbroute.Route{
ID: nbroute.ID(synthRouteID(rr.Id)),
AccountSeqID: rr.Id,
NetID: nbroute.NetID(rr.NetId),
Description: rr.Description,
Domains: domainsFromPunycode(rr.Domains),
KeepRoute: rr.KeepRoute,
NetworkType: nbroute.NetworkType(rr.NetworkType),
Masquerade: rr.Masquerade,
Metric: int(rr.Metric),
Enabled: rr.Enabled,
Groups: groupIDsFromSeqs(rr.GroupIds),
AccessControlGroups: groupIDsFromSeqs(rr.AccessControlGroupIds),
PeerGroups: groupIDsFromSeqs(rr.PeerGroupIds),
SkipAutoApply: rr.SkipAutoApply,
}
if rr.NetworkCidr != "" {
if p, err := netip.ParsePrefix(rr.NetworkCidr); err == nil {
r.Network = p
}
}
if rr.PeerIndexSet && int(rr.PeerIndex) < len(peerIDByIndex) {
r.Peer = peerIDByIndex[rr.PeerIndex]
}
return r
}
func decodeNameServerGroupRaw(nsg *proto.NameServerGroupRaw) *nbdns.NameServerGroup {
out := &nbdns.NameServerGroup{
ID: synthNameServerGroupID(nsg.Id),
AccountSeqID: nsg.Id,
Name: nsg.Name,
Description: nsg.Description,
Groups: groupIDsFromSeqs(nsg.GroupIds),
Primary: nsg.Primary,
Domains: nsg.Domains,
Enabled: nsg.Enabled,
SearchDomainsEnabled: nsg.SearchDomainsEnabled,
NameServers: make([]nbdns.NameServer, 0, len(nsg.Nameservers)),
}
for _, ns := range nsg.Nameservers {
if addr, err := netip.ParseAddr(ns.IP); err == nil {
out.NameServers = append(out.NameServers, nbdns.NameServer{
IP: addr,
NSType: nbdns.NameServerType(ns.NSType),
Port: int(ns.Port),
})
}
}
return out
}
func decodeNetworkResource(nr *proto.NetworkResourceRaw) *resourceTypes.NetworkResource {
out := &resourceTypes.NetworkResource{
ID: synthNetworkResourceID(nr.Id),
AccountSeqID: nr.Id,
NetworkID: synthNetworkID(nr.NetworkSeq),
Name: nr.Name,
Description: nr.Description,
Type: resourceTypes.NetworkResourceType(nr.Type),
Address: nr.Address,
Domain: nr.DomainValue,
Enabled: nr.Enabled,
}
if nr.PrefixCidr != "" {
if p, err := netip.ParsePrefix(nr.PrefixCidr); err == nil {
out.Prefix = p
}
}
return out
}
func decodeSimpleRecords(records []*proto.SimpleRecord) []nbdns.SimpleRecord {
out := make([]nbdns.SimpleRecord, 0, len(records))
for _, r := range records {
out = append(out, nbdns.SimpleRecord{
Name: r.Name,
Type: int(r.Type),
Class: r.Class,
TTL: int(r.TTL),
RData: r.RData,
})
}
return out
}
func decodeCustomZones(zones []*proto.CustomZone) []nbdns.CustomZone {
out := make([]nbdns.CustomZone, 0, len(zones))
for _, z := range zones {
out = append(out, nbdns.CustomZone{
Domain: z.Domain,
Records: decodeSimpleRecords(z.Records),
SearchDomainDisabled: z.SearchDomainDisabled,
NonAuthoritative: z.NonAuthoritative,
})
}
return out
}
// Synthetic ID generators — deterministic given the same wire input.
// Underscore-separated ("p_<n>", "pol_<n>", ...) so they're visually
// distinct in operator logs. fmt.Sprintf would dominate the decode hot path
// on large accounts (a 10k-peer envelope produces ~50k synth calls); the
// strconv.AppendUint builder keeps it allocation-light.
func synthID(prefix string, n uint32) string {
buf := make([]byte, 0, len(prefix)+10)
buf = append(buf, prefix...)
buf = strconv.AppendUint(buf, uint64(n), 10)
return string(buf)
}
func synthPeerID(idx uint32) string { return synthID("p_", idx) }
func synthGroupID(seq uint32) string { return synthID("g_", seq) }
func synthPolicyID(seq uint32) string { return synthID("pol_", seq) }
func synthRouteID(seq uint32) string { return synthID("r_", seq) }
func synthNetworkResourceID(seq uint32) string { return synthID("nres_", seq) }
func synthPostureCheckID(seq uint32) string { return synthID("pc_", seq) }
func synthNetworkID(seq uint32) string { return synthID("net_", seq) }
func synthNameServerGroupID(seq uint32) string { return synthID("nsg_", seq) }
func groupIDsFromSeqs(seqs []uint32) []string {
if len(seqs) == 0 {
return nil
}
out := make([]string, len(seqs))
for i, s := range seqs {
out[i] = synthGroupID(s)
}
return out
}
func uint32SliceToStrings(ports []uint32) []string {
if len(ports) == 0 {
return nil
}
out := make([]string, len(ports))
for i, p := range ports {
out[i] = strconv.FormatUint(uint64(p), 10)
}
return out
}
func portRangesFromProto(ranges []*proto.PortInfo_Range) []types.RulePortRange {
if len(ranges) == 0 {
return nil
}
out := make([]types.RulePortRange, 0, len(ranges))
for _, r := range ranges {
if r == nil || r.Start > 65535 || r.End > 65535 {
continue
}
out = append(out, types.RulePortRange{
Start: uint16(r.Start),
End: uint16(r.End),
})
}
return out
}
func actionFromProto(a proto.RuleAction) types.PolicyTrafficActionType {
if a == proto.RuleAction_DROP {
return types.PolicyTrafficActionDrop
}
return types.PolicyTrafficActionAccept
}
func protocolFromProto(p proto.RuleProtocol) types.PolicyRuleProtocolType {
switch p {
case proto.RuleProtocol_TCP:
return types.PolicyRuleProtocolTCP
case proto.RuleProtocol_UDP:
return types.PolicyRuleProtocolUDP
case proto.RuleProtocol_ICMP:
return types.PolicyRuleProtocolICMP
case proto.RuleProtocol_ALL:
return types.PolicyRuleProtocolALL
case proto.RuleProtocol_NETBIRD_SSH:
return types.PolicyRuleProtocolNetbirdSSH
default:
return types.PolicyRuleProtocolALL
}
}
func encodeWgKeyBase64(raw []byte) string {
if len(raw) != 32 {
return ""
}
return base64.StdEncoding.EncodeToString(raw)
}
func lookupAgentVersion(table []string, idx uint32) string {
if int(idx) < len(table) {
return table[idx]
}
return ""
}
func stringSliceToSet(s []string) map[string]struct{} {
if len(s) == 0 {
return nil
}
out := make(map[string]struct{}, len(s))
for _, v := range s {
out[v] = struct{}{}
}
return out
}
// domainsFromPunycode is a thin wrapper that converts a punycode list back to
// the domain.List type the route.Route struct expects. It accepts the
// punycode strings as-is (no extra decoding) — symmetric with
// route.Domains.ToPunycodeList() used in the encoder.
func domainsFromPunycode(punycoded []string) domain.List {
if len(punycoded) == 0 {
return nil
}
out := make(domain.List, 0, len(punycoded))
for _, d := range punycoded {
out = append(out, domain.Domain(d))
}
return out
}

View File

@@ -0,0 +1,323 @@
// Package networkmap contains the shared NetworkMap helpers that both the
// management server and the client agent need.
//
// The proto-conversion helpers (types.NetworkMap → proto.NetworkMap) live
// here so the client can run the same conversion locally after deriving its
// NetworkMap from a NetworkMapEnvelope, without taking a dependency on the
// server-side conversion package (which pulls in cloud integrations and is
// otherwise an unwanted internal import on the client).
//
// The helpers are pure functions over inputs — no caches, no IO, no logging
// beyond a context-aware error log when an individual user-id hash fails.
package networkmap
import (
"context"
log "github.com/sirupsen/logrus"
goproto "google.golang.org/protobuf/proto"
nbdns "github.com/netbirdio/netbird/dns"
"net/netip"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/types"
nbroute "github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/shared/netiputil"
"github.com/netbirdio/netbird/shared/sshauth"
)
// ToProtocolRoutes converts a slice of typed routes to their proto form.
func ToProtocolRoutes(routes []*nbroute.Route) []*proto.Route {
protoRoutes := make([]*proto.Route, 0, len(routes))
for _, r := range routes {
protoRoutes = append(protoRoutes, ToProtocolRoute(r))
}
return protoRoutes
}
// ToProtocolRoute converts one typed route to its proto form.
func ToProtocolRoute(route *nbroute.Route) *proto.Route {
return &proto.Route{
ID: string(route.ID),
NetID: string(route.NetID),
Network: route.Network.String(),
Domains: route.Domains.ToPunycodeList(),
NetworkType: int64(route.NetworkType),
Peer: route.Peer,
Metric: int64(route.Metric),
Masquerade: route.Masquerade,
KeepRoute: route.KeepRoute,
SkipAutoApply: route.SkipAutoApply,
}
}
// ToProtocolFirewallRules converts the firewall rules to the protocol form.
// When useSourcePrefixes is true, the compact SourcePrefixes field is
// populated alongside the deprecated PeerIP for forward compatibility.
// Wildcard rules ("0.0.0.0") are expanded into separate v4/v6 SourcePrefixes
// when includeIPv6 is true.
func ToProtocolFirewallRules(rules []*types.FirewallRule, includeIPv6, useSourcePrefixes bool) []*proto.FirewallRule {
result := make([]*proto.FirewallRule, 0, len(rules))
for i := range rules {
rule := rules[i]
fwRule := &proto.FirewallRule{
PolicyID: []byte(rule.PolicyID),
PeerIP: rule.PeerIP, //nolint:staticcheck // populated for backward compatibility
Direction: GetProtoDirection(rule.Direction),
Action: GetProtoAction(rule.Action),
Protocol: GetProtoProtocol(rule.Protocol),
Port: rule.Port,
}
if useSourcePrefixes && rule.PeerIP != "" {
result = append(result, populateSourcePrefixes(fwRule, rule, includeIPv6)...)
}
if ShouldUsePortRange(fwRule) {
fwRule.PortInfo = rule.PortRange.ToProto()
}
result = append(result, fwRule)
}
return result
}
// populateSourcePrefixes sets SourcePrefixes on fwRule and returns any
// additional rules needed (e.g. a v6 wildcard clone when the peer IP is
// unspecified).
func populateSourcePrefixes(fwRule *proto.FirewallRule, rule *types.FirewallRule, includeIPv6 bool) []*proto.FirewallRule {
addr, err := netip.ParseAddr(rule.PeerIP)
if err != nil {
return nil
}
if !addr.IsUnspecified() {
fwRule.SourcePrefixes = [][]byte{netiputil.EncodeAddr(addr.Unmap())}
return nil
}
v4Wildcard, _ := netiputil.EncodePrefix(netip.PrefixFrom(netip.IPv4Unspecified(), 0))
fwRule.SourcePrefixes = [][]byte{v4Wildcard}
if !includeIPv6 {
return nil
}
v6Rule := goproto.Clone(fwRule).(*proto.FirewallRule)
v6Rule.PeerIP = "::" //nolint:staticcheck // populated for backward compatibility
v6Wildcard, _ := netiputil.EncodePrefix(netip.PrefixFrom(netip.IPv6Unspecified(), 0))
v6Rule.SourcePrefixes = [][]byte{v6Wildcard}
if ShouldUsePortRange(v6Rule) {
v6Rule.PortInfo = rule.PortRange.ToProto()
}
return []*proto.FirewallRule{v6Rule}
}
// GetProtoDirection converts the direction to proto.RuleDirection.
func GetProtoDirection(direction int) proto.RuleDirection {
if direction == types.FirewallRuleDirectionOUT {
return proto.RuleDirection_OUT
}
return proto.RuleDirection_IN
}
// GetProtoAction converts the action to proto.RuleAction.
func GetProtoAction(action string) proto.RuleAction {
if action == string(types.PolicyTrafficActionDrop) {
return proto.RuleAction_DROP
}
return proto.RuleAction_ACCEPT
}
// GetProtoProtocol converts the protocol to proto.RuleProtocol.
func GetProtoProtocol(protocol string) proto.RuleProtocol {
switch types.PolicyRuleProtocolType(protocol) {
case types.PolicyRuleProtocolALL:
return proto.RuleProtocol_ALL
case types.PolicyRuleProtocolTCP:
return proto.RuleProtocol_TCP
case types.PolicyRuleProtocolUDP:
return proto.RuleProtocol_UDP
case types.PolicyRuleProtocolICMP:
return proto.RuleProtocol_ICMP
case types.PolicyRuleProtocolNetbirdSSH:
return proto.RuleProtocol_NETBIRD_SSH
default:
return proto.RuleProtocol_UNKNOWN
}
}
// GetProtoPortInfo converts route-firewall-rule port info to proto.PortInfo.
func GetProtoPortInfo(rule *types.RouteFirewallRule) *proto.PortInfo {
var portInfo proto.PortInfo
if rule.Port != 0 {
portInfo.PortSelection = &proto.PortInfo_Port{Port: uint32(rule.Port)}
} else if portRange := rule.PortRange; portRange.Start != 0 && portRange.End != 0 {
portInfo.PortSelection = &proto.PortInfo_Range_{
Range: &proto.PortInfo_Range{
Start: uint32(portRange.Start),
End: uint32(portRange.End),
},
}
}
return &portInfo
}
// ShouldUsePortRange reports whether the firewall rule should use a port
// range rather than a single port (TCP/UDP without a single port).
func ShouldUsePortRange(rule *proto.FirewallRule) bool {
return rule.Port == "" && (rule.Protocol == proto.RuleProtocol_UDP || rule.Protocol == proto.RuleProtocol_TCP)
}
// ToProtocolRoutesFirewallRules converts a slice of typed route-firewall
// rules to proto.
func ToProtocolRoutesFirewallRules(rules []*types.RouteFirewallRule) []*proto.RouteFirewallRule {
result := make([]*proto.RouteFirewallRule, len(rules))
for i := range rules {
rule := rules[i]
result[i] = &proto.RouteFirewallRule{
SourceRanges: rule.SourceRanges,
Action: GetProtoAction(rule.Action),
Destination: rule.Destination,
Protocol: GetProtoProtocol(rule.Protocol),
PortInfo: GetProtoPortInfo(rule),
IsDynamic: rule.IsDynamic,
Domains: rule.Domains.ToPunycodeList(),
PolicyID: []byte(rule.PolicyID),
RouteID: string(rule.RouteID),
}
}
return result
}
// ConvertToProtoCustomZone converts an nbdns.CustomZone to its proto form.
func ConvertToProtoCustomZone(zone nbdns.CustomZone) *proto.CustomZone {
protoZone := &proto.CustomZone{
Domain: zone.Domain,
Records: make([]*proto.SimpleRecord, 0, len(zone.Records)),
SearchDomainDisabled: zone.SearchDomainDisabled,
NonAuthoritative: zone.NonAuthoritative,
}
for _, record := range zone.Records {
protoZone.Records = append(protoZone.Records, &proto.SimpleRecord{
Name: record.Name,
Type: int64(record.Type),
Class: record.Class,
TTL: int64(record.TTL),
RData: record.RData,
})
}
return protoZone
}
// ConvertToProtoNameServerGroup converts a NameServerGroup to its proto form.
func ConvertToProtoNameServerGroup(nsGroup *nbdns.NameServerGroup) *proto.NameServerGroup {
protoGroup := &proto.NameServerGroup{
Primary: nsGroup.Primary,
Domains: nsGroup.Domains,
SearchDomainsEnabled: nsGroup.SearchDomainsEnabled,
NameServers: make([]*proto.NameServer, 0, len(nsGroup.NameServers)),
}
for _, ns := range nsGroup.NameServers {
protoGroup.NameServers = append(protoGroup.NameServers, &proto.NameServer{
IP: ns.IP.String(),
Port: int64(ns.Port),
NSType: int64(ns.NSType),
})
}
return protoGroup
}
// DNSConfigCache is the cache contract for amortising NameServerGroup
// proto-conversion across peers in the same account. Server uses a concrete
// implementation; client passes nil (no cross-peer caching needed when
// rebuilding a single NetworkMap from an envelope).
type DNSConfigCache interface {
GetNameServerGroup(key string) (*proto.NameServerGroup, bool)
SetNameServerGroup(key string, value *proto.NameServerGroup)
}
// ToProtocolDNSConfig converts nbdns.Config to proto.DNSConfig. If cache is
// non-nil, NameServerGroup proto values are cached by NSG.ID across calls —
// the server amortises this across peers, the client passes nil.
func ToProtocolDNSConfig(update nbdns.Config, cache DNSConfigCache, forwardPort int64) *proto.DNSConfig {
protoUpdate := &proto.DNSConfig{
ServiceEnable: update.ServiceEnable,
CustomZones: make([]*proto.CustomZone, 0, len(update.CustomZones)),
NameServerGroups: make([]*proto.NameServerGroup, 0, len(update.NameServerGroups)),
ForwarderPort: forwardPort,
}
for _, zone := range update.CustomZones {
protoUpdate.CustomZones = append(protoUpdate.CustomZones, ConvertToProtoCustomZone(zone))
}
for _, nsGroup := range update.NameServerGroups {
if cache != nil {
if cachedGroup, exists := cache.GetNameServerGroup(nsGroup.ID); exists {
protoUpdate.NameServerGroups = append(protoUpdate.NameServerGroups, cachedGroup)
continue
}
}
protoGroup := ConvertToProtoNameServerGroup(nsGroup)
if cache != nil {
cache.SetNameServerGroup(nsGroup.ID, protoGroup)
}
protoUpdate.NameServerGroups = append(protoUpdate.NameServerGroups, protoGroup)
}
return protoUpdate
}
// AppendRemotePeerConfig appends typed peers as proto.RemotePeerConfig
// entries to dst and returns the result.
func AppendRemotePeerConfig(dst []*proto.RemotePeerConfig, peers []*nbpeer.Peer, dnsName string, includeIPv6 bool) []*proto.RemotePeerConfig {
for _, rPeer := range peers {
allowedIPs := []string{rPeer.IP.String() + "/32"}
if includeIPv6 && rPeer.IPv6.IsValid() {
allowedIPs = append(allowedIPs, rPeer.IPv6.String()+"/128")
}
dst = append(dst, &proto.RemotePeerConfig{
WgPubKey: rPeer.Key,
AllowedIps: allowedIPs,
SshConfig: &proto.SSHConfig{SshPubKey: []byte(rPeer.SSHKey)},
Fqdn: rPeer.FQDN(dnsName),
AgentVersion: rPeer.Meta.WtVersion,
})
}
return dst
}
// BuildAuthorizedUsersProto deduplicates user-IDs into a hashed list and
// builds per-machine-user index maps. Returns (hashedUsers, machineUsers).
// Errors from individual hash failures are logged via the provided context;
// they leave the offending user out of the result but don't abort the build.
func BuildAuthorizedUsersProto(ctx context.Context, authorizedUsers map[string]map[string]struct{}) ([][]byte, map[string]*proto.MachineUserIndexes) {
userIDToIndex := make(map[string]uint32)
var hashedUsers [][]byte
machineUsers := make(map[string]*proto.MachineUserIndexes, len(authorizedUsers))
for machineUser, users := range authorizedUsers {
indexes := make([]uint32, 0, len(users))
for userID := range users {
idx, exists := userIDToIndex[userID]
if !exists {
hash, err := sshauth.HashUserID(userID)
if err != nil {
log.WithContext(ctx).WithError(err).Error("failed to hash user id")
continue
}
idx = uint32(len(hashedUsers))
userIDToIndex[userID] = idx
hashedUsers = append(hashedUsers, hash[:])
}
indexes = append(indexes, idx)
}
machineUsers[machineUser] = &proto.MachineUserIndexes{Indexes: indexes}
}
return hashedUsers, machineUsers
}

View File

@@ -0,0 +1,210 @@
package networkmap
import (
"bytes"
"context"
"encoding/base64"
"fmt"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/proto"
)
// EnvelopeResult is what the client engine consumes after receiving a
// component-format NetworkMap. Both fields are populated:
//
// - NetworkMap is the *proto.NetworkMap shape the engine reads today via
// update.GetNetworkMap() — built from the envelope's components by
// running Calculate() locally + converting back through the shared
// proto helpers + merging the optional ProxyPatch.
// - Components is the *types.NetworkMapComponents the engine retains so
// future incremental delta updates (Step 3) have a base to apply
// changes against. The client keeps it under its sync lock.
type EnvelopeResult struct {
NetworkMap *proto.NetworkMap
Components *types.NetworkMapComponents
}
// EnvelopeToNetworkMap is the full client-side pipeline: decode the
// component envelope back to a typed NetworkMapComponents, run Calculate()
// locally to produce the typed NetworkMap, convert it to the wire form the
// engine consumes, and fold in any ProxyPatch the server attached.
//
// localPeerKey is the receiving peer's WG pub key (used to derive
// includeIPv6 / useSourcePrefixes from the receiving peer's own record in
// the components struct, mirroring legacy ToSyncResponse behaviour).
//
// dnsName is the account's DNS domain ("netbird.cloud" etc.); used when
// rebuilding the per-peer FQDNs that proto.RemotePeerConfig carries.
func EnvelopeToNetworkMap(ctx context.Context, env *proto.NetworkMapEnvelope, localPeerKey, dnsName string) (*EnvelopeResult, error) {
components, err := DecodeEnvelope(env)
if err != nil {
return nil, fmt.Errorf("decode envelope: %w", err)
}
// Find the receiving peer in the decoded components by WG key so we can
// derive its capabilities and set components.PeerID for Calculate(). The
// envelope.peers list is index-addressed; we synthesized IDs as "p<idx>".
localPeerID, localPeer := findPeerByWgKey(components, localPeerKey)
if localPeer == nil {
return nil, fmt.Errorf("receiving peer (wg_key prefix %q) not found among %d decoded peers — components have no PeerID, Calculate would return empty", trimKey(localPeerKey), len(components.Peers))
}
components.PeerID = localPeerID
includeIPv6 := localPeer.SupportsIPv6() && localPeer.IPv6.IsValid()
useSourcePrefixes := localPeer.SupportsSourcePrefixes()
typedNM := components.Calculate(ctx)
full := env.GetFull()
dnsFwdPort := int64(0)
if full != nil {
dnsFwdPort = full.DnsForwarderPort
}
protoNM := &proto.NetworkMap{
Serial: typedNM.Network.CurrentSerial(),
}
if full != nil {
protoNM.PeerConfig = full.PeerConfig
}
protoNM.Routes = ToProtocolRoutes(typedNM.Routes)
protoNM.DNSConfig = ToProtocolDNSConfig(typedNM.DNSConfig, nil, dnsFwdPort)
remotePeers := AppendRemotePeerConfig(nil, typedNM.Peers, dnsName, includeIPv6)
protoNM.RemotePeers = remotePeers
protoNM.RemotePeersIsEmpty = len(remotePeers) == 0
protoNM.OfflinePeers = AppendRemotePeerConfig(nil, typedNM.OfflinePeers, dnsName, includeIPv6)
firewallRules := ToProtocolFirewallRules(typedNM.FirewallRules, includeIPv6, useSourcePrefixes)
protoNM.FirewallRules = firewallRules
protoNM.FirewallRulesIsEmpty = len(firewallRules) == 0
routesFirewallRules := ToProtocolRoutesFirewallRules(typedNM.RoutesFirewallRules)
protoNM.RoutesFirewallRules = routesFirewallRules
protoNM.RoutesFirewallRulesIsEmpty = len(routesFirewallRules) == 0
if typedNM.AuthorizedUsers != nil {
hashedUsers, machineUsers := BuildAuthorizedUsersProto(ctx, typedNM.AuthorizedUsers)
userIDClaim := ""
if full != nil {
userIDClaim = full.UserIdClaim
}
protoNM.SshAuth = &proto.SSHAuth{
AuthorizedUsers: hashedUsers,
MachineUsers: machineUsers,
UserIDClaim: userIDClaim,
}
}
if typedNM.ForwardingRules != nil {
forwardingRules := make([]*proto.ForwardingRule, 0, len(typedNM.ForwardingRules))
for _, rule := range typedNM.ForwardingRules {
forwardingRules = append(forwardingRules, rule.ToProto())
}
protoNM.ForwardingRules = forwardingRules
}
// Merge the proxy patch the server attached. Mirrors the legacy
// NetworkMap.Merge step that the server runs after Calculate().
if full != nil && full.ProxyPatch != nil {
mergeProxyPatch(protoNM, full.ProxyPatch)
}
return &EnvelopeResult{
NetworkMap: protoNM,
Components: components,
}, nil
}
// mergeProxyPatch folds a ProxyPatch's pre-expanded fragments into the
// proto.NetworkMap that Calculate() produced. Mirrors types.NetworkMap.Merge
// — same six collections, deduplicated where the legacy merge dedupes.
func mergeProxyPatch(nm *proto.NetworkMap, patch *proto.ProxyPatch) {
nm.RemotePeers = appendUniquePeers(nm.RemotePeers, patch.Peers)
nm.OfflinePeers = appendUniquePeers(nm.OfflinePeers, patch.OfflinePeers)
nm.FirewallRules = append(nm.FirewallRules, patch.FirewallRules...)
nm.Routes = append(nm.Routes, patch.Routes...)
nm.RoutesFirewallRules = append(nm.RoutesFirewallRules, patch.RouteFirewallRules...)
nm.ForwardingRules = append(nm.ForwardingRules, patch.ForwardingRules...)
if len(nm.RemotePeers) > 0 {
nm.RemotePeersIsEmpty = false
}
if len(nm.FirewallRules) > 0 {
nm.FirewallRulesIsEmpty = false
}
if len(nm.RoutesFirewallRules) > 0 {
nm.RoutesFirewallRulesIsEmpty = false
}
}
// appendUniquePeers dedupes by WgPubKey — mirrors legacy
// mergeUniquePeersByID's intent (legacy keyed off Peer.ID; in proto form the
// closest stable identifier is WgPubKey).
func appendUniquePeers(dst, extra []*proto.RemotePeerConfig) []*proto.RemotePeerConfig {
if len(extra) == 0 {
return dst
}
seen := make(map[string]struct{}, len(dst))
for _, p := range dst {
if p == nil {
continue
}
seen[p.WgPubKey] = struct{}{}
}
for _, p := range extra {
if p == nil {
continue
}
if _, ok := seen[p.WgPubKey]; ok {
continue
}
seen[p.WgPubKey] = struct{}{}
dst = append(dst, p)
}
return dst
}
func trimKey(s string) string {
if len(s) > 12 {
return s[:12]
}
return s
}
// findPeerByWgKey locates the receiving peer in the decoded components by
// matching its WireGuard public key. Compares raw 32-byte decode output —
// not the base64 string — because production data has occasional non-canonical
// padding bits that round-trip through the envelope's `bytes wg_pub_key`
// field, canonicalising the encoding (semantically equivalent key, different
// string). Decodes `wgKey` once up front and reuses a stack buffer in the
// loop so an N-peer search is ~zero-alloc.
func findPeerByWgKey(c *types.NetworkMapComponents, wgKey string) (string, *nbpeer.Peer) {
const wgKeyRawLen = 32
var (
targetRaw [wgKeyRawLen]byte
haveRaw bool
)
if n, err := base64.StdEncoding.Decode(targetRaw[:], []byte(wgKey)); err == nil && n == wgKeyRawLen {
haveRaw = true
}
var peerRaw [wgKeyRawLen]byte
for id, p := range c.Peers {
if p == nil {
continue
}
if p.Key == wgKey {
return id, p
}
if !haveRaw {
continue
}
n, err := base64.StdEncoding.Decode(peerRaw[:], []byte(p.Key))
if err == nil && n == wgKeyRawLen && bytes.Equal(peerRaw[:], targetRaw[:]) {
return id, p
}
}
return "", nil
}

View File

@@ -0,0 +1,173 @@
package networkmap_test
import (
"context"
"crypto/rand"
"encoding/base64"
"net"
"net/netip"
"testing"
"github.com/stretchr/testify/require"
goproto "google.golang.org/protobuf/proto"
mgmtgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/types"
nbnetworkmap "github.com/netbirdio/netbird/shared/management/networkmap"
"github.com/netbirdio/netbird/shared/management/proto"
)
// TestEnvelopeToNetworkMap_RoundTrip exercises the full client-side pipeline:
// build a small components struct, encode an envelope, marshal/unmarshal the
// wire bytes, decode back via EnvelopeToNetworkMap, and verify the result is
// non-empty and consistent. Deeper per-field semantic equivalence with the
// legacy server path is covered by the prod-DB equivalence test in
// management/server/store/networkmap_envelope_equivalence_test.go.
func TestEnvelopeToNetworkMap_RoundTrip(t *testing.T) {
c, localPeerKey := buildSmokeComponents(t)
envelope := mgmtgrpc.EncodeNetworkMapEnvelope(mgmtgrpc.ComponentsEnvelopeInput{
Components: c,
DNSDomain: "netbird.cloud",
})
wire, err := goproto.Marshal(envelope)
require.NoError(t, err, "marshal envelope")
var decoded proto.NetworkMapEnvelope
require.NoError(t, goproto.Unmarshal(wire, &decoded), "unmarshal envelope")
result, err := nbnetworkmap.EnvelopeToNetworkMap(context.Background(), &decoded, localPeerKey, "netbird.cloud")
require.NoError(t, err, "EnvelopeToNetworkMap")
require.NotNil(t, result)
require.NotNil(t, result.NetworkMap, "decoded NetworkMap must be non-nil")
require.NotNil(t, result.Components, "Components must be retained for future delta updates")
require.NotNil(t, result.Components.AccountSettings)
require.NotEmpty(t, result.NetworkMap.RemotePeers, "two-peer allow policy should produce one remote peer")
require.NotEmpty(t, result.NetworkMap.FirewallRules, "two-peer allow policy should produce firewall rules")
}
// TestCalculate_FirewallRuleProtocol_NeverNetbirdSSH guards against the
// scenario where a rule with Protocol=NetbirdSSH leaks the enum value into
// proto.FirewallRule.Protocol. Calculate() must rewrite NetbirdSSH → TCP
// before forming firewall rules (see networkmap_components.go:282 and
// account.go:868). Without that rewrite, agents fall into UNKNOWN-protocol
// handling, which on some platforms downgrades to allow-all — a real
// security regression.
func TestCalculate_FirewallRuleProtocol_NeverNetbirdSSH(t *testing.T) {
c, localPeerKey := buildSmokeComponents(t)
// Replace the smoke policy with a NetbirdSSH-protocol allow.
c.Policies = []*types.Policy{{
ID: "pol-ssh", AccountSeqID: 2, Enabled: true,
Rules: []*types.PolicyRule{{
ID: "rule-ssh",
Enabled: true,
Action: types.PolicyTrafficActionAccept,
Protocol: types.PolicyRuleProtocolNetbirdSSH,
Bidirectional: true,
Sources: []string{"group-all"},
Destinations: []string{"group-all"},
}},
}}
envelope := mgmtgrpc.EncodeNetworkMapEnvelope(mgmtgrpc.ComponentsEnvelopeInput{
Components: c,
DNSDomain: "netbird.cloud",
})
wire, err := goproto.Marshal(envelope)
require.NoError(t, err)
var decoded proto.NetworkMapEnvelope
require.NoError(t, goproto.Unmarshal(wire, &decoded))
result, err := nbnetworkmap.EnvelopeToNetworkMap(context.Background(), &decoded, localPeerKey, "netbird.cloud")
require.NoError(t, err)
require.NotEmpty(t, result.NetworkMap.FirewallRules, "ssh policy should produce firewall rules")
for i, fr := range result.NetworkMap.FirewallRules {
require.NotEqualf(t, proto.RuleProtocol_NETBIRD_SSH, fr.Protocol,
"FirewallRules[%d].Protocol must be the rewritten TCP, not NETBIRD_SSH", i)
}
}
func TestEnvelopeToNetworkMap_NilEnvelope(t *testing.T) {
_, err := nbnetworkmap.EnvelopeToNetworkMap(context.Background(), nil, "key", "netbird.cloud")
require.Error(t, err, "nil envelope must produce an error rather than panic")
}
func TestEnvelopeToNetworkMap_FullPayloadMissing(t *testing.T) {
env := &proto.NetworkMapEnvelope{}
_, err := nbnetworkmap.EnvelopeToNetworkMap(context.Background(), env, "key", "netbird.cloud")
require.Error(t, err, "envelope with no Full payload must produce an error")
}
// buildSmokeComponents returns a minimal NetworkMapComponents (2 peers, 1
// group, 1 allow policy) plus the receiving peer's WG public key. Sufficient
// to validate the encode → marshal → decode → Calculate pipeline produces
// non-empty output.
func buildSmokeComponents(t *testing.T) (*types.NetworkMapComponents, string) {
t.Helper()
peerAKey := randomWgKey(t)
peerBKey := randomWgKey(t)
peerA := &nbpeer.Peer{
ID: "peer-A",
Key: peerAKey,
IP: netip.AddrFrom4([4]byte{100, 64, 0, 1}),
DNSLabel: "peerA",
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.40.0"},
}
peerB := &nbpeer.Peer{
ID: "peer-B",
Key: peerBKey,
IP: netip.AddrFrom4([4]byte{100, 64, 0, 2}),
DNSLabel: "peerB",
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.40.0"},
}
group := &types.Group{
ID: "group-all", AccountSeqID: 1, Name: "All",
Peers: []string{"peer-A", "peer-B"},
}
policy := &types.Policy{
ID: "pol-allow", AccountSeqID: 1, Enabled: true,
Rules: []*types.PolicyRule{{
ID: "rule-allow",
Enabled: true,
Action: types.PolicyTrafficActionAccept,
Protocol: types.PolicyRuleProtocolALL,
Bidirectional: true,
Sources: []string{"group-all"},
Destinations: []string{"group-all"},
}},
}
c := &types.NetworkMapComponents{
PeerID: "peer-A",
Network: &types.Network{
Identifier: "net-smoke",
Net: net.IPNet{IP: net.IP{100, 64, 0, 0}, Mask: net.CIDRMask(10, 32)},
Serial: 1,
},
AccountSettings: &types.AccountSettingsInfo{},
DNSSettings: &types.DNSSettings{},
Peers: map[string]*nbpeer.Peer{
"peer-A": peerA,
"peer-B": peerB,
},
Groups: map[string]*types.Group{
"group-all": group,
},
Policies: []*types.Policy{policy},
}
return c, peerAKey
}
func randomWgKey(t *testing.T) string {
t.Helper()
var raw [32]byte
_, err := rand.Read(raw[:])
require.NoError(t, err)
return base64.StdEncoding.EncodeToString(raw[:])
}

File diff suppressed because it is too large Load Diff

View File

@@ -133,6 +133,12 @@ message SyncResponse {
// Posture checks to be evaluated by client
repeated Checks Checks = 6;
// NetworkMapEnvelope carries the component-based wire format for peers that
// advertise PeerCapabilityComponentNetworkMap. When set, NetworkMap (field 5)
// is left empty: management ships components and the client runs Calculate()
// locally instead of receiving an expanded NetworkMap.
NetworkMapEnvelope NetworkMapEnvelope = 7;
}
message SyncMetaRequest {
@@ -212,6 +218,8 @@ enum PeerCapability {
PeerCapabilitySourcePrefixes = 1;
// Client handles IPv6 overlay addresses and firewall rules.
PeerCapabilityIPv6Overlay = 2;
// Client receives NetworkMap as components and assembles it locally.
PeerCapabilityComponentNetworkMap = 3;
}
// PeerSystemMeta is machine meta data like OS and version.
@@ -569,6 +577,13 @@ enum RuleProtocol {
UDP = 3;
ICMP = 4;
CUSTOM = 5;
// NETBIRD_SSH (types.PolicyRuleProtocolType "netbird-ssh") is the marker
// policy rule that drives SSH-server activation in Calculate(). The legacy
// proto.FirewallRule path doesn't ship this value (Calculate already
// expands SSH rules into TCP/22 before encoding), but the components path
// ships RAW policies — the client must see this protocol to derive
// AuthorizedUsers locally.
NETBIRD_SSH = 6;
}
enum RuleDirection {
@@ -709,3 +724,462 @@ message StopExposeRequest {
}
message StopExposeResponse {}
// =====================================================================
// Component-based NetworkMap wire format (PeerCapabilityComponentNetworkMap).
//
// Peers that advertise this capability receive NetworkMap building blocks
// (peers + groups + policies + routes + dns + ssh + forwarding) and run the
// expansion (Calculate) locally instead of receiving a fully-expanded
// NetworkMap from the server.
// =====================================================================
// NetworkMapEnvelope wraps either a full snapshot or a delta. Step 2 ships
// only Full; Delta is reserved for the incremental-update work.
message NetworkMapEnvelope {
oneof payload {
NetworkMapComponentsFull full = 1;
NetworkMapComponentsDelta delta = 2;
}
}
// NetworkMapComponentsFull is the full per-peer component snapshot. The
// client decodes it into a types.NetworkMapComponents and runs Calculate()
// locally to produce the same NetworkMap the legacy server path would have
// produced. Every field carries RAW component data — no server-side
// expansion (firewall rules, DNS config, SSH auth, route firewall rules,
// forwarding rules) is shipped; the client computes those itself.
message NetworkMapComponentsFull {
uint64 serial = 1;
// Peer config for the receiving peer (legacy proto.PeerConfig kept as-is —
// it carries the receiving peer's own overlay address, FQDN, SSH config).
PeerConfig peer_config = 2;
// Account-level network metadata (id, IPv4/IPv6 overlay subnets, DNS,
// serial). Mirrors types.Network.
AccountNetwork network = 3;
// Account-level settings the client needs for its local Calculate().
AccountSettingsCompact account_settings = 4;
// Account DNS settings (mirrors types.DNSSettings).
DNSSettingsCompact dns_settings = 5;
// Domain shared across all peers in this account, e.g. "netbird.cloud".
// Each peer's FQDN is dns_label + "." + dns_domain.
string dns_domain = 6;
// Custom-zone domain for this peer's view (c.CustomZoneDomain). Empty when
// the peer has no custom zone records.
string custom_zone_domain = 7;
// Deduplicated agent versions; PeerCompact.agent_version_idx indexes here.
// Empty string at index 0 if any peer has no version.
repeated string agent_versions = 8;
// All peers (deduplicated). The client splits peers into online / offline
// locally using account_settings.peer_login_expiration on receive.
repeated PeerCompact peers = 9;
// Indexes into peers for the subset that may act as routers.
repeated uint32 router_peer_indexes = 10;
// Policies that affect the receiving peer.
repeated PolicyCompact policies = 11;
// Groups in unspecified order — clients key off id (account_seq_id).
repeated GroupCompact groups = 12;
// Routes relevant to this peer, raw shape (mirrors []*route.Route).
repeated RouteRaw routes = 13;
// Nameserver groups (mirrors []*nbdns.NameServerGroup).
repeated NameServerGroupRaw nameserver_groups = 14;
// All DNS records the client needs to assemble its custom zone. Reuses
// the existing SimpleRecord wire shape.
repeated SimpleRecord all_dns_records = 15;
// Custom zones (typically the peer's own zone). Reuses the existing
// CustomZone wire shape.
repeated CustomZone account_zones = 16;
// Network resources (mirrors []*resourceTypes.NetworkResource).
repeated NetworkResourceRaw network_resources = 17;
// Routers per network. Outer key: network account_seq_id. Each entry is
// the set of routers backing that network for this peer's view.
//
// INCOMPATIBLE WIRE CHANGE: the map key changed from string (network xid)
// to uint32 (account_seq_id). Field 18 was reused without a `reserved`
// entry because capability=3 has never been released — every cap=3
// producer and consumer carries the same regenerated descriptor. Do NOT
// reuse this pattern for any further wire change once cap=3 ships.
map<uint32, NetworkRouterList> routers_map = 18;
// For each NetworkResource account_seq_id, the indexes into policies[]
// that apply to it.
//
// INCOMPATIBLE WIRE CHANGE: see routers_map note above.
map<uint32, PolicyIndexes> resource_policies_map = 19;
// Group-id (account_seq_id) → user ids authorized for SSH on members.
map<uint32, UserIDList> group_id_to_user_ids = 20;
// Account-level allowed user ids (used by Calculate() when assembling SSH
// authorized users for the receiving peer).
repeated string allowed_user_ids = 21;
// Per posture-check account_seq_id, the set of peer indexes that failed
// the check. Server-side evaluation result; clients do not re-evaluate.
//
// INCOMPATIBLE WIRE CHANGE: see routers_map note above.
map<uint32, PeerIndexSet> posture_failed_peers = 22;
// Account-level DNS forwarder port (mirrors the legacy
// proto.DNSConfig.ForwarderPort). Computed by the controller from peer
// versions; clients fold it into their Calculate() DNS output.
int64 dns_forwarder_port = 23;
// Pre-expanded NetworkMap fragments injected post-Calculate by external
// controllers (BYOP / port-forwarding proxies). The receiving client
// merges these into its locally-computed NetworkMap the same way the
// legacy server does via NetworkMap.Merge — so downstream consumers see
// a unified merged result regardless of source.
ProxyPatch proxy_patch = 24;
// SSH UserIDClaim — server-side HttpServerConfig.AuthUserIDClaim, or
// "sub" by default. Populated in proto.SSHAuth.UserIDClaim when the
// client rebuilds the NetworkMap from this envelope. Empty when the
// account has no AuthorizedUsers (and thus no SshAuth to populate).
string user_id_claim = 25;
// Reserved for future component additions (incremental_serial, parent_seq,
// etc.) without forcing a renumber.
reserved 26 to 50;
}
// ProxyPatch carries NetworkMap fragments that don't fit the component-graph
// model — they're pre-expanded by external controllers (BYOP /
// port-forwarding proxies) and injected post-Calculate. Fields use the
// legacy wire types because the proxy delivers them pre-formed; there is
// no raw component shape to convert from. Empty when no proxy is active.
message ProxyPatch {
repeated RemotePeerConfig peers = 1;
repeated RemotePeerConfig offline_peers = 2;
repeated FirewallRule firewall_rules = 3;
repeated Route routes = 4;
repeated RouteFirewallRule route_firewall_rules = 5;
repeated ForwardingRule forwarding_rules = 6;
}
// AccountSettingsCompact carries the account-level settings the client needs
// to evaluate locally. Mirrors the subset of types.AccountSettingsInfo that
// Calculate() actually reads — login-expiration (used to filter expired
// peers). Inactivity expiration is purely server-side bookkeeping and is not
// shipped.
message AccountSettingsCompact {
bool peer_login_expiration_enabled = 1;
// Login expiration window. Unit is nanoseconds (matches time.Duration).
int64 peer_login_expiration_ns = 2;
}
// AccountNetwork is the account-level overlay metadata. Mirrors types.Network
// so the client can populate NetworkMap.Network without a server round-trip.
message AccountNetwork {
string identifier = 1;
// IPv4 overlay subnet in CIDR form (e.g. "100.64.0.0/16").
string net_cidr = 2;
// IPv6 ULA overlay subnet in CIDR form (e.g. "fd00:4e42::/64"). Empty when
// the account has no IPv6 overlay yet.
string net_v6_cidr = 3;
string dns = 4;
uint64 serial = 5;
}
// NetworkMapComponentsDelta is reserved for the incremental update protocol
// (Step 3 of the migration plan). Field numbers 1100 are pre-allocated to
// keep room for the planned event types without needing a renumber.
message NetworkMapComponentsDelta {
reserved 1 to 100;
}
// PeerCompact is the wire-shape of a remote peer used by the component
// format. It carries every field of types.Peer that the client's local
// Calculate() reads — including the trio needed to evaluate
// LoginExpired() (added_with_sso_login + login_expiration_enabled +
// last_login_unix_nano). Fields the client does not consume (Status,
// CreatedAt, etc.) are not shipped.
message PeerCompact {
// Raw 32-byte WireGuard public key (no base64 wrapping).
bytes wg_pub_key = 1;
// Raw 4-byte IPv4 overlay address. Always a /32 host route, so no prefix
// byte is needed.
bytes ip = 2;
// Raw 16-byte IPv6 overlay address; always a /128 host route. Empty when
// the peer has no IPv6 overlay address.
bytes ipv6 = 3;
// Raw SSH public key bytes (or empty).
bytes ssh_pub_key = 4;
// DNS label without the account's domain suffix. Full FQDN is
// dns_label + "." + NetworkMapComponentsFull.dns_domain.
string dns_label = 5;
// Index into NetworkMapComponentsFull.agent_versions.
uint32 agent_version_idx = 6;
// True iff the peer was added via SSO login (i.e., types.Peer.UserID is
// non-empty). Combined with login_expiration_enabled and
// last_login_unix_nano this lets the client reproduce
// (*Peer).LoginExpired() locally.
bool added_with_sso_login = 7;
// True when the peer's login can expire — mirrors
// types.Peer.LoginExpirationEnabled.
bool login_expiration_enabled = 8;
// Unix-nanosecond timestamp of the peer's last login. 0 when the peer has
// never logged in (server stores nil; client treats 0 as "epoch", which
// makes a fresh peer immediately expired iff login_expiration_enabled is
// true — the same semantics as types.Peer.GetLastLogin).
int64 last_login_unix_nano = 9;
// True when the peer has an SSH server enabled locally. Used by the
// legacy SSH path in Calculate() (`policyRuleImpliesLegacySSH`): a rule
// with protocol ALL/TCP-with-SSH-ports activates SSH for the receiving
// peer when this bit is set, even without an explicit NetbirdSSH rule.
bool ssh_enabled = 10;
reserved 11; // was: id (string xid)
// Mirror of types.Peer.SupportsIPv6() — !Meta.Flags.DisableIPv6 &&
// HasCapability(PeerCapabilityIPv6Overlay). Used by the local peer's
// Calculate() when deciding whether to emit IPv6 firewall rules
// (appendIPv6FirewallRule) against this peer's IPv6 address.
bool supports_ipv6 = 12;
// Mirror of types.Peer.SupportsSourcePrefixes() —
// HasCapability(PeerCapabilitySourcePrefixes). Determines whether the
// local peer's Calculate() emits SourcePrefixes alongside legacy PeerIP
// fields in proto.FirewallRule.
bool supports_source_prefixes = 13;
// Mirror of types.Peer.Meta.Flags.ServerSSHAllowed. Read by Calculate()
// when expanding TCP port-22 firewall rules — the native SSH companion
// (port 22022) is only added when this flag is set and the peer agent
// version supports it.
bool server_ssh_allowed = 14;
}
// PolicyCompact is the compact form of a policy rule. Group references use
// the per-account integer ids from account_seq_counters; the client resolves
// them against NetworkMapComponentsFull.groups. Direction is derived per-peer
// on the client (ingress when the peer is in destination_group_ids, egress
// when in source_group_ids; both when bidirectional).
message PolicyCompact {
// Per-account integer id (matches policies.account_seq_id). Used as a
// stable reference for ResourcePoliciesMap.indexes and future delta
// updates (Step 3).
uint32 id = 1;
RuleAction action = 2;
RuleProtocol protocol = 3;
bool bidirectional = 4;
// Single ports referenced by the rule.
repeated uint32 ports = 5;
// Port ranges (start..end) referenced by the rule.
repeated PortInfo.Range port_ranges = 6;
// Group ids (account_seq_id) of source / destination groups.
repeated uint32 source_group_ids = 7;
repeated uint32 destination_group_ids = 8;
reserved 9; // was: xid (string)
// SSH authorization fields. PolicyRule.AuthorizedGroups maps the rule's
// applicable group ids (account_seq_id) to a list of local-user names —
// when a peer in one of those groups is the SSH destination, the named
// local users gain access. AuthorizedUser is the single-user form
// (legacy: rule scopes SSH to one specific user id).
//
// Both fields are only consumed by Calculate() when the rule's protocol
// is NetbirdSSH (or the legacy implicit-SSH heuristic).
map<uint32, UserNameList> authorized_groups = 10;
string authorized_user = 11;
// Resource-typed rule sources/destinations. When a rule targets a specific
// peer (rather than groups), Calculate() reads SourceResource /
// DestinationResource — without these the rule's connection resources
// can't be produced on the client. ResourceCompact's peer_index refers to
// NetworkMapComponentsFull.peers; type is the raw ResourceType string
// ("peer", "host", "subnet", "domain"). Only "peer" is meaningful for
// Calculate's resource-typed rule path today.
ResourceCompact source_resource = 12;
ResourceCompact destination_resource = 13;
// Posture-check seq ids gating this policy's source peers. Calculate()
// reads them when filtering rule peers (peers that fail any listed check
// are dropped from sourcePeers). Match keys in
// NetworkMapComponentsFull.posture_failed_peers.
repeated uint32 source_posture_check_seq_ids = 15;
reserved 14; // was: source_posture_check_ids (repeated string xid)
}
// ResourceCompact mirrors types.Resource. Used by PolicyCompact to carry
// rule.SourceResource / rule.DestinationResource when the rule targets a
// specific resource (typically a peer) rather than groups.
// peer_index_set tells whether peer_index is valid (proto3 uint32 cannot
// disambiguate "0" from "unset"); set only when type == "peer".
message ResourceCompact {
string type = 1;
bool peer_index_set = 2;
uint32 peer_index = 3;
reserved 4; // future: host/subnet/domain references when needed
}
// UserNameList is a list of local-user names — used as the value type in
// PolicyCompact.authorized_groups.
message UserNameList {
repeated string names = 1;
}
// GroupCompact is the wire-shape of a group: per-account integer id, optional
// name, and indexes into NetworkMapComponentsFull.peers identifying members.
message GroupCompact {
// Per-account integer id (matches groups.account_seq_id). Used by
// PolicyCompact.source_group_ids / destination_group_ids.
uint32 id = 1;
// Group name; only sent when non-empty (clients use it for diagnostics).
string name = 2;
// Indexes into NetworkMapComponentsFull.peers.
repeated uint32 peer_indexes = 3;
}
// DNSSettingsCompact mirrors types.DNSSettings.
message DNSSettingsCompact {
// Group ids (account_seq_id) whose DNS management is disabled.
repeated uint32 disabled_management_group_ids = 1;
}
// RouteRaw mirrors *route.Route (the domain type), trimmed to fields that
// types.NetworkMapComponents.Calculate() reads. Group references are
// account_seq_ids; the routing peer (when set) is referenced by index into
// NetworkMapComponentsFull.peers.
message RouteRaw {
// Per-account integer id (matches routes.account_seq_id).
uint32 id = 1;
string net_id = 2;
string description = 3;
// Either network_cidr (e.g. "10.0.0.0/16") or domains is set, not both.
string network_cidr = 4;
repeated string domains = 5;
bool keep_route = 6;
// Routing peer reference: peer_index_set tells whether peer_index is valid
// (proto3 uint32 cannot disambiguate "0" from "unset"). Mutually exclusive
// with peer_group_ids.
//
// peer_index decodes back to types.Peer.ID (the peer's xid string), NOT
// to its WireGuard public key. This matches the server-side data flow:
// c.Routes carry route.Peer = peer.ID, and getRoutingPeerRoutes mutates
// it to peer.Key only after the route has been admitted to the network
// map. Decoders MUST set Route.Peer = peer.ID; the legacy Calculate()
// path will substitute the WG key downstream.
bool peer_index_set = 7;
uint32 peer_index = 8;
repeated uint32 peer_group_ids = 9;
int32 network_type = 10;
bool masquerade = 11;
int32 metric = 12;
bool enabled = 13;
repeated uint32 group_ids = 14;
repeated uint32 access_control_group_ids = 15;
bool skip_auto_apply = 16;
reserved 17; // was: xid (string)
}
// NameServerGroupRaw mirrors *nbdns.NameServerGroup. Distinct from the
// legacy NameServerGroup (which is the wire-trimmed shape consumed by
// proto.DNSConfig and lacks the Name/Description/Groups/Enabled fields).
message NameServerGroupRaw {
uint32 id = 1; // nameserver_groups.account_seq_id
string name = 2;
string description = 3;
// Reuses the legacy NameServer wire shape (IP as string).
repeated NameServer nameservers = 4;
// Group ids (account_seq_id) the NSG distributes nameservers to.
repeated uint32 group_ids = 5;
bool primary = 6;
repeated string domains = 7;
bool enabled = 8;
bool search_domains_enabled = 9;
}
// NetworkResourceRaw mirrors *resourceTypes.NetworkResource.
//
// INCOMPATIBLE WIRE CHANGE: field 2 changed from `string network_id` (xid)
// to `uint32 network_seq` without a `reserved` entry. Safe only because
// capability=3 has never been released — every cap=3 producer and consumer
// carries the same regenerated descriptor. Do NOT reuse this pattern once
// cap=3 ships.
message NetworkResourceRaw {
uint32 id = 1; // network_resources.account_seq_id
uint32 network_seq = 2; // networks.account_seq_id (replaces xid)
string name = 3;
string description = 4;
// Resource type: "host" / "subnet" / "domain".
string type = 5;
string address = 6;
string domain_value = 7; // resource.Domain
string prefix_cidr = 8;
bool enabled = 9;
reserved 10; // was: xid (string)
}
// NetworkRouterList carries the routers backing one network.
message NetworkRouterList {
// Routers in this network, keyed by peer_index (the routing peer).
repeated NetworkRouterEntry entries = 1;
}
// NetworkRouterEntry mirrors a single *routerTypes.NetworkRouter; the routing
// peer is referenced by index into NetworkMapComponentsFull.peers.
message NetworkRouterEntry {
uint32 id = 1; // network_routers.account_seq_id
uint32 peer_index = 2;
bool peer_index_set = 3;
repeated uint32 peer_group_ids = 4;
bool masquerade = 5;
int32 metric = 6;
bool enabled = 7;
}
// PolicyIndexes is a list of indexes into NetworkMapComponentsFull.policies.
message PolicyIndexes {
repeated uint32 indexes = 1;
}
// UserIDList is a list of user ids — used as the value type in
// NetworkMapComponentsFull.group_id_to_user_ids.
message UserIDList {
repeated string user_ids = 1;
}
// PeerIndexSet is a set of peer indexes — used as the value type in
// NetworkMapComponentsFull.posture_failed_peers.
message PeerIndexSet {
repeated uint32 peer_indexes = 1;
}