mirror of
https://github.com/netbirdio/netbird.git
synced 2026-05-21 16:19:56 +00:00
Compare commits
16 Commits
v0.71.2
...
nmap/compo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
13e41e432c | ||
|
|
efa6a3f502 | ||
|
|
5fbcdeceac | ||
|
|
3a1bbeba90 | ||
|
|
728057ef15 | ||
|
|
582cd70086 | ||
|
|
9bbbafaf69 | ||
|
|
672b057aa0 | ||
|
|
b9a0186200 | ||
|
|
9083bdb977 | ||
|
|
b194af48b8 | ||
|
|
4543780ef0 | ||
|
|
2de0283971 | ||
|
|
af24fd7796 | ||
|
|
13d32d274f | ||
|
|
705f87fc20 |
@@ -61,9 +61,11 @@ import (
|
||||
cProto "github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
mgm "github.com/netbirdio/netbird/shared/management/client"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
nbnetworkmap "github.com/netbirdio/netbird/shared/management/networkmap"
|
||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
"github.com/netbirdio/netbird/shared/netiputil"
|
||||
auth "github.com/netbirdio/netbird/shared/relay/auth/hmac"
|
||||
@@ -202,6 +204,13 @@ type Engine struct {
|
||||
// networkSerial is the latest CurrentSerial (state ID) of the network sent by the Management service
|
||||
networkSerial uint64
|
||||
|
||||
// latestComponents is the most-recent NetworkMapComponents decoded from
|
||||
// a NetworkMapEnvelope (capability=3 peers only). Held alongside the
|
||||
// NetworkMap that Calculate() produced from it so Step 3 incremental
|
||||
// updates have a base to apply changes against. nil for legacy-format
|
||||
// peers. Guarded by syncMsgMux.
|
||||
latestComponents *types.NetworkMapComponents
|
||||
|
||||
networkMonitor *networkmonitor.NetworkMonitor
|
||||
|
||||
sshServer sshServer
|
||||
@@ -865,8 +874,12 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
||||
return e.ctx.Err()
|
||||
}
|
||||
|
||||
if update.NetworkMap != nil && update.NetworkMap.PeerConfig != nil {
|
||||
e.handleAutoUpdateVersion(update.NetworkMap.PeerConfig.AutoUpdate)
|
||||
// Envelope sync responses carry PeerConfig at the top level; legacy
|
||||
// NetworkMap syncs carry it under NetworkMap.PeerConfig.
|
||||
if pc := update.GetPeerConfig(); pc != nil {
|
||||
e.handleAutoUpdateVersion(pc.GetAutoUpdate())
|
||||
} else if nm := update.GetNetworkMap(); nm != nil && nm.GetPeerConfig() != nil {
|
||||
e.handleAutoUpdateVersion(nm.GetPeerConfig().GetAutoUpdate())
|
||||
}
|
||||
|
||||
if update.GetNetbirdConfig() != nil {
|
||||
@@ -907,11 +920,45 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
||||
return err
|
||||
}
|
||||
|
||||
nm := update.GetNetworkMap()
|
||||
var (
|
||||
nm *mgmProto.NetworkMap
|
||||
components *types.NetworkMapComponents
|
||||
)
|
||||
if envelope := update.GetNetworkMapEnvelope(); envelope != nil {
|
||||
// Components-format peer: decode the envelope back to typed
|
||||
// components, run Calculate() locally, and convert to the wire
|
||||
// NetworkMap shape the rest of the engine consumes. Components are
|
||||
// retained so future incremental updates (Step 3) can apply deltas
|
||||
// instead of doing a full reconstruction.
|
||||
localKey := e.config.WgPrivateKey.PublicKey().String()
|
||||
dnsName := ""
|
||||
if pc := update.GetPeerConfig(); pc != nil {
|
||||
// PeerConfig.Fqdn = "<dns_label>.<dns_domain>" — extract the
|
||||
// shared domain by stripping the peer's own label prefix. Falls
|
||||
// back to empty if the FQDN doesn't have the expected shape.
|
||||
dnsName = extractDNSDomainFromFQDN(pc.GetFqdn())
|
||||
}
|
||||
result, err := nbnetworkmap.EnvelopeToNetworkMap(e.ctx, envelope, localKey, dnsName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("decode network map envelope: %w", err)
|
||||
}
|
||||
nm = result.NetworkMap
|
||||
components = result.Components
|
||||
} else {
|
||||
nm = update.GetNetworkMap()
|
||||
}
|
||||
if nm == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Only retain the components view when the server sent the envelope
|
||||
// path. A legacy proto.NetworkMap means components == nil; writing it
|
||||
// here would clobber a previously-cached snapshot, breaking the Step 3
|
||||
// incremental-delta base on a future envelope sync.
|
||||
if components != nil {
|
||||
e.latestComponents = components
|
||||
}
|
||||
|
||||
// Persist sync response under the dedicated lock (syncRespMux), not under syncMsgMux.
|
||||
// Read the storage-enabled flag under the syncRespMux too.
|
||||
e.syncRespMux.RLock()
|
||||
@@ -937,6 +984,19 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// extractDNSDomainFromFQDN returns the trailing dotted domain part of the
|
||||
// receiving peer's FQDN — the same value the management server fills as
|
||||
// dnsName when it builds the legacy NetworkMap. "peer42.netbird.cloud" →
|
||||
// "netbird.cloud". An empty string is returned for unrecognized formats.
|
||||
func extractDNSDomainFromFQDN(fqdn string) string {
|
||||
for i := 0; i < len(fqdn); i++ {
|
||||
if fqdn[i] == '.' && i+1 < len(fqdn) {
|
||||
return fqdn[i+1:]
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (e *Engine) handleRelayUpdate(update *mgmProto.RelayConfig) error {
|
||||
if update != nil {
|
||||
// when we receive token we expect valid address list too
|
||||
|
||||
@@ -53,6 +53,9 @@ type NameServerGroup struct {
|
||||
ID string `gorm:"primaryKey"`
|
||||
// AccountID is a reference to Account that this object belongs
|
||||
AccountID string `gorm:"index"`
|
||||
// AccountSeqID is a per-account monotonically increasing identifier used as the
|
||||
// compact wire id when sending NetworkMap components to capable peers.
|
||||
AccountSeqID uint32 `json:"-" gorm:"index:idx_nameserver_groups_account_seq_id;not null;default:0"`
|
||||
// Name group name
|
||||
Name string
|
||||
// Description group description
|
||||
|
||||
@@ -308,7 +308,7 @@ func (s *Storage) OpenStorage(logger *slog.Logger) (storage.Storage, error) {
|
||||
if file == "" {
|
||||
return nil, fmt.Errorf("sqlite3 storage requires 'file' config")
|
||||
}
|
||||
return (&sql.SQLite3{File: file}).Open(logger)
|
||||
return newSQLite3(file).Open(logger)
|
||||
case "postgres":
|
||||
dsn, _ := s.Config["dsn"].(string)
|
||||
if dsn == "" {
|
||||
|
||||
@@ -20,7 +20,6 @@ import (
|
||||
"github.com/dexidp/dex/server"
|
||||
"github.com/dexidp/dex/server/signer"
|
||||
"github.com/dexidp/dex/storage"
|
||||
"github.com/dexidp/dex/storage/sql"
|
||||
jose "github.com/go-jose/go-jose/v4"
|
||||
"github.com/google/uuid"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
@@ -77,7 +76,7 @@ func NewProvider(ctx context.Context, config *Config) (*Provider, error) {
|
||||
|
||||
// Initialize SQLite storage
|
||||
dbPath := filepath.Join(config.DataDir, "oidc.db")
|
||||
sqliteConfig := &sql.SQLite3{File: dbPath}
|
||||
sqliteConfig := newSQLite3(dbPath)
|
||||
stor, err := sqliteConfig.Open(logger)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open storage: %w", err)
|
||||
|
||||
@@ -55,6 +55,15 @@ type Controller struct {
|
||||
proxyController port_forwarding.Controller
|
||||
|
||||
integratedPeerValidator integrated_validator.IntegratedValidator
|
||||
|
||||
// componentsDisabled is the kill switch for the component-based wire
|
||||
// format. When true the controller emits legacy proto.NetworkMap to every
|
||||
// peer regardless of capability — used to roll back instantly via a
|
||||
// management restart from a bad components encoder.
|
||||
//
|
||||
// Set once in NewController from NB_NETWORK_MAP_COMPONENTS_DISABLE and
|
||||
// never written after — readers race-free without a mutex.
|
||||
componentsDisabled bool
|
||||
}
|
||||
|
||||
type bufferUpdate struct {
|
||||
@@ -81,12 +90,30 @@ func NewController(ctx context.Context, store store.Store, metrics telemetry.App
|
||||
settingsManager: settingsManager,
|
||||
dnsDomain: dnsDomain,
|
||||
config: config,
|
||||
componentsDisabled: parseBoolEnv("NB_NETWORK_MAP_COMPONENTS_DISABLE"),
|
||||
|
||||
proxyController: proxyController,
|
||||
EphemeralPeersManager: ephemeralPeersManager,
|
||||
}
|
||||
}
|
||||
|
||||
// PeerNeedsComponents reports whether the gRPC layer should emit the
|
||||
// component-based wire format for this peer. Combines the peer's advertised
|
||||
// capability with the controller-level kill switch — callers ask exactly
|
||||
// this question, so encapsulating it removes accidental double-checks.
|
||||
func (c *Controller) PeerNeedsComponents(p *nbpeer.Peer) bool {
|
||||
return p != nil && p.SupportsComponentNetworkMap() && !c.componentsDisabled
|
||||
}
|
||||
|
||||
// parseBoolEnv reads an env var via strconv.ParseBool so callers accept the
|
||||
// usual "1/t/T/TRUE/true/True" set instead of being strict about a single
|
||||
// literal — matches the convention used elsewhere in the codebase
|
||||
// (e.g. event.go's NB_TRAFFIC_EVENT_*) and reduces operator surprises.
|
||||
func parseBoolEnv(key string) bool {
|
||||
v, _ := strconv.ParseBool(os.Getenv(key))
|
||||
return v
|
||||
}
|
||||
|
||||
func (c *Controller) OnPeerConnected(ctx context.Context, accountID string, peerID string) (chan *network_map.UpdateMessage, error) {
|
||||
peer, err := c.repo.GetPeerByID(ctx, accountID, peerID)
|
||||
if err != nil {
|
||||
@@ -192,18 +219,26 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
|
||||
c.metrics.CountCalcPostureChecksDuration(time.Since(start))
|
||||
start = time.Now()
|
||||
|
||||
remotePeerNetworkMap := account.GetPeerNetworkMapFromComponents(ctx, p.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
||||
result := account.GetPeerNetworkMapResult(ctx, p.ID, c.componentsDisabled, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
||||
|
||||
c.metrics.CountCalcPeerNetworkMapDuration(time.Since(start))
|
||||
|
||||
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
|
||||
if ok {
|
||||
remotePeerNetworkMap.Merge(proxyNetworkMap)
|
||||
proxyNetworkMap := proxyNetworkMaps[p.ID]
|
||||
if result.NetworkMap != nil && proxyNetworkMap != nil {
|
||||
result.NetworkMap.Merge(proxyNetworkMap)
|
||||
}
|
||||
|
||||
peerGroups := account.GetPeerGroups(p.ID)
|
||||
start = time.Now()
|
||||
update := grpc.ToSyncResponse(ctx, nil, c.config.HttpConfig, c.config.DeviceAuthorizationFlow, p, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSetting, maps.Keys(peerGroups), dnsFwdPort)
|
||||
var update *proto.SyncResponse
|
||||
if result.IsComponents() {
|
||||
// proxyNetworkMap rides the envelope as a ProxyPatch sidecar;
|
||||
// the client merges it into Calculate()'s output the same
|
||||
// way the legacy server did via NetworkMap.Merge.
|
||||
update = grpc.ToComponentSyncResponse(ctx, nil, c.config.HttpConfig, c.config.DeviceAuthorizationFlow, p, nil, nil, result.Components, proxyNetworkMap, dnsDomain, postureChecks, account.Settings, extraSetting, maps.Keys(peerGroups), dnsFwdPort)
|
||||
} else {
|
||||
update = grpc.ToSyncResponse(ctx, nil, c.config.HttpConfig, c.config.DeviceAuthorizationFlow, p, nil, nil, result.NetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSetting, maps.Keys(peerGroups), dnsFwdPort)
|
||||
}
|
||||
c.metrics.CountToSyncResponseDuration(time.Since(start))
|
||||
|
||||
c.peersUpdateManager.SendUpdate(ctx, p.ID, &network_map.UpdateMessage{
|
||||
@@ -314,11 +349,11 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe
|
||||
return err
|
||||
}
|
||||
|
||||
remotePeerNetworkMap := account.GetPeerNetworkMapFromComponents(ctx, peerId, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
||||
result := account.GetPeerNetworkMapResult(ctx, peerId, c.componentsDisabled, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
||||
|
||||
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
|
||||
if ok {
|
||||
remotePeerNetworkMap.Merge(proxyNetworkMap)
|
||||
proxyNetworkMap := proxyNetworkMaps[peer.ID]
|
||||
if result.NetworkMap != nil && proxyNetworkMap != nil {
|
||||
result.NetworkMap.Merge(proxyNetworkMap)
|
||||
}
|
||||
|
||||
extraSettings, err := c.settingsManager.GetExtraSettings(ctx, peer.AccountID)
|
||||
@@ -329,7 +364,12 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe
|
||||
peerGroups := account.GetPeerGroups(peerId)
|
||||
dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), network_map.DnsForwarderPortMinVersion)
|
||||
|
||||
update := grpc.ToSyncResponse(ctx, nil, c.config.HttpConfig, c.config.DeviceAuthorizationFlow, peer, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSettings, maps.Keys(peerGroups), dnsFwdPort)
|
||||
var update *proto.SyncResponse
|
||||
if result.IsComponents() {
|
||||
update = grpc.ToComponentSyncResponse(ctx, nil, c.config.HttpConfig, c.config.DeviceAuthorizationFlow, peer, nil, nil, result.Components, proxyNetworkMap, dnsDomain, postureChecks, account.Settings, extraSettings, maps.Keys(peerGroups), dnsFwdPort)
|
||||
} else {
|
||||
update = grpc.ToSyncResponse(ctx, nil, c.config.HttpConfig, c.config.DeviceAuthorizationFlow, peer, nil, nil, result.NetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSettings, maps.Keys(peerGroups), dnsFwdPort)
|
||||
}
|
||||
c.peersUpdateManager.SendUpdate(ctx, peer.ID, &network_map.UpdateMessage{
|
||||
Update: update,
|
||||
MessageType: network_map.MessageTypeNetworkMap,
|
||||
@@ -376,6 +416,67 @@ func (c *Controller) BufferUpdateAccountPeers(ctx context.Context, accountID str
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetValidatedPeerWithComponents is the components-format counterpart of
|
||||
// GetValidatedPeerWithMap. It returns raw NetworkMapComponents for capable
|
||||
// peers along with the proxy NetworkMap fragment (BYOP / port-forwarding
|
||||
// data the legacy server folds in via NetworkMap.Merge). The gRPC layer
|
||||
// encodes both into the wire envelope. The caller is responsible for
|
||||
// checking peer capability + componentsDisabled before dispatching here —
|
||||
// this method does NOT branch on capability itself.
|
||||
func (c *Controller) GetValidatedPeerWithComponents(ctx context.Context, isRequiresApproval bool, accountID string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMapComponents, *types.NetworkMap, []*posture.Checks, int64, error) {
|
||||
if isRequiresApproval {
|
||||
network, err := c.repo.GetAccountNetwork(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, 0, err
|
||||
}
|
||||
return peer, &types.NetworkMapComponents{Network: network.Copy()}, nil, nil, 0, nil
|
||||
}
|
||||
|
||||
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, 0, err
|
||||
}
|
||||
|
||||
account.InjectProxyPolicies(ctx)
|
||||
|
||||
approvedPeersMap, err := c.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, 0, err
|
||||
}
|
||||
|
||||
postureChecks, err := c.getPeerPostureChecks(account, peer.ID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, 0, err
|
||||
}
|
||||
|
||||
accountZones, err := c.repo.GetAccountZones(ctx, account.Id)
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, 0, err
|
||||
}
|
||||
|
||||
// Fetch the proxy network map fragment for this peer alongside the
|
||||
// components — same single-account-load path the streaming controller
|
||||
// uses, so initial-sync delivers BYOP/forwarding patches synchronously
|
||||
// instead of waiting for the next streaming push.
|
||||
proxyNetworkMaps, err := c.proxyController.GetProxyNetworkMaps(ctx, account.Id, peer.ID, account.Peers)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err)
|
||||
return nil, nil, nil, nil, 0, err
|
||||
}
|
||||
|
||||
dnsDomain := c.GetDNSDomain(account.Settings)
|
||||
peersCustomZone := account.GetPeersCustomZone(ctx, dnsDomain)
|
||||
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
groupIDToUserIDs := account.GetActiveGroupUsers()
|
||||
components := account.GetPeerNetworkMapComponents(ctx, peer.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, groupIDToUserIDs)
|
||||
|
||||
dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), network_map.DnsForwarderPortMinVersion)
|
||||
|
||||
return peer, components, proxyNetworkMaps[peer.ID], postureChecks, dnsFwdPort, nil
|
||||
}
|
||||
|
||||
func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
|
||||
if isRequiresApproval {
|
||||
network, err := c.repo.GetAccountNetwork(ctx, accountID)
|
||||
|
||||
@@ -22,6 +22,10 @@ type Controller interface {
|
||||
UpdateAccountPeer(ctx context.Context, accountId string, peerId string) error
|
||||
BufferUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) error
|
||||
GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, p *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
|
||||
GetValidatedPeerWithComponents(ctx context.Context, isRequiresApproval bool, accountID string, p *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMapComponents, *types.NetworkMap, []*posture.Checks, int64, error)
|
||||
// PeerNeedsComponents combines the peer's advertised capability with the
|
||||
// kill-switch flag — the only public predicate gRPC layers should ask.
|
||||
PeerNeedsComponents(p *nbpeer.Peer) bool
|
||||
GetDNSDomain(settings *types.Settings) string
|
||||
StartWarmup(context.Context)
|
||||
GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error)
|
||||
|
||||
@@ -130,6 +130,39 @@ func (mr *MockControllerMockRecorder) GetValidatedPeerWithMap(ctx, isRequiresApp
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetValidatedPeerWithMap", reflect.TypeOf((*MockController)(nil).GetValidatedPeerWithMap), ctx, isRequiresApproval, accountID, p)
|
||||
}
|
||||
|
||||
// GetValidatedPeerWithComponents mocks base method.
|
||||
func (m *MockController) GetValidatedPeerWithComponents(ctx context.Context, isRequiresApproval bool, accountID string, p *peer.Peer) (*peer.Peer, *types.NetworkMapComponents, *types.NetworkMap, []*posture.Checks, int64, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetValidatedPeerWithComponents", ctx, isRequiresApproval, accountID, p)
|
||||
ret0, _ := ret[0].(*peer.Peer)
|
||||
ret1, _ := ret[1].(*types.NetworkMapComponents)
|
||||
ret2, _ := ret[2].(*types.NetworkMap)
|
||||
ret3, _ := ret[3].([]*posture.Checks)
|
||||
ret4, _ := ret[4].(int64)
|
||||
ret5, _ := ret[5].(error)
|
||||
return ret0, ret1, ret2, ret3, ret4, ret5
|
||||
}
|
||||
|
||||
// GetValidatedPeerWithComponents indicates an expected call of GetValidatedPeerWithComponents.
|
||||
func (mr *MockControllerMockRecorder) GetValidatedPeerWithComponents(ctx, isRequiresApproval, accountID, p any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetValidatedPeerWithComponents", reflect.TypeOf((*MockController)(nil).GetValidatedPeerWithComponents), ctx, isRequiresApproval, accountID, p)
|
||||
}
|
||||
|
||||
// PeerNeedsComponents mocks base method.
|
||||
func (m *MockController) PeerNeedsComponents(p *peer.Peer) bool {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "PeerNeedsComponents", p)
|
||||
ret0, _ := ret[0].(bool)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// PeerNeedsComponents indicates an expected call of PeerNeedsComponents.
|
||||
func (mr *MockControllerMockRecorder) PeerNeedsComponents(p any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PeerNeedsComponents", reflect.TypeOf((*MockController)(nil).PeerNeedsComponents), p)
|
||||
}
|
||||
|
||||
// OnPeerConnected mocks base method.
|
||||
func (m *MockController) OnPeerConnected(ctx context.Context, accountID, peerID string) (chan *UpdateMessage, error) {
|
||||
m.ctrl.T.Helper()
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
)
|
||||
@@ -47,6 +48,11 @@ type EphemeralManager struct {
|
||||
|
||||
lifeTime time.Duration
|
||||
cleanupWindow time.Duration
|
||||
|
||||
// metrics is nil-safe; methods on telemetry.EphemeralPeersMetrics
|
||||
// no-op when the receiver is nil so deployments without an app
|
||||
// metrics provider work unchanged.
|
||||
metrics *telemetry.EphemeralPeersMetrics
|
||||
}
|
||||
|
||||
// NewEphemeralManager instantiate new EphemeralManager
|
||||
@@ -60,6 +66,15 @@ func NewEphemeralManager(store store.Store, peersManager peers.Manager) *Ephemer
|
||||
}
|
||||
}
|
||||
|
||||
// SetMetrics attaches a metrics collector. Safe to call once before
|
||||
// LoadInitialPeers; later attachment is fine but earlier loads won't be
|
||||
// reflected in the gauge. Pass nil to detach.
|
||||
func (e *EphemeralManager) SetMetrics(m *telemetry.EphemeralPeersMetrics) {
|
||||
e.peersLock.Lock()
|
||||
e.metrics = m
|
||||
e.peersLock.Unlock()
|
||||
}
|
||||
|
||||
// LoadInitialPeers load from the database the ephemeral type of peers and schedule a cleanup procedure to the head
|
||||
// of the linked list (to the most deprecated peer). At the end of cleanup it schedules the next cleanup to the new
|
||||
// head.
|
||||
@@ -97,7 +112,9 @@ func (e *EphemeralManager) OnPeerConnected(ctx context.Context, peer *nbpeer.Pee
|
||||
e.peersLock.Lock()
|
||||
defer e.peersLock.Unlock()
|
||||
|
||||
e.removePeer(peer.ID)
|
||||
if e.removePeer(peer.ID) {
|
||||
e.metrics.DecPending(1)
|
||||
}
|
||||
|
||||
// stop the unnecessary timer
|
||||
if e.headPeer == nil && e.timer != nil {
|
||||
@@ -123,6 +140,7 @@ func (e *EphemeralManager) OnPeerDisconnected(ctx context.Context, peer *nbpeer.
|
||||
}
|
||||
|
||||
e.addPeer(peer.AccountID, peer.ID, e.newDeadLine())
|
||||
e.metrics.IncPending()
|
||||
if e.timer == nil {
|
||||
delay := e.headPeer.deadline.Sub(timeNow()) + e.cleanupWindow
|
||||
if delay < 0 {
|
||||
@@ -145,6 +163,7 @@ func (e *EphemeralManager) loadEphemeralPeers(ctx context.Context) {
|
||||
for _, p := range peers {
|
||||
e.addPeer(p.AccountID, p.ID, t)
|
||||
}
|
||||
e.metrics.AddPending(int64(len(peers)))
|
||||
|
||||
log.WithContext(ctx).Debugf("loaded ephemeral peer(s): %d", len(peers))
|
||||
}
|
||||
@@ -181,6 +200,15 @@ func (e *EphemeralManager) cleanup(ctx context.Context) {
|
||||
|
||||
e.peersLock.Unlock()
|
||||
|
||||
// Drop the gauge by the number of entries we just took off the list,
|
||||
// regardless of whether the subsequent DeletePeers call succeeds. The
|
||||
// list invariant is what the gauge tracks; failed delete batches are
|
||||
// counted separately via CountCleanupError so we can still see them.
|
||||
if len(deletePeers) > 0 {
|
||||
e.metrics.CountCleanupRun()
|
||||
e.metrics.DecPending(int64(len(deletePeers)))
|
||||
}
|
||||
|
||||
peerIDsPerAccount := make(map[string][]string)
|
||||
for id, p := range deletePeers {
|
||||
peerIDsPerAccount[p.accountID] = append(peerIDsPerAccount[p.accountID], id)
|
||||
@@ -191,7 +219,10 @@ func (e *EphemeralManager) cleanup(ctx context.Context) {
|
||||
err := e.peersManager.DeletePeers(ctx, accountID, peerIDs, activity.SystemInitiator, true)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to delete ephemeral peers: %s", err)
|
||||
e.metrics.CountCleanupError()
|
||||
continue
|
||||
}
|
||||
e.metrics.CountPeersCleaned(int64(len(peerIDs)))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -211,9 +242,12 @@ func (e *EphemeralManager) addPeer(accountID string, peerID string, deadline tim
|
||||
e.tailPeer = ep
|
||||
}
|
||||
|
||||
func (e *EphemeralManager) removePeer(id string) {
|
||||
// removePeer drops the entry from the linked list. Returns true if a
|
||||
// matching entry was found and removed so callers can keep the pending
|
||||
// metric gauge in sync.
|
||||
func (e *EphemeralManager) removePeer(id string) bool {
|
||||
if e.headPeer == nil {
|
||||
return
|
||||
return false
|
||||
}
|
||||
|
||||
if e.headPeer.id == id {
|
||||
@@ -221,7 +255,7 @@ func (e *EphemeralManager) removePeer(id string) {
|
||||
if e.tailPeer.id == id {
|
||||
e.tailPeer = nil
|
||||
}
|
||||
return
|
||||
return true
|
||||
}
|
||||
|
||||
for p := e.headPeer; p.next != nil; p = p.next {
|
||||
@@ -231,9 +265,10 @@ func (e *EphemeralManager) removePeer(id string) {
|
||||
e.tailPeer = p
|
||||
}
|
||||
p.next = p.next.next
|
||||
return
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (e *EphemeralManager) isPeerOnList(id string) bool {
|
||||
|
||||
@@ -112,7 +112,11 @@ func (s *BaseServer) AuthManager() auth.Manager {
|
||||
|
||||
func (s *BaseServer) EphemeralManager() ephemeral.Manager {
|
||||
return Create(s, func() ephemeral.Manager {
|
||||
return manager.NewEphemeralManager(s.Store(), s.PeersManager())
|
||||
em := manager.NewEphemeralManager(s.Store(), s.PeersManager())
|
||||
if metrics := s.Metrics(); metrics != nil {
|
||||
em.SetMetrics(metrics.EphemeralPeersMetrics())
|
||||
}
|
||||
return em
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
815
management/internals/shared/grpc/components_encoder.go
Normal file
815
management/internals/shared/grpc/components_encoder.go
Normal file
@@ -0,0 +1,815 @@
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"strconv"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
nbroute "github.com/netbirdio/netbird/route"
|
||||
"github.com/netbirdio/netbird/shared/management/networkmap"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
// wgKeyRawLen is the raw byte length of a WireGuard public key.
|
||||
const wgKeyRawLen = 32
|
||||
|
||||
// ComponentsEnvelopeInput bundles the data the component-format encoder needs.
|
||||
// In Step 2 the envelope is fully self-contained — every field needed by the
|
||||
// client's local Calculate() comes from the components struct itself. The
|
||||
// only externally-supplied data is the receiving peer's PeerConfig (which is
|
||||
// computed alongside the components in the network_map controller and reused
|
||||
// from the legacy proto path) and the dns_domain string.
|
||||
type ComponentsEnvelopeInput struct {
|
||||
Components *types.NetworkMapComponents
|
||||
PeerConfig *proto.PeerConfig
|
||||
DNSDomain string
|
||||
DNSForwarderPort int64
|
||||
// UserIDClaim is the OIDC claim name the client should embed in
|
||||
// SshAuth.UserIDClaim when reconstructing the NetworkMap. Empty value
|
||||
// is OK — client treats empty as "no SshAuth to build".
|
||||
UserIDClaim string
|
||||
// ProxyPatch carries pre-expanded NetworkMap fragments injected by
|
||||
// external controllers (BYOP/port-forwarding). Nil when no proxy data
|
||||
// is present; encoder skips the field in that case.
|
||||
ProxyPatch *proto.ProxyPatch
|
||||
}
|
||||
|
||||
// EncodeNetworkMapEnvelope converts NetworkMapComponents into the component
|
||||
// wire envelope. The encoder is intentionally non-deterministic: it iterates
|
||||
// Go maps in their native (random) order. Indexes inside the envelope
|
||||
// (peer_indexes, source_group_ids, agent_version_idx, router_peer_indexes)
|
||||
// are self-consistent within a single encode, so the decoder reconstructs
|
||||
// the same typed objects regardless of emit order. Tests that need to
|
||||
// compare envelopes do so semantically via proto round-trip + canonicalize,
|
||||
// not byte-equal.
|
||||
//
|
||||
// Callers must NOT concatenate or merge envelopes from different encodes —
|
||||
// index spaces are local to a single envelope. Delta sync (Step 3+) will
|
||||
// use a different shape for the same reason.
|
||||
func EncodeNetworkMapEnvelope(in ComponentsEnvelopeInput) *proto.NetworkMapEnvelope {
|
||||
c := in.Components
|
||||
|
||||
// Graceful degrade when components is nil — matches the legacy path's
|
||||
// account_components.go:43 behaviour for missing/unvalidated peers
|
||||
// (return a NetworkMap with only Network populated). The receiver gets
|
||||
// an envelope it can decode without crashing; AccountSettings stays
|
||||
// non-nil so client-side dereferences are safe.
|
||||
if c == nil {
|
||||
// Match legacy missing-peer minimum: a NetworkMap with only Network
|
||||
// populated (account_components.go:43). The receiver gets enough to
|
||||
// bootstrap (Network identifier, dns_domain, account_settings) and
|
||||
// nothing else.
|
||||
return &proto.NetworkMapEnvelope{
|
||||
Payload: &proto.NetworkMapEnvelope_Full{
|
||||
Full: &proto.NetworkMapComponentsFull{
|
||||
PeerConfig: in.PeerConfig,
|
||||
DnsDomain: in.DNSDomain,
|
||||
DnsForwarderPort: in.DNSForwarderPort,
|
||||
UserIdClaim: in.UserIDClaim,
|
||||
AccountSettings: &proto.AccountSettingsCompact{},
|
||||
ProxyPatch: in.ProxyPatch,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Phase 1: build dedup tables. Every routing peer (in c.RouterPeers) and
|
||||
// every regular peer (in c.Peers) must be indexed before any encoder
|
||||
// looks up indexes via e.peerOrder — otherwise routes / routers_map for
|
||||
// peers that exist only in c.RouterPeers would silently lose their
|
||||
// peer_index reference.
|
||||
enc := newComponentEncoder(c)
|
||||
enc.indexAllPeers()
|
||||
routerIdxs := enc.indexRouterPeers(c.RouterPeers)
|
||||
|
||||
// Phase 2: gather every policy that any consumer references (peer-pair
|
||||
// policies + resource-only policies) so encodeResourcePoliciesMap can
|
||||
// translate every *Policy pointer to a wire index.
|
||||
allPolicies := unionPolicies(c.Policies, c.ResourcePoliciesMap)
|
||||
policies, policyToIdxs := enc.encodePolicies(allPolicies)
|
||||
|
||||
// Phase 3: emit. Order of struct field expressions no longer matters:
|
||||
// every encoder either reads from the dedup tables or works on
|
||||
// independent input.
|
||||
full := &proto.NetworkMapComponentsFull{
|
||||
Serial: networkSerial(c.Network),
|
||||
PeerConfig: in.PeerConfig,
|
||||
Network: toAccountNetwork(c.Network),
|
||||
AccountSettings: toAccountSettingsCompact(c.AccountSettings),
|
||||
DnsForwarderPort: in.DNSForwarderPort,
|
||||
UserIdClaim: in.UserIDClaim,
|
||||
ProxyPatch: in.ProxyPatch,
|
||||
DnsSettings: enc.encodeDNSSettings(c.DNSSettings),
|
||||
DnsDomain: in.DNSDomain,
|
||||
CustomZoneDomain: c.CustomZoneDomain,
|
||||
AgentVersions: enc.agentVersions,
|
||||
Peers: enc.peers,
|
||||
RouterPeerIndexes: routerIdxs,
|
||||
Policies: policies,
|
||||
Groups: enc.encodeGroups(),
|
||||
Routes: enc.encodeRoutes(c.Routes),
|
||||
NameserverGroups: enc.encodeNameServerGroups(c.NameServerGroups),
|
||||
AllDnsRecords: encodeSimpleRecords(c.AllDNSRecords),
|
||||
AccountZones: encodeCustomZones(c.AccountZones),
|
||||
NetworkResources: enc.encodeNetworkResources(c.NetworkResources),
|
||||
RoutersMap: enc.encodeRoutersMap(c.RoutersMap),
|
||||
ResourcePoliciesMap: enc.encodeResourcePoliciesMap(c.ResourcePoliciesMap, policyToIdxs),
|
||||
GroupIdToUserIds: enc.encodeGroupIDToUserIDs(c.GroupIDToUserIDs),
|
||||
AllowedUserIds: stringSetToSlice(c.AllowedUserIDs),
|
||||
PostureFailedPeers: enc.encodePostureFailedPeers(c.PostureFailedPeers),
|
||||
}
|
||||
|
||||
return &proto.NetworkMapEnvelope{
|
||||
Payload: &proto.NetworkMapEnvelope_Full{Full: full},
|
||||
}
|
||||
}
|
||||
|
||||
// networkSerial returns c.Network.CurrentSerial() with a nil guard. The
|
||||
// production path always populates c.Network (account_components.go:86), but
|
||||
// the encoder is exported and a hand-built components struct may omit it.
|
||||
func networkSerial(n *types.Network) uint64 {
|
||||
if n == nil {
|
||||
return 0
|
||||
}
|
||||
return n.CurrentSerial()
|
||||
}
|
||||
|
||||
type componentEncoder struct {
|
||||
components *types.NetworkMapComponents
|
||||
|
||||
peerOrder map[string]uint32
|
||||
peers []*proto.PeerCompact
|
||||
|
||||
agentVersionOrder map[string]uint32
|
||||
agentVersions []string
|
||||
}
|
||||
|
||||
func newComponentEncoder(c *types.NetworkMapComponents) *componentEncoder {
|
||||
return &componentEncoder{
|
||||
components: c,
|
||||
peerOrder: make(map[string]uint32, len(c.Peers)),
|
||||
peers: make([]*proto.PeerCompact, 0, len(c.Peers)),
|
||||
agentVersionOrder: make(map[string]uint32),
|
||||
}
|
||||
}
|
||||
|
||||
func (e *componentEncoder) indexAllPeers() {
|
||||
for _, p := range e.components.Peers {
|
||||
if p == nil {
|
||||
continue
|
||||
}
|
||||
e.appendPeer(p)
|
||||
}
|
||||
}
|
||||
|
||||
func (e *componentEncoder) appendPeer(p *nbpeer.Peer) uint32 {
|
||||
if idx, ok := e.peerOrder[p.ID]; ok {
|
||||
return idx
|
||||
}
|
||||
idx := uint32(len(e.peers))
|
||||
e.peerOrder[p.ID] = idx
|
||||
e.peers = append(e.peers, toPeerCompact(p, e.agentVersionIndex(p.Meta.WtVersion)))
|
||||
return idx
|
||||
}
|
||||
|
||||
func (e *componentEncoder) agentVersionIndex(v string) uint32 {
|
||||
if idx, ok := e.agentVersionOrder[v]; ok {
|
||||
return idx
|
||||
}
|
||||
// Lazy-initialise the table with "" at index 0 so the empty string
|
||||
// stays interchangeable with proto3's default uint32=0 — peers without
|
||||
// a WtVersion don't force the table to materialise.
|
||||
if v == "" {
|
||||
idx := uint32(len(e.agentVersions))
|
||||
if idx == 0 {
|
||||
e.agentVersions = append(e.agentVersions, "")
|
||||
}
|
||||
e.agentVersionOrder[""] = idx
|
||||
return idx
|
||||
}
|
||||
if len(e.agentVersions) == 0 {
|
||||
e.agentVersions = append(e.agentVersions, "")
|
||||
e.agentVersionOrder[""] = 0
|
||||
}
|
||||
idx := uint32(len(e.agentVersions))
|
||||
e.agentVersionOrder[v] = idx
|
||||
e.agentVersions = append(e.agentVersions, v)
|
||||
return idx
|
||||
}
|
||||
|
||||
// indexRouterPeers ensures every router peer is in the peer dedup table
|
||||
// (c.RouterPeers may contain peers not in c.Peers when validation rules drop
|
||||
// them) and returns their wire indexes for the RouterPeerIndexes field. Must
|
||||
// run before any encoder that resolves peer ids via e.peerOrder.
|
||||
func (e *componentEncoder) indexRouterPeers(routers map[string]*nbpeer.Peer) []uint32 {
|
||||
if len(routers) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]uint32, 0, len(routers))
|
||||
for _, p := range routers {
|
||||
if p == nil {
|
||||
continue
|
||||
}
|
||||
out = append(out, e.appendPeer(p))
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (e *componentEncoder) encodeGroups() []*proto.GroupCompact {
|
||||
if len(e.components.Groups) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
out := make([]*proto.GroupCompact, 0, len(e.components.Groups))
|
||||
for _, g := range e.components.Groups {
|
||||
if !g.HasSeqID() {
|
||||
continue
|
||||
}
|
||||
peerIdxs := make([]uint32, 0, len(g.Peers))
|
||||
for _, peerID := range g.Peers {
|
||||
if idx, ok := e.peerOrder[peerID]; ok {
|
||||
peerIdxs = append(peerIdxs, idx)
|
||||
}
|
||||
}
|
||||
out = append(out, &proto.GroupCompact{
|
||||
Id: g.AccountSeqID,
|
||||
Name: g.Name,
|
||||
PeerIndexes: peerIdxs,
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// encodePolicies flattens Policy{Rules} → []PolicyCompact. Returns the wire
|
||||
// list and a map from policy pointer to the indexes of its emitted rules in
|
||||
// that list — used by encodeResourcePoliciesMap to translate
|
||||
// ResourcePoliciesMap[resourceID][]*Policy into wire-side indexes.
|
||||
func (e *componentEncoder) encodePolicies(policies []*types.Policy) ([]*proto.PolicyCompact, map[*types.Policy][]uint32) {
|
||||
if len(policies) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
out := make([]*proto.PolicyCompact, 0, len(policies))
|
||||
idxByPolicy := make(map[*types.Policy][]uint32, len(policies))
|
||||
|
||||
for _, pol := range policies {
|
||||
if !pol.HasSeqID() || !pol.Enabled {
|
||||
continue
|
||||
}
|
||||
for _, r := range pol.Rules {
|
||||
if r == nil || !r.Enabled {
|
||||
continue
|
||||
}
|
||||
idxByPolicy[pol] = append(idxByPolicy[pol], uint32(len(out)))
|
||||
out = append(out, e.encodePolicyRule(pol, r))
|
||||
}
|
||||
}
|
||||
return out, idxByPolicy
|
||||
}
|
||||
|
||||
// encodePolicyRule maps a single PolicyRule under pol to a PolicyCompact entry.
|
||||
func (e *componentEncoder) encodePolicyRule(pol *types.Policy, r *types.PolicyRule) *proto.PolicyCompact {
|
||||
return &proto.PolicyCompact{
|
||||
Id: pol.AccountSeqID,
|
||||
Action: networkmap.GetProtoAction(string(r.Action)),
|
||||
Protocol: networkmap.GetProtoProtocol(string(r.Protocol)),
|
||||
Bidirectional: r.Bidirectional,
|
||||
Ports: portsToUint32(r.Ports),
|
||||
PortRanges: portRangesToProto(r.PortRanges),
|
||||
SourceGroupIds: e.groupSeqIDs(r.Sources),
|
||||
DestinationGroupIds: e.groupSeqIDs(r.Destinations),
|
||||
AuthorizedUser: r.AuthorizedUser,
|
||||
AuthorizedGroups: e.encodeAuthorizedGroups(r.AuthorizedGroups),
|
||||
SourceResource: e.resourceToProto(r.SourceResource),
|
||||
DestinationResource: e.resourceToProto(r.DestinationResource),
|
||||
SourcePostureCheckSeqIds: e.postureCheckSeqs(pol.SourcePostureChecks),
|
||||
}
|
||||
}
|
||||
|
||||
// groupSeqIDs maps the xid group IDs in src to their per-account seq ids,
|
||||
// dropping any group that has no seq id assigned.
|
||||
func (e *componentEncoder) groupSeqIDs(src []string) []uint32 {
|
||||
if len(src) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]uint32, 0, len(src))
|
||||
for _, gid := range src {
|
||||
if seq, ok := e.groupSeq(gid); ok {
|
||||
out = append(out, seq)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// unionPolicies merges c.Policies with every policy referenced by
|
||||
// c.ResourcePoliciesMap, deduplicating by pointer identity. Resource-only
|
||||
// policies (relevant to a NetworkResource but not to peer-pair traffic)
|
||||
// only live in ResourcePoliciesMap; without this union step they'd be lost
|
||||
// from the wire and the client's resource-policy lookup would come back
|
||||
// empty.
|
||||
func unionPolicies(policies []*types.Policy, resourcePolicies map[string][]*types.Policy) []*types.Policy {
|
||||
// Fast path: non-router peers have no resource-only policies, so the
|
||||
// "union" is identical to `policies`. Skip the dedup map allocation.
|
||||
if len(resourcePolicies) == 0 {
|
||||
return policies
|
||||
}
|
||||
seen := make(map[*types.Policy]struct{}, len(policies))
|
||||
out := make([]*types.Policy, 0, len(policies))
|
||||
for _, p := range policies {
|
||||
if p == nil {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[p]; ok {
|
||||
continue
|
||||
}
|
||||
seen[p] = struct{}{}
|
||||
out = append(out, p)
|
||||
}
|
||||
for _, list := range resourcePolicies {
|
||||
for _, p := range list {
|
||||
if p == nil {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[p]; ok {
|
||||
continue
|
||||
}
|
||||
seen[p] = struct{}{}
|
||||
out = append(out, p)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// encodeAuthorizedGroups translates rule.AuthorizedGroups (map keyed by
|
||||
// group xid → local-user names) to the wire form (map keyed by group
|
||||
// account_seq_id → UserNameList). Groups without a seq id are dropped —
|
||||
// matches how source/destination group references handle the same case.
|
||||
func (e *componentEncoder) encodeAuthorizedGroups(m map[string][]string) map[uint32]*proto.UserNameList {
|
||||
if len(m) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make(map[uint32]*proto.UserNameList, len(m))
|
||||
for groupID, names := range m {
|
||||
seq, ok := e.groupSeq(groupID)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
out[seq] = &proto.UserNameList{Names: append([]string(nil), names...)}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (e *componentEncoder) groupSeq(groupID string) (uint32, bool) {
|
||||
g, ok := e.components.Groups[groupID]
|
||||
if !ok || !g.HasSeqID() {
|
||||
return 0, false
|
||||
}
|
||||
return g.AccountSeqID, true
|
||||
}
|
||||
|
||||
// resourceToProto translates types.Resource for the wire. For peer-typed
|
||||
// resources the peer id is converted to a peer index into the envelope's
|
||||
// peers array. For other resource types only the type string is shipped
|
||||
// today (Calculate's resource-typed rule path consults SourceResource only
|
||||
// for "peer" — other types fall through to group-based lookup).
|
||||
func (e *componentEncoder) resourceToProto(r types.Resource) *proto.ResourceCompact {
|
||||
if r.ID == "" && r.Type == "" {
|
||||
return nil
|
||||
}
|
||||
out := &proto.ResourceCompact{Type: string(r.Type)}
|
||||
if r.Type == types.ResourceTypePeer && r.ID != "" {
|
||||
if idx, ok := e.peerOrder[r.ID]; ok {
|
||||
out.PeerIndexSet = true
|
||||
out.PeerIndex = idx
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// postureCheckSeqs translates a slice of posture-check xids to their
|
||||
// per-account integer ids using the NetworkMapComponents.PostureCheckXIDToSeq
|
||||
// lookup. Unresolvable xids are silently dropped — matches how group/peer
|
||||
// references handle the same case.
|
||||
func (e *componentEncoder) postureCheckSeqs(xids []string) []uint32 {
|
||||
if len(xids) == 0 || len(e.components.PostureCheckXIDToSeq) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]uint32, 0, len(xids))
|
||||
for _, xid := range xids {
|
||||
if seq, ok := e.components.PostureCheckXIDToSeq[xid]; ok {
|
||||
out = append(out, seq)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// networkSeq translates a Network xid to its per-account integer id using
|
||||
// the NetworkMapComponents.NetworkXIDToSeq lookup. Returns (0,false) when
|
||||
// the xid isn't known — callers decide whether to skip the parent record.
|
||||
func (e *componentEncoder) networkSeq(xid string) (uint32, bool) {
|
||||
if xid == "" {
|
||||
return 0, false
|
||||
}
|
||||
seq, ok := e.components.NetworkXIDToSeq[xid]
|
||||
if !ok || seq == 0 {
|
||||
return 0, false
|
||||
}
|
||||
return seq, true
|
||||
}
|
||||
|
||||
func (e *componentEncoder) encodeDNSSettings(s *types.DNSSettings) *proto.DNSSettingsCompact {
|
||||
if s == nil || len(s.DisabledManagementGroups) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := &proto.DNSSettingsCompact{
|
||||
DisabledManagementGroupIds: make([]uint32, 0, len(s.DisabledManagementGroups)),
|
||||
}
|
||||
for _, gid := range s.DisabledManagementGroups {
|
||||
if seq, ok := e.groupSeq(gid); ok {
|
||||
out.DisabledManagementGroupIds = append(out.DisabledManagementGroupIds, seq)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (e *componentEncoder) encodeRoutes(routes []*nbroute.Route) []*proto.RouteRaw {
|
||||
if len(routes) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]*proto.RouteRaw, 0, len(routes))
|
||||
for _, r := range routes {
|
||||
if r == nil {
|
||||
continue
|
||||
}
|
||||
rr := &proto.RouteRaw{
|
||||
Id: r.AccountSeqID,
|
||||
NetId: string(r.NetID),
|
||||
Description: r.Description,
|
||||
KeepRoute: r.KeepRoute,
|
||||
NetworkType: int32(r.NetworkType),
|
||||
Masquerade: r.Masquerade,
|
||||
Metric: int32(r.Metric),
|
||||
Enabled: r.Enabled,
|
||||
SkipAutoApply: r.SkipAutoApply,
|
||||
Domains: r.Domains.ToPunycodeList(),
|
||||
GroupIds: e.groupIDsToSeq(r.Groups),
|
||||
AccessControlGroupIds: e.groupIDsToSeq(r.AccessControlGroups),
|
||||
PeerGroupIds: e.groupIDsToSeq(r.PeerGroups),
|
||||
}
|
||||
if r.Network.IsValid() {
|
||||
rr.NetworkCidr = r.Network.String()
|
||||
}
|
||||
if r.Peer != "" {
|
||||
if idx, ok := e.peerOrder[r.Peer]; ok {
|
||||
rr.PeerIndexSet = true
|
||||
rr.PeerIndex = idx
|
||||
}
|
||||
}
|
||||
out = append(out, rr)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (e *componentEncoder) groupIDsToSeq(groupIDs []string) []uint32 {
|
||||
if len(groupIDs) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]uint32, 0, len(groupIDs))
|
||||
for _, gid := range groupIDs {
|
||||
if seq, ok := e.groupSeq(gid); ok {
|
||||
out = append(out, seq)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (e *componentEncoder) encodeNameServerGroups(nsgs []*nbdns.NameServerGroup) []*proto.NameServerGroupRaw {
|
||||
if len(nsgs) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]*proto.NameServerGroupRaw, 0, len(nsgs))
|
||||
for _, nsg := range nsgs {
|
||||
if nsg == nil {
|
||||
continue
|
||||
}
|
||||
entry := &proto.NameServerGroupRaw{
|
||||
Id: nsg.AccountSeqID,
|
||||
Name: nsg.Name,
|
||||
Description: nsg.Description,
|
||||
Nameservers: encodeNameServers(nsg.NameServers),
|
||||
GroupIds: e.groupIDsToSeq(nsg.Groups),
|
||||
Primary: nsg.Primary,
|
||||
Domains: nsg.Domains,
|
||||
Enabled: nsg.Enabled,
|
||||
SearchDomainsEnabled: nsg.SearchDomainsEnabled,
|
||||
}
|
||||
out = append(out, entry)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func encodeNameServers(servers []nbdns.NameServer) []*proto.NameServer {
|
||||
if len(servers) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]*proto.NameServer, 0, len(servers))
|
||||
for _, s := range servers {
|
||||
out = append(out, &proto.NameServer{
|
||||
IP: s.IP.String(),
|
||||
NSType: int64(s.NSType),
|
||||
Port: int64(s.Port),
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func encodeSimpleRecords(records []nbdns.SimpleRecord) []*proto.SimpleRecord {
|
||||
if len(records) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]*proto.SimpleRecord, 0, len(records))
|
||||
for _, r := range records {
|
||||
out = append(out, &proto.SimpleRecord{
|
||||
Name: r.Name,
|
||||
Type: int64(r.Type),
|
||||
Class: r.Class,
|
||||
TTL: int64(r.TTL),
|
||||
RData: r.RData,
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func encodeCustomZones(zones []nbdns.CustomZone) []*proto.CustomZone {
|
||||
if len(zones) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]*proto.CustomZone, 0, len(zones))
|
||||
for _, z := range zones {
|
||||
out = append(out, &proto.CustomZone{
|
||||
Domain: z.Domain,
|
||||
Records: encodeSimpleRecords(z.Records),
|
||||
SearchDomainDisabled: z.SearchDomainDisabled,
|
||||
NonAuthoritative: z.NonAuthoritative,
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (e *componentEncoder) encodeNetworkResources(resources []*resourceTypes.NetworkResource) []*proto.NetworkResourceRaw {
|
||||
if len(resources) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]*proto.NetworkResourceRaw, 0, len(resources))
|
||||
for _, r := range resources {
|
||||
if r == nil {
|
||||
continue
|
||||
}
|
||||
entry := &proto.NetworkResourceRaw{
|
||||
Id: r.AccountSeqID,
|
||||
Name: r.Name,
|
||||
Description: r.Description,
|
||||
Type: string(r.Type),
|
||||
Address: r.Address,
|
||||
DomainValue: r.Domain,
|
||||
Enabled: r.Enabled,
|
||||
}
|
||||
if seq, ok := e.networkSeq(r.NetworkID); ok {
|
||||
entry.NetworkSeq = seq
|
||||
}
|
||||
if r.Prefix.IsValid() {
|
||||
entry.PrefixCidr = r.Prefix.String()
|
||||
}
|
||||
out = append(out, entry)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (e *componentEncoder) encodeRoutersMap(routersMap map[string]map[string]*routerTypes.NetworkRouter) map[uint32]*proto.NetworkRouterList {
|
||||
if len(routersMap) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make(map[uint32]*proto.NetworkRouterList, len(routersMap))
|
||||
for networkXID, routers := range routersMap {
|
||||
if len(routers) == 0 {
|
||||
continue
|
||||
}
|
||||
netSeq, ok := e.networkSeq(networkXID)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
entries := make([]*proto.NetworkRouterEntry, 0, len(routers))
|
||||
for peerID, r := range routers {
|
||||
if r == nil {
|
||||
continue
|
||||
}
|
||||
entry := &proto.NetworkRouterEntry{
|
||||
Id: r.AccountSeqID,
|
||||
PeerGroupIds: e.groupIDsToSeq(r.PeerGroups),
|
||||
Masquerade: r.Masquerade,
|
||||
Metric: int32(r.Metric),
|
||||
Enabled: r.Enabled,
|
||||
}
|
||||
if idx, ok := e.peerOrder[peerID]; ok {
|
||||
entry.PeerIndexSet = true
|
||||
entry.PeerIndex = idx
|
||||
}
|
||||
entries = append(entries, entry)
|
||||
}
|
||||
out[netSeq] = &proto.NetworkRouterList{Entries: entries}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (e *componentEncoder) encodeResourcePoliciesMap(rpm map[string][]*types.Policy, policyToIdxs map[*types.Policy][]uint32) map[uint32]*proto.PolicyIndexes {
|
||||
if len(rpm) == 0 {
|
||||
return nil
|
||||
}
|
||||
// resourceXIDToSeq is local to one encode — built from components.NetworkResources
|
||||
// (small slice). Network resources without seq id are dropped, matching how
|
||||
// other components-without-seq are silently filtered.
|
||||
resourceXIDToSeq := make(map[string]uint32, len(e.components.NetworkResources))
|
||||
for _, r := range e.components.NetworkResources {
|
||||
if r != nil && r.AccountSeqID != 0 {
|
||||
resourceXIDToSeq[r.ID] = r.AccountSeqID
|
||||
}
|
||||
}
|
||||
out := make(map[uint32]*proto.PolicyIndexes, len(rpm))
|
||||
for resourceXID, policies := range rpm {
|
||||
seq, ok := resourceXIDToSeq[resourceXID]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
idxs := make([]uint32, 0, len(policies)*2)
|
||||
for _, pol := range policies {
|
||||
idxs = append(idxs, policyToIdxs[pol]...)
|
||||
}
|
||||
if len(idxs) == 0 {
|
||||
continue
|
||||
}
|
||||
out[seq] = &proto.PolicyIndexes{Indexes: idxs}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (e *componentEncoder) encodeGroupIDToUserIDs(m map[string][]string) map[uint32]*proto.UserIDList {
|
||||
if len(m) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make(map[uint32]*proto.UserIDList, len(m))
|
||||
for groupID, userIDs := range m {
|
||||
seq, ok := e.groupSeq(groupID)
|
||||
if !ok || len(userIDs) == 0 {
|
||||
continue
|
||||
}
|
||||
out[seq] = &proto.UserIDList{UserIds: userIDs}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func stringSetToSlice(s map[string]struct{}) []string {
|
||||
if len(s) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]string, 0, len(s))
|
||||
for k := range s {
|
||||
out = append(out, k)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (e *componentEncoder) encodePostureFailedPeers(m map[string]map[string]struct{}) map[uint32]*proto.PeerIndexSet {
|
||||
if len(m) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make(map[uint32]*proto.PeerIndexSet, len(m))
|
||||
for checkXID, failedPeerIDs := range m {
|
||||
seq, ok := e.components.PostureCheckXIDToSeq[checkXID]
|
||||
if !ok || seq == 0 {
|
||||
continue
|
||||
}
|
||||
idxs := make([]uint32, 0, len(failedPeerIDs))
|
||||
for peerID := range failedPeerIDs {
|
||||
if idx, ok := e.peerOrder[peerID]; ok {
|
||||
idxs = append(idxs, idx)
|
||||
}
|
||||
}
|
||||
if len(idxs) == 0 {
|
||||
continue
|
||||
}
|
||||
out[seq] = &proto.PeerIndexSet{PeerIndexes: idxs}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// toAccountSettingsCompact always returns a non-nil message — the client
|
||||
// dereferences it unconditionally during Calculate(), so a nil here would
|
||||
// crash the receiver. A missing types.AccountSettingsInfo on the server
|
||||
// (which shouldn't happen in production but the encoder is exported)
|
||||
// degrades to login_expiration_enabled = false, which makes
|
||||
// LoginExpired() return false for every peer.
|
||||
func toAccountSettingsCompact(s *types.AccountSettingsInfo) *proto.AccountSettingsCompact {
|
||||
if s == nil {
|
||||
return &proto.AccountSettingsCompact{}
|
||||
}
|
||||
return &proto.AccountSettingsCompact{
|
||||
PeerLoginExpirationEnabled: s.PeerLoginExpirationEnabled,
|
||||
PeerLoginExpirationNs: int64(s.PeerLoginExpiration),
|
||||
}
|
||||
}
|
||||
|
||||
func toAccountNetwork(n *types.Network) *proto.AccountNetwork {
|
||||
if n == nil {
|
||||
return nil
|
||||
}
|
||||
out := &proto.AccountNetwork{
|
||||
Identifier: n.Identifier,
|
||||
NetCidr: n.Net.String(),
|
||||
Dns: n.Dns,
|
||||
Serial: n.CurrentSerial(),
|
||||
}
|
||||
if len(n.NetV6.IP) > 0 {
|
||||
out.NetV6Cidr = n.NetV6.String()
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func toPeerCompact(p *nbpeer.Peer, agentVersionIdx uint32) *proto.PeerCompact {
|
||||
pc := &proto.PeerCompact{
|
||||
WgPubKey: decodeWgKey(p.Key),
|
||||
SshPubKey: []byte(p.SSHKey),
|
||||
DnsLabel: p.DNSLabel,
|
||||
AgentVersionIdx: agentVersionIdx,
|
||||
AddedWithSsoLogin: p.UserID != "",
|
||||
LoginExpirationEnabled: p.LoginExpirationEnabled,
|
||||
SshEnabled: p.SSHEnabled,
|
||||
SupportsIpv6: p.SupportsIPv6(),
|
||||
SupportsSourcePrefixes: p.SupportsSourcePrefixes(),
|
||||
ServerSshAllowed: p.Meta.Flags.ServerSSHAllowed,
|
||||
}
|
||||
if p.LastLogin != nil {
|
||||
pc.LastLoginUnixNano = p.LastLogin.UnixNano()
|
||||
}
|
||||
switch {
|
||||
case !p.IP.IsValid():
|
||||
// leave Ip nil
|
||||
case p.IP.Is4() || p.IP.Is4In6():
|
||||
ip := p.IP.Unmap().As4()
|
||||
pc.Ip = ip[:]
|
||||
default:
|
||||
ip := p.IP.As16()
|
||||
pc.Ip = ip[:]
|
||||
}
|
||||
if p.IPv6.IsValid() {
|
||||
ip := p.IPv6.As16()
|
||||
pc.Ipv6 = ip[:]
|
||||
}
|
||||
return pc
|
||||
}
|
||||
|
||||
// decodeWgKey returns the raw 32 bytes of a base64-encoded WireGuard public
|
||||
// key, or nil for an empty / malformed key.
|
||||
func decodeWgKey(s string) []byte {
|
||||
if s == "" {
|
||||
return nil
|
||||
}
|
||||
out := make([]byte, wgKeyRawLen)
|
||||
n, err := base64.StdEncoding.Decode(out, []byte(s))
|
||||
if err != nil || n != wgKeyRawLen {
|
||||
return nil
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func portsToUint32(ports []string) []uint32 {
|
||||
if len(ports) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]uint32, 0, len(ports))
|
||||
for _, p := range ports {
|
||||
v, err := strconv.ParseUint(p, 10, 16)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
out = append(out, uint32(v))
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func portRangesToProto(ranges []types.RulePortRange) []*proto.PortInfo_Range {
|
||||
if len(ranges) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]*proto.PortInfo_Range, 0, len(ranges))
|
||||
for _, r := range ranges {
|
||||
out = append(out, &proto.PortInfo_Range{
|
||||
Start: uint32(r.Start),
|
||||
End: uint32(r.End),
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
879
management/internals/shared/grpc/components_encoder_test.go
Normal file
879
management/internals/shared/grpc/components_encoder_test.go
Normal file
@@ -0,0 +1,879 @@
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"cmp"
|
||||
"net"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
goproto "google.golang.org/protobuf/proto"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
nbroute "github.com/netbirdio/netbird/route"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
const testWgKeyA = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopq="
|
||||
const testWgKeyB = "BBCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopq="
|
||||
const testWgKeyC = "CBCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopq="
|
||||
|
||||
// canonicalize rewrites a NetworkMapComponentsFull in place into a canonical
|
||||
// form: peers reordered by wg_pub_key, with the rest of the message rewritten
|
||||
// to reference the new peer indexes. Groups, policies, and router indexes are
|
||||
// also sorted. After canonicalize, two envelopes built from the same logical
|
||||
// input compare byte-equal via proto.Equal.
|
||||
//
|
||||
// This lives on the test side — the encoder itself emits in map-iteration
|
||||
// order. Test-side normalization is the contract for "two encodes are
|
||||
// equivalent".
|
||||
func canonicalize(full *proto.NetworkMapComponentsFull) {
|
||||
if full == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Canonicalize agent_versions first: sort the slice and rewrite each
|
||||
// peer's AgentVersionIdx accordingly. The empty placeholder stays at
|
||||
// index 0 by convention.
|
||||
avRemap := make(map[uint32]uint32, len(full.AgentVersions))
|
||||
if len(full.AgentVersions) > 0 {
|
||||
// Pair version → original index, sort, rebuild.
|
||||
type avEntry struct {
|
||||
version string
|
||||
oldIdx uint32
|
||||
}
|
||||
entries := make([]avEntry, len(full.AgentVersions))
|
||||
for i, v := range full.AgentVersions {
|
||||
entries[i] = avEntry{version: v, oldIdx: uint32(i)}
|
||||
}
|
||||
// Empty stays at 0; sort the rest by string. Tiebreaker on oldIdx
|
||||
// keeps the canonicalize output stable when two entries compare
|
||||
// equal (the encoder dedups, but defending against future inputs).
|
||||
slices.SortFunc(entries, func(a, b avEntry) int {
|
||||
if a.version == "" && b.version != "" {
|
||||
return -1
|
||||
}
|
||||
if b.version == "" && a.version != "" {
|
||||
return 1
|
||||
}
|
||||
if c := cmp.Compare(a.version, b.version); c != 0 {
|
||||
return c
|
||||
}
|
||||
return cmp.Compare(a.oldIdx, b.oldIdx)
|
||||
})
|
||||
newVersions := make([]string, len(entries))
|
||||
for newIdx, e := range entries {
|
||||
avRemap[e.oldIdx] = uint32(newIdx)
|
||||
newVersions[newIdx] = e.version
|
||||
}
|
||||
full.AgentVersions = newVersions
|
||||
}
|
||||
for _, p := range full.Peers {
|
||||
if newIdx, ok := avRemap[p.AgentVersionIdx]; ok {
|
||||
p.AgentVersionIdx = newIdx
|
||||
}
|
||||
}
|
||||
|
||||
type peerEntry struct {
|
||||
peer *proto.PeerCompact
|
||||
oldIdx uint32
|
||||
}
|
||||
entries := make([]peerEntry, len(full.Peers))
|
||||
for i, p := range full.Peers {
|
||||
entries[i] = peerEntry{peer: p, oldIdx: uint32(i)}
|
||||
}
|
||||
// DnsLabel is unique per peer; it tiebreaks on equal WgPubKey (e.g. both
|
||||
// nil from malformed keys, or both empty for placeholders).
|
||||
slices.SortFunc(entries, func(a, b peerEntry) int {
|
||||
if c := bytes.Compare(a.peer.WgPubKey, b.peer.WgPubKey); c != 0 {
|
||||
return c
|
||||
}
|
||||
return cmp.Compare(a.peer.DnsLabel, b.peer.DnsLabel)
|
||||
})
|
||||
|
||||
remap := make(map[uint32]uint32, len(entries))
|
||||
newPeers := make([]*proto.PeerCompact, len(entries))
|
||||
for newIdx, e := range entries {
|
||||
remap[e.oldIdx] = uint32(newIdx)
|
||||
newPeers[newIdx] = e.peer
|
||||
}
|
||||
full.Peers = newPeers
|
||||
|
||||
full.RouterPeerIndexes = remapAndSort(full.RouterPeerIndexes, remap)
|
||||
for _, g := range full.Groups {
|
||||
g.PeerIndexes = remapAndSort(g.PeerIndexes, remap)
|
||||
}
|
||||
slices.SortFunc(full.Groups, func(a, b *proto.GroupCompact) int { return cmp.Compare(a.Id, b.Id) })
|
||||
|
||||
for _, r := range full.Routes {
|
||||
if r.PeerIndexSet {
|
||||
if newIdx, ok := remap[r.PeerIndex]; ok {
|
||||
r.PeerIndex = newIdx
|
||||
}
|
||||
}
|
||||
slices.Sort(r.GroupIds)
|
||||
slices.Sort(r.AccessControlGroupIds)
|
||||
slices.Sort(r.PeerGroupIds)
|
||||
}
|
||||
slices.SortFunc(full.Routes, func(a, b *proto.RouteRaw) int { return cmp.Compare(a.Id, b.Id) })
|
||||
|
||||
for _, list := range full.RoutersMap {
|
||||
for _, entry := range list.Entries {
|
||||
if entry.PeerIndexSet {
|
||||
if newIdx, ok := remap[entry.PeerIndex]; ok {
|
||||
entry.PeerIndex = newIdx
|
||||
}
|
||||
}
|
||||
slices.Sort(entry.PeerGroupIds)
|
||||
}
|
||||
slices.SortFunc(list.Entries, func(a, b *proto.NetworkRouterEntry) int { return cmp.Compare(a.Id, b.Id) })
|
||||
}
|
||||
|
||||
for _, set := range full.PostureFailedPeers {
|
||||
set.PeerIndexes = remapAndSort(set.PeerIndexes, remap)
|
||||
}
|
||||
|
||||
for _, p := range full.Policies {
|
||||
slices.Sort(p.SourceGroupIds)
|
||||
slices.Sort(p.DestinationGroupIds)
|
||||
}
|
||||
// Sort policies by (Id, source_group_ids, destination_group_ids) so that
|
||||
// multiple PolicyCompact entries sharing the same Id (one per rule, when
|
||||
// a Policy has multiple rules) still get a deterministic order. After
|
||||
// sorting we remap indexes in ResourcePoliciesMap.
|
||||
policyOldOrder := make(map[*proto.PolicyCompact]uint32, len(full.Policies))
|
||||
for i, p := range full.Policies {
|
||||
policyOldOrder[p] = uint32(i)
|
||||
}
|
||||
slices.SortFunc(full.Policies, func(a, b *proto.PolicyCompact) int {
|
||||
if c := cmp.Compare(a.Id, b.Id); c != 0 {
|
||||
return c
|
||||
}
|
||||
if c := slices.Compare(a.SourceGroupIds, b.SourceGroupIds); c != 0 {
|
||||
return c
|
||||
}
|
||||
return slices.Compare(a.DestinationGroupIds, b.DestinationGroupIds)
|
||||
})
|
||||
policyRemap := make(map[uint32]uint32, len(full.Policies))
|
||||
for newIdx, p := range full.Policies {
|
||||
policyRemap[policyOldOrder[p]] = uint32(newIdx)
|
||||
}
|
||||
for _, idxs := range full.ResourcePoliciesMap {
|
||||
idxs.Indexes = remapAndSort(idxs.Indexes, policyRemap)
|
||||
}
|
||||
for _, list := range full.GroupIdToUserIds {
|
||||
slices.Sort(list.UserIds)
|
||||
}
|
||||
slices.Sort(full.AllowedUserIds)
|
||||
}
|
||||
|
||||
func remapAndSort(idxs []uint32, remap map[uint32]uint32) []uint32 {
|
||||
out := make([]uint32, 0, len(idxs))
|
||||
for _, i := range idxs {
|
||||
if newIdx, ok := remap[i]; ok {
|
||||
out = append(out, newIdx)
|
||||
}
|
||||
}
|
||||
slices.Sort(out)
|
||||
return out
|
||||
}
|
||||
|
||||
// envelopesEquivalent decodes both envelopes, canonicalizes them, and reports
|
||||
// whether they're proto.Equal. Use instead of byte-comparing marshaled output:
|
||||
// the encoder is intentionally non-deterministic.
|
||||
func envelopesEquivalent(a, b *proto.NetworkMapEnvelope) bool {
|
||||
canonicalize(a.GetFull())
|
||||
canonicalize(b.GetFull())
|
||||
return goproto.Equal(a, b)
|
||||
}
|
||||
|
||||
func newTestComponents() *types.NetworkMapComponents {
|
||||
peerA := &nbpeer.Peer{
|
||||
ID: "peer-a",
|
||||
Key: testWgKeyA,
|
||||
IP: netip.AddrFrom4([4]byte{100, 64, 0, 1}),
|
||||
DNSLabel: "peera",
|
||||
SSHKey: "ssh-a",
|
||||
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()},
|
||||
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.40.0"},
|
||||
}
|
||||
peerB := &nbpeer.Peer{
|
||||
ID: "peer-b",
|
||||
Key: testWgKeyB,
|
||||
IP: netip.AddrFrom4([4]byte{100, 64, 0, 2}),
|
||||
IPv6: netip.AddrFrom16([16]byte{0xfd, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2}),
|
||||
DNSLabel: "peerb",
|
||||
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.25.0"},
|
||||
}
|
||||
peerC := &nbpeer.Peer{
|
||||
ID: "peer-c",
|
||||
Key: testWgKeyC,
|
||||
IP: netip.AddrFrom4([4]byte{100, 64, 0, 3}),
|
||||
DNSLabel: "peerc",
|
||||
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.40.0"},
|
||||
}
|
||||
|
||||
return &types.NetworkMapComponents{
|
||||
PeerID: "peer-a",
|
||||
Network: &types.Network{
|
||||
Identifier: "net-test",
|
||||
Net: net.IPNet{IP: net.IP{100, 64, 0, 0}, Mask: net.CIDRMask(10, 32)},
|
||||
Serial: 7,
|
||||
},
|
||||
AccountSettings: &types.AccountSettingsInfo{
|
||||
PeerLoginExpirationEnabled: true,
|
||||
PeerLoginExpiration: 2 * time.Hour,
|
||||
},
|
||||
Peers: map[string]*nbpeer.Peer{
|
||||
"peer-a": peerA,
|
||||
"peer-b": peerB,
|
||||
"peer-c": peerC,
|
||||
},
|
||||
Groups: map[string]*types.Group{
|
||||
"group-src": {ID: "group-src", AccountSeqID: 1, Name: "Src", Peers: []string{"peer-a"}},
|
||||
"group-dst": {ID: "group-dst", AccountSeqID: 2, Name: "Dst", Peers: []string{"peer-b", "peer-c"}},
|
||||
},
|
||||
Policies: []*types.Policy{
|
||||
{
|
||||
ID: "pol-1",
|
||||
AccountSeqID: 10,
|
||||
Enabled: true,
|
||||
Rules: []*types.PolicyRule{{
|
||||
ID: "rule-1", Enabled: true, Action: types.PolicyTrafficActionAccept,
|
||||
Protocol: types.PolicyRuleProtocolTCP, Bidirectional: true,
|
||||
Ports: []string{"22", "80"},
|
||||
PortRanges: []types.RulePortRange{{Start: 8000, End: 8100}},
|
||||
Sources: []string{"group-src"},
|
||||
Destinations: []string{"group-dst"},
|
||||
}},
|
||||
},
|
||||
},
|
||||
RouterPeers: map[string]*nbpeer.Peer{"peer-c": peerC},
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncodeNetworkMapEnvelope_Basic(t *testing.T) {
|
||||
c := newTestComponents()
|
||||
env := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{
|
||||
Components: c,
|
||||
DNSDomain: "netbird.cloud",
|
||||
})
|
||||
|
||||
require.NotNil(t, env)
|
||||
full := env.GetFull()
|
||||
require.NotNil(t, full, "envelope must contain Full payload")
|
||||
|
||||
assert.EqualValues(t, 7, full.Serial)
|
||||
assert.Equal(t, "netbird.cloud", full.DnsDomain)
|
||||
|
||||
require.NotNil(t, full.Network)
|
||||
assert.Equal(t, "net-test", full.Network.Identifier)
|
||||
assert.Equal(t, "100.64.0.0/10", full.Network.NetCidr)
|
||||
|
||||
require.NotNil(t, full.AccountSettings)
|
||||
assert.True(t, full.AccountSettings.PeerLoginExpirationEnabled)
|
||||
assert.EqualValues(t, (2 * time.Hour).Nanoseconds(), full.AccountSettings.PeerLoginExpirationNs)
|
||||
|
||||
require.Len(t, full.Peers, 3)
|
||||
byLabel := map[string]*proto.PeerCompact{}
|
||||
for _, p := range full.Peers {
|
||||
assert.Len(t, p.WgPubKey, 32, "wg key must be raw 32 bytes")
|
||||
assert.Len(t, p.Ip, 4, "ipv4 must be raw 4 bytes")
|
||||
byLabel[p.DnsLabel] = p
|
||||
}
|
||||
assert.Len(t, byLabel["peerb"].Ipv6, 16, "peer-b has ipv6 → 16 bytes")
|
||||
}
|
||||
|
||||
func TestEncodeNetworkMapEnvelope_RepeatEncodesEquivalent(t *testing.T) {
|
||||
c := newTestComponents()
|
||||
|
||||
expected := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c})
|
||||
|
||||
// Hammer it 100 times — Go map iteration is randomized per call, so each
|
||||
// run produces different wire bytes, but the canonicalized form must
|
||||
// match.
|
||||
for i := 0; i < 100; i++ {
|
||||
got := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c})
|
||||
require.True(t, envelopesEquivalent(expected, got),
|
||||
"encode #%d must be semantically equivalent to first encode", i)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncodeNetworkMapEnvelope_ConcurrentEncodesEquivalent(t *testing.T) {
|
||||
c := newTestComponents()
|
||||
|
||||
expected := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c})
|
||||
|
||||
const goroutines = 50
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(goroutines)
|
||||
results := make([]*proto.NetworkMapEnvelope, goroutines)
|
||||
for i := 0; i < goroutines; i++ {
|
||||
i := i
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
results[i] = EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c})
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
for i, got := range results {
|
||||
require.NotNil(t, got, "goroutine %d returned nil", i)
|
||||
require.True(t, envelopesEquivalent(expected, got),
|
||||
"goroutine %d produced inequivalent envelope", i)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncodeNetworkMapEnvelope_GroupsByAccountSeqID(t *testing.T) {
|
||||
c := newTestComponents()
|
||||
|
||||
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
|
||||
|
||||
require.Len(t, full.Groups, 2)
|
||||
|
||||
groupByID := map[uint32]*proto.GroupCompact{}
|
||||
for _, g := range full.Groups {
|
||||
groupByID[g.Id] = g
|
||||
}
|
||||
require.Contains(t, groupByID, uint32(1))
|
||||
require.Contains(t, groupByID, uint32(2))
|
||||
assert.Equal(t, "Src", groupByID[1].Name)
|
||||
assert.Equal(t, "Dst", groupByID[2].Name)
|
||||
assert.Len(t, groupByID[1].PeerIndexes, 1)
|
||||
assert.Len(t, groupByID[2].PeerIndexes, 2)
|
||||
}
|
||||
|
||||
func TestEncodeNetworkMapEnvelope_PolicyExpansion(t *testing.T) {
|
||||
c := newTestComponents()
|
||||
|
||||
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
|
||||
|
||||
require.Len(t, full.Policies, 1)
|
||||
pc := full.Policies[0]
|
||||
assert.EqualValues(t, 10, pc.Id)
|
||||
assert.Equal(t, proto.RuleAction_ACCEPT, pc.Action)
|
||||
assert.Equal(t, proto.RuleProtocol_TCP, pc.Protocol)
|
||||
assert.True(t, pc.Bidirectional)
|
||||
assert.Equal(t, []uint32{22, 80}, pc.Ports)
|
||||
require.Len(t, pc.PortRanges, 1)
|
||||
assert.EqualValues(t, 8000, pc.PortRanges[0].Start)
|
||||
assert.EqualValues(t, 8100, pc.PortRanges[0].End)
|
||||
assert.Equal(t, []uint32{1}, pc.SourceGroupIds)
|
||||
assert.Equal(t, []uint32{2}, pc.DestinationGroupIds)
|
||||
}
|
||||
|
||||
func TestEncodeNetworkMapEnvelope_RouterIndexes(t *testing.T) {
|
||||
c := newTestComponents()
|
||||
|
||||
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
|
||||
|
||||
require.Len(t, full.RouterPeerIndexes, 1)
|
||||
idx := full.RouterPeerIndexes[0]
|
||||
require.Less(t, int(idx), len(full.Peers))
|
||||
assert.Equal(t, "peerc", full.Peers[idx].DnsLabel)
|
||||
}
|
||||
|
||||
func TestEncodeNetworkMapEnvelope_AgentVersionDedup(t *testing.T) {
|
||||
c := newTestComponents()
|
||||
|
||||
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
|
||||
|
||||
require.Len(t, full.AgentVersions, 3, "empty placeholder + 2 distinct versions")
|
||||
assert.Equal(t, "", full.AgentVersions[0], "index 0 reserved for empty version")
|
||||
assert.ElementsMatch(t, []string{"0.40.0", "0.25.0"}, full.AgentVersions[1:],
|
||||
"two distinct versions, order depends on map iteration")
|
||||
|
||||
idxByLabel := map[string]uint32{}
|
||||
for _, p := range full.Peers {
|
||||
idxByLabel[p.DnsLabel] = p.AgentVersionIdx
|
||||
}
|
||||
assert.Equal(t, idxByLabel["peera"], idxByLabel["peerc"], "peers with the same agent version share an index")
|
||||
assert.NotEqual(t, idxByLabel["peera"], idxByLabel["peerb"])
|
||||
}
|
||||
|
||||
func TestEncodeNetworkMapEnvelope_DisabledPolicySkipped(t *testing.T) {
|
||||
c := newTestComponents()
|
||||
c.Policies[0].Enabled = false
|
||||
|
||||
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
|
||||
|
||||
assert.Empty(t, full.Policies)
|
||||
}
|
||||
|
||||
func TestEncodeNetworkMapEnvelope_GroupZeroSeqIDSkipped(t *testing.T) {
|
||||
c := newTestComponents()
|
||||
c.Groups["group-src"].AccountSeqID = 0
|
||||
|
||||
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
|
||||
|
||||
require.Len(t, full.Groups, 1, "groups with AccountSeqID=0 are not yet persisted and must be skipped")
|
||||
assert.EqualValues(t, 2, full.Groups[0].Id)
|
||||
|
||||
require.Len(t, full.Policies, 1)
|
||||
pc := full.Policies[0]
|
||||
assert.Empty(t, pc.SourceGroupIds, "rule references a group that was filtered out → no group id on wire")
|
||||
assert.Equal(t, []uint32{2}, pc.DestinationGroupIds)
|
||||
}
|
||||
|
||||
func TestEncodeNetworkMapEnvelope_TwoPeersSameMalformedKey(t *testing.T) {
|
||||
// Both peers have nil WgPubKey after decode; canonicalize must still
|
||||
// produce a stable order using DnsLabel as a tiebreaker, so 100 encodes
|
||||
// canonicalize identically.
|
||||
c := newTestComponents()
|
||||
c.Peers["peer-a"].Key = "garbage-a-!!!"
|
||||
c.Peers["peer-b"].Key = "garbage-b-!!!"
|
||||
|
||||
expected := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c})
|
||||
for i := 0; i < 100; i++ {
|
||||
got := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c})
|
||||
require.True(t, envelopesEquivalent(expected, got),
|
||||
"encode #%d with two same-key peers must canonicalize equivalently", i)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncodeNetworkMapEnvelope_MalformedWgKey(t *testing.T) {
|
||||
c := newTestComponents()
|
||||
c.Peers["peer-a"].Key = "not-base64-!!!"
|
||||
|
||||
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
|
||||
|
||||
require.Len(t, full.Peers, 3)
|
||||
|
||||
var byLabel = map[string]*proto.PeerCompact{}
|
||||
for _, p := range full.Peers {
|
||||
byLabel[p.DnsLabel] = p
|
||||
}
|
||||
assert.Nil(t, byLabel["peera"].WgPubKey, "peer with malformed key encodes nil WgPubKey")
|
||||
assert.Len(t, byLabel["peerb"].WgPubKey, 32, "other peers retain their key")
|
||||
}
|
||||
|
||||
func TestEncodeNetworkMapEnvelope_IPv6OnlyPeer(t *testing.T) {
|
||||
c := newTestComponents()
|
||||
v6Only := &nbpeer.Peer{
|
||||
ID: "peer-v6",
|
||||
Key: testWgKeyA,
|
||||
IPv6: netip.AddrFrom16([16]byte{0xfd, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 9}),
|
||||
DNSLabel: "peerv6",
|
||||
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.40.0"},
|
||||
}
|
||||
c.Peers["peer-v6"] = v6Only
|
||||
|
||||
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
|
||||
|
||||
var found *proto.PeerCompact
|
||||
for _, p := range full.Peers {
|
||||
if p.DnsLabel == "peerv6" {
|
||||
found = p
|
||||
}
|
||||
}
|
||||
require.NotNil(t, found, "ipv6-only peer must be present")
|
||||
assert.Empty(t, found.Ip, "no IPv4 address → empty Ip")
|
||||
assert.Len(t, found.Ipv6, 16)
|
||||
}
|
||||
|
||||
func TestEncodeNetworkMapEnvelope_PeerWithoutIP(t *testing.T) {
|
||||
c := newTestComponents()
|
||||
c.Peers["peer-noip"] = &nbpeer.Peer{
|
||||
ID: "peer-noip",
|
||||
Key: testWgKeyA,
|
||||
DNSLabel: "peernoip",
|
||||
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.40.0"},
|
||||
}
|
||||
|
||||
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
|
||||
|
||||
var found *proto.PeerCompact
|
||||
for _, p := range full.Peers {
|
||||
if p.DnsLabel == "peernoip" {
|
||||
found = p
|
||||
}
|
||||
}
|
||||
require.NotNil(t, found)
|
||||
assert.Empty(t, found.Ip)
|
||||
assert.Empty(t, found.Ipv6)
|
||||
}
|
||||
|
||||
func TestEncodeNetworkMapEnvelope_EmptyInput(t *testing.T) {
|
||||
c := &types.NetworkMapComponents{
|
||||
Network: &types.Network{Identifier: "x", Net: net.IPNet{IP: net.IP{100, 64, 0, 0}, Mask: net.CIDRMask(10, 32)}},
|
||||
}
|
||||
|
||||
env := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c})
|
||||
|
||||
full := env.GetFull()
|
||||
require.NotNil(t, full)
|
||||
assert.Empty(t, full.Peers)
|
||||
assert.Empty(t, full.Groups)
|
||||
assert.Empty(t, full.Policies)
|
||||
assert.Empty(t, full.RouterPeerIndexes)
|
||||
require.NotNil(t, full.AccountSettings, "AccountSettingsCompact must always be emitted (client dereferences it unconditionally)")
|
||||
}
|
||||
|
||||
func TestEncodeNetworkMapEnvelope_PeerLoginExpirationFields(t *testing.T) {
|
||||
c := newTestComponents()
|
||||
now := time.Date(2024, 1, 2, 3, 4, 5, 0, time.UTC)
|
||||
c.Peers["peer-a"].UserID = "user-1"
|
||||
c.Peers["peer-a"].LoginExpirationEnabled = true
|
||||
c.Peers["peer-a"].LastLogin = &now
|
||||
|
||||
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
|
||||
|
||||
var pa *proto.PeerCompact
|
||||
for _, p := range full.Peers {
|
||||
if p.DnsLabel == "peera" {
|
||||
pa = p
|
||||
}
|
||||
}
|
||||
require.NotNil(t, pa)
|
||||
assert.True(t, pa.AddedWithSsoLogin)
|
||||
assert.True(t, pa.LoginExpirationEnabled)
|
||||
assert.Equal(t, now.UnixNano(), pa.LastLoginUnixNano)
|
||||
|
||||
// peer-b has no UserID and no LastLogin → all fields zero-value.
|
||||
var pb *proto.PeerCompact
|
||||
for _, p := range full.Peers {
|
||||
if p.DnsLabel == "peerb" {
|
||||
pb = p
|
||||
}
|
||||
}
|
||||
require.NotNil(t, pb)
|
||||
assert.False(t, pb.AddedWithSsoLogin)
|
||||
assert.False(t, pb.LoginExpirationEnabled)
|
||||
assert.Zero(t, pb.LastLoginUnixNano)
|
||||
}
|
||||
|
||||
func TestEncodeNetworkMapEnvelope_RoutesRoundTrip(t *testing.T) {
|
||||
c := newTestComponents()
|
||||
c.Routes = []*nbroute.Route{
|
||||
{
|
||||
ID: "route-peer",
|
||||
AccountSeqID: 100,
|
||||
NetID: "net-A",
|
||||
Description: "via peer-c",
|
||||
Network: netip.MustParsePrefix("10.0.0.0/16"),
|
||||
Peer: "peer-c", // peer ID, not WG key
|
||||
Groups: []string{"group-src"},
|
||||
AccessControlGroups: []string{"group-dst"},
|
||||
Enabled: true,
|
||||
},
|
||||
{
|
||||
ID: "route-peergroup",
|
||||
AccountSeqID: 101,
|
||||
NetID: "net-B",
|
||||
Network: netip.MustParsePrefix("10.1.0.0/16"),
|
||||
PeerGroups: []string{"group-src", "group-dst"},
|
||||
Enabled: true,
|
||||
},
|
||||
{
|
||||
ID: "route-no-seq",
|
||||
AccountSeqID: 0, // unset — should still ship (no group seq filter on routes)
|
||||
Network: netip.MustParsePrefix("10.2.0.0/16"),
|
||||
Enabled: true,
|
||||
},
|
||||
}
|
||||
|
||||
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
|
||||
|
||||
require.Len(t, full.Routes, 3)
|
||||
byNetID := map[string]*proto.RouteRaw{}
|
||||
for _, r := range full.Routes {
|
||||
byNetID[r.NetId] = r
|
||||
}
|
||||
|
||||
r1 := byNetID["net-A"]
|
||||
require.NotNil(t, r1)
|
||||
assert.True(t, r1.PeerIndexSet, "route with peer must set peer_index_set")
|
||||
require.Less(t, int(r1.PeerIndex), len(full.Peers))
|
||||
assert.Equal(t, "peerc", full.Peers[r1.PeerIndex].DnsLabel)
|
||||
assert.Equal(t, []uint32{1}, r1.GroupIds, "group-src has AccountSeqID 1")
|
||||
assert.Equal(t, []uint32{2}, r1.AccessControlGroupIds, "group-dst has AccountSeqID 2")
|
||||
assert.Empty(t, r1.PeerGroupIds)
|
||||
|
||||
r2 := byNetID["net-B"]
|
||||
require.NotNil(t, r2)
|
||||
assert.False(t, r2.PeerIndexSet, "route with peer_groups must NOT set peer_index_set")
|
||||
assert.ElementsMatch(t, []uint32{1, 2}, r2.PeerGroupIds)
|
||||
}
|
||||
|
||||
func TestEncodeNetworkMapEnvelope_RouteWithMissingPeerLeavesIndexUnset(t *testing.T) {
|
||||
c := newTestComponents()
|
||||
c.Routes = []*nbroute.Route{{
|
||||
ID: "route-x",
|
||||
AccountSeqID: 100,
|
||||
Peer: "peer-not-in-components",
|
||||
Network: netip.MustParsePrefix("10.0.0.0/16"),
|
||||
Enabled: true,
|
||||
}}
|
||||
|
||||
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
|
||||
|
||||
require.Len(t, full.Routes, 1)
|
||||
assert.False(t, full.Routes[0].PeerIndexSet,
|
||||
"missing peer reference must not pretend to point at peer index 0")
|
||||
}
|
||||
|
||||
func TestEncodeNetworkMapEnvelope_ResourceOnlyPolicyShippedAndIndexed(t *testing.T) {
|
||||
c := newTestComponents()
|
||||
// Policy that exists ONLY in ResourcePoliciesMap, not in c.Policies. This
|
||||
// is the I1 case — without unionPolicies the encoder would silently
|
||||
// drop it from the wire.
|
||||
resourceOnlyPolicy := &types.Policy{
|
||||
ID: "pol-resource", AccountSeqID: 99, Enabled: true,
|
||||
Rules: []*types.PolicyRule{{
|
||||
ID: "rule-r", Enabled: true, Action: types.PolicyTrafficActionAccept,
|
||||
Protocol: types.PolicyRuleProtocolTCP,
|
||||
Sources: []string{"group-src"},
|
||||
Destinations: []string{"group-dst"},
|
||||
}},
|
||||
}
|
||||
c.ResourcePoliciesMap = map[string][]*types.Policy{
|
||||
"resource-x": {c.Policies[0], resourceOnlyPolicy}, // shared + resource-only
|
||||
}
|
||||
// Resource must appear in components.NetworkResources with a seq id —
|
||||
// encoder uses that to translate the xid map key to uint32.
|
||||
c.NetworkResources = []*resourceTypes.NetworkResource{
|
||||
{ID: "resource-x", AccountSeqID: 77, Name: "res-x", Enabled: true},
|
||||
}
|
||||
|
||||
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
|
||||
|
||||
require.Len(t, full.Policies, 2, "encoded policies must include both peer-traffic and resource-only")
|
||||
|
||||
policyByID := map[uint32]*proto.PolicyCompact{}
|
||||
policyIdxByID := map[uint32]uint32{}
|
||||
for i, p := range full.Policies {
|
||||
policyByID[p.Id] = p
|
||||
policyIdxByID[p.Id] = uint32(i)
|
||||
}
|
||||
require.Contains(t, policyByID, uint32(10), "original peer-traffic policy id 10")
|
||||
require.Contains(t, policyByID, uint32(99), "resource-only policy id 99")
|
||||
|
||||
require.Contains(t, full.ResourcePoliciesMap, uint32(77))
|
||||
idxs := full.ResourcePoliciesMap[77].Indexes
|
||||
require.Len(t, idxs, 2)
|
||||
assert.ElementsMatch(t, []uint32{policyIdxByID[10], policyIdxByID[99]}, idxs,
|
||||
"resource policies map must reference both wire policy indexes")
|
||||
}
|
||||
|
||||
func TestEncodeNetworkMapEnvelope_NameServerGroups(t *testing.T) {
|
||||
c := newTestComponents()
|
||||
c.NameServerGroups = []*nbdns.NameServerGroup{{
|
||||
ID: "nsg-1", AccountSeqID: 50, Name: "Main", Description: "primary",
|
||||
NameServers: []nbdns.NameServer{{
|
||||
IP: netip.MustParseAddr("8.8.8.8"), NSType: nbdns.UDPNameServerType, Port: 53,
|
||||
}},
|
||||
Groups: []string{"group-src", "group-not-persisted"},
|
||||
Primary: true, Enabled: true,
|
||||
Domains: []string{"corp.example"},
|
||||
}}
|
||||
c.Groups["group-not-persisted"] = &types.Group{ID: "group-not-persisted", AccountSeqID: 0, Peers: []string{}}
|
||||
|
||||
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
|
||||
|
||||
require.Len(t, full.NameserverGroups, 1)
|
||||
nsg := full.NameserverGroups[0]
|
||||
assert.EqualValues(t, 50, nsg.Id)
|
||||
assert.Equal(t, "Main", nsg.Name)
|
||||
assert.True(t, nsg.Primary)
|
||||
require.Len(t, nsg.Nameservers, 1)
|
||||
assert.Equal(t, "8.8.8.8", nsg.Nameservers[0].IP)
|
||||
assert.Equal(t, []uint32{1}, nsg.GroupIds, "group-not-persisted is filtered out (AccountSeqID=0)")
|
||||
}
|
||||
|
||||
func TestEncodeNetworkMapEnvelope_PostureFailedPeers(t *testing.T) {
|
||||
c := newTestComponents()
|
||||
c.PostureCheckXIDToSeq = map[string]uint32{"check-1": 33}
|
||||
c.PostureFailedPeers = map[string]map[string]struct{}{
|
||||
"check-1": {
|
||||
"peer-a": {},
|
||||
"peer-b": {},
|
||||
"peer-not-in-account": {},
|
||||
},
|
||||
}
|
||||
|
||||
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
|
||||
|
||||
require.Contains(t, full.PostureFailedPeers, uint32(33))
|
||||
idxs := full.PostureFailedPeers[33].PeerIndexes
|
||||
assert.Len(t, idxs, 2, "missing peer is silently dropped (filterPostureFailedPeers guarantees presence in real data)")
|
||||
}
|
||||
|
||||
func TestEncodeNetworkMapEnvelope_RoutersMap(t *testing.T) {
|
||||
c := newTestComponents()
|
||||
c.NetworkXIDToSeq = map[string]uint32{"net-1": 5}
|
||||
c.RoutersMap = map[string]map[string]*routerTypes.NetworkRouter{
|
||||
"net-1": {
|
||||
"peer-c": {
|
||||
ID: "router-1", AccountSeqID: 200,
|
||||
Peer: "peer-c", Masquerade: true, Metric: 10, Enabled: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
|
||||
|
||||
require.Contains(t, full.RoutersMap, uint32(5))
|
||||
entries := full.RoutersMap[5].Entries
|
||||
require.Len(t, entries, 1)
|
||||
e := entries[0]
|
||||
assert.EqualValues(t, 200, e.Id)
|
||||
assert.True(t, e.PeerIndexSet)
|
||||
require.Less(t, int(e.PeerIndex), len(full.Peers))
|
||||
assert.Equal(t, "peerc", full.Peers[e.PeerIndex].DnsLabel)
|
||||
assert.True(t, e.Masquerade)
|
||||
assert.EqualValues(t, 10, e.Metric)
|
||||
assert.True(t, e.Enabled)
|
||||
}
|
||||
|
||||
func TestEncodeNetworkMapEnvelope_RouterPeerNotInComponentsPeers(t *testing.T) {
|
||||
// Router peer in c.RouterPeers but NOT in c.Peers (validation may have
|
||||
// filtered it). indexRouterPeers runs before encodeRoutersMap, so the
|
||||
// peer_index reference must still resolve.
|
||||
c := newTestComponents()
|
||||
delete(c.Peers, "peer-c")
|
||||
routerPeer := &nbpeer.Peer{
|
||||
ID: "peer-c", Key: testWgKeyC, IP: netip.AddrFrom4([4]byte{100, 64, 0, 3}),
|
||||
DNSLabel: "peerc", Meta: nbpeer.PeerSystemMeta{WtVersion: "0.40.0"},
|
||||
}
|
||||
c.RouterPeers = map[string]*nbpeer.Peer{"peer-c": routerPeer}
|
||||
c.NetworkXIDToSeq = map[string]uint32{"net-1": 5}
|
||||
c.RoutersMap = map[string]map[string]*routerTypes.NetworkRouter{
|
||||
"net-1": {"peer-c": {ID: "r-1", AccountSeqID: 1, Peer: "peer-c", Enabled: true}},
|
||||
}
|
||||
|
||||
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
|
||||
|
||||
require.Contains(t, full.RoutersMap, uint32(5))
|
||||
require.Len(t, full.RoutersMap[5].Entries, 1)
|
||||
e := full.RoutersMap[5].Entries[0]
|
||||
assert.True(t, e.PeerIndexSet, "router peer must be indexed even when not in c.Peers")
|
||||
}
|
||||
|
||||
func TestEncodeNetworkMapEnvelope_DNSSettingsFiltersUnpersistedGroups(t *testing.T) {
|
||||
c := newTestComponents()
|
||||
c.DNSSettings = &types.DNSSettings{
|
||||
DisabledManagementGroups: []string{"group-src", "group-missing", "group-no-seq"},
|
||||
}
|
||||
c.Groups["group-no-seq"] = &types.Group{ID: "group-no-seq", AccountSeqID: 0}
|
||||
|
||||
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
|
||||
|
||||
require.NotNil(t, full.DnsSettings)
|
||||
assert.Equal(t, []uint32{1}, full.DnsSettings.DisabledManagementGroupIds,
|
||||
"only group-src (AccountSeqID=1) survives — missing and unpersisted are dropped")
|
||||
}
|
||||
|
||||
func TestEncodeNetworkMapEnvelope_GroupIDToUserIDs(t *testing.T) {
|
||||
c := newTestComponents()
|
||||
c.GroupIDToUserIDs = map[string][]string{
|
||||
"group-src": {"user-1", "user-2"},
|
||||
"group-no-seq": {"user-3"}, // group not persisted → drop
|
||||
"group-missing": {"user-4"}, // group not in components → drop
|
||||
}
|
||||
c.Groups["group-no-seq"] = &types.Group{ID: "group-no-seq", AccountSeqID: 0}
|
||||
|
||||
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
|
||||
|
||||
require.Len(t, full.GroupIdToUserIds, 1, "only persisted+present groups survive")
|
||||
require.Contains(t, full.GroupIdToUserIds, uint32(1))
|
||||
assert.ElementsMatch(t, []string{"user-1", "user-2"}, full.GroupIdToUserIds[1].UserIds)
|
||||
}
|
||||
|
||||
func TestToProxyPatch_EmptyInputReturnsNil(t *testing.T) {
|
||||
assert.Nil(t, toProxyPatch(nil, "netbird.cloud", false, false))
|
||||
assert.Nil(t, toProxyPatch(&types.NetworkMap{}, "netbird.cloud", false, false),
|
||||
"empty NetworkMap (no peers, rules, routes etc) → nil patch so proto3 omits the field")
|
||||
}
|
||||
|
||||
func TestToProxyPatch_PopulatesAllFields(t *testing.T) {
|
||||
nm := &types.NetworkMap{
|
||||
Peers: []*nbpeer.Peer{{
|
||||
ID: "ext-peer", Key: testWgKeyA, IP: netip.AddrFrom4([4]byte{100, 64, 0, 9}),
|
||||
DNSLabel: "extpeer", Meta: nbpeer.PeerSystemMeta{WtVersion: "0.40.0"},
|
||||
}},
|
||||
FirewallRules: []*types.FirewallRule{{
|
||||
PeerIP: "100.64.0.9", Action: "accept", Direction: 0, Protocol: "tcp",
|
||||
}},
|
||||
}
|
||||
|
||||
patch := toProxyPatch(nm, "netbird.cloud", false, false)
|
||||
|
||||
require.NotNil(t, patch)
|
||||
assert.Len(t, patch.Peers, 1)
|
||||
assert.Len(t, patch.FirewallRules, 1)
|
||||
}
|
||||
|
||||
// TestEncodeNetworkMapEnvelope_ProxyPatchPropagated covers the ProxyPatch
|
||||
// pass-through in both encoder branches (normal path + nil-Components
|
||||
// graceful-degrade). Without this test a regression that drops `ProxyPatch:`
|
||||
// from one of the struct literals in components_encoder.go would slip past CI.
|
||||
func TestEncodeNetworkMapEnvelope_ProxyPatchPropagated(t *testing.T) {
|
||||
patch := &proto.ProxyPatch{
|
||||
ForwardingRules: []*proto.ForwardingRule{{
|
||||
Protocol: proto.RuleProtocol_TCP,
|
||||
DestinationPort: &proto.PortInfo{PortSelection: &proto.PortInfo_Port{Port: 80}},
|
||||
TranslatedAddress: net.IPv4(10, 0, 0, 1).To4(),
|
||||
TranslatedPort: &proto.PortInfo{PortSelection: &proto.PortInfo_Port{Port: 8080}},
|
||||
}},
|
||||
}
|
||||
|
||||
t.Run("normal_path", func(t *testing.T) {
|
||||
c := newTestComponents()
|
||||
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{
|
||||
Components: c,
|
||||
ProxyPatch: patch,
|
||||
}).GetFull()
|
||||
|
||||
require.NotNil(t, full.ProxyPatch, "ProxyPatch must propagate through the normal encode path")
|
||||
assert.Len(t, full.ProxyPatch.ForwardingRules, 1)
|
||||
})
|
||||
|
||||
t.Run("nil_components_graceful_degrade", func(t *testing.T) {
|
||||
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{
|
||||
Components: nil,
|
||||
ProxyPatch: patch,
|
||||
}).GetFull()
|
||||
|
||||
require.NotNil(t, full.ProxyPatch, "ProxyPatch must propagate through the nil-Components branch too")
|
||||
assert.Len(t, full.ProxyPatch.ForwardingRules, 1)
|
||||
})
|
||||
}
|
||||
|
||||
func TestEncodeNetworkMapEnvelope_NilComponentsGracefulDegrade(t *testing.T) {
|
||||
// nil Components → minimal envelope, no crash. Matches the legacy
|
||||
// account_components.go:43 behaviour for missing/unvalidated peers.
|
||||
env := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{
|
||||
Components: nil,
|
||||
DNSDomain: "netbird.cloud",
|
||||
})
|
||||
|
||||
require.NotNil(t, env)
|
||||
full := env.GetFull()
|
||||
require.NotNil(t, full)
|
||||
require.NotNil(t, full.AccountSettings, "AccountSettings must always be non-nil")
|
||||
assert.Equal(t, "netbird.cloud", full.DnsDomain)
|
||||
assert.Empty(t, full.Peers)
|
||||
assert.Empty(t, full.Policies)
|
||||
}
|
||||
|
||||
func TestEncodeNetworkMapEnvelope_AccountSettingsAlwaysEmitted(t *testing.T) {
|
||||
c := &types.NetworkMapComponents{
|
||||
Network: &types.Network{Identifier: "x", Net: net.IPNet{IP: net.IP{100, 64, 0, 0}, Mask: net.CIDRMask(10, 32)}},
|
||||
// AccountSettings deliberately nil
|
||||
}
|
||||
|
||||
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
|
||||
|
||||
require.NotNil(t, full.AccountSettings, "client dereferences AccountSettings unconditionally during Calculate(); a nil here would crash the receiver")
|
||||
assert.False(t, full.AccountSettings.PeerLoginExpirationEnabled)
|
||||
assert.Zero(t, full.AccountSettings.PeerLoginExpirationNs)
|
||||
}
|
||||
193
management/internals/shared/grpc/components_envelope_response.go
Normal file
193
management/internals/shared/grpc/components_envelope_response.go
Normal file
@@ -0,0 +1,193 @@
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
integrationsConfig "github.com/netbirdio/management-integrations/integrations/config"
|
||||
|
||||
"github.com/netbirdio/netbird/client/ssh/auth"
|
||||
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/networkmap"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
// ToComponentSyncResponse builds a SyncResponse carrying the compact
|
||||
// NetworkMapEnvelope for capability-aware peers. The legacy proto.NetworkMap
|
||||
// field is intentionally left empty — capable peers ignore it and the
|
||||
// envelope alone is the authoritative wire shape.
|
||||
//
|
||||
// PeerConfig is computed once server-side using the receiving peer's own
|
||||
// account-level network metadata. EnableSSH inside PeerConfig is left at
|
||||
// peer.SSHEnabled (the peer's local setting); account-policy-driven SSH is
|
||||
// computed by the client from the envelope's GroupIDToUserIDs / AllowedUserIDs
|
||||
// inside Calculate(), so the SshConfig.SshEnabled bit may flip true on the
|
||||
// client even though the server-side PeerConfig reports false.
|
||||
func ToComponentSyncResponse(
|
||||
ctx context.Context,
|
||||
config *nbconfig.Config,
|
||||
httpConfig *nbconfig.HttpServerConfig,
|
||||
deviceFlowConfig *nbconfig.DeviceAuthorizationFlow,
|
||||
peer *nbpeer.Peer,
|
||||
turnCredentials *Token,
|
||||
relayCredentials *Token,
|
||||
components *types.NetworkMapComponents,
|
||||
proxyPatch *types.NetworkMap,
|
||||
dnsName string,
|
||||
checks []*posture.Checks,
|
||||
settings *types.Settings,
|
||||
extraSettings *types.ExtraSettings,
|
||||
peerGroups []string,
|
||||
dnsFwdPort int64,
|
||||
) *proto.SyncResponse {
|
||||
network := networkOrZero(components)
|
||||
enableSSH := computeSSHEnabledForPeer(components, peer)
|
||||
peerConfig := toPeerConfig(peer, network, dnsName, settings, httpConfig, deviceFlowConfig, enableSSH)
|
||||
|
||||
includeIPv6 := peer.SupportsIPv6() && peer.IPv6.IsValid()
|
||||
useSourcePrefixes := peer.SupportsSourcePrefixes()
|
||||
|
||||
userIDClaim := auth.DefaultUserIDClaim
|
||||
if httpConfig != nil && httpConfig.AuthUserIDClaim != "" {
|
||||
userIDClaim = httpConfig.AuthUserIDClaim
|
||||
}
|
||||
|
||||
envelope := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{
|
||||
Components: components,
|
||||
PeerConfig: peerConfig,
|
||||
DNSDomain: dnsName,
|
||||
DNSForwarderPort: dnsFwdPort,
|
||||
UserIDClaim: userIDClaim,
|
||||
ProxyPatch: toProxyPatch(proxyPatch, dnsName, includeIPv6, useSourcePrefixes),
|
||||
})
|
||||
|
||||
resp := &proto.SyncResponse{
|
||||
PeerConfig: peerConfig,
|
||||
NetworkMapEnvelope: envelope,
|
||||
Checks: toProtocolChecks(ctx, checks),
|
||||
}
|
||||
|
||||
nbConfig := toNetbirdConfig(config, turnCredentials, relayCredentials, extraSettings)
|
||||
resp.NetbirdConfig = integrationsConfig.ExtendNetBirdConfig(peer.ID, peerGroups, nbConfig, extraSettings)
|
||||
|
||||
return resp
|
||||
}
|
||||
|
||||
// networkOrZero returns components.Network or a zero Network — toPeerConfig
|
||||
// dereferences network.Net which would panic on nil.
|
||||
func networkOrZero(c *types.NetworkMapComponents) *types.Network {
|
||||
if c == nil || c.Network == nil {
|
||||
return &types.Network{}
|
||||
}
|
||||
return c.Network
|
||||
}
|
||||
|
||||
// toProxyPatch converts a proxy-injected *types.NetworkMap into the wire
|
||||
// patch the components envelope ships alongside. Returns nil when there are
|
||||
// no fragments to merge — proto3 omits a nil message field, so the receiver
|
||||
// sees no patch and skips the merge step entirely.
|
||||
//
|
||||
// We reuse the legacy proto-conversion helpers (toProtocolRoutes,
|
||||
// toProtocolFirewallRules, toProtocolRoutesFirewallRules,
|
||||
// appendRemotePeerConfig, ForwardingRule.ToProto) because the proxy
|
||||
// delivers fragments pre-expanded — there's no raw component shape to
|
||||
// derive them from. Components purity isn't violated: proxy data isn't
|
||||
// policy-graph-derived, it's externally injected post-Calculate, so the
|
||||
// client merges it on top of its locally-computed NetworkMap.
|
||||
func toProxyPatch(nm *types.NetworkMap, dnsName string, includeIPv6, useSourcePrefixes bool) *proto.ProxyPatch {
|
||||
if nm == nil {
|
||||
return nil
|
||||
}
|
||||
if len(nm.Peers) == 0 && len(nm.OfflinePeers) == 0 && len(nm.FirewallRules) == 0 &&
|
||||
len(nm.Routes) == 0 && len(nm.RoutesFirewallRules) == 0 && len(nm.ForwardingRules) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
patch := &proto.ProxyPatch{
|
||||
Peers: networkmap.AppendRemotePeerConfig(nil, nm.Peers, dnsName, includeIPv6),
|
||||
OfflinePeers: networkmap.AppendRemotePeerConfig(nil, nm.OfflinePeers, dnsName, includeIPv6),
|
||||
FirewallRules: networkmap.ToProtocolFirewallRules(nm.FirewallRules, includeIPv6, useSourcePrefixes),
|
||||
Routes: networkmap.ToProtocolRoutes(nm.Routes),
|
||||
RouteFirewallRules: networkmap.ToProtocolRoutesFirewallRules(nm.RoutesFirewallRules),
|
||||
}
|
||||
if len(nm.ForwardingRules) > 0 {
|
||||
patch.ForwardingRules = make([]*proto.ForwardingRule, 0, len(nm.ForwardingRules))
|
||||
for _, r := range nm.ForwardingRules {
|
||||
patch.ForwardingRules = append(patch.ForwardingRules, r.ToProto())
|
||||
}
|
||||
}
|
||||
return patch
|
||||
}
|
||||
|
||||
// computeSSHEnabledForPeer mirrors the SSH-server-activation bit that
|
||||
// Calculate() folds into NetworkMap.EnableSSH. Components-format peers
|
||||
// receive a freshly-computed PeerConfig.SshConfig.SshEnabled at sync time;
|
||||
// without this helper the field would be incorrectly false for any peer
|
||||
// that's the destination of an SSH-enabling policy without having
|
||||
// peer.SSHEnabled set locally.
|
||||
//
|
||||
// Mirrors the two activation paths in Calculate() (`networkmap_components.go`
|
||||
// `getPeerConnectionResources`):
|
||||
// 1. Explicit: rule.Protocol == NetbirdSSH and peer is in the rule's
|
||||
// destinations.
|
||||
// 2. Legacy implicit: rule covers TCP/22 or TCP/22022 (or ALL), peer is in
|
||||
// destinations, AND the peer has SSHEnabled set locally — this is the
|
||||
// "allow-all/TCP-22 implies SSH activation for SSH-capable peers" path.
|
||||
//
|
||||
// The full SSH AuthorizedUsers map is still produced by the client when it
|
||||
// runs Calculate() over the envelope.
|
||||
func computeSSHEnabledForPeer(c *types.NetworkMapComponents, peer *nbpeer.Peer) bool {
|
||||
if c == nil || peer == nil {
|
||||
return false
|
||||
}
|
||||
// Mirror Calculate's `getAllPeersFromGroups` invariant: target peer must
|
||||
// exist in c.Peers, otherwise no rule applies to it.
|
||||
if _, ok := c.Peers[peer.ID]; !ok {
|
||||
return false
|
||||
}
|
||||
for _, policy := range c.Policies {
|
||||
if policy == nil || !policy.Enabled {
|
||||
continue
|
||||
}
|
||||
for _, rule := range policy.Rules {
|
||||
if ruleEnablesSSHForPeer(c, rule, peer) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// ruleEnablesSSHForPeer returns true when rule is active, targets peer, and
|
||||
// either explicitly authorises SSH or covers the legacy TCP/22 path while the
|
||||
// peer itself has SSH enabled locally.
|
||||
func ruleEnablesSSHForPeer(c *types.NetworkMapComponents, rule *types.PolicyRule, peer *nbpeer.Peer) bool {
|
||||
if rule == nil || !rule.Enabled {
|
||||
return false
|
||||
}
|
||||
if !peerInDestinations(c, rule, peer.ID) {
|
||||
return false
|
||||
}
|
||||
if rule.Protocol == types.PolicyRuleProtocolNetbirdSSH {
|
||||
return true
|
||||
}
|
||||
return peer.SSHEnabled && types.PolicyRuleImpliesLegacySSH(rule)
|
||||
}
|
||||
|
||||
// peerInDestinations reports whether peerID is in any of rule.Destinations'
|
||||
// groups (or matches DestinationResource if it's a peer-typed resource —
|
||||
// for non-peer types Calculate falls through to group lookup, so we mirror
|
||||
// that exactly to avoid silent divergence).
|
||||
func peerInDestinations(c *types.NetworkMapComponents, rule *types.PolicyRule, peerID string) bool {
|
||||
if rule.DestinationResource.Type == types.ResourceTypePeer && rule.DestinationResource.ID != "" {
|
||||
return rule.DestinationResource.ID == peerID
|
||||
}
|
||||
for _, groupID := range rule.Destinations {
|
||||
if c.IsPeerInGroup(peerID, groupID) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
@@ -0,0 +1,186 @@
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
// TestComputeSSHEnabledForPeer covers both Calculate-mirroring branches:
|
||||
// explicit NetbirdSSH protocol, and the legacy implicit case where a
|
||||
// TCP/22 (or 22022 / ALL / port-range-covering-22) rule activates SSH when
|
||||
// the destination peer has SSHEnabled=true locally. Belt-and-suspenders for
|
||||
// the B1 fix that the prod-DB equivalence test alone wouldn't have caught
|
||||
// if no account had this combination.
|
||||
func TestComputeSSHEnabledForPeer(t *testing.T) {
|
||||
const targetPeerID = "target"
|
||||
const targetGroupID = "g_dst"
|
||||
|
||||
mkComponents := func(rule *types.PolicyRule, sshEnabled bool) (*types.NetworkMapComponents, *nbpeer.Peer) {
|
||||
peer := &nbpeer.Peer{ID: targetPeerID, SSHEnabled: sshEnabled}
|
||||
group := &types.Group{ID: targetGroupID, Name: "dst", Peers: []string{targetPeerID}}
|
||||
return &types.NetworkMapComponents{
|
||||
Peers: map[string]*nbpeer.Peer{targetPeerID: peer},
|
||||
Groups: map[string]*types.Group{targetGroupID: group},
|
||||
Policies: []*types.Policy{{
|
||||
ID: "p",
|
||||
Enabled: true,
|
||||
Rules: []*types.PolicyRule{rule},
|
||||
}},
|
||||
}, peer
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
peerSSH bool
|
||||
rule types.PolicyRule
|
||||
wantEnabled bool
|
||||
}{
|
||||
{
|
||||
name: "explicit-netbird-ssh-activates-regardless-of-peer-ssh",
|
||||
peerSSH: false,
|
||||
rule: types.PolicyRule{
|
||||
Enabled: true, Protocol: types.PolicyRuleProtocolNetbirdSSH,
|
||||
Destinations: []string{targetGroupID},
|
||||
},
|
||||
wantEnabled: true,
|
||||
},
|
||||
{
|
||||
name: "implicit-tcp-22-with-peer-ssh",
|
||||
peerSSH: true,
|
||||
rule: types.PolicyRule{
|
||||
Enabled: true, Protocol: types.PolicyRuleProtocolTCP, Ports: []string{"22"},
|
||||
Destinations: []string{targetGroupID},
|
||||
},
|
||||
wantEnabled: true,
|
||||
},
|
||||
{
|
||||
name: "implicit-tcp-22-without-peer-ssh-disabled",
|
||||
peerSSH: false,
|
||||
rule: types.PolicyRule{
|
||||
Enabled: true, Protocol: types.PolicyRuleProtocolTCP, Ports: []string{"22"},
|
||||
Destinations: []string{targetGroupID},
|
||||
},
|
||||
wantEnabled: false,
|
||||
},
|
||||
{
|
||||
name: "implicit-tcp-22022-with-peer-ssh",
|
||||
peerSSH: true,
|
||||
rule: types.PolicyRule{
|
||||
Enabled: true, Protocol: types.PolicyRuleProtocolTCP, Ports: []string{"22022"},
|
||||
Destinations: []string{targetGroupID},
|
||||
},
|
||||
wantEnabled: true,
|
||||
},
|
||||
{
|
||||
name: "implicit-all-protocol-with-peer-ssh",
|
||||
peerSSH: true,
|
||||
rule: types.PolicyRule{
|
||||
Enabled: true, Protocol: types.PolicyRuleProtocolALL,
|
||||
Destinations: []string{targetGroupID},
|
||||
},
|
||||
wantEnabled: true,
|
||||
},
|
||||
{
|
||||
name: "implicit-port-range-covers-22",
|
||||
peerSSH: true,
|
||||
rule: types.PolicyRule{
|
||||
Enabled: true,
|
||||
Protocol: types.PolicyRuleProtocolTCP,
|
||||
PortRanges: []types.RulePortRange{{Start: 20, End: 30}},
|
||||
Destinations: []string{targetGroupID},
|
||||
},
|
||||
wantEnabled: true,
|
||||
},
|
||||
{
|
||||
name: "tcp-80-no-ssh",
|
||||
peerSSH: true,
|
||||
rule: types.PolicyRule{
|
||||
Enabled: true, Protocol: types.PolicyRuleProtocolTCP, Ports: []string{"80"},
|
||||
Destinations: []string{targetGroupID},
|
||||
},
|
||||
wantEnabled: false,
|
||||
},
|
||||
{
|
||||
name: "disabled-rule-skipped",
|
||||
peerSSH: true,
|
||||
rule: types.PolicyRule{
|
||||
Enabled: false, Protocol: types.PolicyRuleProtocolNetbirdSSH,
|
||||
Destinations: []string{targetGroupID},
|
||||
},
|
||||
wantEnabled: false,
|
||||
},
|
||||
{
|
||||
name: "peer-not-in-destinations",
|
||||
peerSSH: true,
|
||||
rule: types.PolicyRule{
|
||||
Enabled: true, Protocol: types.PolicyRuleProtocolNetbirdSSH,
|
||||
Destinations: []string{"g_other"}, // target not in this group
|
||||
},
|
||||
wantEnabled: false,
|
||||
},
|
||||
{
|
||||
name: "peer-typed-destination-resource-matches",
|
||||
peerSSH: false,
|
||||
rule: types.PolicyRule{
|
||||
Enabled: true,
|
||||
Protocol: types.PolicyRuleProtocolNetbirdSSH,
|
||||
DestinationResource: types.Resource{ID: targetPeerID, Type: types.ResourceTypePeer},
|
||||
},
|
||||
wantEnabled: true,
|
||||
},
|
||||
{
|
||||
name: "non-peer-destination-resource-falls-through-to-groups",
|
||||
peerSSH: false,
|
||||
rule: types.PolicyRule{
|
||||
Enabled: true,
|
||||
Protocol: types.PolicyRuleProtocolNetbirdSSH,
|
||||
DestinationResource: types.Resource{ID: targetPeerID, Type: "host"}, // wrong type
|
||||
Destinations: []string{targetGroupID}, // saved by group fallback
|
||||
},
|
||||
wantEnabled: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
c, peer := mkComponents(&tc.rule, tc.peerSSH)
|
||||
got := computeSSHEnabledForPeer(c, peer)
|
||||
assert.Equal(t, tc.wantEnabled, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestComputeSSHEnabledForPeer_TargetMissingFromComponents covers the
|
||||
// belt-and-suspenders presence guard mirroring Calculate's
|
||||
// getAllPeersFromGroups invariant.
|
||||
func TestComputeSSHEnabledForPeer_TargetMissingFromComponents(t *testing.T) {
|
||||
peer := &nbpeer.Peer{ID: "missing", SSHEnabled: true}
|
||||
c := &types.NetworkMapComponents{
|
||||
Peers: map[string]*nbpeer.Peer{}, // target peer NOT present
|
||||
Groups: map[string]*types.Group{
|
||||
"g": {ID: "g", Peers: []string{"missing"}},
|
||||
},
|
||||
Policies: []*types.Policy{{
|
||||
ID: "p", Enabled: true,
|
||||
Rules: []*types.PolicyRule{{
|
||||
Enabled: true, Protocol: types.PolicyRuleProtocolNetbirdSSH,
|
||||
Destinations: []string{"g"},
|
||||
}},
|
||||
}},
|
||||
}
|
||||
assert.False(t, computeSSHEnabledForPeer(c, peer),
|
||||
"missing target peer must short-circuit to false, not consult policies")
|
||||
}
|
||||
|
||||
// TestComputeSSHEnabledForPeer_NilInputs guards the cheap nil-checks at
|
||||
// function entry — Calculate doesn't accept nil either, but the helper is
|
||||
// exported indirectly via ToComponentSyncResponse and may receive nil
|
||||
// components on graceful-degrade paths.
|
||||
func TestComputeSSHEnabledForPeer_NilInputs(t *testing.T) {
|
||||
assert.False(t, computeSSHEnabledForPeer(nil, &nbpeer.Peer{ID: "x"}))
|
||||
assert.False(t, computeSSHEnabledForPeer(&types.NetworkMapComponents{}, nil))
|
||||
}
|
||||
@@ -7,23 +7,18 @@ import (
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
goproto "google.golang.org/protobuf/proto"
|
||||
|
||||
integrationsConfig "github.com/netbirdio/management-integrations/integrations/config"
|
||||
|
||||
"github.com/netbirdio/netbird/client/ssh/auth"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache"
|
||||
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
nbroute "github.com/netbirdio/netbird/route"
|
||||
"github.com/netbirdio/netbird/shared/management/networkmap"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
"github.com/netbirdio/netbird/shared/netiputil"
|
||||
"github.com/netbirdio/netbird/shared/sshauth"
|
||||
)
|
||||
|
||||
func toNetbirdConfig(config *nbconfig.Config, turnCredentials *Token, relayToken *Token, extraSettings *types.ExtraSettings) *proto.NetbirdConfig {
|
||||
@@ -138,8 +133,8 @@ func ToSyncResponse(ctx context.Context, config *nbconfig.Config, httpConfig *nb
|
||||
PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings, httpConfig, deviceFlowConfig, networkMap.EnableSSH),
|
||||
NetworkMap: &proto.NetworkMap{
|
||||
Serial: networkMap.Network.CurrentSerial(),
|
||||
Routes: toProtocolRoutes(networkMap.Routes),
|
||||
DNSConfig: toProtocolDNSConfig(networkMap.DNSConfig, dnsCache, dnsFwdPort),
|
||||
Routes: networkmap.ToProtocolRoutes(networkMap.Routes),
|
||||
DNSConfig: networkmap.ToProtocolDNSConfig(networkMap.DNSConfig, dnsCache, dnsFwdPort),
|
||||
PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings, httpConfig, deviceFlowConfig, networkMap.EnableSSH),
|
||||
},
|
||||
Checks: toProtocolChecks(ctx, checks),
|
||||
@@ -152,19 +147,19 @@ func ToSyncResponse(ctx context.Context, config *nbconfig.Config, httpConfig *nb
|
||||
response.NetworkMap.PeerConfig = response.PeerConfig
|
||||
|
||||
remotePeers := make([]*proto.RemotePeerConfig, 0, len(networkMap.Peers)+len(networkMap.OfflinePeers))
|
||||
remotePeers = appendRemotePeerConfig(remotePeers, networkMap.Peers, dnsName, includeIPv6)
|
||||
remotePeers = networkmap.AppendRemotePeerConfig(remotePeers, networkMap.Peers, dnsName, includeIPv6)
|
||||
response.RemotePeers = remotePeers
|
||||
response.NetworkMap.RemotePeers = remotePeers
|
||||
response.RemotePeersIsEmpty = len(remotePeers) == 0
|
||||
response.NetworkMap.RemotePeersIsEmpty = response.RemotePeersIsEmpty
|
||||
|
||||
response.NetworkMap.OfflinePeers = appendRemotePeerConfig(nil, networkMap.OfflinePeers, dnsName, includeIPv6)
|
||||
response.NetworkMap.OfflinePeers = networkmap.AppendRemotePeerConfig(nil, networkMap.OfflinePeers, dnsName, includeIPv6)
|
||||
|
||||
firewallRules := toProtocolFirewallRules(networkMap.FirewallRules, includeIPv6, useSourcePrefixes)
|
||||
firewallRules := networkmap.ToProtocolFirewallRules(networkMap.FirewallRules, includeIPv6, useSourcePrefixes)
|
||||
response.NetworkMap.FirewallRules = firewallRules
|
||||
response.NetworkMap.FirewallRulesIsEmpty = len(firewallRules) == 0
|
||||
|
||||
routesFirewallRules := toProtocolRoutesFirewallRules(networkMap.RoutesFirewallRules)
|
||||
routesFirewallRules := networkmap.ToProtocolRoutesFirewallRules(networkMap.RoutesFirewallRules)
|
||||
response.NetworkMap.RoutesFirewallRules = routesFirewallRules
|
||||
response.NetworkMap.RoutesFirewallRulesIsEmpty = len(routesFirewallRules) == 0
|
||||
|
||||
@@ -177,7 +172,7 @@ func ToSyncResponse(ctx context.Context, config *nbconfig.Config, httpConfig *nb
|
||||
}
|
||||
|
||||
if networkMap.AuthorizedUsers != nil {
|
||||
hashedUsers, machineUsers := buildAuthorizedUsersProto(ctx, networkMap.AuthorizedUsers)
|
||||
hashedUsers, machineUsers := networkmap.BuildAuthorizedUsersProto(ctx, networkMap.AuthorizedUsers)
|
||||
userIDClaim := auth.DefaultUserIDClaim
|
||||
if httpConfig != nil && httpConfig.AuthUserIDClaim != "" {
|
||||
userIDClaim = httpConfig.AuthUserIDClaim
|
||||
@@ -188,78 +183,6 @@ func ToSyncResponse(ctx context.Context, config *nbconfig.Config, httpConfig *nb
|
||||
return response
|
||||
}
|
||||
|
||||
func buildAuthorizedUsersProto(ctx context.Context, authorizedUsers map[string]map[string]struct{}) ([][]byte, map[string]*proto.MachineUserIndexes) {
|
||||
userIDToIndex := make(map[string]uint32)
|
||||
var hashedUsers [][]byte
|
||||
machineUsers := make(map[string]*proto.MachineUserIndexes, len(authorizedUsers))
|
||||
|
||||
for machineUser, users := range authorizedUsers {
|
||||
indexes := make([]uint32, 0, len(users))
|
||||
for userID := range users {
|
||||
idx, exists := userIDToIndex[userID]
|
||||
if !exists {
|
||||
hash, err := sshauth.HashUserID(userID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to hash user id %s: %v", userID, err)
|
||||
continue
|
||||
}
|
||||
idx = uint32(len(hashedUsers))
|
||||
userIDToIndex[userID] = idx
|
||||
hashedUsers = append(hashedUsers, hash[:])
|
||||
}
|
||||
indexes = append(indexes, idx)
|
||||
}
|
||||
machineUsers[machineUser] = &proto.MachineUserIndexes{Indexes: indexes}
|
||||
}
|
||||
|
||||
return hashedUsers, machineUsers
|
||||
}
|
||||
|
||||
func appendRemotePeerConfig(dst []*proto.RemotePeerConfig, peers []*nbpeer.Peer, dnsName string, includeIPv6 bool) []*proto.RemotePeerConfig {
|
||||
for _, rPeer := range peers {
|
||||
allowedIPs := []string{rPeer.IP.String() + "/32"}
|
||||
if includeIPv6 && rPeer.IPv6.IsValid() {
|
||||
allowedIPs = append(allowedIPs, rPeer.IPv6.String()+"/128")
|
||||
}
|
||||
dst = append(dst, &proto.RemotePeerConfig{
|
||||
WgPubKey: rPeer.Key,
|
||||
AllowedIps: allowedIPs,
|
||||
SshConfig: &proto.SSHConfig{SshPubKey: []byte(rPeer.SSHKey)},
|
||||
Fqdn: rPeer.FQDN(dnsName),
|
||||
AgentVersion: rPeer.Meta.WtVersion,
|
||||
})
|
||||
}
|
||||
return dst
|
||||
}
|
||||
|
||||
// toProtocolDNSConfig converts nbdns.Config to proto.DNSConfig using the cache
|
||||
func toProtocolDNSConfig(update nbdns.Config, cache *cache.DNSConfigCache, forwardPort int64) *proto.DNSConfig {
|
||||
protoUpdate := &proto.DNSConfig{
|
||||
ServiceEnable: update.ServiceEnable,
|
||||
CustomZones: make([]*proto.CustomZone, 0, len(update.CustomZones)),
|
||||
NameServerGroups: make([]*proto.NameServerGroup, 0, len(update.NameServerGroups)),
|
||||
ForwarderPort: forwardPort,
|
||||
}
|
||||
|
||||
for _, zone := range update.CustomZones {
|
||||
protoZone := convertToProtoCustomZone(zone)
|
||||
protoUpdate.CustomZones = append(protoUpdate.CustomZones, protoZone)
|
||||
}
|
||||
|
||||
for _, nsGroup := range update.NameServerGroups {
|
||||
cacheKey := nsGroup.ID
|
||||
if cachedGroup, exists := cache.GetNameServerGroup(cacheKey); exists {
|
||||
protoUpdate.NameServerGroups = append(protoUpdate.NameServerGroups, cachedGroup)
|
||||
} else {
|
||||
protoGroup := convertToProtoNameServerGroup(nsGroup)
|
||||
cache.SetNameServerGroup(cacheKey, protoGroup)
|
||||
protoUpdate.NameServerGroups = append(protoUpdate.NameServerGroups, protoGroup)
|
||||
}
|
||||
}
|
||||
|
||||
return protoUpdate
|
||||
}
|
||||
|
||||
func ToResponseProto(configProto nbconfig.Protocol) proto.HostConfig_Protocol {
|
||||
switch configProto {
|
||||
case nbconfig.UDP:
|
||||
@@ -277,204 +200,6 @@ func ToResponseProto(configProto nbconfig.Protocol) proto.HostConfig_Protocol {
|
||||
}
|
||||
}
|
||||
|
||||
func toProtocolRoutes(routes []*nbroute.Route) []*proto.Route {
|
||||
protoRoutes := make([]*proto.Route, 0, len(routes))
|
||||
for _, r := range routes {
|
||||
protoRoutes = append(protoRoutes, toProtocolRoute(r))
|
||||
}
|
||||
return protoRoutes
|
||||
}
|
||||
|
||||
func toProtocolRoute(route *nbroute.Route) *proto.Route {
|
||||
return &proto.Route{
|
||||
ID: string(route.ID),
|
||||
NetID: string(route.NetID),
|
||||
Network: route.Network.String(),
|
||||
Domains: route.Domains.ToPunycodeList(),
|
||||
NetworkType: int64(route.NetworkType),
|
||||
Peer: route.Peer,
|
||||
Metric: int64(route.Metric),
|
||||
Masquerade: route.Masquerade,
|
||||
KeepRoute: route.KeepRoute,
|
||||
SkipAutoApply: route.SkipAutoApply,
|
||||
}
|
||||
}
|
||||
|
||||
// toProtocolFirewallRules converts the firewall rules to the protocol firewall rules.
|
||||
// When useSourcePrefixes is true, the compact SourcePrefixes field is populated
|
||||
// alongside the deprecated PeerIP for forward compatibility.
|
||||
// Wildcard rules ("0.0.0.0") are expanded into separate v4 and v6 SourcePrefixes
|
||||
// when includeIPv6 is true.
|
||||
func toProtocolFirewallRules(rules []*types.FirewallRule, includeIPv6, useSourcePrefixes bool) []*proto.FirewallRule {
|
||||
result := make([]*proto.FirewallRule, 0, len(rules))
|
||||
for i := range rules {
|
||||
rule := rules[i]
|
||||
|
||||
fwRule := &proto.FirewallRule{
|
||||
PolicyID: []byte(rule.PolicyID),
|
||||
PeerIP: rule.PeerIP, //nolint:staticcheck // populated for backward compatibility
|
||||
Direction: getProtoDirection(rule.Direction),
|
||||
Action: getProtoAction(rule.Action),
|
||||
Protocol: getProtoProtocol(rule.Protocol),
|
||||
Port: rule.Port,
|
||||
}
|
||||
|
||||
if useSourcePrefixes && rule.PeerIP != "" {
|
||||
result = append(result, populateSourcePrefixes(fwRule, rule, includeIPv6)...)
|
||||
}
|
||||
|
||||
if shouldUsePortRange(fwRule) {
|
||||
fwRule.PortInfo = rule.PortRange.ToProto()
|
||||
}
|
||||
|
||||
result = append(result, fwRule)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
|
||||
// populateSourcePrefixes sets SourcePrefixes on fwRule and returns any
|
||||
// additional rules needed (e.g. a v6 wildcard clone when the peer IP is unspecified).
|
||||
func populateSourcePrefixes(fwRule *proto.FirewallRule, rule *types.FirewallRule, includeIPv6 bool) []*proto.FirewallRule {
|
||||
addr, err := netip.ParseAddr(rule.PeerIP)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if !addr.IsUnspecified() {
|
||||
fwRule.SourcePrefixes = [][]byte{netiputil.EncodeAddr(addr.Unmap())}
|
||||
return nil
|
||||
}
|
||||
|
||||
// IPv4Unspecified/0 is always valid, error is impossible.
|
||||
v4Wildcard, _ := netiputil.EncodePrefix(netip.PrefixFrom(netip.IPv4Unspecified(), 0))
|
||||
fwRule.SourcePrefixes = [][]byte{v4Wildcard}
|
||||
|
||||
if !includeIPv6 {
|
||||
return nil
|
||||
}
|
||||
|
||||
v6Rule := goproto.Clone(fwRule).(*proto.FirewallRule)
|
||||
v6Rule.PeerIP = "::" //nolint:staticcheck // populated for backward compatibility
|
||||
// IPv6Unspecified/0 is always valid, error is impossible.
|
||||
v6Wildcard, _ := netiputil.EncodePrefix(netip.PrefixFrom(netip.IPv6Unspecified(), 0))
|
||||
v6Rule.SourcePrefixes = [][]byte{v6Wildcard}
|
||||
if shouldUsePortRange(v6Rule) {
|
||||
v6Rule.PortInfo = rule.PortRange.ToProto()
|
||||
}
|
||||
return []*proto.FirewallRule{v6Rule}
|
||||
}
|
||||
|
||||
// getProtoDirection converts the direction to proto.RuleDirection.
|
||||
func getProtoDirection(direction int) proto.RuleDirection {
|
||||
if direction == types.FirewallRuleDirectionOUT {
|
||||
return proto.RuleDirection_OUT
|
||||
}
|
||||
return proto.RuleDirection_IN
|
||||
}
|
||||
|
||||
func toProtocolRoutesFirewallRules(rules []*types.RouteFirewallRule) []*proto.RouteFirewallRule {
|
||||
result := make([]*proto.RouteFirewallRule, len(rules))
|
||||
for i := range rules {
|
||||
rule := rules[i]
|
||||
result[i] = &proto.RouteFirewallRule{
|
||||
SourceRanges: rule.SourceRanges,
|
||||
Action: getProtoAction(rule.Action),
|
||||
Destination: rule.Destination,
|
||||
Protocol: getProtoProtocol(rule.Protocol),
|
||||
PortInfo: getProtoPortInfo(rule),
|
||||
IsDynamic: rule.IsDynamic,
|
||||
Domains: rule.Domains.ToPunycodeList(),
|
||||
PolicyID: []byte(rule.PolicyID),
|
||||
RouteID: string(rule.RouteID),
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// getProtoAction converts the action to proto.RuleAction.
|
||||
func getProtoAction(action string) proto.RuleAction {
|
||||
if action == string(types.PolicyTrafficActionDrop) {
|
||||
return proto.RuleAction_DROP
|
||||
}
|
||||
return proto.RuleAction_ACCEPT
|
||||
}
|
||||
|
||||
// getProtoProtocol converts the protocol to proto.RuleProtocol.
|
||||
func getProtoProtocol(protocol string) proto.RuleProtocol {
|
||||
switch types.PolicyRuleProtocolType(protocol) {
|
||||
case types.PolicyRuleProtocolALL:
|
||||
return proto.RuleProtocol_ALL
|
||||
case types.PolicyRuleProtocolTCP:
|
||||
return proto.RuleProtocol_TCP
|
||||
case types.PolicyRuleProtocolUDP:
|
||||
return proto.RuleProtocol_UDP
|
||||
case types.PolicyRuleProtocolICMP:
|
||||
return proto.RuleProtocol_ICMP
|
||||
default:
|
||||
return proto.RuleProtocol_UNKNOWN
|
||||
}
|
||||
}
|
||||
|
||||
// getProtoPortInfo converts the port info to proto.PortInfo.
|
||||
func getProtoPortInfo(rule *types.RouteFirewallRule) *proto.PortInfo {
|
||||
var portInfo proto.PortInfo
|
||||
if rule.Port != 0 {
|
||||
portInfo.PortSelection = &proto.PortInfo_Port{Port: uint32(rule.Port)}
|
||||
} else if portRange := rule.PortRange; portRange.Start != 0 && portRange.End != 0 {
|
||||
portInfo.PortSelection = &proto.PortInfo_Range_{
|
||||
Range: &proto.PortInfo_Range{
|
||||
Start: uint32(portRange.Start),
|
||||
End: uint32(portRange.End),
|
||||
},
|
||||
}
|
||||
}
|
||||
return &portInfo
|
||||
}
|
||||
|
||||
func shouldUsePortRange(rule *proto.FirewallRule) bool {
|
||||
return rule.Port == "" && (rule.Protocol == proto.RuleProtocol_UDP || rule.Protocol == proto.RuleProtocol_TCP)
|
||||
}
|
||||
|
||||
// Helper function to convert nbdns.CustomZone to proto.CustomZone
|
||||
func convertToProtoCustomZone(zone nbdns.CustomZone) *proto.CustomZone {
|
||||
protoZone := &proto.CustomZone{
|
||||
Domain: zone.Domain,
|
||||
Records: make([]*proto.SimpleRecord, 0, len(zone.Records)),
|
||||
SearchDomainDisabled: zone.SearchDomainDisabled,
|
||||
NonAuthoritative: zone.NonAuthoritative,
|
||||
}
|
||||
for _, record := range zone.Records {
|
||||
protoZone.Records = append(protoZone.Records, &proto.SimpleRecord{
|
||||
Name: record.Name,
|
||||
Type: int64(record.Type),
|
||||
Class: record.Class,
|
||||
TTL: int64(record.TTL),
|
||||
RData: record.RData,
|
||||
})
|
||||
}
|
||||
return protoZone
|
||||
}
|
||||
|
||||
// Helper function to convert nbdns.NameServerGroup to proto.NameServerGroup
|
||||
func convertToProtoNameServerGroup(nsGroup *nbdns.NameServerGroup) *proto.NameServerGroup {
|
||||
protoGroup := &proto.NameServerGroup{
|
||||
Primary: nsGroup.Primary,
|
||||
Domains: nsGroup.Domains,
|
||||
SearchDomainsEnabled: nsGroup.SearchDomainsEnabled,
|
||||
NameServers: make([]*proto.NameServer, 0, len(nsGroup.NameServers)),
|
||||
}
|
||||
for _, ns := range nsGroup.NameServers {
|
||||
protoGroup.NameServers = append(protoGroup.NameServers, &proto.NameServer{
|
||||
IP: ns.IP.String(),
|
||||
Port: int64(ns.Port),
|
||||
NSType: int64(ns.NSType),
|
||||
})
|
||||
}
|
||||
return protoGroup
|
||||
}
|
||||
|
||||
// buildJWTConfig constructs JWT configuration for SSH servers from management server config
|
||||
func buildJWTConfig(config *nbconfig.HttpServerConfig, deviceFlowConfig *nbconfig.DeviceAuthorizationFlow) *proto.JWTConfig {
|
||||
if config == nil || config.AuthAudience == "" {
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache"
|
||||
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
|
||||
"github.com/netbirdio/netbird/shared/management/networkmap"
|
||||
)
|
||||
|
||||
func TestToProtocolDNSConfigWithCache(t *testing.T) {
|
||||
@@ -61,13 +62,13 @@ func TestToProtocolDNSConfigWithCache(t *testing.T) {
|
||||
}
|
||||
|
||||
// First run with config1
|
||||
result1 := toProtocolDNSConfig(config1, &cache, int64(network_map.DnsForwarderPort))
|
||||
result1 := networkmap.ToProtocolDNSConfig(config1, &cache, int64(network_map.DnsForwarderPort))
|
||||
|
||||
// Second run with config2
|
||||
result2 := toProtocolDNSConfig(config2, &cache, int64(network_map.DnsForwarderPort))
|
||||
result2 := networkmap.ToProtocolDNSConfig(config2, &cache, int64(network_map.DnsForwarderPort))
|
||||
|
||||
// Third run with config1 again
|
||||
result3 := toProtocolDNSConfig(config1, &cache, int64(network_map.DnsForwarderPort))
|
||||
result3 := networkmap.ToProtocolDNSConfig(config1, &cache, int64(network_map.DnsForwarderPort))
|
||||
|
||||
// Verify that result1 and result3 are identical
|
||||
if !reflect.DeepEqual(result1, result3) {
|
||||
@@ -99,7 +100,7 @@ func BenchmarkToProtocolDNSConfig(b *testing.B) {
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
toProtocolDNSConfig(testData, cache, int64(network_map.DnsForwarderPort))
|
||||
networkmap.ToProtocolDNSConfig(testData, cache, int64(network_map.DnsForwarderPort))
|
||||
}
|
||||
})
|
||||
|
||||
@@ -107,7 +108,7 @@ func BenchmarkToProtocolDNSConfig(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
cache := &cache.DNSConfigCache{}
|
||||
toProtocolDNSConfig(testData, cache, int64(network_map.DnsForwarderPort))
|
||||
networkmap.ToProtocolDNSConfig(testData, cache, int64(network_map.DnsForwarderPort))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -932,7 +932,31 @@ func (s *Server) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, peer
|
||||
return status.Errorf(codes.Internal, "failed to get peer groups %s", err)
|
||||
}
|
||||
|
||||
plainResp := ToSyncResponse(ctx, s.config, s.config.HttpConfig, s.config.DeviceAuthorizationFlow, peer, turnToken, relayToken, networkMap, s.networkMapController.GetDNSDomain(settings), postureChecks, nil, settings, settings.Extra, peerGroups, dnsFwdPort)
|
||||
dnsName := s.networkMapController.GetDNSDomain(settings)
|
||||
|
||||
var plainResp *proto.SyncResponse
|
||||
if s.networkMapController.PeerNeedsComponents(peer) {
|
||||
// Capable peer: discard the legacy NetworkMap that SyncAndMarkPeer
|
||||
// computed and recompute the raw components instead. This wastes one
|
||||
// Calculate() call per initial-sync — the component-based wire
|
||||
// format is what the peer actually consumes. The streaming path
|
||||
// (network_map.Controller.UpdateAccountPeers) skips this duplication
|
||||
// because it dispatches by capability before computing.
|
||||
//
|
||||
// TODO(step-4-sync): refactor SyncPeer / SyncAndMarkPeer / their
|
||||
// mocks + manager interfaces to return PeerNetworkMapResult so the
|
||||
// initial-sync path stops doing duplicate work. ~13 files of churn,
|
||||
// deferred until the client-side decoder lands and there's a real
|
||||
// deployment of capability=3 peers worth optimizing for.
|
||||
_, components, proxyPatch, _, _, err := s.networkMapController.GetValidatedPeerWithComponents(ctx, false, peer.AccountID, peer)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to build components for peer %s on initial sync: %v", peer.ID, err)
|
||||
return status.Errorf(codes.Internal, "failed to build initial sync envelope")
|
||||
}
|
||||
plainResp = ToComponentSyncResponse(ctx, s.config, s.config.HttpConfig, s.config.DeviceAuthorizationFlow, peer, turnToken, relayToken, components, proxyPatch, dnsName, postureChecks, settings, settings.Extra, peerGroups, dnsFwdPort)
|
||||
} else {
|
||||
plainResp = ToSyncResponse(ctx, s.config, s.config.HttpConfig, s.config.DeviceAuthorizationFlow, peer, turnToken, relayToken, networkMap, dnsName, postureChecks, nil, settings, settings.Extra, peerGroups, dnsFwdPort)
|
||||
}
|
||||
|
||||
key, err := s.secretsManager.GetWGKey()
|
||||
if err != nil {
|
||||
|
||||
@@ -1621,6 +1621,14 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, g := range newGroupsToCreate {
|
||||
seq, err := transaction.AllocateAccountSeqID(ctx, userAuth.AccountId, types.AccountSeqEntityGroup)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error allocating group seq id: %w", err)
|
||||
}
|
||||
g.AccountSeqID = seq
|
||||
}
|
||||
|
||||
if err = transaction.CreateGroups(ctx, userAuth.AccountId, newGroupsToCreate); err != nil {
|
||||
return fmt.Errorf("error saving groups: %w", err)
|
||||
}
|
||||
@@ -1868,35 +1876,32 @@ func domainIsUpToDate(domain string, domainCategory string, userAuth auth.UserAu
|
||||
return domainCategory == types.PrivateCategory || userAuth.DomainCategory != types.PrivateCategory || domain != userAuth.Domain
|
||||
}
|
||||
|
||||
// SyncAndMarkPeer is the per-Sync entry point: it refreshes the peer's
|
||||
// network map and then marks the peer connected with a session token
|
||||
// derived from syncTime (the moment the gRPC stream opened). Any
|
||||
// concurrent stream that started earlier loses the optimistic-lock race
|
||||
// in MarkPeerConnected and bails without writing.
|
||||
func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP, syncTime time.Time) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
|
||||
peer, netMap, postureChecks, dnsfwdPort, err := am.SyncPeer(ctx, types.PeerSync{WireGuardPubKey: peerPubKey, Meta: meta}, accountID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, 0, fmt.Errorf("error syncing peer: %w", err)
|
||||
}
|
||||
|
||||
err = am.MarkPeerConnected(ctx, peerPubKey, true, realIP, accountID, syncTime)
|
||||
if err != nil {
|
||||
if err := am.MarkPeerConnected(ctx, peerPubKey, realIP, accountID, syncTime.UnixNano()); err != nil {
|
||||
log.WithContext(ctx).Warnf("failed marking peer as connected %s %v", peerPubKey, err)
|
||||
}
|
||||
|
||||
return peer, netMap, postureChecks, dnsfwdPort, nil
|
||||
}
|
||||
|
||||
// OnPeerDisconnected is invoked when a sync stream ends. It marks the
|
||||
// peer disconnected only when the stored SessionStartedAt matches the
|
||||
// nanosecond token derived from streamStartTime — i.e. only when this
|
||||
// is the stream that currently owns the peer's session. A mismatch
|
||||
// means a newer stream has already replaced us, so the disconnect is
|
||||
// dropped.
|
||||
func (am *DefaultAccountManager) OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string, streamStartTime time.Time) error {
|
||||
peer, err := am.Store.GetPeerByPeerPubKey(ctx, store.LockingStrengthNone, peerPubKey)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("failed to get peer %s for disconnect check: %v", peerPubKey, err)
|
||||
return nil
|
||||
}
|
||||
|
||||
if peer.Status.LastSeen.After(streamStartTime) {
|
||||
log.WithContext(ctx).Tracef("peer %s has newer activity (lastSeen=%s > streamStart=%s), skipping disconnect",
|
||||
peerPubKey, peer.Status.LastSeen.Format(time.RFC3339), streamStartTime.Format(time.RFC3339))
|
||||
return nil
|
||||
}
|
||||
|
||||
err = am.MarkPeerConnected(ctx, peerPubKey, false, nil, accountID, time.Now().UTC())
|
||||
if err != nil {
|
||||
if err := am.MarkPeerDisconnected(ctx, peerPubKey, accountID, streamStartTime.UnixNano()); err != nil {
|
||||
log.WithContext(ctx).Warnf("failed marking peer as disconnected %s %v", peerPubKey, err)
|
||||
}
|
||||
return nil
|
||||
|
||||
@@ -61,7 +61,8 @@ type Manager interface {
|
||||
GetUserFromUserAuth(ctx context.Context, userAuth auth.UserAuth) (*types.User, error)
|
||||
ListUsers(ctx context.Context, accountID string) ([]*types.User, error)
|
||||
GetPeers(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error)
|
||||
MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string, syncTime time.Time) error
|
||||
MarkPeerConnected(ctx context.Context, peerKey string, realIP net.IP, accountID string, sessionStartedAt int64) error
|
||||
MarkPeerDisconnected(ctx context.Context, peerKey string, accountID string, sessionStartedAt int64) error
|
||||
DeletePeer(ctx context.Context, accountID, peerID, userID string) error
|
||||
UpdatePeer(ctx context.Context, accountID, userID string, p *nbpeer.Peer) (*nbpeer.Peer, error)
|
||||
UpdatePeerIP(ctx context.Context, accountID, userID, peerID string, newIP netip.Addr) error
|
||||
|
||||
@@ -1305,17 +1305,31 @@ func (mr *MockManagerMockRecorder) LoginPeer(ctx, login interface{}) *gomock.Cal
|
||||
}
|
||||
|
||||
// MarkPeerConnected mocks base method.
|
||||
func (m *MockManager) MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string, syncTime time.Time) error {
|
||||
func (m *MockManager) MarkPeerConnected(ctx context.Context, peerKey string, realIP net.IP, accountID string, sessionStartedAt int64) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "MarkPeerConnected", ctx, peerKey, connected, realIP, accountID, syncTime)
|
||||
ret := m.ctrl.Call(m, "MarkPeerConnected", ctx, peerKey, realIP, accountID, sessionStartedAt)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// MarkPeerConnected indicates an expected call of MarkPeerConnected.
|
||||
func (mr *MockManagerMockRecorder) MarkPeerConnected(ctx, peerKey, connected, realIP, accountID, syncTime interface{}) *gomock.Call {
|
||||
func (mr *MockManagerMockRecorder) MarkPeerConnected(ctx, peerKey, realIP, accountID, sessionStartedAt interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MarkPeerConnected", reflect.TypeOf((*MockManager)(nil).MarkPeerConnected), ctx, peerKey, connected, realIP, accountID, syncTime)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MarkPeerConnected", reflect.TypeOf((*MockManager)(nil).MarkPeerConnected), ctx, peerKey, realIP, accountID, sessionStartedAt)
|
||||
}
|
||||
|
||||
// MarkPeerDisconnected mocks base method.
|
||||
func (m *MockManager) MarkPeerDisconnected(ctx context.Context, peerKey string, accountID string, sessionStartedAt int64) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "MarkPeerDisconnected", ctx, peerKey, accountID, sessionStartedAt)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// MarkPeerDisconnected indicates an expected call of MarkPeerDisconnected.
|
||||
func (mr *MockManagerMockRecorder) MarkPeerDisconnected(ctx, peerKey, accountID, sessionStartedAt interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MarkPeerDisconnected", reflect.TypeOf((*MockManager)(nil).MarkPeerDisconnected), ctx, peerKey, accountID, sessionStartedAt)
|
||||
}
|
||||
|
||||
// OnPeerDisconnected mocks base method.
|
||||
|
||||
@@ -1813,7 +1813,7 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
|
||||
accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID})
|
||||
require.NoError(t, err, "unable to get the account")
|
||||
|
||||
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID, time.Now().UTC())
|
||||
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), nil, accountID, time.Now().UTC().UnixNano())
|
||||
require.NoError(t, err, "unable to mark peer connected")
|
||||
|
||||
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{
|
||||
@@ -1884,7 +1884,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
|
||||
require.NoError(t, err, "unable to get the account")
|
||||
|
||||
// when we mark peer as connected, the peer login expiration routine should trigger
|
||||
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID, time.Now().UTC())
|
||||
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), nil, accountID, time.Now().UTC().UnixNano())
|
||||
require.NoError(t, err, "unable to mark peer connected")
|
||||
|
||||
failed := waitTimeout(wg, time.Second)
|
||||
@@ -1910,15 +1910,16 @@ func TestDefaultAccountManager_OnPeerDisconnected_LastSeenCheck(t *testing.T) {
|
||||
}, false)
|
||||
require.NoError(t, err, "unable to add peer")
|
||||
|
||||
t.Run("disconnect peer when streamStartTime is after LastSeen", func(t *testing.T) {
|
||||
err = manager.MarkPeerConnected(context.Background(), peerPubKey, true, nil, accountID, time.Now().UTC())
|
||||
t.Run("disconnect peer when session token matches", func(t *testing.T) {
|
||||
streamStartTime := time.Now().UTC()
|
||||
err = manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, streamStartTime.UnixNano())
|
||||
require.NoError(t, err, "unable to mark peer connected")
|
||||
|
||||
peer, err := manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
|
||||
require.NoError(t, err, "unable to get peer")
|
||||
require.True(t, peer.Status.Connected, "peer should be connected")
|
||||
|
||||
streamStartTime := time.Now().UTC()
|
||||
require.Equal(t, streamStartTime.UnixNano(), peer.Status.SessionStartedAt,
|
||||
"SessionStartedAt should equal the token we passed in")
|
||||
|
||||
err = manager.OnPeerDisconnected(context.Background(), accountID, peerPubKey, streamStartTime)
|
||||
require.NoError(t, err)
|
||||
@@ -1926,49 +1927,127 @@ func TestDefaultAccountManager_OnPeerDisconnected_LastSeenCheck(t *testing.T) {
|
||||
peer, err = manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
|
||||
require.NoError(t, err)
|
||||
require.False(t, peer.Status.Connected, "peer should be disconnected")
|
||||
require.Equal(t, int64(0), peer.Status.SessionStartedAt, "SessionStartedAt should be reset to 0")
|
||||
})
|
||||
|
||||
t.Run("skip disconnect when LastSeen is after streamStartTime (zombie stream protection)", func(t *testing.T) {
|
||||
err = manager.MarkPeerConnected(context.Background(), peerPubKey, true, nil, accountID, time.Now().UTC())
|
||||
t.Run("skip disconnect when stored session is newer (zombie stream protection)", func(t *testing.T) {
|
||||
// Newer stream wins on connect (sets SessionStartedAt = now ns).
|
||||
streamStartTime := time.Now().UTC()
|
||||
err = manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, streamStartTime.UnixNano())
|
||||
require.NoError(t, err, "unable to mark peer connected")
|
||||
|
||||
peer, err := manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
|
||||
require.NoError(t, err)
|
||||
require.True(t, peer.Status.Connected, "peer should be connected")
|
||||
|
||||
streamStartTime := peer.Status.LastSeen.Add(-1 * time.Hour)
|
||||
// Older stream tries to mark disconnect with its own (older) session token —
|
||||
// fencing kicks in and the write is dropped.
|
||||
staleStreamStartTime := streamStartTime.Add(-1 * time.Hour)
|
||||
|
||||
err = manager.OnPeerDisconnected(context.Background(), accountID, peerPubKey, streamStartTime)
|
||||
err = manager.OnPeerDisconnected(context.Background(), accountID, peerPubKey, staleStreamStartTime)
|
||||
require.NoError(t, err)
|
||||
|
||||
peer, err = manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
|
||||
require.NoError(t, err)
|
||||
require.True(t, peer.Status.Connected,
|
||||
"peer should remain connected because LastSeen > streamStartTime (zombie stream protection)")
|
||||
"peer should remain connected because the stored session is newer than the disconnect token")
|
||||
require.Equal(t, streamStartTime.UnixNano(), peer.Status.SessionStartedAt,
|
||||
"SessionStartedAt should still hold the winning stream's token")
|
||||
})
|
||||
|
||||
t.Run("skip stale connect when peer already has newer LastSeen (blocked goroutine protection)", func(t *testing.T) {
|
||||
t.Run("skip stale connect when stored session is newer (blocked goroutine protection)", func(t *testing.T) {
|
||||
node2SyncTime := time.Now().UTC()
|
||||
err = manager.MarkPeerConnected(context.Background(), peerPubKey, true, nil, accountID, node2SyncTime)
|
||||
err = manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, node2SyncTime.UnixNano())
|
||||
require.NoError(t, err, "node 2 should connect peer")
|
||||
|
||||
peer, err := manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
|
||||
require.NoError(t, err)
|
||||
require.True(t, peer.Status.Connected, "peer should be connected")
|
||||
require.Equal(t, node2SyncTime.Unix(), peer.Status.LastSeen.Unix(), "LastSeen should be node2SyncTime")
|
||||
require.Equal(t, node2SyncTime.UnixNano(), peer.Status.SessionStartedAt,
|
||||
"SessionStartedAt should equal node2SyncTime token")
|
||||
|
||||
node1StaleSyncTime := node2SyncTime.Add(-1 * time.Minute)
|
||||
err = manager.MarkPeerConnected(context.Background(), peerPubKey, true, nil, accountID, node1StaleSyncTime)
|
||||
err = manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, node1StaleSyncTime.UnixNano())
|
||||
require.NoError(t, err, "stale connect should not return error")
|
||||
|
||||
peer, err = manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
|
||||
require.NoError(t, err)
|
||||
require.True(t, peer.Status.Connected, "peer should still be connected")
|
||||
require.Equal(t, node2SyncTime.Unix(), peer.Status.LastSeen.Unix(),
|
||||
"LastSeen should NOT be overwritten by stale syncTime from blocked goroutine")
|
||||
require.Equal(t, node2SyncTime.UnixNano(), peer.Status.SessionStartedAt,
|
||||
"SessionStartedAt should NOT be overwritten by stale token from blocked goroutine")
|
||||
})
|
||||
}
|
||||
|
||||
// TestDefaultAccountManager_MarkPeerConnected_ConcurrentRace exercises the
|
||||
// fencing protocol under contention: many goroutines race to mark the
|
||||
// same peer connected with distinct session tokens at the same time.
|
||||
// The contract is that the highest token always wins and is what remains
|
||||
// in the store, regardless of execution order.
|
||||
func TestDefaultAccountManager_MarkPeerConnected_ConcurrentRace(t *testing.T) {
|
||||
manager, _, err := createManager(t)
|
||||
require.NoError(t, err, "unable to create account manager")
|
||||
|
||||
accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID})
|
||||
require.NoError(t, err, "unable to get account")
|
||||
|
||||
key, err := wgtypes.GenerateKey()
|
||||
require.NoError(t, err, "unable to generate WireGuard key")
|
||||
peerPubKey := key.PublicKey().String()
|
||||
|
||||
_, _, _, err = manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
|
||||
Key: peerPubKey,
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: "race-peer"},
|
||||
}, false)
|
||||
require.NoError(t, err, "unable to add peer")
|
||||
|
||||
const workers = 16
|
||||
base := time.Now().UTC().UnixNano()
|
||||
tokens := make([]int64, workers)
|
||||
for i := range tokens {
|
||||
// Spread tokens by 1ms so the comparison is unambiguous; the
|
||||
// largest is index workers-1.
|
||||
tokens[i] = base + int64(i)*int64(time.Millisecond)
|
||||
}
|
||||
expected := tokens[workers-1]
|
||||
|
||||
var ready sync.WaitGroup
|
||||
ready.Add(workers)
|
||||
var start sync.WaitGroup
|
||||
start.Add(1)
|
||||
var done sync.WaitGroup
|
||||
done.Add(workers)
|
||||
|
||||
// require.* calls t.FailNow which is documented as unsafe from
|
||||
// non-test goroutines (it calls runtime.Goexit on the wrong stack and
|
||||
// races with the WaitGroup). Collect errors here and assert from the
|
||||
// main goroutine after done.Wait().
|
||||
errs := make(chan error, workers)
|
||||
|
||||
for i := 0; i < workers; i++ {
|
||||
token := tokens[i]
|
||||
go func() {
|
||||
defer done.Done()
|
||||
ready.Done()
|
||||
start.Wait()
|
||||
errs <- manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, token)
|
||||
}()
|
||||
}
|
||||
|
||||
ready.Wait()
|
||||
start.Done()
|
||||
done.Wait()
|
||||
close(errs)
|
||||
for err := range errs {
|
||||
require.NoError(t, err, "MarkPeerConnected must not error under contention")
|
||||
}
|
||||
|
||||
peer, err := manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
|
||||
require.NoError(t, err)
|
||||
require.True(t, peer.Status.Connected, "peer should be connected after the race")
|
||||
require.Equal(t, expected, peer.Status.SessionStartedAt,
|
||||
"the largest token must win regardless of execution order")
|
||||
}
|
||||
|
||||
func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *testing.T) {
|
||||
manager, _, err := createManager(t)
|
||||
require.NoError(t, err, "unable to create account manager")
|
||||
@@ -1991,7 +2070,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
|
||||
account, err := manager.Store.GetAccount(context.Background(), accountID)
|
||||
require.NoError(t, err, "unable to get the account")
|
||||
|
||||
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID, time.Now().UTC())
|
||||
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), nil, accountID, time.Now().UTC().UnixNano())
|
||||
require.NoError(t, err, "unable to mark peer connected")
|
||||
|
||||
wg := &sync.WaitGroup{}
|
||||
@@ -2957,6 +3036,16 @@ func TestAccount_SetJWTGroups(t *testing.T) {
|
||||
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user2")
|
||||
assert.NoError(t, err, "unable to get user")
|
||||
assert.Len(t, user.AutoGroups, 1, "new group should be added")
|
||||
|
||||
var newJWTGroup *types.Group
|
||||
for _, g := range groups {
|
||||
if g.Name == "group3" {
|
||||
newJWTGroup = g
|
||||
break
|
||||
}
|
||||
}
|
||||
require.NotNil(t, newJWTGroup, "JIT-created JWT group not found")
|
||||
assert.NotZero(t, newJWTGroup.AccountSeqID, "JIT-created JWT group must have a non-zero AccountSeqID")
|
||||
})
|
||||
|
||||
t.Run("remove all JWT groups when list is empty", func(t *testing.T) {
|
||||
|
||||
@@ -96,6 +96,12 @@ func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, use
|
||||
return err
|
||||
}
|
||||
|
||||
seq, err := transaction.AllocateAccountSeqID(ctx, accountID, types.AccountSeqEntityGroup)
|
||||
if err != nil {
|
||||
return status.Errorf(status.Internal, "failed to allocate group seq id: %v", err)
|
||||
}
|
||||
newGroup.AccountSeqID = seq
|
||||
|
||||
if err := transaction.CreateGroup(ctx, newGroup); err != nil {
|
||||
return status.Errorf(status.Internal, "failed to create group: %v", err)
|
||||
}
|
||||
@@ -170,6 +176,8 @@ func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, use
|
||||
return err
|
||||
}
|
||||
|
||||
newGroup.AccountSeqID = oldGroup.AccountSeqID
|
||||
|
||||
if err = transaction.UpdateGroup(ctx, newGroup); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -221,6 +229,12 @@ func (am *DefaultAccountManager) CreateGroups(ctx context.Context, accountID, us
|
||||
|
||||
newGroup.AccountID = accountID
|
||||
|
||||
seq, err := transaction.AllocateAccountSeqID(ctx, accountID, types.AccountSeqEntityGroup)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
newGroup.AccountSeqID = seq
|
||||
|
||||
if err = transaction.CreateGroup(ctx, newGroup); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -320,6 +334,12 @@ func (am *DefaultAccountManager) updateSingleGroup(ctx context.Context, accountI
|
||||
|
||||
newGroup.AccountID = accountID
|
||||
|
||||
oldGroup, err := transaction.GetGroupByID(ctx, store.LockingStrengthNone, accountID, newGroup.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
newGroup.AccountSeqID = oldGroup.AccountSeqID
|
||||
|
||||
if err := transaction.UpdateGroup(ctx, newGroup); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -6,7 +6,9 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path"
|
||||
"strings"
|
||||
|
||||
"github.com/dexidp/dex/storage"
|
||||
@@ -138,10 +140,13 @@ func (c *EmbeddedIdPConfig) ToYAMLConfig() (*dex.YAMLConfig, error) {
|
||||
return nil, fmt.Errorf("invalid IdP storage config: %w", err)
|
||||
}
|
||||
|
||||
// Build CLI redirect URIs including the device callback (both relative and absolute)
|
||||
// Build CLI redirect URIs including the device callback. Dex uses the issuer-relative
|
||||
// path (for example, /oauth2/device/callback) when completing the device flow, so
|
||||
// include it explicitly in addition to the legacy bare path and absolute URL.
|
||||
cliRedirectURIs := c.CLIRedirectURIs
|
||||
cliRedirectURIs = append(cliRedirectURIs, "/device/callback")
|
||||
cliRedirectURIs = append(cliRedirectURIs, c.Issuer+"/device/callback")
|
||||
cliRedirectURIs = append(cliRedirectURIs, issuerRelativeDeviceCallback(c.Issuer))
|
||||
cliRedirectURIs = append(cliRedirectURIs, strings.TrimSuffix(c.Issuer, "/")+"/device/callback")
|
||||
|
||||
// Build dashboard redirect URIs including the OAuth callback for proxy authentication
|
||||
dashboardRedirectURIs := c.DashboardRedirectURIs
|
||||
@@ -154,6 +159,10 @@ func (c *EmbeddedIdPConfig) ToYAMLConfig() (*dex.YAMLConfig, error) {
|
||||
// MGMT api and the dashboard, adding baseURL means less configuration for the instance admin
|
||||
dashboardPostLogoutRedirectURIs = append(dashboardPostLogoutRedirectURIs, baseURL)
|
||||
|
||||
redirectURIs := make([]string, 0)
|
||||
redirectURIs = append(redirectURIs, cliRedirectURIs...)
|
||||
redirectURIs = append(redirectURIs, dashboardRedirectURIs...)
|
||||
|
||||
cfg := &dex.YAMLConfig{
|
||||
Issuer: c.Issuer,
|
||||
Storage: dex.Storage{
|
||||
@@ -179,14 +188,14 @@ func (c *EmbeddedIdPConfig) ToYAMLConfig() (*dex.YAMLConfig, error) {
|
||||
ID: staticClientDashboard,
|
||||
Name: "NetBird Dashboard",
|
||||
Public: true,
|
||||
RedirectURIs: dashboardRedirectURIs,
|
||||
RedirectURIs: redirectURIs,
|
||||
PostLogoutRedirectURIs: sanitizePostLogoutRedirectURIs(dashboardPostLogoutRedirectURIs),
|
||||
},
|
||||
{
|
||||
ID: staticClientCLI,
|
||||
Name: "NetBird CLI",
|
||||
Public: true,
|
||||
RedirectURIs: cliRedirectURIs,
|
||||
RedirectURIs: redirectURIs,
|
||||
},
|
||||
},
|
||||
StaticConnectors: c.StaticConnectors,
|
||||
@@ -217,6 +226,14 @@ func (c *EmbeddedIdPConfig) ToYAMLConfig() (*dex.YAMLConfig, error) {
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func issuerRelativeDeviceCallback(issuer string) string {
|
||||
u, err := url.Parse(issuer)
|
||||
if err != nil || u.Path == "" {
|
||||
return "/device/callback"
|
||||
}
|
||||
return path.Join(u.Path, "/device/callback")
|
||||
}
|
||||
|
||||
// Due to how the frontend generates the logout, sometimes it appends a trailing slash
|
||||
// and because Dex only allows exact matches, we need to make sure we always have both
|
||||
// versions of each provided uri
|
||||
@@ -299,7 +316,7 @@ func resolveSessionCookieEncryptionKey(configuredKey string) (string, error) {
|
||||
}
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("invalid embedded IdP session cookie encryption key: %s (or sessionCookieEncryptionKey) must be 16, 24, or 32 bytes as a raw string or base64-encoded to one of those lengths; got %d raw bytes", sessionCookieEncryptionKeyEnv, len([]byte(key)))
|
||||
return "", fmt.Errorf("invalid embedded IdP session cookie encryption key:%s (or sessionCookieEncryptionKey) must be 16, 24, or 32 bytes as a raw string or base64-encoded to one of those lengths; got %d raw bytes", sessionCookieEncryptionKeyEnv, len([]byte(key)))
|
||||
}
|
||||
|
||||
func validSessionCookieEncryptionKeyLength(length int) bool {
|
||||
|
||||
@@ -314,6 +314,34 @@ func TestEmbeddedIdPManager_UpdateUserPassword(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestEmbeddedIdPConfig_ToYAMLConfig_IncludesDeviceCallbackRedirectURI(t *testing.T) {
|
||||
config := &EmbeddedIdPConfig{
|
||||
Enabled: true,
|
||||
Issuer: "https://example.com/oauth2",
|
||||
Storage: EmbeddedStorageConfig{
|
||||
Type: "sqlite3",
|
||||
Config: EmbeddedStorageTypeConfig{
|
||||
File: filepath.Join(t.TempDir(), "dex.db"),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
yamlConfig, err := config.ToYAMLConfig()
|
||||
require.NoError(t, err)
|
||||
|
||||
var cliRedirectURIs []string
|
||||
for _, client := range yamlConfig.StaticClients {
|
||||
if client.ID == staticClientCLI {
|
||||
cliRedirectURIs = client.RedirectURIs
|
||||
break
|
||||
}
|
||||
}
|
||||
require.NotEmpty(t, cliRedirectURIs)
|
||||
assert.Contains(t, cliRedirectURIs, "/device/callback")
|
||||
assert.Contains(t, cliRedirectURIs, "/oauth2/device/callback")
|
||||
assert.Contains(t, cliRedirectURIs, "https://example.com/oauth2/device/callback")
|
||||
}
|
||||
|
||||
func TestEmbeddedIdPConfig_ToYAMLConfig_SessionCookieEncryptionKey(t *testing.T) {
|
||||
t.Setenv(sessionCookieEncryptionKeyEnv, "")
|
||||
|
||||
|
||||
156
management/server/migration/account_seq.go
Normal file
156
management/server/migration/account_seq.go
Normal file
@@ -0,0 +1,156 @@
|
||||
package migration
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
// BackfillAccountSeqIDs assigns a deterministic per-account sequential id to all
|
||||
// rows of `model` whose account_seq_id is zero, then seeds account_seq_counters
|
||||
// with the next free id per account. Idempotent: safe to re-run; both steps
|
||||
// no-op once everything is consistent.
|
||||
//
|
||||
// Implemented as two table-wide SQL statements with window functions, one
|
||||
// transaction. Backfilling 246k rows across 154k accounts on Postgres takes
|
||||
// well under a second instead of the per-account-loop ~2 minutes.
|
||||
//
|
||||
// orderColumn is the column to use when assigning the deterministic ordering
|
||||
// (typically the primary-key string id).
|
||||
func BackfillAccountSeqIDs[T any](
|
||||
ctx context.Context,
|
||||
db *gorm.DB,
|
||||
entity types.AccountSeqEntity,
|
||||
orderColumn string,
|
||||
) error {
|
||||
var model T
|
||||
if !db.Migrator().HasTable(&model) {
|
||||
log.WithContext(ctx).Debugf("backfill seq id: table for %T missing, skip", model)
|
||||
return nil
|
||||
}
|
||||
|
||||
stmt := &gorm.Statement{DB: db}
|
||||
if err := stmt.Parse(&model); err != nil {
|
||||
return fmt.Errorf("parse model: %w", err)
|
||||
}
|
||||
table := quoteIdent(db, stmt.Schema.Table)
|
||||
orderCol := quoteIdent(db, orderColumn)
|
||||
|
||||
return db.Transaction(func(tx *gorm.DB) error {
|
||||
var pending int64
|
||||
if err := tx.Raw(
|
||||
fmt.Sprintf("SELECT count(*) FROM %s WHERE account_seq_id IS NULL OR account_seq_id = 0", table),
|
||||
).Scan(&pending).Error; err != nil {
|
||||
return fmt.Errorf("count pending on %s: %w", table, err)
|
||||
}
|
||||
|
||||
if pending > 0 {
|
||||
log.WithContext(ctx).Infof("backfill seq id: %s — %d rows pending", table, pending)
|
||||
if err := backfillRankSQL(tx, table, orderCol); err != nil {
|
||||
return fmt.Errorf("rank %s: %w", table, err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := seedCountersSQL(tx, table, entity); err != nil {
|
||||
return fmt.Errorf("seed counters for %s: %w", entity, err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func quoteIdent(db *gorm.DB, name string) string {
|
||||
switch db.Dialector.Name() {
|
||||
case "mysql":
|
||||
return "`" + name + "`"
|
||||
case "postgres":
|
||||
return `"` + name + `"`
|
||||
default:
|
||||
return name
|
||||
}
|
||||
}
|
||||
|
||||
func backfillRankSQL(db *gorm.DB, table, orderCol string) error {
|
||||
dialect := db.Dialector.Name()
|
||||
var sql string
|
||||
switch dialect {
|
||||
case "postgres", "sqlite":
|
||||
sql = fmt.Sprintf(`
|
||||
WITH max_seq AS (
|
||||
SELECT account_id, COALESCE(MAX(account_seq_id), 0) AS max_seq
|
||||
FROM %s
|
||||
GROUP BY account_id
|
||||
),
|
||||
ranked AS (
|
||||
SELECT p.id,
|
||||
m.max_seq + ROW_NUMBER() OVER (PARTITION BY p.account_id ORDER BY p.%s) AS new_seq
|
||||
FROM %s p
|
||||
JOIN max_seq m ON p.account_id = m.account_id
|
||||
WHERE p.account_seq_id IS NULL OR p.account_seq_id = 0
|
||||
)
|
||||
UPDATE %s SET account_seq_id = ranked.new_seq
|
||||
FROM ranked
|
||||
WHERE %s.id = ranked.id
|
||||
`, table, orderCol, table, table, table)
|
||||
case "mysql":
|
||||
sql = fmt.Sprintf(`
|
||||
UPDATE %s p
|
||||
JOIN (
|
||||
SELECT account_id, COALESCE(MAX(account_seq_id), 0) AS max_seq
|
||||
FROM %s
|
||||
GROUP BY account_id
|
||||
) m ON p.account_id = m.account_id
|
||||
JOIN (
|
||||
SELECT id, ROW_NUMBER() OVER (PARTITION BY account_id ORDER BY %s) AS rn
|
||||
FROM %s
|
||||
WHERE account_seq_id IS NULL OR account_seq_id = 0
|
||||
) r ON p.id = r.id
|
||||
SET p.account_seq_id = m.max_seq + r.rn
|
||||
`, table, table, orderCol, table)
|
||||
default:
|
||||
return fmt.Errorf("unsupported dialect: %s", dialect)
|
||||
}
|
||||
return db.Exec(sql).Error
|
||||
}
|
||||
|
||||
func seedCountersSQL(db *gorm.DB, table string, entity types.AccountSeqEntity) error {
|
||||
dialect := db.Dialector.Name()
|
||||
var sql string
|
||||
switch dialect {
|
||||
case "postgres":
|
||||
sql = fmt.Sprintf(`
|
||||
INSERT INTO account_seq_counters (account_id, entity, next_id)
|
||||
SELECT account_id, ?, MAX(account_seq_id) + 1
|
||||
FROM %s
|
||||
WHERE account_seq_id IS NOT NULL AND account_seq_id > 0
|
||||
GROUP BY account_id
|
||||
ON CONFLICT (account_id, entity) DO UPDATE
|
||||
SET next_id = GREATEST(account_seq_counters.next_id, EXCLUDED.next_id)
|
||||
`, table)
|
||||
case "sqlite":
|
||||
sql = fmt.Sprintf(`
|
||||
INSERT INTO account_seq_counters (account_id, entity, next_id)
|
||||
SELECT account_id, ?, MAX(account_seq_id) + 1
|
||||
FROM %s
|
||||
WHERE account_seq_id IS NOT NULL AND account_seq_id > 0
|
||||
GROUP BY account_id
|
||||
ON CONFLICT (account_id, entity) DO UPDATE
|
||||
SET next_id = max(account_seq_counters.next_id, excluded.next_id)
|
||||
`, table)
|
||||
case "mysql":
|
||||
sql = fmt.Sprintf(`
|
||||
INSERT INTO account_seq_counters (account_id, entity, next_id)
|
||||
SELECT account_id, ?, MAX(account_seq_id) + 1
|
||||
FROM %s
|
||||
WHERE account_seq_id IS NOT NULL AND account_seq_id > 0
|
||||
GROUP BY account_id
|
||||
ON DUPLICATE KEY UPDATE next_id = GREATEST(next_id, VALUES(next_id))
|
||||
`, table)
|
||||
default:
|
||||
return fmt.Errorf("unsupported dialect: %s", dialect)
|
||||
}
|
||||
return db.Exec(sql, string(entity)).Error
|
||||
}
|
||||
@@ -38,7 +38,8 @@ type MockAccountManager struct {
|
||||
GetUserFromUserAuthFunc func(ctx context.Context, userAuth auth.UserAuth) (*types.User, error)
|
||||
ListUsersFunc func(ctx context.Context, accountID string) ([]*types.User, error)
|
||||
GetPeersFunc func(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error)
|
||||
MarkPeerConnectedFunc func(ctx context.Context, peerKey string, connected bool, realIP net.IP, syncTime time.Time) error
|
||||
MarkPeerConnectedFunc func(ctx context.Context, peerKey string, realIP net.IP, accountID string, sessionStartedAt int64) error
|
||||
MarkPeerDisconnectedFunc func(ctx context.Context, peerKey string, accountID string, sessionStartedAt int64) error
|
||||
SyncAndMarkPeerFunc func(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP, syncTime time.Time) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
|
||||
DeletePeerFunc func(ctx context.Context, accountID, peerKey, userID string) error
|
||||
GetNetworkMapFunc func(ctx context.Context, peerKey string) (*types.NetworkMap, error)
|
||||
@@ -227,7 +228,14 @@ func (am *MockAccountManager) SyncAndMarkPeer(ctx context.Context, accountID str
|
||||
return nil, nil, nil, 0, status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented")
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) OnPeerDisconnected(_ context.Context, accountID string, peerPubKey string, streamStartTime time.Time) error {
|
||||
func (am *MockAccountManager) OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string, streamStartTime time.Time) error {
|
||||
// Mirror DefaultAccountManager.OnPeerDisconnected: drive the fencing
|
||||
// hook so tests that inject MarkPeerDisconnectedFunc actually observe
|
||||
// disconnect events. Falls through to nil when no hook is set, which
|
||||
// is the original behaviour.
|
||||
if am.MarkPeerDisconnectedFunc != nil {
|
||||
return am.MarkPeerDisconnectedFunc(ctx, peerPubKey, accountID, streamStartTime.UnixNano())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -328,13 +336,21 @@ func (am *MockAccountManager) GetAccountIDByUserID(ctx context.Context, userAuth
|
||||
}
|
||||
|
||||
// MarkPeerConnected mock implementation of MarkPeerConnected from server.AccountManager interface
|
||||
func (am *MockAccountManager) MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string, syncTime time.Time) error {
|
||||
func (am *MockAccountManager) MarkPeerConnected(ctx context.Context, peerKey string, realIP net.IP, accountID string, sessionStartedAt int64) error {
|
||||
if am.MarkPeerConnectedFunc != nil {
|
||||
return am.MarkPeerConnectedFunc(ctx, peerKey, connected, realIP, syncTime)
|
||||
return am.MarkPeerConnectedFunc(ctx, peerKey, realIP, accountID, sessionStartedAt)
|
||||
}
|
||||
return status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented")
|
||||
}
|
||||
|
||||
// MarkPeerDisconnected mock implementation of MarkPeerDisconnected from server.AccountManager interface
|
||||
func (am *MockAccountManager) MarkPeerDisconnected(ctx context.Context, peerKey string, accountID string, sessionStartedAt int64) error {
|
||||
if am.MarkPeerDisconnectedFunc != nil {
|
||||
return am.MarkPeerDisconnectedFunc(ctx, peerKey, accountID, sessionStartedAt)
|
||||
}
|
||||
return status.Errorf(codes.Unimplemented, "method MarkPeerDisconnected is not implemented")
|
||||
}
|
||||
|
||||
// DeleteAccount mock implementation of DeleteAccount from server.AccountManager interface
|
||||
func (am *MockAccountManager) DeleteAccount(ctx context.Context, accountID, userID string) error {
|
||||
if am.DeleteAccountFunc != nil {
|
||||
|
||||
@@ -69,6 +69,12 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco
|
||||
return err
|
||||
}
|
||||
|
||||
seq, err := transaction.AllocateAccountSeqID(ctx, accountID, types.AccountSeqEntityNameserverGroup)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
newNSGroup.AccountSeqID = seq
|
||||
|
||||
if err = transaction.SaveNameServerGroup(ctx, newNSGroup); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -120,6 +126,8 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun
|
||||
return err
|
||||
}
|
||||
|
||||
nsGroupToSave.AccountSeqID = oldNSGroup.AccountSeqID
|
||||
|
||||
if err = transaction.SaveNameServerGroup(ctx, nsGroupToSave); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -71,9 +71,20 @@ func (m *managerImpl) CreateNetwork(ctx context.Context, userID string, network
|
||||
|
||||
network.ID = xid.New().String()
|
||||
|
||||
err = m.store.SaveNetwork(ctx, network)
|
||||
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
seq, err := transaction.AllocateAccountSeqID(ctx, network.AccountID, serverTypes.AccountSeqEntityNetwork)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to allocate network seq id: %w", err)
|
||||
}
|
||||
network.AccountSeqID = seq
|
||||
|
||||
if err := transaction.SaveNetwork(ctx, network); err != nil {
|
||||
return fmt.Errorf("failed to save network: %w", err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to save network: %w", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
m.accountManager.StoreEvent(ctx, userID, network.ID, network.AccountID, activity.NetworkCreated, network.EventMeta())
|
||||
@@ -102,14 +113,25 @@ func (m *managerImpl) UpdateNetwork(ctx context.Context, userID string, network
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
_, err = m.store.GetNetworkByID(ctx, store.LockingStrengthUpdate, network.AccountID, network.ID)
|
||||
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
existing, err := transaction.GetNetworkByID(ctx, store.LockingStrengthUpdate, network.AccountID, network.ID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get network: %w", err)
|
||||
}
|
||||
network.AccountSeqID = existing.AccountSeqID
|
||||
|
||||
if err := transaction.SaveNetwork(ctx, network); err != nil {
|
||||
return fmt.Errorf("failed to save network: %w", err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get network: %w", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
m.accountManager.StoreEvent(ctx, userID, network.ID, network.AccountID, activity.NetworkUpdated, network.EventMeta())
|
||||
|
||||
return network, m.store.SaveNetwork(ctx, network)
|
||||
return network, nil
|
||||
}
|
||||
|
||||
func (m *managerImpl) DeleteNetwork(ctx context.Context, accountID, userID, networkID string) error {
|
||||
|
||||
@@ -252,3 +252,73 @@ func Test_UpdateNetworkFailsWithPermissionDenied(t *testing.T) {
|
||||
require.Error(t, err)
|
||||
require.Nil(t, updatedNetwork)
|
||||
}
|
||||
|
||||
// Test_CreateNetworkAllocatesSeqID verifies that CreateNetwork sets a
|
||||
// non-zero AccountSeqID on the persisted network (allocated through the
|
||||
// account_seq_counters table).
|
||||
func Test_CreateNetworkAllocatesSeqID(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
const accountID = "testAccountId"
|
||||
const userID = "testAdminId"
|
||||
|
||||
s, cleanUp, err := store.NewTestStoreFromSQL(ctx, "../testdata/networks.sql", t.TempDir())
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(cleanUp)
|
||||
|
||||
am := mock_server.MockAccountManager{}
|
||||
permissionsManager := permissions.NewManager(s)
|
||||
groupsManager := groups.NewManagerMock()
|
||||
routerManager := routers.NewManagerMock()
|
||||
resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am, nil)
|
||||
manager := NewManager(s, permissionsManager, resourcesManager, routerManager, &am)
|
||||
|
||||
created, err := manager.CreateNetwork(ctx, userID, &types.Network{
|
||||
AccountID: accountID,
|
||||
Name: "seq-allocation-test",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotZero(t, created.AccountSeqID, "CreateNetwork must allocate a non-zero AccountSeqID")
|
||||
}
|
||||
|
||||
// Test_UpdateNetworkPreservesSeqID verifies UpdateNetwork does not reset
|
||||
// AccountSeqID even when the caller passes a zero value (the shape REST
|
||||
// handlers produce because the field is `json:"-"`).
|
||||
func Test_UpdateNetworkPreservesSeqID(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
const accountID = "testAccountId"
|
||||
const userID = "testAdminId"
|
||||
|
||||
s, cleanUp, err := store.NewTestStoreFromSQL(ctx, "../testdata/networks.sql", t.TempDir())
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(cleanUp)
|
||||
|
||||
am := mock_server.MockAccountManager{}
|
||||
permissionsManager := permissions.NewManager(s)
|
||||
groupsManager := groups.NewManagerMock()
|
||||
routerManager := routers.NewManagerMock()
|
||||
resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am, nil)
|
||||
manager := NewManager(s, permissionsManager, resourcesManager, routerManager, &am)
|
||||
|
||||
created, err := manager.CreateNetwork(ctx, userID, &types.Network{
|
||||
AccountID: accountID,
|
||||
Name: "seq-preserve-original",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
originalSeq := created.AccountSeqID
|
||||
require.NotZero(t, originalSeq)
|
||||
|
||||
update := &types.Network{
|
||||
AccountID: accountID,
|
||||
ID: created.ID,
|
||||
Name: "seq-preserve-renamed",
|
||||
}
|
||||
require.Zero(t, update.AccountSeqID, "incoming struct must mirror an HTTP handler shape")
|
||||
|
||||
_, err = manager.UpdateNetwork(ctx, userID, update)
|
||||
require.NoError(t, err)
|
||||
|
||||
got, err := manager.GetNetwork(ctx, accountID, userID, created.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, originalSeq, got.AccountSeqID, "AccountSeqID must survive UpdateNetwork")
|
||||
require.Equal(t, "seq-preserve-renamed", got.Name)
|
||||
}
|
||||
|
||||
@@ -125,6 +125,12 @@ func (m *managerImpl) CreateResource(ctx context.Context, userID string, resourc
|
||||
return fmt.Errorf("failed to get network: %w", err)
|
||||
}
|
||||
|
||||
seq, err := transaction.AllocateAccountSeqID(ctx, resource.AccountID, nbtypes.AccountSeqEntityNetworkResource)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to allocate network resource seq id: %w", err)
|
||||
}
|
||||
resource.AccountSeqID = seq
|
||||
|
||||
err = transaction.SaveNetworkResource(ctx, resource)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to save network resource: %w", err)
|
||||
@@ -231,6 +237,7 @@ func (m *managerImpl) UpdateResource(ctx context.Context, userID string, resourc
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get network resource: %w", err)
|
||||
}
|
||||
resource.AccountSeqID = oldResource.AccountSeqID
|
||||
|
||||
err = transaction.SaveNetworkResource(ctx, resource)
|
||||
if err != nil {
|
||||
|
||||
@@ -32,6 +32,9 @@ type NetworkResource struct {
|
||||
ID string `gorm:"primaryKey"`
|
||||
NetworkID string `gorm:"index"`
|
||||
AccountID string `gorm:"index"`
|
||||
// AccountSeqID is a per-account monotonically increasing identifier used as the
|
||||
// compact wire id when sending NetworkMap components to capable peers.
|
||||
AccountSeqID uint32 `json:"-" gorm:"index:idx_network_resources_account_seq_id;not null;default:0"`
|
||||
Name string
|
||||
Description string
|
||||
Type NetworkResourceType
|
||||
@@ -93,17 +96,18 @@ func (n *NetworkResource) FromAPIRequest(req *api.NetworkResourceRequest) {
|
||||
|
||||
func (n *NetworkResource) Copy() *NetworkResource {
|
||||
return &NetworkResource{
|
||||
ID: n.ID,
|
||||
AccountID: n.AccountID,
|
||||
NetworkID: n.NetworkID,
|
||||
Name: n.Name,
|
||||
Description: n.Description,
|
||||
Type: n.Type,
|
||||
Address: n.Address,
|
||||
Domain: n.Domain,
|
||||
Prefix: n.Prefix,
|
||||
GroupIDs: n.GroupIDs,
|
||||
Enabled: n.Enabled,
|
||||
ID: n.ID,
|
||||
AccountID: n.AccountID,
|
||||
NetworkID: n.NetworkID,
|
||||
AccountSeqID: n.AccountSeqID,
|
||||
Name: n.Name,
|
||||
Description: n.Description,
|
||||
Type: n.Type,
|
||||
Address: n.Address,
|
||||
Domain: n.Domain,
|
||||
Prefix: n.Prefix,
|
||||
GroupIDs: n.GroupIDs,
|
||||
Enabled: n.Enabled,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -102,6 +102,12 @@ func (m *managerImpl) CreateRouter(ctx context.Context, userID string, router *t
|
||||
|
||||
router.ID = xid.New().String()
|
||||
|
||||
seq, err := transaction.AllocateAccountSeqID(ctx, router.AccountID, serverTypes.AccountSeqEntityNetworkRouter)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to allocate network router seq id: %w", err)
|
||||
}
|
||||
router.AccountSeqID = seq
|
||||
|
||||
err = transaction.SaveNetworkRouter(ctx, router)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create network router: %w", err)
|
||||
@@ -166,6 +172,22 @@ func (m *managerImpl) UpdateRouter(ctx context.Context, userID string, router *t
|
||||
return status.NewRouterNotPartOfNetworkError(router.ID, router.NetworkID)
|
||||
}
|
||||
|
||||
oldRouter, err := transaction.GetNetworkRouterByID(ctx, store.LockingStrengthNone, router.AccountID, router.ID)
|
||||
if err == nil {
|
||||
router.AccountSeqID = oldRouter.AccountSeqID
|
||||
} else if e, ok := status.FromError(err); ok && e.Type() == status.NotFound {
|
||||
// PUT-as-upsert: caller may target a brand-new router id (used by
|
||||
// the dashboard's "save" flow). Allocate a fresh account_seq_id so
|
||||
// the upsert behaves the same as Create().
|
||||
seq, allocErr := transaction.AllocateAccountSeqID(ctx, router.AccountID, serverTypes.AccountSeqEntityNetworkRouter)
|
||||
if allocErr != nil {
|
||||
return fmt.Errorf("failed to allocate network router seq id: %w", allocErr)
|
||||
}
|
||||
router.AccountSeqID = seq
|
||||
} else {
|
||||
return fmt.Errorf("failed to get existing network router: %w", err)
|
||||
}
|
||||
|
||||
err = transaction.SaveNetworkRouter(ctx, router)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update network router: %w", err)
|
||||
|
||||
@@ -195,6 +195,7 @@ func Test_UpdateRouterSuccessfully(t *testing.T) {
|
||||
if err != nil {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
router.ID = "testRouterId"
|
||||
|
||||
s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir())
|
||||
if err != nil {
|
||||
|
||||
@@ -13,6 +13,9 @@ type NetworkRouter struct {
|
||||
ID string `gorm:"primaryKey"`
|
||||
NetworkID string `gorm:"index"`
|
||||
AccountID string `gorm:"index"`
|
||||
// AccountSeqID is a per-account monotonically increasing identifier used as the
|
||||
// compact wire id when sending NetworkMap components to capable peers.
|
||||
AccountSeqID uint32 `json:"-" gorm:"index:idx_network_routers_account_seq_id;not null;default:0"`
|
||||
Peer string
|
||||
PeerGroups []string `gorm:"serializer:json"`
|
||||
Masquerade bool
|
||||
@@ -78,14 +81,15 @@ func (n *NetworkRouter) FromAPIRequest(req *api.NetworkRouterRequest) {
|
||||
|
||||
func (n *NetworkRouter) Copy() *NetworkRouter {
|
||||
return &NetworkRouter{
|
||||
ID: n.ID,
|
||||
NetworkID: n.NetworkID,
|
||||
AccountID: n.AccountID,
|
||||
Peer: n.Peer,
|
||||
PeerGroups: n.PeerGroups,
|
||||
Masquerade: n.Masquerade,
|
||||
Metric: n.Metric,
|
||||
Enabled: n.Enabled,
|
||||
ID: n.ID,
|
||||
NetworkID: n.NetworkID,
|
||||
AccountID: n.AccountID,
|
||||
AccountSeqID: n.AccountSeqID,
|
||||
Peer: n.Peer,
|
||||
PeerGroups: n.PeerGroups,
|
||||
Masquerade: n.Masquerade,
|
||||
Metric: n.Metric,
|
||||
Enabled: n.Enabled,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -7,12 +7,24 @@ import (
|
||||
)
|
||||
|
||||
type Network struct {
|
||||
ID string `gorm:"primaryKey"`
|
||||
AccountID string `gorm:"index"`
|
||||
ID string `gorm:"primaryKey"`
|
||||
AccountID string `gorm:"index"`
|
||||
|
||||
// AccountSeqID is a per-account monotonically increasing identifier used as the
|
||||
// compact wire id when sending NetworkMap components to capable peers.
|
||||
AccountSeqID uint32 `json:"-" gorm:"index:idx_networks_account_seq_id;not null;default:0"`
|
||||
|
||||
Name string
|
||||
Description string
|
||||
}
|
||||
|
||||
// HasSeqID reports whether the network has been persisted long enough to have
|
||||
// a per-account sequence id allocated. Wire encoders that key off AccountSeqID
|
||||
// must skip networks that return false here.
|
||||
func (n *Network) HasSeqID() bool {
|
||||
return n != nil && n.AccountSeqID != 0
|
||||
}
|
||||
|
||||
func NewNetwork(accountId, name, description string) *Network {
|
||||
return &Network{
|
||||
ID: xid.New().String(),
|
||||
@@ -41,13 +53,14 @@ func (n *Network) FromAPIRequest(req *api.NetworkRequest) {
|
||||
}
|
||||
}
|
||||
|
||||
// Copy returns a copy of a posture checks.
|
||||
// Copy returns a copy of a network.
|
||||
func (n *Network) Copy() *Network {
|
||||
return &Network{
|
||||
ID: n.ID,
|
||||
AccountID: n.AccountID,
|
||||
Name: n.Name,
|
||||
Description: n.Description,
|
||||
ID: n.ID,
|
||||
AccountID: n.AccountID,
|
||||
AccountSeqID: n.AccountSeqID,
|
||||
Name: n.Name,
|
||||
Description: n.Description,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -16,7 +16,6 @@ import (
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
@@ -29,6 +28,7 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
@@ -63,56 +63,64 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID
|
||||
return am.Store.GetUserPeers(ctx, store.LockingStrengthNone, accountID, userID)
|
||||
}
|
||||
|
||||
// MarkPeerConnected marks peer as connected (true) or disconnected (false)
|
||||
// syncTime is used as the LastSeen timestamp and for stale request detection
|
||||
func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubKey string, connected bool, realIP net.IP, accountID string, syncTime time.Time) error {
|
||||
var peer *nbpeer.Peer
|
||||
var settings *types.Settings
|
||||
var expired bool
|
||||
var err error
|
||||
var skipped bool
|
||||
// MarkPeerConnected marks a peer as connected with optimistic-locked
|
||||
// fencing on PeerStatus.SessionStartedAt. The sessionStartedAt argument
|
||||
// is the start time of the gRPC sync stream that owns this update,
|
||||
// expressed as Unix nanoseconds — only the call whose token is greater
|
||||
// than what's stored wins. LastSeen is written by the database itself;
|
||||
// we never pass it down.
|
||||
//
|
||||
// Disconnects use MarkPeerDisconnected and require the session to match
|
||||
// exactly; see PeerStatus.SessionStartedAt for the protocol.
|
||||
func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubKey string, realIP net.IP, accountID string, sessionStartedAt int64) error {
|
||||
start := time.Now()
|
||||
defer func() {
|
||||
am.metrics.AccountManagerMetrics().RecordPeerStatusUpdateDuration(telemetry.PeerStatusConnect, time.Since(start))
|
||||
}()
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
peer, err = transaction.GetPeerByPeerPubKey(ctx, store.LockingStrengthUpdate, peerPubKey)
|
||||
if err != nil {
|
||||
return err
|
||||
peer, err := am.Store.GetPeerByPeerPubKey(ctx, store.LockingStrengthNone, peerPubKey)
|
||||
if err != nil {
|
||||
outcome := telemetry.PeerStatusError
|
||||
if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
|
||||
outcome = telemetry.PeerStatusPeerNotFound
|
||||
}
|
||||
|
||||
if connected && !syncTime.After(peer.Status.LastSeen) {
|
||||
log.WithContext(ctx).Tracef("peer %s has newer activity (lastSeen=%s >= syncTime=%s), skipping connect",
|
||||
peer.ID, peer.Status.LastSeen.Format(time.RFC3339), syncTime.Format(time.RFC3339))
|
||||
skipped = true
|
||||
return nil
|
||||
}
|
||||
|
||||
expired, err = updatePeerStatusAndLocation(ctx, am.geo, transaction, peer, connected, realIP, accountID, syncTime)
|
||||
am.metrics.AccountManagerMetrics().CountPeerStatusUpdate(telemetry.PeerStatusConnect, outcome)
|
||||
return err
|
||||
})
|
||||
if skipped {
|
||||
}
|
||||
|
||||
updated, err := am.Store.MarkPeerConnectedIfNewerSession(ctx, accountID, peer.ID, sessionStartedAt)
|
||||
if err != nil {
|
||||
am.metrics.AccountManagerMetrics().CountPeerStatusUpdate(telemetry.PeerStatusConnect, telemetry.PeerStatusError)
|
||||
return err
|
||||
}
|
||||
if !updated {
|
||||
am.metrics.AccountManagerMetrics().CountPeerStatusUpdate(telemetry.PeerStatusConnect, telemetry.PeerStatusStale)
|
||||
log.WithContext(ctx).Tracef("peer %s already has a newer session in store, skipping connect", peer.ID)
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
am.metrics.AccountManagerMetrics().CountPeerStatusUpdate(telemetry.PeerStatusConnect, telemetry.PeerStatusApplied)
|
||||
|
||||
if am.geo != nil && realIP != nil {
|
||||
am.updatePeerLocationIfChanged(ctx, accountID, peer, realIP)
|
||||
}
|
||||
|
||||
expired := peer.Status != nil && peer.Status.LoginExpired
|
||||
|
||||
if peer.AddedWithSSOLogin() {
|
||||
settings, err = am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
|
||||
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if peer.LoginExpirationEnabled && settings.PeerLoginExpirationEnabled {
|
||||
am.schedulePeerLoginExpiration(ctx, accountID)
|
||||
}
|
||||
|
||||
if peer.InactivityExpirationEnabled && settings.PeerInactivityExpirationEnabled {
|
||||
am.checkAndSchedulePeerInactivityExpiration(ctx, accountID)
|
||||
}
|
||||
}
|
||||
|
||||
if expired {
|
||||
err = am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peer.ID})
|
||||
if err != nil {
|
||||
if err = am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peer.ID}); err != nil {
|
||||
return fmt.Errorf("notify network map controller of peer update: %w", err)
|
||||
}
|
||||
}
|
||||
@@ -120,41 +128,60 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK
|
||||
return nil
|
||||
}
|
||||
|
||||
func updatePeerStatusAndLocation(ctx context.Context, geo geolocation.Geolocation, transaction store.Store, peer *nbpeer.Peer, connected bool, realIP net.IP, accountID string, syncTime time.Time) (bool, error) {
|
||||
oldStatus := peer.Status.Copy()
|
||||
newStatus := oldStatus
|
||||
newStatus.LastSeen = syncTime
|
||||
newStatus.Connected = connected
|
||||
// whenever peer got connected that means that it logged in successfully
|
||||
if newStatus.Connected {
|
||||
newStatus.LoginExpired = false
|
||||
}
|
||||
peer.Status = newStatus
|
||||
// MarkPeerDisconnected marks a peer as disconnected, but only when the
|
||||
// stored session token matches the one passed in. A mismatch means a
|
||||
// newer stream has already taken ownership of the peer — disconnects from
|
||||
// the older stream are ignored. LastSeen is written by the database.
|
||||
func (am *DefaultAccountManager) MarkPeerDisconnected(ctx context.Context, peerPubKey string, accountID string, sessionStartedAt int64) error {
|
||||
start := time.Now()
|
||||
defer func() {
|
||||
am.metrics.AccountManagerMetrics().RecordPeerStatusUpdateDuration(telemetry.PeerStatusDisconnect, time.Since(start))
|
||||
}()
|
||||
|
||||
if geo != nil && realIP != nil {
|
||||
location, err := geo.Lookup(realIP)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("failed to get location for peer %s realip: [%s]: %v", peer.ID, realIP.String(), err)
|
||||
} else {
|
||||
peer.Location.ConnectionIP = realIP
|
||||
peer.Location.CountryCode = location.Country.ISOCode
|
||||
peer.Location.CityName = location.City.Names.En
|
||||
peer.Location.GeoNameID = location.City.GeonameID
|
||||
err = transaction.SavePeerLocation(ctx, accountID, peer)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("could not store location for peer %s: %s", peer.ID, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Debugf("saving peer status for peer %s is connected: %t", peer.ID, connected)
|
||||
|
||||
err := transaction.SavePeerStatus(ctx, accountID, peer.ID, *newStatus)
|
||||
peer, err := am.Store.GetPeerByPeerPubKey(ctx, store.LockingStrengthNone, peerPubKey)
|
||||
if err != nil {
|
||||
return false, err
|
||||
outcome := telemetry.PeerStatusError
|
||||
if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
|
||||
outcome = telemetry.PeerStatusPeerNotFound
|
||||
}
|
||||
am.metrics.AccountManagerMetrics().CountPeerStatusUpdate(telemetry.PeerStatusDisconnect, outcome)
|
||||
return err
|
||||
}
|
||||
|
||||
return oldStatus.LoginExpired, nil
|
||||
updated, err := am.Store.MarkPeerDisconnectedIfSameSession(ctx, accountID, peer.ID, sessionStartedAt)
|
||||
if err != nil {
|
||||
am.metrics.AccountManagerMetrics().CountPeerStatusUpdate(telemetry.PeerStatusDisconnect, telemetry.PeerStatusError)
|
||||
return err
|
||||
}
|
||||
if !updated {
|
||||
am.metrics.AccountManagerMetrics().CountPeerStatusUpdate(telemetry.PeerStatusDisconnect, telemetry.PeerStatusStale)
|
||||
log.WithContext(ctx).Tracef("peer %s session token mismatch on disconnect (token=%d), skipping",
|
||||
peer.ID, sessionStartedAt)
|
||||
return nil
|
||||
}
|
||||
am.metrics.AccountManagerMetrics().CountPeerStatusUpdate(telemetry.PeerStatusDisconnect, telemetry.PeerStatusApplied)
|
||||
return nil
|
||||
}
|
||||
|
||||
// updatePeerLocationIfChanged refreshes the geolocation on a separate
|
||||
// row update, only when the connection IP actually changed. Geo lookups
|
||||
// are expensive so we skip same-IP reconnects.
|
||||
func (am *DefaultAccountManager) updatePeerLocationIfChanged(ctx context.Context, accountID string, peer *nbpeer.Peer, realIP net.IP) {
|
||||
if peer.Location.ConnectionIP != nil && peer.Location.ConnectionIP.Equal(realIP) {
|
||||
return
|
||||
}
|
||||
location, err := am.geo.Lookup(realIP)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("failed to get location for peer %s realip: [%s]: %v", peer.ID, realIP.String(), err)
|
||||
return
|
||||
}
|
||||
peer.Location.ConnectionIP = realIP
|
||||
peer.Location.CountryCode = location.Country.ISOCode
|
||||
peer.Location.CityName = location.City.Names.En
|
||||
peer.Location.GeoNameID = location.City.GeonameID
|
||||
if err := am.Store.SavePeerLocation(ctx, accountID, peer); err != nil {
|
||||
log.WithContext(ctx).Warnf("could not store location for peer %s: %s", peer.ID, err)
|
||||
}
|
||||
}
|
||||
|
||||
// UpdatePeer updates peer. Only Peer.Name, Peer.SSHEnabled, Peer.LoginExpirationEnabled and Peer.InactivityExpirationEnabled can be updated.
|
||||
|
||||
@@ -13,8 +13,9 @@ import (
|
||||
|
||||
// Peer capability constants mirror the proto enum values.
|
||||
const (
|
||||
PeerCapabilitySourcePrefixes int32 = 1
|
||||
PeerCapabilityIPv6Overlay int32 = 2
|
||||
PeerCapabilitySourcePrefixes int32 = 1
|
||||
PeerCapabilityIPv6Overlay int32 = 2
|
||||
PeerCapabilityComponentNetworkMap int32 = 3
|
||||
)
|
||||
|
||||
// Peer represents a machine connected to the network.
|
||||
@@ -74,8 +75,19 @@ type ProxyMeta struct {
|
||||
}
|
||||
|
||||
type PeerStatus struct { //nolint:revive
|
||||
// LastSeen is the last time peer was connected to the management service
|
||||
// LastSeen is the last time the peer status was updated (i.e. the last
|
||||
// time we observed the peer being alive on a sync stream). Written by
|
||||
// the database (CURRENT_TIMESTAMP) — callers do not supply it.
|
||||
LastSeen time.Time
|
||||
// SessionStartedAt records when the currently-active sync stream began,
|
||||
// stored as Unix nanoseconds. It acts as the optimistic-locking token
|
||||
// for status updates: a stream is only allowed to mutate the peer's
|
||||
// status when its own token strictly exceeds the stored token (when connecting)
|
||||
// or matches it exactly (for disconnects). Zero means "no
|
||||
// active session". Integer nanoseconds are used so equality is
|
||||
// precision-safe across drivers, and so the predicates compose to a
|
||||
// single bigint comparison.
|
||||
SessionStartedAt int64
|
||||
// Connected indicates whether peer is connected to the management service or not
|
||||
Connected bool
|
||||
// LoginExpired
|
||||
@@ -236,6 +248,14 @@ func (p *Peer) SupportsSourcePrefixes() bool {
|
||||
return p.HasCapability(PeerCapabilitySourcePrefixes)
|
||||
}
|
||||
|
||||
// SupportsComponentNetworkMap reports whether the peer assembles its
|
||||
// NetworkMap from server-shipped components instead of consuming a fully
|
||||
// expanded NetworkMap. Determines whether the network_map controller skips
|
||||
// Calculate() server-side and emits the components envelope.
|
||||
func (p *Peer) SupportsComponentNetworkMap() bool {
|
||||
return p.HasCapability(PeerCapabilityComponentNetworkMap)
|
||||
}
|
||||
|
||||
func capabilitiesEqual(a, b []int32) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
@@ -375,10 +395,14 @@ func (p *Peer) EventMeta(dnsDomain string) map[string]any {
|
||||
return meta
|
||||
}
|
||||
|
||||
// Copy PeerStatus
|
||||
// Copy PeerStatus. SessionStartedAt must be propagated so clone-based
|
||||
// callers (Peer.Copy, MarkLoginExpired, UpdateLastLogin) don't silently
|
||||
// reset the fencing token to zero — that would let any subsequent
|
||||
// SavePeerStatus write reopen the optimistic-lock window.
|
||||
func (p *PeerStatus) Copy() *PeerStatus {
|
||||
return &PeerStatus{
|
||||
LastSeen: p.LastSeen,
|
||||
SessionStartedAt: p.SessionStartedAt,
|
||||
Connected: p.Connected,
|
||||
LoginExpired: p.LoginExpired,
|
||||
RequiresApproval: p.RequiresApproval,
|
||||
|
||||
@@ -69,6 +69,8 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
|
||||
return err
|
||||
}
|
||||
|
||||
policy.AccountSeqID = existingPolicy.AccountSeqID
|
||||
|
||||
if err = transaction.SavePolicy(ctx, policy); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -78,6 +80,12 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
|
||||
return err
|
||||
}
|
||||
|
||||
seq, err := transaction.AllocateAccountSeqID(ctx, accountID, types.AccountSeqEntityPolicy)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
policy.AccountSeqID = seq
|
||||
|
||||
if err = transaction.CreatePolicy(ctx, policy); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -47,10 +47,21 @@ type Checks struct {
|
||||
// AccountID is a reference to the Account that this object belongs
|
||||
AccountID string `json:"-" gorm:"index"`
|
||||
|
||||
// AccountSeqID is a per-account monotonically increasing identifier used as the
|
||||
// compact wire id when sending NetworkMap components to capable peers.
|
||||
AccountSeqID uint32 `json:"-" gorm:"index:idx_posture_checks_account_seq_id;not null;default:0"`
|
||||
|
||||
// Checks is a set of objects that perform the actual checks
|
||||
Checks ChecksDefinition `gorm:"serializer:json"`
|
||||
}
|
||||
|
||||
// HasSeqID reports whether the posture check has been persisted long enough
|
||||
// to have a per-account sequence id allocated. Wire encoders that key off
|
||||
// AccountSeqID must skip checks that return false here.
|
||||
func (pc *Checks) HasSeqID() bool {
|
||||
return pc != nil && pc.AccountSeqID != 0
|
||||
}
|
||||
|
||||
// ChecksDefinition contains definition of actual check
|
||||
type ChecksDefinition struct {
|
||||
NBVersionCheck *NBVersionCheck `json:",omitempty"`
|
||||
@@ -121,11 +132,12 @@ func (*Checks) TableName() string {
|
||||
// Copy returns a copy of a posture checks.
|
||||
func (pc *Checks) Copy() *Checks {
|
||||
checks := &Checks{
|
||||
ID: pc.ID,
|
||||
Name: pc.Name,
|
||||
Description: pc.Description,
|
||||
AccountID: pc.AccountID,
|
||||
Checks: pc.Checks.Copy(),
|
||||
ID: pc.ID,
|
||||
Name: pc.Name,
|
||||
Description: pc.Description,
|
||||
AccountID: pc.AccountID,
|
||||
AccountSeqID: pc.AccountSeqID,
|
||||
Checks: pc.Checks.Copy(),
|
||||
}
|
||||
return checks
|
||||
}
|
||||
|
||||
@@ -51,12 +51,24 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI
|
||||
}
|
||||
|
||||
if isUpdate {
|
||||
existing, err := transaction.GetPostureChecksByID(ctx, store.LockingStrengthNone, accountID, postureChecks.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
postureChecks.AccountSeqID = existing.AccountSeqID
|
||||
|
||||
updateAccountPeers, err = arePostureCheckChangesAffectPeers(ctx, transaction, accountID, postureChecks.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
action = activity.PostureCheckUpdated
|
||||
} else {
|
||||
seq, err := transaction.AllocateAccountSeqID(ctx, accountID, types.AccountSeqEntityPostureCheck)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
postureChecks.AccountSeqID = seq
|
||||
}
|
||||
|
||||
postureChecks.AccountID = accountID
|
||||
|
||||
@@ -563,3 +563,61 @@ func TestArePostureCheckChangesAffectPeers(t *testing.T) {
|
||||
assert.False(t, result)
|
||||
})
|
||||
}
|
||||
|
||||
// TestSavePostureChecks_AllocatesSeqIDOnCreate verifies that the create path
|
||||
// (no incoming ID) allocates a non-zero AccountSeqID via the
|
||||
// account_seq_counters table.
|
||||
func TestSavePostureChecks_AllocatesSeqIDOnCreate(t *testing.T) {
|
||||
am, _, err := createManager(t)
|
||||
require.NoError(t, err)
|
||||
|
||||
account, err := initTestPostureChecksAccount(am)
|
||||
require.NoError(t, err)
|
||||
|
||||
created, err := am.SavePostureChecks(context.Background(), account.Id, adminUserID, &posture.Checks{
|
||||
Name: "seq-allocation-test",
|
||||
Checks: posture.ChecksDefinition{
|
||||
NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.26.0"},
|
||||
},
|
||||
}, true)
|
||||
require.NoError(t, err)
|
||||
require.NotZero(t, created.AccountSeqID, "SavePostureChecks on create must allocate a non-zero AccountSeqID")
|
||||
}
|
||||
|
||||
// TestSavePostureChecks_PreservesSeqIDOnUpdate verifies the update path does
|
||||
// not reset AccountSeqID even when the caller passes a zero value (REST
|
||||
// handler shape, because the field is `json:"-"`).
|
||||
func TestSavePostureChecks_PreservesSeqIDOnUpdate(t *testing.T) {
|
||||
am, _, err := createManager(t)
|
||||
require.NoError(t, err)
|
||||
|
||||
account, err := initTestPostureChecksAccount(am)
|
||||
require.NoError(t, err)
|
||||
|
||||
created, err := am.SavePostureChecks(context.Background(), account.Id, adminUserID, &posture.Checks{
|
||||
Name: "seq-preserve-original",
|
||||
Checks: posture.ChecksDefinition{
|
||||
NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.26.0"},
|
||||
},
|
||||
}, true)
|
||||
require.NoError(t, err)
|
||||
originalSeq := created.AccountSeqID
|
||||
require.NotZero(t, originalSeq)
|
||||
|
||||
update := &posture.Checks{
|
||||
ID: created.ID,
|
||||
Name: "seq-preserve-renamed",
|
||||
Checks: posture.ChecksDefinition{
|
||||
NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.27.0"},
|
||||
},
|
||||
}
|
||||
require.Zero(t, update.AccountSeqID, "incoming struct must mirror an HTTP handler shape")
|
||||
|
||||
_, err = am.SavePostureChecks(context.Background(), account.Id, adminUserID, update, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
got, err := am.GetPostureChecks(context.Background(), account.Id, created.ID, adminUserID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, originalSeq, got.AccountSeqID, "AccountSeqID must survive SavePostureChecks update")
|
||||
require.Equal(t, "seq-preserve-renamed", got.Name)
|
||||
}
|
||||
|
||||
@@ -178,6 +178,12 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
|
||||
return err
|
||||
}
|
||||
|
||||
seq, err := transaction.AllocateAccountSeqID(ctx, accountID, types.AccountSeqEntityRoute)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
newRoute.AccountSeqID = seq
|
||||
|
||||
if err = transaction.SaveRoute(ctx, newRoute); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -231,6 +237,7 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
|
||||
return err
|
||||
}
|
||||
routeToSave.AccountID = accountID
|
||||
routeToSave.AccountSeqID = oldRoute.AccountSeqID
|
||||
|
||||
if err = transaction.SaveRoute(ctx, routeToSave); err != nil {
|
||||
return err
|
||||
|
||||
506
management/server/store/account_seq_test.go
Normal file
506
management/server/store/account_seq_test.go
Normal file
@@ -0,0 +1,506 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
var errRollback = errors.New("intentional rollback")
|
||||
|
||||
func TestAllocateAccountSeqID_SequentialPerAccount(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
const accA = "acc-a"
|
||||
const accB = "acc-b"
|
||||
|
||||
require.NoError(t, store.ExecuteInTransaction(ctx, func(tx Store) error {
|
||||
got, err := tx.AllocateAccountSeqID(ctx, accA, types.AccountSeqEntityPolicy)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, uint32(1), got)
|
||||
|
||||
got, err = tx.AllocateAccountSeqID(ctx, accA, types.AccountSeqEntityPolicy)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, uint32(2), got)
|
||||
|
||||
got, err = tx.AllocateAccountSeqID(ctx, accB, types.AccountSeqEntityPolicy)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, uint32(1), got, "different account starts from 1")
|
||||
|
||||
got, err = tx.AllocateAccountSeqID(ctx, accA, types.AccountSeqEntityGroup)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, uint32(1), got, "different entity starts from 1")
|
||||
|
||||
return nil
|
||||
}))
|
||||
|
||||
require.NoError(t, store.ExecuteInTransaction(ctx, func(tx Store) error {
|
||||
got, err := tx.AllocateAccountSeqID(ctx, accA, types.AccountSeqEntityPolicy)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, uint32(3), got, "counter persists across transactions")
|
||||
return nil
|
||||
}))
|
||||
}
|
||||
|
||||
func TestPolicyBackfill_AssignsSeqIDsToExistingPolicies(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
const accountID = "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
|
||||
policies, err := store.GetAccountPolicies(ctx, LockingStrengthNone, accountID)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, policies, "test fixture must have policies")
|
||||
|
||||
seen := make(map[uint32]bool)
|
||||
for _, p := range policies {
|
||||
require.NotZero(t, p.AccountSeqID, "policy %s must have a non-zero AccountSeqID after migration", p.ID)
|
||||
require.False(t, seen[p.AccountSeqID], "duplicate AccountSeqID %d in account %s", p.AccountSeqID, accountID)
|
||||
seen[p.AccountSeqID] = true
|
||||
}
|
||||
}
|
||||
|
||||
func TestPolicyUpdate_PreservesSeqID(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
const accountID = "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
const policyID = "cs1tnh0hhcjnqoiuebf0"
|
||||
|
||||
original, err := store.GetPolicyByID(ctx, LockingStrengthNone, accountID, policyID)
|
||||
require.NoError(t, err)
|
||||
originalSeq := original.AccountSeqID
|
||||
require.NotZero(t, originalSeq, "fixture must have non-zero AccountSeqID after backfill")
|
||||
|
||||
updated := &types.Policy{
|
||||
ID: policyID,
|
||||
AccountID: accountID,
|
||||
Name: "renamed",
|
||||
Enabled: false,
|
||||
Rules: original.Rules,
|
||||
}
|
||||
require.Zero(t, updated.AccountSeqID, "incoming struct should have zero AccountSeqID like an HTTP handler would")
|
||||
|
||||
require.NoError(t, store.SavePolicy(ctx, updated))
|
||||
|
||||
got, err := store.GetPolicyByID(ctx, LockingStrengthNone, accountID, policyID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, originalSeq, got.AccountSeqID, "AccountSeqID must not be reset by update path")
|
||||
require.Equal(t, "renamed", got.Name)
|
||||
}
|
||||
|
||||
func TestGroupUpdate_PreservesSeqID(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
const accountID = "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
|
||||
groups, err := store.GetAccountGroups(ctx, LockingStrengthNone, accountID)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, groups)
|
||||
|
||||
original := groups[0]
|
||||
originalSeq := original.AccountSeqID
|
||||
require.NotZero(t, originalSeq)
|
||||
|
||||
updated := &types.Group{
|
||||
ID: original.ID,
|
||||
AccountID: accountID,
|
||||
Name: "renamed",
|
||||
Issued: original.Issued,
|
||||
}
|
||||
require.Zero(t, updated.AccountSeqID)
|
||||
|
||||
require.NoError(t, store.UpdateGroup(ctx, updated))
|
||||
|
||||
got, err := store.GetGroupByID(ctx, LockingStrengthNone, accountID, original.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, originalSeq, got.AccountSeqID, "AccountSeqID must not be reset by UpdateGroup")
|
||||
require.Equal(t, "renamed", got.Name)
|
||||
}
|
||||
|
||||
func TestSaveAccount_AllocatesSeqIDsForDefaultGroupAndPolicy(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
const accountID = "save-account-seqid-test"
|
||||
|
||||
account := &types.Account{
|
||||
Id: accountID,
|
||||
CreatedBy: "user1",
|
||||
Domain: "example.test",
|
||||
DNSSettings: types.DNSSettings{},
|
||||
Settings: &types.Settings{},
|
||||
Network: &types.Network{
|
||||
Identifier: "net-test",
|
||||
},
|
||||
Users: map[string]*types.User{
|
||||
"user1": {Id: "user1", AccountID: accountID, Role: types.UserRoleOwner},
|
||||
},
|
||||
}
|
||||
require.NoError(t, account.AddAllGroup(false), "AddAllGroup should populate default Group + Policy")
|
||||
require.Len(t, account.Groups, 1, "default 'All' group must be present")
|
||||
require.Len(t, account.Policies, 1, "default policy must be present")
|
||||
|
||||
for _, g := range account.Groups {
|
||||
require.Zero(t, g.AccountSeqID, "default group must start with seq=0")
|
||||
}
|
||||
require.Zero(t, account.Policies[0].AccountSeqID, "default policy must start with seq=0")
|
||||
|
||||
require.NoError(t, store.SaveAccount(ctx, account))
|
||||
|
||||
groups, err := store.GetAccountGroups(ctx, LockingStrengthNone, accountID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, groups, 1)
|
||||
require.NotZerof(t, groups[0].AccountSeqID, "default group must have seq>0 after SaveAccount")
|
||||
|
||||
policies, err := store.GetAccountPolicies(ctx, LockingStrengthNone, accountID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, policies, 1)
|
||||
require.NotZerof(t, policies[0].AccountSeqID, "default policy must have seq>0 after SaveAccount")
|
||||
|
||||
require.ErrorIs(t, store.ExecuteInTransaction(ctx, func(tx Store) error {
|
||||
next, err := tx.AllocateAccountSeqID(ctx, accountID, types.AccountSeqEntityGroup)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, groups[0].AccountSeqID+1, next, "next group seq must be max+1")
|
||||
|
||||
next, err = tx.AllocateAccountSeqID(ctx, accountID, types.AccountSeqEntityPolicy)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, policies[0].AccountSeqID+1, next, "next policy seq must be max+1")
|
||||
return errRollback
|
||||
}), errRollback)
|
||||
}
|
||||
|
||||
func TestSaveAccount_PreservesExistingSeqIDs(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
const accountID = "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
|
||||
account, err := store.GetAccount(ctx, accountID)
|
||||
require.NoError(t, err)
|
||||
|
||||
groupSeqs := make(map[string]uint32)
|
||||
policySeqs := make(map[string]uint32)
|
||||
routeSeqs := make(map[route.ID]uint32)
|
||||
nsgSeqs := make(map[string]uint32)
|
||||
resourceSeqs := make(map[string]uint32)
|
||||
routerSeqs := make(map[string]uint32)
|
||||
networkSeqs := make(map[string]uint32)
|
||||
|
||||
for _, g := range account.Groups {
|
||||
require.NotZero(t, g.AccountSeqID, "fixture group must have seq>0 after backfill")
|
||||
groupSeqs[g.ID] = g.AccountSeqID
|
||||
}
|
||||
for _, p := range account.Policies {
|
||||
require.NotZero(t, p.AccountSeqID, "fixture policy must have seq>0")
|
||||
policySeqs[p.ID] = p.AccountSeqID
|
||||
}
|
||||
for _, r := range account.Routes {
|
||||
require.NotZero(t, r.AccountSeqID, "fixture route must have seq>0")
|
||||
routeSeqs[r.ID] = r.AccountSeqID
|
||||
}
|
||||
for _, n := range account.NameServerGroups {
|
||||
require.NotZero(t, n.AccountSeqID, "fixture name_server_group must have seq>0")
|
||||
nsgSeqs[n.ID] = n.AccountSeqID
|
||||
}
|
||||
for _, nr := range account.NetworkResources {
|
||||
require.NotZero(t, nr.AccountSeqID, "fixture network_resource must have seq>0")
|
||||
resourceSeqs[nr.ID] = nr.AccountSeqID
|
||||
}
|
||||
for _, nr := range account.NetworkRouters {
|
||||
require.NotZero(t, nr.AccountSeqID, "fixture network_router must have seq>0")
|
||||
routerSeqs[nr.ID] = nr.AccountSeqID
|
||||
}
|
||||
for _, n := range account.Networks {
|
||||
require.NotZero(t, n.AccountSeqID, "fixture network must have seq>0 after backfill")
|
||||
networkSeqs[n.ID] = n.AccountSeqID
|
||||
}
|
||||
|
||||
require.NoError(t, store.SaveAccount(ctx, account))
|
||||
|
||||
after, err := store.GetAccount(ctx, accountID)
|
||||
require.NoError(t, err)
|
||||
for _, g := range after.Groups {
|
||||
require.Equal(t, groupSeqs[g.ID], g.AccountSeqID, "group %s seq must be preserved on re-save", g.ID)
|
||||
}
|
||||
for _, p := range after.Policies {
|
||||
require.Equal(t, policySeqs[p.ID], p.AccountSeqID, "policy %s seq must be preserved", p.ID)
|
||||
}
|
||||
for _, r := range after.Routes {
|
||||
require.Equal(t, routeSeqs[r.ID], r.AccountSeqID, "route %s seq must be preserved (slice-of-value addressability)", r.ID)
|
||||
}
|
||||
for _, n := range after.NameServerGroups {
|
||||
require.Equal(t, nsgSeqs[n.ID], n.AccountSeqID, "name_server_group %s seq must be preserved (slice-of-value addressability)", n.ID)
|
||||
}
|
||||
for _, nr := range after.NetworkResources {
|
||||
require.Equal(t, resourceSeqs[nr.ID], nr.AccountSeqID, "network_resource %s seq must be preserved", nr.ID)
|
||||
}
|
||||
for _, nr := range after.NetworkRouters {
|
||||
require.Equal(t, routerSeqs[nr.ID], nr.AccountSeqID, "network_router %s seq must be preserved", nr.ID)
|
||||
}
|
||||
for _, n := range after.Networks {
|
||||
require.Equal(t, networkSeqs[n.ID], n.AccountSeqID, "network %s seq must be preserved", n.ID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaveAccount_AllocatesSeqIDsForAllEntityTypes(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
const accountID = "save-account-all-entities"
|
||||
|
||||
addr, err := netip.ParseAddr("8.8.8.8")
|
||||
require.NoError(t, err)
|
||||
|
||||
account := &types.Account{
|
||||
Id: accountID,
|
||||
CreatedBy: "user1",
|
||||
Domain: "example.test",
|
||||
Settings: &types.Settings{},
|
||||
Network: &types.Network{Identifier: "net-test"},
|
||||
Users: map[string]*types.User{
|
||||
"user1": {Id: "user1", AccountID: accountID, Role: types.UserRoleOwner},
|
||||
},
|
||||
Groups: map[string]*types.Group{
|
||||
"g1": {ID: "g1", AccountID: accountID, Name: "g1", Issued: types.GroupIssuedAPI},
|
||||
},
|
||||
Policies: []*types.Policy{
|
||||
{ID: "p1", AccountID: accountID, Name: "p1", Enabled: true,
|
||||
Rules: []*types.PolicyRule{{ID: "r1", PolicyID: "p1", Enabled: true}}},
|
||||
},
|
||||
Routes: map[route.ID]*route.Route{
|
||||
"rt1": {ID: "rt1", AccountID: accountID, NetID: "net1", Peer: "peer1"},
|
||||
},
|
||||
NameServerGroups: map[string]*nbdns.NameServerGroup{
|
||||
"nsg1": {ID: "nsg1", AccountID: accountID, Name: "nsg1", Enabled: true,
|
||||
NameServers: []nbdns.NameServer{{IP: addr, NSType: nbdns.UDPNameServerType, Port: 53}}},
|
||||
},
|
||||
NetworkResources: []*resourceTypes.NetworkResource{
|
||||
{ID: "nr1", AccountID: accountID, NetworkID: "net1", Name: "res1", Enabled: true},
|
||||
},
|
||||
NetworkRouters: []*routerTypes.NetworkRouter{
|
||||
{ID: "nrt1", AccountID: accountID, NetworkID: "net1", Peer: "peer1", Enabled: true},
|
||||
},
|
||||
Networks: []*networkTypes.Network{
|
||||
{ID: "n1", AccountID: accountID, Name: "n1"},
|
||||
},
|
||||
PostureChecks: []*posture.Checks{
|
||||
{ID: "pc1", AccountID: accountID, Name: "pc1",
|
||||
Checks: posture.ChecksDefinition{
|
||||
NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.26.0"},
|
||||
}},
|
||||
},
|
||||
}
|
||||
|
||||
require.NoError(t, store.SaveAccount(ctx, account))
|
||||
|
||||
after, err := store.GetAccount(ctx, accountID)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Len(t, after.Groups, 1)
|
||||
require.Len(t, after.Policies, 1)
|
||||
require.Len(t, after.Routes, 1)
|
||||
require.Len(t, after.NameServerGroups, 1)
|
||||
require.Len(t, after.NetworkResources, 1)
|
||||
require.Len(t, after.NetworkRouters, 1)
|
||||
require.Len(t, after.Networks, 1)
|
||||
require.Len(t, after.PostureChecks, 1)
|
||||
|
||||
for _, g := range after.Groups {
|
||||
require.NotZero(t, g.AccountSeqID, "group seq must be allocated")
|
||||
}
|
||||
for _, p := range after.Policies {
|
||||
require.NotZero(t, p.AccountSeqID, "policy seq must be allocated")
|
||||
}
|
||||
for _, r := range after.Routes {
|
||||
require.NotZero(t, r.AccountSeqID, "route seq must be allocated (slice-of-value addressability)")
|
||||
}
|
||||
for _, n := range after.NameServerGroups {
|
||||
require.NotZero(t, n.AccountSeqID, "name_server_group seq must be allocated (slice-of-value addressability)")
|
||||
}
|
||||
for _, nr := range after.NetworkResources {
|
||||
require.NotZero(t, nr.AccountSeqID, "network_resource seq must be allocated")
|
||||
}
|
||||
for _, nr := range after.NetworkRouters {
|
||||
require.NotZero(t, nr.AccountSeqID, "network_router seq must be allocated")
|
||||
}
|
||||
for _, n := range after.Networks {
|
||||
require.NotZero(t, n.AccountSeqID, "network seq must be allocated")
|
||||
}
|
||||
for _, pc := range after.PostureChecks {
|
||||
require.NotZero(t, pc.AccountSeqID, "posture_check seq must be allocated")
|
||||
}
|
||||
|
||||
require.NoError(t, store.SaveAccount(ctx, after))
|
||||
final, err := store.GetAccount(ctx, accountID)
|
||||
require.NoError(t, err)
|
||||
for _, r := range final.Routes {
|
||||
require.Equal(t, after.Routes[r.ID].AccountSeqID, r.AccountSeqID, "route seq preserved on re-save")
|
||||
}
|
||||
for _, n := range final.NameServerGroups {
|
||||
require.Equal(t, after.NameServerGroups[n.ID].AccountSeqID, n.AccountSeqID, "name_server_group seq preserved on re-save")
|
||||
}
|
||||
afterByID := map[string]uint32{}
|
||||
for _, n := range after.Networks {
|
||||
afterByID[n.ID] = n.AccountSeqID
|
||||
}
|
||||
for _, n := range final.Networks {
|
||||
require.Equal(t, afterByID[n.ID], n.AccountSeqID, "network seq preserved on re-save")
|
||||
}
|
||||
afterPCByID := map[string]uint32{}
|
||||
for _, pc := range after.PostureChecks {
|
||||
afterPCByID[pc.ID] = pc.AccountSeqID
|
||||
}
|
||||
for _, pc := range final.PostureChecks {
|
||||
require.Equal(t, afterPCByID[pc.ID], pc.AccountSeqID, "posture_check seq preserved on re-save")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAllocateAccountSeqID_ConcurrentSameAccountEntity(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
const accountID = "concurrent-test"
|
||||
const entity = types.AccountSeqEntityPolicy
|
||||
const goroutines = 32
|
||||
|
||||
type result struct {
|
||||
seq uint32
|
||||
err error
|
||||
}
|
||||
results := make(chan result, goroutines)
|
||||
start := make(chan struct{})
|
||||
|
||||
for i := 0; i < goroutines; i++ {
|
||||
go func() {
|
||||
<-start
|
||||
var allocated uint32
|
||||
err := store.ExecuteInTransaction(ctx, func(tx Store) error {
|
||||
seq, err := tx.AllocateAccountSeqID(ctx, accountID, entity)
|
||||
allocated = seq
|
||||
return err
|
||||
})
|
||||
results <- result{seq: allocated, err: err}
|
||||
}()
|
||||
}
|
||||
close(start)
|
||||
|
||||
seen := make(map[uint32]int, goroutines)
|
||||
for i := 0; i < goroutines; i++ {
|
||||
r := <-results
|
||||
require.NoError(t, r.err, "concurrent allocate must not fail")
|
||||
require.NotZero(t, r.seq, "allocated seq must be non-zero")
|
||||
seen[r.seq]++
|
||||
}
|
||||
|
||||
require.Lenf(t, seen, goroutines, "every concurrent allocation must yield a unique id; got duplicates in %v", seen)
|
||||
for i := uint32(1); i <= goroutines; i++ {
|
||||
require.Equalf(t, 1, seen[i], "id %d must appear exactly once across concurrent allocations", i)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStoreCreateGroups_AllocatedSeqIDIsNotClobbered(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
const accountID = "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
|
||||
groups := []*types.Group{
|
||||
{ID: "seq-test-g1", AccountID: accountID, Name: "g1", Issued: "jwt", AccountSeqID: 7777},
|
||||
{ID: "seq-test-g2", AccountID: accountID, Name: "g2", Issued: "jwt", AccountSeqID: 7778},
|
||||
}
|
||||
require.NoError(t, store.CreateGroups(ctx, accountID, groups))
|
||||
|
||||
for _, want := range groups {
|
||||
got, err := store.GetGroupByID(ctx, LockingStrengthNone, accountID, want.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, want.AccountSeqID, got.AccountSeqID, "seq id from caller must be persisted on insert")
|
||||
}
|
||||
|
||||
groups[0].Name = "g1-renamed"
|
||||
groups[0].AccountSeqID = 0
|
||||
require.NoError(t, store.CreateGroups(ctx, accountID, groups[:1]))
|
||||
|
||||
got, err := store.GetGroupByID(ctx, LockingStrengthNone, accountID, "seq-test-g1")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "g1-renamed", got.Name, "upsert path still updates other columns")
|
||||
require.Equal(t, uint32(7777), got.AccountSeqID, "upsert path must NOT overwrite account_seq_id")
|
||||
}
|
||||
|
||||
func TestPolicyCreate_AllocatesSeqID(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
const accountID = "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
|
||||
existing, err := store.GetAccountPolicies(ctx, LockingStrengthNone, accountID)
|
||||
require.NoError(t, err)
|
||||
maxSeq := uint32(0)
|
||||
for _, p := range existing {
|
||||
if p.AccountSeqID > maxSeq {
|
||||
maxSeq = p.AccountSeqID
|
||||
}
|
||||
}
|
||||
|
||||
require.NoError(t, store.ExecuteInTransaction(ctx, func(tx Store) error {
|
||||
seq, err := tx.AllocateAccountSeqID(ctx, accountID, types.AccountSeqEntityPolicy)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
require.Equal(t, maxSeq+1, seq, "next id should be max+1 after backfill")
|
||||
|
||||
newPolicy := &types.Policy{
|
||||
ID: "bench-new-policy",
|
||||
AccountID: accountID,
|
||||
AccountSeqID: seq,
|
||||
Enabled: true,
|
||||
Rules: []*types.PolicyRule{{
|
||||
ID: "bench-new-policy-rule",
|
||||
PolicyID: "bench-new-policy",
|
||||
Enabled: true,
|
||||
Action: types.PolicyTrafficActionAccept,
|
||||
Sources: []string{"groupA"},
|
||||
Destinations: []string{"groupC"},
|
||||
Bidirectional: true,
|
||||
}},
|
||||
}
|
||||
return tx.CreatePolicy(ctx, newPolicy)
|
||||
}))
|
||||
|
||||
created, err := store.GetPolicyByID(ctx, LockingStrengthNone, accountID, "bench-new-policy")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, maxSeq+1, created.AccountSeqID)
|
||||
}
|
||||
@@ -137,6 +137,7 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, met
|
||||
&networkTypes.Network{}, &routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{}, &types.AccountOnboarding{},
|
||||
&types.Job{}, &zones.Zone{}, &records.Record{}, &types.UserInviteRecord{}, &rpservice.Service{}, &rpservice.Target{}, &domain.Domain{},
|
||||
&accesslogs.AccessLogEntry{}, &proxy.Proxy{},
|
||||
&types.AccountSeqCounter{},
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("auto migratePreAuto: %w", err)
|
||||
@@ -307,6 +308,10 @@ func (s *SqlStore) SaveAccount(ctx context.Context, account *types.Account) erro
|
||||
return result.Error
|
||||
}
|
||||
|
||||
if err := s.assignAccountSeqIDs(ctx, tx, account); err != nil {
|
||||
return fmt.Errorf("assign seq ids: %w", err)
|
||||
}
|
||||
|
||||
result = tx.
|
||||
Session(&gorm.Session{FullSaveAssociations: true}).
|
||||
Clauses(clause.OnConflict{UpdateAll: true}).
|
||||
@@ -498,8 +503,9 @@ func (s *SqlStore) SavePeerStatus(ctx context.Context, accountID, peerID string,
|
||||
peerCopy.Status = &peerStatus
|
||||
|
||||
fieldsToUpdate := []string{
|
||||
"peer_status_last_seen", "peer_status_connected",
|
||||
"peer_status_login_expired", "peer_status_required_approval",
|
||||
"peer_status_last_seen", "peer_status_session_started_at",
|
||||
"peer_status_connected", "peer_status_login_expired",
|
||||
"peer_status_requires_approval",
|
||||
}
|
||||
result := s.db.Model(&nbpeer.Peer{}).
|
||||
Select(fieldsToUpdate).
|
||||
@@ -516,6 +522,69 @@ func (s *SqlStore) SavePeerStatus(ctx context.Context, accountID, peerID string,
|
||||
return nil
|
||||
}
|
||||
|
||||
// MarkPeerConnectedIfNewerSession is an atomic optimistic-locked update.
|
||||
// The peer is marked connected with the given session token only when
|
||||
// the stored SessionStartedAt is strictly smaller than the incoming
|
||||
// one — equivalently, when no newer stream has already taken ownership.
|
||||
// The sentinel zero (set on peer creation or after a disconnect) counts
|
||||
// as the smallest possible token. This is the write half of the
|
||||
// fencing protocol described on PeerStatus.SessionStartedAt.
|
||||
//
|
||||
// The post-write side effects in the caller — geo lookup,
|
||||
// schedulePeerLoginExpiration, checkAndSchedulePeerInactivityExpiration,
|
||||
// OnPeersUpdated — all run AFTER this method returns and are deliberately
|
||||
// outside the database write so they cannot extend the row-lock window.
|
||||
//
|
||||
// LastSeen is set to the database's clock (CURRENT_TIMESTAMP) at the
|
||||
// moment the row is written. The caller never supplies LastSeen because
|
||||
// the value would otherwise drift under lock contention — a Go-side
|
||||
// time.Now() taken before the write can land minutes later than the
|
||||
// actual UPDATE under load, which previously caused real ordering bugs.
|
||||
func (s *SqlStore) MarkPeerConnectedIfNewerSession(ctx context.Context, accountID, peerID string, newSessionStartedAt int64) (bool, error) {
|
||||
result := s.db.WithContext(ctx).
|
||||
Model(&nbpeer.Peer{}).
|
||||
Where(accountAndIDQueryCondition, accountID, peerID).
|
||||
Where("peer_status_session_started_at < ?", newSessionStartedAt).
|
||||
Updates(map[string]any{
|
||||
"peer_status_connected": true,
|
||||
"peer_status_last_seen": gorm.Expr("CURRENT_TIMESTAMP"),
|
||||
"peer_status_session_started_at": newSessionStartedAt,
|
||||
"peer_status_login_expired": false,
|
||||
})
|
||||
if result.Error != nil {
|
||||
return false, status.Errorf(status.Internal, "mark peer connected: %v", result.Error)
|
||||
}
|
||||
return result.RowsAffected > 0, nil
|
||||
}
|
||||
|
||||
// MarkPeerDisconnectedIfSameSession is an atomic optimistic-locked update.
|
||||
// The peer is marked disconnected only when the stored SessionStartedAt
|
||||
// matches the incoming token — meaning the stream that owns the current
|
||||
// session is the one ending. If a newer stream has already replaced the
|
||||
// session, the update is skipped. LastSeen is set to CURRENT_TIMESTAMP at
|
||||
// write time; see MarkPeerConnectedIfNewerSession for the rationale.
|
||||
//
|
||||
// A zero sessionStartedAt is rejected at the call site; the underlying
|
||||
// WHERE on equality would otherwise match every never-connected peer.
|
||||
func (s *SqlStore) MarkPeerDisconnectedIfSameSession(ctx context.Context, accountID, peerID string, sessionStartedAt int64) (bool, error) {
|
||||
if sessionStartedAt == 0 {
|
||||
return false, nil
|
||||
}
|
||||
result := s.db.WithContext(ctx).
|
||||
Model(&nbpeer.Peer{}).
|
||||
Where(accountAndIDQueryCondition, accountID, peerID).
|
||||
Where("peer_status_session_started_at = ?", sessionStartedAt).
|
||||
Updates(map[string]any{
|
||||
"peer_status_connected": false,
|
||||
"peer_status_last_seen": gorm.Expr("CURRENT_TIMESTAMP"),
|
||||
"peer_status_session_started_at": int64(0),
|
||||
})
|
||||
if result.Error != nil {
|
||||
return false, status.Errorf(status.Internal, "mark peer disconnected: %v", result.Error)
|
||||
}
|
||||
return result.RowsAffected > 0, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) SavePeerLocation(ctx context.Context, accountID string, peerWithLocation *nbpeer.Peer) error {
|
||||
// To maintain data integrity, we create a copy of the peer's location to prevent unintended updates to other fields.
|
||||
var peerCopy nbpeer.Peer
|
||||
@@ -594,6 +663,22 @@ func (s *SqlStore) SaveUser(ctx context.Context, user *types.User) error {
|
||||
}
|
||||
|
||||
// CreateGroups creates the given list of groups to the database.
|
||||
// groupUpsertColumns is the explicit allowlist of columns that get updated when
|
||||
// CreateGroups / UpdateGroups hit a PK conflict. account_seq_id is intentionally
|
||||
// omitted so a caller passing an entity with the zero value (e.g. an HTTP
|
||||
// handler-built struct) cannot reset the persisted seq id during an upsert.
|
||||
// Keep this in sync with the Group schema in management/server/types/group.go.
|
||||
func groupUpsertColumns() clause.Set {
|
||||
return clause.AssignmentColumns([]string{
|
||||
"account_id",
|
||||
"name",
|
||||
"issued",
|
||||
"integration_ref_id",
|
||||
"integration_ref_integration_type",
|
||||
"resources",
|
||||
})
|
||||
}
|
||||
|
||||
func (s *SqlStore) CreateGroups(ctx context.Context, accountID string, groups []*types.Group) error {
|
||||
if len(groups) == 0 {
|
||||
return nil
|
||||
@@ -603,8 +688,9 @@ func (s *SqlStore) CreateGroups(ctx context.Context, accountID string, groups []
|
||||
result := tx.
|
||||
Clauses(
|
||||
clause.OnConflict{
|
||||
Columns: []clause.Column{{Name: "id"}},
|
||||
Where: clause.Where{Exprs: []clause.Expression{clause.Eq{Column: "groups.account_id", Value: accountID}}},
|
||||
UpdateAll: true,
|
||||
DoUpdates: groupUpsertColumns(),
|
||||
},
|
||||
).
|
||||
Omit(clause.Associations).
|
||||
@@ -628,8 +714,9 @@ func (s *SqlStore) UpdateGroups(ctx context.Context, accountID string, groups []
|
||||
result := tx.
|
||||
Clauses(
|
||||
clause.OnConflict{
|
||||
Columns: []clause.Column{{Name: "id"}},
|
||||
Where: clause.Where{Exprs: []clause.Expression{clause.Eq{Column: "groups.account_id", Value: accountID}}},
|
||||
UpdateAll: true,
|
||||
DoUpdates: groupUpsertColumns(),
|
||||
},
|
||||
).
|
||||
Omit(clause.Associations).
|
||||
@@ -1723,9 +1810,10 @@ func (s *SqlStore) getPeers(ctx context.Context, accountID string) ([]nbpeer.Pee
|
||||
inactivity_expiration_enabled, last_login, created_at, ephemeral, extra_dns_labels, allow_extra_dns_labels, meta_hostname,
|
||||
meta_go_os, meta_kernel, meta_core, meta_platform, meta_os, meta_os_version, meta_wt_version, meta_ui_version,
|
||||
meta_kernel_version, meta_network_addresses, meta_system_serial_number, meta_system_product_name, meta_system_manufacturer,
|
||||
meta_environment, meta_flags, meta_files, meta_capabilities, peer_status_last_seen, peer_status_connected, peer_status_login_expired,
|
||||
peer_status_requires_approval, location_connection_ip, location_country_code, location_city_name,
|
||||
location_geo_name_id, proxy_meta_embedded, proxy_meta_cluster, ipv6 FROM peers WHERE account_id = $1`
|
||||
meta_environment, meta_flags, meta_files, meta_capabilities, peer_status_last_seen, peer_status_session_started_at,
|
||||
peer_status_connected, peer_status_login_expired, peer_status_requires_approval, location_connection_ip,
|
||||
location_country_code, location_city_name, location_geo_name_id, proxy_meta_embedded, proxy_meta_cluster, ipv6
|
||||
FROM peers WHERE account_id = $1`
|
||||
rows, err := s.pool.Query(ctx, query, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -1738,6 +1826,7 @@ func (s *SqlStore) getPeers(ctx context.Context, accountID string) ([]nbpeer.Pee
|
||||
lastLogin, createdAt sql.NullTime
|
||||
sshEnabled, loginExpirationEnabled, inactivityExpirationEnabled, ephemeral, allowExtraDNSLabels sql.NullBool
|
||||
peerStatusLastSeen sql.NullTime
|
||||
peerStatusSessionStartedAt sql.NullInt64
|
||||
peerStatusConnected, peerStatusLoginExpired, peerStatusRequiresApproval, proxyEmbedded sql.NullBool
|
||||
ip, extraDNS, netAddr, env, flags, files, capabilities, connIP, ipv6 []byte
|
||||
metaHostname, metaGoOS, metaKernel, metaCore, metaPlatform sql.NullString
|
||||
@@ -1752,8 +1841,9 @@ func (s *SqlStore) getPeers(ctx context.Context, accountID string) ([]nbpeer.Pee
|
||||
&allowExtraDNSLabels, &metaHostname, &metaGoOS, &metaKernel, &metaCore, &metaPlatform,
|
||||
&metaOS, &metaOSVersion, &metaWtVersion, &metaUIVersion, &metaKernelVersion, &netAddr,
|
||||
&metaSystemSerialNumber, &metaSystemProductName, &metaSystemManufacturer, &env, &flags, &files, &capabilities,
|
||||
&peerStatusLastSeen, &peerStatusConnected, &peerStatusLoginExpired, &peerStatusRequiresApproval, &connIP,
|
||||
&locationCountryCode, &locationCityName, &locationGeoNameID, &proxyEmbedded, &proxyCluster, &ipv6)
|
||||
&peerStatusLastSeen, &peerStatusSessionStartedAt, &peerStatusConnected, &peerStatusLoginExpired,
|
||||
&peerStatusRequiresApproval, &connIP, &locationCountryCode, &locationCityName, &locationGeoNameID,
|
||||
&proxyEmbedded, &proxyCluster, &ipv6)
|
||||
|
||||
if err == nil {
|
||||
if lastLogin.Valid {
|
||||
@@ -1780,6 +1870,9 @@ func (s *SqlStore) getPeers(ctx context.Context, accountID string) ([]nbpeer.Pee
|
||||
if peerStatusLastSeen.Valid {
|
||||
p.Status.LastSeen = peerStatusLastSeen.Time
|
||||
}
|
||||
if peerStatusSessionStartedAt.Valid {
|
||||
p.Status.SessionStartedAt = peerStatusSessionStartedAt.Int64
|
||||
}
|
||||
if peerStatusConnected.Valid {
|
||||
p.Status.Connected = peerStatusConnected.Bool
|
||||
}
|
||||
@@ -1925,7 +2018,7 @@ func (s *SqlStore) getUsers(ctx context.Context, accountID string) ([]types.User
|
||||
}
|
||||
|
||||
func (s *SqlStore) getGroups(ctx context.Context, accountID string) ([]*types.Group, error) {
|
||||
const query = `SELECT id, account_id, name, issued, resources, integration_ref_id, integration_ref_integration_type FROM groups WHERE account_id = $1`
|
||||
const query = `SELECT id, account_id, account_seq_id, name, issued, resources, integration_ref_id, integration_ref_integration_type FROM groups WHERE account_id = $1`
|
||||
rows, err := s.pool.Query(ctx, query, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -1935,7 +2028,7 @@ func (s *SqlStore) getGroups(ctx context.Context, accountID string) ([]*types.Gr
|
||||
var resources []byte
|
||||
var refID sql.NullInt64
|
||||
var refType sql.NullString
|
||||
err := row.Scan(&g.ID, &g.AccountID, &g.Name, &g.Issued, &resources, &refID, &refType)
|
||||
err := row.Scan(&g.ID, &g.AccountID, &g.AccountSeqID, &g.Name, &g.Issued, &resources, &refID, &refType)
|
||||
if err == nil {
|
||||
if refID.Valid {
|
||||
g.IntegrationReference.ID = int(refID.Int64)
|
||||
@@ -1960,7 +2053,7 @@ func (s *SqlStore) getGroups(ctx context.Context, accountID string) ([]*types.Gr
|
||||
}
|
||||
|
||||
func (s *SqlStore) getPolicies(ctx context.Context, accountID string) ([]*types.Policy, error) {
|
||||
const query = `SELECT id, account_id, name, description, enabled, source_posture_checks FROM policies WHERE account_id = $1`
|
||||
const query = `SELECT id, account_id, account_seq_id, name, description, enabled, source_posture_checks FROM policies WHERE account_id = $1`
|
||||
rows, err := s.pool.Query(ctx, query, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -1969,7 +2062,7 @@ func (s *SqlStore) getPolicies(ctx context.Context, accountID string) ([]*types.
|
||||
var p types.Policy
|
||||
var checks []byte
|
||||
var enabled sql.NullBool
|
||||
err := row.Scan(&p.ID, &p.AccountID, &p.Name, &p.Description, &enabled, &checks)
|
||||
err := row.Scan(&p.ID, &p.AccountID, &p.AccountSeqID, &p.Name, &p.Description, &enabled, &checks)
|
||||
if err == nil {
|
||||
if enabled.Valid {
|
||||
p.Enabled = enabled.Bool
|
||||
@@ -1987,7 +2080,7 @@ func (s *SqlStore) getPolicies(ctx context.Context, accountID string) ([]*types.
|
||||
}
|
||||
|
||||
func (s *SqlStore) getRoutes(ctx context.Context, accountID string) ([]route.Route, error) {
|
||||
const query = `SELECT id, account_id, network, domains, keep_route, net_id, description, peer, peer_groups, network_type, masquerade, metric, enabled, groups, access_control_groups, skip_auto_apply FROM routes WHERE account_id = $1`
|
||||
const query = `SELECT id, account_id, account_seq_id, network, domains, keep_route, net_id, description, peer, peer_groups, network_type, masquerade, metric, enabled, groups, access_control_groups, skip_auto_apply FROM routes WHERE account_id = $1`
|
||||
rows, err := s.pool.Query(ctx, query, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -1997,7 +2090,7 @@ func (s *SqlStore) getRoutes(ctx context.Context, accountID string) ([]route.Rou
|
||||
var network, domains, peerGroups, groups, accessGroups []byte
|
||||
var keepRoute, masquerade, enabled, skipAutoApply sql.NullBool
|
||||
var metric sql.NullInt64
|
||||
err := row.Scan(&r.ID, &r.AccountID, &network, &domains, &keepRoute, &r.NetID, &r.Description, &r.Peer, &peerGroups, &r.NetworkType, &masquerade, &metric, &enabled, &groups, &accessGroups, &skipAutoApply)
|
||||
err := row.Scan(&r.ID, &r.AccountID, &r.AccountSeqID, &network, &domains, &keepRoute, &r.NetID, &r.Description, &r.Peer, &peerGroups, &r.NetworkType, &masquerade, &metric, &enabled, &groups, &accessGroups, &skipAutoApply)
|
||||
if err == nil {
|
||||
if keepRoute.Valid {
|
||||
r.KeepRoute = keepRoute.Bool
|
||||
@@ -2039,7 +2132,7 @@ func (s *SqlStore) getRoutes(ctx context.Context, accountID string) ([]route.Rou
|
||||
}
|
||||
|
||||
func (s *SqlStore) getNameServerGroups(ctx context.Context, accountID string) ([]nbdns.NameServerGroup, error) {
|
||||
const query = `SELECT id, account_id, name, description, name_servers, groups, "primary", domains, enabled, search_domains_enabled FROM name_server_groups WHERE account_id = $1`
|
||||
const query = `SELECT id, account_id, account_seq_id, name, description, name_servers, groups, "primary", domains, enabled, search_domains_enabled FROM name_server_groups WHERE account_id = $1`
|
||||
rows, err := s.pool.Query(ctx, query, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -2048,7 +2141,7 @@ func (s *SqlStore) getNameServerGroups(ctx context.Context, accountID string) ([
|
||||
var n nbdns.NameServerGroup
|
||||
var ns, groups, domains []byte
|
||||
var primary, enabled, searchDomainsEnabled sql.NullBool
|
||||
err := row.Scan(&n.ID, &n.AccountID, &n.Name, &n.Description, &ns, &groups, &primary, &domains, &enabled, &searchDomainsEnabled)
|
||||
err := row.Scan(&n.ID, &n.AccountID, &n.AccountSeqID, &n.Name, &n.Description, &ns, &groups, &primary, &domains, &enabled, &searchDomainsEnabled)
|
||||
if err == nil {
|
||||
if primary.Valid {
|
||||
n.Primary = primary.Bool
|
||||
@@ -2084,7 +2177,7 @@ func (s *SqlStore) getNameServerGroups(ctx context.Context, accountID string) ([
|
||||
}
|
||||
|
||||
func (s *SqlStore) getPostureChecks(ctx context.Context, accountID string) ([]*posture.Checks, error) {
|
||||
const query = `SELECT id, account_id, name, description, checks FROM posture_checks WHERE account_id = $1`
|
||||
const query = `SELECT id, account_id, account_seq_id, name, description, checks FROM posture_checks WHERE account_id = $1`
|
||||
rows, err := s.pool.Query(ctx, query, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -2092,7 +2185,7 @@ func (s *SqlStore) getPostureChecks(ctx context.Context, accountID string) ([]*p
|
||||
checks, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (*posture.Checks, error) {
|
||||
var c posture.Checks
|
||||
var checksDef []byte
|
||||
err := row.Scan(&c.ID, &c.AccountID, &c.Name, &c.Description, &checksDef)
|
||||
err := row.Scan(&c.ID, &c.AccountID, &c.AccountSeqID, &c.Name, &c.Description, &checksDef)
|
||||
if err == nil && checksDef != nil {
|
||||
_ = json.Unmarshal(checksDef, &c.Checks)
|
||||
}
|
||||
@@ -2258,7 +2351,7 @@ func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*rpserv
|
||||
}
|
||||
|
||||
func (s *SqlStore) getNetworks(ctx context.Context, accountID string) ([]*networkTypes.Network, error) {
|
||||
const query = `SELECT id, account_id, name, description FROM networks WHERE account_id = $1`
|
||||
const query = `SELECT id, account_id, account_seq_id, name, description FROM networks WHERE account_id = $1`
|
||||
rows, err := s.pool.Query(ctx, query, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -2275,7 +2368,7 @@ func (s *SqlStore) getNetworks(ctx context.Context, accountID string) ([]*networ
|
||||
}
|
||||
|
||||
func (s *SqlStore) getNetworkRouters(ctx context.Context, accountID string) ([]*routerTypes.NetworkRouter, error) {
|
||||
const query = `SELECT id, network_id, account_id, peer, peer_groups, masquerade, metric, enabled FROM network_routers WHERE account_id = $1`
|
||||
const query = `SELECT id, network_id, account_id, account_seq_id, peer, peer_groups, masquerade, metric, enabled FROM network_routers WHERE account_id = $1`
|
||||
rows, err := s.pool.Query(ctx, query, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -2285,7 +2378,7 @@ func (s *SqlStore) getNetworkRouters(ctx context.Context, accountID string) ([]*
|
||||
var peerGroups []byte
|
||||
var masquerade, enabled sql.NullBool
|
||||
var metric sql.NullInt64
|
||||
err := row.Scan(&r.ID, &r.NetworkID, &r.AccountID, &r.Peer, &peerGroups, &masquerade, &metric, &enabled)
|
||||
err := row.Scan(&r.ID, &r.NetworkID, &r.AccountID, &r.AccountSeqID, &r.Peer, &peerGroups, &masquerade, &metric, &enabled)
|
||||
if err == nil {
|
||||
if masquerade.Valid {
|
||||
r.Masquerade = masquerade.Bool
|
||||
@@ -2313,7 +2406,7 @@ func (s *SqlStore) getNetworkRouters(ctx context.Context, accountID string) ([]*
|
||||
}
|
||||
|
||||
func (s *SqlStore) getNetworkResources(ctx context.Context, accountID string) ([]*resourceTypes.NetworkResource, error) {
|
||||
const query = `SELECT id, network_id, account_id, name, description, type, domain, prefix, enabled FROM network_resources WHERE account_id = $1`
|
||||
const query = `SELECT id, network_id, account_id, account_seq_id, name, description, type, domain, prefix, enabled FROM network_resources WHERE account_id = $1`
|
||||
rows, err := s.pool.Query(ctx, query, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -2322,7 +2415,7 @@ func (s *SqlStore) getNetworkResources(ctx context.Context, accountID string) ([
|
||||
var r resourceTypes.NetworkResource
|
||||
var prefix []byte
|
||||
var enabled sql.NullBool
|
||||
err := row.Scan(&r.ID, &r.NetworkID, &r.AccountID, &r.Name, &r.Description, &r.Type, &r.Domain, &prefix, &enabled)
|
||||
err := row.Scan(&r.ID, &r.NetworkID, &r.AccountID, &r.AccountSeqID, &r.Name, &r.Description, &r.Type, &r.Domain, &prefix, &enabled)
|
||||
if err == nil {
|
||||
if enabled.Valid {
|
||||
r.Enabled = enabled.Bool
|
||||
@@ -3495,6 +3588,262 @@ func (s *SqlStore) withTx(tx *gorm.DB) Store {
|
||||
}
|
||||
}
|
||||
|
||||
// AllocateAccountSeqID returns the next per-account integer id for the given
|
||||
// component kind. Must be called inside ExecuteInTransaction so the increment
|
||||
// is serialized with the component insert.
|
||||
func (s *SqlStore) AllocateAccountSeqID(ctx context.Context, accountID string, entity types.AccountSeqEntity) (uint32, error) {
|
||||
return allocateAccountSeqID(ctx, s.db, s.storeEngine, accountID, entity)
|
||||
}
|
||||
|
||||
func allocateAccountSeqID(_ context.Context, db *gorm.DB, engine types.Engine, accountID string, entity types.AccountSeqEntity) (uint32, error) {
|
||||
switch engine {
|
||||
case types.PostgresStoreEngine, types.SqliteStoreEngine:
|
||||
return allocateAccountSeqIDReturning(db, accountID, entity)
|
||||
case types.MysqlStoreEngine:
|
||||
return allocateAccountSeqIDMysql(db, accountID, entity)
|
||||
default:
|
||||
return 0, fmt.Errorf("unsupported store engine for account_seq allocator: %v", engine)
|
||||
}
|
||||
}
|
||||
|
||||
// allocateAccountSeqIDReturning runs a single atomic INSERT ... ON CONFLICT
|
||||
// DO UPDATE ... RETURNING that gives us the allocated id without a separate
|
||||
// SELECT FOR UPDATE. Two concurrent allocations for the same (account, entity)
|
||||
// produce two distinct ids: one wins the INSERT, the other wins the UPDATE
|
||||
// branch and returns next_id+1.
|
||||
func allocateAccountSeqIDReturning(db *gorm.DB, accountID string, entity types.AccountSeqEntity) (uint32, error) {
|
||||
const sqlStr = `
|
||||
INSERT INTO account_seq_counters (account_id, entity, next_id)
|
||||
VALUES (?, ?, 2)
|
||||
ON CONFLICT (account_id, entity) DO UPDATE
|
||||
SET next_id = account_seq_counters.next_id + 1
|
||||
RETURNING (next_id - 1)
|
||||
`
|
||||
var allocated uint32
|
||||
if err := db.Raw(sqlStr, accountID, string(entity)).Scan(&allocated).Error; err != nil {
|
||||
return 0, fmt.Errorf("upsert account seq counter: %w", err)
|
||||
}
|
||||
if allocated == 0 {
|
||||
return 0, fmt.Errorf("upsert account seq counter returned 0")
|
||||
}
|
||||
return allocated, nil
|
||||
}
|
||||
|
||||
// allocateAccountSeqIDMysql is the MySQL equivalent of allocateAccountSeqIDReturning.
|
||||
// MySQL has no RETURNING on ON DUPLICATE KEY UPDATE, so we use the LAST_INSERT_ID
|
||||
// trick: passing an expression to LAST_INSERT_ID(expr) both sets the session value
|
||||
// and returns it from the INSERT. The INSERT's value uses LAST_INSERT_ID(2) so the
|
||||
// no-conflict path also surfaces the new next_id, keeping the read-back uniform.
|
||||
// LAST_INSERT_ID is per-connection; GORM transactions pin a single connection,
|
||||
// so the follow-up SELECT sees the same value.
|
||||
func allocateAccountSeqIDMysql(db *gorm.DB, accountID string, entity types.AccountSeqEntity) (uint32, error) {
|
||||
const upsertSQL = `
|
||||
INSERT INTO account_seq_counters (account_id, entity, next_id)
|
||||
VALUES (?, ?, LAST_INSERT_ID(2))
|
||||
ON DUPLICATE KEY UPDATE next_id = LAST_INSERT_ID(next_id + 1)
|
||||
`
|
||||
if err := db.Exec(upsertSQL, accountID, string(entity)).Error; err != nil {
|
||||
return 0, fmt.Errorf("upsert account seq counter: %w", err)
|
||||
}
|
||||
var newNext uint64
|
||||
if err := db.Raw("SELECT LAST_INSERT_ID()").Scan(&newNext).Error; err != nil {
|
||||
return 0, fmt.Errorf("get last insert id: %w", err)
|
||||
}
|
||||
if newNext == 0 {
|
||||
return 0, fmt.Errorf("LAST_INSERT_ID returned 0; account_seq_counters misconfigured")
|
||||
}
|
||||
return uint32(newNext - 1), nil
|
||||
}
|
||||
|
||||
// assignAccountSeqIDs allocates a per-account integer id for any component on
|
||||
// the in-memory account whose AccountSeqID is zero. Called from SaveAccount so
|
||||
// the canonical "save the whole account" path produces the same persisted seq
|
||||
// ids that the manager-level Create paths produce. Update flows that go
|
||||
// through SaveAccount preserve existing non-zero values; for those, the
|
||||
// per-entity counter is bumped so subsequent AllocateAccountSeqID calls don't
|
||||
// hand out a colliding id.
|
||||
func (s *SqlStore) assignAccountSeqIDs(ctx context.Context, tx *gorm.DB, account *types.Account) error {
|
||||
maxByEntity := make(map[types.AccountSeqEntity]uint32, 8)
|
||||
bump := func(entity types.AccountSeqEntity, seq uint32) {
|
||||
if seq > maxByEntity[entity] {
|
||||
maxByEntity[entity] = seq
|
||||
}
|
||||
}
|
||||
|
||||
for i := range account.GroupsG {
|
||||
g := account.GroupsG[i]
|
||||
if g == nil {
|
||||
continue
|
||||
}
|
||||
if g.AccountSeqID != 0 {
|
||||
bump(types.AccountSeqEntityGroup, g.AccountSeqID)
|
||||
continue
|
||||
}
|
||||
seq, err := allocateAccountSeqID(ctx, tx, s.storeEngine, account.Id, types.AccountSeqEntityGroup)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
g.AccountSeqID = seq
|
||||
// Defensive: generateAccountSQLTypes currently aliases the same
|
||||
// *Group pointer into GroupsG and Groups[id] (so this is a no-op
|
||||
// today), but mirror the seq anyway so any future divergence in
|
||||
// how the two collections are populated doesn't silently leave
|
||||
// the canonical map view stale.
|
||||
if original, ok := account.Groups[g.ID]; ok && original != nil && original != g {
|
||||
original.AccountSeqID = seq
|
||||
}
|
||||
}
|
||||
for _, p := range account.Policies {
|
||||
if p == nil {
|
||||
continue
|
||||
}
|
||||
if p.AccountSeqID != 0 {
|
||||
bump(types.AccountSeqEntityPolicy, p.AccountSeqID)
|
||||
continue
|
||||
}
|
||||
seq, err := allocateAccountSeqID(ctx, tx, s.storeEngine, account.Id, types.AccountSeqEntityPolicy)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
p.AccountSeqID = seq
|
||||
}
|
||||
for i := range account.RoutesG {
|
||||
r := &account.RoutesG[i]
|
||||
if r.AccountSeqID != 0 {
|
||||
bump(types.AccountSeqEntityRoute, r.AccountSeqID)
|
||||
continue
|
||||
}
|
||||
seq, err := allocateAccountSeqID(ctx, tx, s.storeEngine, account.Id, types.AccountSeqEntityRoute)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
r.AccountSeqID = seq
|
||||
// Mirror the new seq onto the canonical map view so callers that
|
||||
// hold the same in-memory account post-Save read a consistent
|
||||
// AccountSeqID — without this, components/encoder code would see
|
||||
// 0 for routes saved this transaction until the account is reloaded.
|
||||
if original, ok := account.Routes[r.ID]; ok && original != nil {
|
||||
original.AccountSeqID = seq
|
||||
}
|
||||
}
|
||||
for i := range account.NameServerGroupsG {
|
||||
ng := &account.NameServerGroupsG[i]
|
||||
if ng.AccountSeqID != 0 {
|
||||
bump(types.AccountSeqEntityNameserverGroup, ng.AccountSeqID)
|
||||
continue
|
||||
}
|
||||
seq, err := allocateAccountSeqID(ctx, tx, s.storeEngine, account.Id, types.AccountSeqEntityNameserverGroup)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ng.AccountSeqID = seq
|
||||
if original, ok := account.NameServerGroups[ng.ID]; ok && original != nil {
|
||||
original.AccountSeqID = seq
|
||||
}
|
||||
}
|
||||
for _, nr := range account.NetworkResources {
|
||||
if nr == nil {
|
||||
continue
|
||||
}
|
||||
if nr.AccountSeqID != 0 {
|
||||
bump(types.AccountSeqEntityNetworkResource, nr.AccountSeqID)
|
||||
continue
|
||||
}
|
||||
seq, err := allocateAccountSeqID(ctx, tx, s.storeEngine, account.Id, types.AccountSeqEntityNetworkResource)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
nr.AccountSeqID = seq
|
||||
}
|
||||
for _, nr := range account.NetworkRouters {
|
||||
if nr == nil {
|
||||
continue
|
||||
}
|
||||
if nr.AccountSeqID != 0 {
|
||||
bump(types.AccountSeqEntityNetworkRouter, nr.AccountSeqID)
|
||||
continue
|
||||
}
|
||||
seq, err := allocateAccountSeqID(ctx, tx, s.storeEngine, account.Id, types.AccountSeqEntityNetworkRouter)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
nr.AccountSeqID = seq
|
||||
}
|
||||
for _, n := range account.Networks {
|
||||
if n == nil {
|
||||
continue
|
||||
}
|
||||
if n.AccountSeqID != 0 {
|
||||
bump(types.AccountSeqEntityNetwork, n.AccountSeqID)
|
||||
continue
|
||||
}
|
||||
seq, err := allocateAccountSeqID(ctx, tx, s.storeEngine, account.Id, types.AccountSeqEntityNetwork)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
n.AccountSeqID = seq
|
||||
}
|
||||
for _, pc := range account.PostureChecks {
|
||||
if pc == nil {
|
||||
continue
|
||||
}
|
||||
if pc.AccountSeqID != 0 {
|
||||
bump(types.AccountSeqEntityPostureCheck, pc.AccountSeqID)
|
||||
continue
|
||||
}
|
||||
seq, err := allocateAccountSeqID(ctx, tx, s.storeEngine, account.Id, types.AccountSeqEntityPostureCheck)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
pc.AccountSeqID = seq
|
||||
}
|
||||
for entity, maxSeq := range maxByEntity {
|
||||
if err := ensureAccountSeqCounter(tx, s.storeEngine, account.Id, entity, maxSeq+1); err != nil {
|
||||
return fmt.Errorf("seed counter for %s: %w", entity, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ensureAccountSeqCounter raises the per-account counter for entity to at
|
||||
// least target. Used when SaveAccount persists components that already carry
|
||||
// AccountSeqIDs (e.g. test bulk-load from sqlite to postgres, or migrations
|
||||
// running before component data lands) so that the next AllocateAccountSeqID
|
||||
// call returns a fresh id beyond what was just written.
|
||||
func ensureAccountSeqCounter(db *gorm.DB, engine types.Engine, accountID string, entity types.AccountSeqEntity, target uint32) error {
|
||||
switch engine {
|
||||
case types.PostgresStoreEngine, types.SqliteStoreEngine:
|
||||
const sqlStr = `
|
||||
INSERT INTO account_seq_counters (account_id, entity, next_id)
|
||||
VALUES (?, ?, ?)
|
||||
ON CONFLICT (account_id, entity) DO UPDATE
|
||||
SET next_id = GREATEST(account_seq_counters.next_id, EXCLUDED.next_id)
|
||||
`
|
||||
// sqlite's UPSERT understands max() but the migration uses GREATEST
|
||||
// for postgres and max() for sqlite. We collapse to dialect-specific
|
||||
// statements only when needed.
|
||||
if engine == types.SqliteStoreEngine {
|
||||
const sqliteSQL = `
|
||||
INSERT INTO account_seq_counters (account_id, entity, next_id)
|
||||
VALUES (?, ?, ?)
|
||||
ON CONFLICT (account_id, entity) DO UPDATE
|
||||
SET next_id = max(account_seq_counters.next_id, excluded.next_id)
|
||||
`
|
||||
return db.Exec(sqliteSQL, accountID, string(entity), target).Error
|
||||
}
|
||||
return db.Exec(sqlStr, accountID, string(entity), target).Error
|
||||
case types.MysqlStoreEngine:
|
||||
const sqlStr = `
|
||||
INSERT INTO account_seq_counters (account_id, entity, next_id)
|
||||
VALUES (?, ?, ?)
|
||||
ON DUPLICATE KEY UPDATE next_id = GREATEST(next_id, VALUES(next_id))
|
||||
`
|
||||
return db.Exec(sqlStr, accountID, string(entity), target).Error
|
||||
default:
|
||||
return fmt.Errorf("unsupported store engine for account_seq counter: %v", engine)
|
||||
}
|
||||
}
|
||||
|
||||
// transaction wraps a GORM transaction with MySQL-specific FK checks handling
|
||||
// Use this instead of db.Transaction() directly to avoid deadlocks on MySQL/Aurora
|
||||
func (s *SqlStore) transaction(fn func(*gorm.DB) error) error {
|
||||
@@ -3684,7 +4033,7 @@ func (s *SqlStore) UpdateGroup(ctx context.Context, group *types.Group) error {
|
||||
return status.Errorf(status.InvalidArgument, "group is nil")
|
||||
}
|
||||
|
||||
if err := s.db.Omit(clause.Associations).Save(group).Error; err != nil {
|
||||
if err := s.db.Omit(clause.Associations, "account_seq_id").Save(group).Error; err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to save group to store: %v", err)
|
||||
return status.Errorf(status.Internal, "failed to save group to store")
|
||||
}
|
||||
@@ -3772,7 +4121,7 @@ func (s *SqlStore) CreatePolicy(ctx context.Context, policy *types.Policy) error
|
||||
|
||||
// SavePolicy saves a policy to the database.
|
||||
func (s *SqlStore) SavePolicy(ctx context.Context, policy *types.Policy) error {
|
||||
result := s.db.Session(&gorm.Session{FullSaveAssociations: true}).Save(policy)
|
||||
result := s.db.Session(&gorm.Session{FullSaveAssociations: true}).Omit("account_seq_id").Save(policy)
|
||||
if err := result.Error; err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to save policy to the store: %s", err)
|
||||
return status.Errorf(status.Internal, "failed to save policy to store")
|
||||
|
||||
@@ -167,6 +167,21 @@ type Store interface {
|
||||
GetAllEphemeralPeers(ctx context.Context, lockStrength LockingStrength) ([]*nbpeer.Peer, error)
|
||||
SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error
|
||||
SavePeerStatus(ctx context.Context, accountID, peerID string, status nbpeer.PeerStatus) error
|
||||
// MarkPeerConnectedIfNewerSession sets the peer to connected with the
|
||||
// given session token, but only when the stored SessionStartedAt is
|
||||
// strictly less than newSessionStartedAt (the sentinel zero counts as
|
||||
// "older"). LastSeen is recorded by the database at the moment the
|
||||
// row is updated — never by the caller — so it always reflects the
|
||||
// real write time even under lock contention.
|
||||
// Returns true when the update happened, false when this stream lost
|
||||
// the race against a newer session.
|
||||
MarkPeerConnectedIfNewerSession(ctx context.Context, accountID, peerID string, newSessionStartedAt int64) (bool, error)
|
||||
// MarkPeerDisconnectedIfSameSession sets the peer to disconnected and
|
||||
// resets SessionStartedAt to zero, but only when the stored
|
||||
// SessionStartedAt equals the given sessionStartedAt. LastSeen is
|
||||
// recorded by the database. Returns true when the update happened,
|
||||
// false when a newer session has taken over.
|
||||
MarkPeerDisconnectedIfSameSession(ctx context.Context, accountID, peerID string, sessionStartedAt int64) (bool, error)
|
||||
SavePeerLocation(ctx context.Context, accountID string, peer *nbpeer.Peer) error
|
||||
ApproveAccountPeers(ctx context.Context, accountID string) (int, error)
|
||||
DeletePeer(ctx context.Context, accountID string, peerID string) error
|
||||
@@ -205,6 +220,11 @@ type Store interface {
|
||||
GetStoreEngine() types.Engine
|
||||
ExecuteInTransaction(ctx context.Context, f func(store Store) error) error
|
||||
|
||||
// AllocateAccountSeqID returns the next per-account integer id for the given
|
||||
// component kind. Must run inside a transaction so the increment is serialized
|
||||
// with the component insert.
|
||||
AllocateAccountSeqID(ctx context.Context, accountID string, entity types.AccountSeqEntity) (uint32, error)
|
||||
|
||||
GetAccountNetworks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*networkTypes.Network, error)
|
||||
GetNetworkByID(ctx context.Context, lockStrength LockingStrength, accountID, networkID string) (*networkTypes.Network, error)
|
||||
SaveNetwork(ctx context.Context, network *networkTypes.Network) error
|
||||
@@ -507,6 +527,30 @@ func getMigrationsPostAuto(ctx context.Context) []migrationFunc {
|
||||
func(db *gorm.DB) error {
|
||||
return migration.DropIndex[proxy.Proxy](ctx, db, "idx_proxy_account_id_unique")
|
||||
},
|
||||
func(db *gorm.DB) error {
|
||||
return migration.BackfillAccountSeqIDs[types.Policy](ctx, db, types.AccountSeqEntityPolicy, "id")
|
||||
},
|
||||
func(db *gorm.DB) error {
|
||||
return migration.BackfillAccountSeqIDs[types.Group](ctx, db, types.AccountSeqEntityGroup, "id")
|
||||
},
|
||||
func(db *gorm.DB) error {
|
||||
return migration.BackfillAccountSeqIDs[route.Route](ctx, db, types.AccountSeqEntityRoute, "id")
|
||||
},
|
||||
func(db *gorm.DB) error {
|
||||
return migration.BackfillAccountSeqIDs[resourceTypes.NetworkResource](ctx, db, types.AccountSeqEntityNetworkResource, "id")
|
||||
},
|
||||
func(db *gorm.DB) error {
|
||||
return migration.BackfillAccountSeqIDs[routerTypes.NetworkRouter](ctx, db, types.AccountSeqEntityNetworkRouter, "id")
|
||||
},
|
||||
func(db *gorm.DB) error {
|
||||
return migration.BackfillAccountSeqIDs[dns.NameServerGroup](ctx, db, types.AccountSeqEntityNameserverGroup, "id")
|
||||
},
|
||||
func(db *gorm.DB) error {
|
||||
return migration.BackfillAccountSeqIDs[networkTypes.Network](ctx, db, types.AccountSeqEntityNetwork, "id")
|
||||
},
|
||||
func(db *gorm.DB) error {
|
||||
return migration.BackfillAccountSeqIDs[posture.Checks](ctx, db, types.AccountSeqEntityPostureCheck, "id")
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -746,6 +746,21 @@ func (mr *MockStoreMockRecorder) EphemeralServiceExists(ctx, lockStrength, accou
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "EphemeralServiceExists", reflect.TypeOf((*MockStore)(nil).EphemeralServiceExists), ctx, lockStrength, accountID, peerID, domain)
|
||||
}
|
||||
|
||||
// AllocateAccountSeqID mocks base method.
|
||||
func (m *MockStore) AllocateAccountSeqID(ctx context.Context, accountID string, entity types2.AccountSeqEntity) (uint32, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "AllocateAccountSeqID", ctx, accountID, entity)
|
||||
ret0, _ := ret[0].(uint32)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// AllocateAccountSeqID indicates an expected call of AllocateAccountSeqID.
|
||||
func (mr *MockStoreMockRecorder) AllocateAccountSeqID(ctx, accountID, entity interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AllocateAccountSeqID", reflect.TypeOf((*MockStore)(nil).AllocateAccountSeqID), ctx, accountID, entity)
|
||||
}
|
||||
|
||||
// ExecuteInTransaction mocks base method.
|
||||
func (m *MockStore) ExecuteInTransaction(ctx context.Context, f func(Store) error) error {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -2878,6 +2893,36 @@ func (mr *MockStoreMockRecorder) SavePeerStatus(ctx, accountID, peerID, status i
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SavePeerStatus", reflect.TypeOf((*MockStore)(nil).SavePeerStatus), ctx, accountID, peerID, status)
|
||||
}
|
||||
|
||||
// MarkPeerConnectedIfNewerSession mocks base method.
|
||||
func (m *MockStore) MarkPeerConnectedIfNewerSession(ctx context.Context, accountID, peerID string, newSessionStartedAt int64) (bool, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "MarkPeerConnectedIfNewerSession", ctx, accountID, peerID, newSessionStartedAt)
|
||||
ret0, _ := ret[0].(bool)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// MarkPeerConnectedIfNewerSession indicates an expected call of MarkPeerConnectedIfNewerSession.
|
||||
func (mr *MockStoreMockRecorder) MarkPeerConnectedIfNewerSession(ctx, accountID, peerID, newSessionStartedAt interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MarkPeerConnectedIfNewerSession", reflect.TypeOf((*MockStore)(nil).MarkPeerConnectedIfNewerSession), ctx, accountID, peerID, newSessionStartedAt)
|
||||
}
|
||||
|
||||
// MarkPeerDisconnectedIfSameSession mocks base method.
|
||||
func (m *MockStore) MarkPeerDisconnectedIfSameSession(ctx context.Context, accountID, peerID string, sessionStartedAt int64) (bool, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "MarkPeerDisconnectedIfSameSession", ctx, accountID, peerID, sessionStartedAt)
|
||||
ret0, _ := ret[0].(bool)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// MarkPeerDisconnectedIfSameSession indicates an expected call of MarkPeerDisconnectedIfSameSession.
|
||||
func (mr *MockStoreMockRecorder) MarkPeerDisconnectedIfSameSession(ctx, accountID, peerID, sessionStartedAt interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MarkPeerDisconnectedIfSameSession", reflect.TypeOf((*MockStore)(nil).MarkPeerDisconnectedIfSameSession), ctx, accountID, peerID, sessionStartedAt)
|
||||
}
|
||||
|
||||
// SavePolicy mocks base method.
|
||||
func (m *MockStore) SavePolicy(ctx context.Context, policy *types2.Policy) error {
|
||||
m.ctrl.T.Helper()
|
||||
|
||||
@@ -16,6 +16,8 @@ type AccountManagerMetrics struct {
|
||||
getPeerNetworkMapDurationMs metric.Float64Histogram
|
||||
networkMapObjectCount metric.Int64Histogram
|
||||
peerMetaUpdateCount metric.Int64Counter
|
||||
peerStatusUpdateCounter metric.Int64Counter
|
||||
peerStatusUpdateDurationMs metric.Float64Histogram
|
||||
}
|
||||
|
||||
// NewAccountManagerMetrics creates an instance of AccountManagerMetrics
|
||||
@@ -64,6 +66,24 @@ func NewAccountManagerMetrics(ctx context.Context, meter metric.Meter) (*Account
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// peerStatusUpdateCounter records every attempt to mark a peer as connected or disconnected
|
||||
peerStatusUpdateCounter, err := meter.Int64Counter("management.account.peer.status.update.counter",
|
||||
metric.WithUnit("1"),
|
||||
metric.WithDescription("Number of peer status update attempts, labeled by operation (connect|disconnect) and outcome (applied|stale|error|peer_not_found)"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
peerStatusUpdateDurationMs, err := meter.Float64Histogram("management.account.peer.status.update.duration.ms",
|
||||
metric.WithUnit("milliseconds"),
|
||||
metric.WithExplicitBucketBoundaries(
|
||||
1, 5, 15, 25, 50, 100, 250, 500, 1000, 2000, 5000,
|
||||
),
|
||||
metric.WithDescription("Duration of a peer status update (fence UPDATE + post-write side effects), labeled by operation"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &AccountManagerMetrics{
|
||||
ctx: ctx,
|
||||
getPeerNetworkMapDurationMs: getPeerNetworkMapDurationMs,
|
||||
@@ -71,10 +91,35 @@ func NewAccountManagerMetrics(ctx context.Context, meter metric.Meter) (*Account
|
||||
updateAccountPeersCounter: updateAccountPeersCounter,
|
||||
networkMapObjectCount: networkMapObjectCount,
|
||||
peerMetaUpdateCount: peerMetaUpdateCount,
|
||||
peerStatusUpdateCounter: peerStatusUpdateCounter,
|
||||
peerStatusUpdateDurationMs: peerStatusUpdateDurationMs,
|
||||
}, nil
|
||||
|
||||
}
|
||||
|
||||
// PeerStatusOperation labels the kind of fence-locked peer status write.
|
||||
type PeerStatusOperation string
|
||||
|
||||
// PeerStatusOutcome labels how a fence-locked peer status write resolved.
|
||||
type PeerStatusOutcome string
|
||||
|
||||
const (
|
||||
PeerStatusConnect PeerStatusOperation = "connect"
|
||||
PeerStatusDisconnect PeerStatusOperation = "disconnect"
|
||||
|
||||
// PeerStatusApplied — the fence WHERE matched and the UPDATE landed.
|
||||
PeerStatusApplied PeerStatusOutcome = "applied"
|
||||
// PeerStatusStale — the fence WHERE rejected the write because a
|
||||
// newer session has already taken ownership (connect: stored token
|
||||
// >= incoming; disconnect: stored token != incoming).
|
||||
PeerStatusStale PeerStatusOutcome = "stale"
|
||||
// PeerStatusError — the store returned a non-NotFound error.
|
||||
PeerStatusError PeerStatusOutcome = "error"
|
||||
// PeerStatusPeerNotFound — the peer lookup failed (the peer was
|
||||
// deleted between the gRPC sync handshake and the status write).
|
||||
PeerStatusPeerNotFound PeerStatusOutcome = "peer_not_found"
|
||||
)
|
||||
|
||||
// CountUpdateAccountPeersDuration counts the duration of updating account peers
|
||||
func (metrics *AccountManagerMetrics) CountUpdateAccountPeersDuration(duration time.Duration) {
|
||||
metrics.updateAccountPeersDurationMs.Record(metrics.ctx, float64(duration.Nanoseconds())/1e6)
|
||||
@@ -104,3 +149,23 @@ func (metrics *AccountManagerMetrics) CountUpdateAccountPeersTriggered(resource,
|
||||
func (metrics *AccountManagerMetrics) CountPeerMetUpdate() {
|
||||
metrics.peerMetaUpdateCount.Add(metrics.ctx, 1)
|
||||
}
|
||||
|
||||
// CountPeerStatusUpdate increments the connect/disconnect counter,
|
||||
// labeled by operation and outcome. Both labels are bounded enums.
|
||||
func (metrics *AccountManagerMetrics) CountPeerStatusUpdate(op PeerStatusOperation, outcome PeerStatusOutcome) {
|
||||
metrics.peerStatusUpdateCounter.Add(metrics.ctx, 1,
|
||||
metric.WithAttributes(
|
||||
attribute.String("operation", string(op)),
|
||||
attribute.String("outcome", string(outcome)),
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
// RecordPeerStatusUpdateDuration records the wall-clock time spent
|
||||
// running a peer status update (including post-write side effects),
|
||||
// labeled by operation.
|
||||
func (metrics *AccountManagerMetrics) RecordPeerStatusUpdateDuration(op PeerStatusOperation, d time.Duration) {
|
||||
metrics.peerStatusUpdateDurationMs.Record(metrics.ctx, float64(d.Nanoseconds())/1e6,
|
||||
metric.WithAttributes(attribute.String("operation", string(op))),
|
||||
)
|
||||
}
|
||||
|
||||
@@ -29,6 +29,7 @@ type MockAppMetrics struct {
|
||||
StoreMetricsFunc func() *StoreMetrics
|
||||
UpdateChannelMetricsFunc func() *UpdateChannelMetrics
|
||||
AddAccountManagerMetricsFunc func() *AccountManagerMetrics
|
||||
EphemeralPeersMetricsFunc func() *EphemeralPeersMetrics
|
||||
}
|
||||
|
||||
// GetMeter mocks the GetMeter function of the AppMetrics interface
|
||||
@@ -103,6 +104,14 @@ func (mock *MockAppMetrics) AccountManagerMetrics() *AccountManagerMetrics {
|
||||
return nil
|
||||
}
|
||||
|
||||
// EphemeralPeersMetrics mocks the MockAppMetrics function of the EphemeralPeersMetrics interface
|
||||
func (mock *MockAppMetrics) EphemeralPeersMetrics() *EphemeralPeersMetrics {
|
||||
if mock.EphemeralPeersMetricsFunc != nil {
|
||||
return mock.EphemeralPeersMetricsFunc()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AppMetrics is metrics interface
|
||||
type AppMetrics interface {
|
||||
GetMeter() metric2.Meter
|
||||
@@ -114,6 +123,7 @@ type AppMetrics interface {
|
||||
StoreMetrics() *StoreMetrics
|
||||
UpdateChannelMetrics() *UpdateChannelMetrics
|
||||
AccountManagerMetrics() *AccountManagerMetrics
|
||||
EphemeralPeersMetrics() *EphemeralPeersMetrics
|
||||
}
|
||||
|
||||
// defaultAppMetrics are core application metrics based on OpenTelemetry https://opentelemetry.io/
|
||||
@@ -129,6 +139,7 @@ type defaultAppMetrics struct {
|
||||
storeMetrics *StoreMetrics
|
||||
updateChannelMetrics *UpdateChannelMetrics
|
||||
accountManagerMetrics *AccountManagerMetrics
|
||||
ephemeralMetrics *EphemeralPeersMetrics
|
||||
}
|
||||
|
||||
// IDPMetrics returns metrics for the idp package
|
||||
@@ -161,6 +172,11 @@ func (appMetrics *defaultAppMetrics) AccountManagerMetrics() *AccountManagerMetr
|
||||
return appMetrics.accountManagerMetrics
|
||||
}
|
||||
|
||||
// EphemeralPeersMetrics returns metrics for the ephemeral peer cleanup loop
|
||||
func (appMetrics *defaultAppMetrics) EphemeralPeersMetrics() *EphemeralPeersMetrics {
|
||||
return appMetrics.ephemeralMetrics
|
||||
}
|
||||
|
||||
// Close stop application metrics HTTP handler and closes listener.
|
||||
func (appMetrics *defaultAppMetrics) Close() error {
|
||||
if appMetrics.listener == nil {
|
||||
@@ -245,6 +261,11 @@ func NewDefaultAppMetrics(ctx context.Context) (AppMetrics, error) {
|
||||
return nil, fmt.Errorf("failed to initialize account manager metrics: %w", err)
|
||||
}
|
||||
|
||||
ephemeralMetrics, err := NewEphemeralPeersMetrics(ctx, meter)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to initialize ephemeral peers metrics: %w", err)
|
||||
}
|
||||
|
||||
return &defaultAppMetrics{
|
||||
Meter: meter,
|
||||
ctx: ctx,
|
||||
@@ -254,6 +275,7 @@ func NewDefaultAppMetrics(ctx context.Context) (AppMetrics, error) {
|
||||
storeMetrics: storeMetrics,
|
||||
updateChannelMetrics: updateChannelMetrics,
|
||||
accountManagerMetrics: accountManagerMetrics,
|
||||
ephemeralMetrics: ephemeralMetrics,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -290,6 +312,11 @@ func NewAppMetricsWithMeter(ctx context.Context, meter metric2.Meter) (AppMetric
|
||||
return nil, fmt.Errorf("failed to initialize account manager metrics: %w", err)
|
||||
}
|
||||
|
||||
ephemeralMetrics, err := NewEphemeralPeersMetrics(ctx, meter)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to initialize ephemeral peers metrics: %w", err)
|
||||
}
|
||||
|
||||
return &defaultAppMetrics{
|
||||
Meter: meter,
|
||||
ctx: ctx,
|
||||
@@ -300,5 +327,6 @@ func NewAppMetricsWithMeter(ctx context.Context, meter metric2.Meter) (AppMetric
|
||||
storeMetrics: storeMetrics,
|
||||
updateChannelMetrics: updateChannelMetrics,
|
||||
accountManagerMetrics: accountManagerMetrics,
|
||||
ephemeralMetrics: ephemeralMetrics,
|
||||
}, nil
|
||||
}
|
||||
|
||||
115
management/server/telemetry/ephemeral_metrics.go
Normal file
115
management/server/telemetry/ephemeral_metrics.go
Normal file
@@ -0,0 +1,115 @@
|
||||
package telemetry
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"go.opentelemetry.io/otel/metric"
|
||||
)
|
||||
|
||||
// EphemeralPeersMetrics tracks the ephemeral peer cleanup pipeline: how
|
||||
// many peers are currently scheduled for deletion, how many tick runs
|
||||
// the cleaner has performed, how many peers it has removed, and how
|
||||
// many delete batches failed.
|
||||
type EphemeralPeersMetrics struct {
|
||||
ctx context.Context
|
||||
|
||||
pending metric.Int64UpDownCounter
|
||||
cleanupRuns metric.Int64Counter
|
||||
peersCleaned metric.Int64Counter
|
||||
errors metric.Int64Counter
|
||||
}
|
||||
|
||||
// NewEphemeralPeersMetrics constructs the ephemeral cleanup counters.
|
||||
func NewEphemeralPeersMetrics(ctx context.Context, meter metric.Meter) (*EphemeralPeersMetrics, error) {
|
||||
pending, err := meter.Int64UpDownCounter("management.ephemeral.peers.pending",
|
||||
metric.WithUnit("1"),
|
||||
metric.WithDescription("Number of ephemeral peers currently waiting to be cleaned up"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cleanupRuns, err := meter.Int64Counter("management.ephemeral.cleanup.runs.counter",
|
||||
metric.WithUnit("1"),
|
||||
metric.WithDescription("Number of ephemeral cleanup ticks that processed at least one peer"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
peersCleaned, err := meter.Int64Counter("management.ephemeral.peers.cleaned.counter",
|
||||
metric.WithUnit("1"),
|
||||
metric.WithDescription("Total number of ephemeral peers deleted by the cleanup loop"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
errors, err := meter.Int64Counter("management.ephemeral.cleanup.errors.counter",
|
||||
metric.WithUnit("1"),
|
||||
metric.WithDescription("Number of ephemeral cleanup batches (per account) that failed to delete"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &EphemeralPeersMetrics{
|
||||
ctx: ctx,
|
||||
pending: pending,
|
||||
cleanupRuns: cleanupRuns,
|
||||
peersCleaned: peersCleaned,
|
||||
errors: errors,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// All methods are nil-receiver safe so callers that haven't wired metrics
|
||||
// (tests, self-hosted with metrics off) can invoke them unconditionally.
|
||||
|
||||
// IncPending bumps the pending gauge when a peer is added to the cleanup list.
|
||||
func (m *EphemeralPeersMetrics) IncPending() {
|
||||
if m == nil {
|
||||
return
|
||||
}
|
||||
m.pending.Add(m.ctx, 1)
|
||||
}
|
||||
|
||||
// AddPending bumps the pending gauge by n — used at startup when the
|
||||
// initial set of ephemeral peers is loaded from the store.
|
||||
func (m *EphemeralPeersMetrics) AddPending(n int64) {
|
||||
if m == nil || n <= 0 {
|
||||
return
|
||||
}
|
||||
m.pending.Add(m.ctx, n)
|
||||
}
|
||||
|
||||
// DecPending decreases the pending gauge — used both when a peer reconnects
|
||||
// before its deadline (removed from the list) and when a cleanup tick
|
||||
// actually deletes it.
|
||||
func (m *EphemeralPeersMetrics) DecPending(n int64) {
|
||||
if m == nil || n <= 0 {
|
||||
return
|
||||
}
|
||||
m.pending.Add(m.ctx, -n)
|
||||
}
|
||||
|
||||
// CountCleanupRun records one cleanup pass that processed >0 peers. Idle
|
||||
// ticks (nothing to do) deliberately don't increment so the rate
|
||||
// reflects useful work.
|
||||
func (m *EphemeralPeersMetrics) CountCleanupRun() {
|
||||
if m == nil {
|
||||
return
|
||||
}
|
||||
m.cleanupRuns.Add(m.ctx, 1)
|
||||
}
|
||||
|
||||
// CountPeersCleaned records the number of peers a single tick deleted.
|
||||
func (m *EphemeralPeersMetrics) CountPeersCleaned(n int64) {
|
||||
if m == nil || n <= 0 {
|
||||
return
|
||||
}
|
||||
m.peersCleaned.Add(m.ctx, n)
|
||||
}
|
||||
|
||||
// CountCleanupError records a failed delete batch.
|
||||
func (m *EphemeralPeersMetrics) CountCleanupError() {
|
||||
if m == nil {
|
||||
return
|
||||
}
|
||||
m.errors.Add(m.ctx, 1)
|
||||
}
|
||||
@@ -1006,6 +1006,15 @@ func (a *Account) connResourcesGenerator(ctx context.Context, targetPeer *nbpeer
|
||||
}
|
||||
}
|
||||
|
||||
// PolicyRuleImpliesLegacySSH reports whether the rule (without an explicit
|
||||
// NetbirdSSH protocol) implicitly authorises SSH because it permits TCP/22 or
|
||||
// TCP/22022 — either by ALL-protocol coverage or by an explicit port/port-range
|
||||
// containing one of those. Exposed for ToComponentSyncResponse so the
|
||||
// envelope-format response mirrors the legacy SshConfig.SshEnabled bit.
|
||||
func PolicyRuleImpliesLegacySSH(rule *PolicyRule) bool {
|
||||
return policyRuleImpliesLegacySSH(rule)
|
||||
}
|
||||
|
||||
func policyRuleImpliesLegacySSH(rule *PolicyRule) bool {
|
||||
return rule.Protocol == PolicyRuleProtocolALL || (rule.Protocol == PolicyRuleProtocolTCP && (portsIncludesSSH(rule.Ports) || portRangeIncludesSSH(rule.PortRanges)))
|
||||
}
|
||||
|
||||
@@ -16,6 +16,49 @@ import (
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
// GetPeerNetworkMapResult dispatches to either the legacy-NetworkMap path or
|
||||
// the components path based on the peer's capability and the kill switch.
|
||||
// Capable peers (PeerCapabilityComponentNetworkMap) get the raw components
|
||||
// shape — the server skips Calculate() entirely for them, saving CPU
|
||||
// proportional to the number of capable peers in the account. Legacy peers
|
||||
// (or any peer when componentsDisabled is true) get the fully-expanded
|
||||
// NetworkMap as before.
|
||||
func (a *Account) GetPeerNetworkMapResult(
|
||||
ctx context.Context,
|
||||
peerID string,
|
||||
componentsDisabled bool,
|
||||
peersCustomZone nbdns.CustomZone,
|
||||
accountZones []*zones.Zone,
|
||||
validatedPeersMap map[string]struct{},
|
||||
resourcePolicies map[string][]*Policy,
|
||||
routers map[string]map[string]*routerTypes.NetworkRouter,
|
||||
metrics *telemetry.AccountManagerMetrics,
|
||||
groupIDToUserIDs map[string][]string,
|
||||
) PeerNetworkMapResult {
|
||||
peer := a.Peers[peerID]
|
||||
if !componentsDisabled && peer != nil && peer.SupportsComponentNetworkMap() {
|
||||
components := a.GetPeerNetworkMapComponents(
|
||||
ctx, peerID, peersCustomZone, accountZones, validatedPeersMap, resourcePolicies, routers, groupIDToUserIDs,
|
||||
)
|
||||
// Mirror legacy graceful-degrade: GetPeerNetworkMapFromComponents
|
||||
// returns &NetworkMap{Network: a.Network.Copy()} when components is
|
||||
// nil. Match that floor so the receiving client always sees the
|
||||
// account Network identifier, not a fully-empty envelope.
|
||||
if components == nil {
|
||||
components = &NetworkMapComponents{
|
||||
PeerID: peerID,
|
||||
Network: a.Network.Copy(),
|
||||
}
|
||||
}
|
||||
return PeerNetworkMapResult{Components: components}
|
||||
}
|
||||
return PeerNetworkMapResult{
|
||||
NetworkMap: a.GetPeerNetworkMapFromComponents(
|
||||
ctx, peerID, peersCustomZone, accountZones, validatedPeersMap, resourcePolicies, routers, metrics, groupIDToUserIDs,
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Account) GetPeerNetworkMapFromComponents(
|
||||
ctx context.Context,
|
||||
peerID string,
|
||||
@@ -82,15 +125,27 @@ func (a *Account) GetPeerNetworkMapComponents(
|
||||
}
|
||||
|
||||
components := &NetworkMapComponents{
|
||||
PeerID: peerID,
|
||||
Network: a.Network.Copy(),
|
||||
NameServerGroups: make([]*nbdns.NameServerGroup, 0),
|
||||
CustomZoneDomain: peersCustomZone.Domain,
|
||||
ResourcePoliciesMap: make(map[string][]*Policy),
|
||||
RoutersMap: make(map[string]map[string]*routerTypes.NetworkRouter),
|
||||
NetworkResources: make([]*resourceTypes.NetworkResource, 0),
|
||||
PostureFailedPeers: make(map[string]map[string]struct{}, len(a.PostureChecks)),
|
||||
RouterPeers: make(map[string]*nbpeer.Peer),
|
||||
PeerID: peerID,
|
||||
Network: a.Network.Copy(),
|
||||
NameServerGroups: make([]*nbdns.NameServerGroup, 0),
|
||||
CustomZoneDomain: peersCustomZone.Domain,
|
||||
ResourcePoliciesMap: make(map[string][]*Policy),
|
||||
RoutersMap: make(map[string]map[string]*routerTypes.NetworkRouter),
|
||||
NetworkResources: make([]*resourceTypes.NetworkResource, 0),
|
||||
PostureFailedPeers: make(map[string]map[string]struct{}, len(a.PostureChecks)),
|
||||
RouterPeers: make(map[string]*nbpeer.Peer),
|
||||
NetworkXIDToSeq: make(map[string]uint32, len(a.Networks)),
|
||||
PostureCheckXIDToSeq: make(map[string]uint32, len(a.PostureChecks)),
|
||||
}
|
||||
for _, n := range a.Networks {
|
||||
if n != nil && n.HasSeqID() {
|
||||
components.NetworkXIDToSeq[n.ID] = n.AccountSeqID
|
||||
}
|
||||
}
|
||||
for _, pc := range a.PostureChecks {
|
||||
if pc != nil && pc.HasSeqID() {
|
||||
components.PostureCheckXIDToSeq[pc.ID] = pc.AccountSeqID
|
||||
}
|
||||
}
|
||||
|
||||
components.AccountSettings = &AccountSettingsInfo{
|
||||
@@ -253,18 +308,44 @@ func (a *Account) getPeersGroupsPoliciesRoutes(
|
||||
|
||||
relevantPeerIDs[peerID] = a.GetPeer(peerID)
|
||||
|
||||
peerGroupSet := make(map[string]struct{}, 8)
|
||||
for groupID, group := range a.Groups {
|
||||
if slices.Contains(group.Peers, peerID) {
|
||||
relevantGroupIDs[groupID] = a.GetGroup(groupID)
|
||||
peerGroupSet[groupID] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
routeAccessControlGroups := make(map[string]struct{})
|
||||
for _, r := range a.Routes {
|
||||
for _, groupID := range r.Groups {
|
||||
if r == nil {
|
||||
continue
|
||||
}
|
||||
relevant := r.Peer == peerID
|
||||
if !relevant {
|
||||
for _, groupID := range r.PeerGroups {
|
||||
if _, ok := peerGroupSet[groupID]; ok {
|
||||
relevant = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if !relevant && r.Enabled {
|
||||
for _, groupID := range r.Groups {
|
||||
if _, ok := peerGroupSet[groupID]; ok {
|
||||
relevant = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if !relevant {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, groupID := range r.PeerGroups {
|
||||
relevantGroupIDs[groupID] = a.GetGroup(groupID)
|
||||
}
|
||||
for _, groupID := range r.PeerGroups {
|
||||
for _, groupID := range r.Groups {
|
||||
relevantGroupIDs[groupID] = a.GetGroup(groupID)
|
||||
}
|
||||
if r.Enabled {
|
||||
@@ -485,6 +566,13 @@ func (a *Account) getPostureValidPeersSaveFailed(inputPeers []string, postureChe
|
||||
return dest
|
||||
}
|
||||
|
||||
// filterGroupPeers trims each group's Peers slice to only those peers that
|
||||
// also appear in `peers`. Groups whose filtered list is empty are NOT
|
||||
// deleted from the map — they're kept so the components wire encoder can
|
||||
// still resolve seq references from routes/policies/access-control groups
|
||||
// that name them. Calculate() tolerates groups with empty Peers (the inner
|
||||
// loops simply iterate zero times), so retaining them is behaviourally a
|
||||
// no-op for the legacy path that consumes the same NetworkMapComponents.
|
||||
func filterGroupPeers(groups *map[string]*Group, peers map[string]*nbpeer.Peer) {
|
||||
for groupID, groupInfo := range *groups {
|
||||
filteredPeers := make([]string, 0, len(groupInfo.Peers))
|
||||
@@ -494,9 +582,7 @@ func filterGroupPeers(groups *map[string]*Group, peers map[string]*nbpeer.Peer)
|
||||
}
|
||||
}
|
||||
|
||||
if len(filteredPeers) == 0 {
|
||||
delete(*groups, groupID)
|
||||
} else if len(filteredPeers) != len(groupInfo.Peers) {
|
||||
if len(filteredPeers) != len(groupInfo.Peers) {
|
||||
ng := groupInfo.Copy()
|
||||
ng.Peers = filteredPeers
|
||||
(*groups)[groupID] = ng
|
||||
|
||||
29
management/server/types/account_seq_counter.go
Normal file
29
management/server/types/account_seq_counter.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package types
|
||||
|
||||
// AccountSeqEntity identifies the kind of component that uses a per-account sequence.
|
||||
type AccountSeqEntity string
|
||||
|
||||
const (
|
||||
AccountSeqEntityPolicy AccountSeqEntity = "policy"
|
||||
AccountSeqEntityGroup AccountSeqEntity = "group"
|
||||
AccountSeqEntityRoute AccountSeqEntity = "route"
|
||||
AccountSeqEntityNetworkResource AccountSeqEntity = "network_resource"
|
||||
AccountSeqEntityNetworkRouter AccountSeqEntity = "network_router"
|
||||
AccountSeqEntityNameserverGroup AccountSeqEntity = "nameserver_group"
|
||||
AccountSeqEntityNetwork AccountSeqEntity = "network"
|
||||
AccountSeqEntityPostureCheck AccountSeqEntity = "posture_check"
|
||||
)
|
||||
|
||||
// AccountSeqCounter tracks the next per-account integer id for a given component
|
||||
// kind. Reads/writes go through the store inside the same transaction as the
|
||||
// component insert so two concurrent inserts cannot collide on the same id.
|
||||
type AccountSeqCounter struct {
|
||||
AccountID string `gorm:"primaryKey;size:255"`
|
||||
Entity string `gorm:"primaryKey;size:32"`
|
||||
NextID uint32 `gorm:"not null;default:1"`
|
||||
}
|
||||
|
||||
// TableName overrides the GORM-derived table name.
|
||||
func (AccountSeqCounter) TableName() string {
|
||||
return "account_seq_counters"
|
||||
}
|
||||
@@ -19,6 +19,10 @@ type Group struct {
|
||||
// AccountID is a reference to Account that this object belongs
|
||||
AccountID string `json:"-" gorm:"index"`
|
||||
|
||||
// AccountSeqID is a per-account monotonically increasing identifier used as the
|
||||
// compact wire id when sending NetworkMap components to capable peers.
|
||||
AccountSeqID uint32 `json:"-" gorm:"index:idx_groups_account_seq_id;not null;default:0"`
|
||||
|
||||
// Name visible in the UI
|
||||
Name string
|
||||
|
||||
@@ -41,6 +45,14 @@ type GroupPeer struct {
|
||||
PeerID string `gorm:"primaryKey"`
|
||||
}
|
||||
|
||||
// HasSeqID reports whether the group has been persisted long enough to have a
|
||||
// per-account sequence id allocated. Wire encoders that key off AccountSeqID
|
||||
// must skip groups that return false here — otherwise multiple unpersisted
|
||||
// groups would collide on id 0.
|
||||
func (g *Group) HasSeqID() bool {
|
||||
return g != nil && g.AccountSeqID != 0
|
||||
}
|
||||
|
||||
func (g *Group) LoadGroupPeers() {
|
||||
g.Peers = make([]string, len(g.GroupPeers))
|
||||
for i, peer := range g.GroupPeers {
|
||||
@@ -74,6 +86,7 @@ func (g *Group) Copy() *Group {
|
||||
group := &Group{
|
||||
ID: g.ID,
|
||||
AccountID: g.AccountID,
|
||||
AccountSeqID: g.AccountSeqID,
|
||||
Name: g.Name,
|
||||
Issued: g.Issued,
|
||||
Peers: make([]string, len(g.Peers)),
|
||||
|
||||
@@ -42,6 +42,17 @@ type NetworkMapComponents struct {
|
||||
PostureFailedPeers map[string]map[string]struct{}
|
||||
|
||||
RouterPeers map[string]*nbpeer.Peer
|
||||
|
||||
// NetworkXIDToSeq maps Network.ID (xid) → AccountSeqID. Populated by the
|
||||
// account-side component builder; consumed by the envelope encoder to
|
||||
// translate RoutersMap keys and NetworkResource.NetworkID references
|
||||
// to compact uint32 ids. Legacy Calculate() doesn't consult it.
|
||||
NetworkXIDToSeq map[string]uint32
|
||||
|
||||
// PostureCheckXIDToSeq maps posture.Checks.ID (xid) → AccountSeqID.
|
||||
// Same role as NetworkXIDToSeq, used for PostureFailedPeers keys and
|
||||
// policy SourcePostureChecks references.
|
||||
PostureCheckXIDToSeq map[string]uint32
|
||||
}
|
||||
|
||||
type AccountSettingsInfo struct {
|
||||
|
||||
181
management/server/types/networkmap_wire_benchmark_test.go
Normal file
181
management/server/types/networkmap_wire_benchmark_test.go
Normal file
@@ -0,0 +1,181 @@
|
||||
package types_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
goproto "google.golang.org/protobuf/proto"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache"
|
||||
mgmtgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
// wireBenchScales mirrors the scales used by networkmap_benchmark_test.go but
|
||||
// trimmed: encoding+marshal are linear, so we don't need the 30k peer extreme
|
||||
// to see the trend.
|
||||
var wireBenchScales = []benchmarkScale{
|
||||
{"100peers_5groups", 100, 5},
|
||||
{"500peers_20groups", 500, 20},
|
||||
{"1000peers_50groups", 1000, 50},
|
||||
{"5000peers_100groups", 5000, 100},
|
||||
}
|
||||
|
||||
// populateAccountSeqIDs assigns deterministic AccountSeqIDs to every group and
|
||||
// policy in the account so that the component encoder can reference them. The
|
||||
// scalableTestAccount fixture builds entities by struct literal and skips this
|
||||
// step, but production paths populate the IDs via the store layer.
|
||||
func populateAccountSeqIDs(account *types.Account) {
|
||||
var nextGroupSeq uint32 = 1
|
||||
for _, g := range account.Groups {
|
||||
g.AccountSeqID = nextGroupSeq
|
||||
nextGroupSeq++
|
||||
}
|
||||
var nextPolicySeq uint32 = 1
|
||||
for _, p := range account.Policies {
|
||||
p.AccountSeqID = nextPolicySeq
|
||||
nextPolicySeq++
|
||||
}
|
||||
}
|
||||
|
||||
// assignValidWgKeys overwrites every peer's Key with a valid base64-encoded
|
||||
// 32-byte string. The default scalableTestAccount uses unparsable strings
|
||||
// like "key-peer-0", which makes the components encoder emit a nil WgPubKey
|
||||
// and the legacy encoder ship 10-char placeholders — both shrink the wire
|
||||
// size in unrealistic ways. Production peers always have valid 44-char base64
|
||||
// keys, so any benchmark/breakdown that wants honest numbers must call this.
|
||||
func assignValidWgKeys(account *types.Account) {
|
||||
for _, p := range account.Peers {
|
||||
var raw [32]byte
|
||||
_, _ = rand.Read(raw[:])
|
||||
p.Key = base64.StdEncoding.EncodeToString(raw[:])
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkNetworkMapWireEncode reports per-call ns and the marshaled wire
|
||||
// size for both encoding paths. Run with:
|
||||
//
|
||||
// go test -run=^$ -bench=BenchmarkNetworkMapWireEncode -benchmem ./management/server/types/
|
||||
func BenchmarkNetworkMapWireEncode(b *testing.B) {
|
||||
skipCIBenchmark(b)
|
||||
|
||||
for _, scale := range wireBenchScales {
|
||||
account, validatedPeers := scalableTestAccount(scale.peers, scale.groups)
|
||||
populateAccountSeqIDs(account)
|
||||
assignValidWgKeys(account)
|
||||
|
||||
ctx := context.Background()
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
groupIDToUserIDs := account.GetActiveGroupUsers()
|
||||
|
||||
peerID := "peer-0"
|
||||
peer := account.Peers[peerID]
|
||||
|
||||
networkMap := account.GetPeerNetworkMapFromComponents(ctx, peerID, nbdns.CustomZone{}, nil, validatedPeers, resourcePolicies, routers, nil, groupIDToUserIDs)
|
||||
components := account.GetPeerNetworkMapComponents(ctx, peerID, nbdns.CustomZone{}, nil, validatedPeers, resourcePolicies, routers, groupIDToUserIDs)
|
||||
|
||||
dnsCache := &cache.DNSConfigCache{}
|
||||
settings := &types.Settings{}
|
||||
|
||||
// Pre-encode once so the size metric is identical for every run inside
|
||||
// the same scale; the b.Loop call only re-runs encode + Marshal.
|
||||
legacyResp := mgmtgrpc.ToSyncResponse(ctx, nil, nil, nil, peer, nil, nil, networkMap, "netbird.cloud", nil, dnsCache, settings, nil, nil, 0)
|
||||
legacyBytes, err := goproto.Marshal(legacyResp.NetworkMap)
|
||||
if err != nil {
|
||||
b.Fatalf("marshal legacy networkmap: %v", err)
|
||||
}
|
||||
|
||||
envelopeInput := mgmtgrpc.ComponentsEnvelopeInput{
|
||||
Components: components,
|
||||
PeerConfig: legacyResp.NetworkMap.PeerConfig,
|
||||
DNSDomain: "netbird.cloud",
|
||||
}
|
||||
envelope := mgmtgrpc.EncodeNetworkMapEnvelope(envelopeInput)
|
||||
envelopeBytes, err := goproto.Marshal(envelope)
|
||||
if err != nil {
|
||||
b.Fatalf("marshal envelope: %v", err)
|
||||
}
|
||||
|
||||
b.Run(fmt.Sprintf("legacy/%s", scale.name), func(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
b.ReportMetric(float64(len(legacyBytes)), "bytes/msg")
|
||||
b.ResetTimer()
|
||||
for range b.N {
|
||||
resp := mgmtgrpc.ToSyncResponse(ctx, nil, nil, nil, peer, nil, nil, networkMap, "netbird.cloud", nil, dnsCache, settings, nil, nil, 0)
|
||||
if _, err := goproto.Marshal(resp.NetworkMap); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
b.Run(fmt.Sprintf("components/%s", scale.name), func(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
b.ReportMetric(float64(len(envelopeBytes)), "bytes/msg")
|
||||
b.ResetTimer()
|
||||
for range b.N {
|
||||
env := mgmtgrpc.EncodeNetworkMapEnvelope(envelopeInput)
|
||||
if _, err := goproto.Marshal(env); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkNetworkMapWireSize is a fast snapshot of the wire size by scale
|
||||
// without a tight encode loop. Run with -bench to see one ns/op + bytes per
|
||||
// scale (treat the timing as informational; the sample is one Marshal per
|
||||
// scale, not the full b.N loop).
|
||||
func BenchmarkNetworkMapWireSize(b *testing.B) {
|
||||
skipCIBenchmark(b)
|
||||
|
||||
for _, scale := range wireBenchScales {
|
||||
account, validatedPeers := scalableTestAccount(scale.peers, scale.groups)
|
||||
populateAccountSeqIDs(account)
|
||||
assignValidWgKeys(account)
|
||||
|
||||
ctx := context.Background()
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
groupIDToUserIDs := account.GetActiveGroupUsers()
|
||||
|
||||
peerID := "peer-0"
|
||||
peer := account.Peers[peerID]
|
||||
|
||||
networkMap := account.GetPeerNetworkMapFromComponents(ctx, peerID, nbdns.CustomZone{}, nil, validatedPeers, resourcePolicies, routers, nil, groupIDToUserIDs)
|
||||
components := account.GetPeerNetworkMapComponents(ctx, peerID, nbdns.CustomZone{}, nil, validatedPeers, resourcePolicies, routers, groupIDToUserIDs)
|
||||
|
||||
dnsCache := &cache.DNSConfigCache{}
|
||||
settings := &types.Settings{}
|
||||
|
||||
legacyResp := mgmtgrpc.ToSyncResponse(ctx, nil, nil, nil, peer, nil, nil, networkMap, "netbird.cloud", nil, dnsCache, settings, nil, nil, 0)
|
||||
legacyBytes, err := goproto.Marshal(legacyResp.NetworkMap)
|
||||
if err != nil {
|
||||
b.Fatalf("marshal legacy networkmap: %v", err)
|
||||
}
|
||||
|
||||
env := mgmtgrpc.EncodeNetworkMapEnvelope(mgmtgrpc.ComponentsEnvelopeInput{
|
||||
Components: components,
|
||||
PeerConfig: legacyResp.NetworkMap.PeerConfig,
|
||||
DNSDomain: "netbird.cloud",
|
||||
})
|
||||
envBytes, err := goproto.Marshal(env)
|
||||
if err != nil {
|
||||
b.Fatalf("marshal envelope: %v", err)
|
||||
}
|
||||
|
||||
b.Run(fmt.Sprintf("size/%s", scale.name), func(b *testing.B) {
|
||||
b.ReportMetric(float64(len(legacyBytes)), "legacy_bytes")
|
||||
b.ReportMetric(float64(len(envBytes)), "components_bytes")
|
||||
ratio := float64(len(envBytes)) / float64(len(legacyBytes))
|
||||
b.ReportMetric(ratio, "components/legacy")
|
||||
for range b.N {
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
150
management/server/types/networkmap_wire_breakdown_test.go
Normal file
150
management/server/types/networkmap_wire_breakdown_test.go
Normal file
@@ -0,0 +1,150 @@
|
||||
package types_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
goproto "google.golang.org/protobuf/proto"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache"
|
||||
mgmtgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
// TestNetworkMapWireBreakdown is a one-shot diagnostic: it computes the wire
|
||||
// size attributable to each top-level field of both the legacy NetworkMap and
|
||||
// the components NetworkMapEnvelope at the 5000-peer scale, so the migration
|
||||
// docs can attribute the size reduction to each optimization. Runs only on
|
||||
// demand via -run TestNetworkMapWireBreakdown.
|
||||
func TestNetworkMapWireBreakdown(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("size diagnostic, skipped with -short")
|
||||
}
|
||||
if os.Getenv("NB_RUN_WIRE_BREAKDOWN") != "1" {
|
||||
t.Skip("set NB_RUN_WIRE_BREAKDOWN=1 to run wire breakdown diagnostic")
|
||||
}
|
||||
|
||||
const peerCount, groupCount = 5000, 100
|
||||
account, validatedPeers := scalableTestAccount(peerCount, groupCount)
|
||||
populateAccountSeqIDs(account)
|
||||
assignValidWgKeys(account)
|
||||
|
||||
ctx := context.Background()
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
groupIDToUserIDs := account.GetActiveGroupUsers()
|
||||
|
||||
peerID := "peer-0"
|
||||
peer := account.Peers[peerID]
|
||||
networkMap := account.GetPeerNetworkMapFromComponents(ctx, peerID, nbdns.CustomZone{}, nil, validatedPeers, resourcePolicies, routers, nil, groupIDToUserIDs)
|
||||
components := account.GetPeerNetworkMapComponents(ctx, peerID, nbdns.CustomZone{}, nil, validatedPeers, resourcePolicies, routers, groupIDToUserIDs)
|
||||
|
||||
dnsCache := &cache.DNSConfigCache{}
|
||||
settings := &types.Settings{}
|
||||
|
||||
legacyResp := mgmtgrpc.ToSyncResponse(ctx, nil, nil, nil, peer, nil, nil, networkMap, "netbird.cloud", nil, dnsCache, settings, nil, nil, 0)
|
||||
legacyTotal := mustMarshalSize(t, legacyResp.NetworkMap)
|
||||
|
||||
envelope := mgmtgrpc.EncodeNetworkMapEnvelope(mgmtgrpc.ComponentsEnvelopeInput{
|
||||
Components: components,
|
||||
PeerConfig: legacyResp.NetworkMap.PeerConfig,
|
||||
DNSDomain: "netbird.cloud",
|
||||
})
|
||||
componentsTotal := mustMarshalSize(t, envelope)
|
||||
|
||||
t.Logf("\n=== LEGACY NetworkMap (%d peers, %d groups) ===", peerCount, groupCount)
|
||||
t.Logf(" Total: %d bytes\n", legacyTotal)
|
||||
|
||||
legacyBreakdown := []struct {
|
||||
name string
|
||||
nm *proto.NetworkMap
|
||||
}{
|
||||
{"RemotePeers", &proto.NetworkMap{RemotePeers: legacyResp.NetworkMap.RemotePeers}},
|
||||
{"OfflinePeers", &proto.NetworkMap{OfflinePeers: legacyResp.NetworkMap.OfflinePeers}},
|
||||
{"FirewallRules", &proto.NetworkMap{FirewallRules: legacyResp.NetworkMap.FirewallRules}},
|
||||
{"Routes", &proto.NetworkMap{Routes: legacyResp.NetworkMap.Routes}},
|
||||
{"RoutesFirewallRules", &proto.NetworkMap{RoutesFirewallRules: legacyResp.NetworkMap.RoutesFirewallRules}},
|
||||
{"DNSConfig", &proto.NetworkMap{DNSConfig: legacyResp.NetworkMap.DNSConfig}},
|
||||
{"PeerConfig", &proto.NetworkMap{PeerConfig: legacyResp.NetworkMap.PeerConfig}},
|
||||
{"SshAuth", &proto.NetworkMap{SshAuth: legacyResp.NetworkMap.SshAuth}},
|
||||
}
|
||||
for _, e := range legacyBreakdown {
|
||||
size := mustMarshalSize(t, e.nm)
|
||||
t.Logf(" %-22s %8d bytes %5.1f%%", e.name, size, pct(size, legacyTotal))
|
||||
}
|
||||
|
||||
full := envelope.GetFull()
|
||||
if full == nil {
|
||||
t.Fatalf("expected full network map envelope payload, got nil")
|
||||
}
|
||||
t.Logf("\n=== COMPONENTS NetworkMapEnvelope (%d peers, %d groups) ===", peerCount, groupCount)
|
||||
t.Logf(" Total: %d bytes (%.1f%% of legacy)\n", componentsTotal, pct(componentsTotal, legacyTotal))
|
||||
|
||||
componentsBreakdown := []struct {
|
||||
name string
|
||||
nm *proto.NetworkMapComponentsFull
|
||||
}{
|
||||
{"Peers", &proto.NetworkMapComponentsFull{Peers: full.Peers}},
|
||||
{"Policies", &proto.NetworkMapComponentsFull{Policies: full.Policies}},
|
||||
{"Groups", &proto.NetworkMapComponentsFull{Groups: full.Groups}},
|
||||
{"Routes (raw)", &proto.NetworkMapComponentsFull{Routes: full.Routes}},
|
||||
{"NameServerGroups", &proto.NetworkMapComponentsFull{NameserverGroups: full.NameserverGroups}},
|
||||
{"AllDNSRecords", &proto.NetworkMapComponentsFull{AllDnsRecords: full.AllDnsRecords}},
|
||||
{"AccountZones", &proto.NetworkMapComponentsFull{AccountZones: full.AccountZones}},
|
||||
{"NetworkResources", &proto.NetworkMapComponentsFull{NetworkResources: full.NetworkResources}},
|
||||
{"RoutersMap", &proto.NetworkMapComponentsFull{RoutersMap: full.RoutersMap}},
|
||||
{"ResourcePoliciesMap", &proto.NetworkMapComponentsFull{ResourcePoliciesMap: full.ResourcePoliciesMap}},
|
||||
{"GroupIDToUserIDs", &proto.NetworkMapComponentsFull{GroupIdToUserIds: full.GroupIdToUserIds}},
|
||||
{"AllowedUserIDs", &proto.NetworkMapComponentsFull{AllowedUserIds: full.AllowedUserIds}},
|
||||
{"PostureFailedPeers", &proto.NetworkMapComponentsFull{PostureFailedPeers: full.PostureFailedPeers}},
|
||||
{"DNSSettings", &proto.NetworkMapComponentsFull{DnsSettings: full.DnsSettings}},
|
||||
{"PeerConfig", &proto.NetworkMapComponentsFull{PeerConfig: full.PeerConfig}},
|
||||
{"AgentVersions", &proto.NetworkMapComponentsFull{AgentVersions: full.AgentVersions}},
|
||||
}
|
||||
for _, e := range componentsBreakdown {
|
||||
size := mustMarshalSize(t, e.nm)
|
||||
t.Logf(" %-22s %8d bytes %5.1f%%", e.name, size, pct(size, componentsTotal))
|
||||
}
|
||||
|
||||
t.Logf("\n=== Per-PeerCompact average ===")
|
||||
if len(full.Peers) > 0 {
|
||||
t.Logf(" PeerCompact avg: %d bytes/peer", mustMarshalSize(t, &proto.NetworkMapComponentsFull{Peers: full.Peers})/len(full.Peers))
|
||||
}
|
||||
if len(legacyResp.NetworkMap.RemotePeers) > 0 {
|
||||
t.Logf(" RemotePeer avg: %d bytes/peer",
|
||||
mustMarshalSize(t, &proto.NetworkMap{RemotePeers: legacyResp.NetworkMap.RemotePeers})/len(legacyResp.NetworkMap.RemotePeers))
|
||||
}
|
||||
|
||||
t.Logf("\n=== FirewallRule expansion footprint ===")
|
||||
t.Logf(" legacy FirewallRules count: %d", len(legacyResp.NetworkMap.FirewallRules))
|
||||
t.Logf(" components Policies count: %d", len(full.Policies))
|
||||
t.Logf(" components Groups count: %d", len(full.Groups))
|
||||
|
||||
totalGroupPeerIdxs := 0
|
||||
for _, g := range full.Groups {
|
||||
totalGroupPeerIdxs += len(g.PeerIndexes)
|
||||
}
|
||||
t.Logf(" components peer-index refs across all groups: %d", totalGroupPeerIdxs)
|
||||
}
|
||||
|
||||
func mustMarshalSize(t *testing.T, m goproto.Message) int {
|
||||
b, err := goproto.Marshal(m)
|
||||
if err != nil {
|
||||
t.Fatalf("marshal: %v", err)
|
||||
}
|
||||
return len(b)
|
||||
}
|
||||
|
||||
func pct(part, total int) float64 {
|
||||
if total == 0 {
|
||||
return 0
|
||||
}
|
||||
return 100 * float64(part) / float64(total)
|
||||
}
|
||||
|
||||
// Stops fmt being unused if the breakdown loop above is later commented out.
|
||||
var _ = fmt.Sprintf
|
||||
25
management/server/types/peer_networkmap_result.go
Normal file
25
management/server/types/peer_networkmap_result.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package types
|
||||
|
||||
// PeerNetworkMapResult is what the network_map controller produces for a
|
||||
// single peer. Exactly one of NetworkMap or Components is populated depending
|
||||
// on the peer's capability:
|
||||
//
|
||||
// - Components-capable peers (PeerCapabilityComponentNetworkMap) get
|
||||
// Components: the raw types.NetworkMapComponents the client decodes and
|
||||
// runs Calculate() on locally. NetworkMap stays nil — the server skips
|
||||
// the expansion entirely.
|
||||
// - Legacy peers (or any peer when the kill switch is set) get NetworkMap:
|
||||
// the fully-expanded view the legacy gRPC path consumes.
|
||||
//
|
||||
// The gRPC layer (ToSyncResponseForPeer) dispatches by which field is
|
||||
// non-nil; callers must not rely on both being set.
|
||||
type PeerNetworkMapResult struct {
|
||||
NetworkMap *NetworkMap
|
||||
Components *NetworkMapComponents
|
||||
}
|
||||
|
||||
// IsComponents reports whether the result carries the components shape.
|
||||
// Use this in preference to direct nil checks on the fields.
|
||||
func (r PeerNetworkMapResult) IsComponents() bool {
|
||||
return r.Components != nil
|
||||
}
|
||||
104
management/server/types/peer_networkmap_result_test.go
Normal file
104
management/server/types/peer_networkmap_result_test.go
Normal file
@@ -0,0 +1,104 @@
|
||||
package types_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
// helper: marks the given peer as components-capable.
|
||||
func markCapable(p *nbpeer.Peer) {
|
||||
p.Meta.Capabilities = append(p.Meta.Capabilities, nbpeer.PeerCapabilityComponentNetworkMap)
|
||||
}
|
||||
|
||||
func TestGetPeerNetworkMapResult_CapablePeerGetsComponents(t *testing.T) {
|
||||
account, validatedPeers := scalableTestAccount(10, 2)
|
||||
markCapable(account.Peers["peer-0"])
|
||||
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
groupIDToUserIDs := account.GetActiveGroupUsers()
|
||||
|
||||
result := account.GetPeerNetworkMapResult(
|
||||
context.Background(),
|
||||
"peer-0",
|
||||
false, // componentsDisabled
|
||||
nbdns.CustomZone{},
|
||||
nil,
|
||||
validatedPeers,
|
||||
resourcePolicies,
|
||||
routers,
|
||||
nil,
|
||||
groupIDToUserIDs,
|
||||
)
|
||||
|
||||
require.True(t, result.IsComponents(), "capable peer must get the components shape")
|
||||
assert.Nil(t, result.NetworkMap)
|
||||
require.NotNil(t, result.Components)
|
||||
assert.Equal(t, "peer-0", result.Components.PeerID)
|
||||
}
|
||||
|
||||
func TestGetPeerNetworkMapResult_LegacyPeerGetsNetworkMap(t *testing.T) {
|
||||
account, validatedPeers := scalableTestAccount(10, 2)
|
||||
// peer-0 left without the component capability
|
||||
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
groupIDToUserIDs := account.GetActiveGroupUsers()
|
||||
|
||||
result := account.GetPeerNetworkMapResult(
|
||||
context.Background(),
|
||||
"peer-0",
|
||||
false,
|
||||
nbdns.CustomZone{},
|
||||
nil,
|
||||
validatedPeers,
|
||||
resourcePolicies,
|
||||
routers,
|
||||
nil,
|
||||
groupIDToUserIDs,
|
||||
)
|
||||
|
||||
assert.False(t, result.IsComponents())
|
||||
assert.Nil(t, result.Components)
|
||||
require.NotNil(t, result.NetworkMap, "legacy peer must get a NetworkMap")
|
||||
}
|
||||
|
||||
func TestGetPeerNetworkMapResult_KillSwitchOverridesCapability(t *testing.T) {
|
||||
// Capable peer + componentsDisabled=true → falls back to legacy.
|
||||
account, validatedPeers := scalableTestAccount(10, 2)
|
||||
markCapable(account.Peers["peer-0"])
|
||||
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
groupIDToUserIDs := account.GetActiveGroupUsers()
|
||||
|
||||
result := account.GetPeerNetworkMapResult(
|
||||
context.Background(),
|
||||
"peer-0",
|
||||
true, // componentsDisabled = true (kill switch)
|
||||
nbdns.CustomZone{},
|
||||
nil,
|
||||
validatedPeers,
|
||||
resourcePolicies,
|
||||
routers,
|
||||
nil,
|
||||
groupIDToUserIDs,
|
||||
)
|
||||
|
||||
assert.False(t, result.IsComponents(), "kill switch must force legacy NetworkMap path")
|
||||
assert.Nil(t, result.Components)
|
||||
require.NotNil(t, result.NetworkMap)
|
||||
}
|
||||
|
||||
func TestPeerNetworkMapResult_IsComponents(t *testing.T) {
|
||||
assert.True(t, types.PeerNetworkMapResult{Components: &types.NetworkMapComponents{}}.IsComponents())
|
||||
assert.False(t, types.PeerNetworkMapResult{NetworkMap: &types.NetworkMap{}}.IsComponents())
|
||||
assert.False(t, types.PeerNetworkMapResult{}.IsComponents())
|
||||
}
|
||||
@@ -59,6 +59,10 @@ type Policy struct {
|
||||
// AccountID is a reference to Account that this object belongs
|
||||
AccountID string `json:"-" gorm:"index"`
|
||||
|
||||
// AccountSeqID is a per-account monotonically increasing identifier used as the
|
||||
// compact wire id when sending NetworkMap components to capable peers.
|
||||
AccountSeqID uint32 `json:"-" gorm:"index:idx_policies_account_seq_id;not null;default:0"`
|
||||
|
||||
// Name of the Policy
|
||||
Name string
|
||||
|
||||
@@ -75,11 +79,19 @@ type Policy struct {
|
||||
SourcePostureChecks []string `gorm:"serializer:json"`
|
||||
}
|
||||
|
||||
// HasSeqID reports whether the policy has been persisted long enough to have
|
||||
// a per-account sequence id allocated. Wire encoders that key off
|
||||
// AccountSeqID must skip policies that return false here.
|
||||
func (p *Policy) HasSeqID() bool {
|
||||
return p != nil && p.AccountSeqID != 0
|
||||
}
|
||||
|
||||
// Copy returns a copy of the policy.
|
||||
func (p *Policy) Copy() *Policy {
|
||||
c := &Policy{
|
||||
ID: p.ID,
|
||||
AccountID: p.AccountID,
|
||||
AccountSeqID: p.AccountSeqID,
|
||||
Name: p.Name,
|
||||
Description: p.Description,
|
||||
Enabled: p.Enabled,
|
||||
|
||||
@@ -95,6 +95,9 @@ type Route struct {
|
||||
ID ID `gorm:"primaryKey"`
|
||||
// AccountID is a reference to Account that this object belongs
|
||||
AccountID string `gorm:"index"`
|
||||
// AccountSeqID is a per-account monotonically increasing identifier used as the
|
||||
// compact wire id when sending NetworkMap components to capable peers.
|
||||
AccountSeqID uint32 `json:"-" gorm:"index:idx_routes_account_seq_id;not null;default:0"`
|
||||
// Network and Domains are mutually exclusive
|
||||
Network netip.Prefix `gorm:"serializer:json"`
|
||||
Domains domain.List `gorm:"serializer:json"`
|
||||
@@ -128,6 +131,7 @@ func (r *Route) Copy() *Route {
|
||||
route := &Route{
|
||||
ID: r.ID,
|
||||
AccountID: r.AccountID,
|
||||
AccountSeqID: r.AccountSeqID,
|
||||
Description: r.Description,
|
||||
NetID: r.NetID,
|
||||
Network: r.Network,
|
||||
|
||||
@@ -316,27 +316,74 @@ func TestClient_Sync(t *testing.T) {
|
||||
|
||||
select {
|
||||
case resp := <-ch:
|
||||
if resp.GetPeerConfig() == nil {
|
||||
if resp.GetPeerConfig() == nil && resp.GetNetworkMap().GetPeerConfig() == nil {
|
||||
t.Error("expecting non nil PeerConfig got nil")
|
||||
}
|
||||
if resp.GetNetbirdConfig() == nil {
|
||||
t.Error("expecting non nil NetbirdConfig got nil")
|
||||
}
|
||||
if len(resp.GetRemotePeers()) != 1 {
|
||||
t.Errorf("expecting RemotePeers size %d got %d", 1, len(resp.GetRemotePeers()))
|
||||
// Component-capable clients receive a NetworkMapEnvelope; the
|
||||
// remote-peers list is encoded inside it. Decode it and check the
|
||||
// envelope's peers slice. Legacy peers populate the top-level
|
||||
// RemotePeers; both shapes must surface exactly one remote peer.
|
||||
remotePeerKeys := remotePeerKeysFromSync(resp, testKey.PublicKey().String())
|
||||
if len(remotePeerKeys) != 1 {
|
||||
t.Errorf("expecting RemotePeers size %d got %d", 1, len(remotePeerKeys))
|
||||
return
|
||||
}
|
||||
if resp.GetRemotePeersIsEmpty() == true {
|
||||
if resp.GetNetworkMap() != nil && resp.GetRemotePeersIsEmpty() {
|
||||
t.Error("expecting RemotePeers property to be false, got true")
|
||||
}
|
||||
if resp.GetRemotePeers()[0].GetWgPubKey() != remoteKey.PublicKey().String() {
|
||||
t.Errorf("expecting RemotePeer public key %s got %s", remoteKey.PublicKey().String(), resp.GetRemotePeers()[0].GetWgPubKey())
|
||||
if remotePeerKeys[0] != remoteKey.PublicKey().String() {
|
||||
t.Errorf("expecting RemotePeer public key %s got %s", remoteKey.PublicKey().String(), remotePeerKeys[0])
|
||||
}
|
||||
case <-time.After(3 * time.Second):
|
||||
t.Error("timeout waiting for test to finish")
|
||||
}
|
||||
}
|
||||
|
||||
// remotePeerKeysFromSync extracts the remote-peer WG keys from either the
|
||||
// legacy NetworkMap.RemotePeers list or the components NetworkMapEnvelope's
|
||||
// inner peers slice (filtering out the local receiving peer identified by
|
||||
// localKey, since the envelope's peers list is index-addressed and includes
|
||||
// the local peer alongside remotes).
|
||||
func remotePeerKeysFromSync(resp *mgmtProto.SyncResponse, localKey string) []string {
|
||||
if rp := resp.GetRemotePeers(); len(rp) > 0 {
|
||||
out := make([]string, 0, len(rp))
|
||||
for _, p := range rp {
|
||||
out = append(out, p.GetWgPubKey())
|
||||
}
|
||||
return out
|
||||
}
|
||||
env := resp.GetNetworkMapEnvelope().GetFull()
|
||||
if env == nil {
|
||||
return nil
|
||||
}
|
||||
out := make([]string, 0, len(env.GetPeers()))
|
||||
for _, p := range env.GetPeers() {
|
||||
key := wgKeyFromBytes(p.GetWgPubKey())
|
||||
if key == "" || key == localKey {
|
||||
continue
|
||||
}
|
||||
out = append(out, key)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// wgKeyFromBytes mirrors the client-side decoder: the envelope ships raw 32
|
||||
// bytes; reconstruct the standard base64 key the test compares against.
|
||||
func wgKeyFromBytes(raw []byte) string {
|
||||
if len(raw) == 0 {
|
||||
return ""
|
||||
}
|
||||
var k wgtypes.Key
|
||||
if len(raw) != len(k) {
|
||||
return ""
|
||||
}
|
||||
copy(k[:], raw)
|
||||
return k.String()
|
||||
}
|
||||
|
||||
func Test_SystemMetaDataFromClient(t *testing.T) {
|
||||
s, lis, mgmtMockServer, serverKey := startMockManagement(t)
|
||||
defer s.GracefulStop()
|
||||
|
||||
@@ -950,6 +950,13 @@ func infoToMetaData(info *system.Info) *proto.PeerSystemMeta {
|
||||
func peerCapabilities(info system.Info) []proto.PeerCapability {
|
||||
caps := []proto.PeerCapability{
|
||||
proto.PeerCapability_PeerCapabilitySourcePrefixes,
|
||||
// PeerCapabilityComponentNetworkMap signals that this client can
|
||||
// decode the components-format SyncResponse.NetworkMapEnvelope and
|
||||
// run Calculate() locally. Always advertised by Step-4-capable
|
||||
// builds — there's no opt-out flag because the server-side kill
|
||||
// switch (NB_NETWORK_MAP_COMPONENTS_DISABLE) covers emergency
|
||||
// rollback and the client decoder is built in.
|
||||
proto.PeerCapability_PeerCapabilityComponentNetworkMap,
|
||||
}
|
||||
if !info.DisableIPv6 {
|
||||
caps = append(caps, proto.PeerCapability_PeerCapabilityIPv6Overlay)
|
||||
|
||||
607
shared/management/networkmap/decode.go
Normal file
607
shared/management/networkmap/decode.go
Normal file
@@ -0,0 +1,607 @@
|
||||
package networkmap
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
nbroute "github.com/netbirdio/netbird/route"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
// DecodeEnvelope converts a NetworkMapEnvelope into a NetworkMapComponents
|
||||
// the client can run Calculate() over. Every ID-reference on the wire is a
|
||||
// uint32 (peer index or account_seq_id) — no xid strings travel. The decoder
|
||||
// synthesises consistent string IDs from the uint32s so the reconstructed
|
||||
// components struct round-trips through Calculate exactly the way the
|
||||
// server-side typed components would.
|
||||
//
|
||||
// Synthetic ID scheme (underscore-separated, visually distinct from the xid
|
||||
// format Calculate would put in log lines under the legacy path):
|
||||
//
|
||||
// Peers "p_<wire_index>" // envelope.peers is index-addressed
|
||||
// Groups "g_<account_seq_id>"
|
||||
// Policies "pol_<account_seq_id>" // 1 rule per policy
|
||||
// Routes "r_<account_seq_id>"
|
||||
// Network resources "nres_<account_seq_id>"
|
||||
// Posture checks "pc_<account_seq_id>"
|
||||
// Networks "net_<account_seq_id>"
|
||||
// Nameserver groups "nsg_<account_seq_id>"
|
||||
func DecodeEnvelope(env *proto.NetworkMapEnvelope) (*types.NetworkMapComponents, error) {
|
||||
if env == nil {
|
||||
return nil, fmt.Errorf("nil envelope")
|
||||
}
|
||||
full := env.GetFull()
|
||||
if full == nil {
|
||||
return nil, fmt.Errorf("envelope has no Full payload")
|
||||
}
|
||||
|
||||
c := &types.NetworkMapComponents{
|
||||
PeerID: "", // engine fills its own peer id from PeerConfig
|
||||
Network: decodeAccountNetwork(full.Network),
|
||||
AccountSettings: decodeAccountSettings(full.AccountSettings),
|
||||
CustomZoneDomain: full.CustomZoneDomain,
|
||||
Peers: make(map[string]*nbpeer.Peer, len(full.Peers)),
|
||||
Groups: make(map[string]*types.Group, len(full.Groups)),
|
||||
Policies: make([]*types.Policy, 0, len(full.Policies)),
|
||||
Routes: make([]*nbroute.Route, 0, len(full.Routes)),
|
||||
NameServerGroups: make([]*nbdns.NameServerGroup, 0, len(full.NameserverGroups)),
|
||||
AllDNSRecords: decodeSimpleRecords(full.AllDnsRecords),
|
||||
AccountZones: decodeCustomZones(full.AccountZones),
|
||||
ResourcePoliciesMap: make(map[string][]*types.Policy),
|
||||
RoutersMap: make(map[string]map[string]*routerTypes.NetworkRouter),
|
||||
NetworkResources: make([]*resourceTypes.NetworkResource, 0, len(full.NetworkResources)),
|
||||
RouterPeers: make(map[string]*nbpeer.Peer),
|
||||
AllowedUserIDs: stringSliceToSet(full.AllowedUserIds),
|
||||
PostureFailedPeers: make(map[string]map[string]struct{}, len(full.PostureFailedPeers)),
|
||||
GroupIDToUserIDs: make(map[string][]string, len(full.GroupIdToUserIds)),
|
||||
}
|
||||
|
||||
if full.DnsSettings != nil {
|
||||
c.DNSSettings = &types.DNSSettings{
|
||||
DisabledManagementGroups: groupIDsFromSeqs(full.DnsSettings.DisabledManagementGroupIds),
|
||||
}
|
||||
} else {
|
||||
c.DNSSettings = &types.DNSSettings{}
|
||||
}
|
||||
|
||||
// Phase 1: peers. The envelope's peers slice is index-addressed; we
|
||||
// build a peerOrder lookup for downstream references. Peer.ID is
|
||||
// synthesized from the peer's wire index — wire format ships no xid
|
||||
// for peers (and never has).
|
||||
peerIDByIndex := make([]string, len(full.Peers))
|
||||
for idx, pc := range full.Peers {
|
||||
if pc == nil {
|
||||
return nil, fmt.Errorf("invalid envelope: peers[%d] is nil", idx)
|
||||
}
|
||||
peerID := synthPeerID(uint32(idx))
|
||||
peer := decodePeerCompact(pc, peerID, full.AgentVersions)
|
||||
c.Peers[peerID] = peer
|
||||
peerIDByIndex[idx] = peerID
|
||||
}
|
||||
|
||||
// Phase 2: groups. AccountSeqID becomes both the synthesized string ID
|
||||
// and the GroupCompact.id wire value.
|
||||
for i, gc := range full.Groups {
|
||||
if gc == nil {
|
||||
return nil, fmt.Errorf("invalid envelope: groups[%d] is nil", i)
|
||||
}
|
||||
groupID := synthGroupID(gc.Id)
|
||||
peerIDs := make([]string, 0, len(gc.PeerIndexes))
|
||||
for _, idx := range gc.PeerIndexes {
|
||||
if int(idx) < len(peerIDByIndex) {
|
||||
peerIDs = append(peerIDs, peerIDByIndex[idx])
|
||||
}
|
||||
}
|
||||
c.Groups[groupID] = &types.Group{
|
||||
ID: groupID,
|
||||
AccountSeqID: gc.Id,
|
||||
Name: gc.Name,
|
||||
Peers: peerIDs,
|
||||
}
|
||||
}
|
||||
|
||||
// Phase 3: policies (PolicyCompact = one rule per entry; current data
|
||||
// model is 1 rule per policy). Policy.ID is synthesized from the
|
||||
// per-account seq id; proto.FirewallRule.PolicyID downstream carries
|
||||
// the same synth string (no xid on the wire).
|
||||
for i, pc := range full.Policies {
|
||||
if pc == nil {
|
||||
return nil, fmt.Errorf("invalid envelope: policies[%d] is nil", i)
|
||||
}
|
||||
policyID := synthPolicyID(pc.Id)
|
||||
c.Policies = append(c.Policies, decodePolicyCompact(pc, policyID, peerIDByIndex))
|
||||
}
|
||||
|
||||
// Phase 4: routes.
|
||||
for i, rr := range full.Routes {
|
||||
if rr == nil {
|
||||
return nil, fmt.Errorf("invalid envelope: routes[%d] is nil", i)
|
||||
}
|
||||
c.Routes = append(c.Routes, decodeRouteRaw(rr, peerIDByIndex))
|
||||
}
|
||||
|
||||
// Phase 5: NSGs.
|
||||
for i, nsg := range full.NameserverGroups {
|
||||
if nsg == nil {
|
||||
return nil, fmt.Errorf("invalid envelope: nameserver_groups[%d] is nil", i)
|
||||
}
|
||||
c.NameServerGroups = append(c.NameServerGroups, decodeNameServerGroupRaw(nsg))
|
||||
}
|
||||
|
||||
// Phase 6: network resources.
|
||||
for i, nr := range full.NetworkResources {
|
||||
if nr == nil {
|
||||
return nil, fmt.Errorf("invalid envelope: network_resources[%d] is nil", i)
|
||||
}
|
||||
c.NetworkResources = append(c.NetworkResources, decodeNetworkResource(nr))
|
||||
}
|
||||
|
||||
// Phase 7: routers_map (outer key = network seq id, inner key = peer-id
|
||||
// reconstructed from peer_index). Synthesized network id is "net_<seq>".
|
||||
for networkSeq, list := range full.RoutersMap {
|
||||
networkID := synthNetworkID(networkSeq)
|
||||
inner := make(map[string]*routerTypes.NetworkRouter, len(list.Entries))
|
||||
for _, entry := range list.Entries {
|
||||
if !entry.PeerIndexSet {
|
||||
continue
|
||||
}
|
||||
if int(entry.PeerIndex) >= len(peerIDByIndex) {
|
||||
continue
|
||||
}
|
||||
peerID := peerIDByIndex[entry.PeerIndex]
|
||||
inner[peerID] = &routerTypes.NetworkRouter{
|
||||
ID: "",
|
||||
NetworkID: networkID,
|
||||
AccountSeqID: entry.Id,
|
||||
Peer: peerID,
|
||||
PeerGroups: groupIDsFromSeqs(entry.PeerGroupIds),
|
||||
Masquerade: entry.Masquerade,
|
||||
Metric: int(entry.Metric),
|
||||
Enabled: entry.Enabled,
|
||||
}
|
||||
}
|
||||
if len(inner) > 0 {
|
||||
c.RoutersMap[networkID] = inner
|
||||
}
|
||||
}
|
||||
|
||||
// Phase 8: resource_policies_map (resource seq id → list of *types.Policy
|
||||
// pointers from the decoded policies slice). Resource ID is synthesized
|
||||
// the same way as in decodeNetworkResource.
|
||||
for resourceSeq, idxs := range full.ResourcePoliciesMap {
|
||||
if len(idxs.Indexes) == 0 {
|
||||
continue
|
||||
}
|
||||
resourceID := synthNetworkResourceID(resourceSeq)
|
||||
policies := make([]*types.Policy, 0, len(idxs.Indexes))
|
||||
for _, i := range idxs.Indexes {
|
||||
if int(i) < len(c.Policies) {
|
||||
policies = append(policies, c.Policies[i])
|
||||
}
|
||||
}
|
||||
if len(policies) > 0 {
|
||||
c.ResourcePoliciesMap[resourceID] = policies
|
||||
}
|
||||
}
|
||||
|
||||
// Phase 9: group_id_to_user_ids — wire keys are seq ids, synth to strings.
|
||||
for groupSeq, list := range full.GroupIdToUserIds {
|
||||
c.GroupIDToUserIDs[synthGroupID(groupSeq)] = append([]string(nil), list.UserIds...)
|
||||
}
|
||||
|
||||
// Phase 10: posture_failed_peers — wire keys are posture-check seq ids,
|
||||
// values are peer indexes that need to be turned into peer ids. PolicyRule
|
||||
// SourcePostureChecks (also synth ids) reference the same key space.
|
||||
for checkSeq, set := range full.PostureFailedPeers {
|
||||
checkID := synthPostureCheckID(checkSeq)
|
||||
failed := make(map[string]struct{}, len(set.PeerIndexes))
|
||||
for _, idx := range set.PeerIndexes {
|
||||
if int(idx) < len(peerIDByIndex) {
|
||||
failed[peerIDByIndex[idx]] = struct{}{}
|
||||
}
|
||||
}
|
||||
if len(failed) > 0 {
|
||||
c.PostureFailedPeers[checkID] = failed
|
||||
}
|
||||
}
|
||||
|
||||
// Phase 11: router_peer_indexes — peers that act as routers. They're
|
||||
// already in c.Peers (router peers are appended to the global peers
|
||||
// list by the encoder); RouterPeers is the subset.
|
||||
for _, idx := range full.RouterPeerIndexes {
|
||||
if int(idx) < len(peerIDByIndex) {
|
||||
peerID := peerIDByIndex[idx]
|
||||
c.RouterPeers[peerID] = c.Peers[peerID]
|
||||
}
|
||||
}
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func decodeAccountNetwork(an *proto.AccountNetwork) *types.Network {
|
||||
if an == nil {
|
||||
return nil
|
||||
}
|
||||
n := &types.Network{
|
||||
Identifier: an.Identifier,
|
||||
Dns: an.Dns,
|
||||
Serial: an.Serial,
|
||||
}
|
||||
if an.NetCidr != "" {
|
||||
if _, ipnet, err := net.ParseCIDR(an.NetCidr); err == nil && ipnet != nil {
|
||||
n.Net = *ipnet
|
||||
}
|
||||
}
|
||||
if an.NetV6Cidr != "" {
|
||||
if _, ipnet, err := net.ParseCIDR(an.NetV6Cidr); err == nil && ipnet != nil {
|
||||
n.NetV6 = *ipnet
|
||||
}
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
func decodeAccountSettings(as *proto.AccountSettingsCompact) *types.AccountSettingsInfo {
|
||||
if as == nil {
|
||||
return &types.AccountSettingsInfo{}
|
||||
}
|
||||
return &types.AccountSettingsInfo{
|
||||
PeerLoginExpirationEnabled: as.PeerLoginExpirationEnabled,
|
||||
PeerLoginExpiration: time.Duration(as.PeerLoginExpirationNs),
|
||||
}
|
||||
}
|
||||
|
||||
func decodePeerCompact(pc *proto.PeerCompact, peerID string, agentVersions []string) *nbpeer.Peer {
|
||||
var caps []int32
|
||||
if pc.SupportsSourcePrefixes {
|
||||
caps = append(caps, nbpeer.PeerCapabilitySourcePrefixes)
|
||||
}
|
||||
if pc.SupportsIpv6 {
|
||||
caps = append(caps, nbpeer.PeerCapabilityIPv6Overlay)
|
||||
}
|
||||
peer := &nbpeer.Peer{
|
||||
ID: peerID,
|
||||
Key: encodeWgKeyBase64(pc.WgPubKey),
|
||||
SSHKey: string(pc.SshPubKey),
|
||||
SSHEnabled: pc.SshEnabled,
|
||||
DNSLabel: pc.DnsLabel,
|
||||
LoginExpirationEnabled: pc.LoginExpirationEnabled,
|
||||
Meta: nbpeer.PeerSystemMeta{
|
||||
WtVersion: lookupAgentVersion(agentVersions, pc.AgentVersionIdx),
|
||||
Capabilities: caps,
|
||||
Flags: nbpeer.Flags{
|
||||
ServerSSHAllowed: pc.ServerSshAllowed,
|
||||
},
|
||||
},
|
||||
}
|
||||
if pc.AddedWithSsoLogin {
|
||||
// Set a non-empty UserID so (*Peer).AddedWithSSOLogin() returns true.
|
||||
// The original UserID isn't on the wire; the value is intentionally
|
||||
// visibly synthetic so any future consumer that mistakes UserID for a
|
||||
// real account user xid won't silently match (or worse, write the
|
||||
// sentinel into a downstream record).
|
||||
peer.UserID = "<env-sso>"
|
||||
}
|
||||
if pc.LastLoginUnixNano != 0 {
|
||||
t := time.Unix(0, pc.LastLoginUnixNano)
|
||||
peer.LastLogin = &t
|
||||
}
|
||||
switch len(pc.Ip) {
|
||||
case 4:
|
||||
peer.IP = netip.AddrFrom4([4]byte{pc.Ip[0], pc.Ip[1], pc.Ip[2], pc.Ip[3]})
|
||||
case 16:
|
||||
var a [16]byte
|
||||
copy(a[:], pc.Ip)
|
||||
peer.IP = netip.AddrFrom16(a)
|
||||
}
|
||||
if len(pc.Ipv6) == 16 {
|
||||
var a [16]byte
|
||||
copy(a[:], pc.Ipv6)
|
||||
peer.IPv6 = netip.AddrFrom16(a)
|
||||
}
|
||||
return peer
|
||||
}
|
||||
|
||||
func decodePolicyCompact(pc *proto.PolicyCompact, policyID string, peerIDByIndex []string) *types.Policy {
|
||||
rule := &types.PolicyRule{
|
||||
ID: policyID, // 1 rule per policy → reuse synthesized id
|
||||
PolicyID: policyID,
|
||||
Enabled: true,
|
||||
Action: actionFromProto(pc.Action),
|
||||
Protocol: protocolFromProto(pc.Protocol),
|
||||
Bidirectional: pc.Bidirectional,
|
||||
Ports: uint32SliceToStrings(pc.Ports),
|
||||
PortRanges: portRangesFromProto(pc.PortRanges),
|
||||
Sources: groupIDsFromSeqs(pc.SourceGroupIds),
|
||||
Destinations: groupIDsFromSeqs(pc.DestinationGroupIds),
|
||||
AuthorizedUser: pc.AuthorizedUser,
|
||||
AuthorizedGroups: authorizedGroupsFromProto(pc.AuthorizedGroups),
|
||||
SourceResource: resourceFromProto(pc.SourceResource, peerIDByIndex),
|
||||
DestinationResource: resourceFromProto(pc.DestinationResource, peerIDByIndex),
|
||||
}
|
||||
return &types.Policy{
|
||||
ID: policyID,
|
||||
AccountSeqID: pc.Id,
|
||||
Enabled: true,
|
||||
Rules: []*types.PolicyRule{rule},
|
||||
SourcePostureChecks: postureCheckIDsFromSeqs(pc.SourcePostureCheckSeqIds),
|
||||
}
|
||||
}
|
||||
|
||||
// resourceFromProto rebuilds types.Resource. For peer-typed resources the
|
||||
// peer reference is reconstructed from the envelope's peer index — wire
|
||||
// format ships no xid for peers, so we use the synthesized peer id.
|
||||
func resourceFromProto(r *proto.ResourceCompact, peerIDByIndex []string) types.Resource {
|
||||
if r == nil {
|
||||
return types.Resource{}
|
||||
}
|
||||
out := types.Resource{Type: types.ResourceType(r.Type)}
|
||||
if r.PeerIndexSet && int(r.PeerIndex) < len(peerIDByIndex) {
|
||||
out.ID = peerIDByIndex[r.PeerIndex]
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// postureCheckIDsFromSeqs synths posture-check ids from per-account seq ids.
|
||||
// Mirrors groupIDsFromSeqs.
|
||||
func postureCheckIDsFromSeqs(seqs []uint32) []string {
|
||||
if len(seqs) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]string, len(seqs))
|
||||
for i, s := range seqs {
|
||||
out[i] = synthPostureCheckID(s)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// authorizedGroupsFromProto inverts encodeAuthorizedGroups: the wire form
|
||||
// keys by group account_seq_id, the typed PolicyRule field keys by group
|
||||
// xid string. We rebuild using the same synthetic scheme the rest of the
|
||||
// decoder uses ("g<seq>").
|
||||
func authorizedGroupsFromProto(m map[uint32]*proto.UserNameList) map[string][]string {
|
||||
if len(m) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make(map[string][]string, len(m))
|
||||
for seq, list := range m {
|
||||
if list == nil {
|
||||
continue
|
||||
}
|
||||
out[synthGroupID(seq)] = append([]string(nil), list.Names...)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func decodeRouteRaw(rr *proto.RouteRaw, peerIDByIndex []string) *nbroute.Route {
|
||||
r := &nbroute.Route{
|
||||
ID: nbroute.ID(synthRouteID(rr.Id)),
|
||||
AccountSeqID: rr.Id,
|
||||
NetID: nbroute.NetID(rr.NetId),
|
||||
Description: rr.Description,
|
||||
Domains: domainsFromPunycode(rr.Domains),
|
||||
KeepRoute: rr.KeepRoute,
|
||||
NetworkType: nbroute.NetworkType(rr.NetworkType),
|
||||
Masquerade: rr.Masquerade,
|
||||
Metric: int(rr.Metric),
|
||||
Enabled: rr.Enabled,
|
||||
Groups: groupIDsFromSeqs(rr.GroupIds),
|
||||
AccessControlGroups: groupIDsFromSeqs(rr.AccessControlGroupIds),
|
||||
PeerGroups: groupIDsFromSeqs(rr.PeerGroupIds),
|
||||
SkipAutoApply: rr.SkipAutoApply,
|
||||
}
|
||||
if rr.NetworkCidr != "" {
|
||||
if p, err := netip.ParsePrefix(rr.NetworkCidr); err == nil {
|
||||
r.Network = p
|
||||
}
|
||||
}
|
||||
if rr.PeerIndexSet && int(rr.PeerIndex) < len(peerIDByIndex) {
|
||||
r.Peer = peerIDByIndex[rr.PeerIndex]
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
func decodeNameServerGroupRaw(nsg *proto.NameServerGroupRaw) *nbdns.NameServerGroup {
|
||||
out := &nbdns.NameServerGroup{
|
||||
ID: synthNameServerGroupID(nsg.Id),
|
||||
AccountSeqID: nsg.Id,
|
||||
Name: nsg.Name,
|
||||
Description: nsg.Description,
|
||||
Groups: groupIDsFromSeqs(nsg.GroupIds),
|
||||
Primary: nsg.Primary,
|
||||
Domains: nsg.Domains,
|
||||
Enabled: nsg.Enabled,
|
||||
SearchDomainsEnabled: nsg.SearchDomainsEnabled,
|
||||
NameServers: make([]nbdns.NameServer, 0, len(nsg.Nameservers)),
|
||||
}
|
||||
for _, ns := range nsg.Nameservers {
|
||||
if addr, err := netip.ParseAddr(ns.IP); err == nil {
|
||||
out.NameServers = append(out.NameServers, nbdns.NameServer{
|
||||
IP: addr,
|
||||
NSType: nbdns.NameServerType(ns.NSType),
|
||||
Port: int(ns.Port),
|
||||
})
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func decodeNetworkResource(nr *proto.NetworkResourceRaw) *resourceTypes.NetworkResource {
|
||||
out := &resourceTypes.NetworkResource{
|
||||
ID: synthNetworkResourceID(nr.Id),
|
||||
AccountSeqID: nr.Id,
|
||||
NetworkID: synthNetworkID(nr.NetworkSeq),
|
||||
Name: nr.Name,
|
||||
Description: nr.Description,
|
||||
Type: resourceTypes.NetworkResourceType(nr.Type),
|
||||
Address: nr.Address,
|
||||
Domain: nr.DomainValue,
|
||||
Enabled: nr.Enabled,
|
||||
}
|
||||
if nr.PrefixCidr != "" {
|
||||
if p, err := netip.ParsePrefix(nr.PrefixCidr); err == nil {
|
||||
out.Prefix = p
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func decodeSimpleRecords(records []*proto.SimpleRecord) []nbdns.SimpleRecord {
|
||||
out := make([]nbdns.SimpleRecord, 0, len(records))
|
||||
for _, r := range records {
|
||||
out = append(out, nbdns.SimpleRecord{
|
||||
Name: r.Name,
|
||||
Type: int(r.Type),
|
||||
Class: r.Class,
|
||||
TTL: int(r.TTL),
|
||||
RData: r.RData,
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func decodeCustomZones(zones []*proto.CustomZone) []nbdns.CustomZone {
|
||||
out := make([]nbdns.CustomZone, 0, len(zones))
|
||||
for _, z := range zones {
|
||||
out = append(out, nbdns.CustomZone{
|
||||
Domain: z.Domain,
|
||||
Records: decodeSimpleRecords(z.Records),
|
||||
SearchDomainDisabled: z.SearchDomainDisabled,
|
||||
NonAuthoritative: z.NonAuthoritative,
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// Synthetic ID generators — deterministic given the same wire input.
|
||||
// Underscore-separated ("p_<n>", "pol_<n>", ...) so they're visually
|
||||
// distinct in operator logs. fmt.Sprintf would dominate the decode hot path
|
||||
// on large accounts (a 10k-peer envelope produces ~50k synth calls); the
|
||||
// strconv.AppendUint builder keeps it allocation-light.
|
||||
func synthID(prefix string, n uint32) string {
|
||||
buf := make([]byte, 0, len(prefix)+10)
|
||||
buf = append(buf, prefix...)
|
||||
buf = strconv.AppendUint(buf, uint64(n), 10)
|
||||
return string(buf)
|
||||
}
|
||||
|
||||
func synthPeerID(idx uint32) string { return synthID("p_", idx) }
|
||||
func synthGroupID(seq uint32) string { return synthID("g_", seq) }
|
||||
func synthPolicyID(seq uint32) string { return synthID("pol_", seq) }
|
||||
func synthRouteID(seq uint32) string { return synthID("r_", seq) }
|
||||
func synthNetworkResourceID(seq uint32) string { return synthID("nres_", seq) }
|
||||
func synthPostureCheckID(seq uint32) string { return synthID("pc_", seq) }
|
||||
func synthNetworkID(seq uint32) string { return synthID("net_", seq) }
|
||||
func synthNameServerGroupID(seq uint32) string { return synthID("nsg_", seq) }
|
||||
|
||||
func groupIDsFromSeqs(seqs []uint32) []string {
|
||||
if len(seqs) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]string, len(seqs))
|
||||
for i, s := range seqs {
|
||||
out[i] = synthGroupID(s)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func uint32SliceToStrings(ports []uint32) []string {
|
||||
if len(ports) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]string, len(ports))
|
||||
for i, p := range ports {
|
||||
out[i] = strconv.FormatUint(uint64(p), 10)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func portRangesFromProto(ranges []*proto.PortInfo_Range) []types.RulePortRange {
|
||||
if len(ranges) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]types.RulePortRange, 0, len(ranges))
|
||||
for _, r := range ranges {
|
||||
if r == nil || r.Start > 65535 || r.End > 65535 {
|
||||
continue
|
||||
}
|
||||
out = append(out, types.RulePortRange{
|
||||
Start: uint16(r.Start),
|
||||
End: uint16(r.End),
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func actionFromProto(a proto.RuleAction) types.PolicyTrafficActionType {
|
||||
if a == proto.RuleAction_DROP {
|
||||
return types.PolicyTrafficActionDrop
|
||||
}
|
||||
return types.PolicyTrafficActionAccept
|
||||
}
|
||||
|
||||
func protocolFromProto(p proto.RuleProtocol) types.PolicyRuleProtocolType {
|
||||
switch p {
|
||||
case proto.RuleProtocol_TCP:
|
||||
return types.PolicyRuleProtocolTCP
|
||||
case proto.RuleProtocol_UDP:
|
||||
return types.PolicyRuleProtocolUDP
|
||||
case proto.RuleProtocol_ICMP:
|
||||
return types.PolicyRuleProtocolICMP
|
||||
case proto.RuleProtocol_ALL:
|
||||
return types.PolicyRuleProtocolALL
|
||||
case proto.RuleProtocol_NETBIRD_SSH:
|
||||
return types.PolicyRuleProtocolNetbirdSSH
|
||||
default:
|
||||
return types.PolicyRuleProtocolALL
|
||||
}
|
||||
}
|
||||
|
||||
func encodeWgKeyBase64(raw []byte) string {
|
||||
if len(raw) != 32 {
|
||||
return ""
|
||||
}
|
||||
return base64.StdEncoding.EncodeToString(raw)
|
||||
}
|
||||
|
||||
func lookupAgentVersion(table []string, idx uint32) string {
|
||||
if int(idx) < len(table) {
|
||||
return table[idx]
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func stringSliceToSet(s []string) map[string]struct{} {
|
||||
if len(s) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make(map[string]struct{}, len(s))
|
||||
for _, v := range s {
|
||||
out[v] = struct{}{}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// domainsFromPunycode is a thin wrapper that converts a punycode list back to
|
||||
// the domain.List type the route.Route struct expects. It accepts the
|
||||
// punycode strings as-is (no extra decoding) — symmetric with
|
||||
// route.Domains.ToPunycodeList() used in the encoder.
|
||||
func domainsFromPunycode(punycoded []string) domain.List {
|
||||
if len(punycoded) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make(domain.List, 0, len(punycoded))
|
||||
for _, d := range punycoded {
|
||||
out = append(out, domain.Domain(d))
|
||||
}
|
||||
return out
|
||||
}
|
||||
323
shared/management/networkmap/encode.go
Normal file
323
shared/management/networkmap/encode.go
Normal file
@@ -0,0 +1,323 @@
|
||||
// Package networkmap contains the shared NetworkMap helpers that both the
|
||||
// management server and the client agent need.
|
||||
//
|
||||
// The proto-conversion helpers (types.NetworkMap → proto.NetworkMap) live
|
||||
// here so the client can run the same conversion locally after deriving its
|
||||
// NetworkMap from a NetworkMapEnvelope, without taking a dependency on the
|
||||
// server-side conversion package (which pulls in cloud integrations and is
|
||||
// otherwise an unwanted internal import on the client).
|
||||
//
|
||||
// The helpers are pure functions over inputs — no caches, no IO, no logging
|
||||
// beyond a context-aware error log when an individual user-id hash fails.
|
||||
package networkmap
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
goproto "google.golang.org/protobuf/proto"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"net/netip"
|
||||
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
nbroute "github.com/netbirdio/netbird/route"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
"github.com/netbirdio/netbird/shared/netiputil"
|
||||
"github.com/netbirdio/netbird/shared/sshauth"
|
||||
)
|
||||
|
||||
// ToProtocolRoutes converts a slice of typed routes to their proto form.
|
||||
func ToProtocolRoutes(routes []*nbroute.Route) []*proto.Route {
|
||||
protoRoutes := make([]*proto.Route, 0, len(routes))
|
||||
for _, r := range routes {
|
||||
protoRoutes = append(protoRoutes, ToProtocolRoute(r))
|
||||
}
|
||||
return protoRoutes
|
||||
}
|
||||
|
||||
// ToProtocolRoute converts one typed route to its proto form.
|
||||
func ToProtocolRoute(route *nbroute.Route) *proto.Route {
|
||||
return &proto.Route{
|
||||
ID: string(route.ID),
|
||||
NetID: string(route.NetID),
|
||||
Network: route.Network.String(),
|
||||
Domains: route.Domains.ToPunycodeList(),
|
||||
NetworkType: int64(route.NetworkType),
|
||||
Peer: route.Peer,
|
||||
Metric: int64(route.Metric),
|
||||
Masquerade: route.Masquerade,
|
||||
KeepRoute: route.KeepRoute,
|
||||
SkipAutoApply: route.SkipAutoApply,
|
||||
}
|
||||
}
|
||||
|
||||
// ToProtocolFirewallRules converts the firewall rules to the protocol form.
|
||||
// When useSourcePrefixes is true, the compact SourcePrefixes field is
|
||||
// populated alongside the deprecated PeerIP for forward compatibility.
|
||||
// Wildcard rules ("0.0.0.0") are expanded into separate v4/v6 SourcePrefixes
|
||||
// when includeIPv6 is true.
|
||||
func ToProtocolFirewallRules(rules []*types.FirewallRule, includeIPv6, useSourcePrefixes bool) []*proto.FirewallRule {
|
||||
result := make([]*proto.FirewallRule, 0, len(rules))
|
||||
for i := range rules {
|
||||
rule := rules[i]
|
||||
|
||||
fwRule := &proto.FirewallRule{
|
||||
PolicyID: []byte(rule.PolicyID),
|
||||
PeerIP: rule.PeerIP, //nolint:staticcheck // populated for backward compatibility
|
||||
Direction: GetProtoDirection(rule.Direction),
|
||||
Action: GetProtoAction(rule.Action),
|
||||
Protocol: GetProtoProtocol(rule.Protocol),
|
||||
Port: rule.Port,
|
||||
}
|
||||
|
||||
if useSourcePrefixes && rule.PeerIP != "" {
|
||||
result = append(result, populateSourcePrefixes(fwRule, rule, includeIPv6)...)
|
||||
}
|
||||
|
||||
if ShouldUsePortRange(fwRule) {
|
||||
fwRule.PortInfo = rule.PortRange.ToProto()
|
||||
}
|
||||
|
||||
result = append(result, fwRule)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// populateSourcePrefixes sets SourcePrefixes on fwRule and returns any
|
||||
// additional rules needed (e.g. a v6 wildcard clone when the peer IP is
|
||||
// unspecified).
|
||||
func populateSourcePrefixes(fwRule *proto.FirewallRule, rule *types.FirewallRule, includeIPv6 bool) []*proto.FirewallRule {
|
||||
addr, err := netip.ParseAddr(rule.PeerIP)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if !addr.IsUnspecified() {
|
||||
fwRule.SourcePrefixes = [][]byte{netiputil.EncodeAddr(addr.Unmap())}
|
||||
return nil
|
||||
}
|
||||
|
||||
v4Wildcard, _ := netiputil.EncodePrefix(netip.PrefixFrom(netip.IPv4Unspecified(), 0))
|
||||
fwRule.SourcePrefixes = [][]byte{v4Wildcard}
|
||||
|
||||
if !includeIPv6 {
|
||||
return nil
|
||||
}
|
||||
|
||||
v6Rule := goproto.Clone(fwRule).(*proto.FirewallRule)
|
||||
v6Rule.PeerIP = "::" //nolint:staticcheck // populated for backward compatibility
|
||||
v6Wildcard, _ := netiputil.EncodePrefix(netip.PrefixFrom(netip.IPv6Unspecified(), 0))
|
||||
v6Rule.SourcePrefixes = [][]byte{v6Wildcard}
|
||||
if ShouldUsePortRange(v6Rule) {
|
||||
v6Rule.PortInfo = rule.PortRange.ToProto()
|
||||
}
|
||||
return []*proto.FirewallRule{v6Rule}
|
||||
}
|
||||
|
||||
// GetProtoDirection converts the direction to proto.RuleDirection.
|
||||
func GetProtoDirection(direction int) proto.RuleDirection {
|
||||
if direction == types.FirewallRuleDirectionOUT {
|
||||
return proto.RuleDirection_OUT
|
||||
}
|
||||
return proto.RuleDirection_IN
|
||||
}
|
||||
|
||||
// GetProtoAction converts the action to proto.RuleAction.
|
||||
func GetProtoAction(action string) proto.RuleAction {
|
||||
if action == string(types.PolicyTrafficActionDrop) {
|
||||
return proto.RuleAction_DROP
|
||||
}
|
||||
return proto.RuleAction_ACCEPT
|
||||
}
|
||||
|
||||
// GetProtoProtocol converts the protocol to proto.RuleProtocol.
|
||||
func GetProtoProtocol(protocol string) proto.RuleProtocol {
|
||||
switch types.PolicyRuleProtocolType(protocol) {
|
||||
case types.PolicyRuleProtocolALL:
|
||||
return proto.RuleProtocol_ALL
|
||||
case types.PolicyRuleProtocolTCP:
|
||||
return proto.RuleProtocol_TCP
|
||||
case types.PolicyRuleProtocolUDP:
|
||||
return proto.RuleProtocol_UDP
|
||||
case types.PolicyRuleProtocolICMP:
|
||||
return proto.RuleProtocol_ICMP
|
||||
case types.PolicyRuleProtocolNetbirdSSH:
|
||||
return proto.RuleProtocol_NETBIRD_SSH
|
||||
default:
|
||||
return proto.RuleProtocol_UNKNOWN
|
||||
}
|
||||
}
|
||||
|
||||
// GetProtoPortInfo converts route-firewall-rule port info to proto.PortInfo.
|
||||
func GetProtoPortInfo(rule *types.RouteFirewallRule) *proto.PortInfo {
|
||||
var portInfo proto.PortInfo
|
||||
if rule.Port != 0 {
|
||||
portInfo.PortSelection = &proto.PortInfo_Port{Port: uint32(rule.Port)}
|
||||
} else if portRange := rule.PortRange; portRange.Start != 0 && portRange.End != 0 {
|
||||
portInfo.PortSelection = &proto.PortInfo_Range_{
|
||||
Range: &proto.PortInfo_Range{
|
||||
Start: uint32(portRange.Start),
|
||||
End: uint32(portRange.End),
|
||||
},
|
||||
}
|
||||
}
|
||||
return &portInfo
|
||||
}
|
||||
|
||||
// ShouldUsePortRange reports whether the firewall rule should use a port
|
||||
// range rather than a single port (TCP/UDP without a single port).
|
||||
func ShouldUsePortRange(rule *proto.FirewallRule) bool {
|
||||
return rule.Port == "" && (rule.Protocol == proto.RuleProtocol_UDP || rule.Protocol == proto.RuleProtocol_TCP)
|
||||
}
|
||||
|
||||
// ToProtocolRoutesFirewallRules converts a slice of typed route-firewall
|
||||
// rules to proto.
|
||||
func ToProtocolRoutesFirewallRules(rules []*types.RouteFirewallRule) []*proto.RouteFirewallRule {
|
||||
result := make([]*proto.RouteFirewallRule, len(rules))
|
||||
for i := range rules {
|
||||
rule := rules[i]
|
||||
result[i] = &proto.RouteFirewallRule{
|
||||
SourceRanges: rule.SourceRanges,
|
||||
Action: GetProtoAction(rule.Action),
|
||||
Destination: rule.Destination,
|
||||
Protocol: GetProtoProtocol(rule.Protocol),
|
||||
PortInfo: GetProtoPortInfo(rule),
|
||||
IsDynamic: rule.IsDynamic,
|
||||
Domains: rule.Domains.ToPunycodeList(),
|
||||
PolicyID: []byte(rule.PolicyID),
|
||||
RouteID: string(rule.RouteID),
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// ConvertToProtoCustomZone converts an nbdns.CustomZone to its proto form.
|
||||
func ConvertToProtoCustomZone(zone nbdns.CustomZone) *proto.CustomZone {
|
||||
protoZone := &proto.CustomZone{
|
||||
Domain: zone.Domain,
|
||||
Records: make([]*proto.SimpleRecord, 0, len(zone.Records)),
|
||||
SearchDomainDisabled: zone.SearchDomainDisabled,
|
||||
NonAuthoritative: zone.NonAuthoritative,
|
||||
}
|
||||
for _, record := range zone.Records {
|
||||
protoZone.Records = append(protoZone.Records, &proto.SimpleRecord{
|
||||
Name: record.Name,
|
||||
Type: int64(record.Type),
|
||||
Class: record.Class,
|
||||
TTL: int64(record.TTL),
|
||||
RData: record.RData,
|
||||
})
|
||||
}
|
||||
return protoZone
|
||||
}
|
||||
|
||||
// ConvertToProtoNameServerGroup converts a NameServerGroup to its proto form.
|
||||
func ConvertToProtoNameServerGroup(nsGroup *nbdns.NameServerGroup) *proto.NameServerGroup {
|
||||
protoGroup := &proto.NameServerGroup{
|
||||
Primary: nsGroup.Primary,
|
||||
Domains: nsGroup.Domains,
|
||||
SearchDomainsEnabled: nsGroup.SearchDomainsEnabled,
|
||||
NameServers: make([]*proto.NameServer, 0, len(nsGroup.NameServers)),
|
||||
}
|
||||
for _, ns := range nsGroup.NameServers {
|
||||
protoGroup.NameServers = append(protoGroup.NameServers, &proto.NameServer{
|
||||
IP: ns.IP.String(),
|
||||
Port: int64(ns.Port),
|
||||
NSType: int64(ns.NSType),
|
||||
})
|
||||
}
|
||||
return protoGroup
|
||||
}
|
||||
|
||||
// DNSConfigCache is the cache contract for amortising NameServerGroup
|
||||
// proto-conversion across peers in the same account. Server uses a concrete
|
||||
// implementation; client passes nil (no cross-peer caching needed when
|
||||
// rebuilding a single NetworkMap from an envelope).
|
||||
type DNSConfigCache interface {
|
||||
GetNameServerGroup(key string) (*proto.NameServerGroup, bool)
|
||||
SetNameServerGroup(key string, value *proto.NameServerGroup)
|
||||
}
|
||||
|
||||
// ToProtocolDNSConfig converts nbdns.Config to proto.DNSConfig. If cache is
|
||||
// non-nil, NameServerGroup proto values are cached by NSG.ID across calls —
|
||||
// the server amortises this across peers, the client passes nil.
|
||||
func ToProtocolDNSConfig(update nbdns.Config, cache DNSConfigCache, forwardPort int64) *proto.DNSConfig {
|
||||
protoUpdate := &proto.DNSConfig{
|
||||
ServiceEnable: update.ServiceEnable,
|
||||
CustomZones: make([]*proto.CustomZone, 0, len(update.CustomZones)),
|
||||
NameServerGroups: make([]*proto.NameServerGroup, 0, len(update.NameServerGroups)),
|
||||
ForwarderPort: forwardPort,
|
||||
}
|
||||
|
||||
for _, zone := range update.CustomZones {
|
||||
protoUpdate.CustomZones = append(protoUpdate.CustomZones, ConvertToProtoCustomZone(zone))
|
||||
}
|
||||
|
||||
for _, nsGroup := range update.NameServerGroups {
|
||||
if cache != nil {
|
||||
if cachedGroup, exists := cache.GetNameServerGroup(nsGroup.ID); exists {
|
||||
protoUpdate.NameServerGroups = append(protoUpdate.NameServerGroups, cachedGroup)
|
||||
continue
|
||||
}
|
||||
}
|
||||
protoGroup := ConvertToProtoNameServerGroup(nsGroup)
|
||||
if cache != nil {
|
||||
cache.SetNameServerGroup(nsGroup.ID, protoGroup)
|
||||
}
|
||||
protoUpdate.NameServerGroups = append(protoUpdate.NameServerGroups, protoGroup)
|
||||
}
|
||||
|
||||
return protoUpdate
|
||||
}
|
||||
|
||||
// AppendRemotePeerConfig appends typed peers as proto.RemotePeerConfig
|
||||
// entries to dst and returns the result.
|
||||
func AppendRemotePeerConfig(dst []*proto.RemotePeerConfig, peers []*nbpeer.Peer, dnsName string, includeIPv6 bool) []*proto.RemotePeerConfig {
|
||||
for _, rPeer := range peers {
|
||||
allowedIPs := []string{rPeer.IP.String() + "/32"}
|
||||
if includeIPv6 && rPeer.IPv6.IsValid() {
|
||||
allowedIPs = append(allowedIPs, rPeer.IPv6.String()+"/128")
|
||||
}
|
||||
dst = append(dst, &proto.RemotePeerConfig{
|
||||
WgPubKey: rPeer.Key,
|
||||
AllowedIps: allowedIPs,
|
||||
SshConfig: &proto.SSHConfig{SshPubKey: []byte(rPeer.SSHKey)},
|
||||
Fqdn: rPeer.FQDN(dnsName),
|
||||
AgentVersion: rPeer.Meta.WtVersion,
|
||||
})
|
||||
}
|
||||
return dst
|
||||
}
|
||||
|
||||
// BuildAuthorizedUsersProto deduplicates user-IDs into a hashed list and
|
||||
// builds per-machine-user index maps. Returns (hashedUsers, machineUsers).
|
||||
// Errors from individual hash failures are logged via the provided context;
|
||||
// they leave the offending user out of the result but don't abort the build.
|
||||
func BuildAuthorizedUsersProto(ctx context.Context, authorizedUsers map[string]map[string]struct{}) ([][]byte, map[string]*proto.MachineUserIndexes) {
|
||||
userIDToIndex := make(map[string]uint32)
|
||||
var hashedUsers [][]byte
|
||||
machineUsers := make(map[string]*proto.MachineUserIndexes, len(authorizedUsers))
|
||||
|
||||
for machineUser, users := range authorizedUsers {
|
||||
indexes := make([]uint32, 0, len(users))
|
||||
for userID := range users {
|
||||
idx, exists := userIDToIndex[userID]
|
||||
if !exists {
|
||||
hash, err := sshauth.HashUserID(userID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).WithError(err).Error("failed to hash user id")
|
||||
continue
|
||||
}
|
||||
idx = uint32(len(hashedUsers))
|
||||
userIDToIndex[userID] = idx
|
||||
hashedUsers = append(hashedUsers, hash[:])
|
||||
}
|
||||
indexes = append(indexes, idx)
|
||||
}
|
||||
machineUsers[machineUser] = &proto.MachineUserIndexes{Indexes: indexes}
|
||||
}
|
||||
|
||||
return hashedUsers, machineUsers
|
||||
}
|
||||
210
shared/management/networkmap/envelope.go
Normal file
210
shared/management/networkmap/envelope.go
Normal file
@@ -0,0 +1,210 @@
|
||||
package networkmap
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
// EnvelopeResult is what the client engine consumes after receiving a
|
||||
// component-format NetworkMap. Both fields are populated:
|
||||
//
|
||||
// - NetworkMap is the *proto.NetworkMap shape the engine reads today via
|
||||
// update.GetNetworkMap() — built from the envelope's components by
|
||||
// running Calculate() locally + converting back through the shared
|
||||
// proto helpers + merging the optional ProxyPatch.
|
||||
// - Components is the *types.NetworkMapComponents the engine retains so
|
||||
// future incremental delta updates (Step 3) have a base to apply
|
||||
// changes against. The client keeps it under its sync lock.
|
||||
type EnvelopeResult struct {
|
||||
NetworkMap *proto.NetworkMap
|
||||
Components *types.NetworkMapComponents
|
||||
}
|
||||
|
||||
// EnvelopeToNetworkMap is the full client-side pipeline: decode the
|
||||
// component envelope back to a typed NetworkMapComponents, run Calculate()
|
||||
// locally to produce the typed NetworkMap, convert it to the wire form the
|
||||
// engine consumes, and fold in any ProxyPatch the server attached.
|
||||
//
|
||||
// localPeerKey is the receiving peer's WG pub key (used to derive
|
||||
// includeIPv6 / useSourcePrefixes from the receiving peer's own record in
|
||||
// the components struct, mirroring legacy ToSyncResponse behaviour).
|
||||
//
|
||||
// dnsName is the account's DNS domain ("netbird.cloud" etc.); used when
|
||||
// rebuilding the per-peer FQDNs that proto.RemotePeerConfig carries.
|
||||
func EnvelopeToNetworkMap(ctx context.Context, env *proto.NetworkMapEnvelope, localPeerKey, dnsName string) (*EnvelopeResult, error) {
|
||||
components, err := DecodeEnvelope(env)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decode envelope: %w", err)
|
||||
}
|
||||
|
||||
// Find the receiving peer in the decoded components by WG key so we can
|
||||
// derive its capabilities and set components.PeerID for Calculate(). The
|
||||
// envelope.peers list is index-addressed; we synthesized IDs as "p<idx>".
|
||||
localPeerID, localPeer := findPeerByWgKey(components, localPeerKey)
|
||||
if localPeer == nil {
|
||||
return nil, fmt.Errorf("receiving peer (wg_key prefix %q) not found among %d decoded peers — components have no PeerID, Calculate would return empty", trimKey(localPeerKey), len(components.Peers))
|
||||
}
|
||||
components.PeerID = localPeerID
|
||||
|
||||
includeIPv6 := localPeer.SupportsIPv6() && localPeer.IPv6.IsValid()
|
||||
useSourcePrefixes := localPeer.SupportsSourcePrefixes()
|
||||
|
||||
typedNM := components.Calculate(ctx)
|
||||
|
||||
full := env.GetFull()
|
||||
dnsFwdPort := int64(0)
|
||||
if full != nil {
|
||||
dnsFwdPort = full.DnsForwarderPort
|
||||
}
|
||||
|
||||
protoNM := &proto.NetworkMap{
|
||||
Serial: typedNM.Network.CurrentSerial(),
|
||||
}
|
||||
if full != nil {
|
||||
protoNM.PeerConfig = full.PeerConfig
|
||||
}
|
||||
protoNM.Routes = ToProtocolRoutes(typedNM.Routes)
|
||||
protoNM.DNSConfig = ToProtocolDNSConfig(typedNM.DNSConfig, nil, dnsFwdPort)
|
||||
|
||||
remotePeers := AppendRemotePeerConfig(nil, typedNM.Peers, dnsName, includeIPv6)
|
||||
protoNM.RemotePeers = remotePeers
|
||||
protoNM.RemotePeersIsEmpty = len(remotePeers) == 0
|
||||
|
||||
protoNM.OfflinePeers = AppendRemotePeerConfig(nil, typedNM.OfflinePeers, dnsName, includeIPv6)
|
||||
|
||||
firewallRules := ToProtocolFirewallRules(typedNM.FirewallRules, includeIPv6, useSourcePrefixes)
|
||||
protoNM.FirewallRules = firewallRules
|
||||
protoNM.FirewallRulesIsEmpty = len(firewallRules) == 0
|
||||
|
||||
routesFirewallRules := ToProtocolRoutesFirewallRules(typedNM.RoutesFirewallRules)
|
||||
protoNM.RoutesFirewallRules = routesFirewallRules
|
||||
protoNM.RoutesFirewallRulesIsEmpty = len(routesFirewallRules) == 0
|
||||
|
||||
if typedNM.AuthorizedUsers != nil {
|
||||
hashedUsers, machineUsers := BuildAuthorizedUsersProto(ctx, typedNM.AuthorizedUsers)
|
||||
userIDClaim := ""
|
||||
if full != nil {
|
||||
userIDClaim = full.UserIdClaim
|
||||
}
|
||||
protoNM.SshAuth = &proto.SSHAuth{
|
||||
AuthorizedUsers: hashedUsers,
|
||||
MachineUsers: machineUsers,
|
||||
UserIDClaim: userIDClaim,
|
||||
}
|
||||
}
|
||||
|
||||
if typedNM.ForwardingRules != nil {
|
||||
forwardingRules := make([]*proto.ForwardingRule, 0, len(typedNM.ForwardingRules))
|
||||
for _, rule := range typedNM.ForwardingRules {
|
||||
forwardingRules = append(forwardingRules, rule.ToProto())
|
||||
}
|
||||
protoNM.ForwardingRules = forwardingRules
|
||||
}
|
||||
|
||||
// Merge the proxy patch the server attached. Mirrors the legacy
|
||||
// NetworkMap.Merge step that the server runs after Calculate().
|
||||
if full != nil && full.ProxyPatch != nil {
|
||||
mergeProxyPatch(protoNM, full.ProxyPatch)
|
||||
}
|
||||
|
||||
return &EnvelopeResult{
|
||||
NetworkMap: protoNM,
|
||||
Components: components,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// mergeProxyPatch folds a ProxyPatch's pre-expanded fragments into the
|
||||
// proto.NetworkMap that Calculate() produced. Mirrors types.NetworkMap.Merge
|
||||
// — same six collections, deduplicated where the legacy merge dedupes.
|
||||
func mergeProxyPatch(nm *proto.NetworkMap, patch *proto.ProxyPatch) {
|
||||
nm.RemotePeers = appendUniquePeers(nm.RemotePeers, patch.Peers)
|
||||
nm.OfflinePeers = appendUniquePeers(nm.OfflinePeers, patch.OfflinePeers)
|
||||
nm.FirewallRules = append(nm.FirewallRules, patch.FirewallRules...)
|
||||
nm.Routes = append(nm.Routes, patch.Routes...)
|
||||
nm.RoutesFirewallRules = append(nm.RoutesFirewallRules, patch.RouteFirewallRules...)
|
||||
nm.ForwardingRules = append(nm.ForwardingRules, patch.ForwardingRules...)
|
||||
if len(nm.RemotePeers) > 0 {
|
||||
nm.RemotePeersIsEmpty = false
|
||||
}
|
||||
if len(nm.FirewallRules) > 0 {
|
||||
nm.FirewallRulesIsEmpty = false
|
||||
}
|
||||
if len(nm.RoutesFirewallRules) > 0 {
|
||||
nm.RoutesFirewallRulesIsEmpty = false
|
||||
}
|
||||
}
|
||||
|
||||
// appendUniquePeers dedupes by WgPubKey — mirrors legacy
|
||||
// mergeUniquePeersByID's intent (legacy keyed off Peer.ID; in proto form the
|
||||
// closest stable identifier is WgPubKey).
|
||||
func appendUniquePeers(dst, extra []*proto.RemotePeerConfig) []*proto.RemotePeerConfig {
|
||||
if len(extra) == 0 {
|
||||
return dst
|
||||
}
|
||||
seen := make(map[string]struct{}, len(dst))
|
||||
for _, p := range dst {
|
||||
if p == nil {
|
||||
continue
|
||||
}
|
||||
seen[p.WgPubKey] = struct{}{}
|
||||
}
|
||||
for _, p := range extra {
|
||||
if p == nil {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[p.WgPubKey]; ok {
|
||||
continue
|
||||
}
|
||||
seen[p.WgPubKey] = struct{}{}
|
||||
dst = append(dst, p)
|
||||
}
|
||||
return dst
|
||||
}
|
||||
|
||||
func trimKey(s string) string {
|
||||
if len(s) > 12 {
|
||||
return s[:12]
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// findPeerByWgKey locates the receiving peer in the decoded components by
|
||||
// matching its WireGuard public key. Compares raw 32-byte decode output —
|
||||
// not the base64 string — because production data has occasional non-canonical
|
||||
// padding bits that round-trip through the envelope's `bytes wg_pub_key`
|
||||
// field, canonicalising the encoding (semantically equivalent key, different
|
||||
// string). Decodes `wgKey` once up front and reuses a stack buffer in the
|
||||
// loop so an N-peer search is ~zero-alloc.
|
||||
func findPeerByWgKey(c *types.NetworkMapComponents, wgKey string) (string, *nbpeer.Peer) {
|
||||
const wgKeyRawLen = 32
|
||||
var (
|
||||
targetRaw [wgKeyRawLen]byte
|
||||
haveRaw bool
|
||||
)
|
||||
if n, err := base64.StdEncoding.Decode(targetRaw[:], []byte(wgKey)); err == nil && n == wgKeyRawLen {
|
||||
haveRaw = true
|
||||
}
|
||||
var peerRaw [wgKeyRawLen]byte
|
||||
for id, p := range c.Peers {
|
||||
if p == nil {
|
||||
continue
|
||||
}
|
||||
if p.Key == wgKey {
|
||||
return id, p
|
||||
}
|
||||
if !haveRaw {
|
||||
continue
|
||||
}
|
||||
n, err := base64.StdEncoding.Decode(peerRaw[:], []byte(p.Key))
|
||||
if err == nil && n == wgKeyRawLen && bytes.Equal(peerRaw[:], targetRaw[:]) {
|
||||
return id, p
|
||||
}
|
||||
}
|
||||
return "", nil
|
||||
}
|
||||
173
shared/management/networkmap/envelope_test.go
Normal file
173
shared/management/networkmap/envelope_test.go
Normal file
@@ -0,0 +1,173 @@
|
||||
package networkmap_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
goproto "google.golang.org/protobuf/proto"
|
||||
|
||||
mgmtgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
nbnetworkmap "github.com/netbirdio/netbird/shared/management/networkmap"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
// TestEnvelopeToNetworkMap_RoundTrip exercises the full client-side pipeline:
|
||||
// build a small components struct, encode an envelope, marshal/unmarshal the
|
||||
// wire bytes, decode back via EnvelopeToNetworkMap, and verify the result is
|
||||
// non-empty and consistent. Deeper per-field semantic equivalence with the
|
||||
// legacy server path is covered by the prod-DB equivalence test in
|
||||
// management/server/store/networkmap_envelope_equivalence_test.go.
|
||||
func TestEnvelopeToNetworkMap_RoundTrip(t *testing.T) {
|
||||
c, localPeerKey := buildSmokeComponents(t)
|
||||
|
||||
envelope := mgmtgrpc.EncodeNetworkMapEnvelope(mgmtgrpc.ComponentsEnvelopeInput{
|
||||
Components: c,
|
||||
DNSDomain: "netbird.cloud",
|
||||
})
|
||||
|
||||
wire, err := goproto.Marshal(envelope)
|
||||
require.NoError(t, err, "marshal envelope")
|
||||
|
||||
var decoded proto.NetworkMapEnvelope
|
||||
require.NoError(t, goproto.Unmarshal(wire, &decoded), "unmarshal envelope")
|
||||
|
||||
result, err := nbnetworkmap.EnvelopeToNetworkMap(context.Background(), &decoded, localPeerKey, "netbird.cloud")
|
||||
require.NoError(t, err, "EnvelopeToNetworkMap")
|
||||
require.NotNil(t, result)
|
||||
require.NotNil(t, result.NetworkMap, "decoded NetworkMap must be non-nil")
|
||||
require.NotNil(t, result.Components, "Components must be retained for future delta updates")
|
||||
require.NotNil(t, result.Components.AccountSettings)
|
||||
require.NotEmpty(t, result.NetworkMap.RemotePeers, "two-peer allow policy should produce one remote peer")
|
||||
require.NotEmpty(t, result.NetworkMap.FirewallRules, "two-peer allow policy should produce firewall rules")
|
||||
}
|
||||
|
||||
// TestCalculate_FirewallRuleProtocol_NeverNetbirdSSH guards against the
|
||||
// scenario where a rule with Protocol=NetbirdSSH leaks the enum value into
|
||||
// proto.FirewallRule.Protocol. Calculate() must rewrite NetbirdSSH → TCP
|
||||
// before forming firewall rules (see networkmap_components.go:282 and
|
||||
// account.go:868). Without that rewrite, agents fall into UNKNOWN-protocol
|
||||
// handling, which on some platforms downgrades to allow-all — a real
|
||||
// security regression.
|
||||
func TestCalculate_FirewallRuleProtocol_NeverNetbirdSSH(t *testing.T) {
|
||||
c, localPeerKey := buildSmokeComponents(t)
|
||||
// Replace the smoke policy with a NetbirdSSH-protocol allow.
|
||||
c.Policies = []*types.Policy{{
|
||||
ID: "pol-ssh", AccountSeqID: 2, Enabled: true,
|
||||
Rules: []*types.PolicyRule{{
|
||||
ID: "rule-ssh",
|
||||
Enabled: true,
|
||||
Action: types.PolicyTrafficActionAccept,
|
||||
Protocol: types.PolicyRuleProtocolNetbirdSSH,
|
||||
Bidirectional: true,
|
||||
Sources: []string{"group-all"},
|
||||
Destinations: []string{"group-all"},
|
||||
}},
|
||||
}}
|
||||
|
||||
envelope := mgmtgrpc.EncodeNetworkMapEnvelope(mgmtgrpc.ComponentsEnvelopeInput{
|
||||
Components: c,
|
||||
DNSDomain: "netbird.cloud",
|
||||
})
|
||||
wire, err := goproto.Marshal(envelope)
|
||||
require.NoError(t, err)
|
||||
var decoded proto.NetworkMapEnvelope
|
||||
require.NoError(t, goproto.Unmarshal(wire, &decoded))
|
||||
|
||||
result, err := nbnetworkmap.EnvelopeToNetworkMap(context.Background(), &decoded, localPeerKey, "netbird.cloud")
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, result.NetworkMap.FirewallRules, "ssh policy should produce firewall rules")
|
||||
for i, fr := range result.NetworkMap.FirewallRules {
|
||||
require.NotEqualf(t, proto.RuleProtocol_NETBIRD_SSH, fr.Protocol,
|
||||
"FirewallRules[%d].Protocol must be the rewritten TCP, not NETBIRD_SSH", i)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnvelopeToNetworkMap_NilEnvelope(t *testing.T) {
|
||||
_, err := nbnetworkmap.EnvelopeToNetworkMap(context.Background(), nil, "key", "netbird.cloud")
|
||||
require.Error(t, err, "nil envelope must produce an error rather than panic")
|
||||
}
|
||||
|
||||
func TestEnvelopeToNetworkMap_FullPayloadMissing(t *testing.T) {
|
||||
env := &proto.NetworkMapEnvelope{}
|
||||
_, err := nbnetworkmap.EnvelopeToNetworkMap(context.Background(), env, "key", "netbird.cloud")
|
||||
require.Error(t, err, "envelope with no Full payload must produce an error")
|
||||
}
|
||||
|
||||
// buildSmokeComponents returns a minimal NetworkMapComponents (2 peers, 1
|
||||
// group, 1 allow policy) plus the receiving peer's WG public key. Sufficient
|
||||
// to validate the encode → marshal → decode → Calculate pipeline produces
|
||||
// non-empty output.
|
||||
func buildSmokeComponents(t *testing.T) (*types.NetworkMapComponents, string) {
|
||||
t.Helper()
|
||||
|
||||
peerAKey := randomWgKey(t)
|
||||
peerBKey := randomWgKey(t)
|
||||
|
||||
peerA := &nbpeer.Peer{
|
||||
ID: "peer-A",
|
||||
Key: peerAKey,
|
||||
IP: netip.AddrFrom4([4]byte{100, 64, 0, 1}),
|
||||
DNSLabel: "peerA",
|
||||
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.40.0"},
|
||||
}
|
||||
peerB := &nbpeer.Peer{
|
||||
ID: "peer-B",
|
||||
Key: peerBKey,
|
||||
IP: netip.AddrFrom4([4]byte{100, 64, 0, 2}),
|
||||
DNSLabel: "peerB",
|
||||
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.40.0"},
|
||||
}
|
||||
|
||||
group := &types.Group{
|
||||
ID: "group-all", AccountSeqID: 1, Name: "All",
|
||||
Peers: []string{"peer-A", "peer-B"},
|
||||
}
|
||||
|
||||
policy := &types.Policy{
|
||||
ID: "pol-allow", AccountSeqID: 1, Enabled: true,
|
||||
Rules: []*types.PolicyRule{{
|
||||
ID: "rule-allow",
|
||||
Enabled: true,
|
||||
Action: types.PolicyTrafficActionAccept,
|
||||
Protocol: types.PolicyRuleProtocolALL,
|
||||
Bidirectional: true,
|
||||
Sources: []string{"group-all"},
|
||||
Destinations: []string{"group-all"},
|
||||
}},
|
||||
}
|
||||
|
||||
c := &types.NetworkMapComponents{
|
||||
PeerID: "peer-A",
|
||||
Network: &types.Network{
|
||||
Identifier: "net-smoke",
|
||||
Net: net.IPNet{IP: net.IP{100, 64, 0, 0}, Mask: net.CIDRMask(10, 32)},
|
||||
Serial: 1,
|
||||
},
|
||||
AccountSettings: &types.AccountSettingsInfo{},
|
||||
DNSSettings: &types.DNSSettings{},
|
||||
Peers: map[string]*nbpeer.Peer{
|
||||
"peer-A": peerA,
|
||||
"peer-B": peerB,
|
||||
},
|
||||
Groups: map[string]*types.Group{
|
||||
"group-all": group,
|
||||
},
|
||||
Policies: []*types.Policy{policy},
|
||||
}
|
||||
return c, peerAKey
|
||||
}
|
||||
|
||||
func randomWgKey(t *testing.T) string {
|
||||
t.Helper()
|
||||
var raw [32]byte
|
||||
_, err := rand.Read(raw[:])
|
||||
require.NoError(t, err)
|
||||
return base64.StdEncoding.EncodeToString(raw[:])
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -133,6 +133,12 @@ message SyncResponse {
|
||||
|
||||
// Posture checks to be evaluated by client
|
||||
repeated Checks Checks = 6;
|
||||
|
||||
// NetworkMapEnvelope carries the component-based wire format for peers that
|
||||
// advertise PeerCapabilityComponentNetworkMap. When set, NetworkMap (field 5)
|
||||
// is left empty: management ships components and the client runs Calculate()
|
||||
// locally instead of receiving an expanded NetworkMap.
|
||||
NetworkMapEnvelope NetworkMapEnvelope = 7;
|
||||
}
|
||||
|
||||
message SyncMetaRequest {
|
||||
@@ -212,6 +218,8 @@ enum PeerCapability {
|
||||
PeerCapabilitySourcePrefixes = 1;
|
||||
// Client handles IPv6 overlay addresses and firewall rules.
|
||||
PeerCapabilityIPv6Overlay = 2;
|
||||
// Client receives NetworkMap as components and assembles it locally.
|
||||
PeerCapabilityComponentNetworkMap = 3;
|
||||
}
|
||||
|
||||
// PeerSystemMeta is machine meta data like OS and version.
|
||||
@@ -569,6 +577,13 @@ enum RuleProtocol {
|
||||
UDP = 3;
|
||||
ICMP = 4;
|
||||
CUSTOM = 5;
|
||||
// NETBIRD_SSH (types.PolicyRuleProtocolType "netbird-ssh") is the marker
|
||||
// policy rule that drives SSH-server activation in Calculate(). The legacy
|
||||
// proto.FirewallRule path doesn't ship this value (Calculate already
|
||||
// expands SSH rules into TCP/22 before encoding), but the components path
|
||||
// ships RAW policies — the client must see this protocol to derive
|
||||
// AuthorizedUsers locally.
|
||||
NETBIRD_SSH = 6;
|
||||
}
|
||||
|
||||
enum RuleDirection {
|
||||
@@ -709,3 +724,462 @@ message StopExposeRequest {
|
||||
}
|
||||
|
||||
message StopExposeResponse {}
|
||||
|
||||
// =====================================================================
|
||||
// Component-based NetworkMap wire format (PeerCapabilityComponentNetworkMap).
|
||||
//
|
||||
// Peers that advertise this capability receive NetworkMap building blocks
|
||||
// (peers + groups + policies + routes + dns + ssh + forwarding) and run the
|
||||
// expansion (Calculate) locally instead of receiving a fully-expanded
|
||||
// NetworkMap from the server.
|
||||
// =====================================================================
|
||||
|
||||
// NetworkMapEnvelope wraps either a full snapshot or a delta. Step 2 ships
|
||||
// only Full; Delta is reserved for the incremental-update work.
|
||||
message NetworkMapEnvelope {
|
||||
oneof payload {
|
||||
NetworkMapComponentsFull full = 1;
|
||||
NetworkMapComponentsDelta delta = 2;
|
||||
}
|
||||
}
|
||||
|
||||
// NetworkMapComponentsFull is the full per-peer component snapshot. The
|
||||
// client decodes it into a types.NetworkMapComponents and runs Calculate()
|
||||
// locally to produce the same NetworkMap the legacy server path would have
|
||||
// produced. Every field carries RAW component data — no server-side
|
||||
// expansion (firewall rules, DNS config, SSH auth, route firewall rules,
|
||||
// forwarding rules) is shipped; the client computes those itself.
|
||||
message NetworkMapComponentsFull {
|
||||
uint64 serial = 1;
|
||||
|
||||
// Peer config for the receiving peer (legacy proto.PeerConfig kept as-is —
|
||||
// it carries the receiving peer's own overlay address, FQDN, SSH config).
|
||||
PeerConfig peer_config = 2;
|
||||
|
||||
// Account-level network metadata (id, IPv4/IPv6 overlay subnets, DNS,
|
||||
// serial). Mirrors types.Network.
|
||||
AccountNetwork network = 3;
|
||||
|
||||
// Account-level settings the client needs for its local Calculate().
|
||||
AccountSettingsCompact account_settings = 4;
|
||||
|
||||
// Account DNS settings (mirrors types.DNSSettings).
|
||||
DNSSettingsCompact dns_settings = 5;
|
||||
|
||||
// Domain shared across all peers in this account, e.g. "netbird.cloud".
|
||||
// Each peer's FQDN is dns_label + "." + dns_domain.
|
||||
string dns_domain = 6;
|
||||
|
||||
// Custom-zone domain for this peer's view (c.CustomZoneDomain). Empty when
|
||||
// the peer has no custom zone records.
|
||||
string custom_zone_domain = 7;
|
||||
|
||||
// Deduplicated agent versions; PeerCompact.agent_version_idx indexes here.
|
||||
// Empty string at index 0 if any peer has no version.
|
||||
repeated string agent_versions = 8;
|
||||
|
||||
// All peers (deduplicated). The client splits peers into online / offline
|
||||
// locally using account_settings.peer_login_expiration on receive.
|
||||
repeated PeerCompact peers = 9;
|
||||
|
||||
// Indexes into peers for the subset that may act as routers.
|
||||
repeated uint32 router_peer_indexes = 10;
|
||||
|
||||
// Policies that affect the receiving peer.
|
||||
repeated PolicyCompact policies = 11;
|
||||
|
||||
// Groups in unspecified order — clients key off id (account_seq_id).
|
||||
repeated GroupCompact groups = 12;
|
||||
|
||||
// Routes relevant to this peer, raw shape (mirrors []*route.Route).
|
||||
repeated RouteRaw routes = 13;
|
||||
|
||||
// Nameserver groups (mirrors []*nbdns.NameServerGroup).
|
||||
repeated NameServerGroupRaw nameserver_groups = 14;
|
||||
|
||||
// All DNS records the client needs to assemble its custom zone. Reuses
|
||||
// the existing SimpleRecord wire shape.
|
||||
repeated SimpleRecord all_dns_records = 15;
|
||||
|
||||
// Custom zones (typically the peer's own zone). Reuses the existing
|
||||
// CustomZone wire shape.
|
||||
repeated CustomZone account_zones = 16;
|
||||
|
||||
// Network resources (mirrors []*resourceTypes.NetworkResource).
|
||||
repeated NetworkResourceRaw network_resources = 17;
|
||||
|
||||
// Routers per network. Outer key: network account_seq_id. Each entry is
|
||||
// the set of routers backing that network for this peer's view.
|
||||
//
|
||||
// INCOMPATIBLE WIRE CHANGE: the map key changed from string (network xid)
|
||||
// to uint32 (account_seq_id). Field 18 was reused without a `reserved`
|
||||
// entry because capability=3 has never been released — every cap=3
|
||||
// producer and consumer carries the same regenerated descriptor. Do NOT
|
||||
// reuse this pattern for any further wire change once cap=3 ships.
|
||||
map<uint32, NetworkRouterList> routers_map = 18;
|
||||
|
||||
// For each NetworkResource account_seq_id, the indexes into policies[]
|
||||
// that apply to it.
|
||||
//
|
||||
// INCOMPATIBLE WIRE CHANGE: see routers_map note above.
|
||||
map<uint32, PolicyIndexes> resource_policies_map = 19;
|
||||
|
||||
// Group-id (account_seq_id) → user ids authorized for SSH on members.
|
||||
map<uint32, UserIDList> group_id_to_user_ids = 20;
|
||||
|
||||
// Account-level allowed user ids (used by Calculate() when assembling SSH
|
||||
// authorized users for the receiving peer).
|
||||
repeated string allowed_user_ids = 21;
|
||||
|
||||
// Per posture-check account_seq_id, the set of peer indexes that failed
|
||||
// the check. Server-side evaluation result; clients do not re-evaluate.
|
||||
//
|
||||
// INCOMPATIBLE WIRE CHANGE: see routers_map note above.
|
||||
map<uint32, PeerIndexSet> posture_failed_peers = 22;
|
||||
|
||||
// Account-level DNS forwarder port (mirrors the legacy
|
||||
// proto.DNSConfig.ForwarderPort). Computed by the controller from peer
|
||||
// versions; clients fold it into their Calculate() DNS output.
|
||||
int64 dns_forwarder_port = 23;
|
||||
|
||||
// Pre-expanded NetworkMap fragments injected post-Calculate by external
|
||||
// controllers (BYOP / port-forwarding proxies). The receiving client
|
||||
// merges these into its locally-computed NetworkMap the same way the
|
||||
// legacy server does via NetworkMap.Merge — so downstream consumers see
|
||||
// a unified merged result regardless of source.
|
||||
ProxyPatch proxy_patch = 24;
|
||||
|
||||
// SSH UserIDClaim — server-side HttpServerConfig.AuthUserIDClaim, or
|
||||
// "sub" by default. Populated in proto.SSHAuth.UserIDClaim when the
|
||||
// client rebuilds the NetworkMap from this envelope. Empty when the
|
||||
// account has no AuthorizedUsers (and thus no SshAuth to populate).
|
||||
string user_id_claim = 25;
|
||||
|
||||
// Reserved for future component additions (incremental_serial, parent_seq,
|
||||
// etc.) without forcing a renumber.
|
||||
reserved 26 to 50;
|
||||
}
|
||||
|
||||
// ProxyPatch carries NetworkMap fragments that don't fit the component-graph
|
||||
// model — they're pre-expanded by external controllers (BYOP /
|
||||
// port-forwarding proxies) and injected post-Calculate. Fields use the
|
||||
// legacy wire types because the proxy delivers them pre-formed; there is
|
||||
// no raw component shape to convert from. Empty when no proxy is active.
|
||||
message ProxyPatch {
|
||||
repeated RemotePeerConfig peers = 1;
|
||||
repeated RemotePeerConfig offline_peers = 2;
|
||||
repeated FirewallRule firewall_rules = 3;
|
||||
repeated Route routes = 4;
|
||||
repeated RouteFirewallRule route_firewall_rules = 5;
|
||||
repeated ForwardingRule forwarding_rules = 6;
|
||||
}
|
||||
|
||||
// AccountSettingsCompact carries the account-level settings the client needs
|
||||
// to evaluate locally. Mirrors the subset of types.AccountSettingsInfo that
|
||||
// Calculate() actually reads — login-expiration (used to filter expired
|
||||
// peers). Inactivity expiration is purely server-side bookkeeping and is not
|
||||
// shipped.
|
||||
message AccountSettingsCompact {
|
||||
bool peer_login_expiration_enabled = 1;
|
||||
// Login expiration window. Unit is nanoseconds (matches time.Duration).
|
||||
int64 peer_login_expiration_ns = 2;
|
||||
}
|
||||
|
||||
// AccountNetwork is the account-level overlay metadata. Mirrors types.Network
|
||||
// so the client can populate NetworkMap.Network without a server round-trip.
|
||||
message AccountNetwork {
|
||||
string identifier = 1;
|
||||
// IPv4 overlay subnet in CIDR form (e.g. "100.64.0.0/16").
|
||||
string net_cidr = 2;
|
||||
// IPv6 ULA overlay subnet in CIDR form (e.g. "fd00:4e42::/64"). Empty when
|
||||
// the account has no IPv6 overlay yet.
|
||||
string net_v6_cidr = 3;
|
||||
string dns = 4;
|
||||
uint64 serial = 5;
|
||||
}
|
||||
|
||||
// NetworkMapComponentsDelta is reserved for the incremental update protocol
|
||||
// (Step 3 of the migration plan). Field numbers 1–100 are pre-allocated to
|
||||
// keep room for the planned event types without needing a renumber.
|
||||
message NetworkMapComponentsDelta {
|
||||
reserved 1 to 100;
|
||||
}
|
||||
|
||||
// PeerCompact is the wire-shape of a remote peer used by the component
|
||||
// format. It carries every field of types.Peer that the client's local
|
||||
// Calculate() reads — including the trio needed to evaluate
|
||||
// LoginExpired() (added_with_sso_login + login_expiration_enabled +
|
||||
// last_login_unix_nano). Fields the client does not consume (Status,
|
||||
// CreatedAt, etc.) are not shipped.
|
||||
message PeerCompact {
|
||||
// Raw 32-byte WireGuard public key (no base64 wrapping).
|
||||
bytes wg_pub_key = 1;
|
||||
|
||||
// Raw 4-byte IPv4 overlay address. Always a /32 host route, so no prefix
|
||||
// byte is needed.
|
||||
bytes ip = 2;
|
||||
|
||||
// Raw 16-byte IPv6 overlay address; always a /128 host route. Empty when
|
||||
// the peer has no IPv6 overlay address.
|
||||
bytes ipv6 = 3;
|
||||
|
||||
// Raw SSH public key bytes (or empty).
|
||||
bytes ssh_pub_key = 4;
|
||||
|
||||
// DNS label without the account's domain suffix. Full FQDN is
|
||||
// dns_label + "." + NetworkMapComponentsFull.dns_domain.
|
||||
string dns_label = 5;
|
||||
|
||||
// Index into NetworkMapComponentsFull.agent_versions.
|
||||
uint32 agent_version_idx = 6;
|
||||
|
||||
// True iff the peer was added via SSO login (i.e., types.Peer.UserID is
|
||||
// non-empty). Combined with login_expiration_enabled and
|
||||
// last_login_unix_nano this lets the client reproduce
|
||||
// (*Peer).LoginExpired() locally.
|
||||
bool added_with_sso_login = 7;
|
||||
|
||||
// True when the peer's login can expire — mirrors
|
||||
// types.Peer.LoginExpirationEnabled.
|
||||
bool login_expiration_enabled = 8;
|
||||
|
||||
// Unix-nanosecond timestamp of the peer's last login. 0 when the peer has
|
||||
// never logged in (server stores nil; client treats 0 as "epoch", which
|
||||
// makes a fresh peer immediately expired iff login_expiration_enabled is
|
||||
// true — the same semantics as types.Peer.GetLastLogin).
|
||||
int64 last_login_unix_nano = 9;
|
||||
|
||||
// True when the peer has an SSH server enabled locally. Used by the
|
||||
// legacy SSH path in Calculate() (`policyRuleImpliesLegacySSH`): a rule
|
||||
// with protocol ALL/TCP-with-SSH-ports activates SSH for the receiving
|
||||
// peer when this bit is set, even without an explicit NetbirdSSH rule.
|
||||
bool ssh_enabled = 10;
|
||||
|
||||
reserved 11; // was: id (string xid)
|
||||
|
||||
// Mirror of types.Peer.SupportsIPv6() — !Meta.Flags.DisableIPv6 &&
|
||||
// HasCapability(PeerCapabilityIPv6Overlay). Used by the local peer's
|
||||
// Calculate() when deciding whether to emit IPv6 firewall rules
|
||||
// (appendIPv6FirewallRule) against this peer's IPv6 address.
|
||||
bool supports_ipv6 = 12;
|
||||
|
||||
// Mirror of types.Peer.SupportsSourcePrefixes() —
|
||||
// HasCapability(PeerCapabilitySourcePrefixes). Determines whether the
|
||||
// local peer's Calculate() emits SourcePrefixes alongside legacy PeerIP
|
||||
// fields in proto.FirewallRule.
|
||||
bool supports_source_prefixes = 13;
|
||||
|
||||
// Mirror of types.Peer.Meta.Flags.ServerSSHAllowed. Read by Calculate()
|
||||
// when expanding TCP port-22 firewall rules — the native SSH companion
|
||||
// (port 22022) is only added when this flag is set and the peer agent
|
||||
// version supports it.
|
||||
bool server_ssh_allowed = 14;
|
||||
}
|
||||
|
||||
// PolicyCompact is the compact form of a policy rule. Group references use
|
||||
// the per-account integer ids from account_seq_counters; the client resolves
|
||||
// them against NetworkMapComponentsFull.groups. Direction is derived per-peer
|
||||
// on the client (ingress when the peer is in destination_group_ids, egress
|
||||
// when in source_group_ids; both when bidirectional).
|
||||
message PolicyCompact {
|
||||
// Per-account integer id (matches policies.account_seq_id). Used as a
|
||||
// stable reference for ResourcePoliciesMap.indexes and future delta
|
||||
// updates (Step 3).
|
||||
uint32 id = 1;
|
||||
|
||||
RuleAction action = 2;
|
||||
RuleProtocol protocol = 3;
|
||||
bool bidirectional = 4;
|
||||
|
||||
// Single ports referenced by the rule.
|
||||
repeated uint32 ports = 5;
|
||||
|
||||
// Port ranges (start..end) referenced by the rule.
|
||||
repeated PortInfo.Range port_ranges = 6;
|
||||
|
||||
// Group ids (account_seq_id) of source / destination groups.
|
||||
repeated uint32 source_group_ids = 7;
|
||||
repeated uint32 destination_group_ids = 8;
|
||||
|
||||
reserved 9; // was: xid (string)
|
||||
|
||||
// SSH authorization fields. PolicyRule.AuthorizedGroups maps the rule's
|
||||
// applicable group ids (account_seq_id) to a list of local-user names —
|
||||
// when a peer in one of those groups is the SSH destination, the named
|
||||
// local users gain access. AuthorizedUser is the single-user form
|
||||
// (legacy: rule scopes SSH to one specific user id).
|
||||
//
|
||||
// Both fields are only consumed by Calculate() when the rule's protocol
|
||||
// is NetbirdSSH (or the legacy implicit-SSH heuristic).
|
||||
map<uint32, UserNameList> authorized_groups = 10;
|
||||
string authorized_user = 11;
|
||||
|
||||
// Resource-typed rule sources/destinations. When a rule targets a specific
|
||||
// peer (rather than groups), Calculate() reads SourceResource /
|
||||
// DestinationResource — without these the rule's connection resources
|
||||
// can't be produced on the client. ResourceCompact's peer_index refers to
|
||||
// NetworkMapComponentsFull.peers; type is the raw ResourceType string
|
||||
// ("peer", "host", "subnet", "domain"). Only "peer" is meaningful for
|
||||
// Calculate's resource-typed rule path today.
|
||||
ResourceCompact source_resource = 12;
|
||||
ResourceCompact destination_resource = 13;
|
||||
|
||||
// Posture-check seq ids gating this policy's source peers. Calculate()
|
||||
// reads them when filtering rule peers (peers that fail any listed check
|
||||
// are dropped from sourcePeers). Match keys in
|
||||
// NetworkMapComponentsFull.posture_failed_peers.
|
||||
repeated uint32 source_posture_check_seq_ids = 15;
|
||||
|
||||
reserved 14; // was: source_posture_check_ids (repeated string xid)
|
||||
}
|
||||
|
||||
// ResourceCompact mirrors types.Resource. Used by PolicyCompact to carry
|
||||
// rule.SourceResource / rule.DestinationResource when the rule targets a
|
||||
// specific resource (typically a peer) rather than groups.
|
||||
// peer_index_set tells whether peer_index is valid (proto3 uint32 cannot
|
||||
// disambiguate "0" from "unset"); set only when type == "peer".
|
||||
message ResourceCompact {
|
||||
string type = 1;
|
||||
bool peer_index_set = 2;
|
||||
uint32 peer_index = 3;
|
||||
reserved 4; // future: host/subnet/domain references when needed
|
||||
}
|
||||
|
||||
// UserNameList is a list of local-user names — used as the value type in
|
||||
// PolicyCompact.authorized_groups.
|
||||
message UserNameList {
|
||||
repeated string names = 1;
|
||||
}
|
||||
|
||||
// GroupCompact is the wire-shape of a group: per-account integer id, optional
|
||||
// name, and indexes into NetworkMapComponentsFull.peers identifying members.
|
||||
message GroupCompact {
|
||||
// Per-account integer id (matches groups.account_seq_id). Used by
|
||||
// PolicyCompact.source_group_ids / destination_group_ids.
|
||||
uint32 id = 1;
|
||||
|
||||
// Group name; only sent when non-empty (clients use it for diagnostics).
|
||||
string name = 2;
|
||||
|
||||
// Indexes into NetworkMapComponentsFull.peers.
|
||||
repeated uint32 peer_indexes = 3;
|
||||
}
|
||||
|
||||
// DNSSettingsCompact mirrors types.DNSSettings.
|
||||
message DNSSettingsCompact {
|
||||
// Group ids (account_seq_id) whose DNS management is disabled.
|
||||
repeated uint32 disabled_management_group_ids = 1;
|
||||
}
|
||||
|
||||
// RouteRaw mirrors *route.Route (the domain type), trimmed to fields that
|
||||
// types.NetworkMapComponents.Calculate() reads. Group references are
|
||||
// account_seq_ids; the routing peer (when set) is referenced by index into
|
||||
// NetworkMapComponentsFull.peers.
|
||||
message RouteRaw {
|
||||
// Per-account integer id (matches routes.account_seq_id).
|
||||
uint32 id = 1;
|
||||
string net_id = 2;
|
||||
string description = 3;
|
||||
|
||||
// Either network_cidr (e.g. "10.0.0.0/16") or domains is set, not both.
|
||||
string network_cidr = 4;
|
||||
repeated string domains = 5;
|
||||
bool keep_route = 6;
|
||||
|
||||
// Routing peer reference: peer_index_set tells whether peer_index is valid
|
||||
// (proto3 uint32 cannot disambiguate "0" from "unset"). Mutually exclusive
|
||||
// with peer_group_ids.
|
||||
//
|
||||
// peer_index decodes back to types.Peer.ID (the peer's xid string), NOT
|
||||
// to its WireGuard public key. This matches the server-side data flow:
|
||||
// c.Routes carry route.Peer = peer.ID, and getRoutingPeerRoutes mutates
|
||||
// it to peer.Key only after the route has been admitted to the network
|
||||
// map. Decoders MUST set Route.Peer = peer.ID; the legacy Calculate()
|
||||
// path will substitute the WG key downstream.
|
||||
bool peer_index_set = 7;
|
||||
uint32 peer_index = 8;
|
||||
repeated uint32 peer_group_ids = 9;
|
||||
|
||||
int32 network_type = 10;
|
||||
bool masquerade = 11;
|
||||
int32 metric = 12;
|
||||
bool enabled = 13;
|
||||
repeated uint32 group_ids = 14;
|
||||
repeated uint32 access_control_group_ids = 15;
|
||||
bool skip_auto_apply = 16;
|
||||
|
||||
reserved 17; // was: xid (string)
|
||||
}
|
||||
|
||||
// NameServerGroupRaw mirrors *nbdns.NameServerGroup. Distinct from the
|
||||
// legacy NameServerGroup (which is the wire-trimmed shape consumed by
|
||||
// proto.DNSConfig and lacks the Name/Description/Groups/Enabled fields).
|
||||
message NameServerGroupRaw {
|
||||
uint32 id = 1; // nameserver_groups.account_seq_id
|
||||
string name = 2;
|
||||
string description = 3;
|
||||
// Reuses the legacy NameServer wire shape (IP as string).
|
||||
repeated NameServer nameservers = 4;
|
||||
// Group ids (account_seq_id) the NSG distributes nameservers to.
|
||||
repeated uint32 group_ids = 5;
|
||||
bool primary = 6;
|
||||
repeated string domains = 7;
|
||||
bool enabled = 8;
|
||||
bool search_domains_enabled = 9;
|
||||
}
|
||||
|
||||
// NetworkResourceRaw mirrors *resourceTypes.NetworkResource.
|
||||
//
|
||||
// INCOMPATIBLE WIRE CHANGE: field 2 changed from `string network_id` (xid)
|
||||
// to `uint32 network_seq` without a `reserved` entry. Safe only because
|
||||
// capability=3 has never been released — every cap=3 producer and consumer
|
||||
// carries the same regenerated descriptor. Do NOT reuse this pattern once
|
||||
// cap=3 ships.
|
||||
message NetworkResourceRaw {
|
||||
uint32 id = 1; // network_resources.account_seq_id
|
||||
uint32 network_seq = 2; // networks.account_seq_id (replaces xid)
|
||||
string name = 3;
|
||||
string description = 4;
|
||||
// Resource type: "host" / "subnet" / "domain".
|
||||
string type = 5;
|
||||
string address = 6;
|
||||
string domain_value = 7; // resource.Domain
|
||||
string prefix_cidr = 8;
|
||||
bool enabled = 9;
|
||||
reserved 10; // was: xid (string)
|
||||
}
|
||||
|
||||
// NetworkRouterList carries the routers backing one network.
|
||||
message NetworkRouterList {
|
||||
// Routers in this network, keyed by peer_index (the routing peer).
|
||||
repeated NetworkRouterEntry entries = 1;
|
||||
}
|
||||
|
||||
// NetworkRouterEntry mirrors a single *routerTypes.NetworkRouter; the routing
|
||||
// peer is referenced by index into NetworkMapComponentsFull.peers.
|
||||
message NetworkRouterEntry {
|
||||
uint32 id = 1; // network_routers.account_seq_id
|
||||
uint32 peer_index = 2;
|
||||
bool peer_index_set = 3;
|
||||
repeated uint32 peer_group_ids = 4;
|
||||
bool masquerade = 5;
|
||||
int32 metric = 6;
|
||||
bool enabled = 7;
|
||||
}
|
||||
|
||||
// PolicyIndexes is a list of indexes into NetworkMapComponentsFull.policies.
|
||||
message PolicyIndexes {
|
||||
repeated uint32 indexes = 1;
|
||||
}
|
||||
|
||||
// UserIDList is a list of user ids — used as the value type in
|
||||
// NetworkMapComponentsFull.group_id_to_user_ids.
|
||||
message UserIDList {
|
||||
repeated string user_ids = 1;
|
||||
}
|
||||
|
||||
// PeerIndexSet is a set of peer indexes — used as the value type in
|
||||
// NetworkMapComponentsFull.posture_failed_peers.
|
||||
message PeerIndexSet {
|
||||
repeated uint32 peer_indexes = 1;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user