mirror of
https://github.com/netbirdio/netbird.git
synced 2026-06-27 02:10:00 +00:00
Compare commits
24 Commits
add-enterp
...
nmap/compo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b1347ad1ff | ||
|
|
ff2787e184 | ||
|
|
e20b62ad65 | ||
|
|
18b38943aa | ||
|
|
db76f33b71 | ||
|
|
eab0826b4e | ||
|
|
7048b87931 | ||
|
|
596952265d | ||
|
|
21cfec93d4 | ||
|
|
98818e3095 | ||
|
|
5d5c2d9f95 | ||
|
|
13e41e432c | ||
|
|
efa6a3f502 | ||
|
|
5fbcdeceac | ||
|
|
3a1bbeba90 | ||
|
|
728057ef15 | ||
|
|
582cd70086 | ||
|
|
9bbbafaf69 | ||
|
|
672b057aa0 | ||
|
|
b9a0186200 | ||
|
|
9083bdb977 | ||
|
|
b194af48b8 | ||
|
|
4543780ef0 | ||
|
|
2de0283971 |
4
.github/workflows/golangci-lint.yml
vendored
4
.github/workflows/golangci-lint.yml
vendored
@@ -37,7 +37,7 @@ jobs:
|
||||
display_name: Linux
|
||||
name: ${{ matrix.display_name }}
|
||||
runs-on: ${{ matrix.os }}
|
||||
timeout-minutes: 15
|
||||
timeout-minutes: 25
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
@@ -62,4 +62,4 @@ jobs:
|
||||
skip-cache: true
|
||||
skip-save-cache: true
|
||||
cache-invalidation-interval: 0
|
||||
args: --timeout=12m
|
||||
args: --timeout=20m
|
||||
|
||||
@@ -63,7 +63,9 @@ import (
|
||||
"github.com/netbirdio/netbird/route"
|
||||
mgm "github.com/netbirdio/netbird/shared/management/client"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
nbnetworkmap "github.com/netbirdio/netbird/shared/management/networkmap"
|
||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
types "github.com/netbirdio/netbird/shared/management/types"
|
||||
"github.com/netbirdio/netbird/shared/netiputil"
|
||||
auth "github.com/netbirdio/netbird/shared/relay/auth/hmac"
|
||||
relayClient "github.com/netbirdio/netbird/shared/relay/client"
|
||||
@@ -210,6 +212,13 @@ type Engine struct {
|
||||
// networkSerial is the latest CurrentSerial (state ID) of the network sent by the Management service
|
||||
networkSerial uint64
|
||||
|
||||
// latestComponents is the most-recent NetworkMapComponents decoded from
|
||||
// a NetworkMapEnvelope (capability=3 peers only). Held alongside the
|
||||
// NetworkMap that Calculate() produced from it so future incremental
|
||||
// updates have a base to apply changes against. nil for legacy-format
|
||||
// peers. Guarded by syncMsgMux.
|
||||
latestComponents *types.NetworkMapComponents
|
||||
|
||||
networkMonitor *networkmonitor.NetworkMonitor
|
||||
|
||||
sshServer sshServer
|
||||
@@ -910,20 +919,54 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
||||
return e.ctx.Err()
|
||||
}
|
||||
|
||||
if update.NetworkMap != nil && update.NetworkMap.PeerConfig != nil {
|
||||
e.handleAutoUpdateVersion(update.NetworkMap.PeerConfig.AutoUpdate)
|
||||
// Envelope sync responses carry PeerConfig at the top level; legacy
|
||||
// NetworkMap syncs carry it under NetworkMap.PeerConfig.
|
||||
if pc := update.GetPeerConfig(); pc != nil {
|
||||
e.handleAutoUpdateVersion(pc.GetAutoUpdate())
|
||||
} else if nm := update.GetNetworkMap(); nm != nil && nm.GetPeerConfig() != nil {
|
||||
e.handleAutoUpdateVersion(nm.GetPeerConfig().GetAutoUpdate())
|
||||
}
|
||||
|
||||
if err := e.updateNetbirdConfig(update.GetNetbirdConfig()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Decode the network map from either the components envelope or the
|
||||
// legacy proto.NetworkMap before the posture-check gating below, so the
|
||||
// "is there a network map" decision covers both wire shapes.
|
||||
var (
|
||||
nm *mgmProto.NetworkMap
|
||||
components *types.NetworkMapComponents
|
||||
)
|
||||
if envelope := update.GetNetworkMapEnvelope(); envelope != nil {
|
||||
// Components-format peer: decode the envelope back to typed
|
||||
// components, run Calculate() locally, and convert to the wire
|
||||
// NetworkMap shape the rest of the engine consumes. Components are
|
||||
// retained so future incremental updates can apply deltas instead
|
||||
// of doing a full reconstruction.
|
||||
localKey := e.config.WgPrivateKey.PublicKey().String()
|
||||
dnsName := ""
|
||||
if pc := update.GetPeerConfig(); pc != nil {
|
||||
// PeerConfig.Fqdn = "<dns_label>.<dns_domain>" — extract the
|
||||
// shared domain by stripping the peer's own label prefix. Falls
|
||||
// back to empty if the FQDN doesn't have the expected shape.
|
||||
dnsName = extractDNSDomainFromFQDN(pc.GetFqdn())
|
||||
}
|
||||
result, err := nbnetworkmap.EnvelopeToNetworkMap(e.ctx, envelope, localKey, dnsName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("decode network map envelope: %w", err)
|
||||
}
|
||||
nm = result.NetworkMap
|
||||
components = result.Components
|
||||
} else {
|
||||
nm = update.GetNetworkMap()
|
||||
}
|
||||
|
||||
// Posture checks are bound to the network map presence:
|
||||
// NetworkMap != nil, checks present -> apply the received checks
|
||||
// NetworkMap != nil, checks nil -> posture checks were removed, clear them
|
||||
// NetworkMap == nil -> config-only update (e.g. relay token rotation),
|
||||
// leave the previously applied checks untouched
|
||||
nm := update.GetNetworkMap()
|
||||
if nm == nil {
|
||||
return nil
|
||||
}
|
||||
@@ -932,6 +975,14 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
||||
return err
|
||||
}
|
||||
|
||||
// Only retain the components view when the server sent the envelope
|
||||
// path. A legacy proto.NetworkMap means components == nil; writing it
|
||||
// here would clobber a previously-cached snapshot, breaking the
|
||||
// incremental-delta base on a future envelope sync.
|
||||
if components != nil {
|
||||
e.latestComponents = components
|
||||
}
|
||||
|
||||
e.persistSyncResponse(update)
|
||||
|
||||
// only apply new changes and ignore old ones
|
||||
@@ -944,6 +995,19 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// extractDNSDomainFromFQDN returns the trailing dotted domain part of the
|
||||
// receiving peer's FQDN — the same value the management server fills as
|
||||
// dnsName when it builds the legacy NetworkMap. "peer42.netbird.cloud" →
|
||||
// "netbird.cloud". An empty string is returned for unrecognized formats.
|
||||
func extractDNSDomainFromFQDN(fqdn string) string {
|
||||
for i := 0; i < len(fqdn); i++ {
|
||||
if fqdn[i] == '.' && i+1 < len(fqdn) {
|
||||
return fqdn[i+1:]
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// updateNetbirdConfig applies the management-provided NetBird configuration:
|
||||
// STUN/TURN and relay servers, flow logging and DNS settings. A nil config is a no-op,
|
||||
// which is the case for sync updates carrying only a network map.
|
||||
|
||||
@@ -67,7 +67,6 @@ var boolStringLiterals = map[string]bool{
|
||||
"no": false,
|
||||
}
|
||||
|
||||
|
||||
// Policy holds MDM-managed settings read from the platform source. A nil or
|
||||
// empty Policy means no enforcement is active.
|
||||
type Policy struct {
|
||||
|
||||
@@ -31,8 +31,8 @@ func TestPolicy_Empty(t *testing.T) {
|
||||
|
||||
func TestPolicy_HasKey(t *testing.T) {
|
||||
p := NewPolicy(map[string]any{
|
||||
KeyManagementURL: "https://corp.example.com",
|
||||
KeyDisableProfiles: true,
|
||||
KeyManagementURL: "https://corp.example.com",
|
||||
KeyDisableProfiles: true,
|
||||
})
|
||||
assert.False(t, p.IsEmpty())
|
||||
assert.True(t, p.HasKey(KeyManagementURL))
|
||||
@@ -53,8 +53,8 @@ func TestPolicy_ManagedKeysSorted(t *testing.T) {
|
||||
func TestPolicy_GetString(t *testing.T) {
|
||||
p := NewPolicy(map[string]any{
|
||||
KeyManagementURL: "https://corp.example.com",
|
||||
KeyDisableProfiles: true, // wrong type for GetString
|
||||
KeyPreSharedKey: "", // empty rejected
|
||||
KeyDisableProfiles: true, // wrong type for GetString
|
||||
KeyPreSharedKey: "", // empty rejected
|
||||
})
|
||||
v, ok := p.GetString(KeyManagementURL)
|
||||
assert.True(t, ok)
|
||||
|
||||
@@ -336,11 +336,11 @@ type serviceClient struct {
|
||||
// mNetworks + mExitNode submenu items. Combines features.DisableNetworks
|
||||
// AND s.connected — both must be true for the menus to be active.
|
||||
// Zero value (false) matches the Disable() call at AddMenuItem time.
|
||||
networksMenuEnabled bool
|
||||
showNetworks bool
|
||||
wNetworks fyne.Window
|
||||
wProfiles fyne.Window
|
||||
wQuickActions fyne.Window
|
||||
networksMenuEnabled bool
|
||||
showNetworks bool
|
||||
wNetworks fyne.Window
|
||||
wProfiles fyne.Window
|
||||
wQuickActions fyne.Window
|
||||
|
||||
eventManager *event.Manager
|
||||
|
||||
@@ -418,7 +418,14 @@ func newServiceClient(args *newServiceClientArgs) *serviceClient {
|
||||
case args.showProfiles:
|
||||
s.showProfilesUI()
|
||||
case args.showQuickActions:
|
||||
s.showQuickActionsUI()
|
||||
// Suppress the on-boot Quick Actions popup when the daemon
|
||||
// reports DisableAutoConnect=true — that flag carries both the
|
||||
// user's "Connect on Startup = off" preference AND any MDM-
|
||||
// enforced override (applyMDMPolicy writes the policy value
|
||||
// into the same Config field). See netbirdio/netbird#5744.
|
||||
if !s.disableAutoConnectFromDaemon() {
|
||||
s.showQuickActionsUI()
|
||||
}
|
||||
case args.showUpdate:
|
||||
s.showUpdateProgress(ctx, args.showUpdateVersion)
|
||||
}
|
||||
@@ -1338,6 +1345,40 @@ func (s *serviceClient) getFeatures() (*proto.GetFeaturesResponse, error) {
|
||||
return features, nil
|
||||
}
|
||||
|
||||
// disableAutoConnectFromDaemon returns true when the daemon reports
|
||||
// the active profile has DisableAutoConnect=true. Used by the
|
||||
// --quick-actions startup path to suppress the on-boot popup when the
|
||||
// user (or an MDM admin) opted out of auto-connecting; both cases
|
||||
// converge on the same Config field because applyMDMPolicy writes the
|
||||
// policy value into it. Returns false on any RPC / lookup failure so a
|
||||
// daemon hiccup does not silently swallow the popup.
|
||||
func (s *serviceClient) disableAutoConnectFromDaemon() bool {
|
||||
activeProf, err := s.profileManager.GetActiveProfile()
|
||||
if err != nil {
|
||||
log.Warnf("disableAutoConnectFromDaemon: get active profile: %v", err)
|
||||
return false
|
||||
}
|
||||
currUser, err := user.Current()
|
||||
if err != nil {
|
||||
log.Warnf("disableAutoConnectFromDaemon: get current user: %v", err)
|
||||
return false
|
||||
}
|
||||
conn, err := s.getSrvClient(failFastTimeout)
|
||||
if err != nil {
|
||||
log.Warnf("disableAutoConnectFromDaemon: get daemon client: %v", err)
|
||||
return false
|
||||
}
|
||||
srvCfg, err := conn.GetConfig(s.ctx, &proto.GetConfigRequest{
|
||||
ProfileName: activeProf.ID.String(),
|
||||
Username: currUser.Username,
|
||||
})
|
||||
if err != nil {
|
||||
log.Warnf("disableAutoConnectFromDaemon: GetConfig RPC: %v", err)
|
||||
return false
|
||||
}
|
||||
return srvCfg.GetDisableAutoConnect()
|
||||
}
|
||||
|
||||
// getSrvConfig from the service to show it in the settings window.
|
||||
func (s *serviceClient) getSrvConfig() {
|
||||
s.managementURL = profilemanager.DefaultManagementURL
|
||||
|
||||
@@ -53,6 +53,9 @@ type NameServerGroup struct {
|
||||
ID string `gorm:"primaryKey"`
|
||||
// AccountID is a reference to Account that this object belongs
|
||||
AccountID string `gorm:"index"`
|
||||
// AccountSeqID is a per-account monotonically increasing identifier used as the
|
||||
// compact wire id when sending NetworkMap components to capable peers.
|
||||
AccountSeqID uint32 `json:"-" gorm:"index:idx_nameserver_groups_account_seq_id;not null;default:0"`
|
||||
// Name group name
|
||||
Name string
|
||||
// Description group description
|
||||
|
||||
@@ -308,7 +308,7 @@ func (s *Storage) OpenStorage(logger *slog.Logger) (storage.Storage, error) {
|
||||
if file == "" {
|
||||
return nil, fmt.Errorf("sqlite3 storage requires 'file' config")
|
||||
}
|
||||
return (&sql.SQLite3{File: file}).Open(logger)
|
||||
return newSQLite3(file).Open(logger)
|
||||
case "postgres":
|
||||
dsn, _ := s.Config["dsn"].(string)
|
||||
if dsn == "" {
|
||||
|
||||
@@ -20,7 +20,6 @@ import (
|
||||
"github.com/dexidp/dex/server"
|
||||
"github.com/dexidp/dex/server/signer"
|
||||
"github.com/dexidp/dex/storage"
|
||||
"github.com/dexidp/dex/storage/sql"
|
||||
jose "github.com/go-jose/go-jose/v4"
|
||||
"github.com/google/uuid"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
@@ -79,7 +78,7 @@ func NewProvider(ctx context.Context, config *Config) (*Provider, error) {
|
||||
|
||||
// Initialize SQLite storage
|
||||
dbPath := filepath.Join(config.DataDir, "oidc.db")
|
||||
sqliteConfig := &sql.SQLite3{File: dbPath}
|
||||
sqliteConfig := newSQLite3(dbPath)
|
||||
stor, err := sqliteConfig.Open(logger)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open storage: %w", err)
|
||||
|
||||
15
idp/dex/sqlite_cgo.go
Normal file
15
idp/dex/sqlite_cgo.go
Normal file
@@ -0,0 +1,15 @@
|
||||
//go:build cgo
|
||||
|
||||
package dex
|
||||
|
||||
import (
|
||||
sql "github.com/dexidp/dex/storage/sql"
|
||||
)
|
||||
|
||||
// newSQLite3 builds the dex SQLite3 config. CGO builds use the upstream
|
||||
// struct that takes a File path. Non-CGO builds get an empty stub whose
|
||||
// Open() returns the dex "SQLite not available" error — correct behaviour
|
||||
// for binaries that can't link sqlite3 (e.g. cross-compiled ARM targets).
|
||||
func newSQLite3(file string) *sql.SQLite3 {
|
||||
return &sql.SQLite3{File: file}
|
||||
}
|
||||
15
idp/dex/sqlite_nocgo.go
Normal file
15
idp/dex/sqlite_nocgo.go
Normal file
@@ -0,0 +1,15 @@
|
||||
//go:build !cgo
|
||||
|
||||
package dex
|
||||
|
||||
import (
|
||||
sql "github.com/dexidp/dex/storage/sql"
|
||||
)
|
||||
|
||||
// newSQLite3 for non-CGO builds. The dex SQLite3 stub has no fields and its
|
||||
// Open() returns an error documenting the missing CGO support — correct
|
||||
// behaviour for cross-compiled artefacts that never actually run the
|
||||
// embedded IdP. The `file` argument is ignored.
|
||||
func newSQLite3(_ string) *sql.SQLite3 {
|
||||
return &sql.SQLite3{}
|
||||
}
|
||||
@@ -9,8 +9,6 @@ set -o pipefail
|
||||
|
||||
SED_STRIP_PADDING='s/=//g'
|
||||
|
||||
NETBIRD_EULA_URL="https://trust.netbird.io/?tab=reports-and-documents"
|
||||
|
||||
check_docker_compose() {
|
||||
if command -v docker-compose &> /dev/null; then
|
||||
echo "docker-compose"
|
||||
@@ -141,43 +139,6 @@ read_yes_no() {
|
||||
esac
|
||||
}
|
||||
|
||||
# Gate the install on explicit acceptance of the NetBird On-Premise EULA.
|
||||
require_eula_acceptance() {
|
||||
cat > /dev/stderr <<EOF
|
||||
|
||||
──────────────────────────────────────────────────────────────────────
|
||||
NetBird On-Premise End User License Agreement
|
||||
──────────────────────────────────────────────────────────────────────
|
||||
NetBird's on-premise software is commercial software, licensed and not
|
||||
sold. Your installation, deployment and use are governed by the NetBird
|
||||
On-Premise End User License Agreement (the "EULA"). Please read it in
|
||||
full before continuing — open the "On-Premise EULA" document here:
|
||||
|
||||
${NETBIRD_EULA_URL}
|
||||
|
||||
By typing "accept" and continuing the installation, you confirm that you
|
||||
have read and agree to the EULA, that you are authorized to accept it on
|
||||
behalf of your organization (the "Customer"), and that the Software is
|
||||
used for business purposes only.
|
||||
──────────────────────────────────────────────────────────────────────
|
||||
EOF
|
||||
|
||||
if [[ "${NB_ACCEPT_EULA:-}" == "yes" ]]; then
|
||||
echo "EULA accepted via NB_ACCEPT_EULA=yes." > /dev/stderr
|
||||
return 0
|
||||
fi
|
||||
|
||||
local ans=""
|
||||
echo -n 'Type "accept" to agree, or anything else to abort: ' > /dev/stderr
|
||||
read -r ans < /dev/tty
|
||||
if [[ "$ans" != "accept" ]]; then
|
||||
echo "" > /dev/stderr
|
||||
echo "EULA not accepted. Aborting installation." > /dev/stderr
|
||||
exit 1
|
||||
fi
|
||||
echo "" > /dev/stderr
|
||||
}
|
||||
|
||||
wait_postgres() {
|
||||
set +e
|
||||
echo -n "Waiting for postgres to become ready"
|
||||
@@ -213,9 +174,6 @@ init_environment() {
|
||||
exit 1
|
||||
fi
|
||||
|
||||
require_eula_acceptance
|
||||
NETBIRD_EULA_ACCEPTED_AT=$(date -u +%Y-%m-%dT%H:%M:%SZ)
|
||||
|
||||
echo "NetBird Enterprise bootstrap"
|
||||
echo ""
|
||||
echo "Traffic flow:"
|
||||
@@ -302,11 +260,6 @@ render_env() {
|
||||
# Generated by getting-started-enterprise.sh
|
||||
# Holds all configuration and secrets for the stack. Mode 600.
|
||||
|
||||
# NetBird On-Premise EULA acceptance
|
||||
NETBIRD_EULA_ACCEPTED=yes
|
||||
NETBIRD_EULA_ACCEPTED_AT=${NETBIRD_EULA_ACCEPTED_AT}
|
||||
NETBIRD_EULA_URL=${NETBIRD_EULA_URL}
|
||||
|
||||
# Features (set by the script; don't edit without re-running)
|
||||
NETBIRD_TRAFFIC_FLOW_ENABLED=${NETBIRD_TRAFFIC_FLOW}
|
||||
|
||||
|
||||
@@ -25,8 +25,6 @@ set -o pipefail
|
||||
OVERRIDE_FILE="docker-compose.override.yml"
|
||||
ENTERPRISE_CONFIG_FILE="config.yaml.enterprise"
|
||||
|
||||
NETBIRD_EULA_URL="https://trust.netbird.io/?tab=reports-and-documents"
|
||||
|
||||
check_docker_compose() {
|
||||
if command -v docker-compose &> /dev/null; then
|
||||
echo "docker-compose"
|
||||
@@ -117,43 +115,6 @@ read_yes_no() {
|
||||
esac
|
||||
}
|
||||
|
||||
# Gate the migration on explicit acceptance of the NetBird On-Premise EULA.
|
||||
require_eula_acceptance() {
|
||||
cat > /dev/stderr <<EOF
|
||||
|
||||
──────────────────────────────────────────────────────────────────────
|
||||
NetBird On-Premise End User License Agreement
|
||||
──────────────────────────────────────────────────────────────────────
|
||||
NetBird's on-premise software is commercial software, licensed and not
|
||||
sold. Your installation, deployment and use are governed by the NetBird
|
||||
On-Premise End User License Agreement (the "EULA"). Please read it in
|
||||
full before continuing — open the "On-Premise EULA" document here:
|
||||
|
||||
${NETBIRD_EULA_URL}
|
||||
|
||||
By typing "accept" and continuing the installation, you confirm that you
|
||||
have read and agree to the EULA, that you are authorized to accept it on
|
||||
behalf of your organization (the "Customer"), and that the Software is
|
||||
used for business purposes only.
|
||||
──────────────────────────────────────────────────────────────────────
|
||||
EOF
|
||||
|
||||
if [[ "${NB_ACCEPT_EULA:-}" == "yes" ]]; then
|
||||
echo "EULA accepted via NB_ACCEPT_EULA=yes." > /dev/stderr
|
||||
return 0
|
||||
fi
|
||||
|
||||
local ans=""
|
||||
echo -n 'Type "accept" to agree, or anything else to abort: ' > /dev/stderr
|
||||
read -r ans < /dev/tty
|
||||
if [[ "$ans" != "accept" ]]; then
|
||||
echo "" > /dev/stderr
|
||||
echo "EULA not accepted. Aborting migration." > /dev/stderr
|
||||
exit 1
|
||||
fi
|
||||
echo "" > /dev/stderr
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Detection — read the operator's existing compose to find service names and
|
||||
# paths we need to override. Bail loudly if shape isn't recognised.
|
||||
@@ -423,9 +384,6 @@ init_migration() {
|
||||
check_yq
|
||||
check_openssl
|
||||
|
||||
require_eula_acceptance
|
||||
NETBIRD_EULA_ACCEPTED_AT=$(date -u +%Y-%m-%dT%H:%M:%SZ)
|
||||
|
||||
COMPOSE_FILE="${COMPOSE_FILE:-docker-compose.yml}"
|
||||
|
||||
if [[ ! -f "$COMPOSE_FILE" ]]; then
|
||||
@@ -571,10 +529,6 @@ apply_changes() {
|
||||
{
|
||||
echo ""
|
||||
echo "# Added by migrate-to-enterprise.sh on $(date -u +%Y-%m-%dT%H:%M:%SZ)"
|
||||
echo "# NetBird On-Premise EULA accepted at install time"
|
||||
echo "NETBIRD_EULA_ACCEPTED=yes"
|
||||
echo "NETBIRD_EULA_ACCEPTED_AT=${NETBIRD_EULA_ACCEPTED_AT}"
|
||||
echo "NETBIRD_EULA_URL=${NETBIRD_EULA_URL}"
|
||||
echo "NB_LICENSE_KEY=${NB_LICENSE_KEY}"
|
||||
if [[ -n "${NETBIRD_LICENSE_SERVER_BASE_URL:-}" ]]; then
|
||||
echo "NETBIRD_LICENSE_SERVER_BASE_URL=${NETBIRD_LICENSE_SERVER_BASE_URL}"
|
||||
|
||||
@@ -56,6 +56,12 @@ type Controller struct {
|
||||
proxyController port_forwarding.Controller
|
||||
|
||||
integratedPeerValidator integrated_validator.IntegratedValidator
|
||||
|
||||
// componentsDisabled, when true, forces the controller to emit legacy
|
||||
// proto.NetworkMap to every peer regardless of capability. Set once at
|
||||
// construction and never written after — readers race-free without a
|
||||
// mutex.
|
||||
componentsDisabled bool
|
||||
}
|
||||
|
||||
type bufferUpdate struct {
|
||||
@@ -89,12 +95,27 @@ func NewController(ctx context.Context, store store.Store, metrics telemetry.App
|
||||
settingsManager: settingsManager,
|
||||
dnsDomain: dnsDomain,
|
||||
config: config,
|
||||
componentsDisabled: parseBoolEnv("NB_NETWORK_MAP_COMPONENTS_DISABLE"),
|
||||
|
||||
proxyController: proxyController,
|
||||
EphemeralPeersManager: ephemeralPeersManager,
|
||||
}
|
||||
}
|
||||
|
||||
// PeerNeedsComponents reports whether the gRPC layer should emit the
|
||||
// component-based wire format for this peer.
|
||||
func (c *Controller) PeerNeedsComponents(p *nbpeer.Peer) bool {
|
||||
return p != nil && p.SupportsComponentNetworkMap() && !c.componentsDisabled
|
||||
}
|
||||
|
||||
// parseBoolEnv reads an env var via strconv.ParseBool so callers accept the
|
||||
// usual "1/t/T/TRUE/true/True" set instead of being strict about a single
|
||||
// literal.
|
||||
func parseBoolEnv(key string) bool {
|
||||
v, _ := strconv.ParseBool(os.Getenv(key))
|
||||
return v
|
||||
}
|
||||
|
||||
func (c *Controller) OnPeerConnected(ctx context.Context, accountID string, peerID string) (chan *network_map.UpdateMessage, error) {
|
||||
peer, err := c.repo.GetPeerByID(ctx, accountID, peerID)
|
||||
if err != nil {
|
||||
@@ -204,18 +225,26 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
|
||||
c.metrics.CountCalcPostureChecksDuration(time.Since(start))
|
||||
start = time.Now()
|
||||
|
||||
remotePeerNetworkMap := account.GetPeerNetworkMapFromComponents(ctx, p.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
||||
result := account.GetPeerNetworkMapResult(ctx, p.ID, c.componentsDisabled, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
||||
|
||||
c.metrics.CountCalcPeerNetworkMapDuration(time.Since(start))
|
||||
|
||||
proxyNetworkMap, ok := proxyNetworkMaps[p.ID]
|
||||
if ok {
|
||||
remotePeerNetworkMap.Merge(proxyNetworkMap)
|
||||
proxyNetworkMap := proxyNetworkMaps[p.ID]
|
||||
if result.NetworkMap != nil && proxyNetworkMap != nil {
|
||||
result.NetworkMap.Merge(proxyNetworkMap)
|
||||
}
|
||||
|
||||
peerGroups := account.GetPeerGroups(p.ID)
|
||||
start = time.Now()
|
||||
update := grpc.ToSyncResponse(ctx, nil, c.config.HttpConfig, c.config.DeviceAuthorizationFlow, p, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSetting, maps.Keys(peerGroups), dnsFwdPort)
|
||||
var update *proto.SyncResponse
|
||||
if result.IsComponents() {
|
||||
// proxyNetworkMap rides the envelope as a ProxyPatch sidecar;
|
||||
// the client merges it into Calculate()'s output the same
|
||||
// way the legacy server did via NetworkMap.Merge.
|
||||
update = grpc.ToComponentSyncResponse(ctx, nil, c.config.HttpConfig, c.config.DeviceAuthorizationFlow, p, nil, nil, result.Components, proxyNetworkMap, dnsDomain, postureChecks, account.Settings, extraSetting, maps.Keys(peerGroups), dnsFwdPort)
|
||||
} else {
|
||||
update = grpc.ToSyncResponse(ctx, nil, c.config.HttpConfig, c.config.DeviceAuthorizationFlow, p, nil, nil, result.NetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSetting, maps.Keys(peerGroups), dnsFwdPort)
|
||||
}
|
||||
c.metrics.CountToSyncResponseDuration(time.Since(start))
|
||||
|
||||
c.peersUpdateManager.SendUpdate(ctx, p.ID, &network_map.UpdateMessage{
|
||||
@@ -425,11 +454,11 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe
|
||||
return err
|
||||
}
|
||||
|
||||
remotePeerNetworkMap := account.GetPeerNetworkMapFromComponents(ctx, peerId, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
||||
result := account.GetPeerNetworkMapResult(ctx, peerId, c.componentsDisabled, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
||||
|
||||
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
|
||||
if ok {
|
||||
remotePeerNetworkMap.Merge(proxyNetworkMap)
|
||||
proxyNetworkMap := proxyNetworkMaps[peer.ID]
|
||||
if result.NetworkMap != nil && proxyNetworkMap != nil {
|
||||
result.NetworkMap.Merge(proxyNetworkMap)
|
||||
}
|
||||
|
||||
extraSettings, err := c.settingsManager.GetExtraSettings(ctx, peer.AccountID)
|
||||
@@ -440,7 +469,12 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe
|
||||
peerGroups := account.GetPeerGroups(peerId)
|
||||
dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), network_map.DnsForwarderPortMinVersion)
|
||||
|
||||
update := grpc.ToSyncResponse(ctx, nil, c.config.HttpConfig, c.config.DeviceAuthorizationFlow, peer, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSettings, maps.Keys(peerGroups), dnsFwdPort)
|
||||
var update *proto.SyncResponse
|
||||
if result.IsComponents() {
|
||||
update = grpc.ToComponentSyncResponse(ctx, nil, c.config.HttpConfig, c.config.DeviceAuthorizationFlow, peer, nil, nil, result.Components, proxyNetworkMap, dnsDomain, postureChecks, account.Settings, extraSettings, maps.Keys(peerGroups), dnsFwdPort)
|
||||
} else {
|
||||
update = grpc.ToSyncResponse(ctx, nil, c.config.HttpConfig, c.config.DeviceAuthorizationFlow, peer, nil, nil, result.NetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSettings, maps.Keys(peerGroups), dnsFwdPort)
|
||||
}
|
||||
c.peersUpdateManager.SendUpdate(ctx, peer.ID, &network_map.UpdateMessage{
|
||||
Update: update,
|
||||
MessageType: network_map.MessageTypeNetworkMap,
|
||||
@@ -487,6 +521,66 @@ func (c *Controller) BufferUpdateAccountPeers(ctx context.Context, accountID str
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetValidatedPeerWithComponents is the components-format counterpart of
|
||||
// GetValidatedPeerWithMap. It returns raw NetworkMapComponents for capable
|
||||
// peers along with the proxy NetworkMap fragment (BYOP / port-forwarding
|
||||
// data the legacy server folds in via NetworkMap.Merge). The gRPC layer
|
||||
// encodes both into the wire envelope. Callers must gate on capability
|
||||
// themselves before dispatching here — this method does NOT branch on it.
|
||||
func (c *Controller) GetValidatedPeerWithComponents(ctx context.Context, isRequiresApproval bool, accountID string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMapComponents, *types.NetworkMap, []*posture.Checks, int64, error) {
|
||||
if isRequiresApproval {
|
||||
network, err := c.repo.GetAccountNetwork(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, 0, err
|
||||
}
|
||||
return peer, &types.NetworkMapComponents{Network: network.Copy()}, nil, nil, 0, nil
|
||||
}
|
||||
|
||||
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, 0, err
|
||||
}
|
||||
|
||||
account.InjectProxyPolicies(ctx)
|
||||
|
||||
approvedPeersMap, err := c.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, 0, err
|
||||
}
|
||||
|
||||
postureChecks, err := c.getPeerPostureChecks(account, peer.ID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, 0, err
|
||||
}
|
||||
|
||||
accountZones, err := c.repo.GetAccountZones(ctx, account.Id)
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, 0, err
|
||||
}
|
||||
|
||||
// Fetch the proxy network map fragment for this peer alongside the
|
||||
// components — same single-account-load path the streaming controller
|
||||
// uses, so initial-sync delivers BYOP/forwarding patches synchronously
|
||||
// instead of waiting for the next streaming push.
|
||||
proxyNetworkMaps, err := c.proxyController.GetProxyNetworkMaps(ctx, account.Id, peer.ID, account.Peers)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err)
|
||||
return nil, nil, nil, nil, 0, err
|
||||
}
|
||||
|
||||
dnsDomain := c.GetDNSDomain(account.Settings)
|
||||
peersCustomZone := account.GetPeersCustomZone(ctx, dnsDomain)
|
||||
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
groupIDToUserIDs := account.GetActiveGroupUsers()
|
||||
components := account.GetPeerNetworkMapComponents(ctx, peer.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, groupIDToUserIDs)
|
||||
|
||||
dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), network_map.DnsForwarderPortMinVersion)
|
||||
|
||||
return peer, components, proxyNetworkMaps[peer.ID], postureChecks, dnsFwdPort, nil
|
||||
}
|
||||
|
||||
// BufferUpdateAffectedPeers accumulates peer IDs and flushes them after the buffer interval.
|
||||
func (c *Controller) BufferUpdateAffectedPeers(ctx context.Context, accountID string, peerIDs []string, reason types.UpdateReason) error {
|
||||
if len(peerIDs) == 0 {
|
||||
@@ -497,7 +591,7 @@ func (c *Controller) BufferUpdateAffectedPeers(ctx context.Context, accountID st
|
||||
c.accountManagerMetrics.CountUpdateAccountPeersTriggered(string(reason.Resource), string(reason.Operation))
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Tracef("buffer updating %d affected peers for account %s from %s", len(peerIDs), accountID, util.GetCallerName())
|
||||
log.WithContext(ctx).Tracef("buffer updating %d affected peers for account %s from %s with reason %s/%s", len(peerIDs), accountID, util.GetCallerName(), reason.Operation, reason.Resource)
|
||||
|
||||
bufUpd, _ := c.affectedPeerUpdateLocks.LoadOrStore(accountID, &bufferAffectedUpdate{
|
||||
peerIDs: make(map[string]struct{}),
|
||||
|
||||
@@ -24,6 +24,10 @@ type Controller interface {
|
||||
UpdateAccountPeer(ctx context.Context, accountId string, peerId string) error
|
||||
BufferUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) error
|
||||
GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, peerID string) (*types.NetworkMap, []*posture.Checks, int64, error)
|
||||
GetValidatedPeerWithComponents(ctx context.Context, isRequiresApproval bool, accountID string, p *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMapComponents, *types.NetworkMap, []*posture.Checks, int64, error)
|
||||
// PeerNeedsComponents combines the peer's advertised capability with the
|
||||
// kill-switch flag — the only public predicate gRPC layers should ask.
|
||||
PeerNeedsComponents(p *nbpeer.Peer) bool
|
||||
GetDNSDomain(settings *types.Settings) string
|
||||
StartWarmup(context.Context)
|
||||
GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error)
|
||||
|
||||
@@ -143,6 +143,39 @@ func (mr *MockControllerMockRecorder) GetValidatedPeerWithMap(ctx, isRequiresApp
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetValidatedPeerWithMap", reflect.TypeOf((*MockController)(nil).GetValidatedPeerWithMap), ctx, isRequiresApproval, accountID, peerID)
|
||||
}
|
||||
|
||||
// GetValidatedPeerWithComponents mocks base method.
|
||||
func (m *MockController) GetValidatedPeerWithComponents(ctx context.Context, isRequiresApproval bool, accountID string, p *peer.Peer) (*peer.Peer, *types.NetworkMapComponents, *types.NetworkMap, []*posture.Checks, int64, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetValidatedPeerWithComponents", ctx, isRequiresApproval, accountID, p)
|
||||
ret0, _ := ret[0].(*peer.Peer)
|
||||
ret1, _ := ret[1].(*types.NetworkMapComponents)
|
||||
ret2, _ := ret[2].(*types.NetworkMap)
|
||||
ret3, _ := ret[3].([]*posture.Checks)
|
||||
ret4, _ := ret[4].(int64)
|
||||
ret5, _ := ret[5].(error)
|
||||
return ret0, ret1, ret2, ret3, ret4, ret5
|
||||
}
|
||||
|
||||
// GetValidatedPeerWithComponents indicates an expected call of GetValidatedPeerWithComponents.
|
||||
func (mr *MockControllerMockRecorder) GetValidatedPeerWithComponents(ctx, isRequiresApproval, accountID, p any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetValidatedPeerWithComponents", reflect.TypeOf((*MockController)(nil).GetValidatedPeerWithComponents), ctx, isRequiresApproval, accountID, p)
|
||||
}
|
||||
|
||||
// PeerNeedsComponents mocks base method.
|
||||
func (m *MockController) PeerNeedsComponents(p *peer.Peer) bool {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "PeerNeedsComponents", p)
|
||||
ret0, _ := ret[0].(bool)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// PeerNeedsComponents indicates an expected call of PeerNeedsComponents.
|
||||
func (mr *MockControllerMockRecorder) PeerNeedsComponents(p any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PeerNeedsComponents", reflect.TypeOf((*MockController)(nil).PeerNeedsComponents), p)
|
||||
}
|
||||
|
||||
// OnPeerConnected mocks base method.
|
||||
func (m *MockController) OnPeerConnected(ctx context.Context, accountID, peerID string) (chan *UpdateMessage, error) {
|
||||
m.ctrl.T.Helper()
|
||||
|
||||
813
management/internals/shared/grpc/components_encoder.go
Normal file
813
management/internals/shared/grpc/components_encoder.go
Normal file
@@ -0,0 +1,813 @@
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"strconv"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
nbroute "github.com/netbirdio/netbird/route"
|
||||
"github.com/netbirdio/netbird/shared/management/networkmap"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
// wgKeyRawLen is the raw byte length of a WireGuard public key.
|
||||
const wgKeyRawLen = 32
|
||||
|
||||
// ComponentsEnvelopeInput bundles the data the component-format encoder needs.
|
||||
// The envelope is fully self-contained — every field needed by the client's
|
||||
// local Calculate() comes from the components struct itself. The only
|
||||
// externally-supplied data is the receiving peer's PeerConfig (which is
|
||||
// computed alongside the components in the network_map controller and reused
|
||||
// from the legacy proto path) and the dns_domain string.
|
||||
type ComponentsEnvelopeInput struct {
|
||||
Components *types.NetworkMapComponents
|
||||
PeerConfig *proto.PeerConfig
|
||||
DNSDomain string
|
||||
DNSForwarderPort int64
|
||||
// UserIDClaim is the OIDC claim name the client should embed in
|
||||
// SshAuth.UserIDClaim when reconstructing the NetworkMap. Empty value
|
||||
// is OK — client treats empty as "no SshAuth to build".
|
||||
UserIDClaim string
|
||||
// ProxyPatch carries pre-expanded NetworkMap fragments injected by
|
||||
// external controllers (BYOP/port-forwarding). Nil when no proxy data
|
||||
// is present; encoder skips the field in that case.
|
||||
ProxyPatch *proto.ProxyPatch
|
||||
}
|
||||
|
||||
// EncodeNetworkMapEnvelope converts NetworkMapComponents into the component
|
||||
// wire envelope. The encoder is intentionally non-deterministic: it iterates
|
||||
// Go maps in their native (random) order. Indexes inside the envelope
|
||||
// (peer_indexes, source_group_ids, agent_version_idx, router_peer_indexes)
|
||||
// are self-consistent within a single encode, so the decoder reconstructs
|
||||
// the same typed objects regardless of emit order. Tests that need to
|
||||
// compare envelopes do so semantically via proto round-trip + canonicalize,
|
||||
// not byte-equal.
|
||||
//
|
||||
// Callers must NOT concatenate or merge envelopes from different encodes —
|
||||
// index spaces are local to a single envelope.
|
||||
func EncodeNetworkMapEnvelope(in ComponentsEnvelopeInput) *proto.NetworkMapEnvelope {
|
||||
c := in.Components
|
||||
|
||||
// Graceful degrade when components is nil — matches the legacy path's
|
||||
// behaviour for missing/unvalidated peers (return a NetworkMap with only
|
||||
// Network populated). The receiver gets an envelope it can decode
|
||||
// without crashing; AccountSettings stays non-nil so client-side
|
||||
// dereferences are safe.
|
||||
if c == nil {
|
||||
// Match legacy missing-peer minimum: a NetworkMap with only Network
|
||||
// populated. The receiver gets enough to bootstrap (Network
|
||||
// identifier, dns_domain, account_settings) and nothing else.
|
||||
return &proto.NetworkMapEnvelope{
|
||||
Payload: &proto.NetworkMapEnvelope_Full{
|
||||
Full: &proto.NetworkMapComponentsFull{
|
||||
PeerConfig: in.PeerConfig,
|
||||
DnsDomain: in.DNSDomain,
|
||||
DnsForwarderPort: in.DNSForwarderPort,
|
||||
UserIdClaim: in.UserIDClaim,
|
||||
AccountSettings: &proto.AccountSettingsCompact{},
|
||||
ProxyPatch: in.ProxyPatch,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Phase 1: build dedup tables. Every routing peer (in c.RouterPeers) and
|
||||
// every regular peer (in c.Peers) must be indexed before any encoder
|
||||
// looks up indexes via e.peerOrder — otherwise routes / routers_map for
|
||||
// peers that exist only in c.RouterPeers would silently lose their
|
||||
// peer_index reference.
|
||||
enc := newComponentEncoder(c)
|
||||
enc.indexAllPeers()
|
||||
routerIdxs := enc.indexRouterPeers(c.RouterPeers)
|
||||
|
||||
// Phase 2: gather every policy that any consumer references (peer-pair
|
||||
// policies + resource-only policies) so encodeResourcePoliciesMap can
|
||||
// translate every *Policy pointer to a wire index.
|
||||
allPolicies := unionPolicies(c.Policies, c.ResourcePoliciesMap)
|
||||
policies, policyToIdxs := enc.encodePolicies(allPolicies)
|
||||
|
||||
// Phase 3: emit. Order of struct field expressions no longer matters:
|
||||
// every encoder either reads from the dedup tables or works on
|
||||
// independent input.
|
||||
full := &proto.NetworkMapComponentsFull{
|
||||
Serial: networkSerial(c.Network),
|
||||
PeerConfig: in.PeerConfig,
|
||||
Network: toAccountNetwork(c.Network),
|
||||
AccountSettings: toAccountSettingsCompact(c.AccountSettings),
|
||||
DnsForwarderPort: in.DNSForwarderPort,
|
||||
UserIdClaim: in.UserIDClaim,
|
||||
ProxyPatch: in.ProxyPatch,
|
||||
DnsSettings: enc.encodeDNSSettings(c.DNSSettings),
|
||||
DnsDomain: in.DNSDomain,
|
||||
CustomZoneDomain: c.CustomZoneDomain,
|
||||
AgentVersions: enc.agentVersions,
|
||||
Peers: enc.peers,
|
||||
RouterPeerIndexes: routerIdxs,
|
||||
Policies: policies,
|
||||
Groups: enc.encodeGroups(),
|
||||
Routes: enc.encodeRoutes(c.Routes),
|
||||
NameserverGroups: enc.encodeNameServerGroups(c.NameServerGroups),
|
||||
AllDnsRecords: encodeSimpleRecords(c.AllDNSRecords),
|
||||
AccountZones: encodeCustomZones(c.AccountZones),
|
||||
NetworkResources: enc.encodeNetworkResources(c.NetworkResources),
|
||||
RoutersMap: enc.encodeRoutersMap(c.RoutersMap),
|
||||
ResourcePoliciesMap: enc.encodeResourcePoliciesMap(c.ResourcePoliciesMap, policyToIdxs),
|
||||
GroupIdToUserIds: enc.encodeGroupIDToUserIDs(c.GroupIDToUserIDs),
|
||||
AllowedUserIds: stringSetToSlice(c.AllowedUserIDs),
|
||||
PostureFailedPeers: enc.encodePostureFailedPeers(c.PostureFailedPeers),
|
||||
}
|
||||
|
||||
return &proto.NetworkMapEnvelope{
|
||||
Payload: &proto.NetworkMapEnvelope_Full{Full: full},
|
||||
}
|
||||
}
|
||||
|
||||
// networkSerial returns c.Network.CurrentSerial() with a nil guard. The
|
||||
// production path always populates c.Network, but the encoder is exported
|
||||
// and a hand-built components struct may omit it.
|
||||
func networkSerial(n *types.Network) uint64 {
|
||||
if n == nil {
|
||||
return 0
|
||||
}
|
||||
return n.CurrentSerial()
|
||||
}
|
||||
|
||||
type componentEncoder struct {
|
||||
components *types.NetworkMapComponents
|
||||
|
||||
peerOrder map[string]uint32
|
||||
peers []*proto.PeerCompact
|
||||
|
||||
agentVersionOrder map[string]uint32
|
||||
agentVersions []string
|
||||
}
|
||||
|
||||
func newComponentEncoder(c *types.NetworkMapComponents) *componentEncoder {
|
||||
return &componentEncoder{
|
||||
components: c,
|
||||
peerOrder: make(map[string]uint32, len(c.Peers)),
|
||||
peers: make([]*proto.PeerCompact, 0, len(c.Peers)),
|
||||
agentVersionOrder: make(map[string]uint32),
|
||||
}
|
||||
}
|
||||
|
||||
func (e *componentEncoder) indexAllPeers() {
|
||||
for _, p := range e.components.Peers {
|
||||
if p == nil {
|
||||
continue
|
||||
}
|
||||
e.appendPeer(p)
|
||||
}
|
||||
}
|
||||
|
||||
func (e *componentEncoder) appendPeer(p *nbpeer.Peer) uint32 {
|
||||
if idx, ok := e.peerOrder[p.ID]; ok {
|
||||
return idx
|
||||
}
|
||||
idx := uint32(len(e.peers))
|
||||
e.peerOrder[p.ID] = idx
|
||||
e.peers = append(e.peers, toPeerCompact(p, e.agentVersionIndex(p.Meta.WtVersion)))
|
||||
return idx
|
||||
}
|
||||
|
||||
func (e *componentEncoder) agentVersionIndex(v string) uint32 {
|
||||
if idx, ok := e.agentVersionOrder[v]; ok {
|
||||
return idx
|
||||
}
|
||||
// Lazy-initialise the table with "" at index 0 so the empty string
|
||||
// stays interchangeable with proto3's default uint32=0 — peers without
|
||||
// a WtVersion don't force the table to materialise.
|
||||
if v == "" {
|
||||
idx := uint32(len(e.agentVersions))
|
||||
if idx == 0 {
|
||||
e.agentVersions = append(e.agentVersions, "")
|
||||
}
|
||||
e.agentVersionOrder[""] = idx
|
||||
return idx
|
||||
}
|
||||
if len(e.agentVersions) == 0 {
|
||||
e.agentVersions = append(e.agentVersions, "")
|
||||
e.agentVersionOrder[""] = 0
|
||||
}
|
||||
idx := uint32(len(e.agentVersions))
|
||||
e.agentVersionOrder[v] = idx
|
||||
e.agentVersions = append(e.agentVersions, v)
|
||||
return idx
|
||||
}
|
||||
|
||||
// indexRouterPeers ensures every router peer is in the peer dedup table
|
||||
// (c.RouterPeers may contain peers not in c.Peers when validation rules drop
|
||||
// them) and returns their wire indexes for the RouterPeerIndexes field. Must
|
||||
// run before any encoder that resolves peer ids via e.peerOrder.
|
||||
func (e *componentEncoder) indexRouterPeers(routers map[string]*nbpeer.Peer) []uint32 {
|
||||
if len(routers) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]uint32, 0, len(routers))
|
||||
for _, p := range routers {
|
||||
if p == nil {
|
||||
continue
|
||||
}
|
||||
out = append(out, e.appendPeer(p))
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (e *componentEncoder) encodeGroups() []*proto.GroupCompact {
|
||||
if len(e.components.Groups) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
out := make([]*proto.GroupCompact, 0, len(e.components.Groups))
|
||||
for _, g := range e.components.Groups {
|
||||
if !g.HasSeqID() {
|
||||
continue
|
||||
}
|
||||
peerIdxs := make([]uint32, 0, len(g.Peers))
|
||||
for _, peerID := range g.Peers {
|
||||
if idx, ok := e.peerOrder[peerID]; ok {
|
||||
peerIdxs = append(peerIdxs, idx)
|
||||
}
|
||||
}
|
||||
out = append(out, &proto.GroupCompact{
|
||||
Id: g.AccountSeqID,
|
||||
Name: g.Name,
|
||||
PeerIndexes: peerIdxs,
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// encodePolicies flattens Policy{Rules} → []PolicyCompact. Returns the wire
|
||||
// list and a map from policy pointer to the indexes of its emitted rules in
|
||||
// that list — used by encodeResourcePoliciesMap to translate
|
||||
// ResourcePoliciesMap[resourceID][]*Policy into wire-side indexes.
|
||||
func (e *componentEncoder) encodePolicies(policies []*types.Policy) ([]*proto.PolicyCompact, map[*types.Policy][]uint32) {
|
||||
if len(policies) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
out := make([]*proto.PolicyCompact, 0, len(policies))
|
||||
idxByPolicy := make(map[*types.Policy][]uint32, len(policies))
|
||||
|
||||
for _, pol := range policies {
|
||||
if !pol.HasSeqID() || !pol.Enabled {
|
||||
continue
|
||||
}
|
||||
for _, r := range pol.Rules {
|
||||
if r == nil || !r.Enabled {
|
||||
continue
|
||||
}
|
||||
idxByPolicy[pol] = append(idxByPolicy[pol], uint32(len(out)))
|
||||
out = append(out, e.encodePolicyRule(pol, r))
|
||||
}
|
||||
}
|
||||
return out, idxByPolicy
|
||||
}
|
||||
|
||||
// encodePolicyRule maps a single PolicyRule under pol to a PolicyCompact entry.
|
||||
func (e *componentEncoder) encodePolicyRule(pol *types.Policy, r *types.PolicyRule) *proto.PolicyCompact {
|
||||
return &proto.PolicyCompact{
|
||||
Id: pol.AccountSeqID,
|
||||
Action: networkmap.GetProtoAction(string(r.Action)),
|
||||
Protocol: networkmap.GetProtoProtocol(string(r.Protocol)),
|
||||
Bidirectional: r.Bidirectional,
|
||||
Ports: portsToUint32(r.Ports),
|
||||
PortRanges: portRangesToProto(r.PortRanges),
|
||||
SourceGroupIds: e.groupSeqIDs(r.Sources),
|
||||
DestinationGroupIds: e.groupSeqIDs(r.Destinations),
|
||||
AuthorizedUser: r.AuthorizedUser,
|
||||
AuthorizedGroups: e.encodeAuthorizedGroups(r.AuthorizedGroups),
|
||||
SourceResource: e.resourceToProto(r.SourceResource),
|
||||
DestinationResource: e.resourceToProto(r.DestinationResource),
|
||||
SourcePostureCheckSeqIds: e.postureCheckSeqs(pol.SourcePostureChecks),
|
||||
}
|
||||
}
|
||||
|
||||
// groupSeqIDs maps the xid group IDs in src to their per-account seq ids,
|
||||
// dropping any group that has no seq id assigned.
|
||||
func (e *componentEncoder) groupSeqIDs(src []string) []uint32 {
|
||||
if len(src) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]uint32, 0, len(src))
|
||||
for _, gid := range src {
|
||||
if seq, ok := e.groupSeq(gid); ok {
|
||||
out = append(out, seq)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// unionPolicies merges c.Policies with every policy referenced by
|
||||
// c.ResourcePoliciesMap, deduplicating by pointer identity. Resource-only
|
||||
// policies (relevant to a NetworkResource but not to peer-pair traffic)
|
||||
// only live in ResourcePoliciesMap; without this union step they'd be lost
|
||||
// from the wire and the client's resource-policy lookup would come back
|
||||
// empty.
|
||||
func unionPolicies(policies []*types.Policy, resourcePolicies map[string][]*types.Policy) []*types.Policy {
|
||||
// Fast path: non-router peers have no resource-only policies, so the
|
||||
// "union" is identical to `policies`. Skip the dedup map allocation.
|
||||
if len(resourcePolicies) == 0 {
|
||||
return policies
|
||||
}
|
||||
seen := make(map[*types.Policy]struct{}, len(policies))
|
||||
out := make([]*types.Policy, 0, len(policies))
|
||||
for _, p := range policies {
|
||||
if p == nil {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[p]; ok {
|
||||
continue
|
||||
}
|
||||
seen[p] = struct{}{}
|
||||
out = append(out, p)
|
||||
}
|
||||
for _, list := range resourcePolicies {
|
||||
for _, p := range list {
|
||||
if p == nil {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[p]; ok {
|
||||
continue
|
||||
}
|
||||
seen[p] = struct{}{}
|
||||
out = append(out, p)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// encodeAuthorizedGroups translates rule.AuthorizedGroups (map keyed by
|
||||
// group xid → local-user names) to the wire form (map keyed by group
|
||||
// account_seq_id → UserNameList). Groups without a seq id are dropped —
|
||||
// matches how source/destination group references handle the same case.
|
||||
func (e *componentEncoder) encodeAuthorizedGroups(m map[string][]string) map[uint32]*proto.UserNameList {
|
||||
if len(m) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make(map[uint32]*proto.UserNameList, len(m))
|
||||
for groupID, names := range m {
|
||||
seq, ok := e.groupSeq(groupID)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
out[seq] = &proto.UserNameList{Names: append([]string(nil), names...)}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (e *componentEncoder) groupSeq(groupID string) (uint32, bool) {
|
||||
g, ok := e.components.Groups[groupID]
|
||||
if !ok || !g.HasSeqID() {
|
||||
return 0, false
|
||||
}
|
||||
return g.AccountSeqID, true
|
||||
}
|
||||
|
||||
// resourceToProto translates types.Resource for the wire. For peer-typed
|
||||
// resources the peer id is converted to a peer index into the envelope's
|
||||
// peers array. For other resource types only the type string is shipped
|
||||
// today (Calculate's resource-typed rule path consults SourceResource only
|
||||
// for "peer" — other types fall through to group-based lookup).
|
||||
func (e *componentEncoder) resourceToProto(r types.Resource) *proto.ResourceCompact {
|
||||
if r.ID == "" && r.Type == "" {
|
||||
return nil
|
||||
}
|
||||
out := &proto.ResourceCompact{Type: string(r.Type)}
|
||||
if r.Type == types.ResourceTypePeer && r.ID != "" {
|
||||
if idx, ok := e.peerOrder[r.ID]; ok {
|
||||
out.PeerIndexSet = true
|
||||
out.PeerIndex = idx
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// postureCheckSeqs translates a slice of posture-check xids to their
|
||||
// per-account integer ids using the NetworkMapComponents.PostureCheckXIDToSeq
|
||||
// lookup. Unresolvable xids are silently dropped — matches how group/peer
|
||||
// references handle the same case.
|
||||
func (e *componentEncoder) postureCheckSeqs(xids []string) []uint32 {
|
||||
if len(xids) == 0 || len(e.components.PostureCheckXIDToSeq) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]uint32, 0, len(xids))
|
||||
for _, xid := range xids {
|
||||
if seq, ok := e.components.PostureCheckXIDToSeq[xid]; ok {
|
||||
out = append(out, seq)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// networkSeq translates a Network xid to its per-account integer id using
|
||||
// the NetworkMapComponents.NetworkXIDToSeq lookup. Returns (0,false) when
|
||||
// the xid isn't known — callers decide whether to skip the parent record.
|
||||
func (e *componentEncoder) networkSeq(xid string) (uint32, bool) {
|
||||
if xid == "" {
|
||||
return 0, false
|
||||
}
|
||||
seq, ok := e.components.NetworkXIDToSeq[xid]
|
||||
if !ok || seq == 0 {
|
||||
return 0, false
|
||||
}
|
||||
return seq, true
|
||||
}
|
||||
|
||||
func (e *componentEncoder) encodeDNSSettings(s *types.DNSSettings) *proto.DNSSettingsCompact {
|
||||
if s == nil || len(s.DisabledManagementGroups) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := &proto.DNSSettingsCompact{
|
||||
DisabledManagementGroupIds: make([]uint32, 0, len(s.DisabledManagementGroups)),
|
||||
}
|
||||
for _, gid := range s.DisabledManagementGroups {
|
||||
if seq, ok := e.groupSeq(gid); ok {
|
||||
out.DisabledManagementGroupIds = append(out.DisabledManagementGroupIds, seq)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (e *componentEncoder) encodeRoutes(routes []*nbroute.Route) []*proto.RouteRaw {
|
||||
if len(routes) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]*proto.RouteRaw, 0, len(routes))
|
||||
for _, r := range routes {
|
||||
if r == nil {
|
||||
continue
|
||||
}
|
||||
rr := &proto.RouteRaw{
|
||||
Id: r.AccountSeqID,
|
||||
NetId: string(r.NetID),
|
||||
Description: r.Description,
|
||||
KeepRoute: r.KeepRoute,
|
||||
NetworkType: int32(r.NetworkType),
|
||||
Masquerade: r.Masquerade,
|
||||
Metric: int32(r.Metric),
|
||||
Enabled: r.Enabled,
|
||||
SkipAutoApply: r.SkipAutoApply,
|
||||
Domains: r.Domains.ToPunycodeList(),
|
||||
GroupIds: e.groupIDsToSeq(r.Groups),
|
||||
AccessControlGroupIds: e.groupIDsToSeq(r.AccessControlGroups),
|
||||
PeerGroupIds: e.groupIDsToSeq(r.PeerGroups),
|
||||
}
|
||||
if r.Network.IsValid() {
|
||||
rr.NetworkCidr = r.Network.String()
|
||||
}
|
||||
if r.Peer != "" {
|
||||
if idx, ok := e.peerOrder[r.Peer]; ok {
|
||||
rr.PeerIndexSet = true
|
||||
rr.PeerIndex = idx
|
||||
}
|
||||
}
|
||||
out = append(out, rr)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (e *componentEncoder) groupIDsToSeq(groupIDs []string) []uint32 {
|
||||
if len(groupIDs) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]uint32, 0, len(groupIDs))
|
||||
for _, gid := range groupIDs {
|
||||
if seq, ok := e.groupSeq(gid); ok {
|
||||
out = append(out, seq)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (e *componentEncoder) encodeNameServerGroups(nsgs []*nbdns.NameServerGroup) []*proto.NameServerGroupRaw {
|
||||
if len(nsgs) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]*proto.NameServerGroupRaw, 0, len(nsgs))
|
||||
for _, nsg := range nsgs {
|
||||
if nsg == nil {
|
||||
continue
|
||||
}
|
||||
entry := &proto.NameServerGroupRaw{
|
||||
Id: nsg.AccountSeqID,
|
||||
Name: nsg.Name,
|
||||
Description: nsg.Description,
|
||||
Nameservers: encodeNameServers(nsg.NameServers),
|
||||
GroupIds: e.groupIDsToSeq(nsg.Groups),
|
||||
Primary: nsg.Primary,
|
||||
Domains: nsg.Domains,
|
||||
Enabled: nsg.Enabled,
|
||||
SearchDomainsEnabled: nsg.SearchDomainsEnabled,
|
||||
}
|
||||
out = append(out, entry)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func encodeNameServers(servers []nbdns.NameServer) []*proto.NameServer {
|
||||
if len(servers) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]*proto.NameServer, 0, len(servers))
|
||||
for _, s := range servers {
|
||||
out = append(out, &proto.NameServer{
|
||||
IP: s.IP.String(),
|
||||
NSType: int64(s.NSType),
|
||||
Port: int64(s.Port),
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func encodeSimpleRecords(records []nbdns.SimpleRecord) []*proto.SimpleRecord {
|
||||
if len(records) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]*proto.SimpleRecord, 0, len(records))
|
||||
for _, r := range records {
|
||||
out = append(out, &proto.SimpleRecord{
|
||||
Name: r.Name,
|
||||
Type: int64(r.Type),
|
||||
Class: r.Class,
|
||||
TTL: int64(r.TTL),
|
||||
RData: r.RData,
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func encodeCustomZones(zones []nbdns.CustomZone) []*proto.CustomZone {
|
||||
if len(zones) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]*proto.CustomZone, 0, len(zones))
|
||||
for _, z := range zones {
|
||||
out = append(out, &proto.CustomZone{
|
||||
Domain: z.Domain,
|
||||
Records: encodeSimpleRecords(z.Records),
|
||||
SearchDomainDisabled: z.SearchDomainDisabled,
|
||||
NonAuthoritative: z.NonAuthoritative,
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (e *componentEncoder) encodeNetworkResources(resources []*resourceTypes.NetworkResource) []*proto.NetworkResourceRaw {
|
||||
if len(resources) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]*proto.NetworkResourceRaw, 0, len(resources))
|
||||
for _, r := range resources {
|
||||
if r == nil {
|
||||
continue
|
||||
}
|
||||
entry := &proto.NetworkResourceRaw{
|
||||
Id: r.AccountSeqID,
|
||||
Name: r.Name,
|
||||
Description: r.Description,
|
||||
Type: string(r.Type),
|
||||
Address: r.Address,
|
||||
DomainValue: r.Domain,
|
||||
Enabled: r.Enabled,
|
||||
}
|
||||
if seq, ok := e.networkSeq(r.NetworkID); ok {
|
||||
entry.NetworkSeq = seq
|
||||
}
|
||||
if r.Prefix.IsValid() {
|
||||
entry.PrefixCidr = r.Prefix.String()
|
||||
}
|
||||
out = append(out, entry)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (e *componentEncoder) encodeRoutersMap(routersMap map[string]map[string]*routerTypes.NetworkRouter) map[uint32]*proto.NetworkRouterList {
|
||||
if len(routersMap) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make(map[uint32]*proto.NetworkRouterList, len(routersMap))
|
||||
for networkXID, routers := range routersMap {
|
||||
if len(routers) == 0 {
|
||||
continue
|
||||
}
|
||||
netSeq, ok := e.networkSeq(networkXID)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
entries := make([]*proto.NetworkRouterEntry, 0, len(routers))
|
||||
for peerID, r := range routers {
|
||||
if r == nil {
|
||||
continue
|
||||
}
|
||||
entry := &proto.NetworkRouterEntry{
|
||||
Id: r.AccountSeqID,
|
||||
PeerGroupIds: e.groupIDsToSeq(r.PeerGroups),
|
||||
Masquerade: r.Masquerade,
|
||||
Metric: int32(r.Metric),
|
||||
Enabled: r.Enabled,
|
||||
}
|
||||
if idx, ok := e.peerOrder[peerID]; ok {
|
||||
entry.PeerIndexSet = true
|
||||
entry.PeerIndex = idx
|
||||
}
|
||||
entries = append(entries, entry)
|
||||
}
|
||||
out[netSeq] = &proto.NetworkRouterList{Entries: entries}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (e *componentEncoder) encodeResourcePoliciesMap(rpm map[string][]*types.Policy, policyToIdxs map[*types.Policy][]uint32) map[uint32]*proto.PolicyIndexes {
|
||||
if len(rpm) == 0 {
|
||||
return nil
|
||||
}
|
||||
// resourceXIDToSeq is local to one encode — built from components.NetworkResources
|
||||
// (small slice). Network resources without seq id are dropped, matching how
|
||||
// other components-without-seq are silently filtered.
|
||||
resourceXIDToSeq := make(map[string]uint32, len(e.components.NetworkResources))
|
||||
for _, r := range e.components.NetworkResources {
|
||||
if r != nil && r.AccountSeqID != 0 {
|
||||
resourceXIDToSeq[r.ID] = r.AccountSeqID
|
||||
}
|
||||
}
|
||||
out := make(map[uint32]*proto.PolicyIndexes, len(rpm))
|
||||
for resourceXID, policies := range rpm {
|
||||
seq, ok := resourceXIDToSeq[resourceXID]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
idxs := make([]uint32, 0, len(policies)*2)
|
||||
for _, pol := range policies {
|
||||
idxs = append(idxs, policyToIdxs[pol]...)
|
||||
}
|
||||
if len(idxs) == 0 {
|
||||
continue
|
||||
}
|
||||
out[seq] = &proto.PolicyIndexes{Indexes: idxs}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (e *componentEncoder) encodeGroupIDToUserIDs(m map[string][]string) map[uint32]*proto.UserIDList {
|
||||
if len(m) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make(map[uint32]*proto.UserIDList, len(m))
|
||||
for groupID, userIDs := range m {
|
||||
seq, ok := e.groupSeq(groupID)
|
||||
if !ok || len(userIDs) == 0 {
|
||||
continue
|
||||
}
|
||||
out[seq] = &proto.UserIDList{UserIds: userIDs}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func stringSetToSlice(s map[string]struct{}) []string {
|
||||
if len(s) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]string, 0, len(s))
|
||||
for k := range s {
|
||||
out = append(out, k)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (e *componentEncoder) encodePostureFailedPeers(m map[string]map[string]struct{}) map[uint32]*proto.PeerIndexSet {
|
||||
if len(m) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make(map[uint32]*proto.PeerIndexSet, len(m))
|
||||
for checkXID, failedPeerIDs := range m {
|
||||
seq, ok := e.components.PostureCheckXIDToSeq[checkXID]
|
||||
if !ok || seq == 0 {
|
||||
continue
|
||||
}
|
||||
idxs := make([]uint32, 0, len(failedPeerIDs))
|
||||
for peerID := range failedPeerIDs {
|
||||
if idx, ok := e.peerOrder[peerID]; ok {
|
||||
idxs = append(idxs, idx)
|
||||
}
|
||||
}
|
||||
if len(idxs) == 0 {
|
||||
continue
|
||||
}
|
||||
out[seq] = &proto.PeerIndexSet{PeerIndexes: idxs}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// toAccountSettingsCompact always returns a non-nil message — the client
|
||||
// dereferences it unconditionally during Calculate(), so a nil here would
|
||||
// crash the receiver. A missing types.AccountSettingsInfo on the server
|
||||
// (which shouldn't happen in production but the encoder is exported)
|
||||
// degrades to login_expiration_enabled = false, which makes
|
||||
// LoginExpired() return false for every peer.
|
||||
func toAccountSettingsCompact(s *types.AccountSettingsInfo) *proto.AccountSettingsCompact {
|
||||
if s == nil {
|
||||
return &proto.AccountSettingsCompact{}
|
||||
}
|
||||
return &proto.AccountSettingsCompact{
|
||||
PeerLoginExpirationEnabled: s.PeerLoginExpirationEnabled,
|
||||
PeerLoginExpirationNs: int64(s.PeerLoginExpiration),
|
||||
}
|
||||
}
|
||||
|
||||
func toAccountNetwork(n *types.Network) *proto.AccountNetwork {
|
||||
if n == nil {
|
||||
return nil
|
||||
}
|
||||
out := &proto.AccountNetwork{
|
||||
Identifier: n.Identifier,
|
||||
NetCidr: n.Net.String(),
|
||||
Dns: n.Dns,
|
||||
Serial: n.CurrentSerial(),
|
||||
}
|
||||
if len(n.NetV6.IP) > 0 {
|
||||
out.NetV6Cidr = n.NetV6.String()
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func toPeerCompact(p *nbpeer.Peer, agentVersionIdx uint32) *proto.PeerCompact {
|
||||
pc := &proto.PeerCompact{
|
||||
WgPubKey: decodeWgKey(p.Key),
|
||||
SshPubKey: []byte(p.SSHKey),
|
||||
DnsLabel: p.DNSLabel,
|
||||
AgentVersionIdx: agentVersionIdx,
|
||||
AddedWithSsoLogin: p.UserID != "",
|
||||
LoginExpirationEnabled: p.LoginExpirationEnabled,
|
||||
SshEnabled: p.SSHEnabled,
|
||||
SupportsIpv6: p.SupportsIPv6(),
|
||||
SupportsSourcePrefixes: p.SupportsSourcePrefixes(),
|
||||
ServerSshAllowed: p.Meta.Flags.ServerSSHAllowed,
|
||||
}
|
||||
if p.LastLogin != nil {
|
||||
pc.LastLoginUnixNano = p.LastLogin.UnixNano()
|
||||
}
|
||||
switch {
|
||||
case !p.IP.IsValid():
|
||||
// leave Ip nil
|
||||
case p.IP.Is4() || p.IP.Is4In6():
|
||||
ip := p.IP.Unmap().As4()
|
||||
pc.Ip = ip[:]
|
||||
default:
|
||||
ip := p.IP.As16()
|
||||
pc.Ip = ip[:]
|
||||
}
|
||||
if p.IPv6.IsValid() {
|
||||
ip := p.IPv6.As16()
|
||||
pc.Ipv6 = ip[:]
|
||||
}
|
||||
return pc
|
||||
}
|
||||
|
||||
// decodeWgKey returns the raw 32 bytes of a base64-encoded WireGuard public
|
||||
// key, or nil for an empty / malformed key.
|
||||
func decodeWgKey(s string) []byte {
|
||||
if s == "" {
|
||||
return nil
|
||||
}
|
||||
out := make([]byte, wgKeyRawLen)
|
||||
n, err := base64.StdEncoding.Decode(out, []byte(s))
|
||||
if err != nil || n != wgKeyRawLen {
|
||||
return nil
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func portsToUint32(ports []string) []uint32 {
|
||||
if len(ports) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]uint32, 0, len(ports))
|
||||
for _, p := range ports {
|
||||
v, err := strconv.ParseUint(p, 10, 16)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
out = append(out, uint32(v))
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func portRangesToProto(ranges []types.RulePortRange) []*proto.PortInfo_Range {
|
||||
if len(ranges) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]*proto.PortInfo_Range, 0, len(ranges))
|
||||
for _, r := range ranges {
|
||||
out = append(out, &proto.PortInfo_Range{
|
||||
Start: uint32(r.Start),
|
||||
End: uint32(r.End),
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
879
management/internals/shared/grpc/components_encoder_test.go
Normal file
879
management/internals/shared/grpc/components_encoder_test.go
Normal file
@@ -0,0 +1,879 @@
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"cmp"
|
||||
"net"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
goproto "google.golang.org/protobuf/proto"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
nbroute "github.com/netbirdio/netbird/route"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
const testWgKeyA = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopq="
|
||||
const testWgKeyB = "BBCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopq="
|
||||
const testWgKeyC = "CBCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopq="
|
||||
|
||||
// canonicalize rewrites a NetworkMapComponentsFull in place into a canonical
|
||||
// form: peers reordered by wg_pub_key, with the rest of the message rewritten
|
||||
// to reference the new peer indexes. Groups, policies, and router indexes are
|
||||
// also sorted. After canonicalize, two envelopes built from the same logical
|
||||
// input compare byte-equal via proto.Equal.
|
||||
//
|
||||
// This lives on the test side — the encoder itself emits in map-iteration
|
||||
// order. Test-side normalization is the contract for "two encodes are
|
||||
// equivalent".
|
||||
func canonicalize(full *proto.NetworkMapComponentsFull) {
|
||||
if full == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Canonicalize agent_versions first: sort the slice and rewrite each
|
||||
// peer's AgentVersionIdx accordingly. The empty placeholder stays at
|
||||
// index 0 by convention.
|
||||
avRemap := make(map[uint32]uint32, len(full.AgentVersions))
|
||||
if len(full.AgentVersions) > 0 {
|
||||
// Pair version → original index, sort, rebuild.
|
||||
type avEntry struct {
|
||||
version string
|
||||
oldIdx uint32
|
||||
}
|
||||
entries := make([]avEntry, len(full.AgentVersions))
|
||||
for i, v := range full.AgentVersions {
|
||||
entries[i] = avEntry{version: v, oldIdx: uint32(i)}
|
||||
}
|
||||
// Empty stays at 0; sort the rest by string. Tiebreaker on oldIdx
|
||||
// keeps the canonicalize output stable when two entries compare
|
||||
// equal (the encoder dedups, but defending against future inputs).
|
||||
slices.SortFunc(entries, func(a, b avEntry) int {
|
||||
if a.version == "" && b.version != "" {
|
||||
return -1
|
||||
}
|
||||
if b.version == "" && a.version != "" {
|
||||
return 1
|
||||
}
|
||||
if c := cmp.Compare(a.version, b.version); c != 0 {
|
||||
return c
|
||||
}
|
||||
return cmp.Compare(a.oldIdx, b.oldIdx)
|
||||
})
|
||||
newVersions := make([]string, len(entries))
|
||||
for newIdx, e := range entries {
|
||||
avRemap[e.oldIdx] = uint32(newIdx)
|
||||
newVersions[newIdx] = e.version
|
||||
}
|
||||
full.AgentVersions = newVersions
|
||||
}
|
||||
for _, p := range full.Peers {
|
||||
if newIdx, ok := avRemap[p.AgentVersionIdx]; ok {
|
||||
p.AgentVersionIdx = newIdx
|
||||
}
|
||||
}
|
||||
|
||||
type peerEntry struct {
|
||||
peer *proto.PeerCompact
|
||||
oldIdx uint32
|
||||
}
|
||||
entries := make([]peerEntry, len(full.Peers))
|
||||
for i, p := range full.Peers {
|
||||
entries[i] = peerEntry{peer: p, oldIdx: uint32(i)}
|
||||
}
|
||||
// DnsLabel is unique per peer; it tiebreaks on equal WgPubKey (e.g. both
|
||||
// nil from malformed keys, or both empty for placeholders).
|
||||
slices.SortFunc(entries, func(a, b peerEntry) int {
|
||||
if c := bytes.Compare(a.peer.WgPubKey, b.peer.WgPubKey); c != 0 {
|
||||
return c
|
||||
}
|
||||
return cmp.Compare(a.peer.DnsLabel, b.peer.DnsLabel)
|
||||
})
|
||||
|
||||
remap := make(map[uint32]uint32, len(entries))
|
||||
newPeers := make([]*proto.PeerCompact, len(entries))
|
||||
for newIdx, e := range entries {
|
||||
remap[e.oldIdx] = uint32(newIdx)
|
||||
newPeers[newIdx] = e.peer
|
||||
}
|
||||
full.Peers = newPeers
|
||||
|
||||
full.RouterPeerIndexes = remapAndSort(full.RouterPeerIndexes, remap)
|
||||
for _, g := range full.Groups {
|
||||
g.PeerIndexes = remapAndSort(g.PeerIndexes, remap)
|
||||
}
|
||||
slices.SortFunc(full.Groups, func(a, b *proto.GroupCompact) int { return cmp.Compare(a.Id, b.Id) })
|
||||
|
||||
for _, r := range full.Routes {
|
||||
if r.PeerIndexSet {
|
||||
if newIdx, ok := remap[r.PeerIndex]; ok {
|
||||
r.PeerIndex = newIdx
|
||||
}
|
||||
}
|
||||
slices.Sort(r.GroupIds)
|
||||
slices.Sort(r.AccessControlGroupIds)
|
||||
slices.Sort(r.PeerGroupIds)
|
||||
}
|
||||
slices.SortFunc(full.Routes, func(a, b *proto.RouteRaw) int { return cmp.Compare(a.Id, b.Id) })
|
||||
|
||||
for _, list := range full.RoutersMap {
|
||||
for _, entry := range list.Entries {
|
||||
if entry.PeerIndexSet {
|
||||
if newIdx, ok := remap[entry.PeerIndex]; ok {
|
||||
entry.PeerIndex = newIdx
|
||||
}
|
||||
}
|
||||
slices.Sort(entry.PeerGroupIds)
|
||||
}
|
||||
slices.SortFunc(list.Entries, func(a, b *proto.NetworkRouterEntry) int { return cmp.Compare(a.Id, b.Id) })
|
||||
}
|
||||
|
||||
for _, set := range full.PostureFailedPeers {
|
||||
set.PeerIndexes = remapAndSort(set.PeerIndexes, remap)
|
||||
}
|
||||
|
||||
for _, p := range full.Policies {
|
||||
slices.Sort(p.SourceGroupIds)
|
||||
slices.Sort(p.DestinationGroupIds)
|
||||
}
|
||||
// Sort policies by (Id, source_group_ids, destination_group_ids) so that
|
||||
// multiple PolicyCompact entries sharing the same Id (one per rule, when
|
||||
// a Policy has multiple rules) still get a deterministic order. After
|
||||
// sorting we remap indexes in ResourcePoliciesMap.
|
||||
policyOldOrder := make(map[*proto.PolicyCompact]uint32, len(full.Policies))
|
||||
for i, p := range full.Policies {
|
||||
policyOldOrder[p] = uint32(i)
|
||||
}
|
||||
slices.SortFunc(full.Policies, func(a, b *proto.PolicyCompact) int {
|
||||
if c := cmp.Compare(a.Id, b.Id); c != 0 {
|
||||
return c
|
||||
}
|
||||
if c := slices.Compare(a.SourceGroupIds, b.SourceGroupIds); c != 0 {
|
||||
return c
|
||||
}
|
||||
return slices.Compare(a.DestinationGroupIds, b.DestinationGroupIds)
|
||||
})
|
||||
policyRemap := make(map[uint32]uint32, len(full.Policies))
|
||||
for newIdx, p := range full.Policies {
|
||||
policyRemap[policyOldOrder[p]] = uint32(newIdx)
|
||||
}
|
||||
for _, idxs := range full.ResourcePoliciesMap {
|
||||
idxs.Indexes = remapAndSort(idxs.Indexes, policyRemap)
|
||||
}
|
||||
for _, list := range full.GroupIdToUserIds {
|
||||
slices.Sort(list.UserIds)
|
||||
}
|
||||
slices.Sort(full.AllowedUserIds)
|
||||
}
|
||||
|
||||
func remapAndSort(idxs []uint32, remap map[uint32]uint32) []uint32 {
|
||||
out := make([]uint32, 0, len(idxs))
|
||||
for _, i := range idxs {
|
||||
if newIdx, ok := remap[i]; ok {
|
||||
out = append(out, newIdx)
|
||||
}
|
||||
}
|
||||
slices.Sort(out)
|
||||
return out
|
||||
}
|
||||
|
||||
// envelopesEquivalent decodes both envelopes, canonicalizes them, and reports
|
||||
// whether they're proto.Equal. Use instead of byte-comparing marshaled output:
|
||||
// the encoder is intentionally non-deterministic.
|
||||
func envelopesEquivalent(a, b *proto.NetworkMapEnvelope) bool {
|
||||
canonicalize(a.GetFull())
|
||||
canonicalize(b.GetFull())
|
||||
return goproto.Equal(a, b)
|
||||
}
|
||||
|
||||
func newTestComponents() *types.NetworkMapComponents {
|
||||
peerA := &nbpeer.Peer{
|
||||
ID: "peer-a",
|
||||
Key: testWgKeyA,
|
||||
IP: netip.AddrFrom4([4]byte{100, 64, 0, 1}),
|
||||
DNSLabel: "peera",
|
||||
SSHKey: "ssh-a",
|
||||
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()},
|
||||
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.40.0"},
|
||||
}
|
||||
peerB := &nbpeer.Peer{
|
||||
ID: "peer-b",
|
||||
Key: testWgKeyB,
|
||||
IP: netip.AddrFrom4([4]byte{100, 64, 0, 2}),
|
||||
IPv6: netip.AddrFrom16([16]byte{0xfd, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2}),
|
||||
DNSLabel: "peerb",
|
||||
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.25.0"},
|
||||
}
|
||||
peerC := &nbpeer.Peer{
|
||||
ID: "peer-c",
|
||||
Key: testWgKeyC,
|
||||
IP: netip.AddrFrom4([4]byte{100, 64, 0, 3}),
|
||||
DNSLabel: "peerc",
|
||||
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.40.0"},
|
||||
}
|
||||
|
||||
return &types.NetworkMapComponents{
|
||||
PeerID: "peer-a",
|
||||
Network: &types.Network{
|
||||
Identifier: "net-test",
|
||||
Net: net.IPNet{IP: net.IP{100, 64, 0, 0}, Mask: net.CIDRMask(10, 32)},
|
||||
Serial: 7,
|
||||
},
|
||||
AccountSettings: &types.AccountSettingsInfo{
|
||||
PeerLoginExpirationEnabled: true,
|
||||
PeerLoginExpiration: 2 * time.Hour,
|
||||
},
|
||||
Peers: map[string]*nbpeer.Peer{
|
||||
"peer-a": peerA,
|
||||
"peer-b": peerB,
|
||||
"peer-c": peerC,
|
||||
},
|
||||
Groups: map[string]*types.Group{
|
||||
"group-src": {ID: "group-src", AccountSeqID: 1, Name: "Src", Peers: []string{"peer-a"}},
|
||||
"group-dst": {ID: "group-dst", AccountSeqID: 2, Name: "Dst", Peers: []string{"peer-b", "peer-c"}},
|
||||
},
|
||||
Policies: []*types.Policy{
|
||||
{
|
||||
ID: "pol-1",
|
||||
AccountSeqID: 10,
|
||||
Enabled: true,
|
||||
Rules: []*types.PolicyRule{{
|
||||
ID: "rule-1", Enabled: true, Action: types.PolicyTrafficActionAccept,
|
||||
Protocol: types.PolicyRuleProtocolTCP, Bidirectional: true,
|
||||
Ports: []string{"22", "80"},
|
||||
PortRanges: []types.RulePortRange{{Start: 8000, End: 8100}},
|
||||
Sources: []string{"group-src"},
|
||||
Destinations: []string{"group-dst"},
|
||||
}},
|
||||
},
|
||||
},
|
||||
RouterPeers: map[string]*nbpeer.Peer{"peer-c": peerC},
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncodeNetworkMapEnvelope_Basic(t *testing.T) {
|
||||
c := newTestComponents()
|
||||
env := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{
|
||||
Components: c,
|
||||
DNSDomain: "netbird.cloud",
|
||||
})
|
||||
|
||||
require.NotNil(t, env)
|
||||
full := env.GetFull()
|
||||
require.NotNil(t, full, "envelope must contain Full payload")
|
||||
|
||||
assert.EqualValues(t, 7, full.Serial)
|
||||
assert.Equal(t, "netbird.cloud", full.DnsDomain)
|
||||
|
||||
require.NotNil(t, full.Network)
|
||||
assert.Equal(t, "net-test", full.Network.Identifier)
|
||||
assert.Equal(t, "100.64.0.0/10", full.Network.NetCidr)
|
||||
|
||||
require.NotNil(t, full.AccountSettings)
|
||||
assert.True(t, full.AccountSettings.PeerLoginExpirationEnabled)
|
||||
assert.EqualValues(t, (2 * time.Hour).Nanoseconds(), full.AccountSettings.PeerLoginExpirationNs)
|
||||
|
||||
require.Len(t, full.Peers, 3)
|
||||
byLabel := map[string]*proto.PeerCompact{}
|
||||
for _, p := range full.Peers {
|
||||
assert.Len(t, p.WgPubKey, 32, "wg key must be raw 32 bytes")
|
||||
assert.Len(t, p.Ip, 4, "ipv4 must be raw 4 bytes")
|
||||
byLabel[p.DnsLabel] = p
|
||||
}
|
||||
assert.Len(t, byLabel["peerb"].Ipv6, 16, "peer-b has ipv6 → 16 bytes")
|
||||
}
|
||||
|
||||
func TestEncodeNetworkMapEnvelope_RepeatEncodesEquivalent(t *testing.T) {
|
||||
c := newTestComponents()
|
||||
|
||||
expected := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c})
|
||||
|
||||
// Hammer it 100 times — Go map iteration is randomized per call, so each
|
||||
// run produces different wire bytes, but the canonicalized form must
|
||||
// match.
|
||||
for i := 0; i < 100; i++ {
|
||||
got := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c})
|
||||
require.True(t, envelopesEquivalent(expected, got),
|
||||
"encode #%d must be semantically equivalent to first encode", i)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncodeNetworkMapEnvelope_ConcurrentEncodesEquivalent(t *testing.T) {
|
||||
c := newTestComponents()
|
||||
|
||||
expected := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c})
|
||||
|
||||
const goroutines = 50
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(goroutines)
|
||||
results := make([]*proto.NetworkMapEnvelope, goroutines)
|
||||
for i := 0; i < goroutines; i++ {
|
||||
i := i
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
results[i] = EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c})
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
for i, got := range results {
|
||||
require.NotNil(t, got, "goroutine %d returned nil", i)
|
||||
require.True(t, envelopesEquivalent(expected, got),
|
||||
"goroutine %d produced inequivalent envelope", i)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncodeNetworkMapEnvelope_GroupsByAccountSeqID(t *testing.T) {
|
||||
c := newTestComponents()
|
||||
|
||||
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
|
||||
|
||||
require.Len(t, full.Groups, 2)
|
||||
|
||||
groupByID := map[uint32]*proto.GroupCompact{}
|
||||
for _, g := range full.Groups {
|
||||
groupByID[g.Id] = g
|
||||
}
|
||||
require.Contains(t, groupByID, uint32(1))
|
||||
require.Contains(t, groupByID, uint32(2))
|
||||
assert.Equal(t, "Src", groupByID[1].Name)
|
||||
assert.Equal(t, "Dst", groupByID[2].Name)
|
||||
assert.Len(t, groupByID[1].PeerIndexes, 1)
|
||||
assert.Len(t, groupByID[2].PeerIndexes, 2)
|
||||
}
|
||||
|
||||
func TestEncodeNetworkMapEnvelope_PolicyExpansion(t *testing.T) {
|
||||
c := newTestComponents()
|
||||
|
||||
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
|
||||
|
||||
require.Len(t, full.Policies, 1)
|
||||
pc := full.Policies[0]
|
||||
assert.EqualValues(t, 10, pc.Id)
|
||||
assert.Equal(t, proto.RuleAction_ACCEPT, pc.Action)
|
||||
assert.Equal(t, proto.RuleProtocol_TCP, pc.Protocol)
|
||||
assert.True(t, pc.Bidirectional)
|
||||
assert.Equal(t, []uint32{22, 80}, pc.Ports)
|
||||
require.Len(t, pc.PortRanges, 1)
|
||||
assert.EqualValues(t, 8000, pc.PortRanges[0].Start)
|
||||
assert.EqualValues(t, 8100, pc.PortRanges[0].End)
|
||||
assert.Equal(t, []uint32{1}, pc.SourceGroupIds)
|
||||
assert.Equal(t, []uint32{2}, pc.DestinationGroupIds)
|
||||
}
|
||||
|
||||
func TestEncodeNetworkMapEnvelope_RouterIndexes(t *testing.T) {
|
||||
c := newTestComponents()
|
||||
|
||||
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
|
||||
|
||||
require.Len(t, full.RouterPeerIndexes, 1)
|
||||
idx := full.RouterPeerIndexes[0]
|
||||
require.Less(t, int(idx), len(full.Peers))
|
||||
assert.Equal(t, "peerc", full.Peers[idx].DnsLabel)
|
||||
}
|
||||
|
||||
func TestEncodeNetworkMapEnvelope_AgentVersionDedup(t *testing.T) {
|
||||
c := newTestComponents()
|
||||
|
||||
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
|
||||
|
||||
require.Len(t, full.AgentVersions, 3, "empty placeholder + 2 distinct versions")
|
||||
assert.Equal(t, "", full.AgentVersions[0], "index 0 reserved for empty version")
|
||||
assert.ElementsMatch(t, []string{"0.40.0", "0.25.0"}, full.AgentVersions[1:],
|
||||
"two distinct versions, order depends on map iteration")
|
||||
|
||||
idxByLabel := map[string]uint32{}
|
||||
for _, p := range full.Peers {
|
||||
idxByLabel[p.DnsLabel] = p.AgentVersionIdx
|
||||
}
|
||||
assert.Equal(t, idxByLabel["peera"], idxByLabel["peerc"], "peers with the same agent version share an index")
|
||||
assert.NotEqual(t, idxByLabel["peera"], idxByLabel["peerb"])
|
||||
}
|
||||
|
||||
func TestEncodeNetworkMapEnvelope_DisabledPolicySkipped(t *testing.T) {
|
||||
c := newTestComponents()
|
||||
c.Policies[0].Enabled = false
|
||||
|
||||
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
|
||||
|
||||
assert.Empty(t, full.Policies)
|
||||
}
|
||||
|
||||
func TestEncodeNetworkMapEnvelope_GroupZeroSeqIDSkipped(t *testing.T) {
|
||||
c := newTestComponents()
|
||||
c.Groups["group-src"].AccountSeqID = 0
|
||||
|
||||
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
|
||||
|
||||
require.Len(t, full.Groups, 1, "groups with AccountSeqID=0 are not yet persisted and must be skipped")
|
||||
assert.EqualValues(t, 2, full.Groups[0].Id)
|
||||
|
||||
require.Len(t, full.Policies, 1)
|
||||
pc := full.Policies[0]
|
||||
assert.Empty(t, pc.SourceGroupIds, "rule references a group that was filtered out → no group id on wire")
|
||||
assert.Equal(t, []uint32{2}, pc.DestinationGroupIds)
|
||||
}
|
||||
|
||||
func TestEncodeNetworkMapEnvelope_TwoPeersSameMalformedKey(t *testing.T) {
|
||||
// Both peers have nil WgPubKey after decode; canonicalize must still
|
||||
// produce a stable order using DnsLabel as a tiebreaker, so 100 encodes
|
||||
// canonicalize identically.
|
||||
c := newTestComponents()
|
||||
c.Peers["peer-a"].Key = "garbage-a-!!!"
|
||||
c.Peers["peer-b"].Key = "garbage-b-!!!"
|
||||
|
||||
expected := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c})
|
||||
for i := 0; i < 100; i++ {
|
||||
got := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c})
|
||||
require.True(t, envelopesEquivalent(expected, got),
|
||||
"encode #%d with two same-key peers must canonicalize equivalently", i)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncodeNetworkMapEnvelope_MalformedWgKey(t *testing.T) {
|
||||
c := newTestComponents()
|
||||
c.Peers["peer-a"].Key = "not-base64-!!!"
|
||||
|
||||
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
|
||||
|
||||
require.Len(t, full.Peers, 3)
|
||||
|
||||
var byLabel = map[string]*proto.PeerCompact{}
|
||||
for _, p := range full.Peers {
|
||||
byLabel[p.DnsLabel] = p
|
||||
}
|
||||
assert.Nil(t, byLabel["peera"].WgPubKey, "peer with malformed key encodes nil WgPubKey")
|
||||
assert.Len(t, byLabel["peerb"].WgPubKey, 32, "other peers retain their key")
|
||||
}
|
||||
|
||||
func TestEncodeNetworkMapEnvelope_IPv6OnlyPeer(t *testing.T) {
|
||||
c := newTestComponents()
|
||||
v6Only := &nbpeer.Peer{
|
||||
ID: "peer-v6",
|
||||
Key: testWgKeyA,
|
||||
IPv6: netip.AddrFrom16([16]byte{0xfd, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 9}),
|
||||
DNSLabel: "peerv6",
|
||||
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.40.0"},
|
||||
}
|
||||
c.Peers["peer-v6"] = v6Only
|
||||
|
||||
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
|
||||
|
||||
var found *proto.PeerCompact
|
||||
for _, p := range full.Peers {
|
||||
if p.DnsLabel == "peerv6" {
|
||||
found = p
|
||||
}
|
||||
}
|
||||
require.NotNil(t, found, "ipv6-only peer must be present")
|
||||
assert.Empty(t, found.Ip, "no IPv4 address → empty Ip")
|
||||
assert.Len(t, found.Ipv6, 16)
|
||||
}
|
||||
|
||||
func TestEncodeNetworkMapEnvelope_PeerWithoutIP(t *testing.T) {
|
||||
c := newTestComponents()
|
||||
c.Peers["peer-noip"] = &nbpeer.Peer{
|
||||
ID: "peer-noip",
|
||||
Key: testWgKeyA,
|
||||
DNSLabel: "peernoip",
|
||||
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.40.0"},
|
||||
}
|
||||
|
||||
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
|
||||
|
||||
var found *proto.PeerCompact
|
||||
for _, p := range full.Peers {
|
||||
if p.DnsLabel == "peernoip" {
|
||||
found = p
|
||||
}
|
||||
}
|
||||
require.NotNil(t, found)
|
||||
assert.Empty(t, found.Ip)
|
||||
assert.Empty(t, found.Ipv6)
|
||||
}
|
||||
|
||||
func TestEncodeNetworkMapEnvelope_EmptyInput(t *testing.T) {
|
||||
c := &types.NetworkMapComponents{
|
||||
Network: &types.Network{Identifier: "x", Net: net.IPNet{IP: net.IP{100, 64, 0, 0}, Mask: net.CIDRMask(10, 32)}},
|
||||
}
|
||||
|
||||
env := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c})
|
||||
|
||||
full := env.GetFull()
|
||||
require.NotNil(t, full)
|
||||
assert.Empty(t, full.Peers)
|
||||
assert.Empty(t, full.Groups)
|
||||
assert.Empty(t, full.Policies)
|
||||
assert.Empty(t, full.RouterPeerIndexes)
|
||||
require.NotNil(t, full.AccountSettings, "AccountSettingsCompact must always be emitted (client dereferences it unconditionally)")
|
||||
}
|
||||
|
||||
func TestEncodeNetworkMapEnvelope_PeerLoginExpirationFields(t *testing.T) {
|
||||
c := newTestComponents()
|
||||
now := time.Date(2024, 1, 2, 3, 4, 5, 0, time.UTC)
|
||||
c.Peers["peer-a"].UserID = "user-1"
|
||||
c.Peers["peer-a"].LoginExpirationEnabled = true
|
||||
c.Peers["peer-a"].LastLogin = &now
|
||||
|
||||
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
|
||||
|
||||
var pa *proto.PeerCompact
|
||||
for _, p := range full.Peers {
|
||||
if p.DnsLabel == "peera" {
|
||||
pa = p
|
||||
}
|
||||
}
|
||||
require.NotNil(t, pa)
|
||||
assert.True(t, pa.AddedWithSsoLogin)
|
||||
assert.True(t, pa.LoginExpirationEnabled)
|
||||
assert.Equal(t, now.UnixNano(), pa.LastLoginUnixNano)
|
||||
|
||||
// peer-b has no UserID and no LastLogin → all fields zero-value.
|
||||
var pb *proto.PeerCompact
|
||||
for _, p := range full.Peers {
|
||||
if p.DnsLabel == "peerb" {
|
||||
pb = p
|
||||
}
|
||||
}
|
||||
require.NotNil(t, pb)
|
||||
assert.False(t, pb.AddedWithSsoLogin)
|
||||
assert.False(t, pb.LoginExpirationEnabled)
|
||||
assert.Zero(t, pb.LastLoginUnixNano)
|
||||
}
|
||||
|
||||
func TestEncodeNetworkMapEnvelope_RoutesRoundTrip(t *testing.T) {
|
||||
c := newTestComponents()
|
||||
c.Routes = []*nbroute.Route{
|
||||
{
|
||||
ID: "route-peer",
|
||||
AccountSeqID: 100,
|
||||
NetID: "net-A",
|
||||
Description: "via peer-c",
|
||||
Network: netip.MustParsePrefix("10.0.0.0/16"),
|
||||
Peer: "peer-c", // peer ID, not WG key
|
||||
Groups: []string{"group-src"},
|
||||
AccessControlGroups: []string{"group-dst"},
|
||||
Enabled: true,
|
||||
},
|
||||
{
|
||||
ID: "route-peergroup",
|
||||
AccountSeqID: 101,
|
||||
NetID: "net-B",
|
||||
Network: netip.MustParsePrefix("10.1.0.0/16"),
|
||||
PeerGroups: []string{"group-src", "group-dst"},
|
||||
Enabled: true,
|
||||
},
|
||||
{
|
||||
ID: "route-no-seq",
|
||||
AccountSeqID: 0, // unset — should still ship (no group seq filter on routes)
|
||||
Network: netip.MustParsePrefix("10.2.0.0/16"),
|
||||
Enabled: true,
|
||||
},
|
||||
}
|
||||
|
||||
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
|
||||
|
||||
require.Len(t, full.Routes, 3)
|
||||
byNetID := map[string]*proto.RouteRaw{}
|
||||
for _, r := range full.Routes {
|
||||
byNetID[r.NetId] = r
|
||||
}
|
||||
|
||||
r1 := byNetID["net-A"]
|
||||
require.NotNil(t, r1)
|
||||
assert.True(t, r1.PeerIndexSet, "route with peer must set peer_index_set")
|
||||
require.Less(t, int(r1.PeerIndex), len(full.Peers))
|
||||
assert.Equal(t, "peerc", full.Peers[r1.PeerIndex].DnsLabel)
|
||||
assert.Equal(t, []uint32{1}, r1.GroupIds, "group-src has AccountSeqID 1")
|
||||
assert.Equal(t, []uint32{2}, r1.AccessControlGroupIds, "group-dst has AccountSeqID 2")
|
||||
assert.Empty(t, r1.PeerGroupIds)
|
||||
|
||||
r2 := byNetID["net-B"]
|
||||
require.NotNil(t, r2)
|
||||
assert.False(t, r2.PeerIndexSet, "route with peer_groups must NOT set peer_index_set")
|
||||
assert.ElementsMatch(t, []uint32{1, 2}, r2.PeerGroupIds)
|
||||
}
|
||||
|
||||
func TestEncodeNetworkMapEnvelope_RouteWithMissingPeerLeavesIndexUnset(t *testing.T) {
|
||||
c := newTestComponents()
|
||||
c.Routes = []*nbroute.Route{{
|
||||
ID: "route-x",
|
||||
AccountSeqID: 100,
|
||||
Peer: "peer-not-in-components",
|
||||
Network: netip.MustParsePrefix("10.0.0.0/16"),
|
||||
Enabled: true,
|
||||
}}
|
||||
|
||||
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
|
||||
|
||||
require.Len(t, full.Routes, 1)
|
||||
assert.False(t, full.Routes[0].PeerIndexSet,
|
||||
"missing peer reference must not pretend to point at peer index 0")
|
||||
}
|
||||
|
||||
func TestEncodeNetworkMapEnvelope_ResourceOnlyPolicyShippedAndIndexed(t *testing.T) {
|
||||
c := newTestComponents()
|
||||
// Policy that exists ONLY in ResourcePoliciesMap, not in c.Policies. This
|
||||
// is the I1 case — without unionPolicies the encoder would silently
|
||||
// drop it from the wire.
|
||||
resourceOnlyPolicy := &types.Policy{
|
||||
ID: "pol-resource", AccountSeqID: 99, Enabled: true,
|
||||
Rules: []*types.PolicyRule{{
|
||||
ID: "rule-r", Enabled: true, Action: types.PolicyTrafficActionAccept,
|
||||
Protocol: types.PolicyRuleProtocolTCP,
|
||||
Sources: []string{"group-src"},
|
||||
Destinations: []string{"group-dst"},
|
||||
}},
|
||||
}
|
||||
c.ResourcePoliciesMap = map[string][]*types.Policy{
|
||||
"resource-x": {c.Policies[0], resourceOnlyPolicy}, // shared + resource-only
|
||||
}
|
||||
// Resource must appear in components.NetworkResources with a seq id —
|
||||
// encoder uses that to translate the xid map key to uint32.
|
||||
c.NetworkResources = []*resourceTypes.NetworkResource{
|
||||
{ID: "resource-x", AccountSeqID: 77, Name: "res-x", Enabled: true},
|
||||
}
|
||||
|
||||
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
|
||||
|
||||
require.Len(t, full.Policies, 2, "encoded policies must include both peer-traffic and resource-only")
|
||||
|
||||
policyByID := map[uint32]*proto.PolicyCompact{}
|
||||
policyIdxByID := map[uint32]uint32{}
|
||||
for i, p := range full.Policies {
|
||||
policyByID[p.Id] = p
|
||||
policyIdxByID[p.Id] = uint32(i)
|
||||
}
|
||||
require.Contains(t, policyByID, uint32(10), "original peer-traffic policy id 10")
|
||||
require.Contains(t, policyByID, uint32(99), "resource-only policy id 99")
|
||||
|
||||
require.Contains(t, full.ResourcePoliciesMap, uint32(77))
|
||||
idxs := full.ResourcePoliciesMap[77].Indexes
|
||||
require.Len(t, idxs, 2)
|
||||
assert.ElementsMatch(t, []uint32{policyIdxByID[10], policyIdxByID[99]}, idxs,
|
||||
"resource policies map must reference both wire policy indexes")
|
||||
}
|
||||
|
||||
func TestEncodeNetworkMapEnvelope_NameServerGroups(t *testing.T) {
|
||||
c := newTestComponents()
|
||||
c.NameServerGroups = []*nbdns.NameServerGroup{{
|
||||
ID: "nsg-1", AccountSeqID: 50, Name: "Main", Description: "primary",
|
||||
NameServers: []nbdns.NameServer{{
|
||||
IP: netip.MustParseAddr("8.8.8.8"), NSType: nbdns.UDPNameServerType, Port: 53,
|
||||
}},
|
||||
Groups: []string{"group-src", "group-not-persisted"},
|
||||
Primary: true, Enabled: true,
|
||||
Domains: []string{"corp.example"},
|
||||
}}
|
||||
c.Groups["group-not-persisted"] = &types.Group{ID: "group-not-persisted", AccountSeqID: 0, Peers: []string{}}
|
||||
|
||||
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
|
||||
|
||||
require.Len(t, full.NameserverGroups, 1)
|
||||
nsg := full.NameserverGroups[0]
|
||||
assert.EqualValues(t, 50, nsg.Id)
|
||||
assert.Equal(t, "Main", nsg.Name)
|
||||
assert.True(t, nsg.Primary)
|
||||
require.Len(t, nsg.Nameservers, 1)
|
||||
assert.Equal(t, "8.8.8.8", nsg.Nameservers[0].IP)
|
||||
assert.Equal(t, []uint32{1}, nsg.GroupIds, "group-not-persisted is filtered out (AccountSeqID=0)")
|
||||
}
|
||||
|
||||
func TestEncodeNetworkMapEnvelope_PostureFailedPeers(t *testing.T) {
|
||||
c := newTestComponents()
|
||||
c.PostureCheckXIDToSeq = map[string]uint32{"check-1": 33}
|
||||
c.PostureFailedPeers = map[string]map[string]struct{}{
|
||||
"check-1": {
|
||||
"peer-a": {},
|
||||
"peer-b": {},
|
||||
"peer-not-in-account": {},
|
||||
},
|
||||
}
|
||||
|
||||
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
|
||||
|
||||
require.Contains(t, full.PostureFailedPeers, uint32(33))
|
||||
idxs := full.PostureFailedPeers[33].PeerIndexes
|
||||
assert.Len(t, idxs, 2, "missing peer is silently dropped (filterPostureFailedPeers guarantees presence in real data)")
|
||||
}
|
||||
|
||||
func TestEncodeNetworkMapEnvelope_RoutersMap(t *testing.T) {
|
||||
c := newTestComponents()
|
||||
c.NetworkXIDToSeq = map[string]uint32{"net-1": 5}
|
||||
c.RoutersMap = map[string]map[string]*routerTypes.NetworkRouter{
|
||||
"net-1": {
|
||||
"peer-c": {
|
||||
ID: "router-1", AccountSeqID: 200,
|
||||
Peer: "peer-c", Masquerade: true, Metric: 10, Enabled: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
|
||||
|
||||
require.Contains(t, full.RoutersMap, uint32(5))
|
||||
entries := full.RoutersMap[5].Entries
|
||||
require.Len(t, entries, 1)
|
||||
e := entries[0]
|
||||
assert.EqualValues(t, 200, e.Id)
|
||||
assert.True(t, e.PeerIndexSet)
|
||||
require.Less(t, int(e.PeerIndex), len(full.Peers))
|
||||
assert.Equal(t, "peerc", full.Peers[e.PeerIndex].DnsLabel)
|
||||
assert.True(t, e.Masquerade)
|
||||
assert.EqualValues(t, 10, e.Metric)
|
||||
assert.True(t, e.Enabled)
|
||||
}
|
||||
|
||||
func TestEncodeNetworkMapEnvelope_RouterPeerNotInComponentsPeers(t *testing.T) {
|
||||
// Router peer in c.RouterPeers but NOT in c.Peers (validation may have
|
||||
// filtered it). indexRouterPeers runs before encodeRoutersMap, so the
|
||||
// peer_index reference must still resolve.
|
||||
c := newTestComponents()
|
||||
delete(c.Peers, "peer-c")
|
||||
routerPeer := &nbpeer.Peer{
|
||||
ID: "peer-c", Key: testWgKeyC, IP: netip.AddrFrom4([4]byte{100, 64, 0, 3}),
|
||||
DNSLabel: "peerc", Meta: nbpeer.PeerSystemMeta{WtVersion: "0.40.0"},
|
||||
}
|
||||
c.RouterPeers = map[string]*nbpeer.Peer{"peer-c": routerPeer}
|
||||
c.NetworkXIDToSeq = map[string]uint32{"net-1": 5}
|
||||
c.RoutersMap = map[string]map[string]*routerTypes.NetworkRouter{
|
||||
"net-1": {"peer-c": {ID: "r-1", AccountSeqID: 1, Peer: "peer-c", Enabled: true}},
|
||||
}
|
||||
|
||||
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
|
||||
|
||||
require.Contains(t, full.RoutersMap, uint32(5))
|
||||
require.Len(t, full.RoutersMap[5].Entries, 1)
|
||||
e := full.RoutersMap[5].Entries[0]
|
||||
assert.True(t, e.PeerIndexSet, "router peer must be indexed even when not in c.Peers")
|
||||
}
|
||||
|
||||
func TestEncodeNetworkMapEnvelope_DNSSettingsFiltersUnpersistedGroups(t *testing.T) {
|
||||
c := newTestComponents()
|
||||
c.DNSSettings = &types.DNSSettings{
|
||||
DisabledManagementGroups: []string{"group-src", "group-missing", "group-no-seq"},
|
||||
}
|
||||
c.Groups["group-no-seq"] = &types.Group{ID: "group-no-seq", AccountSeqID: 0}
|
||||
|
||||
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
|
||||
|
||||
require.NotNil(t, full.DnsSettings)
|
||||
assert.Equal(t, []uint32{1}, full.DnsSettings.DisabledManagementGroupIds,
|
||||
"only group-src (AccountSeqID=1) survives — missing and unpersisted are dropped")
|
||||
}
|
||||
|
||||
func TestEncodeNetworkMapEnvelope_GroupIDToUserIDs(t *testing.T) {
|
||||
c := newTestComponents()
|
||||
c.GroupIDToUserIDs = map[string][]string{
|
||||
"group-src": {"user-1", "user-2"},
|
||||
"group-no-seq": {"user-3"}, // group not persisted → drop
|
||||
"group-missing": {"user-4"}, // group not in components → drop
|
||||
}
|
||||
c.Groups["group-no-seq"] = &types.Group{ID: "group-no-seq", AccountSeqID: 0}
|
||||
|
||||
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
|
||||
|
||||
require.Len(t, full.GroupIdToUserIds, 1, "only persisted+present groups survive")
|
||||
require.Contains(t, full.GroupIdToUserIds, uint32(1))
|
||||
assert.ElementsMatch(t, []string{"user-1", "user-2"}, full.GroupIdToUserIds[1].UserIds)
|
||||
}
|
||||
|
||||
func TestToProxyPatch_EmptyInputReturnsNil(t *testing.T) {
|
||||
assert.Nil(t, toProxyPatch(nil, "netbird.cloud", false, false))
|
||||
assert.Nil(t, toProxyPatch(&types.NetworkMap{}, "netbird.cloud", false, false),
|
||||
"empty NetworkMap (no peers, rules, routes etc) → nil patch so proto3 omits the field")
|
||||
}
|
||||
|
||||
func TestToProxyPatch_PopulatesAllFields(t *testing.T) {
|
||||
nm := &types.NetworkMap{
|
||||
Peers: []*nbpeer.Peer{{
|
||||
ID: "ext-peer", Key: testWgKeyA, IP: netip.AddrFrom4([4]byte{100, 64, 0, 9}),
|
||||
DNSLabel: "extpeer", Meta: nbpeer.PeerSystemMeta{WtVersion: "0.40.0"},
|
||||
}},
|
||||
FirewallRules: []*types.FirewallRule{{
|
||||
PeerIP: "100.64.0.9", Action: "accept", Direction: 0, Protocol: "tcp",
|
||||
}},
|
||||
}
|
||||
|
||||
patch := toProxyPatch(nm, "netbird.cloud", false, false)
|
||||
|
||||
require.NotNil(t, patch)
|
||||
assert.Len(t, patch.Peers, 1)
|
||||
assert.Len(t, patch.FirewallRules, 1)
|
||||
}
|
||||
|
||||
// TestEncodeNetworkMapEnvelope_ProxyPatchPropagated covers the ProxyPatch
|
||||
// pass-through in both encoder branches (normal path + nil-Components
|
||||
// graceful-degrade). Guards against a regression that drops `ProxyPatch:`
|
||||
// from one of the envelope struct literals.
|
||||
func TestEncodeNetworkMapEnvelope_ProxyPatchPropagated(t *testing.T) {
|
||||
patch := &proto.ProxyPatch{
|
||||
ForwardingRules: []*proto.ForwardingRule{{
|
||||
Protocol: proto.RuleProtocol_TCP,
|
||||
DestinationPort: &proto.PortInfo{PortSelection: &proto.PortInfo_Port{Port: 80}},
|
||||
TranslatedAddress: net.IPv4(10, 0, 0, 1).To4(),
|
||||
TranslatedPort: &proto.PortInfo{PortSelection: &proto.PortInfo_Port{Port: 8080}},
|
||||
}},
|
||||
}
|
||||
|
||||
t.Run("normal_path", func(t *testing.T) {
|
||||
c := newTestComponents()
|
||||
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{
|
||||
Components: c,
|
||||
ProxyPatch: patch,
|
||||
}).GetFull()
|
||||
|
||||
require.NotNil(t, full.ProxyPatch, "ProxyPatch must propagate through the normal encode path")
|
||||
assert.Len(t, full.ProxyPatch.ForwardingRules, 1)
|
||||
})
|
||||
|
||||
t.Run("nil_components_graceful_degrade", func(t *testing.T) {
|
||||
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{
|
||||
Components: nil,
|
||||
ProxyPatch: patch,
|
||||
}).GetFull()
|
||||
|
||||
require.NotNil(t, full.ProxyPatch, "ProxyPatch must propagate through the nil-Components branch too")
|
||||
assert.Len(t, full.ProxyPatch.ForwardingRules, 1)
|
||||
})
|
||||
}
|
||||
|
||||
func TestEncodeNetworkMapEnvelope_NilComponentsGracefulDegrade(t *testing.T) {
|
||||
// nil Components → minimal envelope, no crash. Matches the legacy
|
||||
// behaviour for missing/unvalidated peers.
|
||||
env := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{
|
||||
Components: nil,
|
||||
DNSDomain: "netbird.cloud",
|
||||
})
|
||||
|
||||
require.NotNil(t, env)
|
||||
full := env.GetFull()
|
||||
require.NotNil(t, full)
|
||||
require.NotNil(t, full.AccountSettings, "AccountSettings must always be non-nil")
|
||||
assert.Equal(t, "netbird.cloud", full.DnsDomain)
|
||||
assert.Empty(t, full.Peers)
|
||||
assert.Empty(t, full.Policies)
|
||||
}
|
||||
|
||||
func TestEncodeNetworkMapEnvelope_AccountSettingsAlwaysEmitted(t *testing.T) {
|
||||
c := &types.NetworkMapComponents{
|
||||
Network: &types.Network{Identifier: "x", Net: net.IPNet{IP: net.IP{100, 64, 0, 0}, Mask: net.CIDRMask(10, 32)}},
|
||||
// AccountSettings deliberately nil
|
||||
}
|
||||
|
||||
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
|
||||
|
||||
require.NotNil(t, full.AccountSettings, "client dereferences AccountSettings unconditionally during Calculate(); a nil here would crash the receiver")
|
||||
assert.False(t, full.AccountSettings.PeerLoginExpirationEnabled)
|
||||
assert.Zero(t, full.AccountSettings.PeerLoginExpirationNs)
|
||||
}
|
||||
192
management/internals/shared/grpc/components_envelope_response.go
Normal file
192
management/internals/shared/grpc/components_envelope_response.go
Normal file
@@ -0,0 +1,192 @@
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
integrationsConfig "github.com/netbirdio/management-integrations/integrations/config"
|
||||
|
||||
"github.com/netbirdio/netbird/client/ssh/auth"
|
||||
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/networkmap"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
// ToComponentSyncResponse builds a SyncResponse carrying the compact
|
||||
// NetworkMapEnvelope for capability-aware peers. The legacy proto.NetworkMap
|
||||
// field is intentionally left empty — capable peers ignore it and the
|
||||
// envelope alone is the authoritative wire shape.
|
||||
//
|
||||
// PeerConfig is computed once server-side using the receiving peer's own
|
||||
// account-level network metadata. EnableSSH inside PeerConfig is left at
|
||||
// peer.SSHEnabled (the peer's local setting); account-policy-driven SSH is
|
||||
// computed by the client from the envelope's GroupIDToUserIDs / AllowedUserIDs
|
||||
// inside Calculate(), so the SshConfig.SshEnabled bit may flip true on the
|
||||
// client even though the server-side PeerConfig reports false.
|
||||
func ToComponentSyncResponse(
|
||||
ctx context.Context,
|
||||
config *nbconfig.Config,
|
||||
httpConfig *nbconfig.HttpServerConfig,
|
||||
deviceFlowConfig *nbconfig.DeviceAuthorizationFlow,
|
||||
peer *nbpeer.Peer,
|
||||
turnCredentials *Token,
|
||||
relayCredentials *Token,
|
||||
components *types.NetworkMapComponents,
|
||||
proxyPatch *types.NetworkMap,
|
||||
dnsName string,
|
||||
checks []*posture.Checks,
|
||||
settings *types.Settings,
|
||||
extraSettings *types.ExtraSettings,
|
||||
peerGroups []string,
|
||||
dnsFwdPort int64,
|
||||
) *proto.SyncResponse {
|
||||
network := networkOrZero(components)
|
||||
enableSSH := computeSSHEnabledForPeer(components, peer)
|
||||
peerConfig := toPeerConfig(peer, network, dnsName, settings, httpConfig, deviceFlowConfig, enableSSH)
|
||||
|
||||
includeIPv6 := peer.SupportsIPv6() && peer.IPv6.IsValid()
|
||||
useSourcePrefixes := peer.SupportsSourcePrefixes()
|
||||
|
||||
userIDClaim := auth.DefaultUserIDClaim
|
||||
if httpConfig != nil && httpConfig.AuthUserIDClaim != "" {
|
||||
userIDClaim = httpConfig.AuthUserIDClaim
|
||||
}
|
||||
|
||||
envelope := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{
|
||||
Components: components,
|
||||
PeerConfig: peerConfig,
|
||||
DNSDomain: dnsName,
|
||||
DNSForwarderPort: dnsFwdPort,
|
||||
UserIDClaim: userIDClaim,
|
||||
ProxyPatch: toProxyPatch(proxyPatch, dnsName, includeIPv6, useSourcePrefixes),
|
||||
})
|
||||
|
||||
resp := &proto.SyncResponse{
|
||||
PeerConfig: peerConfig,
|
||||
NetworkMapEnvelope: envelope,
|
||||
Checks: toProtocolChecks(ctx, checks),
|
||||
}
|
||||
|
||||
nbConfig := toNetbirdConfig(config, turnCredentials, relayCredentials, extraSettings)
|
||||
resp.NetbirdConfig = integrationsConfig.ExtendNetBirdConfig(peer.ID, peerGroups, nbConfig, extraSettings)
|
||||
|
||||
return resp
|
||||
}
|
||||
|
||||
// networkOrZero returns components.Network or a zero Network — toPeerConfig
|
||||
// dereferences network.Net which would panic on nil.
|
||||
func networkOrZero(c *types.NetworkMapComponents) *types.Network {
|
||||
if c == nil || c.Network == nil {
|
||||
return &types.Network{}
|
||||
}
|
||||
return c.Network
|
||||
}
|
||||
|
||||
// toProxyPatch converts a proxy-injected *types.NetworkMap into the wire
|
||||
// patch the components envelope ships alongside. Returns nil when there are
|
||||
// no fragments to merge — proto3 omits a nil message field, so the receiver
|
||||
// sees no patch and skips the merge step entirely.
|
||||
//
|
||||
// We reuse the legacy proto-conversion helpers (toProtocolRoutes,
|
||||
// toProtocolFirewallRules, toProtocolRoutesFirewallRules,
|
||||
// appendRemotePeerConfig, ForwardingRule.ToProto) because the proxy
|
||||
// delivers fragments pre-expanded — there's no raw component shape to
|
||||
// derive them from. Components purity isn't violated: proxy data isn't
|
||||
// policy-graph-derived, it's externally injected post-Calculate, so the
|
||||
// client merges it on top of its locally-computed NetworkMap.
|
||||
func toProxyPatch(nm *types.NetworkMap, dnsName string, includeIPv6, useSourcePrefixes bool) *proto.ProxyPatch {
|
||||
if nm == nil {
|
||||
return nil
|
||||
}
|
||||
if len(nm.Peers) == 0 && len(nm.OfflinePeers) == 0 && len(nm.FirewallRules) == 0 &&
|
||||
len(nm.Routes) == 0 && len(nm.RoutesFirewallRules) == 0 && len(nm.ForwardingRules) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
patch := &proto.ProxyPatch{
|
||||
Peers: networkmap.AppendRemotePeerConfig(nil, nm.Peers, dnsName, includeIPv6),
|
||||
OfflinePeers: networkmap.AppendRemotePeerConfig(nil, nm.OfflinePeers, dnsName, includeIPv6),
|
||||
FirewallRules: networkmap.ToProtocolFirewallRules(nm.FirewallRules, includeIPv6, useSourcePrefixes),
|
||||
Routes: networkmap.ToProtocolRoutes(nm.Routes),
|
||||
RouteFirewallRules: networkmap.ToProtocolRoutesFirewallRules(nm.RoutesFirewallRules),
|
||||
}
|
||||
if len(nm.ForwardingRules) > 0 {
|
||||
patch.ForwardingRules = make([]*proto.ForwardingRule, 0, len(nm.ForwardingRules))
|
||||
for _, r := range nm.ForwardingRules {
|
||||
patch.ForwardingRules = append(patch.ForwardingRules, r.ToProto())
|
||||
}
|
||||
}
|
||||
return patch
|
||||
}
|
||||
|
||||
// computeSSHEnabledForPeer mirrors the SSH-server-activation bit that
|
||||
// Calculate() folds into NetworkMap.EnableSSH. Components-format peers
|
||||
// receive a freshly-computed PeerConfig.SshConfig.SshEnabled at sync time;
|
||||
// without this helper the field would be incorrectly false for any peer
|
||||
// that's the destination of an SSH-enabling policy without having
|
||||
// peer.SSHEnabled set locally.
|
||||
//
|
||||
// Mirrors the two activation paths Calculate() uses:
|
||||
// 1. Explicit: rule.Protocol == NetbirdSSH and peer is in the rule's
|
||||
// destinations.
|
||||
// 2. Legacy implicit: rule covers TCP/22 or TCP/22022 (or ALL), peer is in
|
||||
// destinations, AND the peer has SSHEnabled set locally — this is the
|
||||
// "allow-all/TCP-22 implies SSH activation for SSH-capable peers" path.
|
||||
//
|
||||
// The full SSH AuthorizedUsers map is still produced by the client when it
|
||||
// runs Calculate() over the envelope.
|
||||
func computeSSHEnabledForPeer(c *types.NetworkMapComponents, peer *nbpeer.Peer) bool {
|
||||
if c == nil || peer == nil {
|
||||
return false
|
||||
}
|
||||
// Mirror Calculate's `getAllPeersFromGroups` invariant: target peer must
|
||||
// exist in c.Peers, otherwise no rule applies to it.
|
||||
if _, ok := c.Peers[peer.ID]; !ok {
|
||||
return false
|
||||
}
|
||||
for _, policy := range c.Policies {
|
||||
if policy == nil || !policy.Enabled {
|
||||
continue
|
||||
}
|
||||
for _, rule := range policy.Rules {
|
||||
if ruleEnablesSSHForPeer(c, rule, peer) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// ruleEnablesSSHForPeer returns true when rule is active, targets peer, and
|
||||
// either explicitly authorises SSH or covers the legacy TCP/22 path while the
|
||||
// peer itself has SSH enabled locally.
|
||||
func ruleEnablesSSHForPeer(c *types.NetworkMapComponents, rule *types.PolicyRule, peer *nbpeer.Peer) bool {
|
||||
if rule == nil || !rule.Enabled {
|
||||
return false
|
||||
}
|
||||
if !peerInDestinations(c, rule, peer.ID) {
|
||||
return false
|
||||
}
|
||||
if rule.Protocol == types.PolicyRuleProtocolNetbirdSSH {
|
||||
return true
|
||||
}
|
||||
return peer.SSHEnabled && types.PolicyRuleImpliesLegacySSH(rule)
|
||||
}
|
||||
|
||||
// peerInDestinations reports whether peerID is in any of rule.Destinations'
|
||||
// groups (or matches DestinationResource if it's a peer-typed resource —
|
||||
// for non-peer types Calculate falls through to group lookup, so we mirror
|
||||
// that exactly to avoid silent divergence).
|
||||
func peerInDestinations(c *types.NetworkMapComponents, rule *types.PolicyRule, peerID string) bool {
|
||||
if rule.DestinationResource.Type == types.ResourceTypePeer && rule.DestinationResource.ID != "" {
|
||||
return rule.DestinationResource.ID == peerID
|
||||
}
|
||||
for _, groupID := range rule.Destinations {
|
||||
if c.IsPeerInGroup(peerID, groupID) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
@@ -0,0 +1,184 @@
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
// TestComputeSSHEnabledForPeer covers both Calculate-mirroring branches:
|
||||
// explicit NetbirdSSH protocol, and the legacy implicit case where a
|
||||
// TCP/22 (or 22022 / ALL / port-range-covering-22) rule activates SSH when
|
||||
// the destination peer has SSHEnabled=true locally.
|
||||
func TestComputeSSHEnabledForPeer(t *testing.T) {
|
||||
const targetPeerID = "target"
|
||||
const targetGroupID = "g_dst"
|
||||
|
||||
mkComponents := func(rule *types.PolicyRule, sshEnabled bool) (*types.NetworkMapComponents, *nbpeer.Peer) {
|
||||
peer := &nbpeer.Peer{ID: targetPeerID, SSHEnabled: sshEnabled}
|
||||
group := &types.Group{ID: targetGroupID, Name: "dst", Peers: []string{targetPeerID}}
|
||||
return &types.NetworkMapComponents{
|
||||
Peers: map[string]*nbpeer.Peer{targetPeerID: peer},
|
||||
Groups: map[string]*types.Group{targetGroupID: group},
|
||||
Policies: []*types.Policy{{
|
||||
ID: "p",
|
||||
Enabled: true,
|
||||
Rules: []*types.PolicyRule{rule},
|
||||
}},
|
||||
}, peer
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
peerSSH bool
|
||||
rule types.PolicyRule
|
||||
wantEnabled bool
|
||||
}{
|
||||
{
|
||||
name: "explicit-netbird-ssh-activates-regardless-of-peer-ssh",
|
||||
peerSSH: false,
|
||||
rule: types.PolicyRule{
|
||||
Enabled: true, Protocol: types.PolicyRuleProtocolNetbirdSSH,
|
||||
Destinations: []string{targetGroupID},
|
||||
},
|
||||
wantEnabled: true,
|
||||
},
|
||||
{
|
||||
name: "implicit-tcp-22-with-peer-ssh",
|
||||
peerSSH: true,
|
||||
rule: types.PolicyRule{
|
||||
Enabled: true, Protocol: types.PolicyRuleProtocolTCP, Ports: []string{"22"},
|
||||
Destinations: []string{targetGroupID},
|
||||
},
|
||||
wantEnabled: true,
|
||||
},
|
||||
{
|
||||
name: "implicit-tcp-22-without-peer-ssh-disabled",
|
||||
peerSSH: false,
|
||||
rule: types.PolicyRule{
|
||||
Enabled: true, Protocol: types.PolicyRuleProtocolTCP, Ports: []string{"22"},
|
||||
Destinations: []string{targetGroupID},
|
||||
},
|
||||
wantEnabled: false,
|
||||
},
|
||||
{
|
||||
name: "implicit-tcp-22022-with-peer-ssh",
|
||||
peerSSH: true,
|
||||
rule: types.PolicyRule{
|
||||
Enabled: true, Protocol: types.PolicyRuleProtocolTCP, Ports: []string{"22022"},
|
||||
Destinations: []string{targetGroupID},
|
||||
},
|
||||
wantEnabled: true,
|
||||
},
|
||||
{
|
||||
name: "implicit-all-protocol-with-peer-ssh",
|
||||
peerSSH: true,
|
||||
rule: types.PolicyRule{
|
||||
Enabled: true, Protocol: types.PolicyRuleProtocolALL,
|
||||
Destinations: []string{targetGroupID},
|
||||
},
|
||||
wantEnabled: true,
|
||||
},
|
||||
{
|
||||
name: "implicit-port-range-covers-22",
|
||||
peerSSH: true,
|
||||
rule: types.PolicyRule{
|
||||
Enabled: true,
|
||||
Protocol: types.PolicyRuleProtocolTCP,
|
||||
PortRanges: []types.RulePortRange{{Start: 20, End: 30}},
|
||||
Destinations: []string{targetGroupID},
|
||||
},
|
||||
wantEnabled: true,
|
||||
},
|
||||
{
|
||||
name: "tcp-80-no-ssh",
|
||||
peerSSH: true,
|
||||
rule: types.PolicyRule{
|
||||
Enabled: true, Protocol: types.PolicyRuleProtocolTCP, Ports: []string{"80"},
|
||||
Destinations: []string{targetGroupID},
|
||||
},
|
||||
wantEnabled: false,
|
||||
},
|
||||
{
|
||||
name: "disabled-rule-skipped",
|
||||
peerSSH: true,
|
||||
rule: types.PolicyRule{
|
||||
Enabled: false, Protocol: types.PolicyRuleProtocolNetbirdSSH,
|
||||
Destinations: []string{targetGroupID},
|
||||
},
|
||||
wantEnabled: false,
|
||||
},
|
||||
{
|
||||
name: "peer-not-in-destinations",
|
||||
peerSSH: true,
|
||||
rule: types.PolicyRule{
|
||||
Enabled: true, Protocol: types.PolicyRuleProtocolNetbirdSSH,
|
||||
Destinations: []string{"g_other"}, // target not in this group
|
||||
},
|
||||
wantEnabled: false,
|
||||
},
|
||||
{
|
||||
name: "peer-typed-destination-resource-matches",
|
||||
peerSSH: false,
|
||||
rule: types.PolicyRule{
|
||||
Enabled: true,
|
||||
Protocol: types.PolicyRuleProtocolNetbirdSSH,
|
||||
DestinationResource: types.Resource{ID: targetPeerID, Type: types.ResourceTypePeer},
|
||||
},
|
||||
wantEnabled: true,
|
||||
},
|
||||
{
|
||||
name: "non-peer-destination-resource-falls-through-to-groups",
|
||||
peerSSH: false,
|
||||
rule: types.PolicyRule{
|
||||
Enabled: true,
|
||||
Protocol: types.PolicyRuleProtocolNetbirdSSH,
|
||||
DestinationResource: types.Resource{ID: targetPeerID, Type: "host"}, // wrong type
|
||||
Destinations: []string{targetGroupID}, // saved by group fallback
|
||||
},
|
||||
wantEnabled: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
c, peer := mkComponents(&tc.rule, tc.peerSSH)
|
||||
got := computeSSHEnabledForPeer(c, peer)
|
||||
assert.Equal(t, tc.wantEnabled, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestComputeSSHEnabledForPeer_TargetMissingFromComponents covers the
|
||||
// belt-and-suspenders presence guard mirroring Calculate's
|
||||
// getAllPeersFromGroups invariant.
|
||||
func TestComputeSSHEnabledForPeer_TargetMissingFromComponents(t *testing.T) {
|
||||
peer := &nbpeer.Peer{ID: "missing", SSHEnabled: true}
|
||||
c := &types.NetworkMapComponents{
|
||||
Peers: map[string]*nbpeer.Peer{}, // target peer NOT present
|
||||
Groups: map[string]*types.Group{
|
||||
"g": {ID: "g", Peers: []string{"missing"}},
|
||||
},
|
||||
Policies: []*types.Policy{{
|
||||
ID: "p", Enabled: true,
|
||||
Rules: []*types.PolicyRule{{
|
||||
Enabled: true, Protocol: types.PolicyRuleProtocolNetbirdSSH,
|
||||
Destinations: []string{"g"},
|
||||
}},
|
||||
}},
|
||||
}
|
||||
assert.False(t, computeSSHEnabledForPeer(c, peer),
|
||||
"missing target peer must short-circuit to false, not consult policies")
|
||||
}
|
||||
|
||||
// TestComputeSSHEnabledForPeer_NilInputs guards the cheap nil-checks at
|
||||
// function entry — Calculate doesn't accept nil either, but the helper is
|
||||
// exported indirectly via ToComponentSyncResponse and may receive nil
|
||||
// components on graceful-degrade paths.
|
||||
func TestComputeSSHEnabledForPeer_NilInputs(t *testing.T) {
|
||||
assert.False(t, computeSSHEnabledForPeer(nil, &nbpeer.Peer{ID: "x"}))
|
||||
assert.False(t, computeSSHEnabledForPeer(&types.NetworkMapComponents{}, nil))
|
||||
}
|
||||
@@ -10,24 +10,20 @@ import (
|
||||
|
||||
"github.com/hashicorp/go-version"
|
||||
nbversion "github.com/netbirdio/netbird/version"
|
||||
log "github.com/sirupsen/logrus"
|
||||
goproto "google.golang.org/protobuf/proto"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
integrationsConfig "github.com/netbirdio/management-integrations/integrations/config"
|
||||
|
||||
"github.com/netbirdio/netbird/client/ssh/auth"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache"
|
||||
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
nbroute "github.com/netbirdio/netbird/route"
|
||||
"github.com/netbirdio/netbird/shared/management/networkmap"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
"github.com/netbirdio/netbird/shared/netiputil"
|
||||
"github.com/netbirdio/netbird/shared/sshauth"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -159,8 +155,8 @@ func ToSyncResponse(ctx context.Context, config *nbconfig.Config, httpConfig *nb
|
||||
PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings, httpConfig, deviceFlowConfig, networkMap.EnableSSH),
|
||||
NetworkMap: &proto.NetworkMap{
|
||||
Serial: networkMap.Network.CurrentSerial(),
|
||||
Routes: toProtocolRoutes(networkMap.Routes),
|
||||
DNSConfig: toProtocolDNSConfig(networkMap.DNSConfig, dnsCache, dnsFwdPort),
|
||||
Routes: networkmap.ToProtocolRoutes(networkMap.Routes),
|
||||
DNSConfig: networkmap.ToProtocolDNSConfig(networkMap.DNSConfig, dnsCache, dnsFwdPort),
|
||||
PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings, httpConfig, deviceFlowConfig, networkMap.EnableSSH),
|
||||
},
|
||||
Checks: toProtocolChecks(ctx, checks),
|
||||
@@ -173,7 +169,7 @@ func ToSyncResponse(ctx context.Context, config *nbconfig.Config, httpConfig *nb
|
||||
response.NetworkMap.PeerConfig = response.PeerConfig
|
||||
|
||||
remotePeers := make([]*proto.RemotePeerConfig, 0, len(networkMap.Peers)+len(networkMap.OfflinePeers))
|
||||
remotePeers = appendRemotePeerConfig(remotePeers, networkMap.Peers, dnsName, includeIPv6)
|
||||
remotePeers = networkmap.AppendRemotePeerConfig(remotePeers, networkMap.Peers, dnsName, includeIPv6)
|
||||
|
||||
if !shouldSkipSendingDeprecatedRemotePeers(peer.Meta.WtVersion) {
|
||||
response.RemotePeers = remotePeers
|
||||
@@ -183,13 +179,13 @@ func ToSyncResponse(ctx context.Context, config *nbconfig.Config, httpConfig *nb
|
||||
response.RemotePeersIsEmpty = len(remotePeers) == 0
|
||||
response.NetworkMap.RemotePeersIsEmpty = response.RemotePeersIsEmpty
|
||||
|
||||
response.NetworkMap.OfflinePeers = appendRemotePeerConfig(nil, networkMap.OfflinePeers, dnsName, includeIPv6)
|
||||
response.NetworkMap.OfflinePeers = networkmap.AppendRemotePeerConfig(nil, networkMap.OfflinePeers, dnsName, includeIPv6)
|
||||
|
||||
firewallRules := toProtocolFirewallRules(networkMap.FirewallRules, includeIPv6, useSourcePrefixes)
|
||||
firewallRules := networkmap.ToProtocolFirewallRules(networkMap.FirewallRules, includeIPv6, useSourcePrefixes)
|
||||
response.NetworkMap.FirewallRules = firewallRules
|
||||
response.NetworkMap.FirewallRulesIsEmpty = len(firewallRules) == 0
|
||||
|
||||
routesFirewallRules := toProtocolRoutesFirewallRules(networkMap.RoutesFirewallRules)
|
||||
routesFirewallRules := networkmap.ToProtocolRoutesFirewallRules(networkMap.RoutesFirewallRules)
|
||||
response.NetworkMap.RoutesFirewallRules = routesFirewallRules
|
||||
response.NetworkMap.RoutesFirewallRulesIsEmpty = len(routesFirewallRules) == 0
|
||||
|
||||
@@ -202,7 +198,7 @@ func ToSyncResponse(ctx context.Context, config *nbconfig.Config, httpConfig *nb
|
||||
}
|
||||
|
||||
if networkMap.AuthorizedUsers != nil {
|
||||
hashedUsers, machineUsers := buildAuthorizedUsersProto(ctx, networkMap.AuthorizedUsers)
|
||||
hashedUsers, machineUsers := networkmap.BuildAuthorizedUsersProto(ctx, networkMap.AuthorizedUsers)
|
||||
userIDClaim := auth.DefaultUserIDClaim
|
||||
if httpConfig != nil && httpConfig.AuthUserIDClaim != "" {
|
||||
userIDClaim = httpConfig.AuthUserIDClaim
|
||||
@@ -242,33 +238,6 @@ func encodeSessionExpiresAt(deadline time.Time) *timestamppb.Timestamp {
|
||||
return timestamppb.New(deadline)
|
||||
}
|
||||
|
||||
func buildAuthorizedUsersProto(ctx context.Context, authorizedUsers map[string]map[string]struct{}) ([][]byte, map[string]*proto.MachineUserIndexes) {
|
||||
userIDToIndex := make(map[string]uint32)
|
||||
var hashedUsers [][]byte
|
||||
machineUsers := make(map[string]*proto.MachineUserIndexes, len(authorizedUsers))
|
||||
|
||||
for machineUser, users := range authorizedUsers {
|
||||
indexes := make([]uint32, 0, len(users))
|
||||
for userID := range users {
|
||||
idx, exists := userIDToIndex[userID]
|
||||
if !exists {
|
||||
hash, err := sshauth.HashUserID(userID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to hash user id %s: %v", userID, err)
|
||||
continue
|
||||
}
|
||||
idx = uint32(len(hashedUsers))
|
||||
userIDToIndex[userID] = idx
|
||||
hashedUsers = append(hashedUsers, hash[:])
|
||||
}
|
||||
indexes = append(indexes, idx)
|
||||
}
|
||||
machineUsers[machineUser] = &proto.MachineUserIndexes{Indexes: indexes}
|
||||
}
|
||||
|
||||
return hashedUsers, machineUsers
|
||||
}
|
||||
|
||||
func shouldSkipSendingDeprecatedRemotePeers(peerVersion string) bool {
|
||||
if nbversion.IsDevelopmentVersion(peerVersion) {
|
||||
return true
|
||||
@@ -282,51 +251,6 @@ func shouldSkipSendingDeprecatedRemotePeers(peerVersion string) bool {
|
||||
return precomputedDeprecatedRemotePeersConstraint.Check(peerNBVersion)
|
||||
}
|
||||
|
||||
func appendRemotePeerConfig(dst []*proto.RemotePeerConfig, peers []*nbpeer.Peer, dnsName string, includeIPv6 bool) []*proto.RemotePeerConfig {
|
||||
for _, rPeer := range peers {
|
||||
allowedIPs := []string{rPeer.IP.String() + "/32"}
|
||||
if includeIPv6 && rPeer.IPv6.IsValid() {
|
||||
allowedIPs = append(allowedIPs, rPeer.IPv6.String()+"/128")
|
||||
}
|
||||
dst = append(dst, &proto.RemotePeerConfig{
|
||||
WgPubKey: rPeer.Key,
|
||||
AllowedIps: allowedIPs,
|
||||
SshConfig: &proto.SSHConfig{SshPubKey: []byte(rPeer.SSHKey)},
|
||||
Fqdn: rPeer.FQDN(dnsName),
|
||||
AgentVersion: rPeer.Meta.WtVersion,
|
||||
})
|
||||
}
|
||||
return dst
|
||||
}
|
||||
|
||||
// toProtocolDNSConfig converts nbdns.Config to proto.DNSConfig using the cache
|
||||
func toProtocolDNSConfig(update nbdns.Config, cache *cache.DNSConfigCache, forwardPort int64) *proto.DNSConfig {
|
||||
protoUpdate := &proto.DNSConfig{
|
||||
ServiceEnable: update.ServiceEnable,
|
||||
CustomZones: make([]*proto.CustomZone, 0, len(update.CustomZones)),
|
||||
NameServerGroups: make([]*proto.NameServerGroup, 0, len(update.NameServerGroups)),
|
||||
ForwarderPort: forwardPort,
|
||||
}
|
||||
|
||||
for _, zone := range update.CustomZones {
|
||||
protoZone := convertToProtoCustomZone(zone)
|
||||
protoUpdate.CustomZones = append(protoUpdate.CustomZones, protoZone)
|
||||
}
|
||||
|
||||
for _, nsGroup := range update.NameServerGroups {
|
||||
cacheKey := nsGroup.ID
|
||||
if cachedGroup, exists := cache.GetNameServerGroup(cacheKey); exists {
|
||||
protoUpdate.NameServerGroups = append(protoUpdate.NameServerGroups, cachedGroup)
|
||||
} else {
|
||||
protoGroup := convertToProtoNameServerGroup(nsGroup)
|
||||
cache.SetNameServerGroup(cacheKey, protoGroup)
|
||||
protoUpdate.NameServerGroups = append(protoUpdate.NameServerGroups, protoGroup)
|
||||
}
|
||||
}
|
||||
|
||||
return protoUpdate
|
||||
}
|
||||
|
||||
func ToResponseProto(configProto nbconfig.Protocol) proto.HostConfig_Protocol {
|
||||
switch configProto {
|
||||
case nbconfig.UDP:
|
||||
@@ -344,203 +268,6 @@ func ToResponseProto(configProto nbconfig.Protocol) proto.HostConfig_Protocol {
|
||||
}
|
||||
}
|
||||
|
||||
func toProtocolRoutes(routes []*nbroute.Route) []*proto.Route {
|
||||
protoRoutes := make([]*proto.Route, 0, len(routes))
|
||||
for _, r := range routes {
|
||||
protoRoutes = append(protoRoutes, toProtocolRoute(r))
|
||||
}
|
||||
return protoRoutes
|
||||
}
|
||||
|
||||
func toProtocolRoute(route *nbroute.Route) *proto.Route {
|
||||
return &proto.Route{
|
||||
ID: string(route.ID),
|
||||
NetID: string(route.NetID),
|
||||
Network: route.Network.String(),
|
||||
Domains: route.Domains.ToPunycodeList(),
|
||||
NetworkType: int64(route.NetworkType),
|
||||
Peer: route.Peer,
|
||||
Metric: int64(route.Metric),
|
||||
Masquerade: route.Masquerade,
|
||||
KeepRoute: route.KeepRoute,
|
||||
SkipAutoApply: route.SkipAutoApply,
|
||||
}
|
||||
}
|
||||
|
||||
// toProtocolFirewallRules converts the firewall rules to the protocol firewall rules.
|
||||
// When useSourcePrefixes is true, the compact SourcePrefixes field is populated
|
||||
// alongside the deprecated PeerIP for forward compatibility.
|
||||
// Wildcard rules ("0.0.0.0") are expanded into separate v4 and v6 SourcePrefixes
|
||||
// when includeIPv6 is true.
|
||||
func toProtocolFirewallRules(rules []*types.FirewallRule, includeIPv6, useSourcePrefixes bool) []*proto.FirewallRule {
|
||||
result := make([]*proto.FirewallRule, 0, len(rules))
|
||||
for i := range rules {
|
||||
rule := rules[i]
|
||||
|
||||
fwRule := &proto.FirewallRule{
|
||||
PolicyID: []byte(rule.PolicyID),
|
||||
PeerIP: rule.PeerIP, //nolint:staticcheck // populated for backward compatibility
|
||||
Direction: getProtoDirection(rule.Direction),
|
||||
Action: getProtoAction(rule.Action),
|
||||
Protocol: getProtoProtocol(rule.Protocol),
|
||||
Port: rule.Port,
|
||||
}
|
||||
|
||||
if useSourcePrefixes && rule.PeerIP != "" {
|
||||
result = append(result, populateSourcePrefixes(fwRule, rule, includeIPv6)...)
|
||||
}
|
||||
|
||||
if shouldUsePortRange(fwRule) {
|
||||
fwRule.PortInfo = rule.PortRange.ToProto()
|
||||
}
|
||||
|
||||
result = append(result, fwRule)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// populateSourcePrefixes sets SourcePrefixes on fwRule and returns any
|
||||
// additional rules needed (e.g. a v6 wildcard clone when the peer IP is unspecified).
|
||||
func populateSourcePrefixes(fwRule *proto.FirewallRule, rule *types.FirewallRule, includeIPv6 bool) []*proto.FirewallRule {
|
||||
addr, err := netip.ParseAddr(rule.PeerIP)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if !addr.IsUnspecified() {
|
||||
fwRule.SourcePrefixes = [][]byte{netiputil.EncodeAddr(addr.Unmap())}
|
||||
return nil
|
||||
}
|
||||
|
||||
// IPv4Unspecified/0 is always valid, error is impossible.
|
||||
v4Wildcard, _ := netiputil.EncodePrefix(netip.PrefixFrom(netip.IPv4Unspecified(), 0))
|
||||
fwRule.SourcePrefixes = [][]byte{v4Wildcard}
|
||||
|
||||
if !includeIPv6 {
|
||||
return nil
|
||||
}
|
||||
|
||||
v6Rule := goproto.Clone(fwRule).(*proto.FirewallRule)
|
||||
v6Rule.PeerIP = "::" //nolint:staticcheck // populated for backward compatibility
|
||||
// IPv6Unspecified/0 is always valid, error is impossible.
|
||||
v6Wildcard, _ := netiputil.EncodePrefix(netip.PrefixFrom(netip.IPv6Unspecified(), 0))
|
||||
v6Rule.SourcePrefixes = [][]byte{v6Wildcard}
|
||||
if shouldUsePortRange(v6Rule) {
|
||||
v6Rule.PortInfo = rule.PortRange.ToProto()
|
||||
}
|
||||
return []*proto.FirewallRule{v6Rule}
|
||||
}
|
||||
|
||||
// getProtoDirection converts the direction to proto.RuleDirection.
|
||||
func getProtoDirection(direction int) proto.RuleDirection {
|
||||
if direction == types.FirewallRuleDirectionOUT {
|
||||
return proto.RuleDirection_OUT
|
||||
}
|
||||
return proto.RuleDirection_IN
|
||||
}
|
||||
|
||||
func toProtocolRoutesFirewallRules(rules []*types.RouteFirewallRule) []*proto.RouteFirewallRule {
|
||||
result := make([]*proto.RouteFirewallRule, len(rules))
|
||||
for i := range rules {
|
||||
rule := rules[i]
|
||||
result[i] = &proto.RouteFirewallRule{
|
||||
SourceRanges: rule.SourceRanges,
|
||||
Action: getProtoAction(rule.Action),
|
||||
Destination: rule.Destination,
|
||||
Protocol: getProtoProtocol(rule.Protocol),
|
||||
PortInfo: getProtoPortInfo(rule),
|
||||
IsDynamic: rule.IsDynamic,
|
||||
Domains: rule.Domains.ToPunycodeList(),
|
||||
PolicyID: []byte(rule.PolicyID),
|
||||
RouteID: string(rule.RouteID),
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// getProtoAction converts the action to proto.RuleAction.
|
||||
func getProtoAction(action string) proto.RuleAction {
|
||||
if action == string(types.PolicyTrafficActionDrop) {
|
||||
return proto.RuleAction_DROP
|
||||
}
|
||||
return proto.RuleAction_ACCEPT
|
||||
}
|
||||
|
||||
// getProtoProtocol converts the protocol to proto.RuleProtocol.
|
||||
func getProtoProtocol(protocol string) proto.RuleProtocol {
|
||||
switch types.PolicyRuleProtocolType(protocol) {
|
||||
case types.PolicyRuleProtocolALL:
|
||||
return proto.RuleProtocol_ALL
|
||||
case types.PolicyRuleProtocolTCP:
|
||||
return proto.RuleProtocol_TCP
|
||||
case types.PolicyRuleProtocolUDP:
|
||||
return proto.RuleProtocol_UDP
|
||||
case types.PolicyRuleProtocolICMP:
|
||||
return proto.RuleProtocol_ICMP
|
||||
default:
|
||||
return proto.RuleProtocol_UNKNOWN
|
||||
}
|
||||
}
|
||||
|
||||
// getProtoPortInfo converts the port info to proto.PortInfo.
|
||||
func getProtoPortInfo(rule *types.RouteFirewallRule) *proto.PortInfo {
|
||||
var portInfo proto.PortInfo
|
||||
if rule.Port != 0 {
|
||||
portInfo.PortSelection = &proto.PortInfo_Port{Port: uint32(rule.Port)}
|
||||
} else if portRange := rule.PortRange; portRange.Start != 0 && portRange.End != 0 {
|
||||
portInfo.PortSelection = &proto.PortInfo_Range_{
|
||||
Range: &proto.PortInfo_Range{
|
||||
Start: uint32(portRange.Start),
|
||||
End: uint32(portRange.End),
|
||||
},
|
||||
}
|
||||
}
|
||||
return &portInfo
|
||||
}
|
||||
|
||||
func shouldUsePortRange(rule *proto.FirewallRule) bool {
|
||||
return rule.Port == "" && (rule.Protocol == proto.RuleProtocol_UDP || rule.Protocol == proto.RuleProtocol_TCP)
|
||||
}
|
||||
|
||||
// Helper function to convert nbdns.CustomZone to proto.CustomZone
|
||||
func convertToProtoCustomZone(zone nbdns.CustomZone) *proto.CustomZone {
|
||||
protoZone := &proto.CustomZone{
|
||||
Domain: zone.Domain,
|
||||
Records: make([]*proto.SimpleRecord, 0, len(zone.Records)),
|
||||
SearchDomainDisabled: zone.SearchDomainDisabled,
|
||||
NonAuthoritative: zone.NonAuthoritative,
|
||||
}
|
||||
for _, record := range zone.Records {
|
||||
protoZone.Records = append(protoZone.Records, &proto.SimpleRecord{
|
||||
Name: record.Name,
|
||||
Type: int64(record.Type),
|
||||
Class: record.Class,
|
||||
TTL: int64(record.TTL),
|
||||
RData: record.RData,
|
||||
})
|
||||
}
|
||||
return protoZone
|
||||
}
|
||||
|
||||
// Helper function to convert nbdns.NameServerGroup to proto.NameServerGroup
|
||||
func convertToProtoNameServerGroup(nsGroup *nbdns.NameServerGroup) *proto.NameServerGroup {
|
||||
protoGroup := &proto.NameServerGroup{
|
||||
Primary: nsGroup.Primary,
|
||||
Domains: nsGroup.Domains,
|
||||
SearchDomainsEnabled: nsGroup.SearchDomainsEnabled,
|
||||
NameServers: make([]*proto.NameServer, 0, len(nsGroup.NameServers)),
|
||||
}
|
||||
for _, ns := range nsGroup.NameServers {
|
||||
protoGroup.NameServers = append(protoGroup.NameServers, &proto.NameServer{
|
||||
IP: ns.IP.String(),
|
||||
Port: int64(ns.Port),
|
||||
NSType: int64(ns.NSType),
|
||||
})
|
||||
}
|
||||
return protoGroup
|
||||
}
|
||||
|
||||
// buildJWTConfig constructs JWT configuration for SSH servers from management server config
|
||||
func buildJWTConfig(config *nbconfig.HttpServerConfig, deviceFlowConfig *nbconfig.DeviceAuthorizationFlow) *proto.JWTConfig {
|
||||
if config == nil || config.AuthAudience == "" {
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache"
|
||||
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
|
||||
"github.com/netbirdio/netbird/shared/management/networkmap"
|
||||
)
|
||||
|
||||
func TestToProtocolDNSConfigWithCache(t *testing.T) {
|
||||
@@ -62,13 +63,13 @@ func TestToProtocolDNSConfigWithCache(t *testing.T) {
|
||||
}
|
||||
|
||||
// First run with config1
|
||||
result1 := toProtocolDNSConfig(config1, &cache, int64(network_map.DnsForwarderPort))
|
||||
result1 := networkmap.ToProtocolDNSConfig(config1, &cache, int64(network_map.DnsForwarderPort))
|
||||
|
||||
// Second run with config2
|
||||
result2 := toProtocolDNSConfig(config2, &cache, int64(network_map.DnsForwarderPort))
|
||||
result2 := networkmap.ToProtocolDNSConfig(config2, &cache, int64(network_map.DnsForwarderPort))
|
||||
|
||||
// Third run with config1 again
|
||||
result3 := toProtocolDNSConfig(config1, &cache, int64(network_map.DnsForwarderPort))
|
||||
result3 := networkmap.ToProtocolDNSConfig(config1, &cache, int64(network_map.DnsForwarderPort))
|
||||
|
||||
// Verify that result1 and result3 are identical
|
||||
if !reflect.DeepEqual(result1, result3) {
|
||||
@@ -100,7 +101,7 @@ func BenchmarkToProtocolDNSConfig(b *testing.B) {
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
toProtocolDNSConfig(testData, cache, int64(network_map.DnsForwarderPort))
|
||||
networkmap.ToProtocolDNSConfig(testData, cache, int64(network_map.DnsForwarderPort))
|
||||
}
|
||||
})
|
||||
|
||||
@@ -108,7 +109,7 @@ func BenchmarkToProtocolDNSConfig(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
cache := &cache.DNSConfigCache{}
|
||||
toProtocolDNSConfig(testData, cache, int64(network_map.DnsForwarderPort))
|
||||
networkmap.ToProtocolDNSConfig(testData, cache, int64(network_map.DnsForwarderPort))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1016,7 +1016,31 @@ func (s *Server) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, peer
|
||||
return status.Errorf(codes.Internal, "failed to get peer groups %s", err)
|
||||
}
|
||||
|
||||
plainResp := ToSyncResponse(ctx, s.config, s.config.HttpConfig, s.config.DeviceAuthorizationFlow, peer, turnToken, relayToken, networkMap, s.networkMapController.GetDNSDomain(settings), postureChecks, nil, settings, settings.Extra, peerGroups, dnsFwdPort)
|
||||
dnsName := s.networkMapController.GetDNSDomain(settings)
|
||||
|
||||
var plainResp *proto.SyncResponse
|
||||
if s.networkMapController.PeerNeedsComponents(peer) {
|
||||
// Capable peer: discard the legacy NetworkMap that SyncAndMarkPeer
|
||||
// computed and recompute the raw components instead. This wastes one
|
||||
// Calculate() call per initial-sync — the component-based wire
|
||||
// format is what the peer actually consumes. The streaming path
|
||||
// (network_map.Controller.UpdateAccountPeers) skips this duplication
|
||||
// because it dispatches by capability before computing.
|
||||
//
|
||||
// TODO: refactor SyncPeer / SyncAndMarkPeer / their mocks + manager
|
||||
// interfaces to return PeerNetworkMapResult so the initial-sync path
|
||||
// stops doing duplicate work. Deferred until the client-side
|
||||
// decoder lands and there's a real deployment of capability=3 peers
|
||||
// worth optimizing for.
|
||||
_, components, proxyPatch, _, _, err := s.networkMapController.GetValidatedPeerWithComponents(ctx, false, peer.AccountID, peer)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to build components for peer %s on initial sync: %v", peer.ID, err)
|
||||
return status.Errorf(codes.Internal, "failed to build initial sync envelope")
|
||||
}
|
||||
plainResp = ToComponentSyncResponse(ctx, s.config, s.config.HttpConfig, s.config.DeviceAuthorizationFlow, peer, turnToken, relayToken, components, proxyPatch, dnsName, postureChecks, settings, settings.Extra, peerGroups, dnsFwdPort)
|
||||
} else {
|
||||
plainResp = ToSyncResponse(ctx, s.config, s.config.HttpConfig, s.config.DeviceAuthorizationFlow, peer, turnToken, relayToken, networkMap, dnsName, postureChecks, nil, settings, settings.Extra, peerGroups, dnsFwdPort)
|
||||
}
|
||||
|
||||
key, err := s.secretsManager.GetWGKey()
|
||||
if err != nil {
|
||||
|
||||
@@ -1636,6 +1636,14 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, g := range newGroupsToCreate {
|
||||
seq, err := transaction.AllocateAccountSeqID(ctx, userAuth.AccountId, types.AccountSeqEntityGroup)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error allocating group seq id: %w", err)
|
||||
}
|
||||
g.AccountSeqID = seq
|
||||
}
|
||||
|
||||
if err = transaction.CreateGroups(ctx, userAuth.AccountId, newGroupsToCreate); err != nil {
|
||||
return fmt.Errorf("error saving groups: %w", err)
|
||||
}
|
||||
|
||||
@@ -3170,6 +3170,16 @@ func TestAccount_SetJWTGroups(t *testing.T) {
|
||||
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user2")
|
||||
assert.NoError(t, err, "unable to get user")
|
||||
assert.Len(t, user.AutoGroups, 1, "new group should be added")
|
||||
|
||||
var newJWTGroup *types.Group
|
||||
for _, g := range groups {
|
||||
if g.Name == "group3" {
|
||||
newJWTGroup = g
|
||||
break
|
||||
}
|
||||
}
|
||||
require.NotNil(t, newJWTGroup, "JIT-created JWT group not found")
|
||||
assert.NotZero(t, newJWTGroup.AccountSeqID, "JIT-created JWT group must have a non-zero AccountSeqID")
|
||||
})
|
||||
|
||||
t.Run("remove all JWT groups when list is empty", func(t *testing.T) {
|
||||
|
||||
@@ -41,7 +41,7 @@ func TestAffectedPeers_DependencyCoverageMatrix(t *testing.T) {
|
||||
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
|
||||
require.NoError(t, err)
|
||||
return affectedpeers.Change{ChangedPeerIDs: []string{s.routerPeerID}},
|
||||
[]string{s.sourcePeerID}, []string{s.unrelatedPeerID}
|
||||
[]string{s.sourcePeerID, s.routerPeerID}, []string{s.unrelatedPeerID}
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -106,11 +106,9 @@ func TestAffectedPeers_DependencyCoverageMatrix(t *testing.T) {
|
||||
change, mustContain, mustExclude := r.build(t, s, ctx)
|
||||
affected := resolveAffected(t, s.manager.Store, s.accountID, change)
|
||||
|
||||
for _, id := range mustContain {
|
||||
assert.Contains(t, affected, id, "expected peer to be affected")
|
||||
}
|
||||
for _, id := range mustExclude {
|
||||
assert.NotContains(t, affected, id, "peer must not be affected")
|
||||
assert.ElementsMatch(t, affected, mustContain, "expected peer to be affected")
|
||||
for _, peerID := range mustExclude {
|
||||
assert.NotContains(t, affected, peerID, "peer must not be affected")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -251,7 +251,9 @@ func TestAffectedPeers_E2E_UpdateResource_DestinationResourcePolicy_RefreshesSou
|
||||
}
|
||||
}
|
||||
|
||||
func TestAffectedPeers_E2E_UpdateResource_DisabledSiblingRouter_StillBridged(t *testing.T) {
|
||||
// A disabled sibling router routes to nobody, so updating a resource on its network
|
||||
// must NOT refresh its peer (the enabled router carries the bridge instead).
|
||||
func TestAffectedPeers_E2E_UpdateResource_DisabledSiblingRouterNotBridged(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
|
||||
@@ -274,13 +276,18 @@ func TestAffectedPeers_E2E_UpdateResource_DisabledSiblingRouter_StillBridged(t *
|
||||
require.NoError(t, err)
|
||||
|
||||
disabledCh := s.updateManager.CreateChannel(ctx, disabledRouterPeer.ID)
|
||||
t.Cleanup(func() { s.updateManager.CloseChannel(ctx, disabledRouterPeer.ID) })
|
||||
enabledCh := s.updateManager.CreateChannel(ctx, s.routerPeerID)
|
||||
t.Cleanup(func() {
|
||||
s.updateManager.CloseChannel(ctx, disabledRouterPeer.ID)
|
||||
s.updateManager.CloseChannel(ctx, s.routerPeerID)
|
||||
})
|
||||
|
||||
settleAffectedUpdates(disabledCh)
|
||||
settleAffectedUpdates(disabledCh, enabledCh)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, disabledCh)
|
||||
peerShouldReceiveUpdate(t, enabledCh)
|
||||
peerShouldNotReceiveUpdate(t, disabledCh)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -298,7 +305,7 @@ func TestAffectedPeers_E2E_UpdateResource_DisabledSiblingRouter_StillBridged(t *
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout: resource update did not refresh the disabled sibling router's peer")
|
||||
t.Error("timeout")
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -682,6 +682,9 @@ func TestAffectedPeers_AllRoutingPeers_Network(t *testing.T) {
|
||||
assert.Contains(t, affected, secondRouterPeer.ID, "second routing peer on the same network must also be affected")
|
||||
}
|
||||
|
||||
// A disabled router in the snapshot routes to nobody, so it is skipped when the
|
||||
// walk scans existing account data: a policy edit still folds the literal source
|
||||
// group, but not the disabled router's peer.
|
||||
func TestAffectedPeers_DisabledRouter(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
@@ -694,11 +697,13 @@ func TestAffectedPeers_DisabledRouter(t *testing.T) {
|
||||
|
||||
affected := s.resolvePolicyAffected(ctx, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID))
|
||||
|
||||
assert.Contains(t, affected, s.sourcePeerID, "source peer must be affected")
|
||||
assert.Contains(t, affected, s.routerPeerID,
|
||||
"disabled router's peer must still be affected: Enabled must not gate affected-peers")
|
||||
assert.Contains(t, affected, s.sourcePeerID, "source peer (literal policy source group) must be affected")
|
||||
assert.NotContains(t, affected, s.routerPeerID,
|
||||
"a disabled router routes to nobody, so its peer must not be folded from snapshot data")
|
||||
}
|
||||
|
||||
// A disabled resource in the snapshot is skipped: the policy edit still folds the
|
||||
// literal source group, but the resource no longer bridges to its network's router.
|
||||
func TestAffectedPeers_DisabledResource(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
@@ -710,9 +715,9 @@ func TestAffectedPeers_DisabledResource(t *testing.T) {
|
||||
|
||||
affected := s.resolvePolicyAffected(ctx, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID))
|
||||
|
||||
assert.Contains(t, affected, s.sourcePeerID, "source peer must be affected")
|
||||
assert.Contains(t, affected, s.routerPeerID,
|
||||
"disabled resource must still resolve the routing peer: Enabled must not gate affected-peers")
|
||||
assert.Contains(t, affected, s.sourcePeerID, "source peer (literal policy source group) must be affected")
|
||||
assert.NotContains(t, affected, s.routerPeerID,
|
||||
"a disabled resource routes to nobody, so its network's router must not be folded from snapshot data")
|
||||
}
|
||||
|
||||
func TestAffectedPeers_DisabledRule(t *testing.T) {
|
||||
|
||||
@@ -96,33 +96,54 @@ func affectedGroupID(i int) string { return fmt.Sprintf("affected-grp-%d", i)
|
||||
func affectedGroupName(i int) string { return fmt.Sprintf("AffectedGroup%d", i) }
|
||||
|
||||
func TestCollectGroupChange_PolicyLinked(t *testing.T) {
|
||||
manager, s, accountID, _, groupIDs := setupAffectedPeersTest(t)
|
||||
manager, s, accountID, peerIDs, groupIDs := setupAffectedPeersTest(t)
|
||||
ctx := context.Background()
|
||||
|
||||
_, err := manager.SavePolicy(ctx, accountID, userID, &types.Policy{
|
||||
Enabled: true,
|
||||
Rules: []*types.PolicyRule{
|
||||
{
|
||||
Enabled: true,
|
||||
Sources: []string{groupIDs[0]},
|
||||
Destinations: []string{groupIDs[1]},
|
||||
Bidirectional: true,
|
||||
Action: types.PolicyTrafficActionAccept,
|
||||
Enabled: true,
|
||||
Sources: []string{groupIDs[0]},
|
||||
Destinations: []string{groupIDs[1]},
|
||||
SourceResource: types.Resource{ID: peerIDs[0], Type: types.ResourceTypePeer},
|
||||
DestinationResource: types.Resource{ID: peerIDs[1], Type: types.ResourceTypePeer},
|
||||
Bidirectional: true,
|
||||
Action: types.PolicyTrafficActionAccept,
|
||||
},
|
||||
{
|
||||
Enabled: true,
|
||||
Sources: []string{groupIDs[0]},
|
||||
Destinations: []string{groupIDs[1]},
|
||||
SourceResource: types.Resource{ID: peerIDs[2], Type: types.ResourceTypeHost},
|
||||
DestinationResource: types.Resource{ID: peerIDs[3], Type: types.ResourceTypeHost},
|
||||
Bidirectional: true,
|
||||
Action: types.PolicyTrafficActionAccept,
|
||||
},
|
||||
{
|
||||
Enabled: true,
|
||||
Sources: []string{groupIDs[0]},
|
||||
Destinations: []string{groupIDs[1]},
|
||||
SourceResource: types.Resource{ID: "", Type: types.ResourceTypePeer},
|
||||
DestinationResource: types.Resource{ID: "", Type: types.ResourceTypePeer},
|
||||
Bidirectional: true,
|
||||
Action: types.PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
}, true)
|
||||
require.NoError(t, err)
|
||||
|
||||
groups, _ := collectGroupChangeAffectedGroups(ctx, s, accountID, []string{groupIDs[0]})
|
||||
assert.Contains(t, groups, groupIDs[0])
|
||||
assert.Contains(t, groups, groupIDs[1])
|
||||
groups, directPeers := collectGroupChangeAffectedGroups(ctx, s, accountID, []string{groupIDs[0]})
|
||||
assert.ElementsMatch(t, groups, []string{groupIDs[0], groupIDs[1]})
|
||||
assert.ElementsMatch(t, directPeers, []string{peerIDs[1]})
|
||||
|
||||
groups, _ = collectGroupChangeAffectedGroups(ctx, s, accountID, []string{groupIDs[1]})
|
||||
assert.Contains(t, groups, groupIDs[0])
|
||||
assert.Contains(t, groups, groupIDs[1])
|
||||
groups, directPeers = collectGroupChangeAffectedGroups(ctx, s, accountID, []string{groupIDs[1]})
|
||||
assert.ElementsMatch(t, groups, []string{groupIDs[0], groupIDs[1]})
|
||||
assert.ElementsMatch(t, directPeers, []string{peerIDs[0]})
|
||||
|
||||
groups, _ = collectGroupChangeAffectedGroups(ctx, s, accountID, []string{groupIDs[2]})
|
||||
groups, directPeers = collectGroupChangeAffectedGroups(ctx, s, accountID, []string{groupIDs[2]})
|
||||
assert.Empty(t, groups)
|
||||
assert.Empty(t, directPeers)
|
||||
}
|
||||
|
||||
func TestCollectGroupChange_PolicyWithDirectPeerResource(t *testing.T) {
|
||||
@@ -133,20 +154,44 @@ func TestCollectGroupChange_PolicyWithDirectPeerResource(t *testing.T) {
|
||||
Enabled: true,
|
||||
Rules: []*types.PolicyRule{
|
||||
{
|
||||
Enabled: true,
|
||||
Sources: []string{groupIDs[0]},
|
||||
SourceResource: types.Resource{ID: peerIDs[3], Type: types.ResourceTypePeer},
|
||||
Destinations: []string{groupIDs[1]},
|
||||
Action: types.PolicyTrafficActionAccept,
|
||||
Enabled: true,
|
||||
Sources: []string{groupIDs[0]},
|
||||
SourceResource: types.Resource{ID: peerIDs[3], Type: types.ResourceTypePeer},
|
||||
DestinationResource: types.Resource{ID: peerIDs[4], Type: types.ResourceTypePeer},
|
||||
Destinations: []string{groupIDs[1]},
|
||||
Action: types.PolicyTrafficActionAccept,
|
||||
},
|
||||
{
|
||||
Enabled: true,
|
||||
Sources: []string{groupIDs[0]},
|
||||
SourceResource: types.Resource{ID: peerIDs[1], Type: types.ResourceTypeHost},
|
||||
DestinationResource: types.Resource{ID: peerIDs[2], Type: types.ResourceTypeHost},
|
||||
Destinations: []string{groupIDs[1]},
|
||||
Action: types.PolicyTrafficActionAccept,
|
||||
},
|
||||
{
|
||||
Enabled: true,
|
||||
Sources: []string{groupIDs[0]},
|
||||
SourceResource: types.Resource{ID: "", Type: types.ResourceTypePeer},
|
||||
DestinationResource: types.Resource{ID: "", Type: types.ResourceTypePeer},
|
||||
Destinations: []string{groupIDs[1]},
|
||||
Action: types.PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
}, true)
|
||||
require.NoError(t, err)
|
||||
|
||||
groups, directPeers := collectGroupChangeAffectedGroups(ctx, s, accountID, []string{groupIDs[0]})
|
||||
assert.Contains(t, groups, groupIDs[0])
|
||||
assert.Contains(t, groups, groupIDs[1])
|
||||
assert.Contains(t, directPeers, peerIDs[3])
|
||||
assert.ElementsMatch(t, groups, []string{groupIDs[0], groupIDs[1]})
|
||||
assert.ElementsMatch(t, directPeers, []string{peerIDs[4]})
|
||||
|
||||
groups, directPeers = collectGroupChangeAffectedGroups(ctx, s, accountID, []string{groupIDs[1]})
|
||||
assert.ElementsMatch(t, groups, []string{groupIDs[0], groupIDs[1]})
|
||||
assert.ElementsMatch(t, directPeers, []string{peerIDs[3]})
|
||||
|
||||
groups, directPeers = collectGroupChangeAffectedGroups(ctx, s, accountID, []string{groupIDs[2]})
|
||||
assert.Empty(t, groups)
|
||||
assert.Empty(t, directPeers)
|
||||
}
|
||||
|
||||
func TestCollectGroupChange_PolicyWithNonPeerResource_NoDirectPeers(t *testing.T) {
|
||||
@@ -168,8 +213,7 @@ func TestCollectGroupChange_PolicyWithNonPeerResource_NoDirectPeers(t *testing.T
|
||||
require.NoError(t, err)
|
||||
|
||||
groups, directPeers := collectGroupChangeAffectedGroups(ctx, s, accountID, []string{groupIDs[0]})
|
||||
assert.Contains(t, groups, groupIDs[0])
|
||||
assert.Contains(t, groups, groupIDs[1])
|
||||
assert.ElementsMatch(t, groups, []string{groupIDs[0], groupIDs[1]})
|
||||
assert.Empty(t, directPeers, "non-peer resources should not produce direct peer IDs")
|
||||
}
|
||||
|
||||
@@ -294,6 +338,7 @@ func TestCollectGroupChange_NetworkRouterLinked(t *testing.T) {
|
||||
AccountID: accountID,
|
||||
PeerGroups: []string{groupIDs[0]},
|
||||
Peer: peerIDs[3],
|
||||
Enabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -324,6 +369,7 @@ func TestCollectGroupChange_NetworkRouterPeerOnlyNoGroups(t *testing.T) {
|
||||
NetworkID: net1.ID,
|
||||
AccountID: accountID,
|
||||
Peer: peerIDs[4],
|
||||
Enabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -373,17 +419,11 @@ func TestCollectGroupChange_MultipleEntities(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
groups, directPeers := collectGroupChangeAffectedGroups(ctx, s, accountID, []string{groupIDs[0]})
|
||||
assert.Contains(t, groups, groupIDs[0])
|
||||
assert.Contains(t, groups, groupIDs[1])
|
||||
assert.NotContains(t, groups, groupIDs[2])
|
||||
assert.NotContains(t, groups, groupIDs[3])
|
||||
assert.ElementsMatch(t, groups, []string{groupIDs[0], groupIDs[1]})
|
||||
assert.Empty(t, directPeers)
|
||||
|
||||
groups, directPeers = collectGroupChangeAffectedGroups(ctx, s, accountID, []string{groupIDs[3]})
|
||||
assert.Contains(t, groups, groupIDs[2])
|
||||
assert.Contains(t, groups, groupIDs[3])
|
||||
assert.NotContains(t, groups, groupIDs[0])
|
||||
assert.NotContains(t, groups, groupIDs[1])
|
||||
assert.ElementsMatch(t, groups, []string{groupIDs[2], groupIDs[3]})
|
||||
assert.Empty(t, directPeers)
|
||||
}
|
||||
|
||||
@@ -452,8 +492,9 @@ func TestResolveAffectedPeers_PolicyBetweenTwoGroups(t *testing.T) {
|
||||
result = manager.resolveAffectedPeersForPeerChanges(ctx, s, accountID, []string{peerIDs[1]})
|
||||
assert.ElementsMatch(t, []string{peerIDs[0], peerIDs[1]}, result)
|
||||
|
||||
// peerIDs[2] is unrelated to the route; only its own map can change.
|
||||
result = manager.resolveAffectedPeersForPeerChanges(ctx, s, accountID, []string{peerIDs[2]})
|
||||
assert.Empty(t, result)
|
||||
assert.ElementsMatch(t, []string{peerIDs[2]}, result)
|
||||
}
|
||||
|
||||
func TestResolveAffectedPeers_PolicyThreeGroups(t *testing.T) {
|
||||
@@ -474,7 +515,7 @@ func TestResolveAffectedPeers_PolicyThreeGroups(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
result := manager.resolveAffectedPeersForPeerChanges(ctx, s, accountID, []string{peerIDs[0]})
|
||||
assert.ElementsMatch(t, []string{peerIDs[0], peerIDs[1], peerIDs[2]}, result)
|
||||
assert.ElementsMatch(t, []string{peerIDs[0], peerIDs[2]}, result)
|
||||
}
|
||||
|
||||
func TestResolveAffectedPeers_RoutePeerGroups(t *testing.T) {
|
||||
@@ -506,8 +547,9 @@ func TestResolveAffectedPeers_RoutePeerGroups(t *testing.T) {
|
||||
result = manager.resolveAffectedPeersForPeerChanges(ctx, s, accountID, []string{peerIDs[1]})
|
||||
assert.ElementsMatch(t, []string{peerIDs[0], peerIDs[1]}, result)
|
||||
|
||||
// peerIDs[2] is in no policy; only its own map can change, so it refreshes itself.
|
||||
result = manager.resolveAffectedPeersForPeerChanges(ctx, s, accountID, []string{peerIDs[2]})
|
||||
assert.Empty(t, result)
|
||||
assert.ElementsMatch(t, []string{peerIDs[2]}, result)
|
||||
}
|
||||
|
||||
func TestResolveAffectedPeers_RouteWithDirectPeer(t *testing.T) {
|
||||
@@ -564,9 +606,9 @@ func TestResolveAffectedPeers_RouteWithAccessControlGroups(t *testing.T) {
|
||||
result := manager.resolveAffectedPeersForPeerChanges(ctx, s, accountID, []string{peerIDs[2]})
|
||||
assert.ElementsMatch(t, []string{peerIDs[0], peerIDs[1], peerIDs[2]}, result)
|
||||
|
||||
// peer3 is unrelated
|
||||
// peer3 is unrelated to the route; only its own map can change.
|
||||
result = manager.resolveAffectedPeersForPeerChanges(ctx, s, accountID, []string{peerIDs[3]})
|
||||
assert.Empty(t, result)
|
||||
assert.ElementsMatch(t, []string{peerIDs[3]}, result)
|
||||
}
|
||||
|
||||
func TestResolveAffectedPeers_NetworkRouter(t *testing.T) {
|
||||
@@ -587,6 +629,7 @@ func TestResolveAffectedPeers_NetworkRouter(t *testing.T) {
|
||||
AccountID: accountID,
|
||||
PeerGroups: []string{groupIDs[0]},
|
||||
Peer: peerIDs[3],
|
||||
Enabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -659,9 +702,13 @@ func TestResolveAffectedPeers_PeerInMultipleGroups(t *testing.T) {
|
||||
}, true)
|
||||
require.NoError(t, err)
|
||||
|
||||
// peer0 is in group0 AND group1, so both policies apply
|
||||
// peer0 is in group0 AND group1, so both policies apply. A peer change folds
|
||||
// only the changed peer plus the opposite side of each rule: group2 (peer2) via
|
||||
// the group0 policy and group3 (peer3) via the group1 policy. peer1, a co-member
|
||||
// of group1, is a sibling of the changed peer and must NOT refresh.
|
||||
result := manager.resolveAffectedPeersForPeerChanges(ctx, s, accountID, []string{peerIDs[0]})
|
||||
assert.ElementsMatch(t, []string{peerIDs[0], peerIDs[1], peerIDs[2], peerIDs[3]}, result)
|
||||
assert.ElementsMatch(t, []string{peerIDs[0], peerIDs[2], peerIDs[3]}, result)
|
||||
assert.NotContains(t, result, peerIDs[1], "co-member of the changed peer's group must not refresh")
|
||||
}
|
||||
|
||||
func TestResolveAffectedPeers_MultipleChangedPeers(t *testing.T) {
|
||||
@@ -697,7 +744,7 @@ func TestResolveAffectedPeers_MultipleChangedPeers(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
result := manager.resolveAffectedPeersForPeerChanges(ctx, s, accountID, []string{peerIDs[0], peerIDs[2]})
|
||||
assert.ElementsMatch(t, []string{peerIDs[0], peerIDs[1], peerIDs[2], peerIDs[3]}, result)
|
||||
assert.ElementsMatch(t, []string{peerIDs[0], peerIDs[2], peerIDs[1], peerIDs[3]}, result)
|
||||
}
|
||||
|
||||
func TestResolveAffectedPeers_SharedGroupAcrossPolicyAndRoute(t *testing.T) {
|
||||
@@ -854,8 +901,9 @@ func TestAffectedPeers_IsolatedPolicies(t *testing.T) {
|
||||
assert.NotContains(t, result, peerIDs[0])
|
||||
assert.NotContains(t, result, peerIDs[1])
|
||||
|
||||
// peerIDs[4] is in neither isolated policy; only its own map can change.
|
||||
result = manager.resolveAffectedPeersForPeerChanges(ctx, s, accountID, []string{peerIDs[4]})
|
||||
assert.Empty(t, result)
|
||||
assert.ElementsMatch(t, []string{peerIDs[4]}, result)
|
||||
}
|
||||
|
||||
func TestAffectedPeers_IsolatedRouteAndPolicy(t *testing.T) {
|
||||
@@ -977,12 +1025,13 @@ func TestAffectedPeers_GroupUpdateOnlyAffectsLinkedPeers(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestAffectedPeers_UnlinkedGroupChange_NoUpdates(t *testing.T) {
|
||||
// A peer in no policy/route refreshes only itself — no other peer is affected.
|
||||
func TestAffectedPeers_UnlinkedPeerChange_RefreshesSelfOnly(t *testing.T) {
|
||||
manager, s, accountID, peerIDs, _ := setupAffectedPeersTest(t)
|
||||
ctx := context.Background()
|
||||
|
||||
result := manager.resolveAffectedPeersForPeerChanges(ctx, s, accountID, []string{peerIDs[0]})
|
||||
assert.Empty(t, result)
|
||||
assert.ElementsMatch(t, []string{peerIDs[0]}, result)
|
||||
}
|
||||
|
||||
// TestAffectedPeers_PolicyChange_UnrelatedPeerNoUpdate verifies that creating/deleting a
|
||||
@@ -1332,6 +1381,7 @@ func TestAffectedPeers_NetworkRouterUnlinkedPeerNoUpdate(t *testing.T) {
|
||||
NetworkID: net1.ID,
|
||||
AccountID: accountID,
|
||||
PeerGroups: []string{"nr-grpA"},
|
||||
Enabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -1755,7 +1805,9 @@ func TestCollectAffectedFromProxyServices_GroupContainingTargetPeerChanged(t *te
|
||||
assert.Contains(t, directPeers, peerIDs[1], "target peer must be refreshed")
|
||||
}
|
||||
|
||||
func TestCollectAffectedFromProxyServices_DisabledServiceStillMatches(t *testing.T) {
|
||||
// A disabled service in the snapshot proxies nothing, so it is skipped: a changed
|
||||
// target peer does not pull in the service's proxy peer.
|
||||
func TestCollectAffectedFromProxyServices_DisabledServiceSkipped(t *testing.T) {
|
||||
manager, s, accountID, peerIDs, _ := setupAffectedPeersTest(t)
|
||||
ctx := context.Background()
|
||||
|
||||
@@ -1781,8 +1833,7 @@ func TestCollectAffectedFromProxyServices_DisabledServiceStillMatches(t *testing
|
||||
require.NoError(t, s.CreateService(ctx, svc))
|
||||
|
||||
_, directPeers := collectPeerChangeAffectedGroups(ctx, manager.Store, accountID, nil, []string{peerIDs[1]})
|
||||
assert.Contains(t, directPeers, peerIDs[0], "disabled service should still trigger a refresh so peers are ready when re-enabled")
|
||||
assert.Contains(t, directPeers, peerIDs[1], "disabled target should still trigger a refresh")
|
||||
assert.NotContains(t, directPeers, peerIDs[0], "a disabled service proxies nothing, so its proxy peer must not be folded")
|
||||
}
|
||||
|
||||
func TestCollectAffectedFromProxyServices_NonPeerTargetType(t *testing.T) {
|
||||
|
||||
@@ -6,7 +6,12 @@
|
||||
// and before a delete/removal severs the old state).
|
||||
// - Snapshot.Expand: in-memory walk, no store access. Run AFTER the tx commits.
|
||||
//
|
||||
// Enabled is never consulted: toggling it is itself an observable change.
|
||||
// Enabled handling differs by source. Disabled objects in the SNAPSHOT (existing
|
||||
// account policies/resources/routers/routes/proxy services and their rules/targets)
|
||||
// route to nobody and are skipped — they cannot affect any peer's map. Objects in
|
||||
// the CHANGE itself are processed regardless of Enabled, so disabling one still
|
||||
// refreshes the peers that lose access (the toggle is the observable change, and the
|
||||
// update carries the old∪new state).
|
||||
package affectedpeers
|
||||
|
||||
import (
|
||||
@@ -61,7 +66,8 @@ func Load(ctx context.Context, s store.Store, accountID string, c Change) (*Snap
|
||||
// loadCollections reads the policy/route/nameserver/dns/router/resource/proxy
|
||||
// collections a Change can touch, gated to what the walk needs.
|
||||
func (snap *Snapshot) loadCollections(ctx context.Context, s store.Store, accountID string, c Change) error {
|
||||
hasGroupOrPeerChange := len(c.ChangedGroupIDs) > 0 || len(c.ChangedPeerIDs) > 0 || len(c.Resources) > 0
|
||||
// LinkGroups drive the same policy/route/dns walk as a changed group or peer.
|
||||
hasGroupOrPeerChange := len(c.ChangedGroupIDs) > 0 || len(c.ChangedPeerIDs) > 0 || len(c.LinkGroups) > 0 || len(c.Resources) > 0
|
||||
hasNetworkObject := len(c.Routers) > 0 || len(c.Resources) > 0 || len(c.Networks) > 0
|
||||
// the resource<->router bridge can fire for any of these
|
||||
needsRoutersResources := hasGroupOrPeerChange || len(c.PostureCheckIDs) > 0 || len(c.Policies) > 0 || hasNetworkObject
|
||||
@@ -76,7 +82,7 @@ func (snap *Snapshot) loadCollections(ctx context.Context, s store.Store, accoun
|
||||
return err
|
||||
}
|
||||
}
|
||||
if len(c.ChangedGroupIDs) > 0 || len(c.ChangedPeerIDs) > 0 {
|
||||
if len(c.ChangedGroupIDs) > 0 || len(c.ChangedPeerIDs) > 0 || len(c.LinkGroups) > 0 {
|
||||
if err := snap.loadDNS(ctx, s, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -174,6 +180,24 @@ type Change struct {
|
||||
// folded in — but only when the group is linked (an unlinked group has no map
|
||||
// impact), matching how current members are handled.
|
||||
RemovedPeersByGroup map[string][]string
|
||||
|
||||
// OutputPeerIDs are peers folded straight into the result without seeding their
|
||||
// group memberships into the walk. Use for the peer whose group membership changed:
|
||||
// the peer itself must refresh, but its OTHER groups did not change, so they must
|
||||
// not be walked. Contrast ChangedPeerIDs, which seeds ALL of the peer's groups
|
||||
// (correct when the peer's own attributes changed, e.g. IP/status).
|
||||
OutputPeerIDs []string
|
||||
|
||||
// LinkGroups are groups used ONLY to match policies/routes/routers and walk to the
|
||||
// OPPOSITE side — they are never expanded to their own members. Use this when a
|
||||
// peer's group membership changed: pass the peer in ChangedPeerIDs and its
|
||||
// group(s) here. The opposite side of the policies the group participates in
|
||||
// refreshes, but the group's other members (siblings) do not — nothing changed for
|
||||
// them. For an intra-group policy (A→A) the opposite side IS the group, so its
|
||||
// members still refresh via the opposite-side fold, exactly when they genuinely
|
||||
// gain/lose the changed peer. Unlike ChangedGroupIDs, a LinkGroup is not added to
|
||||
// the output, so a one-sided membership change never wakes the whole group.
|
||||
LinkGroups []string
|
||||
}
|
||||
|
||||
func (c Change) isEmpty() bool {
|
||||
@@ -186,7 +210,9 @@ func (c Change) isEmpty() bool {
|
||||
len(c.Networks) == 0 &&
|
||||
len(c.PostureCheckIDs) == 0 &&
|
||||
len(c.DistributionGroupIDs) == 0 &&
|
||||
len(c.RemovedPeersByGroup) == 0
|
||||
len(c.RemovedPeersByGroup) == 0 &&
|
||||
len(c.LinkGroups) == 0 &&
|
||||
len(c.OutputPeerIDs) == 0
|
||||
}
|
||||
|
||||
// Expand returns the deduplicated affected peer IDs from the preloaded Snapshot,
|
||||
@@ -197,8 +223,8 @@ func (snap *Snapshot) Expand(ctx context.Context, accountID string, c Change) []
|
||||
return nil
|
||||
}
|
||||
r := newResolver(ctx, snap, accountID, c)
|
||||
log.WithContext(ctx).Tracef("affectedpeers expand start: account=%s changedGroups=%v changedPeers=%v policies=%d routes=%d routers=%d resources=%d networks=%d postureChecks=%v distributionGroups=%v",
|
||||
accountID, c.ChangedGroupIDs, c.ChangedPeerIDs, len(c.Policies), len(c.Routes), len(c.Routers), len(c.Resources), len(c.Networks), c.PostureCheckIDs, c.DistributionGroupIDs)
|
||||
log.WithContext(ctx).Tracef("affectedpeers expand start: account=%s changedGroups=%v changedPeers=%v linkGroups=%v policies=%d routes=%d routers=%d resources=%d networks=%d postureChecks=%v distributionGroups=%v",
|
||||
accountID, c.ChangedGroupIDs, c.ChangedPeerIDs, c.LinkGroups, len(c.Policies), len(c.Routes), len(c.Routers), len(c.Resources), len(c.Networks), c.PostureCheckIDs, c.DistributionGroupIDs)
|
||||
r.walk()
|
||||
return r.expand()
|
||||
}
|
||||
@@ -216,57 +242,84 @@ func Collect(ctx context.Context, s store.Store, accountID string, c Change) (gr
|
||||
}
|
||||
r := newResolver(ctx, snap, accountID, c)
|
||||
r.walk()
|
||||
return setToSlice(r.groupSet), setToSlice(r.peerSet)
|
||||
return setToSlice(r.affectedGroups), setToSlice(r.affectedPeers)
|
||||
}
|
||||
|
||||
func newResolver(ctx context.Context, snap *Snapshot, accountID string, c Change) *resolver {
|
||||
r := &resolver{
|
||||
ctx: ctx,
|
||||
snap: snap,
|
||||
accountID: accountID,
|
||||
change: c,
|
||||
changedGroupSet: toSet(c.ChangedGroupIDs),
|
||||
changedPeerSet: toSet(c.ChangedPeerIDs),
|
||||
groupSet: make(map[string]struct{}),
|
||||
peerSet: make(map[string]struct{}),
|
||||
networkIDs: make(map[string]struct{}),
|
||||
ctx: ctx,
|
||||
snap: snap,
|
||||
accountID: accountID,
|
||||
change: c,
|
||||
linkGroups: toSet(c.ChangedGroupIDs),
|
||||
outputGroups: toSet(c.ChangedGroupIDs),
|
||||
changedPeers: toSet(c.ChangedPeerIDs),
|
||||
affectedGroups: make(map[string]struct{}),
|
||||
affectedPeers: make(map[string]struct{}),
|
||||
}
|
||||
// LinkGroups match policies/routes to find the opposite side but are NOT output:
|
||||
// they go into linkGroups only, never outputGroups, so their members never fold in.
|
||||
addAll(r.linkGroups, c.LinkGroups)
|
||||
// Resolve each changed peer to its groups here so callers pass only ChangedPeerIDs.
|
||||
r.seedChangedGroupsFromPeers()
|
||||
r.matchedPolicies = append(r.matchedPolicies, c.Policies...)
|
||||
return r
|
||||
}
|
||||
|
||||
// seedChangedGroupsFromPeers adds each changed peer's groups to changedGroupSet so
|
||||
// seedChangedGroupsFromPeers adds each changed peer's groups to linkGroups so
|
||||
// the group-driven walkers fire for memberships, not just direct peer references.
|
||||
// These seeded groups are for MATCHING only — folding the changed entity's own
|
||||
// side is gated on outputGroups (the caller-reported groups), so a seeded group
|
||||
// never folds its whole membership; only the changed peer itself folds in.
|
||||
func (r *resolver) seedChangedGroupsFromPeers() {
|
||||
if len(r.changedPeerSet) == 0 {
|
||||
if len(r.changedPeers) == 0 {
|
||||
return
|
||||
}
|
||||
for groupID, members := range r.snap.groupPeers {
|
||||
for pID := range r.changedPeerSet {
|
||||
for pID := range r.changedPeers {
|
||||
if _, ok := members[pID]; ok {
|
||||
r.changedGroupSet[groupID] = struct{}{}
|
||||
r.linkGroups[groupID] = struct{}{}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// policySide selects which side of a policy rule to walk.
|
||||
type policySide int
|
||||
|
||||
const (
|
||||
sideSource policySide = iota
|
||||
sideDestination
|
||||
)
|
||||
|
||||
func (s policySide) opposite() policySide {
|
||||
if s == sideSource {
|
||||
return sideDestination
|
||||
}
|
||||
return sideSource
|
||||
}
|
||||
|
||||
// walk resolves affected peers in two buckets, by how far each change propagates.
|
||||
//
|
||||
// BOTH-SIDES — the rule itself changed (an explicit policy edit, or a policy whose
|
||||
// posture check changed). Source AND destination refresh, so each such policy is
|
||||
// walked on both sides.
|
||||
//
|
||||
// OPPOSITE-SIDE — an endpoint moved but no rule changed. For each policy the change
|
||||
// touches we fold only the side AWAY from the change:
|
||||
// - a changed peer/group sits ON a policy side -> fold the opposite side;
|
||||
// - a changed router/resource/network sits on a NETWORK -> fold the SOURCE side of
|
||||
// the policies whose destination reaches it (and the routers it implies).
|
||||
//
|
||||
// Routes, nameserver groups, DNS and embedded-proxy services distribute to their own
|
||||
// member peers, outside the policy graph, and are folded here too.
|
||||
func (r *resolver) walk() {
|
||||
r.collectFromExplicitPolicies()
|
||||
r.collectFromExplicitRoutes(r.change.Routes)
|
||||
r.collectFromExplicitRouters(r.change.Routers)
|
||||
r.collectFromExplicitResources(r.change.Resources)
|
||||
r.collectFromExplicitNetworks(r.change.Networks)
|
||||
r.collectFromPostureChecks(r.change.PostureCheckIDs)
|
||||
for _, policy := range r.bothSidesPolicies() {
|
||||
r.foldPolicySide(policy, sideSource)
|
||||
r.foldPolicySide(policy, sideDestination)
|
||||
}
|
||||
|
||||
// Distribution groups (nameserver/DNS) affect only their member peers: fold them
|
||||
// straight into groupSet so expand() maps them to members, without the policy/
|
||||
// route walk that changedGroupSet would trigger.
|
||||
addAll(r.groupSet, r.change.DistributionGroupIDs)
|
||||
|
||||
if len(r.changedGroupSet) > 0 || len(r.changedPeerSet) > 0 {
|
||||
if len(r.linkGroups) > 0 || len(r.changedPeers) > 0 {
|
||||
r.collectFromPolicies()
|
||||
r.collectFromRoutes()
|
||||
r.collectFromNameServers()
|
||||
@@ -275,7 +328,31 @@ func (r *resolver) walk() {
|
||||
r.collectFromProxyServices()
|
||||
}
|
||||
|
||||
r.collectResourceRouterBridge()
|
||||
r.collectFromChangedRoutes(r.change.Routes)
|
||||
r.collectFromChangedRouters(r.change.Routers)
|
||||
r.collectFromChangedResources(r.change.Resources)
|
||||
r.collectFromChangedNetworks(r.change.Networks)
|
||||
|
||||
// The explicitly changed peers always refresh their own maps. OnPeersUpdated only
|
||||
// refreshes the resolver's output (it ignores the separately-passed changed peers),
|
||||
// so the changed peer reaches its own new map only via here. An offline/deleted
|
||||
// peer in the set is filtered downstream (filterConnectedAffectedPeers).
|
||||
addAll(r.affectedPeers, setToSlice(r.changedPeers))
|
||||
// OutputPeerIDs refresh themselves too, but unlike changedPeers their group
|
||||
// memberships were not seeded into the walk (only the changed group was).
|
||||
addAll(r.affectedPeers, r.change.OutputPeerIDs)
|
||||
|
||||
// Distribution groups (nameserver/DNS) affect only their member peers: fold them
|
||||
// straight into affectedGroups so expand() maps them to members, without the
|
||||
// policy/route walk that linkGroups would trigger.
|
||||
addAll(r.affectedGroups, r.change.DistributionGroupIDs)
|
||||
}
|
||||
|
||||
// bothSidesPolicies are the policies whose rule changed: the explicitly edited ones
|
||||
// plus those gated by a changed posture check. walk folds both their sides.
|
||||
func (r *resolver) bothSidesPolicies() []*types.Policy {
|
||||
policies := append([]*types.Policy(nil), r.change.Policies...)
|
||||
return r.appendPoliciesForPostureChecks(policies, r.change.PostureCheckIDs)
|
||||
}
|
||||
|
||||
type resolver struct {
|
||||
@@ -284,27 +361,71 @@ type resolver struct {
|
||||
accountID string
|
||||
change Change
|
||||
|
||||
changedGroupSet map[string]struct{}
|
||||
changedPeerSet map[string]struct{}
|
||||
// Inputs — what changed. Set once at construction, read-only during the walk
|
||||
// (except linkGroups, which collectFromExplicitResources also seeds).
|
||||
//
|
||||
// linkGroups is the MATCH set: caller-changed groups ∪ the groups of changed
|
||||
// peers ∪ changed-resource groups. A rule/route/router matches the change when
|
||||
// one of its groups is here — used only to find the opposite side to fold.
|
||||
//
|
||||
// outputGroups is the FOLD-WHOLE-GROUP set: ONLY Change.ChangedGroupIDs. When a
|
||||
// matched group is here, its whole membership is affected. A peer-seeded group
|
||||
// is in linkGroups but NOT outputGroups, so it folds only the changed peer
|
||||
// (changedPeers), never its siblings.
|
||||
linkGroups map[string]struct{}
|
||||
outputGroups map[string]struct{}
|
||||
changedPeers map[string]struct{}
|
||||
|
||||
groupSet map[string]struct{}
|
||||
peerSet map[string]struct{}
|
||||
|
||||
matchedPolicies []*types.Policy
|
||||
networkIDs map[string]struct{}
|
||||
// Outputs — the answer. The only sets the walk accumulates into. affectedGroups
|
||||
// is expanded to its member peers in expand().
|
||||
affectedGroups map[string]struct{}
|
||||
affectedPeers map[string]struct{}
|
||||
}
|
||||
|
||||
func (r *resolver) policies() []*types.Policy { return r.snap.policies }
|
||||
// policies returns the account's ENABLED policies from the snapshot. Disabled
|
||||
// policies grant no access, so the walk skips them when scanning existing account
|
||||
// data. Explicitly changed policies (Change.Policies, via bothSidesPolicies) are
|
||||
// processed regardless of Enabled, so disabling one still refreshes its peers.
|
||||
func (r *resolver) policies() []*types.Policy {
|
||||
enabled := make([]*types.Policy, 0, len(r.snap.policies))
|
||||
for _, policy := range r.snap.policies {
|
||||
if policy != nil && policy.Enabled {
|
||||
enabled = append(enabled, policy)
|
||||
}
|
||||
}
|
||||
return enabled
|
||||
}
|
||||
|
||||
func (r *resolver) networkResources() []*resourceTypes.NetworkResource { return r.snap.resources }
|
||||
// networkResources / networkRouters return the account's ENABLED resources/routers
|
||||
// from the snapshot. Disabled objects route to nobody, so the walk skips them when
|
||||
// it scans existing account data. The explicitly changed objects in the Change are
|
||||
// processed regardless of Enabled (collectFromChanged*), so disabling one still
|
||||
// refreshes the peers that lose access.
|
||||
func (r *resolver) networkResources() []*resourceTypes.NetworkResource {
|
||||
enabled := make([]*resourceTypes.NetworkResource, 0, len(r.snap.resources))
|
||||
for _, resource := range r.snap.resources {
|
||||
if resource.Enabled {
|
||||
enabled = append(enabled, resource)
|
||||
}
|
||||
}
|
||||
return enabled
|
||||
}
|
||||
|
||||
func (r *resolver) networkRouters() []*routerTypes.NetworkRouter { return r.snap.routers }
|
||||
func (r *resolver) networkRouters() []*routerTypes.NetworkRouter {
|
||||
enabled := make([]*routerTypes.NetworkRouter, 0, len(r.snap.routers))
|
||||
for _, router := range r.snap.routers {
|
||||
if router.Enabled {
|
||||
enabled = append(enabled, router)
|
||||
}
|
||||
}
|
||||
return enabled
|
||||
}
|
||||
|
||||
// peerIDsForGroups maps a group set to its member peer IDs via the preloaded index.
|
||||
func (r *resolver) peerIDsForGroups(groupSet map[string]struct{}) []string {
|
||||
func (r *resolver) peerIDsForGroups(groups map[string]struct{}) []string {
|
||||
seen := make(map[string]struct{})
|
||||
var ids []string
|
||||
for gID := range groupSet {
|
||||
for gID := range groups {
|
||||
for pID := range r.snap.groupPeers[gID] {
|
||||
if _, ok := seen[pID]; ok {
|
||||
continue
|
||||
@@ -317,25 +438,25 @@ func (r *resolver) peerIDsForGroups(groupSet map[string]struct{}) []string {
|
||||
}
|
||||
|
||||
func (r *resolver) expand() []string {
|
||||
peerIDs := r.peerIDsForGroups(r.groupSet)
|
||||
peerIDs := r.peerIDsForGroups(r.affectedGroups)
|
||||
|
||||
log.WithContext(r.ctx).Tracef("affectedpeers expand: account=%s affectedGroups=%v -> %d group-member peers; direct peers=%v",
|
||||
r.accountID, setToSlice(r.groupSet), len(peerIDs), setToSlice(r.peerSet))
|
||||
r.accountID, setToSlice(r.affectedGroups), len(peerIDs), setToSlice(r.affectedPeers))
|
||||
|
||||
seen := make(map[string]struct{}, len(peerIDs))
|
||||
for _, id := range peerIDs {
|
||||
seen[id] = struct{}{}
|
||||
}
|
||||
for id := range r.peerSet {
|
||||
for id := range r.affectedPeers {
|
||||
if _, ok := seen[id]; !ok {
|
||||
peerIDs = append(peerIDs, id)
|
||||
seen[id] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
// Fold in removed peers only when their group is linked (in groupSet).
|
||||
// Fold in removed peers only when their group is linked (in affectedGroups).
|
||||
for groupID, removed := range r.change.RemovedPeersByGroup {
|
||||
if _, linked := r.groupSet[groupID]; !linked {
|
||||
if _, linked := r.affectedGroups[groupID]; !linked {
|
||||
continue
|
||||
}
|
||||
for _, id := range removed {
|
||||
@@ -351,169 +472,349 @@ func (r *resolver) expand() []string {
|
||||
return peerIDs
|
||||
}
|
||||
|
||||
func (r *resolver) collectFromExplicitPolicies() {
|
||||
for _, policy := range r.matchedPolicies {
|
||||
if policy == nil {
|
||||
continue
|
||||
// ruleSideGroups / ruleSideResource return the groups and the resource on the given
|
||||
// side of a rule.
|
||||
func ruleSideGroups(rule *types.PolicyRule, side policySide) []string {
|
||||
if side == sideDestination {
|
||||
return rule.Destinations
|
||||
}
|
||||
return rule.Sources
|
||||
}
|
||||
|
||||
func ruleSideResource(rule *types.PolicyRule, side policySide) types.Resource {
|
||||
if side == sideDestination {
|
||||
return rule.DestinationResource
|
||||
}
|
||||
return rule.SourceResource
|
||||
}
|
||||
|
||||
// foldPolicySide folds one side of a policy down to affected peers: its groups
|
||||
// (resolved to members in expand) and its direct peer. When the side is the
|
||||
// DESTINATION and references a network resource (directly or via a destination
|
||||
// group's resources), it also folds the routers that serve that resource's network
|
||||
// — a destination resource is reached through its routers. A resource on the SOURCE
|
||||
// side routes to nobody (GetPoliciesForNetworkResource matches destinations only),
|
||||
// so the router hop is destination-only.
|
||||
func (r *resolver) foldPolicySide(policy *types.Policy, side policySide) {
|
||||
if policy == nil {
|
||||
return
|
||||
}
|
||||
for _, rule := range policy.Rules {
|
||||
addAll(r.affectedGroups, ruleSideGroups(rule, side))
|
||||
res := ruleSideResource(rule, side)
|
||||
if res.Type == types.ResourceTypePeer && res.ID != "" {
|
||||
r.affectedPeers[res.ID] = struct{}{}
|
||||
}
|
||||
log.WithContext(r.ctx).Tracef("collectFromExplicitPolicies: changed policy %s (%s) -> folding rule groups %v + direct peers",
|
||||
policy.ID, policy.Name, policy.RuleGroups())
|
||||
addAll(r.groupSet, policy.RuleGroups())
|
||||
collectPolicyDirectPeers(policy, r.peerSet)
|
||||
}
|
||||
if side == sideDestination {
|
||||
r.foldRoutersForResources(r.policyDestinationResourceIDs(policy))
|
||||
}
|
||||
}
|
||||
|
||||
func (r *resolver) collectFromExplicitRoutes(routes []*route.Route) {
|
||||
// appendPoliciesForPostureChecks appends every policy that references a changed
|
||||
// posture check (a rule change, so walk both sides).
|
||||
func (r *resolver) appendPoliciesForPostureChecks(policies []*types.Policy, postureCheckIDs []string) []*types.Policy {
|
||||
if len(postureCheckIDs) == 0 {
|
||||
return policies
|
||||
}
|
||||
ids := toSet(postureCheckIDs)
|
||||
for _, policy := range r.policies() {
|
||||
if !policyReferencesPostureChecks(policy, ids) || !policy.Enabled {
|
||||
continue
|
||||
}
|
||||
log.WithContext(r.ctx).Tracef("appendPoliciesForPostureChecks: policy %s (%s) references changed posture checks %v -> both-sides policy",
|
||||
policy.ID, policy.Name, postureCheckIDs)
|
||||
policies = append(policies, policy)
|
||||
}
|
||||
return policies
|
||||
}
|
||||
|
||||
// collectFromPolicies folds, for every policy whose rule a changed group or peer
|
||||
// touches, only the OPPOSITE side (down to peers, incl. destination routers), plus
|
||||
// the changed entity's own side: the changed group's whole membership when the
|
||||
// group itself changed (outputGroups), or the changed peer alone when matched via a
|
||||
// peer-seeded group (never its co-members).
|
||||
func (r *resolver) collectFromPolicies() {
|
||||
for _, policy := range r.policies() {
|
||||
for _, rule := range policy.Rules {
|
||||
if !rule.Enabled {
|
||||
continue // a disabled rule grants no access
|
||||
}
|
||||
r.foldRuleSideIfChanged(policy, rule, sideSource)
|
||||
r.foldRuleSideIfChanged(policy, rule, sideDestination)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// foldRuleSideIfChanged: when a changed group or direct peer sits on `side` of the
|
||||
// rule, fold the opposite side fully (groups/peers + destination routers) and fold
|
||||
// the changed entity's own side (the whole changed group, or the changed peer alone).
|
||||
func (r *resolver) foldRuleSideIfChanged(policy *types.Policy, rule *types.PolicyRule, side policySide) {
|
||||
nearGroups := ruleSideGroups(rule, side)
|
||||
nearResource := ruleSideResource(rule, side)
|
||||
|
||||
matchedByGroup := anyInSet(nearGroups, r.linkGroups)
|
||||
matchedByPeer := isDirectPeerInSet(nearResource, r.changedPeers)
|
||||
if !matchedByGroup && !matchedByPeer {
|
||||
return
|
||||
}
|
||||
|
||||
// Opposite side, fully down to peers (a destination opposite also folds routers).
|
||||
r.foldPolicySideForRule(policy, rule, side.opposite())
|
||||
|
||||
// Own side: fold the whole changed group's members only when the group itself
|
||||
// changed (outputGroups). A peer-seeded or link-only group is not folded here —
|
||||
// its siblings never refresh. The changed peers themselves are folded once, after
|
||||
// the walk (see walk()).
|
||||
for _, gID := range nearGroups {
|
||||
if _, ok := r.outputGroups[gID]; ok {
|
||||
r.affectedGroups[gID] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
// When the changed side IS a destination, the resources it targets are reached
|
||||
// through their network's routers, so those routers refresh too (e.g. attaching a
|
||||
// resource to a destination group, or a changed destination group/resource).
|
||||
if side == sideDestination {
|
||||
r.foldRoutersForResources(r.ruleDestinationResourceIDs(rule))
|
||||
}
|
||||
}
|
||||
|
||||
// foldPolicySideForRule folds one side of a single rule (groups + direct peer), and
|
||||
// for a destination side the routers of that rule's destination resources.
|
||||
func (r *resolver) foldPolicySideForRule(policy *types.Policy, rule *types.PolicyRule, side policySide) {
|
||||
addAll(r.affectedGroups, ruleSideGroups(rule, side))
|
||||
res := ruleSideResource(rule, side)
|
||||
if res.Type == types.ResourceTypePeer && res.ID != "" {
|
||||
r.affectedPeers[res.ID] = struct{}{}
|
||||
}
|
||||
if side == sideDestination {
|
||||
r.foldRoutersForResources(r.ruleDestinationResourceIDs(rule))
|
||||
}
|
||||
}
|
||||
|
||||
// collectFromChangedRoutes folds an explicitly changed route's own groups and peer.
|
||||
func (r *resolver) collectFromChangedRoutes(routes []*route.Route) {
|
||||
for _, rt := range routes {
|
||||
if rt == nil {
|
||||
continue
|
||||
}
|
||||
log.WithContext(r.ctx).Tracef("collectFromExplicitRoutes: changed route %s -> folding groups=%v peerGroups=%v accessControlGroups=%v peer=%q",
|
||||
log.WithContext(r.ctx).Tracef("collectFromChangedRoutes: changed route %s -> folding groups=%v peerGroups=%v accessControlGroups=%v peer=%q",
|
||||
rt.ID, rt.Groups, rt.PeerGroups, rt.AccessControlGroups, rt.Peer)
|
||||
addAll(r.groupSet, rt.Groups, rt.PeerGroups, rt.AccessControlGroups)
|
||||
addAll(r.affectedGroups, rt.Groups, rt.PeerGroups, rt.AccessControlGroups)
|
||||
if rt.Peer != "" {
|
||||
r.peerSet[rt.Peer] = struct{}{}
|
||||
r.affectedPeers[rt.Peer] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// collectFromExplicitRouters folds changed routers' peers and marks their networks
|
||||
// for the bridge. Passing the old router keeps a repointed router's previous peers
|
||||
// affected without a post-commit read.
|
||||
func (r *resolver) collectFromExplicitRouters(routers []*routerTypes.NetworkRouter) {
|
||||
// collectFromChangedRouters: a changed router refreshes its OWN backing peer/groups
|
||||
// (the changed entity) and the SOURCE side of every policy reaching a resource on
|
||||
// its network (the router serves the whole network). Sibling routers on the network
|
||||
// are independent and are NOT folded. Passing the old router state keeps a repointed
|
||||
// router's previous backing affected without a post-commit read.
|
||||
func (r *resolver) collectFromChangedRouters(routers []*routerTypes.NetworkRouter) {
|
||||
for _, router := range routers {
|
||||
if router == nil {
|
||||
continue
|
||||
}
|
||||
log.WithContext(r.ctx).Tracef("collectFromExplicitRouters: changed router %s on network %s -> folding peerGroups=%v peer=%q and marking network for source bridge",
|
||||
log.WithContext(r.ctx).Tracef("collectFromChangedRouters: changed router %s on network %s -> folding its own peerGroups=%v peer=%q + sources reaching network resources",
|
||||
router.ID, router.NetworkID, router.PeerGroups, router.Peer)
|
||||
addAll(r.groupSet, router.PeerGroups)
|
||||
addAll(r.affectedGroups, router.PeerGroups)
|
||||
if router.Peer != "" {
|
||||
r.peerSet[router.Peer] = struct{}{}
|
||||
r.affectedPeers[router.Peer] = struct{}{}
|
||||
}
|
||||
if router.NetworkID != "" {
|
||||
r.networkIDs[router.NetworkID] = struct{}{}
|
||||
r.foldPolicySourcesForResources(r.networkResourceIDs(router.NetworkID))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// collectFromExplicitResources marks changed resources' networks for the bridge and
|
||||
// treats their group IDs as changed, so policies targeting the resource via a
|
||||
// now-detached (old) group still refresh.
|
||||
func (r *resolver) collectFromExplicitResources(resources []*resourceTypes.NetworkResource) {
|
||||
// collectFromChangedResources: a changed resource refreshes the SOURCE side of the
|
||||
// policies targeting EXACTLY that resource — directly, or via one of the resource's
|
||||
// own groups (old∪new across the change, so a now-detached group's sources still
|
||||
// refresh) — plus the routers serving its network (the resource is reached through
|
||||
// them). It does not touch sibling resources on the same network.
|
||||
func (r *resolver) collectFromChangedResources(resources []*resourceTypes.NetworkResource) {
|
||||
for _, resource := range resources {
|
||||
if resource == nil {
|
||||
continue
|
||||
}
|
||||
log.WithContext(r.ctx).Tracef("collectFromExplicitResources: changed resource %s on network %s -> marking network for bridge and treating groups %v as changed",
|
||||
log.WithContext(r.ctx).Tracef("collectFromChangedResources: changed resource %s on network %s (groups %v) -> folding sources of policies targeting it + its network's routers",
|
||||
resource.ID, resource.NetworkID, resource.GroupIDs)
|
||||
addAll(r.changedGroupSet, resource.GroupIDs)
|
||||
r.foldPolicySourcesForResource(resource.ID, resource.GroupIDs)
|
||||
if resource.NetworkID != "" {
|
||||
r.networkIDs[resource.NetworkID] = struct{}{}
|
||||
r.foldRoutersOnNetworks(map[string]struct{}{resource.NetworkID: {}})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// collectFromExplicitNetworks marks changed networks for the bridge. A network has
|
||||
// no groups/peers of its own.
|
||||
func (r *resolver) collectFromExplicitNetworks(networks []*networkTypes.Network) {
|
||||
for _, network := range networks {
|
||||
if network == nil {
|
||||
// foldPolicySourcesForResource folds the source side of every policy whose
|
||||
// destination is the given resource — referenced directly, or via any of the given
|
||||
// groups (the resource's own old∪new groups, which captures a detached group).
|
||||
func (r *resolver) foldPolicySourcesForResource(resourceID string, groupIDs []string) {
|
||||
groups := toSet(groupIDs)
|
||||
for _, policy := range r.policies() {
|
||||
if !policyTargetsResourceOrGroups(policy, resourceID, groups) {
|
||||
continue
|
||||
}
|
||||
log.WithContext(r.ctx).Tracef("collectFromExplicitNetworks: changed network %s -> marking for bridge", network.ID)
|
||||
if network.ID != "" {
|
||||
r.networkIDs[network.ID] = struct{}{}
|
||||
}
|
||||
log.WithContext(r.ctx).Tracef("foldPolicySourcesForResource: policy %s (%s) targets changed resource %s -> folding its source groups/peers", policy.ID, policy.Name, resourceID)
|
||||
collectPolicySources(policy, r.affectedGroups, r.affectedPeers)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *resolver) collectFromPostureChecks(postureCheckIDs []string) {
|
||||
if len(postureCheckIDs) == 0 {
|
||||
// policyTargetsResourceOrGroups reports whether a policy's destination is the given
|
||||
// resource directly, or one of the given destination groups.
|
||||
func policyTargetsResourceOrGroups(policy *types.Policy, resourceID string, groups map[string]struct{}) bool {
|
||||
if policy == nil {
|
||||
return false
|
||||
}
|
||||
for _, rule := range policy.Rules {
|
||||
if !rule.Enabled {
|
||||
continue
|
||||
}
|
||||
if rule.DestinationResource.Type != types.ResourceTypePeer && rule.DestinationResource.ID == resourceID && resourceID != "" {
|
||||
return true
|
||||
}
|
||||
if anyInSet(rule.Destinations, groups) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// collectFromChangedNetworks: a changed network refreshes the SOURCE side of the
|
||||
// policies reaching any of its resources, plus its routers. A network has no
|
||||
// groups/peers of its own.
|
||||
func (r *resolver) collectFromChangedNetworks(networks []*networkTypes.Network) {
|
||||
for _, network := range networks {
|
||||
if network == nil || network.ID == "" {
|
||||
continue
|
||||
}
|
||||
log.WithContext(r.ctx).Tracef("collectFromChangedNetworks: changed network %s -> folding sources reaching its resources + its routers", network.ID)
|
||||
resourceIDs := r.networkResourceIDs(network.ID)
|
||||
r.foldPolicySourcesForResources(resourceIDs)
|
||||
r.foldRoutersOnNetworks(map[string]struct{}{network.ID: {}})
|
||||
}
|
||||
}
|
||||
|
||||
// foldPolicySourcesForResources folds the source groups/peers of every policy whose
|
||||
// destination targets one of resourceIDs (directly or via a destination group).
|
||||
func (r *resolver) foldPolicySourcesForResources(resourceIDs map[string]struct{}) {
|
||||
if len(resourceIDs) == 0 {
|
||||
return
|
||||
}
|
||||
ids := toSet(postureCheckIDs)
|
||||
for _, policy := range r.policies() {
|
||||
if !policyReferencesPostureChecks(policy, ids) {
|
||||
continue
|
||||
if r.policyTargetsResources(policy, resourceIDs) {
|
||||
log.WithContext(r.ctx).Tracef("foldPolicySourcesForResources: policy %s (%s) targets a changed resource -> folding its source groups/peers", policy.ID, policy.Name)
|
||||
collectPolicySources(policy, r.affectedGroups, r.affectedPeers)
|
||||
}
|
||||
log.WithContext(r.ctx).Tracef("collectFromPostureChecks: policy %s (%s) references changed posture checks %v -> folding rule groups %v + direct peers",
|
||||
policy.ID, policy.Name, postureCheckIDs, policy.RuleGroups())
|
||||
addAll(r.groupSet, policy.RuleGroups())
|
||||
collectPolicyDirectPeers(policy, r.peerSet)
|
||||
r.matchedPolicies = append(r.matchedPolicies, policy)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *resolver) collectFromPolicies() {
|
||||
for _, policy := range r.policies() {
|
||||
matchedByGroup := policyReferencesGroups(policy, r.changedGroupSet)
|
||||
matchedByPeer := len(r.changedPeerSet) > 0 && policyReferencesDirectPeers(policy, r.changedPeerSet)
|
||||
if !matchedByGroup && !matchedByPeer {
|
||||
continue
|
||||
}
|
||||
log.WithContext(r.ctx).Tracef("collectFromPolicies: policy %s (%s) matched (byGroup=%t byPeer=%t) -> folding rule groups %v + direct peers",
|
||||
policy.ID, policy.Name, matchedByGroup, matchedByPeer, policy.RuleGroups())
|
||||
addAll(r.groupSet, policy.RuleGroups())
|
||||
collectPolicyDirectPeers(policy, r.peerSet)
|
||||
r.matchedPolicies = append(r.matchedPolicies, policy)
|
||||
}
|
||||
}
|
||||
|
||||
// collectFromRoutes folds, per matched route, the OPPOSITE side(s) fully and the
|
||||
// matched side's own groups only on a whole-group change (outputGroups). A route has
|
||||
// three peer sides — routing (Peer/PeerGroups), consumer (Groups) and ACL
|
||||
// (AccessControlGroups) — that each refresh the others; the changed side's own group
|
||||
// folds its siblings only when the group itself changed, never on a one-peer move.
|
||||
func (r *resolver) collectFromRoutes() {
|
||||
for _, rt := range r.snap.routes {
|
||||
matchedByGroup := anyInSet(rt.Groups, r.changedGroupSet) || anyInSet(rt.PeerGroups, r.changedGroupSet) || anyInSet(rt.AccessControlGroups, r.changedGroupSet)
|
||||
matchedByPeer := rt.Peer != "" && len(r.changedPeerSet) > 0 && isInSet(rt.Peer, r.changedPeerSet)
|
||||
if !matchedByGroup && !matchedByPeer {
|
||||
if !rt.Enabled {
|
||||
continue // disabled routes route to nobody; skip existing account data
|
||||
}
|
||||
routing := anyInSet(rt.PeerGroups, r.linkGroups) || (rt.Peer != "" && isInSet(rt.Peer, r.changedPeers))
|
||||
consumer := anyInSet(rt.Groups, r.linkGroups)
|
||||
acl := anyInSet(rt.AccessControlGroups, r.linkGroups)
|
||||
if !routing && !consumer && !acl {
|
||||
continue
|
||||
}
|
||||
log.WithContext(r.ctx).Tracef("collectFromRoutes: route %s matched (byGroup=%t byPeer=%t) -> folding groups=%v peerGroups=%v accessControlGroups=%v peer=%q",
|
||||
rt.ID, matchedByGroup, matchedByPeer, rt.Groups, rt.PeerGroups, rt.AccessControlGroups, rt.Peer)
|
||||
addAll(r.groupSet, rt.Groups, rt.PeerGroups, rt.AccessControlGroups)
|
||||
if rt.Peer != "" {
|
||||
r.peerSet[rt.Peer] = struct{}{}
|
||||
log.WithContext(r.ctx).Tracef("collectFromRoutes: route %s matched (routing=%t consumer=%t acl=%t) -> folding opposite sides; own side gated on outputGroups",
|
||||
rt.ID, routing, consumer, acl)
|
||||
r.foldRouteSide(rt.PeerGroups, routing)
|
||||
r.foldRouteSide(rt.Groups, consumer)
|
||||
r.foldRouteSide(rt.AccessControlGroups, acl)
|
||||
// The single routing Peer folds when the routing side is the OPPOSITE of the
|
||||
// match (consumer/acl need it), or when that very peer is the change.
|
||||
if rt.Peer != "" && (consumer || acl || isInSet(rt.Peer, r.changedPeers)) {
|
||||
r.affectedPeers[rt.Peer] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// foldRouteSide folds a route side: when this side is the one that matched, fold its
|
||||
// groups only on a whole-group change (outputGroups) so siblings of a single moved
|
||||
// peer stay put; otherwise it is an opposite side and folds fully.
|
||||
func (r *resolver) foldRouteSide(groups []string, matchedHere bool) {
|
||||
if matchedHere {
|
||||
r.foldOutputGroups(groups)
|
||||
return
|
||||
}
|
||||
addAll(r.affectedGroups, groups)
|
||||
}
|
||||
|
||||
// foldOutputGroups folds only the groups that the caller reported as wholly changed
|
||||
// (outputGroups). Used for a matched object's OWN side, where a peer-seeded or
|
||||
// link-only group must not pull in its siblings.
|
||||
func (r *resolver) foldOutputGroups(groups ...[]string) {
|
||||
for _, gs := range groups {
|
||||
for _, gID := range gs {
|
||||
if _, ok := r.outputGroups[gID]; ok {
|
||||
r.affectedGroups[gID] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *resolver) collectFromNameServers() {
|
||||
if len(r.changedGroupSet) == 0 {
|
||||
if len(r.linkGroups) == 0 {
|
||||
return
|
||||
}
|
||||
for _, ns := range r.snap.nsGroups {
|
||||
if anyInSet(ns.Groups, r.changedGroupSet) {
|
||||
log.WithContext(r.ctx).Tracef("collectFromNameServers: nameserver group %s references a changed group -> folding its groups %v", ns.ID, ns.Groups)
|
||||
addAll(r.groupSet, ns.Groups)
|
||||
if anyInSet(ns.Groups, r.linkGroups) {
|
||||
// A nameserver group has no opposite side: a peer's DNS config depends only
|
||||
// on its own membership, so a one-peer move refreshes that peer alone (folded
|
||||
// elsewhere). Fold the referenced groups only on a whole-group change.
|
||||
log.WithContext(r.ctx).Tracef("collectFromNameServers: nameserver group %s references a linked group -> folding its groups %v (outputGroups only)", ns.ID, ns.Groups)
|
||||
r.foldOutputGroups(ns.Groups)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *resolver) collectFromDNSSettings() {
|
||||
if len(r.changedGroupSet) == 0 || r.snap.dnsSettings == nil {
|
||||
if len(r.linkGroups) == 0 || r.snap.dnsSettings == nil {
|
||||
return
|
||||
}
|
||||
for _, gID := range r.snap.dnsSettings.DisabledManagementGroups {
|
||||
if _, ok := r.changedGroupSet[gID]; ok {
|
||||
if _, ok := r.linkGroups[gID]; ok {
|
||||
log.WithContext(r.ctx).Tracef("collectFromDNSSettings: changed group %s is in DisabledManagementGroups -> folding it", gID)
|
||||
r.groupSet[gID] = struct{}{}
|
||||
r.affectedGroups[gID] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// collectFromNetworkRouters handles a changed group/peer that BACKS a router (the
|
||||
// routing peer set moved): the router's own peers refresh and so do the sources of
|
||||
// the policies reaching its network's resources. Sibling routers on the network are
|
||||
// independent and are not folded.
|
||||
func (r *resolver) collectFromNetworkRouters() {
|
||||
for _, router := range r.networkRouters() {
|
||||
matchedByGroup := anyInSet(router.PeerGroups, r.changedGroupSet)
|
||||
matchedByPeer := router.Peer != "" && len(r.changedPeerSet) > 0 && isInSet(router.Peer, r.changedPeerSet)
|
||||
matchedByGroup := anyInSet(router.PeerGroups, r.linkGroups)
|
||||
matchedByPeer := router.Peer != "" && len(r.changedPeers) > 0 && isInSet(router.Peer, r.changedPeers)
|
||||
if !matchedByGroup && !matchedByPeer {
|
||||
continue
|
||||
}
|
||||
log.WithContext(r.ctx).Tracef("collectFromNetworkRouters: router %s on network %s matched (byGroup=%t byPeer=%t) -> folding peerGroups=%v peer=%q and marking network for source bridge",
|
||||
log.WithContext(r.ctx).Tracef("collectFromNetworkRouters: router %s on network %s matched (byGroup=%t byPeer=%t) -> folding its peerGroups=%v peer=%q (own groups on outputGroups) + sources reaching network resources",
|
||||
router.ID, router.NetworkID, matchedByGroup, matchedByPeer, router.PeerGroups, router.Peer)
|
||||
addAll(r.groupSet, router.PeerGroups)
|
||||
// The backing PeerGroups are the matched (own) side: fold them only on a
|
||||
// whole-group change so a one-peer move does not wake sibling backing peers. The
|
||||
// opposite side (policy sources reaching the network) is folded below.
|
||||
r.foldOutputGroups(router.PeerGroups)
|
||||
if router.Peer != "" {
|
||||
r.peerSet[router.Peer] = struct{}{}
|
||||
r.affectedPeers[router.Peer] = struct{}{}
|
||||
}
|
||||
if router.NetworkID != "" {
|
||||
r.foldPolicySourcesForResources(r.networkResourceIDs(router.NetworkID))
|
||||
}
|
||||
r.networkIDs[router.NetworkID] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -526,42 +827,48 @@ func (r *resolver) collectFromProxyServices() {
|
||||
expanded := r.expandChangedPeersWithGroups()
|
||||
|
||||
for _, svc := range services {
|
||||
if svc == nil {
|
||||
continue
|
||||
if svc == nil || !svc.Enabled {
|
||||
continue // a disabled service proxies nothing; skip existing account data
|
||||
}
|
||||
proxyPeers := proxyByCluster[svc.ProxyCluster]
|
||||
if len(proxyPeers) == 0 {
|
||||
continue
|
||||
}
|
||||
matchedByPeer := serviceMatchesChangedPeers(svc, proxyPeers, expanded)
|
||||
matchedByAccessGroup := anyInSet(svc.AccessGroups, r.changedGroupSet)
|
||||
matchedByAccessGroup := anyInSet(svc.AccessGroups, r.linkGroups)
|
||||
if !matchedByPeer && !matchedByAccessGroup {
|
||||
continue
|
||||
}
|
||||
log.WithContext(r.ctx).Tracef("collectFromProxyServices: service %s (cluster=%s) matched (byProxyOrTargetPeer=%t byAccessGroup=%t) -> folding %d proxy peers, peer targets and access groups %v",
|
||||
log.WithContext(r.ctx).Tracef("collectFromProxyServices: service %s (cluster=%s) matched (byProxyOrTargetPeer=%t byAccessGroup=%t) -> folding %d proxy peers, peer targets; access groups %v on outputGroups only",
|
||||
svc.ID, svc.ProxyCluster, matchedByPeer, matchedByAccessGroup, len(proxyPeers), svc.AccessGroups)
|
||||
for _, pid := range proxyPeers {
|
||||
r.peerSet[pid] = struct{}{}
|
||||
r.affectedPeers[pid] = struct{}{}
|
||||
}
|
||||
for _, target := range svc.Targets {
|
||||
if !target.Enabled {
|
||||
continue // a disabled target forwards nothing
|
||||
}
|
||||
if target.TargetType == rpservice.TargetTypePeer && target.TargetId != "" {
|
||||
r.peerSet[target.TargetId] = struct{}{}
|
||||
r.affectedPeers[target.TargetId] = struct{}{}
|
||||
}
|
||||
}
|
||||
addAll(r.groupSet, svc.AccessGroups)
|
||||
// AccessGroups are the matched (own) side with no opposite to fold: a member's
|
||||
// proxy access is self-contained, so a one-peer move refreshes that peer alone.
|
||||
// Fold the groups only on a whole-group change.
|
||||
r.foldOutputGroups(svc.AccessGroups)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *resolver) expandChangedPeersWithGroups() map[string]struct{} {
|
||||
if len(r.changedGroupSet) == 0 {
|
||||
return r.changedPeerSet
|
||||
if len(r.linkGroups) == 0 {
|
||||
return r.changedPeers
|
||||
}
|
||||
ids := r.peerIDsForGroups(r.changedGroupSet)
|
||||
ids := r.peerIDsForGroups(r.linkGroups)
|
||||
if len(ids) == 0 {
|
||||
return r.changedPeerSet
|
||||
return r.changedPeers
|
||||
}
|
||||
merged := make(map[string]struct{}, len(r.changedPeerSet)+len(ids))
|
||||
for id := range r.changedPeerSet {
|
||||
merged := make(map[string]struct{}, len(r.changedPeers)+len(ids))
|
||||
for id := range r.changedPeers {
|
||||
merged[id] = struct{}{}
|
||||
}
|
||||
for _, id := range ids {
|
||||
@@ -570,54 +877,36 @@ func (r *resolver) expandChangedPeersWithGroups() map[string]struct{} {
|
||||
return merged
|
||||
}
|
||||
|
||||
// collectResourceRouterBridge crosses between source peers and routing peers, which
|
||||
// are reachable only via resource -> network -> router, not through the policy's own
|
||||
// groups: source -> router (targeted resources' networks), then router -> source.
|
||||
func (r *resolver) collectResourceRouterBridge() {
|
||||
r.bridgeSourceToRouters()
|
||||
r.bridgeRoutersToSources()
|
||||
}
|
||||
|
||||
func (r *resolver) bridgeSourceToRouters() {
|
||||
resourceIDs := r.policyDestinationResourceIDs(r.matchedPolicies...)
|
||||
// foldRoutersForResources folds the routers serving the networks of the given
|
||||
// resources (a destination resource is reached through its network's routers). It is
|
||||
// the resource -> network -> router hop used by foldPolicySide for a destination.
|
||||
func (r *resolver) foldRoutersForResources(resourceIDs map[string]struct{}) {
|
||||
if len(resourceIDs) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
networkIDs := r.resourceNetworkIDs(resourceIDs)
|
||||
log.WithContext(r.ctx).Tracef("bridgeSourceToRouters: targeted resources %v -> networks %v (their routers become affected via the router->source pass)",
|
||||
setToSlice(resourceIDs), setToSlice(networkIDs))
|
||||
for id := range networkIDs {
|
||||
r.networkIDs[id] = struct{}{}
|
||||
}
|
||||
r.foldRoutersOnNetworks(r.resourceNetworkIDs(resourceIDs))
|
||||
}
|
||||
|
||||
func (r *resolver) bridgeRoutersToSources() {
|
||||
if len(r.networkIDs) == 0 {
|
||||
return
|
||||
// ruleDestinationResourceIDs returns the destination resource IDs of a single rule:
|
||||
// the direct DestinationResource plus the resources of its destination groups.
|
||||
func (r *resolver) ruleDestinationResourceIDs(rule *types.PolicyRule) map[string]struct{} {
|
||||
resourceIDs := make(map[string]struct{})
|
||||
if rule.DestinationResource.Type != types.ResourceTypePeer && rule.DestinationResource.ID != "" {
|
||||
resourceIDs[rule.DestinationResource.ID] = struct{}{}
|
||||
}
|
||||
r.addGroupResourceIDs(toSet(rule.Destinations), resourceIDs)
|
||||
return resourceIDs
|
||||
}
|
||||
|
||||
log.WithContext(r.ctx).Tracef("bridgeRoutersToSources: affected networks %v -> folding their routing peers and the source peers of policies targeting their resources",
|
||||
setToSlice(r.networkIDs))
|
||||
|
||||
r.foldRoutersOnNetworks(r.networkIDs)
|
||||
|
||||
// networkResourceIDs returns the IDs of all resources on the given network.
|
||||
func (r *resolver) networkResourceIDs(networkID string) map[string]struct{} {
|
||||
resourceIDs := make(map[string]struct{})
|
||||
for _, resource := range r.networkResources() {
|
||||
if _, ok := r.networkIDs[resource.NetworkID]; ok {
|
||||
if resource.NetworkID == networkID {
|
||||
resourceIDs[resource.ID] = struct{}{}
|
||||
}
|
||||
}
|
||||
if len(resourceIDs) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
for _, policy := range r.policies() {
|
||||
if r.policyTargetsResources(policy, resourceIDs) {
|
||||
log.WithContext(r.ctx).Tracef("bridgeRoutersToSources: policy %s (%s) targets an affected-network resource -> folding its source groups/peers", policy.ID, policy.Name)
|
||||
collectPolicySources(policy, r.groupSet, r.peerSet)
|
||||
}
|
||||
}
|
||||
return resourceIDs
|
||||
}
|
||||
|
||||
func (r *resolver) foldRoutersOnNetworks(networkIDs map[string]struct{}) {
|
||||
@@ -627,9 +916,9 @@ func (r *resolver) foldRoutersOnNetworks(networkIDs map[string]struct{}) {
|
||||
}
|
||||
log.WithContext(r.ctx).Tracef("bridgeRoutersToSources: router %s serves affected network %s -> folding peerGroups=%v peer=%q",
|
||||
router.ID, router.NetworkID, router.PeerGroups, router.Peer)
|
||||
addAll(r.groupSet, router.PeerGroups)
|
||||
addAll(r.affectedGroups, router.PeerGroups)
|
||||
if router.Peer != "" {
|
||||
r.peerSet[router.Peer] = struct{}{}
|
||||
r.affectedPeers[router.Peer] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -650,6 +939,9 @@ func (r *resolver) policyTargetsResources(policy *types.Policy, resourceIDs map[
|
||||
}
|
||||
destGroupSet := make(map[string]struct{})
|
||||
for _, rule := range policy.Rules {
|
||||
if !rule.Enabled {
|
||||
continue
|
||||
}
|
||||
if rule.DestinationResource.Type != types.ResourceTypePeer && isInSet(rule.DestinationResource.ID, resourceIDs) {
|
||||
return true
|
||||
}
|
||||
@@ -714,44 +1006,20 @@ func (r *resolver) addGroupResourceIDs(groupIDs map[string]struct{}, resourceIDs
|
||||
}
|
||||
}
|
||||
|
||||
func collectPolicyDirectPeers(policy *types.Policy, peerSet map[string]struct{}) {
|
||||
// collectPolicySources folds the source groups/peers of a snapshot policy's enabled
|
||||
// rules (a disabled rule grants no access).
|
||||
func collectPolicySources(policy *types.Policy, groups, peers map[string]struct{}) {
|
||||
for _, rule := range policy.Rules {
|
||||
if !rule.Enabled {
|
||||
continue
|
||||
}
|
||||
addAll(groups, rule.Sources)
|
||||
if rule.SourceResource.Type == types.ResourceTypePeer && rule.SourceResource.ID != "" {
|
||||
peerSet[rule.SourceResource.ID] = struct{}{}
|
||||
}
|
||||
if rule.DestinationResource.Type == types.ResourceTypePeer && rule.DestinationResource.ID != "" {
|
||||
peerSet[rule.DestinationResource.ID] = struct{}{}
|
||||
peers[rule.SourceResource.ID] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func collectPolicySources(policy *types.Policy, groupSet, peerSet map[string]struct{}) {
|
||||
for _, rule := range policy.Rules {
|
||||
addAll(groupSet, rule.Sources)
|
||||
if rule.SourceResource.Type == types.ResourceTypePeer && rule.SourceResource.ID != "" {
|
||||
peerSet[rule.SourceResource.ID] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func policyReferencesGroups(policy *types.Policy, groupSet map[string]struct{}) bool {
|
||||
for _, rule := range policy.Rules {
|
||||
if anyInSet(rule.Sources, groupSet) || anyInSet(rule.Destinations, groupSet) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func policyReferencesDirectPeers(policy *types.Policy, changedSet map[string]struct{}) bool {
|
||||
for _, rule := range policy.Rules {
|
||||
if isDirectPeerInSet(rule.SourceResource, changedSet) || isDirectPeerInSet(rule.DestinationResource, changedSet) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func policyReferencesPostureChecks(policy *types.Policy, ids map[string]struct{}) bool {
|
||||
for _, id := range policy.SourcePostureChecks {
|
||||
if _, ok := ids[id]; ok {
|
||||
@@ -776,7 +1044,7 @@ func serviceMatchesChangedPeers(svc *rpservice.Service, proxyPeers []string, cha
|
||||
}
|
||||
}
|
||||
for _, target := range svc.Targets {
|
||||
if target.TargetType != rpservice.TargetTypePeer || target.TargetId == "" {
|
||||
if !target.Enabled || target.TargetType != rpservice.TargetTypePeer || target.TargetId == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := changedPeers[target.TargetId]; ok {
|
||||
|
||||
@@ -10,8 +10,8 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
// policyGroupsAndPeers mirrors the explicit-policy extraction (RuleGroups +
|
||||
// direct peers) the resolver folds in, for asserting the pure logic.
|
||||
// policyGroupsAndPeers mirrors the both-sides extraction (RuleGroups + direct peers)
|
||||
// the resolver folds in for a changed policy, for asserting the pure logic.
|
||||
func policyGroupsAndPeers(policies ...*types.Policy) (groups []string, peers []string) {
|
||||
peerSet := map[string]struct{}{}
|
||||
for _, p := range policies {
|
||||
@@ -19,7 +19,14 @@ func policyGroupsAndPeers(policies ...*types.Policy) (groups []string, peers []s
|
||||
continue
|
||||
}
|
||||
groups = append(groups, p.RuleGroups()...)
|
||||
collectPolicyDirectPeers(p, peerSet)
|
||||
for _, rule := range p.Rules {
|
||||
if rule.SourceResource.Type == types.ResourceTypePeer && rule.SourceResource.ID != "" {
|
||||
peerSet[rule.SourceResource.ID] = struct{}{}
|
||||
}
|
||||
if rule.DestinationResource.Type == types.ResourceTypePeer && rule.DestinationResource.ID != "" {
|
||||
peerSet[rule.DestinationResource.ID] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
for id := range peerSet {
|
||||
peers = append(peers, id)
|
||||
@@ -80,26 +87,6 @@ func TestChangeIsEmpty(t *testing.T) {
|
||||
assert.False(t, Change{PostureCheckIDs: []string{"pc"}}.isEmpty())
|
||||
}
|
||||
|
||||
func TestPolicyReferencesGroups(t *testing.T) {
|
||||
policy := &types.Policy{Rules: []*types.PolicyRule{{Sources: []string{"g1", "g2"}, Destinations: []string{"g3"}}}}
|
||||
|
||||
assert.True(t, policyReferencesGroups(policy, map[string]struct{}{"g1": {}}))
|
||||
assert.True(t, policyReferencesGroups(policy, map[string]struct{}{"g3": {}}))
|
||||
assert.False(t, policyReferencesGroups(policy, map[string]struct{}{"g4": {}}))
|
||||
assert.False(t, policyReferencesGroups(policy, map[string]struct{}{}))
|
||||
}
|
||||
|
||||
func TestPolicyReferencesDirectPeers(t *testing.T) {
|
||||
policy := &types.Policy{Rules: []*types.PolicyRule{{
|
||||
SourceResource: types.Resource{Type: types.ResourceTypePeer, ID: "p1"},
|
||||
DestinationResource: types.Resource{Type: types.ResourceTypeHost, ID: "r1"},
|
||||
}}}
|
||||
|
||||
assert.True(t, policyReferencesDirectPeers(policy, map[string]struct{}{"p1": {}}))
|
||||
assert.False(t, policyReferencesDirectPeers(policy, map[string]struct{}{"r1": {}}))
|
||||
assert.False(t, policyReferencesDirectPeers(policy, map[string]struct{}{"p2": {}}))
|
||||
}
|
||||
|
||||
func TestPolicyReferencesPostureChecks(t *testing.T) {
|
||||
policy := &types.Policy{SourcePostureChecks: []string{"pc1", "pc2"}}
|
||||
|
||||
@@ -107,24 +94,9 @@ func TestPolicyReferencesPostureChecks(t *testing.T) {
|
||||
assert.False(t, policyReferencesPostureChecks(policy, map[string]struct{}{"pc3": {}}))
|
||||
}
|
||||
|
||||
func TestCollectPolicyDirectPeers(t *testing.T) {
|
||||
policy := &types.Policy{Rules: []*types.PolicyRule{{
|
||||
SourceResource: types.Resource{Type: types.ResourceTypePeer, ID: "p1"},
|
||||
DestinationResource: types.Resource{Type: types.ResourceTypePeer, ID: "p2"},
|
||||
}, {
|
||||
DestinationResource: types.Resource{Type: types.ResourceTypeHost, ID: "r1"},
|
||||
}}}
|
||||
|
||||
peerSet := map[string]struct{}{}
|
||||
collectPolicyDirectPeers(policy, peerSet)
|
||||
|
||||
assert.Contains(t, peerSet, "p1")
|
||||
assert.Contains(t, peerSet, "p2")
|
||||
assert.NotContains(t, peerSet, "r1")
|
||||
}
|
||||
|
||||
func TestCollectPolicySources(t *testing.T) {
|
||||
policy := &types.Policy{Rules: []*types.PolicyRule{{
|
||||
Enabled: true,
|
||||
Sources: []string{"g1"},
|
||||
SourceResource: types.Resource{Type: types.ResourceTypePeer, ID: "p1"},
|
||||
Destinations: []string{"g2"},
|
||||
|
||||
@@ -93,6 +93,12 @@ func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, use
|
||||
events := am.prepareGroupEvents(ctx, transaction, accountID, userID, newGroup)
|
||||
eventsToStore = append(eventsToStore, events...)
|
||||
|
||||
seq, err := transaction.AllocateAccountSeqID(ctx, accountID, types.AccountSeqEntityGroup)
|
||||
if err != nil {
|
||||
return status.Errorf(status.Internal, "failed to allocate group seq id: %v", err)
|
||||
}
|
||||
newGroup.AccountSeqID = seq
|
||||
|
||||
if err := transaction.CreateGroup(ctx, newGroup); err != nil {
|
||||
return status.Errorf(status.Internal, "failed to create group: %v", err)
|
||||
}
|
||||
@@ -158,6 +164,8 @@ func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, use
|
||||
return err
|
||||
}
|
||||
|
||||
newGroup.AccountSeqID = oldGroup.AccountSeqID
|
||||
|
||||
if err = transaction.UpdateGroup(ctx, newGroup); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -236,6 +244,12 @@ func (am *DefaultAccountManager) CreateGroups(ctx context.Context, accountID, us
|
||||
|
||||
newGroup.AccountID = accountID
|
||||
|
||||
seq, err := transaction.AllocateAccountSeqID(ctx, accountID, types.AccountSeqEntityGroup)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
newGroup.AccountSeqID = seq
|
||||
|
||||
if err = transaction.CreateGroup(ctx, newGroup); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -327,6 +341,12 @@ func (am *DefaultAccountManager) updateSingleGroup(ctx context.Context, accountI
|
||||
|
||||
newGroup.AccountID = accountID
|
||||
|
||||
oldGroup, err := transaction.GetGroupByID(ctx, store.LockingStrengthNone, accountID, newGroup.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
newGroup.AccountSeqID = oldGroup.AccountSeqID
|
||||
|
||||
if err := transaction.UpdateGroup(ctx, newGroup); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -341,7 +361,6 @@ func (am *DefaultAccountManager) updateSingleGroup(ctx context.Context, accountI
|
||||
|
||||
events = am.prepareGroupEvents(ctx, transaction, accountID, userID, newGroup)
|
||||
|
||||
var err error
|
||||
snap, err = affectedpeers.Load(ctx, transaction, accountID, change)
|
||||
return err
|
||||
})
|
||||
@@ -520,7 +539,12 @@ func collectDeletableGroups(ctx context.Context, transaction store.Store, accoun
|
||||
// GroupAddPeer appends peer to the group
|
||||
func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, groupID, peerID string) error {
|
||||
var snap *affectedpeers.Snapshot
|
||||
change := affectedpeers.Change{ChangedGroupIDs: []string{groupID}}
|
||||
// A membership change affects only the peer itself and the opposite side of THIS
|
||||
// group's policies — not the group's other members, and not the peer's other
|
||||
// groups. LinkGroups walks only this group (matched, not expanded); OutputPeerIDs
|
||||
// refreshes the peer without seeding its other group memberships. For an
|
||||
// intra-group policy the opposite side is the group, so its members still refresh.
|
||||
change := affectedpeers.Change{OutputPeerIDs: []string{peerID}, LinkGroups: []string{groupID}}
|
||||
|
||||
err := am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
if err := transaction.AddPeerToGroup(ctx, accountID, peerID, groupID); err != nil {
|
||||
@@ -586,10 +610,11 @@ func (am *DefaultAccountManager) GroupAddResource(ctx context.Context, accountID
|
||||
// GroupDeletePeer removes peer from the group
|
||||
func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID, groupID, peerID string) error {
|
||||
var snap *affectedpeers.Snapshot
|
||||
change := affectedpeers.Change{
|
||||
ChangedGroupIDs: []string{groupID},
|
||||
RemovedPeersByGroup: map[string][]string{groupID: {peerID}},
|
||||
}
|
||||
// Same as GroupAddPeer: the removed peer and the opposite side of THIS group's
|
||||
// policies refresh, not the group's other members or the peer's other groups. The
|
||||
// peer is no longer in the group's index, but LinkGroups still drives the
|
||||
// opposite-side walk, and OutputPeerIDs refreshes the removed peer itself.
|
||||
change := affectedpeers.Change{OutputPeerIDs: []string{peerID}, LinkGroups: []string{groupID}}
|
||||
|
||||
err := am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
if err := transaction.RemovePeerFromGroup(ctx, peerID, groupID); err != nil {
|
||||
@@ -600,8 +625,6 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID,
|
||||
return err
|
||||
}
|
||||
|
||||
// The removed peer is carried in change.RemovedPeersByGroup and folded in
|
||||
// only when the group is linked, so loading post-removal is correct.
|
||||
var err error
|
||||
if snap, err = affectedpeers.Load(ctx, transaction, accountID, change); err != nil {
|
||||
return err
|
||||
|
||||
156
management/server/migration/account_seq.go
Normal file
156
management/server/migration/account_seq.go
Normal file
@@ -0,0 +1,156 @@
|
||||
package migration
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
// BackfillAccountSeqIDs assigns a deterministic per-account sequential id to all
|
||||
// rows of `model` whose account_seq_id is zero, then seeds account_seq_counters
|
||||
// with the next free id per account. Idempotent: safe to re-run; both steps
|
||||
// no-op once everything is consistent.
|
||||
//
|
||||
// Implemented as two table-wide SQL statements with window functions, one
|
||||
// transaction. Backfilling 246k rows across 154k accounts on Postgres takes
|
||||
// well under a second instead of the per-account-loop ~2 minutes.
|
||||
//
|
||||
// orderColumn is the column to use when assigning the deterministic ordering
|
||||
// (typically the primary-key string id).
|
||||
func BackfillAccountSeqIDs[T any](
|
||||
ctx context.Context,
|
||||
db *gorm.DB,
|
||||
entity types.AccountSeqEntity,
|
||||
orderColumn string,
|
||||
) error {
|
||||
var model T
|
||||
if !db.Migrator().HasTable(&model) {
|
||||
log.WithContext(ctx).Debugf("backfill seq id: table for %T missing, skip", model)
|
||||
return nil
|
||||
}
|
||||
|
||||
stmt := &gorm.Statement{DB: db}
|
||||
if err := stmt.Parse(&model); err != nil {
|
||||
return fmt.Errorf("parse model: %w", err)
|
||||
}
|
||||
table := quoteIdent(db, stmt.Schema.Table)
|
||||
orderCol := quoteIdent(db, orderColumn)
|
||||
|
||||
return db.Transaction(func(tx *gorm.DB) error {
|
||||
var pending int64
|
||||
if err := tx.Raw(
|
||||
fmt.Sprintf("SELECT count(*) FROM %s WHERE account_seq_id IS NULL OR account_seq_id = 0", table),
|
||||
).Scan(&pending).Error; err != nil {
|
||||
return fmt.Errorf("count pending on %s: %w", table, err)
|
||||
}
|
||||
|
||||
if pending > 0 {
|
||||
log.WithContext(ctx).Infof("backfill seq id: %s — %d rows pending", table, pending)
|
||||
if err := backfillRankSQL(tx, table, orderCol); err != nil {
|
||||
return fmt.Errorf("rank %s: %w", table, err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := seedCountersSQL(tx, table, entity); err != nil {
|
||||
return fmt.Errorf("seed counters for %s: %w", entity, err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func quoteIdent(db *gorm.DB, name string) string {
|
||||
switch db.Dialector.Name() {
|
||||
case "mysql":
|
||||
return "`" + name + "`"
|
||||
case "postgres":
|
||||
return `"` + name + `"`
|
||||
default:
|
||||
return name
|
||||
}
|
||||
}
|
||||
|
||||
func backfillRankSQL(db *gorm.DB, table, orderCol string) error {
|
||||
dialect := db.Dialector.Name()
|
||||
var sql string
|
||||
switch dialect {
|
||||
case "postgres", "sqlite":
|
||||
sql = fmt.Sprintf(`
|
||||
WITH max_seq AS (
|
||||
SELECT account_id, COALESCE(MAX(account_seq_id), 0) AS max_seq
|
||||
FROM %s
|
||||
GROUP BY account_id
|
||||
),
|
||||
ranked AS (
|
||||
SELECT p.id,
|
||||
m.max_seq + ROW_NUMBER() OVER (PARTITION BY p.account_id ORDER BY p.%s) AS new_seq
|
||||
FROM %s p
|
||||
JOIN max_seq m ON p.account_id = m.account_id
|
||||
WHERE p.account_seq_id IS NULL OR p.account_seq_id = 0
|
||||
)
|
||||
UPDATE %s SET account_seq_id = ranked.new_seq
|
||||
FROM ranked
|
||||
WHERE %s.id = ranked.id
|
||||
`, table, orderCol, table, table, table)
|
||||
case "mysql":
|
||||
sql = fmt.Sprintf(`
|
||||
UPDATE %s p
|
||||
JOIN (
|
||||
SELECT account_id, COALESCE(MAX(account_seq_id), 0) AS max_seq
|
||||
FROM %s
|
||||
GROUP BY account_id
|
||||
) m ON p.account_id = m.account_id
|
||||
JOIN (
|
||||
SELECT id, ROW_NUMBER() OVER (PARTITION BY account_id ORDER BY %s) AS rn
|
||||
FROM %s
|
||||
WHERE account_seq_id IS NULL OR account_seq_id = 0
|
||||
) r ON p.id = r.id
|
||||
SET p.account_seq_id = m.max_seq + r.rn
|
||||
`, table, table, orderCol, table)
|
||||
default:
|
||||
return fmt.Errorf("unsupported dialect: %s", dialect)
|
||||
}
|
||||
return db.Exec(sql).Error
|
||||
}
|
||||
|
||||
func seedCountersSQL(db *gorm.DB, table string, entity types.AccountSeqEntity) error {
|
||||
dialect := db.Dialector.Name()
|
||||
var sql string
|
||||
switch dialect {
|
||||
case "postgres":
|
||||
sql = fmt.Sprintf(`
|
||||
INSERT INTO account_seq_counters (account_id, entity, next_id)
|
||||
SELECT account_id, ?, MAX(account_seq_id) + 1
|
||||
FROM %s
|
||||
WHERE account_seq_id IS NOT NULL AND account_seq_id > 0
|
||||
GROUP BY account_id
|
||||
ON CONFLICT (account_id, entity) DO UPDATE
|
||||
SET next_id = GREATEST(account_seq_counters.next_id, EXCLUDED.next_id)
|
||||
`, table)
|
||||
case "sqlite":
|
||||
sql = fmt.Sprintf(`
|
||||
INSERT INTO account_seq_counters (account_id, entity, next_id)
|
||||
SELECT account_id, ?, MAX(account_seq_id) + 1
|
||||
FROM %s
|
||||
WHERE account_seq_id IS NOT NULL AND account_seq_id > 0
|
||||
GROUP BY account_id
|
||||
ON CONFLICT (account_id, entity) DO UPDATE
|
||||
SET next_id = max(account_seq_counters.next_id, excluded.next_id)
|
||||
`, table)
|
||||
case "mysql":
|
||||
sql = fmt.Sprintf(`
|
||||
INSERT INTO account_seq_counters (account_id, entity, next_id)
|
||||
SELECT account_id, ?, MAX(account_seq_id) + 1
|
||||
FROM %s
|
||||
WHERE account_seq_id IS NOT NULL AND account_seq_id > 0
|
||||
GROUP BY account_id
|
||||
ON DUPLICATE KEY UPDATE next_id = GREATEST(next_id, VALUES(next_id))
|
||||
`, table)
|
||||
default:
|
||||
return fmt.Errorf("unsupported dialect: %s", dialect)
|
||||
}
|
||||
return db.Exec(sql, string(entity)).Error
|
||||
}
|
||||
@@ -67,6 +67,12 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco
|
||||
return err
|
||||
}
|
||||
|
||||
seq, err := transaction.AllocateAccountSeqID(ctx, accountID, types.AccountSeqEntityNameserverGroup)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
newNSGroup.AccountSeqID = seq
|
||||
|
||||
if err = transaction.SaveNameServerGroup(ctx, newNSGroup); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -116,6 +122,8 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun
|
||||
return err
|
||||
}
|
||||
|
||||
nsGroupToSave.AccountSeqID = oldNSGroup.AccountSeqID
|
||||
|
||||
if err = transaction.SaveNameServerGroup(ctx, nsGroupToSave); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
serverTypes "github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
@@ -71,9 +72,20 @@ func (m *managerImpl) CreateNetwork(ctx context.Context, userID string, network
|
||||
|
||||
network.ID = xid.New().String()
|
||||
|
||||
err = m.store.SaveNetwork(ctx, network)
|
||||
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
seq, err := transaction.AllocateAccountSeqID(ctx, network.AccountID, serverTypes.AccountSeqEntityNetwork)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to allocate network seq id: %w", err)
|
||||
}
|
||||
network.AccountSeqID = seq
|
||||
|
||||
if err := transaction.SaveNetwork(ctx, network); err != nil {
|
||||
return fmt.Errorf("failed to save network: %w", err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to save network: %w", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
m.accountManager.StoreEvent(ctx, userID, network.ID, network.AccountID, activity.NetworkCreated, network.EventMeta())
|
||||
@@ -102,14 +114,25 @@ func (m *managerImpl) UpdateNetwork(ctx context.Context, userID string, network
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
_, err = m.store.GetNetworkByID(ctx, store.LockingStrengthUpdate, network.AccountID, network.ID)
|
||||
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
existing, err := transaction.GetNetworkByID(ctx, store.LockingStrengthUpdate, network.AccountID, network.ID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get network: %w", err)
|
||||
}
|
||||
network.AccountSeqID = existing.AccountSeqID
|
||||
|
||||
if err := transaction.SaveNetwork(ctx, network); err != nil {
|
||||
return fmt.Errorf("failed to save network: %w", err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get network: %w", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
m.accountManager.StoreEvent(ctx, userID, network.ID, network.AccountID, activity.NetworkUpdated, network.EventMeta())
|
||||
|
||||
return network, m.store.SaveNetwork(ctx, network)
|
||||
return network, nil
|
||||
}
|
||||
|
||||
func (m *managerImpl) DeleteNetwork(ctx context.Context, accountID, userID, networkID string) error {
|
||||
|
||||
@@ -255,3 +255,73 @@ func Test_UpdateNetworkFailsWithPermissionDenied(t *testing.T) {
|
||||
require.Error(t, err)
|
||||
require.Nil(t, updatedNetwork)
|
||||
}
|
||||
|
||||
// Test_CreateNetworkAllocatesSeqID verifies that CreateNetwork sets a
|
||||
// non-zero AccountSeqID on the persisted network (allocated through the
|
||||
// account_seq_counters table).
|
||||
func Test_CreateNetworkAllocatesSeqID(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
const accountID = "testAccountId"
|
||||
const userID = "testAdminId"
|
||||
|
||||
s, cleanUp, err := store.NewTestStoreFromSQL(ctx, "../testdata/networks.sql", t.TempDir())
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(cleanUp)
|
||||
|
||||
am := mock_server.MockAccountManager{}
|
||||
permissionsManager := permissions.NewManager(s)
|
||||
groupsManager := groups.NewManagerMock()
|
||||
routerManager := routers.NewManagerMock()
|
||||
resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am, nil)
|
||||
manager := NewManager(s, permissionsManager, resourcesManager, routerManager, &am)
|
||||
|
||||
created, err := manager.CreateNetwork(ctx, userID, &types.Network{
|
||||
AccountID: accountID,
|
||||
Name: "seq-allocation-test",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotZero(t, created.AccountSeqID, "CreateNetwork must allocate a non-zero AccountSeqID")
|
||||
}
|
||||
|
||||
// Test_UpdateNetworkPreservesSeqID verifies UpdateNetwork does not reset
|
||||
// AccountSeqID even when the caller passes a zero value (the shape REST
|
||||
// handlers produce because the field is `json:"-"`).
|
||||
func Test_UpdateNetworkPreservesSeqID(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
const accountID = "testAccountId"
|
||||
const userID = "testAdminId"
|
||||
|
||||
s, cleanUp, err := store.NewTestStoreFromSQL(ctx, "../testdata/networks.sql", t.TempDir())
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(cleanUp)
|
||||
|
||||
am := mock_server.MockAccountManager{}
|
||||
permissionsManager := permissions.NewManager(s)
|
||||
groupsManager := groups.NewManagerMock()
|
||||
routerManager := routers.NewManagerMock()
|
||||
resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am, nil)
|
||||
manager := NewManager(s, permissionsManager, resourcesManager, routerManager, &am)
|
||||
|
||||
created, err := manager.CreateNetwork(ctx, userID, &types.Network{
|
||||
AccountID: accountID,
|
||||
Name: "seq-preserve-original",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
originalSeq := created.AccountSeqID
|
||||
require.NotZero(t, originalSeq)
|
||||
|
||||
update := &types.Network{
|
||||
AccountID: accountID,
|
||||
ID: created.ID,
|
||||
Name: "seq-preserve-renamed",
|
||||
}
|
||||
require.Zero(t, update.AccountSeqID, "incoming struct must mirror an HTTP handler shape")
|
||||
|
||||
_, err = manager.UpdateNetwork(ctx, userID, update)
|
||||
require.NoError(t, err)
|
||||
|
||||
got, err := manager.GetNetwork(ctx, accountID, userID, created.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, originalSeq, got.AccountSeqID, "AccountSeqID must survive UpdateNetwork")
|
||||
require.Equal(t, "seq-preserve-renamed", got.Name)
|
||||
}
|
||||
|
||||
@@ -146,6 +146,12 @@ func (m *managerImpl) createResourceInTransaction(ctx context.Context, transacti
|
||||
return nil, nil, fmt.Errorf("failed to get network: %w", err)
|
||||
}
|
||||
|
||||
seq, err := transaction.AllocateAccountSeqID(ctx, resource.AccountID, nbtypes.AccountSeqEntityNetworkResource)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to allocate network resource seq id: %w", err)
|
||||
}
|
||||
resource.AccountSeqID = seq
|
||||
|
||||
if err = transaction.SaveNetworkResource(ctx, resource); err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to save network resource: %w", err)
|
||||
}
|
||||
@@ -245,6 +251,7 @@ func (m *managerImpl) UpdateResource(ctx context.Context, userID string, resourc
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get network resource: %w", err)
|
||||
}
|
||||
resource.AccountSeqID = oldResource.AccountSeqID
|
||||
|
||||
oldGroups, err := m.groupsManager.GetResourceGroupsInTransaction(ctx, transaction, store.LockingStrengthNone, resource.AccountID, resource.ID)
|
||||
if err != nil {
|
||||
|
||||
@@ -32,6 +32,9 @@ type NetworkResource struct {
|
||||
ID string `gorm:"primaryKey"`
|
||||
NetworkID string `gorm:"index"`
|
||||
AccountID string `gorm:"index"`
|
||||
// AccountSeqID is a per-account monotonically increasing identifier used as the
|
||||
// compact wire id when sending NetworkMap components to capable peers.
|
||||
AccountSeqID uint32 `json:"-" gorm:"index:idx_network_resources_account_seq_id;not null;default:0"`
|
||||
Name string
|
||||
Description string
|
||||
Type NetworkResourceType
|
||||
@@ -93,17 +96,18 @@ func (n *NetworkResource) FromAPIRequest(req *api.NetworkResourceRequest) {
|
||||
|
||||
func (n *NetworkResource) Copy() *NetworkResource {
|
||||
return &NetworkResource{
|
||||
ID: n.ID,
|
||||
AccountID: n.AccountID,
|
||||
NetworkID: n.NetworkID,
|
||||
Name: n.Name,
|
||||
Description: n.Description,
|
||||
Type: n.Type,
|
||||
Address: n.Address,
|
||||
Domain: n.Domain,
|
||||
Prefix: n.Prefix,
|
||||
GroupIDs: n.GroupIDs,
|
||||
Enabled: n.Enabled,
|
||||
ID: n.ID,
|
||||
AccountID: n.AccountID,
|
||||
NetworkID: n.NetworkID,
|
||||
AccountSeqID: n.AccountSeqID,
|
||||
Name: n.Name,
|
||||
Description: n.Description,
|
||||
Type: n.Type,
|
||||
Address: n.Address,
|
||||
Domain: n.Domain,
|
||||
Prefix: n.Prefix,
|
||||
GroupIDs: n.GroupIDs,
|
||||
Enabled: n.Enabled,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
serverTypes "github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
@@ -104,6 +105,12 @@ func (m *managerImpl) CreateRouter(ctx context.Context, userID string, router *t
|
||||
|
||||
router.ID = xid.New().String()
|
||||
|
||||
seq, err := transaction.AllocateAccountSeqID(ctx, router.AccountID, serverTypes.AccountSeqEntityNetworkRouter)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to allocate network router seq id: %w", err)
|
||||
}
|
||||
router.AccountSeqID = seq
|
||||
|
||||
err = transaction.CreateNetworkRouter(ctx, router)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create network router: %w", err)
|
||||
@@ -199,6 +206,11 @@ func (m *managerImpl) updateRouterInTransaction(ctx context.Context, transaction
|
||||
return nil, nil, affectedpeers.Change{}, status.NewRouterNotPartOfNetworkError(router.ID, router.NetworkID)
|
||||
}
|
||||
|
||||
// Preserve AccountSeqID from the existing router so the upstream
|
||||
// UpdateNetworkRouter (which does Updates(router) with Select("*"))
|
||||
// doesn't clobber it with the request's zero value.
|
||||
router.AccountSeqID = existing.AccountSeqID
|
||||
|
||||
if err = transaction.UpdateNetworkRouter(ctx, router); err != nil {
|
||||
return nil, nil, affectedpeers.Change{}, fmt.Errorf("failed to update network router: %w", err)
|
||||
}
|
||||
|
||||
@@ -13,6 +13,9 @@ type NetworkRouter struct {
|
||||
ID string `gorm:"primaryKey"`
|
||||
NetworkID string `gorm:"index"`
|
||||
AccountID string `gorm:"index"`
|
||||
// AccountSeqID is a per-account monotonically increasing identifier used as the
|
||||
// compact wire id when sending NetworkMap components to capable peers.
|
||||
AccountSeqID uint32 `json:"-" gorm:"index:idx_network_routers_account_seq_id;not null;default:0"`
|
||||
Peer string
|
||||
PeerGroups []string `gorm:"serializer:json"`
|
||||
Masquerade bool
|
||||
@@ -78,14 +81,15 @@ func (n *NetworkRouter) FromAPIRequest(req *api.NetworkRouterRequest) {
|
||||
|
||||
func (n *NetworkRouter) Copy() *NetworkRouter {
|
||||
return &NetworkRouter{
|
||||
ID: n.ID,
|
||||
NetworkID: n.NetworkID,
|
||||
AccountID: n.AccountID,
|
||||
Peer: n.Peer,
|
||||
PeerGroups: n.PeerGroups,
|
||||
Masquerade: n.Masquerade,
|
||||
Metric: n.Metric,
|
||||
Enabled: n.Enabled,
|
||||
ID: n.ID,
|
||||
NetworkID: n.NetworkID,
|
||||
AccountID: n.AccountID,
|
||||
AccountSeqID: n.AccountSeqID,
|
||||
Peer: n.Peer,
|
||||
PeerGroups: n.PeerGroups,
|
||||
Masquerade: n.Masquerade,
|
||||
Metric: n.Metric,
|
||||
Enabled: n.Enabled,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -7,12 +7,24 @@ import (
|
||||
)
|
||||
|
||||
type Network struct {
|
||||
ID string `gorm:"primaryKey"`
|
||||
AccountID string `gorm:"index"`
|
||||
ID string `gorm:"primaryKey"`
|
||||
AccountID string `gorm:"index"`
|
||||
|
||||
// AccountSeqID is a per-account monotonically increasing identifier used as the
|
||||
// compact wire id when sending NetworkMap components to capable peers.
|
||||
AccountSeqID uint32 `json:"-" gorm:"index:idx_networks_account_seq_id;not null;default:0"`
|
||||
|
||||
Name string
|
||||
Description string
|
||||
}
|
||||
|
||||
// HasSeqID reports whether the network has been persisted long enough to have
|
||||
// a per-account sequence id allocated. Wire encoders that key off AccountSeqID
|
||||
// must skip networks that return false here.
|
||||
func (n *Network) HasSeqID() bool {
|
||||
return n != nil && n.AccountSeqID != 0
|
||||
}
|
||||
|
||||
func NewNetwork(accountId, name, description string) *Network {
|
||||
return &Network{
|
||||
ID: xid.New().String(),
|
||||
@@ -41,13 +53,14 @@ func (n *Network) FromAPIRequest(req *api.NetworkRequest) {
|
||||
}
|
||||
}
|
||||
|
||||
// Copy returns a copy of a posture checks.
|
||||
// Copy returns a copy of a network.
|
||||
func (n *Network) Copy() *Network {
|
||||
return &Network{
|
||||
ID: n.ID,
|
||||
AccountID: n.AccountID,
|
||||
Name: n.Name,
|
||||
Description: n.Description,
|
||||
ID: n.ID,
|
||||
AccountID: n.AccountID,
|
||||
AccountSeqID: n.AccountSeqID,
|
||||
Name: n.Name,
|
||||
Description: n.Description,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -209,14 +209,14 @@ func (am *DefaultAccountManager) resolvePeerLocation(ctx context.Context, peer *
|
||||
if am.geo == nil || realIP == nil {
|
||||
return nil
|
||||
}
|
||||
if peer.Location.ConnectionIP != nil && peer.Location.ConnectionIP.Equal(realIP) {
|
||||
return nil
|
||||
}
|
||||
location, err := am.geo.Lookup(realIP)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("failed to get location for peer %s realip: [%s]: %v", peer.ID, realIP.String(), err)
|
||||
return nil
|
||||
}
|
||||
if peer.Location.ConnectionIP != nil && peer.Location.ConnectionIP.Equal(realIP) && peer.Location.GeoNameID == location.City.GeonameID {
|
||||
return nil
|
||||
}
|
||||
return &nbpeer.Location{
|
||||
ConnectionIP: realIP,
|
||||
CountryCode: location.Country.ISOCode,
|
||||
@@ -1052,7 +1052,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy
|
||||
}
|
||||
|
||||
metaDiffAffectsPosture := posture.AffectsPosture(ctx, &metaDiff, resPostureChecks)
|
||||
if isStatusChanged || sync.UpdateAccountPeers || ipv6CapabilityChanged || metaDiffAffectsPosture || metaDiff.VersionChanged() || metaDiff.HostnameChanged() {
|
||||
if requiresPeerUpdate(ctx, isStatusChanged, sync.UpdateAccountPeers, ipv6CapabilityChanged, metaDiffAffectsPosture, metaDiff.VersionChanged(), metaDiff.HostnameChanged()) {
|
||||
changedPeerIDs := []string{peer.ID}
|
||||
affectedPeerIDs := am.syncPeerAffectedPeers(ctx, accountID, peer.ID, nmap, peerNotValid, metaDiffAffectsPosture)
|
||||
if err = am.networkMapController.OnPeersUpdated(ctx, accountID, changedPeerIDs, affectedPeerIDs); err != nil {
|
||||
@@ -1063,6 +1063,29 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy
|
||||
return peer, nmap, resPostureChecks, dnsFwdPort, nil
|
||||
}
|
||||
|
||||
func requiresPeerUpdate(ctx context.Context, isStatusChanged, updateAccountPeers, ipv6CapabilityChanged, metaDiffAffectsPosture, versionChanged, hostname bool) bool {
|
||||
var reason string
|
||||
switch {
|
||||
case isStatusChanged:
|
||||
reason = "status changed"
|
||||
case updateAccountPeers:
|
||||
reason = "update account peers"
|
||||
case ipv6CapabilityChanged:
|
||||
reason = "ipv6 capability changed"
|
||||
case metaDiffAffectsPosture:
|
||||
reason = "meta diff affects posture"
|
||||
case versionChanged:
|
||||
reason = "version changed"
|
||||
case hostname:
|
||||
reason = "hostname changed"
|
||||
default:
|
||||
return false
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Tracef("peer update required: %s", reason)
|
||||
return true
|
||||
}
|
||||
|
||||
// syncPeerAffectedPeers resolves the peers affected by a SyncPeer change. The
|
||||
// peer's own validated network map is bidirectional for policy and routing
|
||||
// reachability, so when the peer stays valid and no source-posture gate is in
|
||||
|
||||
@@ -17,8 +17,9 @@ import (
|
||||
|
||||
// Peer capability constants mirror the proto enum values.
|
||||
const (
|
||||
PeerCapabilitySourcePrefixes int32 = 1
|
||||
PeerCapabilityIPv6Overlay int32 = 2
|
||||
PeerCapabilitySourcePrefixes int32 = 1
|
||||
PeerCapabilityIPv6Overlay int32 = 2
|
||||
PeerCapabilityComponentNetworkMap int32 = 3
|
||||
)
|
||||
|
||||
// Peer represents a machine connected to the network.
|
||||
@@ -218,6 +219,14 @@ func (p *Peer) SupportsSourcePrefixes() bool {
|
||||
return p.HasCapability(PeerCapabilitySourcePrefixes)
|
||||
}
|
||||
|
||||
// SupportsComponentNetworkMap reports whether the peer assembles its
|
||||
// NetworkMap from server-shipped components instead of consuming a fully
|
||||
// expanded NetworkMap. Determines whether the network_map controller skips
|
||||
// Calculate() server-side and emits the components envelope.
|
||||
func (p *Peer) SupportsComponentNetworkMap() bool {
|
||||
return p.HasCapability(PeerCapabilityComponentNetworkMap)
|
||||
}
|
||||
|
||||
func capabilitiesEqual(a, b []int32) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
|
||||
@@ -49,6 +49,7 @@ import (
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
@@ -2893,3 +2894,141 @@ func TestUpdatePeer_DnsLabelUniqueName(t *testing.T) {
|
||||
require.NoError(t, err, "renaming to unique FQDN should succeed")
|
||||
assert.Equal(t, "api-server", updated.DNSLabel, "DNS label should be first label of FQDN")
|
||||
}
|
||||
|
||||
// fakeGeo is a configurable geolocation.Geolocation implementation for tests. It
|
||||
// returns a record built from the configured city geoname id, or an error when set.
|
||||
type fakeGeo struct {
|
||||
geoNameID uint
|
||||
isoCode string
|
||||
cityName string
|
||||
err error
|
||||
}
|
||||
|
||||
func (g *fakeGeo) Lookup(net.IP) (*geolocation.Record, error) {
|
||||
if g.err != nil {
|
||||
return nil, g.err
|
||||
}
|
||||
record := &geolocation.Record{}
|
||||
record.City.GeonameID = g.geoNameID
|
||||
record.City.Names.En = g.cityName
|
||||
record.Country.ISOCode = g.isoCode
|
||||
return record, nil
|
||||
}
|
||||
|
||||
func (g *fakeGeo) GetAllCountries() ([]geolocation.Country, error) { return nil, nil }
|
||||
|
||||
func (g *fakeGeo) GetCitiesByCountry(string) ([]geolocation.City, error) { return nil, nil }
|
||||
|
||||
func (g *fakeGeo) Stop() error { return nil }
|
||||
|
||||
func TestResolvePeerLocation(t *testing.T) {
|
||||
realIP := net.ParseIP("203.0.113.10")
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
geo geolocation.Geolocation
|
||||
peer *nbpeer.Peer
|
||||
realIP net.IP
|
||||
want *nbpeer.Location
|
||||
wantNil bool
|
||||
}{
|
||||
{
|
||||
name: "no geo configured returns nil",
|
||||
geo: nil,
|
||||
peer: &nbpeer.Peer{ID: "p1"},
|
||||
realIP: realIP,
|
||||
wantNil: true,
|
||||
},
|
||||
{
|
||||
name: "nil real IP returns nil",
|
||||
geo: &fakeGeo{geoNameID: 100},
|
||||
peer: &nbpeer.Peer{ID: "p1"},
|
||||
realIP: nil,
|
||||
wantNil: true,
|
||||
},
|
||||
{
|
||||
name: "lookup error returns nil",
|
||||
geo: &fakeGeo{err: fmt.Errorf("lookup boom")},
|
||||
peer: &nbpeer.Peer{ID: "p1"},
|
||||
realIP: realIP,
|
||||
wantNil: true,
|
||||
},
|
||||
{
|
||||
name: "same IP and same geoname returns nil",
|
||||
geo: &fakeGeo{geoNameID: 100, isoCode: "US", cityName: "City A"},
|
||||
peer: &nbpeer.Peer{
|
||||
ID: "p1",
|
||||
Location: nbpeer.Location{
|
||||
ConnectionIP: realIP,
|
||||
GeoNameID: 100,
|
||||
},
|
||||
},
|
||||
realIP: realIP,
|
||||
wantNil: true,
|
||||
},
|
||||
{
|
||||
name: "same IP but changed geoname returns location",
|
||||
geo: &fakeGeo{geoNameID: 200, isoCode: "US", cityName: "City B"},
|
||||
peer: &nbpeer.Peer{
|
||||
ID: "p1",
|
||||
Location: nbpeer.Location{
|
||||
ConnectionIP: realIP,
|
||||
GeoNameID: 100,
|
||||
},
|
||||
},
|
||||
realIP: realIP,
|
||||
want: &nbpeer.Location{
|
||||
ConnectionIP: realIP,
|
||||
CountryCode: "US",
|
||||
CityName: "City B",
|
||||
GeoNameID: 200,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "different IP returns location",
|
||||
geo: &fakeGeo{geoNameID: 100, isoCode: "US", cityName: "City A"},
|
||||
peer: &nbpeer.Peer{
|
||||
ID: "p1",
|
||||
Location: nbpeer.Location{
|
||||
ConnectionIP: net.ParseIP("198.51.100.7"),
|
||||
GeoNameID: 100,
|
||||
},
|
||||
},
|
||||
realIP: realIP,
|
||||
want: &nbpeer.Location{
|
||||
ConnectionIP: realIP,
|
||||
CountryCode: "US",
|
||||
CityName: "City A",
|
||||
GeoNameID: 100,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "no prior location returns location",
|
||||
geo: &fakeGeo{geoNameID: 100, isoCode: "US", cityName: "City A"},
|
||||
peer: &nbpeer.Peer{ID: "p1"},
|
||||
realIP: realIP,
|
||||
want: &nbpeer.Location{
|
||||
ConnectionIP: realIP,
|
||||
CountryCode: "US",
|
||||
CityName: "City A",
|
||||
GeoNameID: 100,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
am := &DefaultAccountManager{geo: tt.geo}
|
||||
got := am.resolvePeerLocation(context.Background(), tt.peer, tt.realIP)
|
||||
if tt.wantNil {
|
||||
assert.Nil(t, got, "resolved location should be nil")
|
||||
return
|
||||
}
|
||||
require.NotNil(t, got, "resolved location should not be nil")
|
||||
assert.True(t, tt.want.ConnectionIP.Equal(got.ConnectionIP), "connection IP should match")
|
||||
assert.Equal(t, tt.want.CountryCode, got.CountryCode, "country code should match")
|
||||
assert.Equal(t, tt.want.CityName, got.CityName, "city name should match")
|
||||
assert.Equal(t, tt.want.GeoNameID, got.GeoNameID, "geoname id should match")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -67,10 +67,18 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
|
||||
|
||||
action = activity.PolicyUpdated
|
||||
|
||||
policy.AccountSeqID = existingPolicy.AccountSeqID
|
||||
|
||||
if err = transaction.SavePolicy(ctx, policy); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
seq, err := transaction.AllocateAccountSeqID(ctx, accountID, types.AccountSeqEntityPolicy)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
policy.AccountSeqID = seq
|
||||
|
||||
if err = transaction.CreatePolicy(ctx, policy); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -49,6 +49,10 @@ type Checks struct {
|
||||
// AccountID is a reference to the Account that this object belongs
|
||||
AccountID string `json:"-" gorm:"index"`
|
||||
|
||||
// AccountSeqID is a per-account monotonically increasing identifier used as the
|
||||
// compact wire id when sending NetworkMap components to capable peers.
|
||||
AccountSeqID uint32 `json:"-" gorm:"index:idx_posture_checks_account_seq_id;not null;default:0"`
|
||||
|
||||
// Checks is a set of objects that perform the actual checks
|
||||
Checks ChecksDefinition `gorm:"serializer:json"`
|
||||
}
|
||||
@@ -93,6 +97,13 @@ func verdictChanged(ctx context.Context, check Check, oldPeer, newPeer nbpeer.Pe
|
||||
return changed
|
||||
}
|
||||
|
||||
// HasSeqID reports whether the posture check has been persisted long enough
|
||||
// to have a per-account sequence id allocated. Wire encoders that key off
|
||||
// AccountSeqID must skip checks that return false here.
|
||||
func (pc *Checks) HasSeqID() bool {
|
||||
return pc != nil && pc.AccountSeqID != 0
|
||||
}
|
||||
|
||||
// ChecksDefinition contains definition of actual check
|
||||
type ChecksDefinition struct {
|
||||
NBVersionCheck *NBVersionCheck `json:",omitempty"`
|
||||
@@ -163,11 +174,12 @@ func (*Checks) TableName() string {
|
||||
// Copy returns a copy of a posture checks.
|
||||
func (pc *Checks) Copy() *Checks {
|
||||
checks := &Checks{
|
||||
ID: pc.ID,
|
||||
Name: pc.Name,
|
||||
Description: pc.Description,
|
||||
AccountID: pc.AccountID,
|
||||
Checks: pc.Checks.Copy(),
|
||||
ID: pc.ID,
|
||||
Name: pc.Name,
|
||||
Description: pc.Description,
|
||||
AccountID: pc.AccountID,
|
||||
AccountSeqID: pc.AccountSeqID,
|
||||
Checks: pc.Checks.Copy(),
|
||||
}
|
||||
return checks
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
@@ -52,7 +53,19 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI
|
||||
}
|
||||
|
||||
if isUpdate {
|
||||
existing, err := transaction.GetPostureChecksByID(ctx, store.LockingStrengthNone, accountID, postureChecks.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
postureChecks.AccountSeqID = existing.AccountSeqID
|
||||
|
||||
action = activity.PostureCheckUpdated
|
||||
} else {
|
||||
seq, err := transaction.AllocateAccountSeqID(ctx, accountID, types.AccountSeqEntityPostureCheck)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
postureChecks.AccountSeqID = seq
|
||||
}
|
||||
|
||||
postureChecks.AccountID = accountID
|
||||
|
||||
@@ -489,6 +489,7 @@ func TestArePostureCheckChangesAffectPeers(t *testing.T) {
|
||||
|
||||
policy := &types.Policy{
|
||||
AccountID: account.Id,
|
||||
Enabled: true,
|
||||
Rules: []*types.PolicyRule{
|
||||
{
|
||||
Enabled: true,
|
||||
@@ -562,3 +563,61 @@ func TestArePostureCheckChangesAffectPeers(t *testing.T) {
|
||||
assert.Empty(t, directPeerIDs)
|
||||
})
|
||||
}
|
||||
|
||||
// TestSavePostureChecks_AllocatesSeqIDOnCreate verifies that the create path
|
||||
// (no incoming ID) allocates a non-zero AccountSeqID via the
|
||||
// account_seq_counters table.
|
||||
func TestSavePostureChecks_AllocatesSeqIDOnCreate(t *testing.T) {
|
||||
am, _, err := createManager(t)
|
||||
require.NoError(t, err)
|
||||
|
||||
account, err := initTestPostureChecksAccount(am)
|
||||
require.NoError(t, err)
|
||||
|
||||
created, err := am.SavePostureChecks(context.Background(), account.Id, adminUserID, &posture.Checks{
|
||||
Name: "seq-allocation-test",
|
||||
Checks: posture.ChecksDefinition{
|
||||
NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.26.0"},
|
||||
},
|
||||
}, true)
|
||||
require.NoError(t, err)
|
||||
require.NotZero(t, created.AccountSeqID, "SavePostureChecks on create must allocate a non-zero AccountSeqID")
|
||||
}
|
||||
|
||||
// TestSavePostureChecks_PreservesSeqIDOnUpdate verifies the update path does
|
||||
// not reset AccountSeqID even when the caller passes a zero value (REST
|
||||
// handler shape, because the field is `json:"-"`).
|
||||
func TestSavePostureChecks_PreservesSeqIDOnUpdate(t *testing.T) {
|
||||
am, _, err := createManager(t)
|
||||
require.NoError(t, err)
|
||||
|
||||
account, err := initTestPostureChecksAccount(am)
|
||||
require.NoError(t, err)
|
||||
|
||||
created, err := am.SavePostureChecks(context.Background(), account.Id, adminUserID, &posture.Checks{
|
||||
Name: "seq-preserve-original",
|
||||
Checks: posture.ChecksDefinition{
|
||||
NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.26.0"},
|
||||
},
|
||||
}, true)
|
||||
require.NoError(t, err)
|
||||
originalSeq := created.AccountSeqID
|
||||
require.NotZero(t, originalSeq)
|
||||
|
||||
update := &posture.Checks{
|
||||
ID: created.ID,
|
||||
Name: "seq-preserve-renamed",
|
||||
Checks: posture.ChecksDefinition{
|
||||
NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.27.0"},
|
||||
},
|
||||
}
|
||||
require.Zero(t, update.AccountSeqID, "incoming struct must mirror an HTTP handler shape")
|
||||
|
||||
_, err = am.SavePostureChecks(context.Background(), account.Id, adminUserID, update, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
got, err := am.GetPostureChecks(context.Background(), account.Id, created.ID, adminUserID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, originalSeq, got.AccountSeqID, "AccountSeqID must survive SavePostureChecks update")
|
||||
require.Equal(t, "seq-preserve-renamed", got.Name)
|
||||
}
|
||||
|
||||
@@ -175,6 +175,12 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
|
||||
return err
|
||||
}
|
||||
|
||||
seq, err := transaction.AllocateAccountSeqID(ctx, accountID, types.AccountSeqEntityRoute)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
newRoute.AccountSeqID = seq
|
||||
|
||||
if err = transaction.SaveRoute(ctx, newRoute); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -222,6 +228,7 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
|
||||
}
|
||||
|
||||
routeToSave.AccountID = accountID
|
||||
routeToSave.AccountSeqID = oldRoute.AccountSeqID
|
||||
|
||||
if err = transaction.SaveRoute(ctx, routeToSave); err != nil {
|
||||
return err
|
||||
|
||||
506
management/server/store/account_seq_test.go
Normal file
506
management/server/store/account_seq_test.go
Normal file
@@ -0,0 +1,506 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
var errRollback = errors.New("intentional rollback")
|
||||
|
||||
func TestAllocateAccountSeqID_SequentialPerAccount(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
const accA = "acc-a"
|
||||
const accB = "acc-b"
|
||||
|
||||
require.NoError(t, store.ExecuteInTransaction(ctx, func(tx Store) error {
|
||||
got, err := tx.AllocateAccountSeqID(ctx, accA, types.AccountSeqEntityPolicy)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, uint32(1), got)
|
||||
|
||||
got, err = tx.AllocateAccountSeqID(ctx, accA, types.AccountSeqEntityPolicy)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, uint32(2), got)
|
||||
|
||||
got, err = tx.AllocateAccountSeqID(ctx, accB, types.AccountSeqEntityPolicy)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, uint32(1), got, "different account starts from 1")
|
||||
|
||||
got, err = tx.AllocateAccountSeqID(ctx, accA, types.AccountSeqEntityGroup)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, uint32(1), got, "different entity starts from 1")
|
||||
|
||||
return nil
|
||||
}))
|
||||
|
||||
require.NoError(t, store.ExecuteInTransaction(ctx, func(tx Store) error {
|
||||
got, err := tx.AllocateAccountSeqID(ctx, accA, types.AccountSeqEntityPolicy)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, uint32(3), got, "counter persists across transactions")
|
||||
return nil
|
||||
}))
|
||||
}
|
||||
|
||||
func TestPolicyBackfill_AssignsSeqIDsToExistingPolicies(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
const accountID = "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
|
||||
policies, err := store.GetAccountPolicies(ctx, LockingStrengthNone, accountID)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, policies, "test fixture must have policies")
|
||||
|
||||
seen := make(map[uint32]bool)
|
||||
for _, p := range policies {
|
||||
require.NotZero(t, p.AccountSeqID, "policy %s must have a non-zero AccountSeqID after migration", p.ID)
|
||||
require.False(t, seen[p.AccountSeqID], "duplicate AccountSeqID %d in account %s", p.AccountSeqID, accountID)
|
||||
seen[p.AccountSeqID] = true
|
||||
}
|
||||
}
|
||||
|
||||
func TestPolicyUpdate_PreservesSeqID(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
const accountID = "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
const policyID = "cs1tnh0hhcjnqoiuebf0"
|
||||
|
||||
original, err := store.GetPolicyByID(ctx, LockingStrengthNone, accountID, policyID)
|
||||
require.NoError(t, err)
|
||||
originalSeq := original.AccountSeqID
|
||||
require.NotZero(t, originalSeq, "fixture must have non-zero AccountSeqID after backfill")
|
||||
|
||||
updated := &types.Policy{
|
||||
ID: policyID,
|
||||
AccountID: accountID,
|
||||
Name: "renamed",
|
||||
Enabled: false,
|
||||
Rules: original.Rules,
|
||||
}
|
||||
require.Zero(t, updated.AccountSeqID, "incoming struct should have zero AccountSeqID like an HTTP handler would")
|
||||
|
||||
require.NoError(t, store.SavePolicy(ctx, updated))
|
||||
|
||||
got, err := store.GetPolicyByID(ctx, LockingStrengthNone, accountID, policyID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, originalSeq, got.AccountSeqID, "AccountSeqID must not be reset by update path")
|
||||
require.Equal(t, "renamed", got.Name)
|
||||
}
|
||||
|
||||
func TestGroupUpdate_PreservesSeqID(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
const accountID = "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
|
||||
groups, err := store.GetAccountGroups(ctx, LockingStrengthNone, accountID)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, groups)
|
||||
|
||||
original := groups[0]
|
||||
originalSeq := original.AccountSeqID
|
||||
require.NotZero(t, originalSeq)
|
||||
|
||||
updated := &types.Group{
|
||||
ID: original.ID,
|
||||
AccountID: accountID,
|
||||
Name: "renamed",
|
||||
Issued: original.Issued,
|
||||
}
|
||||
require.Zero(t, updated.AccountSeqID)
|
||||
|
||||
require.NoError(t, store.UpdateGroup(ctx, updated))
|
||||
|
||||
got, err := store.GetGroupByID(ctx, LockingStrengthNone, accountID, original.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, originalSeq, got.AccountSeqID, "AccountSeqID must not be reset by UpdateGroup")
|
||||
require.Equal(t, "renamed", got.Name)
|
||||
}
|
||||
|
||||
func TestSaveAccount_AllocatesSeqIDsForDefaultGroupAndPolicy(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
const accountID = "save-account-seqid-test"
|
||||
|
||||
account := &types.Account{
|
||||
Id: accountID,
|
||||
CreatedBy: "user1",
|
||||
Domain: "example.test",
|
||||
DNSSettings: types.DNSSettings{},
|
||||
Settings: &types.Settings{},
|
||||
Network: &types.Network{
|
||||
Identifier: "net-test",
|
||||
},
|
||||
Users: map[string]*types.User{
|
||||
"user1": {Id: "user1", AccountID: accountID, Role: types.UserRoleOwner},
|
||||
},
|
||||
}
|
||||
require.NoError(t, account.AddAllGroup(false), "AddAllGroup should populate default Group + Policy")
|
||||
require.Len(t, account.Groups, 1, "default 'All' group must be present")
|
||||
require.Len(t, account.Policies, 1, "default policy must be present")
|
||||
|
||||
for _, g := range account.Groups {
|
||||
require.Zero(t, g.AccountSeqID, "default group must start with seq=0")
|
||||
}
|
||||
require.Zero(t, account.Policies[0].AccountSeqID, "default policy must start with seq=0")
|
||||
|
||||
require.NoError(t, store.SaveAccount(ctx, account))
|
||||
|
||||
groups, err := store.GetAccountGroups(ctx, LockingStrengthNone, accountID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, groups, 1)
|
||||
require.NotZerof(t, groups[0].AccountSeqID, "default group must have seq>0 after SaveAccount")
|
||||
|
||||
policies, err := store.GetAccountPolicies(ctx, LockingStrengthNone, accountID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, policies, 1)
|
||||
require.NotZerof(t, policies[0].AccountSeqID, "default policy must have seq>0 after SaveAccount")
|
||||
|
||||
require.ErrorIs(t, store.ExecuteInTransaction(ctx, func(tx Store) error {
|
||||
next, err := tx.AllocateAccountSeqID(ctx, accountID, types.AccountSeqEntityGroup)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, groups[0].AccountSeqID+1, next, "next group seq must be max+1")
|
||||
|
||||
next, err = tx.AllocateAccountSeqID(ctx, accountID, types.AccountSeqEntityPolicy)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, policies[0].AccountSeqID+1, next, "next policy seq must be max+1")
|
||||
return errRollback
|
||||
}), errRollback)
|
||||
}
|
||||
|
||||
func TestSaveAccount_PreservesExistingSeqIDs(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
const accountID = "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
|
||||
account, err := store.GetAccount(ctx, accountID)
|
||||
require.NoError(t, err)
|
||||
|
||||
groupSeqs := make(map[string]uint32)
|
||||
policySeqs := make(map[string]uint32)
|
||||
routeSeqs := make(map[route.ID]uint32)
|
||||
nsgSeqs := make(map[string]uint32)
|
||||
resourceSeqs := make(map[string]uint32)
|
||||
routerSeqs := make(map[string]uint32)
|
||||
networkSeqs := make(map[string]uint32)
|
||||
|
||||
for _, g := range account.Groups {
|
||||
require.NotZero(t, g.AccountSeqID, "fixture group must have seq>0 after backfill")
|
||||
groupSeqs[g.ID] = g.AccountSeqID
|
||||
}
|
||||
for _, p := range account.Policies {
|
||||
require.NotZero(t, p.AccountSeqID, "fixture policy must have seq>0")
|
||||
policySeqs[p.ID] = p.AccountSeqID
|
||||
}
|
||||
for _, r := range account.Routes {
|
||||
require.NotZero(t, r.AccountSeqID, "fixture route must have seq>0")
|
||||
routeSeqs[r.ID] = r.AccountSeqID
|
||||
}
|
||||
for _, n := range account.NameServerGroups {
|
||||
require.NotZero(t, n.AccountSeqID, "fixture name_server_group must have seq>0")
|
||||
nsgSeqs[n.ID] = n.AccountSeqID
|
||||
}
|
||||
for _, nr := range account.NetworkResources {
|
||||
require.NotZero(t, nr.AccountSeqID, "fixture network_resource must have seq>0")
|
||||
resourceSeqs[nr.ID] = nr.AccountSeqID
|
||||
}
|
||||
for _, nr := range account.NetworkRouters {
|
||||
require.NotZero(t, nr.AccountSeqID, "fixture network_router must have seq>0")
|
||||
routerSeqs[nr.ID] = nr.AccountSeqID
|
||||
}
|
||||
for _, n := range account.Networks {
|
||||
require.NotZero(t, n.AccountSeqID, "fixture network must have seq>0 after backfill")
|
||||
networkSeqs[n.ID] = n.AccountSeqID
|
||||
}
|
||||
|
||||
require.NoError(t, store.SaveAccount(ctx, account))
|
||||
|
||||
after, err := store.GetAccount(ctx, accountID)
|
||||
require.NoError(t, err)
|
||||
for _, g := range after.Groups {
|
||||
require.Equal(t, groupSeqs[g.ID], g.AccountSeqID, "group %s seq must be preserved on re-save", g.ID)
|
||||
}
|
||||
for _, p := range after.Policies {
|
||||
require.Equal(t, policySeqs[p.ID], p.AccountSeqID, "policy %s seq must be preserved", p.ID)
|
||||
}
|
||||
for _, r := range after.Routes {
|
||||
require.Equal(t, routeSeqs[r.ID], r.AccountSeqID, "route %s seq must be preserved (slice-of-value addressability)", r.ID)
|
||||
}
|
||||
for _, n := range after.NameServerGroups {
|
||||
require.Equal(t, nsgSeqs[n.ID], n.AccountSeqID, "name_server_group %s seq must be preserved (slice-of-value addressability)", n.ID)
|
||||
}
|
||||
for _, nr := range after.NetworkResources {
|
||||
require.Equal(t, resourceSeqs[nr.ID], nr.AccountSeqID, "network_resource %s seq must be preserved", nr.ID)
|
||||
}
|
||||
for _, nr := range after.NetworkRouters {
|
||||
require.Equal(t, routerSeqs[nr.ID], nr.AccountSeqID, "network_router %s seq must be preserved", nr.ID)
|
||||
}
|
||||
for _, n := range after.Networks {
|
||||
require.Equal(t, networkSeqs[n.ID], n.AccountSeqID, "network %s seq must be preserved", n.ID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaveAccount_AllocatesSeqIDsForAllEntityTypes(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
const accountID = "save-account-all-entities"
|
||||
|
||||
addr, err := netip.ParseAddr("8.8.8.8")
|
||||
require.NoError(t, err)
|
||||
|
||||
account := &types.Account{
|
||||
Id: accountID,
|
||||
CreatedBy: "user1",
|
||||
Domain: "example.test",
|
||||
Settings: &types.Settings{},
|
||||
Network: &types.Network{Identifier: "net-test"},
|
||||
Users: map[string]*types.User{
|
||||
"user1": {Id: "user1", AccountID: accountID, Role: types.UserRoleOwner},
|
||||
},
|
||||
Groups: map[string]*types.Group{
|
||||
"g1": {ID: "g1", AccountID: accountID, Name: "g1", Issued: types.GroupIssuedAPI},
|
||||
},
|
||||
Policies: []*types.Policy{
|
||||
{ID: "p1", AccountID: accountID, Name: "p1", Enabled: true,
|
||||
Rules: []*types.PolicyRule{{ID: "r1", PolicyID: "p1", Enabled: true}}},
|
||||
},
|
||||
Routes: map[route.ID]*route.Route{
|
||||
"rt1": {ID: "rt1", AccountID: accountID, NetID: "net1", Peer: "peer1"},
|
||||
},
|
||||
NameServerGroups: map[string]*nbdns.NameServerGroup{
|
||||
"nsg1": {ID: "nsg1", AccountID: accountID, Name: "nsg1", Enabled: true,
|
||||
NameServers: []nbdns.NameServer{{IP: addr, NSType: nbdns.UDPNameServerType, Port: 53}}},
|
||||
},
|
||||
NetworkResources: []*resourceTypes.NetworkResource{
|
||||
{ID: "nr1", AccountID: accountID, NetworkID: "net1", Name: "res1", Enabled: true},
|
||||
},
|
||||
NetworkRouters: []*routerTypes.NetworkRouter{
|
||||
{ID: "nrt1", AccountID: accountID, NetworkID: "net1", Peer: "peer1", Enabled: true},
|
||||
},
|
||||
Networks: []*networkTypes.Network{
|
||||
{ID: "n1", AccountID: accountID, Name: "n1"},
|
||||
},
|
||||
PostureChecks: []*posture.Checks{
|
||||
{ID: "pc1", AccountID: accountID, Name: "pc1",
|
||||
Checks: posture.ChecksDefinition{
|
||||
NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.26.0"},
|
||||
}},
|
||||
},
|
||||
}
|
||||
|
||||
require.NoError(t, store.SaveAccount(ctx, account))
|
||||
|
||||
after, err := store.GetAccount(ctx, accountID)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Len(t, after.Groups, 1)
|
||||
require.Len(t, after.Policies, 1)
|
||||
require.Len(t, after.Routes, 1)
|
||||
require.Len(t, after.NameServerGroups, 1)
|
||||
require.Len(t, after.NetworkResources, 1)
|
||||
require.Len(t, after.NetworkRouters, 1)
|
||||
require.Len(t, after.Networks, 1)
|
||||
require.Len(t, after.PostureChecks, 1)
|
||||
|
||||
for _, g := range after.Groups {
|
||||
require.NotZero(t, g.AccountSeqID, "group seq must be allocated")
|
||||
}
|
||||
for _, p := range after.Policies {
|
||||
require.NotZero(t, p.AccountSeqID, "policy seq must be allocated")
|
||||
}
|
||||
for _, r := range after.Routes {
|
||||
require.NotZero(t, r.AccountSeqID, "route seq must be allocated (slice-of-value addressability)")
|
||||
}
|
||||
for _, n := range after.NameServerGroups {
|
||||
require.NotZero(t, n.AccountSeqID, "name_server_group seq must be allocated (slice-of-value addressability)")
|
||||
}
|
||||
for _, nr := range after.NetworkResources {
|
||||
require.NotZero(t, nr.AccountSeqID, "network_resource seq must be allocated")
|
||||
}
|
||||
for _, nr := range after.NetworkRouters {
|
||||
require.NotZero(t, nr.AccountSeqID, "network_router seq must be allocated")
|
||||
}
|
||||
for _, n := range after.Networks {
|
||||
require.NotZero(t, n.AccountSeqID, "network seq must be allocated")
|
||||
}
|
||||
for _, pc := range after.PostureChecks {
|
||||
require.NotZero(t, pc.AccountSeqID, "posture_check seq must be allocated")
|
||||
}
|
||||
|
||||
require.NoError(t, store.SaveAccount(ctx, after))
|
||||
final, err := store.GetAccount(ctx, accountID)
|
||||
require.NoError(t, err)
|
||||
for _, r := range final.Routes {
|
||||
require.Equal(t, after.Routes[r.ID].AccountSeqID, r.AccountSeqID, "route seq preserved on re-save")
|
||||
}
|
||||
for _, n := range final.NameServerGroups {
|
||||
require.Equal(t, after.NameServerGroups[n.ID].AccountSeqID, n.AccountSeqID, "name_server_group seq preserved on re-save")
|
||||
}
|
||||
afterByID := map[string]uint32{}
|
||||
for _, n := range after.Networks {
|
||||
afterByID[n.ID] = n.AccountSeqID
|
||||
}
|
||||
for _, n := range final.Networks {
|
||||
require.Equal(t, afterByID[n.ID], n.AccountSeqID, "network seq preserved on re-save")
|
||||
}
|
||||
afterPCByID := map[string]uint32{}
|
||||
for _, pc := range after.PostureChecks {
|
||||
afterPCByID[pc.ID] = pc.AccountSeqID
|
||||
}
|
||||
for _, pc := range final.PostureChecks {
|
||||
require.Equal(t, afterPCByID[pc.ID], pc.AccountSeqID, "posture_check seq preserved on re-save")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAllocateAccountSeqID_ConcurrentSameAccountEntity(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
const accountID = "concurrent-test"
|
||||
const entity = types.AccountSeqEntityPolicy
|
||||
const goroutines = 32
|
||||
|
||||
type result struct {
|
||||
seq uint32
|
||||
err error
|
||||
}
|
||||
results := make(chan result, goroutines)
|
||||
start := make(chan struct{})
|
||||
|
||||
for i := 0; i < goroutines; i++ {
|
||||
go func() {
|
||||
<-start
|
||||
var allocated uint32
|
||||
err := store.ExecuteInTransaction(ctx, func(tx Store) error {
|
||||
seq, err := tx.AllocateAccountSeqID(ctx, accountID, entity)
|
||||
allocated = seq
|
||||
return err
|
||||
})
|
||||
results <- result{seq: allocated, err: err}
|
||||
}()
|
||||
}
|
||||
close(start)
|
||||
|
||||
seen := make(map[uint32]int, goroutines)
|
||||
for i := 0; i < goroutines; i++ {
|
||||
r := <-results
|
||||
require.NoError(t, r.err, "concurrent allocate must not fail")
|
||||
require.NotZero(t, r.seq, "allocated seq must be non-zero")
|
||||
seen[r.seq]++
|
||||
}
|
||||
|
||||
require.Lenf(t, seen, goroutines, "every concurrent allocation must yield a unique id; got duplicates in %v", seen)
|
||||
for i := uint32(1); i <= goroutines; i++ {
|
||||
require.Equalf(t, 1, seen[i], "id %d must appear exactly once across concurrent allocations", i)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStoreCreateGroups_AllocatedSeqIDIsNotClobbered(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
const accountID = "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
|
||||
groups := []*types.Group{
|
||||
{ID: "seq-test-g1", AccountID: accountID, Name: "g1", Issued: "jwt", AccountSeqID: 7777},
|
||||
{ID: "seq-test-g2", AccountID: accountID, Name: "g2", Issued: "jwt", AccountSeqID: 7778},
|
||||
}
|
||||
require.NoError(t, store.CreateGroups(ctx, accountID, groups))
|
||||
|
||||
for _, want := range groups {
|
||||
got, err := store.GetGroupByID(ctx, LockingStrengthNone, accountID, want.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, want.AccountSeqID, got.AccountSeqID, "seq id from caller must be persisted on insert")
|
||||
}
|
||||
|
||||
groups[0].Name = "g1-renamed"
|
||||
groups[0].AccountSeqID = 0
|
||||
require.NoError(t, store.CreateGroups(ctx, accountID, groups[:1]))
|
||||
|
||||
got, err := store.GetGroupByID(ctx, LockingStrengthNone, accountID, "seq-test-g1")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "g1-renamed", got.Name, "upsert path still updates other columns")
|
||||
require.Equal(t, uint32(7777), got.AccountSeqID, "upsert path must NOT overwrite account_seq_id")
|
||||
}
|
||||
|
||||
func TestPolicyCreate_AllocatesSeqID(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
const accountID = "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
|
||||
existing, err := store.GetAccountPolicies(ctx, LockingStrengthNone, accountID)
|
||||
require.NoError(t, err)
|
||||
maxSeq := uint32(0)
|
||||
for _, p := range existing {
|
||||
if p.AccountSeqID > maxSeq {
|
||||
maxSeq = p.AccountSeqID
|
||||
}
|
||||
}
|
||||
|
||||
require.NoError(t, store.ExecuteInTransaction(ctx, func(tx Store) error {
|
||||
seq, err := tx.AllocateAccountSeqID(ctx, accountID, types.AccountSeqEntityPolicy)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
require.Equal(t, maxSeq+1, seq, "next id should be max+1 after backfill")
|
||||
|
||||
newPolicy := &types.Policy{
|
||||
ID: "bench-new-policy",
|
||||
AccountID: accountID,
|
||||
AccountSeqID: seq,
|
||||
Enabled: true,
|
||||
Rules: []*types.PolicyRule{{
|
||||
ID: "bench-new-policy-rule",
|
||||
PolicyID: "bench-new-policy",
|
||||
Enabled: true,
|
||||
Action: types.PolicyTrafficActionAccept,
|
||||
Sources: []string{"groupA"},
|
||||
Destinations: []string{"groupC"},
|
||||
Bidirectional: true,
|
||||
}},
|
||||
}
|
||||
return tx.CreatePolicy(ctx, newPolicy)
|
||||
}))
|
||||
|
||||
created, err := store.GetPolicyByID(ctx, LockingStrengthNone, accountID, "bench-new-policy")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, maxSeq+1, created.AccountSeqID)
|
||||
}
|
||||
@@ -137,6 +137,7 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, met
|
||||
&networkTypes.Network{}, &routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{}, &types.AccountOnboarding{},
|
||||
&types.Job{}, &zones.Zone{}, &records.Record{}, &types.UserInviteRecord{}, &rpservice.Service{}, &rpservice.Target{}, &domain.Domain{},
|
||||
&accesslogs.AccessLogEntry{}, &proxy.Proxy{},
|
||||
&types.AccountSeqCounter{},
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("auto migratePreAuto: %w", err)
|
||||
@@ -308,6 +309,10 @@ func (s *SqlStore) SaveAccount(ctx context.Context, account *types.Account) erro
|
||||
return result.Error
|
||||
}
|
||||
|
||||
if err := s.assignAccountSeqIDs(ctx, tx, account); err != nil {
|
||||
return fmt.Errorf("assign seq ids: %w", err)
|
||||
}
|
||||
|
||||
result = tx.
|
||||
Session(&gorm.Session{FullSaveAssociations: true}).
|
||||
Clauses(clause.OnConflict{UpdateAll: true}).
|
||||
@@ -637,6 +642,22 @@ func (s *SqlStore) SaveUser(ctx context.Context, user *types.User) error {
|
||||
}
|
||||
|
||||
// CreateGroups creates the given list of groups to the database.
|
||||
// groupUpsertColumns is the explicit allowlist of columns that get updated when
|
||||
// CreateGroups / UpdateGroups hit a PK conflict. account_seq_id is intentionally
|
||||
// omitted so a caller passing an entity with the zero value (e.g. an HTTP
|
||||
// handler-built struct) cannot reset the persisted seq id during an upsert.
|
||||
// Keep this in sync with the Group schema in management/server/types/group.go.
|
||||
func groupUpsertColumns() clause.Set {
|
||||
return clause.AssignmentColumns([]string{
|
||||
"account_id",
|
||||
"name",
|
||||
"issued",
|
||||
"integration_ref_id",
|
||||
"integration_ref_integration_type",
|
||||
"resources",
|
||||
})
|
||||
}
|
||||
|
||||
func (s *SqlStore) CreateGroups(ctx context.Context, accountID string, groups []*types.Group) error {
|
||||
if len(groups) == 0 {
|
||||
return nil
|
||||
@@ -646,8 +667,9 @@ func (s *SqlStore) CreateGroups(ctx context.Context, accountID string, groups []
|
||||
result := tx.
|
||||
Clauses(
|
||||
clause.OnConflict{
|
||||
Columns: []clause.Column{{Name: "id"}},
|
||||
Where: clause.Where{Exprs: []clause.Expression{clause.Eq{Column: "groups.account_id", Value: accountID}}},
|
||||
UpdateAll: true,
|
||||
DoUpdates: groupUpsertColumns(),
|
||||
},
|
||||
).
|
||||
Omit(clause.Associations).
|
||||
@@ -671,8 +693,9 @@ func (s *SqlStore) UpdateGroups(ctx context.Context, accountID string, groups []
|
||||
result := tx.
|
||||
Clauses(
|
||||
clause.OnConflict{
|
||||
Columns: []clause.Column{{Name: "id"}},
|
||||
Where: clause.Where{Exprs: []clause.Expression{clause.Eq{Column: "groups.account_id", Value: accountID}}},
|
||||
UpdateAll: true,
|
||||
DoUpdates: groupUpsertColumns(),
|
||||
},
|
||||
).
|
||||
Omit(clause.Associations).
|
||||
@@ -2018,7 +2041,7 @@ func (s *SqlStore) getUsers(ctx context.Context, accountID string) ([]types.User
|
||||
}
|
||||
|
||||
func (s *SqlStore) getGroups(ctx context.Context, accountID string) ([]*types.Group, error) {
|
||||
const query = `SELECT id, account_id, name, issued, resources, integration_ref_id, integration_ref_integration_type FROM groups WHERE account_id = $1`
|
||||
const query = `SELECT id, account_id, account_seq_id, name, issued, resources, integration_ref_id, integration_ref_integration_type FROM groups WHERE account_id = $1`
|
||||
rows, err := s.pool.Query(ctx, query, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -2028,7 +2051,7 @@ func (s *SqlStore) getGroups(ctx context.Context, accountID string) ([]*types.Gr
|
||||
var resources []byte
|
||||
var refID sql.NullInt64
|
||||
var refType sql.NullString
|
||||
err := row.Scan(&g.ID, &g.AccountID, &g.Name, &g.Issued, &resources, &refID, &refType)
|
||||
err := row.Scan(&g.ID, &g.AccountID, &g.AccountSeqID, &g.Name, &g.Issued, &resources, &refID, &refType)
|
||||
if err == nil {
|
||||
if refID.Valid {
|
||||
g.IntegrationReference.ID = int(refID.Int64)
|
||||
@@ -2053,7 +2076,7 @@ func (s *SqlStore) getGroups(ctx context.Context, accountID string) ([]*types.Gr
|
||||
}
|
||||
|
||||
func (s *SqlStore) getPolicies(ctx context.Context, accountID string) ([]*types.Policy, error) {
|
||||
const query = `SELECT id, account_id, name, description, enabled, source_posture_checks FROM policies WHERE account_id = $1`
|
||||
const query = `SELECT id, account_id, account_seq_id, name, description, enabled, source_posture_checks FROM policies WHERE account_id = $1`
|
||||
rows, err := s.pool.Query(ctx, query, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -2062,7 +2085,7 @@ func (s *SqlStore) getPolicies(ctx context.Context, accountID string) ([]*types.
|
||||
var p types.Policy
|
||||
var checks []byte
|
||||
var enabled sql.NullBool
|
||||
err := row.Scan(&p.ID, &p.AccountID, &p.Name, &p.Description, &enabled, &checks)
|
||||
err := row.Scan(&p.ID, &p.AccountID, &p.AccountSeqID, &p.Name, &p.Description, &enabled, &checks)
|
||||
if err == nil {
|
||||
if enabled.Valid {
|
||||
p.Enabled = enabled.Bool
|
||||
@@ -2080,7 +2103,7 @@ func (s *SqlStore) getPolicies(ctx context.Context, accountID string) ([]*types.
|
||||
}
|
||||
|
||||
func (s *SqlStore) getRoutes(ctx context.Context, accountID string) ([]route.Route, error) {
|
||||
const query = `SELECT id, account_id, network, domains, keep_route, net_id, description, peer, peer_groups, network_type, masquerade, metric, enabled, groups, access_control_groups, skip_auto_apply FROM routes WHERE account_id = $1`
|
||||
const query = `SELECT id, account_id, account_seq_id, network, domains, keep_route, net_id, description, peer, peer_groups, network_type, masquerade, metric, enabled, groups, access_control_groups, skip_auto_apply FROM routes WHERE account_id = $1`
|
||||
rows, err := s.pool.Query(ctx, query, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -2090,7 +2113,7 @@ func (s *SqlStore) getRoutes(ctx context.Context, accountID string) ([]route.Rou
|
||||
var network, domains, peerGroups, groups, accessGroups []byte
|
||||
var keepRoute, masquerade, enabled, skipAutoApply sql.NullBool
|
||||
var metric sql.NullInt64
|
||||
err := row.Scan(&r.ID, &r.AccountID, &network, &domains, &keepRoute, &r.NetID, &r.Description, &r.Peer, &peerGroups, &r.NetworkType, &masquerade, &metric, &enabled, &groups, &accessGroups, &skipAutoApply)
|
||||
err := row.Scan(&r.ID, &r.AccountID, &r.AccountSeqID, &network, &domains, &keepRoute, &r.NetID, &r.Description, &r.Peer, &peerGroups, &r.NetworkType, &masquerade, &metric, &enabled, &groups, &accessGroups, &skipAutoApply)
|
||||
if err == nil {
|
||||
if keepRoute.Valid {
|
||||
r.KeepRoute = keepRoute.Bool
|
||||
@@ -2132,7 +2155,7 @@ func (s *SqlStore) getRoutes(ctx context.Context, accountID string) ([]route.Rou
|
||||
}
|
||||
|
||||
func (s *SqlStore) getNameServerGroups(ctx context.Context, accountID string) ([]nbdns.NameServerGroup, error) {
|
||||
const query = `SELECT id, account_id, name, description, name_servers, groups, "primary", domains, enabled, search_domains_enabled FROM name_server_groups WHERE account_id = $1`
|
||||
const query = `SELECT id, account_id, account_seq_id, name, description, name_servers, groups, "primary", domains, enabled, search_domains_enabled FROM name_server_groups WHERE account_id = $1`
|
||||
rows, err := s.pool.Query(ctx, query, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -2141,7 +2164,7 @@ func (s *SqlStore) getNameServerGroups(ctx context.Context, accountID string) ([
|
||||
var n nbdns.NameServerGroup
|
||||
var ns, groups, domains []byte
|
||||
var primary, enabled, searchDomainsEnabled sql.NullBool
|
||||
err := row.Scan(&n.ID, &n.AccountID, &n.Name, &n.Description, &ns, &groups, &primary, &domains, &enabled, &searchDomainsEnabled)
|
||||
err := row.Scan(&n.ID, &n.AccountID, &n.AccountSeqID, &n.Name, &n.Description, &ns, &groups, &primary, &domains, &enabled, &searchDomainsEnabled)
|
||||
if err == nil {
|
||||
if primary.Valid {
|
||||
n.Primary = primary.Bool
|
||||
@@ -2177,7 +2200,7 @@ func (s *SqlStore) getNameServerGroups(ctx context.Context, accountID string) ([
|
||||
}
|
||||
|
||||
func (s *SqlStore) getPostureChecks(ctx context.Context, accountID string) ([]*posture.Checks, error) {
|
||||
const query = `SELECT id, account_id, name, description, checks FROM posture_checks WHERE account_id = $1`
|
||||
const query = `SELECT id, account_id, account_seq_id, name, description, checks FROM posture_checks WHERE account_id = $1`
|
||||
rows, err := s.pool.Query(ctx, query, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -2185,7 +2208,7 @@ func (s *SqlStore) getPostureChecks(ctx context.Context, accountID string) ([]*p
|
||||
checks, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (*posture.Checks, error) {
|
||||
var c posture.Checks
|
||||
var checksDef []byte
|
||||
err := row.Scan(&c.ID, &c.AccountID, &c.Name, &c.Description, &checksDef)
|
||||
err := row.Scan(&c.ID, &c.AccountID, &c.AccountSeqID, &c.Name, &c.Description, &checksDef)
|
||||
if err == nil && checksDef != nil {
|
||||
_ = json.Unmarshal(checksDef, &c.Checks)
|
||||
}
|
||||
@@ -2365,7 +2388,7 @@ func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*rpserv
|
||||
}
|
||||
|
||||
func (s *SqlStore) getNetworks(ctx context.Context, accountID string) ([]*networkTypes.Network, error) {
|
||||
const query = `SELECT id, account_id, name, description FROM networks WHERE account_id = $1`
|
||||
const query = `SELECT id, account_id, account_seq_id, name, description FROM networks WHERE account_id = $1`
|
||||
rows, err := s.pool.Query(ctx, query, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -2382,7 +2405,7 @@ func (s *SqlStore) getNetworks(ctx context.Context, accountID string) ([]*networ
|
||||
}
|
||||
|
||||
func (s *SqlStore) getNetworkRouters(ctx context.Context, accountID string) ([]*routerTypes.NetworkRouter, error) {
|
||||
const query = `SELECT id, network_id, account_id, peer, peer_groups, masquerade, metric, enabled FROM network_routers WHERE account_id = $1`
|
||||
const query = `SELECT id, network_id, account_id, account_seq_id, peer, peer_groups, masquerade, metric, enabled FROM network_routers WHERE account_id = $1`
|
||||
rows, err := s.pool.Query(ctx, query, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -2392,7 +2415,7 @@ func (s *SqlStore) getNetworkRouters(ctx context.Context, accountID string) ([]*
|
||||
var peerGroups []byte
|
||||
var masquerade, enabled sql.NullBool
|
||||
var metric sql.NullInt64
|
||||
err := row.Scan(&r.ID, &r.NetworkID, &r.AccountID, &r.Peer, &peerGroups, &masquerade, &metric, &enabled)
|
||||
err := row.Scan(&r.ID, &r.NetworkID, &r.AccountID, &r.AccountSeqID, &r.Peer, &peerGroups, &masquerade, &metric, &enabled)
|
||||
if err == nil {
|
||||
if masquerade.Valid {
|
||||
r.Masquerade = masquerade.Bool
|
||||
@@ -2420,7 +2443,7 @@ func (s *SqlStore) getNetworkRouters(ctx context.Context, accountID string) ([]*
|
||||
}
|
||||
|
||||
func (s *SqlStore) getNetworkResources(ctx context.Context, accountID string) ([]*resourceTypes.NetworkResource, error) {
|
||||
const query = `SELECT id, network_id, account_id, name, description, type, domain, prefix, enabled FROM network_resources WHERE account_id = $1`
|
||||
const query = `SELECT id, network_id, account_id, account_seq_id, name, description, type, domain, prefix, enabled FROM network_resources WHERE account_id = $1`
|
||||
rows, err := s.pool.Query(ctx, query, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -2429,7 +2452,7 @@ func (s *SqlStore) getNetworkResources(ctx context.Context, accountID string) ([
|
||||
var r resourceTypes.NetworkResource
|
||||
var prefix []byte
|
||||
var enabled sql.NullBool
|
||||
err := row.Scan(&r.ID, &r.NetworkID, &r.AccountID, &r.Name, &r.Description, &r.Type, &r.Domain, &prefix, &enabled)
|
||||
err := row.Scan(&r.ID, &r.NetworkID, &r.AccountID, &r.AccountSeqID, &r.Name, &r.Description, &r.Type, &r.Domain, &prefix, &enabled)
|
||||
if err == nil {
|
||||
if enabled.Valid {
|
||||
r.Enabled = enabled.Bool
|
||||
@@ -3602,6 +3625,262 @@ func (s *SqlStore) withTx(tx *gorm.DB) Store {
|
||||
}
|
||||
}
|
||||
|
||||
// AllocateAccountSeqID returns the next per-account integer id for the given
|
||||
// component kind. Must be called inside ExecuteInTransaction so the increment
|
||||
// is serialized with the component insert.
|
||||
func (s *SqlStore) AllocateAccountSeqID(ctx context.Context, accountID string, entity types.AccountSeqEntity) (uint32, error) {
|
||||
return allocateAccountSeqID(ctx, s.db, s.storeEngine, accountID, entity)
|
||||
}
|
||||
|
||||
func allocateAccountSeqID(_ context.Context, db *gorm.DB, engine types.Engine, accountID string, entity types.AccountSeqEntity) (uint32, error) {
|
||||
switch engine {
|
||||
case types.PostgresStoreEngine, types.SqliteStoreEngine:
|
||||
return allocateAccountSeqIDReturning(db, accountID, entity)
|
||||
case types.MysqlStoreEngine:
|
||||
return allocateAccountSeqIDMysql(db, accountID, entity)
|
||||
default:
|
||||
return 0, fmt.Errorf("unsupported store engine for account_seq allocator: %v", engine)
|
||||
}
|
||||
}
|
||||
|
||||
// allocateAccountSeqIDReturning runs a single atomic INSERT ... ON CONFLICT
|
||||
// DO UPDATE ... RETURNING that gives us the allocated id without a separate
|
||||
// SELECT FOR UPDATE. Two concurrent allocations for the same (account, entity)
|
||||
// produce two distinct ids: one wins the INSERT, the other wins the UPDATE
|
||||
// branch and returns next_id+1.
|
||||
func allocateAccountSeqIDReturning(db *gorm.DB, accountID string, entity types.AccountSeqEntity) (uint32, error) {
|
||||
const sqlStr = `
|
||||
INSERT INTO account_seq_counters (account_id, entity, next_id)
|
||||
VALUES (?, ?, 2)
|
||||
ON CONFLICT (account_id, entity) DO UPDATE
|
||||
SET next_id = account_seq_counters.next_id + 1
|
||||
RETURNING (next_id - 1)
|
||||
`
|
||||
var allocated uint32
|
||||
if err := db.Raw(sqlStr, accountID, string(entity)).Scan(&allocated).Error; err != nil {
|
||||
return 0, fmt.Errorf("upsert account seq counter: %w", err)
|
||||
}
|
||||
if allocated == 0 {
|
||||
return 0, fmt.Errorf("upsert account seq counter returned 0")
|
||||
}
|
||||
return allocated, nil
|
||||
}
|
||||
|
||||
// allocateAccountSeqIDMysql is the MySQL equivalent of allocateAccountSeqIDReturning.
|
||||
// MySQL has no RETURNING on ON DUPLICATE KEY UPDATE, so we use the LAST_INSERT_ID
|
||||
// trick: passing an expression to LAST_INSERT_ID(expr) both sets the session value
|
||||
// and returns it from the INSERT. The INSERT's value uses LAST_INSERT_ID(2) so the
|
||||
// no-conflict path also surfaces the new next_id, keeping the read-back uniform.
|
||||
// LAST_INSERT_ID is per-connection; GORM transactions pin a single connection,
|
||||
// so the follow-up SELECT sees the same value.
|
||||
func allocateAccountSeqIDMysql(db *gorm.DB, accountID string, entity types.AccountSeqEntity) (uint32, error) {
|
||||
const upsertSQL = `
|
||||
INSERT INTO account_seq_counters (account_id, entity, next_id)
|
||||
VALUES (?, ?, LAST_INSERT_ID(2))
|
||||
ON DUPLICATE KEY UPDATE next_id = LAST_INSERT_ID(next_id + 1)
|
||||
`
|
||||
if err := db.Exec(upsertSQL, accountID, string(entity)).Error; err != nil {
|
||||
return 0, fmt.Errorf("upsert account seq counter: %w", err)
|
||||
}
|
||||
var newNext uint64
|
||||
if err := db.Raw("SELECT LAST_INSERT_ID()").Scan(&newNext).Error; err != nil {
|
||||
return 0, fmt.Errorf("get last insert id: %w", err)
|
||||
}
|
||||
if newNext == 0 {
|
||||
return 0, fmt.Errorf("LAST_INSERT_ID returned 0; account_seq_counters misconfigured")
|
||||
}
|
||||
return uint32(newNext - 1), nil
|
||||
}
|
||||
|
||||
// assignAccountSeqIDs allocates a per-account integer id for any component on
|
||||
// the in-memory account whose AccountSeqID is zero. Called from SaveAccount so
|
||||
// the canonical "save the whole account" path produces the same persisted seq
|
||||
// ids that the manager-level Create paths produce. Update flows that go
|
||||
// through SaveAccount preserve existing non-zero values; for those, the
|
||||
// per-entity counter is bumped so subsequent AllocateAccountSeqID calls don't
|
||||
// hand out a colliding id.
|
||||
func (s *SqlStore) assignAccountSeqIDs(ctx context.Context, tx *gorm.DB, account *types.Account) error {
|
||||
maxByEntity := make(map[types.AccountSeqEntity]uint32, 8)
|
||||
bump := func(entity types.AccountSeqEntity, seq uint32) {
|
||||
if seq > maxByEntity[entity] {
|
||||
maxByEntity[entity] = seq
|
||||
}
|
||||
}
|
||||
|
||||
for i := range account.GroupsG {
|
||||
g := account.GroupsG[i]
|
||||
if g == nil {
|
||||
continue
|
||||
}
|
||||
if g.AccountSeqID != 0 {
|
||||
bump(types.AccountSeqEntityGroup, g.AccountSeqID)
|
||||
continue
|
||||
}
|
||||
seq, err := allocateAccountSeqID(ctx, tx, s.storeEngine, account.Id, types.AccountSeqEntityGroup)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
g.AccountSeqID = seq
|
||||
// Defensive: generateAccountSQLTypes currently aliases the same
|
||||
// *Group pointer into GroupsG and Groups[id] (so this is a no-op
|
||||
// today), but mirror the seq anyway so any future divergence in
|
||||
// how the two collections are populated doesn't silently leave
|
||||
// the canonical map view stale.
|
||||
if original, ok := account.Groups[g.ID]; ok && original != nil && original != g {
|
||||
original.AccountSeqID = seq
|
||||
}
|
||||
}
|
||||
for _, p := range account.Policies {
|
||||
if p == nil {
|
||||
continue
|
||||
}
|
||||
if p.AccountSeqID != 0 {
|
||||
bump(types.AccountSeqEntityPolicy, p.AccountSeqID)
|
||||
continue
|
||||
}
|
||||
seq, err := allocateAccountSeqID(ctx, tx, s.storeEngine, account.Id, types.AccountSeqEntityPolicy)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
p.AccountSeqID = seq
|
||||
}
|
||||
for i := range account.RoutesG {
|
||||
r := &account.RoutesG[i]
|
||||
if r.AccountSeqID != 0 {
|
||||
bump(types.AccountSeqEntityRoute, r.AccountSeqID)
|
||||
continue
|
||||
}
|
||||
seq, err := allocateAccountSeqID(ctx, tx, s.storeEngine, account.Id, types.AccountSeqEntityRoute)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
r.AccountSeqID = seq
|
||||
// Mirror the new seq onto the canonical map view so callers that
|
||||
// hold the same in-memory account post-Save read a consistent
|
||||
// AccountSeqID — without this, components/encoder code would see
|
||||
// 0 for routes saved this transaction until the account is reloaded.
|
||||
if original, ok := account.Routes[r.ID]; ok && original != nil {
|
||||
original.AccountSeqID = seq
|
||||
}
|
||||
}
|
||||
for i := range account.NameServerGroupsG {
|
||||
ng := &account.NameServerGroupsG[i]
|
||||
if ng.AccountSeqID != 0 {
|
||||
bump(types.AccountSeqEntityNameserverGroup, ng.AccountSeqID)
|
||||
continue
|
||||
}
|
||||
seq, err := allocateAccountSeqID(ctx, tx, s.storeEngine, account.Id, types.AccountSeqEntityNameserverGroup)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ng.AccountSeqID = seq
|
||||
if original, ok := account.NameServerGroups[ng.ID]; ok && original != nil {
|
||||
original.AccountSeqID = seq
|
||||
}
|
||||
}
|
||||
for _, nr := range account.NetworkResources {
|
||||
if nr == nil {
|
||||
continue
|
||||
}
|
||||
if nr.AccountSeqID != 0 {
|
||||
bump(types.AccountSeqEntityNetworkResource, nr.AccountSeqID)
|
||||
continue
|
||||
}
|
||||
seq, err := allocateAccountSeqID(ctx, tx, s.storeEngine, account.Id, types.AccountSeqEntityNetworkResource)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
nr.AccountSeqID = seq
|
||||
}
|
||||
for _, nr := range account.NetworkRouters {
|
||||
if nr == nil {
|
||||
continue
|
||||
}
|
||||
if nr.AccountSeqID != 0 {
|
||||
bump(types.AccountSeqEntityNetworkRouter, nr.AccountSeqID)
|
||||
continue
|
||||
}
|
||||
seq, err := allocateAccountSeqID(ctx, tx, s.storeEngine, account.Id, types.AccountSeqEntityNetworkRouter)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
nr.AccountSeqID = seq
|
||||
}
|
||||
for _, n := range account.Networks {
|
||||
if n == nil {
|
||||
continue
|
||||
}
|
||||
if n.AccountSeqID != 0 {
|
||||
bump(types.AccountSeqEntityNetwork, n.AccountSeqID)
|
||||
continue
|
||||
}
|
||||
seq, err := allocateAccountSeqID(ctx, tx, s.storeEngine, account.Id, types.AccountSeqEntityNetwork)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
n.AccountSeqID = seq
|
||||
}
|
||||
for _, pc := range account.PostureChecks {
|
||||
if pc == nil {
|
||||
continue
|
||||
}
|
||||
if pc.AccountSeqID != 0 {
|
||||
bump(types.AccountSeqEntityPostureCheck, pc.AccountSeqID)
|
||||
continue
|
||||
}
|
||||
seq, err := allocateAccountSeqID(ctx, tx, s.storeEngine, account.Id, types.AccountSeqEntityPostureCheck)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
pc.AccountSeqID = seq
|
||||
}
|
||||
for entity, maxSeq := range maxByEntity {
|
||||
if err := ensureAccountSeqCounter(tx, s.storeEngine, account.Id, entity, maxSeq+1); err != nil {
|
||||
return fmt.Errorf("seed counter for %s: %w", entity, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ensureAccountSeqCounter raises the per-account counter for entity to at
|
||||
// least target. Used when SaveAccount persists components that already carry
|
||||
// AccountSeqIDs (e.g. test bulk-load from sqlite to postgres, or migrations
|
||||
// running before component data lands) so that the next AllocateAccountSeqID
|
||||
// call returns a fresh id beyond what was just written.
|
||||
func ensureAccountSeqCounter(db *gorm.DB, engine types.Engine, accountID string, entity types.AccountSeqEntity, target uint32) error {
|
||||
switch engine {
|
||||
case types.PostgresStoreEngine, types.SqliteStoreEngine:
|
||||
const sqlStr = `
|
||||
INSERT INTO account_seq_counters (account_id, entity, next_id)
|
||||
VALUES (?, ?, ?)
|
||||
ON CONFLICT (account_id, entity) DO UPDATE
|
||||
SET next_id = GREATEST(account_seq_counters.next_id, EXCLUDED.next_id)
|
||||
`
|
||||
// sqlite's UPSERT understands max() but the migration uses GREATEST
|
||||
// for postgres and max() for sqlite. We collapse to dialect-specific
|
||||
// statements only when needed.
|
||||
if engine == types.SqliteStoreEngine {
|
||||
const sqliteSQL = `
|
||||
INSERT INTO account_seq_counters (account_id, entity, next_id)
|
||||
VALUES (?, ?, ?)
|
||||
ON CONFLICT (account_id, entity) DO UPDATE
|
||||
SET next_id = max(account_seq_counters.next_id, excluded.next_id)
|
||||
`
|
||||
return db.Exec(sqliteSQL, accountID, string(entity), target).Error
|
||||
}
|
||||
return db.Exec(sqlStr, accountID, string(entity), target).Error
|
||||
case types.MysqlStoreEngine:
|
||||
const sqlStr = `
|
||||
INSERT INTO account_seq_counters (account_id, entity, next_id)
|
||||
VALUES (?, ?, ?)
|
||||
ON DUPLICATE KEY UPDATE next_id = GREATEST(next_id, VALUES(next_id))
|
||||
`
|
||||
return db.Exec(sqlStr, accountID, string(entity), target).Error
|
||||
default:
|
||||
return fmt.Errorf("unsupported store engine for account_seq counter: %v", engine)
|
||||
}
|
||||
}
|
||||
|
||||
// transaction wraps a GORM transaction with MySQL-specific FK checks handling
|
||||
// Use this instead of db.Transaction() directly to avoid deadlocks on MySQL/Aurora
|
||||
func (s *SqlStore) transaction(fn func(*gorm.DB) error) error {
|
||||
@@ -3791,7 +4070,7 @@ func (s *SqlStore) UpdateGroup(ctx context.Context, group *types.Group) error {
|
||||
return status.Errorf(status.InvalidArgument, "group is nil")
|
||||
}
|
||||
|
||||
if err := s.db.Omit(clause.Associations).Save(group).Error; err != nil {
|
||||
if err := s.db.Omit(clause.Associations, "account_seq_id").Save(group).Error; err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to save group to store: %v", err)
|
||||
return status.Errorf(status.Internal, "failed to save group to store")
|
||||
}
|
||||
@@ -3879,7 +4158,7 @@ func (s *SqlStore) CreatePolicy(ctx context.Context, policy *types.Policy) error
|
||||
|
||||
// SavePolicy saves a policy to the database.
|
||||
func (s *SqlStore) SavePolicy(ctx context.Context, policy *types.Policy) error {
|
||||
result := s.db.Session(&gorm.Session{FullSaveAssociations: true}).Save(policy)
|
||||
result := s.db.Session(&gorm.Session{FullSaveAssociations: true}).Omit("account_seq_id").Save(policy)
|
||||
if err := result.Error; err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to save policy to the store: %s", err)
|
||||
return status.Errorf(status.Internal, "failed to save policy to store")
|
||||
|
||||
@@ -222,6 +222,11 @@ type Store interface {
|
||||
GetStoreEngine() types.Engine
|
||||
ExecuteInTransaction(ctx context.Context, f func(store Store) error) error
|
||||
|
||||
// AllocateAccountSeqID returns the next per-account integer id for the given
|
||||
// component kind. Must run inside a transaction so the increment is serialized
|
||||
// with the component insert.
|
||||
AllocateAccountSeqID(ctx context.Context, accountID string, entity types.AccountSeqEntity) (uint32, error)
|
||||
|
||||
GetAccountNetworks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*networkTypes.Network, error)
|
||||
GetNetworkByID(ctx context.Context, lockStrength LockingStrength, accountID, networkID string) (*networkTypes.Network, error)
|
||||
SaveNetwork(ctx context.Context, network *networkTypes.Network) error
|
||||
@@ -558,6 +563,30 @@ func getMigrationsPostAuto(ctx context.Context) []migrationFunc {
|
||||
func(db *gorm.DB) error {
|
||||
return migration.DropIndex[proxy.Proxy](ctx, db, "idx_proxy_account_id_unique")
|
||||
},
|
||||
func(db *gorm.DB) error {
|
||||
return migration.BackfillAccountSeqIDs[types.Policy](ctx, db, types.AccountSeqEntityPolicy, "id")
|
||||
},
|
||||
func(db *gorm.DB) error {
|
||||
return migration.BackfillAccountSeqIDs[types.Group](ctx, db, types.AccountSeqEntityGroup, "id")
|
||||
},
|
||||
func(db *gorm.DB) error {
|
||||
return migration.BackfillAccountSeqIDs[route.Route](ctx, db, types.AccountSeqEntityRoute, "id")
|
||||
},
|
||||
func(db *gorm.DB) error {
|
||||
return migration.BackfillAccountSeqIDs[resourceTypes.NetworkResource](ctx, db, types.AccountSeqEntityNetworkResource, "id")
|
||||
},
|
||||
func(db *gorm.DB) error {
|
||||
return migration.BackfillAccountSeqIDs[routerTypes.NetworkRouter](ctx, db, types.AccountSeqEntityNetworkRouter, "id")
|
||||
},
|
||||
func(db *gorm.DB) error {
|
||||
return migration.BackfillAccountSeqIDs[dns.NameServerGroup](ctx, db, types.AccountSeqEntityNameserverGroup, "id")
|
||||
},
|
||||
func(db *gorm.DB) error {
|
||||
return migration.BackfillAccountSeqIDs[networkTypes.Network](ctx, db, types.AccountSeqEntityNetwork, "id")
|
||||
},
|
||||
func(db *gorm.DB) error {
|
||||
return migration.BackfillAccountSeqIDs[posture.Checks](ctx, db, types.AccountSeqEntityPostureCheck, "id")
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -774,6 +774,21 @@ func (mr *MockStoreMockRecorder) EphemeralServiceExists(ctx, lockStrength, accou
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "EphemeralServiceExists", reflect.TypeOf((*MockStore)(nil).EphemeralServiceExists), ctx, lockStrength, accountID, peerID, domain)
|
||||
}
|
||||
|
||||
// AllocateAccountSeqID mocks base method.
|
||||
func (m *MockStore) AllocateAccountSeqID(ctx context.Context, accountID string, entity types2.AccountSeqEntity) (uint32, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "AllocateAccountSeqID", ctx, accountID, entity)
|
||||
ret0, _ := ret[0].(uint32)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// AllocateAccountSeqID indicates an expected call of AllocateAccountSeqID.
|
||||
func (mr *MockStoreMockRecorder) AllocateAccountSeqID(ctx, accountID, entity interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AllocateAccountSeqID", reflect.TypeOf((*MockStore)(nil).AllocateAccountSeqID), ctx, accountID, entity)
|
||||
}
|
||||
|
||||
// ExecuteInTransaction mocks base method.
|
||||
func (m *MockStore) ExecuteInTransaction(ctx context.Context, f func(Store) error) error {
|
||||
m.ctrl.T.Helper()
|
||||
|
||||
@@ -29,7 +29,6 @@ import (
|
||||
"github.com/netbirdio/netbird/route"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -42,27 +41,8 @@ const (
|
||||
PublicCategory = "public"
|
||||
PrivateCategory = "private"
|
||||
UnknownCategory = "unknown"
|
||||
|
||||
// firewallRuleMinPortRangesVer defines the minimum peer version that supports port range rules.
|
||||
firewallRuleMinPortRangesVer = "0.48.0"
|
||||
// firewallRuleMinNativeSSHVer defines the minimum peer version that supports native SSH features in the firewall rules.
|
||||
firewallRuleMinNativeSSHVer = "0.60.0"
|
||||
|
||||
// nativeSSHPortString defines the default port number as a string used for native SSH connections; this port is used by clients when hijacking ssh connections.
|
||||
nativeSSHPortString = "22022"
|
||||
nativeSSHPortNumber = 22022
|
||||
// defaultSSHPortString defines the standard SSH port number as a string, commonly used for default SSH connections.
|
||||
defaultSSHPortString = "22"
|
||||
defaultSSHPortNumber = 22
|
||||
)
|
||||
|
||||
type supportedFeatures struct {
|
||||
nativeSSH bool
|
||||
portRanges bool
|
||||
}
|
||||
|
||||
type LookupMap map[string]struct{}
|
||||
|
||||
// AccountMeta is a struct that contains a stripped down version of the Account object.
|
||||
// It doesn't carry any peers, groups, policies, or routes, etc. Just some metadata (e.g. ID, created by, created at, etc).
|
||||
type AccountMeta struct {
|
||||
@@ -1070,7 +1050,7 @@ func (a *Account) GetPeerConnectionResources(ctx context.Context, peer *nbpeer.P
|
||||
default:
|
||||
authorizedUsers[auth.Wildcard] = a.getAllowedUserIDs()
|
||||
}
|
||||
} else if peerInDestinations && policyRuleImpliesLegacySSH(rule) && peer.SSHEnabled {
|
||||
} else if peerInDestinations && PolicyRuleImpliesLegacySSH(rule) && peer.SSHEnabled {
|
||||
sshEnabled = true
|
||||
authorizedUsers[auth.Wildcard] = a.getAllowedUserIDs()
|
||||
}
|
||||
@@ -1136,15 +1116,15 @@ func (a *Account) connResourcesGenerator(ctx context.Context, targetPeer *nbpeer
|
||||
if len(rule.Ports) == 0 && len(rule.PortRanges) == 0 {
|
||||
rules = append(rules, &fr)
|
||||
} else {
|
||||
rules = append(rules, expandPortsAndRanges(fr, rule, targetPeer)...)
|
||||
rules = append(rules, ExpandPortsAndRanges(fr, rule, targetPeer)...)
|
||||
}
|
||||
|
||||
rules = appendIPv6FirewallRule(rules, rulesExists, peer, targetPeer, rule, firewallRuleContext{
|
||||
direction: direction,
|
||||
dirStr: strconv.Itoa(direction),
|
||||
protocolStr: string(protocol),
|
||||
actionStr: string(rule.Action),
|
||||
portsJoined: strings.Join(rule.Ports, ","),
|
||||
rules = AppendIPv6FirewallRule(rules, rulesExists, peer, targetPeer, rule, FirewallRuleContext{
|
||||
Direction: direction,
|
||||
DirStr: strconv.Itoa(direction),
|
||||
ProtocolStr: string(protocol),
|
||||
ActionStr: string(rule.Action),
|
||||
PortsJoined: strings.Join(rule.Ports, ","),
|
||||
})
|
||||
}
|
||||
}, func() ([]*nbpeer.Peer, []*FirewallRule) {
|
||||
@@ -1152,10 +1132,6 @@ func (a *Account) connResourcesGenerator(ctx context.Context, targetPeer *nbpeer
|
||||
}
|
||||
}
|
||||
|
||||
func policyRuleImpliesLegacySSH(rule *PolicyRule) bool {
|
||||
return rule.Protocol == PolicyRuleProtocolALL || (rule.Protocol == PolicyRuleProtocolTCP && (portsIncludesSSH(rule.Ports) || portRangeIncludesSSH(rule.PortRanges)))
|
||||
}
|
||||
|
||||
// PeerSSHEnabledFromPolicies is the network-map-free equivalent of the sshEnabled
|
||||
// determination in GetPeerConnectionResources / CalculateNetworkMapFromComponents.
|
||||
func PeerSSHEnabledFromPolicies(policies []*Policy, peerID string, peerGroupIDs map[string]struct{}, peerSSHEnabled bool) bool {
|
||||
@@ -1170,7 +1146,7 @@ func PeerSSHEnabledFromPolicies(policies []*Policy, peerID string, peerGroupIDs
|
||||
}
|
||||
|
||||
isSSHRule := rule.Protocol == PolicyRuleProtocolNetbirdSSH ||
|
||||
(policyRuleImpliesLegacySSH(rule) && peerSSHEnabled)
|
||||
(PolicyRuleImpliesLegacySSH(rule) && peerSSHEnabled)
|
||||
if !isSSHRule {
|
||||
continue
|
||||
}
|
||||
@@ -1197,24 +1173,6 @@ func ruleHasDestination(rule *PolicyRule, peerID string, peerGroupIDs map[string
|
||||
return false
|
||||
}
|
||||
|
||||
func portRangeIncludesSSH(portRanges []RulePortRange) bool {
|
||||
for _, pr := range portRanges {
|
||||
if (pr.Start <= defaultSSHPortNumber && pr.End >= defaultSSHPortNumber) || (pr.Start <= nativeSSHPortNumber && pr.End >= nativeSSHPortNumber) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func portsIncludesSSH(ports []string) bool {
|
||||
for _, port := range ports {
|
||||
if port == defaultSSHPortString || port == nativeSSHPortString {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// getAllPeersFromGroups for given peer ID and list of groups
|
||||
//
|
||||
// Returns a list of peers from specified groups that pass specified posture checks
|
||||
@@ -1314,7 +1272,7 @@ func (a *Account) getRouteFirewallRules(ctx context.Context, peerID string, poli
|
||||
}
|
||||
|
||||
rulePeers := a.getRulePeers(rule, policy.SourcePostureChecks, peerID, distributionPeers, validatedPeersMap)
|
||||
rules := generateRouteFirewallRules(ctx, route, rule, rulePeers, FirewallRuleDirectionIN, includeIPv6)
|
||||
rules := GenerateRouteFirewallRules(ctx, route, rule, rulePeers, FirewallRuleDirectionIN, includeIPv6)
|
||||
fwRules = append(fwRules, rules...)
|
||||
}
|
||||
}
|
||||
@@ -1807,96 +1765,6 @@ func (a *Account) createProxyPolicy(svc *service.Service, target *service.Target
|
||||
}
|
||||
}
|
||||
|
||||
// expandPortsAndRanges expands Ports and PortRanges of a rule into individual firewall rules
|
||||
func expandPortsAndRanges(base FirewallRule, rule *PolicyRule, peer *nbpeer.Peer) []*FirewallRule {
|
||||
features := peerSupportedFirewallFeatures(peer.Meta.WtVersion)
|
||||
|
||||
var expanded []*FirewallRule
|
||||
|
||||
for _, port := range rule.Ports {
|
||||
fr := base
|
||||
fr.Port = port
|
||||
expanded = append(expanded, &fr)
|
||||
}
|
||||
|
||||
for _, portRange := range rule.PortRanges {
|
||||
// prefer PolicyRule.Ports
|
||||
if len(rule.Ports) > 0 {
|
||||
break
|
||||
}
|
||||
fr := base
|
||||
|
||||
if features.portRanges {
|
||||
fr.PortRange = portRange
|
||||
} else {
|
||||
// Peer doesn't support port ranges, only allow single-port ranges
|
||||
if portRange.Start != portRange.End {
|
||||
continue
|
||||
}
|
||||
fr.Port = strconv.FormatUint(uint64(portRange.Start), 10)
|
||||
}
|
||||
expanded = append(expanded, &fr)
|
||||
}
|
||||
|
||||
if shouldCheckRulesForNativeSSH(features.nativeSSH, rule, peer) || rule.Protocol == PolicyRuleProtocolNetbirdSSH {
|
||||
expanded = addNativeSSHRule(base, expanded)
|
||||
}
|
||||
|
||||
return expanded
|
||||
}
|
||||
|
||||
// addNativeSSHRule adds a native SSH rule (port 22022) to the expanded rules if the base rule has port 22 configured.
|
||||
func addNativeSSHRule(base FirewallRule, expanded []*FirewallRule) []*FirewallRule {
|
||||
shouldAdd := false
|
||||
for _, fr := range expanded {
|
||||
if isPortInRule(nativeSSHPortString, 22022, fr) {
|
||||
return expanded
|
||||
}
|
||||
if isPortInRule(defaultSSHPortString, 22, fr) {
|
||||
shouldAdd = true
|
||||
}
|
||||
}
|
||||
if !shouldAdd {
|
||||
return expanded
|
||||
}
|
||||
|
||||
fr := base
|
||||
fr.Port = nativeSSHPortString
|
||||
return append(expanded, &fr)
|
||||
}
|
||||
|
||||
func isPortInRule(portString string, portInt uint16, rule *FirewallRule) bool {
|
||||
return rule.Port == portString || (rule.PortRange.Start <= portInt && portInt <= rule.PortRange.End)
|
||||
}
|
||||
|
||||
// shouldCheckRulesForNativeSSH determines whether specific policy rules should be checked for native SSH support.
|
||||
// While users can add the nativeSSHPortString, we look for cases when they used port 22 and based on SSH enabled
|
||||
// in both management and client, we indicate to add the native port.
|
||||
func shouldCheckRulesForNativeSSH(supportsNative bool, rule *PolicyRule, peer *nbpeer.Peer) bool {
|
||||
return supportsNative && peer.SSHEnabled && peer.Meta.Flags.ServerSSHAllowed && rule.Protocol == PolicyRuleProtocolTCP
|
||||
}
|
||||
|
||||
// peerSupportedFirewallFeatures checks if the peer version supports port ranges.
|
||||
func peerSupportedFirewallFeatures(peerVer string) supportedFeatures {
|
||||
if version.IsDevelopmentVersion(peerVer) {
|
||||
return supportedFeatures{true, true}
|
||||
}
|
||||
|
||||
var features supportedFeatures
|
||||
|
||||
meetMinVer, err := posture.MeetsMinVersion(firewallRuleMinNativeSSHVer, peerVer)
|
||||
features.nativeSSH = err == nil && meetMinVer
|
||||
|
||||
if features.nativeSSH {
|
||||
features.portRanges = true
|
||||
} else {
|
||||
meetMinVer, err = posture.MeetsMinVersion(firewallRuleMinPortRangesVer, peerVer)
|
||||
features.portRanges = err == nil && meetMinVer
|
||||
}
|
||||
|
||||
return features
|
||||
}
|
||||
|
||||
// filterZoneRecordsForPeers filters DNS records to only include peers to connect.
|
||||
// AAAA records are excluded when the requesting peer lacks IPv6 capability.
|
||||
func filterZoneRecordsForPeers(peer *nbpeer.Peer, customZone nbdns.CustomZone, peersToConnect, expiredPeers []*nbpeer.Peer) []nbdns.SimpleRecord {
|
||||
|
||||
@@ -16,6 +16,49 @@ import (
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
// GetPeerNetworkMapResult dispatches to either the legacy-NetworkMap path or
|
||||
// the components path based on the peer's capability and the kill switch.
|
||||
// Capable peers (PeerCapabilityComponentNetworkMap) get the raw components
|
||||
// shape — the server skips Calculate() entirely for them, saving CPU
|
||||
// proportional to the number of capable peers in the account. Legacy peers
|
||||
// (or any peer when componentsDisabled is true) get the fully-expanded
|
||||
// NetworkMap as before.
|
||||
func (a *Account) GetPeerNetworkMapResult(
|
||||
ctx context.Context,
|
||||
peerID string,
|
||||
componentsDisabled bool,
|
||||
peersCustomZone nbdns.CustomZone,
|
||||
accountZones []*zones.Zone,
|
||||
validatedPeersMap map[string]struct{},
|
||||
resourcePolicies map[string][]*Policy,
|
||||
routers map[string]map[string]*routerTypes.NetworkRouter,
|
||||
metrics *telemetry.AccountManagerMetrics,
|
||||
groupIDToUserIDs map[string][]string,
|
||||
) PeerNetworkMapResult {
|
||||
peer := a.Peers[peerID]
|
||||
if !componentsDisabled && peer != nil && peer.SupportsComponentNetworkMap() {
|
||||
components := a.GetPeerNetworkMapComponents(
|
||||
ctx, peerID, peersCustomZone, accountZones, validatedPeersMap, resourcePolicies, routers, groupIDToUserIDs,
|
||||
)
|
||||
// Mirror legacy graceful-degrade: GetPeerNetworkMapFromComponents
|
||||
// returns &NetworkMap{Network: a.Network.Copy()} when components is
|
||||
// nil. Match that floor so the receiving client always sees the
|
||||
// account Network identifier, not a fully-empty envelope.
|
||||
if components == nil {
|
||||
components = &NetworkMapComponents{
|
||||
PeerID: peerID,
|
||||
Network: a.Network.Copy(),
|
||||
}
|
||||
}
|
||||
return PeerNetworkMapResult{Components: components}
|
||||
}
|
||||
return PeerNetworkMapResult{
|
||||
NetworkMap: a.GetPeerNetworkMapFromComponents(
|
||||
ctx, peerID, peersCustomZone, accountZones, validatedPeersMap, resourcePolicies, routers, metrics, groupIDToUserIDs,
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Account) GetPeerNetworkMapFromComponents(
|
||||
ctx context.Context,
|
||||
peerID string,
|
||||
@@ -82,15 +125,27 @@ func (a *Account) GetPeerNetworkMapComponents(
|
||||
}
|
||||
|
||||
components := &NetworkMapComponents{
|
||||
PeerID: peerID,
|
||||
Network: a.Network.Copy(),
|
||||
NameServerGroups: make([]*nbdns.NameServerGroup, 0),
|
||||
CustomZoneDomain: peersCustomZone.Domain,
|
||||
ResourcePoliciesMap: make(map[string][]*Policy),
|
||||
RoutersMap: make(map[string]map[string]*routerTypes.NetworkRouter),
|
||||
NetworkResources: make([]*resourceTypes.NetworkResource, 0),
|
||||
PostureFailedPeers: make(map[string]map[string]struct{}, len(a.PostureChecks)),
|
||||
RouterPeers: make(map[string]*nbpeer.Peer),
|
||||
PeerID: peerID,
|
||||
Network: a.Network.Copy(),
|
||||
NameServerGroups: make([]*nbdns.NameServerGroup, 0),
|
||||
CustomZoneDomain: peersCustomZone.Domain,
|
||||
ResourcePoliciesMap: make(map[string][]*Policy),
|
||||
RoutersMap: make(map[string]map[string]*routerTypes.NetworkRouter),
|
||||
NetworkResources: make([]*resourceTypes.NetworkResource, 0),
|
||||
PostureFailedPeers: make(map[string]map[string]struct{}, len(a.PostureChecks)),
|
||||
RouterPeers: make(map[string]*nbpeer.Peer),
|
||||
NetworkXIDToSeq: make(map[string]uint32, len(a.Networks)),
|
||||
PostureCheckXIDToSeq: make(map[string]uint32, len(a.PostureChecks)),
|
||||
}
|
||||
for _, n := range a.Networks {
|
||||
if n != nil && n.HasSeqID() {
|
||||
components.NetworkXIDToSeq[n.ID] = n.AccountSeqID
|
||||
}
|
||||
}
|
||||
for _, pc := range a.PostureChecks {
|
||||
if pc != nil && pc.HasSeqID() {
|
||||
components.PostureCheckXIDToSeq[pc.ID] = pc.AccountSeqID
|
||||
}
|
||||
}
|
||||
|
||||
components.AccountSettings = &AccountSettingsInfo{
|
||||
@@ -209,21 +264,26 @@ func (a *Account) GetPeerNetworkMapComponents(
|
||||
components.ResourcePoliciesMap[resource.ID] = policies
|
||||
}
|
||||
|
||||
components.RoutersMap[resource.NetworkID] = networkRoutingPeers
|
||||
for peerIDKey := range networkRoutingPeers {
|
||||
if p := a.Peers[peerIDKey]; p != nil {
|
||||
if _, exists := components.RouterPeers[peerIDKey]; !exists {
|
||||
components.RouterPeers[peerIDKey] = p
|
||||
}
|
||||
if _, exists := components.Peers[peerIDKey]; !exists {
|
||||
if _, validated := validatedPeersMap[peerIDKey]; validated {
|
||||
components.Peers[peerIDKey] = p
|
||||
// Only expose router peers and the per-network routers_map when this
|
||||
// target peer actually has access to the resource (either as a router
|
||||
// itself or via a policy that includes it as a source). Without this
|
||||
// gate, every peer's envelope was leaking router peers of every
|
||||
// network in the account — accounts with many tenants/networks
|
||||
// shipped tens of unrelated peers in `peers[]` and `routers_map`.
|
||||
if addSourcePeers {
|
||||
components.RoutersMap[resource.NetworkID] = networkRoutingPeers
|
||||
for peerIDKey := range networkRoutingPeers {
|
||||
if p := a.Peers[peerIDKey]; p != nil {
|
||||
if _, exists := components.RouterPeers[peerIDKey]; !exists {
|
||||
components.RouterPeers[peerIDKey] = p
|
||||
}
|
||||
if _, exists := components.Peers[peerIDKey]; !exists {
|
||||
if _, validated := validatedPeersMap[peerIDKey]; validated {
|
||||
components.Peers[peerIDKey] = p
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if addSourcePeers {
|
||||
components.NetworkResources = append(components.NetworkResources, resource)
|
||||
}
|
||||
}
|
||||
@@ -254,18 +314,44 @@ func (a *Account) getPeersGroupsPoliciesRoutes(
|
||||
|
||||
relevantPeerIDs[peerID] = a.GetPeer(peerID)
|
||||
|
||||
peerGroupSet := make(map[string]struct{}, 8)
|
||||
for groupID, group := range a.Groups {
|
||||
if slices.Contains(group.Peers, peerID) {
|
||||
relevantGroupIDs[groupID] = a.GetGroup(groupID)
|
||||
peerGroupSet[groupID] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
routeAccessControlGroups := make(map[string]struct{})
|
||||
for _, r := range a.Routes {
|
||||
for _, groupID := range r.Groups {
|
||||
if r == nil {
|
||||
continue
|
||||
}
|
||||
relevant := r.Peer == peerID
|
||||
if !relevant {
|
||||
for _, groupID := range r.PeerGroups {
|
||||
if _, ok := peerGroupSet[groupID]; ok {
|
||||
relevant = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if !relevant && r.Enabled {
|
||||
for _, groupID := range r.Groups {
|
||||
if _, ok := peerGroupSet[groupID]; ok {
|
||||
relevant = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if !relevant {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, groupID := range r.PeerGroups {
|
||||
relevantGroupIDs[groupID] = a.GetGroup(groupID)
|
||||
}
|
||||
for _, groupID := range r.PeerGroups {
|
||||
for _, groupID := range r.Groups {
|
||||
relevantGroupIDs[groupID] = a.GetGroup(groupID)
|
||||
}
|
||||
if r.Enabled {
|
||||
@@ -274,6 +360,44 @@ func (a *Account) getPeersGroupsPoliciesRoutes(
|
||||
routeAccessControlGroups[groupID] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
// Include route advertisers in relevantPeerIDs. The envelope
|
||||
// encoder writes route.peer_index by looking up r.Peer in the
|
||||
// shipped peers list; if the advertiser is policy-isolated from
|
||||
// the target peer (no rule edge between them), it would otherwise
|
||||
// be omitted and the decoder would fail to resolve r.Peer, leaving
|
||||
// the client without a WG tunnel target for this route. Legacy
|
||||
// NetworkMap.Routes shipped the WG public key inline, so the
|
||||
// equivalence path doesn't surface this — but the dependency is
|
||||
// real once a client actually tries to use the route.
|
||||
// Gate by validatedPeersMap so non-validated advertisers stay out
|
||||
// (matches the network-resource router behaviour at the bottom of
|
||||
// this loop, and the legacy invariant that only validated peers
|
||||
// reach a client's view).
|
||||
if r.Peer != "" {
|
||||
if _, ok := validatedPeersMap[r.Peer]; ok {
|
||||
if p := a.GetPeer(r.Peer); p != nil {
|
||||
relevantPeerIDs[r.Peer] = p
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, groupID := range r.PeerGroups {
|
||||
g := a.GetGroup(groupID)
|
||||
if g == nil {
|
||||
continue
|
||||
}
|
||||
for _, pid := range g.Peers {
|
||||
if _, exists := relevantPeerIDs[pid]; exists {
|
||||
continue
|
||||
}
|
||||
if _, ok := validatedPeersMap[pid]; !ok {
|
||||
continue
|
||||
}
|
||||
if p := a.GetPeer(pid); p != nil {
|
||||
relevantPeerIDs[pid] = p
|
||||
}
|
||||
}
|
||||
}
|
||||
relevantRoutes = append(relevantRoutes, r)
|
||||
}
|
||||
|
||||
@@ -353,7 +477,7 @@ func (a *Account) getPeersGroupsPoliciesRoutes(
|
||||
default:
|
||||
sshReqs.needAllowedUserIDs = true
|
||||
}
|
||||
} else if policyRuleImpliesLegacySSH(rule) && peerSSHEnabled {
|
||||
} else if PolicyRuleImpliesLegacySSH(rule) && peerSSHEnabled {
|
||||
sshReqs.needAllowedUserIDs = true
|
||||
}
|
||||
}
|
||||
@@ -486,6 +610,13 @@ func (a *Account) getPostureValidPeersSaveFailed(inputPeers []string, postureChe
|
||||
return dest
|
||||
}
|
||||
|
||||
// filterGroupPeers trims each group's Peers slice to only those peers that
|
||||
// also appear in `peers`. Groups whose filtered list is empty are NOT
|
||||
// deleted from the map — they're kept so the components wire encoder can
|
||||
// still resolve seq references from routes/policies/access-control groups
|
||||
// that name them. Calculate() tolerates groups with empty Peers (the inner
|
||||
// loops simply iterate zero times), so retaining them is behaviourally a
|
||||
// no-op for the legacy path that consumes the same NetworkMapComponents.
|
||||
func filterGroupPeers(groups *map[string]*Group, peers map[string]*nbpeer.Peer) {
|
||||
for groupID, groupInfo := range *groups {
|
||||
filteredPeers := make([]string, 0, len(groupInfo.Peers))
|
||||
@@ -495,9 +626,7 @@ func filterGroupPeers(groups *map[string]*Group, peers map[string]*nbpeer.Peer)
|
||||
}
|
||||
}
|
||||
|
||||
if len(filteredPeers) == 0 {
|
||||
delete(*groups, groupID)
|
||||
} else if len(filteredPeers) != len(groupInfo.Peers) {
|
||||
if len(filteredPeers) != len(groupInfo.Peers) {
|
||||
ng := groupInfo.Copy()
|
||||
ng.Peers = filteredPeers
|
||||
(*groups)[groupID] = ng
|
||||
|
||||
29
management/server/types/account_seq_counter.go
Normal file
29
management/server/types/account_seq_counter.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package types
|
||||
|
||||
// AccountSeqEntity identifies the kind of component that uses a per-account sequence.
|
||||
type AccountSeqEntity string
|
||||
|
||||
const (
|
||||
AccountSeqEntityPolicy AccountSeqEntity = "policy"
|
||||
AccountSeqEntityGroup AccountSeqEntity = "group"
|
||||
AccountSeqEntityRoute AccountSeqEntity = "route"
|
||||
AccountSeqEntityNetworkResource AccountSeqEntity = "network_resource"
|
||||
AccountSeqEntityNetworkRouter AccountSeqEntity = "network_router"
|
||||
AccountSeqEntityNameserverGroup AccountSeqEntity = "nameserver_group"
|
||||
AccountSeqEntityNetwork AccountSeqEntity = "network"
|
||||
AccountSeqEntityPostureCheck AccountSeqEntity = "posture_check"
|
||||
)
|
||||
|
||||
// AccountSeqCounter tracks the next per-account integer id for a given component
|
||||
// kind. Reads/writes go through the store inside the same transaction as the
|
||||
// component insert so two concurrent inserts cannot collide on the same id.
|
||||
type AccountSeqCounter struct {
|
||||
AccountID string `gorm:"primaryKey;size:255"`
|
||||
Entity string `gorm:"primaryKey;size:32"`
|
||||
NextID uint32 `gorm:"not null;default:1"`
|
||||
}
|
||||
|
||||
// TableName overrides the GORM-derived table name.
|
||||
func (AccountSeqCounter) TableName() string {
|
||||
return "account_seq_counters"
|
||||
}
|
||||
@@ -666,7 +666,7 @@ func Test_ExpandPortsAndRanges_SSHRuleExpansion(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := expandPortsAndRanges(tt.base, tt.rule, tt.peer)
|
||||
result := ExpandPortsAndRanges(tt.base, tt.rule, tt.peer)
|
||||
|
||||
var ports []string
|
||||
for _, fr := range result {
|
||||
|
||||
142
management/server/types/aliases.go
Normal file
142
management/server/types/aliases.go
Normal file
@@ -0,0 +1,142 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"context"
|
||||
"math/rand"
|
||||
"net"
|
||||
"net/netip"
|
||||
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
nbroute "github.com/netbirdio/netbird/route"
|
||||
sharedtypes "github.com/netbirdio/netbird/shared/management/types"
|
||||
)
|
||||
|
||||
// Type aliases for types relocated to shared/management/types so that the
|
||||
// client-side compute path can depend on them
|
||||
|
||||
type DNSSettings = sharedtypes.DNSSettings
|
||||
|
||||
type FirewallRule = sharedtypes.FirewallRule
|
||||
|
||||
type Group = sharedtypes.Group
|
||||
type GroupPeer = sharedtypes.GroupPeer
|
||||
|
||||
type Network = sharedtypes.Network
|
||||
type NetworkMap = sharedtypes.NetworkMap
|
||||
type ForwardingRule = sharedtypes.ForwardingRule
|
||||
|
||||
type Policy = sharedtypes.Policy
|
||||
type PolicyUpdateOperation = sharedtypes.PolicyUpdateOperation
|
||||
|
||||
type PolicyRule = sharedtypes.PolicyRule
|
||||
type PolicyUpdateOperationType = sharedtypes.PolicyUpdateOperationType
|
||||
type PolicyTrafficActionType = sharedtypes.PolicyTrafficActionType
|
||||
type PolicyRuleProtocolType = sharedtypes.PolicyRuleProtocolType
|
||||
type PolicyRuleDirection = sharedtypes.PolicyRuleDirection
|
||||
type RulePortRange = sharedtypes.RulePortRange
|
||||
|
||||
type Resource = sharedtypes.Resource
|
||||
type ResourceType = sharedtypes.ResourceType
|
||||
|
||||
type RouteFirewallRule = sharedtypes.RouteFirewallRule
|
||||
|
||||
type NetworkMapComponents = sharedtypes.NetworkMapComponents
|
||||
type AccountSettingsInfo = sharedtypes.AccountSettingsInfo
|
||||
|
||||
type GroupCompact = sharedtypes.GroupCompact
|
||||
type NetworkMapComponentsCompact = sharedtypes.NetworkMapComponentsCompact
|
||||
|
||||
type LookupMap = sharedtypes.LookupMap
|
||||
type FirewallRuleContext = sharedtypes.FirewallRuleContext
|
||||
|
||||
const (
|
||||
GroupIssuedAPI = sharedtypes.GroupIssuedAPI
|
||||
GroupIssuedJWT = sharedtypes.GroupIssuedJWT
|
||||
GroupIssuedIntegration = sharedtypes.GroupIssuedIntegration
|
||||
GroupAllName = sharedtypes.GroupAllName
|
||||
)
|
||||
|
||||
// Function forwarders preserve types.X(...) call sites that previously
|
||||
// resolved to package-local funcs. Plain forwarders (not var aliases) keep
|
||||
// the symbol immutable and allow the inliner to flatten the call.
|
||||
|
||||
func PolicyRuleImpliesLegacySSH(rule *PolicyRule) bool {
|
||||
return sharedtypes.PolicyRuleImpliesLegacySSH(rule)
|
||||
}
|
||||
|
||||
func ExpandPortsAndRanges(base FirewallRule, rule *PolicyRule, peer *nbpeer.Peer) []*FirewallRule {
|
||||
return sharedtypes.ExpandPortsAndRanges(base, rule, peer)
|
||||
}
|
||||
|
||||
func AppendIPv6FirewallRule(rules []*FirewallRule, rulesExists map[string]struct{}, peer, targetPeer *nbpeer.Peer, rule *PolicyRule, rc FirewallRuleContext) []*FirewallRule {
|
||||
return sharedtypes.AppendIPv6FirewallRule(rules, rulesExists, peer, targetPeer, rule, rc)
|
||||
}
|
||||
|
||||
func CalculateNetworkMapFromComponents(ctx context.Context, components *NetworkMapComponents) *NetworkMap {
|
||||
return sharedtypes.CalculateNetworkMapFromComponents(ctx, components)
|
||||
}
|
||||
|
||||
func GenerateRouteFirewallRules(ctx context.Context, route *nbroute.Route, rule *PolicyRule, groupPeers []*nbpeer.Peer, direction int, includeIPv6 bool) []*RouteFirewallRule {
|
||||
return sharedtypes.GenerateRouteFirewallRules(ctx, route, rule, groupPeers, direction, includeIPv6)
|
||||
}
|
||||
|
||||
func AllocateIPv6Subnet(r *rand.Rand) net.IPNet {
|
||||
return sharedtypes.AllocateIPv6Subnet(r)
|
||||
}
|
||||
|
||||
func NewNetwork() *Network {
|
||||
return sharedtypes.NewNetwork()
|
||||
}
|
||||
|
||||
func AllocatePeerIP(prefix netip.Prefix, takenIps []netip.Addr) (netip.Addr, error) {
|
||||
return sharedtypes.AllocatePeerIP(prefix, takenIps)
|
||||
}
|
||||
|
||||
func AllocateRandomPeerIP(prefix netip.Prefix) (netip.Addr, error) {
|
||||
return sharedtypes.AllocateRandomPeerIP(prefix)
|
||||
}
|
||||
|
||||
func AllocateRandomPeerIPv6(prefix netip.Prefix) (netip.Addr, error) {
|
||||
return sharedtypes.AllocateRandomPeerIPv6(prefix)
|
||||
}
|
||||
|
||||
func ParseRuleString(rule string) (PolicyRuleProtocolType, RulePortRange, error) {
|
||||
return sharedtypes.ParseRuleString(rule)
|
||||
}
|
||||
|
||||
const (
|
||||
FirewallRuleDirectionIN = sharedtypes.FirewallRuleDirectionIN
|
||||
FirewallRuleDirectionOUT = sharedtypes.FirewallRuleDirectionOUT
|
||||
)
|
||||
|
||||
const (
|
||||
ResourceTypePeer = sharedtypes.ResourceTypePeer
|
||||
ResourceTypeDomain = sharedtypes.ResourceTypeDomain
|
||||
ResourceTypeHost = sharedtypes.ResourceTypeHost
|
||||
ResourceTypeSubnet = sharedtypes.ResourceTypeSubnet
|
||||
)
|
||||
|
||||
const (
|
||||
PolicyTrafficActionAccept = sharedtypes.PolicyTrafficActionAccept
|
||||
PolicyTrafficActionDrop = sharedtypes.PolicyTrafficActionDrop
|
||||
)
|
||||
|
||||
const (
|
||||
PolicyRuleProtocolALL = sharedtypes.PolicyRuleProtocolALL
|
||||
PolicyRuleProtocolTCP = sharedtypes.PolicyRuleProtocolTCP
|
||||
PolicyRuleProtocolUDP = sharedtypes.PolicyRuleProtocolUDP
|
||||
PolicyRuleProtocolICMP = sharedtypes.PolicyRuleProtocolICMP
|
||||
PolicyRuleProtocolNetbirdSSH = sharedtypes.PolicyRuleProtocolNetbirdSSH
|
||||
)
|
||||
|
||||
const (
|
||||
PolicyRuleFlowDirect = sharedtypes.PolicyRuleFlowDirect
|
||||
PolicyRuleFlowBidirect = sharedtypes.PolicyRuleFlowBidirect
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultRuleName = sharedtypes.DefaultRuleName
|
||||
DefaultRuleDescription = sharedtypes.DefaultRuleDescription
|
||||
DefaultPolicyName = sharedtypes.DefaultPolicyName
|
||||
DefaultPolicyDescription = sharedtypes.DefaultPolicyDescription
|
||||
)
|
||||
180
management/server/types/networkmap_wire_benchmark_test.go
Normal file
180
management/server/types/networkmap_wire_benchmark_test.go
Normal file
@@ -0,0 +1,180 @@
|
||||
package types_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
goproto "google.golang.org/protobuf/proto"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache"
|
||||
mgmtgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
// wireBenchScales — trimmed scale set for wire-size measurements. Encoding
|
||||
// and marshalling are linear, so the largest extremes don't add signal.
|
||||
var wireBenchScales = []benchmarkScale{
|
||||
{"100peers_5groups", 100, 5},
|
||||
{"500peers_20groups", 500, 20},
|
||||
{"1000peers_50groups", 1000, 50},
|
||||
{"5000peers_100groups", 5000, 100},
|
||||
}
|
||||
|
||||
// populateAccountSeqIDs assigns deterministic AccountSeqIDs to every group and
|
||||
// policy in the account so that the component encoder can reference them. The
|
||||
// scalableTestAccount fixture builds entities by struct literal and skips this
|
||||
// step, but production paths populate the IDs via the store layer.
|
||||
func populateAccountSeqIDs(account *types.Account) {
|
||||
var nextGroupSeq uint32 = 1
|
||||
for _, g := range account.Groups {
|
||||
g.AccountSeqID = nextGroupSeq
|
||||
nextGroupSeq++
|
||||
}
|
||||
var nextPolicySeq uint32 = 1
|
||||
for _, p := range account.Policies {
|
||||
p.AccountSeqID = nextPolicySeq
|
||||
nextPolicySeq++
|
||||
}
|
||||
}
|
||||
|
||||
// assignValidWgKeys overwrites every peer's Key with a valid base64-encoded
|
||||
// 32-byte string. The default scalableTestAccount uses unparsable strings
|
||||
// like "key-peer-0", which makes the components encoder emit a nil WgPubKey
|
||||
// and the legacy encoder ship 10-char placeholders — both shrink the wire
|
||||
// size in unrealistic ways. Production peers always have valid 44-char base64
|
||||
// keys, so any benchmark/breakdown that wants honest numbers must call this.
|
||||
func assignValidWgKeys(account *types.Account) {
|
||||
for _, p := range account.Peers {
|
||||
var raw [32]byte
|
||||
_, _ = rand.Read(raw[:])
|
||||
p.Key = base64.StdEncoding.EncodeToString(raw[:])
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkNetworkMapWireEncode reports per-call ns and the marshaled wire
|
||||
// size for both encoding paths. Run with:
|
||||
//
|
||||
// go test -run=^$ -bench=BenchmarkNetworkMapWireEncode -benchmem ./management/server/types/
|
||||
func BenchmarkNetworkMapWireEncode(b *testing.B) {
|
||||
skipCIBenchmark(b)
|
||||
|
||||
for _, scale := range wireBenchScales {
|
||||
account, validatedPeers := scalableTestAccount(scale.peers, scale.groups)
|
||||
populateAccountSeqIDs(account)
|
||||
assignValidWgKeys(account)
|
||||
|
||||
ctx := context.Background()
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
groupIDToUserIDs := account.GetActiveGroupUsers()
|
||||
|
||||
peerID := "peer-0"
|
||||
peer := account.Peers[peerID]
|
||||
|
||||
networkMap := account.GetPeerNetworkMapFromComponents(ctx, peerID, nbdns.CustomZone{}, nil, validatedPeers, resourcePolicies, routers, nil, groupIDToUserIDs)
|
||||
components := account.GetPeerNetworkMapComponents(ctx, peerID, nbdns.CustomZone{}, nil, validatedPeers, resourcePolicies, routers, groupIDToUserIDs)
|
||||
|
||||
dnsCache := &cache.DNSConfigCache{}
|
||||
settings := &types.Settings{}
|
||||
|
||||
// Pre-encode once so the size metric is identical for every run inside
|
||||
// the same scale; the b.Loop call only re-runs encode + Marshal.
|
||||
legacyResp := mgmtgrpc.ToSyncResponse(ctx, nil, nil, nil, peer, nil, nil, networkMap, "netbird.cloud", nil, dnsCache, settings, nil, nil, 0)
|
||||
legacyBytes, err := goproto.Marshal(legacyResp.NetworkMap)
|
||||
if err != nil {
|
||||
b.Fatalf("marshal legacy networkmap: %v", err)
|
||||
}
|
||||
|
||||
envelopeInput := mgmtgrpc.ComponentsEnvelopeInput{
|
||||
Components: components,
|
||||
PeerConfig: legacyResp.NetworkMap.PeerConfig,
|
||||
DNSDomain: "netbird.cloud",
|
||||
}
|
||||
envelope := mgmtgrpc.EncodeNetworkMapEnvelope(envelopeInput)
|
||||
envelopeBytes, err := goproto.Marshal(envelope)
|
||||
if err != nil {
|
||||
b.Fatalf("marshal envelope: %v", err)
|
||||
}
|
||||
|
||||
b.Run(fmt.Sprintf("legacy/%s", scale.name), func(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
b.ReportMetric(float64(len(legacyBytes)), "bytes/msg")
|
||||
b.ResetTimer()
|
||||
for range b.N {
|
||||
resp := mgmtgrpc.ToSyncResponse(ctx, nil, nil, nil, peer, nil, nil, networkMap, "netbird.cloud", nil, dnsCache, settings, nil, nil, 0)
|
||||
if _, err := goproto.Marshal(resp.NetworkMap); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
b.Run(fmt.Sprintf("components/%s", scale.name), func(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
b.ReportMetric(float64(len(envelopeBytes)), "bytes/msg")
|
||||
b.ResetTimer()
|
||||
for range b.N {
|
||||
env := mgmtgrpc.EncodeNetworkMapEnvelope(envelopeInput)
|
||||
if _, err := goproto.Marshal(env); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkNetworkMapWireSize is a fast snapshot of the wire size by scale
|
||||
// without a tight encode loop. Run with -bench to see one ns/op + bytes per
|
||||
// scale (treat the timing as informational; the sample is one Marshal per
|
||||
// scale, not the full b.N loop).
|
||||
func BenchmarkNetworkMapWireSize(b *testing.B) {
|
||||
skipCIBenchmark(b)
|
||||
|
||||
for _, scale := range wireBenchScales {
|
||||
account, validatedPeers := scalableTestAccount(scale.peers, scale.groups)
|
||||
populateAccountSeqIDs(account)
|
||||
assignValidWgKeys(account)
|
||||
|
||||
ctx := context.Background()
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
groupIDToUserIDs := account.GetActiveGroupUsers()
|
||||
|
||||
peerID := "peer-0"
|
||||
peer := account.Peers[peerID]
|
||||
|
||||
networkMap := account.GetPeerNetworkMapFromComponents(ctx, peerID, nbdns.CustomZone{}, nil, validatedPeers, resourcePolicies, routers, nil, groupIDToUserIDs)
|
||||
components := account.GetPeerNetworkMapComponents(ctx, peerID, nbdns.CustomZone{}, nil, validatedPeers, resourcePolicies, routers, groupIDToUserIDs)
|
||||
|
||||
dnsCache := &cache.DNSConfigCache{}
|
||||
settings := &types.Settings{}
|
||||
|
||||
legacyResp := mgmtgrpc.ToSyncResponse(ctx, nil, nil, nil, peer, nil, nil, networkMap, "netbird.cloud", nil, dnsCache, settings, nil, nil, 0)
|
||||
legacyBytes, err := goproto.Marshal(legacyResp.NetworkMap)
|
||||
if err != nil {
|
||||
b.Fatalf("marshal legacy networkmap: %v", err)
|
||||
}
|
||||
|
||||
env := mgmtgrpc.EncodeNetworkMapEnvelope(mgmtgrpc.ComponentsEnvelopeInput{
|
||||
Components: components,
|
||||
PeerConfig: legacyResp.NetworkMap.PeerConfig,
|
||||
DNSDomain: "netbird.cloud",
|
||||
})
|
||||
envBytes, err := goproto.Marshal(env)
|
||||
if err != nil {
|
||||
b.Fatalf("marshal envelope: %v", err)
|
||||
}
|
||||
|
||||
b.Run(fmt.Sprintf("size/%s", scale.name), func(b *testing.B) {
|
||||
b.ReportMetric(float64(len(legacyBytes)), "legacy_bytes")
|
||||
b.ReportMetric(float64(len(envBytes)), "components_bytes")
|
||||
ratio := float64(len(envBytes)) / float64(len(legacyBytes))
|
||||
b.ReportMetric(ratio, "components/legacy")
|
||||
for range b.N {
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
150
management/server/types/networkmap_wire_breakdown_test.go
Normal file
150
management/server/types/networkmap_wire_breakdown_test.go
Normal file
@@ -0,0 +1,150 @@
|
||||
package types_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
goproto "google.golang.org/protobuf/proto"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache"
|
||||
mgmtgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
// TestNetworkMapWireBreakdown is a one-shot diagnostic: it computes the wire
|
||||
// size attributable to each top-level field of both the legacy NetworkMap and
|
||||
// the components NetworkMapEnvelope at the 5000-peer scale, so the migration
|
||||
// docs can attribute the size reduction to each optimization. Runs only on
|
||||
// demand via -run TestNetworkMapWireBreakdown.
|
||||
func TestNetworkMapWireBreakdown(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("size diagnostic, skipped with -short")
|
||||
}
|
||||
if os.Getenv("NB_RUN_WIRE_BREAKDOWN") != "1" {
|
||||
t.Skip("set NB_RUN_WIRE_BREAKDOWN=1 to run wire breakdown diagnostic")
|
||||
}
|
||||
|
||||
const peerCount, groupCount = 5000, 100
|
||||
account, validatedPeers := scalableTestAccount(peerCount, groupCount)
|
||||
populateAccountSeqIDs(account)
|
||||
assignValidWgKeys(account)
|
||||
|
||||
ctx := context.Background()
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
groupIDToUserIDs := account.GetActiveGroupUsers()
|
||||
|
||||
peerID := "peer-0"
|
||||
peer := account.Peers[peerID]
|
||||
networkMap := account.GetPeerNetworkMapFromComponents(ctx, peerID, nbdns.CustomZone{}, nil, validatedPeers, resourcePolicies, routers, nil, groupIDToUserIDs)
|
||||
components := account.GetPeerNetworkMapComponents(ctx, peerID, nbdns.CustomZone{}, nil, validatedPeers, resourcePolicies, routers, groupIDToUserIDs)
|
||||
|
||||
dnsCache := &cache.DNSConfigCache{}
|
||||
settings := &types.Settings{}
|
||||
|
||||
legacyResp := mgmtgrpc.ToSyncResponse(ctx, nil, nil, nil, peer, nil, nil, networkMap, "netbird.cloud", nil, dnsCache, settings, nil, nil, 0)
|
||||
legacyTotal := mustMarshalSize(t, legacyResp.NetworkMap)
|
||||
|
||||
envelope := mgmtgrpc.EncodeNetworkMapEnvelope(mgmtgrpc.ComponentsEnvelopeInput{
|
||||
Components: components,
|
||||
PeerConfig: legacyResp.NetworkMap.PeerConfig,
|
||||
DNSDomain: "netbird.cloud",
|
||||
})
|
||||
componentsTotal := mustMarshalSize(t, envelope)
|
||||
|
||||
t.Logf("\n=== LEGACY NetworkMap (%d peers, %d groups) ===", peerCount, groupCount)
|
||||
t.Logf(" Total: %d bytes\n", legacyTotal)
|
||||
|
||||
legacyBreakdown := []struct {
|
||||
name string
|
||||
nm *proto.NetworkMap
|
||||
}{
|
||||
{"RemotePeers", &proto.NetworkMap{RemotePeers: legacyResp.NetworkMap.RemotePeers}},
|
||||
{"OfflinePeers", &proto.NetworkMap{OfflinePeers: legacyResp.NetworkMap.OfflinePeers}},
|
||||
{"FirewallRules", &proto.NetworkMap{FirewallRules: legacyResp.NetworkMap.FirewallRules}},
|
||||
{"Routes", &proto.NetworkMap{Routes: legacyResp.NetworkMap.Routes}},
|
||||
{"RoutesFirewallRules", &proto.NetworkMap{RoutesFirewallRules: legacyResp.NetworkMap.RoutesFirewallRules}},
|
||||
{"DNSConfig", &proto.NetworkMap{DNSConfig: legacyResp.NetworkMap.DNSConfig}},
|
||||
{"PeerConfig", &proto.NetworkMap{PeerConfig: legacyResp.NetworkMap.PeerConfig}},
|
||||
{"SshAuth", &proto.NetworkMap{SshAuth: legacyResp.NetworkMap.SshAuth}},
|
||||
}
|
||||
for _, e := range legacyBreakdown {
|
||||
size := mustMarshalSize(t, e.nm)
|
||||
t.Logf(" %-22s %8d bytes %5.1f%%", e.name, size, pct(size, legacyTotal))
|
||||
}
|
||||
|
||||
full := envelope.GetFull()
|
||||
if full == nil {
|
||||
t.Fatalf("expected full network map envelope payload, got nil")
|
||||
}
|
||||
t.Logf("\n=== COMPONENTS NetworkMapEnvelope (%d peers, %d groups) ===", peerCount, groupCount)
|
||||
t.Logf(" Total: %d bytes (%.1f%% of legacy)\n", componentsTotal, pct(componentsTotal, legacyTotal))
|
||||
|
||||
componentsBreakdown := []struct {
|
||||
name string
|
||||
nm *proto.NetworkMapComponentsFull
|
||||
}{
|
||||
{"Peers", &proto.NetworkMapComponentsFull{Peers: full.Peers}},
|
||||
{"Policies", &proto.NetworkMapComponentsFull{Policies: full.Policies}},
|
||||
{"Groups", &proto.NetworkMapComponentsFull{Groups: full.Groups}},
|
||||
{"Routes (raw)", &proto.NetworkMapComponentsFull{Routes: full.Routes}},
|
||||
{"NameServerGroups", &proto.NetworkMapComponentsFull{NameserverGroups: full.NameserverGroups}},
|
||||
{"AllDNSRecords", &proto.NetworkMapComponentsFull{AllDnsRecords: full.AllDnsRecords}},
|
||||
{"AccountZones", &proto.NetworkMapComponentsFull{AccountZones: full.AccountZones}},
|
||||
{"NetworkResources", &proto.NetworkMapComponentsFull{NetworkResources: full.NetworkResources}},
|
||||
{"RoutersMap", &proto.NetworkMapComponentsFull{RoutersMap: full.RoutersMap}},
|
||||
{"ResourcePoliciesMap", &proto.NetworkMapComponentsFull{ResourcePoliciesMap: full.ResourcePoliciesMap}},
|
||||
{"GroupIDToUserIDs", &proto.NetworkMapComponentsFull{GroupIdToUserIds: full.GroupIdToUserIds}},
|
||||
{"AllowedUserIDs", &proto.NetworkMapComponentsFull{AllowedUserIds: full.AllowedUserIds}},
|
||||
{"PostureFailedPeers", &proto.NetworkMapComponentsFull{PostureFailedPeers: full.PostureFailedPeers}},
|
||||
{"DNSSettings", &proto.NetworkMapComponentsFull{DnsSettings: full.DnsSettings}},
|
||||
{"PeerConfig", &proto.NetworkMapComponentsFull{PeerConfig: full.PeerConfig}},
|
||||
{"AgentVersions", &proto.NetworkMapComponentsFull{AgentVersions: full.AgentVersions}},
|
||||
}
|
||||
for _, e := range componentsBreakdown {
|
||||
size := mustMarshalSize(t, e.nm)
|
||||
t.Logf(" %-22s %8d bytes %5.1f%%", e.name, size, pct(size, componentsTotal))
|
||||
}
|
||||
|
||||
t.Logf("\n=== Per-PeerCompact average ===")
|
||||
if len(full.Peers) > 0 {
|
||||
t.Logf(" PeerCompact avg: %d bytes/peer", mustMarshalSize(t, &proto.NetworkMapComponentsFull{Peers: full.Peers})/len(full.Peers))
|
||||
}
|
||||
if len(legacyResp.NetworkMap.RemotePeers) > 0 {
|
||||
t.Logf(" RemotePeer avg: %d bytes/peer",
|
||||
mustMarshalSize(t, &proto.NetworkMap{RemotePeers: legacyResp.NetworkMap.RemotePeers})/len(legacyResp.NetworkMap.RemotePeers))
|
||||
}
|
||||
|
||||
t.Logf("\n=== FirewallRule expansion footprint ===")
|
||||
t.Logf(" legacy FirewallRules count: %d", len(legacyResp.NetworkMap.FirewallRules))
|
||||
t.Logf(" components Policies count: %d", len(full.Policies))
|
||||
t.Logf(" components Groups count: %d", len(full.Groups))
|
||||
|
||||
totalGroupPeerIdxs := 0
|
||||
for _, g := range full.Groups {
|
||||
totalGroupPeerIdxs += len(g.PeerIndexes)
|
||||
}
|
||||
t.Logf(" components peer-index refs across all groups: %d", totalGroupPeerIdxs)
|
||||
}
|
||||
|
||||
func mustMarshalSize(t *testing.T, m goproto.Message) int {
|
||||
b, err := goproto.Marshal(m)
|
||||
if err != nil {
|
||||
t.Fatalf("marshal: %v", err)
|
||||
}
|
||||
return len(b)
|
||||
}
|
||||
|
||||
func pct(part, total int) float64 {
|
||||
if total == 0 {
|
||||
return 0
|
||||
}
|
||||
return 100 * float64(part) / float64(total)
|
||||
}
|
||||
|
||||
// Stops fmt being unused if the breakdown loop above is later commented out.
|
||||
var _ = fmt.Sprintf
|
||||
25
management/server/types/peer_networkmap_result.go
Normal file
25
management/server/types/peer_networkmap_result.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package types
|
||||
|
||||
// PeerNetworkMapResult is what the network_map controller produces for a
|
||||
// single peer. Exactly one of NetworkMap or Components is populated depending
|
||||
// on the peer's capability:
|
||||
//
|
||||
// - Components-capable peers (PeerCapabilityComponentNetworkMap) get
|
||||
// Components: the raw types.NetworkMapComponents the client decodes and
|
||||
// runs Calculate() on locally. NetworkMap stays nil — the server skips
|
||||
// the expansion entirely.
|
||||
// - Legacy peers (or any peer when the kill switch is set) get NetworkMap:
|
||||
// the fully-expanded view the legacy gRPC path consumes.
|
||||
//
|
||||
// The gRPC layer (ToSyncResponseForPeer) dispatches by which field is
|
||||
// non-nil; callers must not rely on both being set.
|
||||
type PeerNetworkMapResult struct {
|
||||
NetworkMap *NetworkMap
|
||||
Components *NetworkMapComponents
|
||||
}
|
||||
|
||||
// IsComponents reports whether the result carries the components shape.
|
||||
// Use this in preference to direct nil checks on the fields.
|
||||
func (r PeerNetworkMapResult) IsComponents() bool {
|
||||
return r.Components != nil
|
||||
}
|
||||
104
management/server/types/peer_networkmap_result_test.go
Normal file
104
management/server/types/peer_networkmap_result_test.go
Normal file
@@ -0,0 +1,104 @@
|
||||
package types_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
// helper: marks the given peer as components-capable.
|
||||
func markCapable(p *nbpeer.Peer) {
|
||||
p.Meta.Capabilities = append(p.Meta.Capabilities, nbpeer.PeerCapabilityComponentNetworkMap)
|
||||
}
|
||||
|
||||
func TestGetPeerNetworkMapResult_CapablePeerGetsComponents(t *testing.T) {
|
||||
account, validatedPeers := scalableTestAccount(10, 2)
|
||||
markCapable(account.Peers["peer-0"])
|
||||
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
groupIDToUserIDs := account.GetActiveGroupUsers()
|
||||
|
||||
result := account.GetPeerNetworkMapResult(
|
||||
context.Background(),
|
||||
"peer-0",
|
||||
false, // componentsDisabled
|
||||
nbdns.CustomZone{},
|
||||
nil,
|
||||
validatedPeers,
|
||||
resourcePolicies,
|
||||
routers,
|
||||
nil,
|
||||
groupIDToUserIDs,
|
||||
)
|
||||
|
||||
require.True(t, result.IsComponents(), "capable peer must get the components shape")
|
||||
assert.Nil(t, result.NetworkMap)
|
||||
require.NotNil(t, result.Components)
|
||||
assert.Equal(t, "peer-0", result.Components.PeerID)
|
||||
}
|
||||
|
||||
func TestGetPeerNetworkMapResult_LegacyPeerGetsNetworkMap(t *testing.T) {
|
||||
account, validatedPeers := scalableTestAccount(10, 2)
|
||||
// peer-0 left without the component capability
|
||||
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
groupIDToUserIDs := account.GetActiveGroupUsers()
|
||||
|
||||
result := account.GetPeerNetworkMapResult(
|
||||
context.Background(),
|
||||
"peer-0",
|
||||
false,
|
||||
nbdns.CustomZone{},
|
||||
nil,
|
||||
validatedPeers,
|
||||
resourcePolicies,
|
||||
routers,
|
||||
nil,
|
||||
groupIDToUserIDs,
|
||||
)
|
||||
|
||||
assert.False(t, result.IsComponents())
|
||||
assert.Nil(t, result.Components)
|
||||
require.NotNil(t, result.NetworkMap, "legacy peer must get a NetworkMap")
|
||||
}
|
||||
|
||||
func TestGetPeerNetworkMapResult_KillSwitchOverridesCapability(t *testing.T) {
|
||||
// Capable peer + componentsDisabled=true → falls back to legacy.
|
||||
account, validatedPeers := scalableTestAccount(10, 2)
|
||||
markCapable(account.Peers["peer-0"])
|
||||
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
groupIDToUserIDs := account.GetActiveGroupUsers()
|
||||
|
||||
result := account.GetPeerNetworkMapResult(
|
||||
context.Background(),
|
||||
"peer-0",
|
||||
true, // componentsDisabled = true (kill switch)
|
||||
nbdns.CustomZone{},
|
||||
nil,
|
||||
validatedPeers,
|
||||
resourcePolicies,
|
||||
routers,
|
||||
nil,
|
||||
groupIDToUserIDs,
|
||||
)
|
||||
|
||||
assert.False(t, result.IsComponents(), "kill switch must force legacy NetworkMap path")
|
||||
assert.Nil(t, result.Components)
|
||||
require.NotNil(t, result.NetworkMap)
|
||||
}
|
||||
|
||||
func TestPeerNetworkMapResult_IsComponents(t *testing.T) {
|
||||
assert.True(t, types.PeerNetworkMapResult{Components: &types.NetworkMapComponents{}}.IsComponents())
|
||||
assert.False(t, types.PeerNetworkMapResult{NetworkMap: &types.NetworkMap{}}.IsComponents())
|
||||
assert.False(t, types.PeerNetworkMapResult{}.IsComponents())
|
||||
}
|
||||
@@ -95,6 +95,9 @@ type Route struct {
|
||||
ID ID `gorm:"primaryKey"`
|
||||
// AccountID is a reference to Account that this object belongs
|
||||
AccountID string `gorm:"index"`
|
||||
// AccountSeqID is a per-account monotonically increasing identifier used as the
|
||||
// compact wire id when sending NetworkMap components to capable peers.
|
||||
AccountSeqID uint32 `json:"-" gorm:"index:idx_routes_account_seq_id;not null;default:0"`
|
||||
// Network and Domains are mutually exclusive
|
||||
Network netip.Prefix `gorm:"serializer:json"`
|
||||
Domains domain.List `gorm:"serializer:json"`
|
||||
@@ -128,6 +131,7 @@ func (r *Route) Copy() *Route {
|
||||
route := &Route{
|
||||
ID: r.ID,
|
||||
AccountID: r.AccountID,
|
||||
AccountSeqID: r.AccountSeqID,
|
||||
Description: r.Description,
|
||||
NetID: r.NetID,
|
||||
Network: r.Network,
|
||||
|
||||
@@ -316,33 +316,87 @@ func TestClient_Sync(t *testing.T) {
|
||||
|
||||
select {
|
||||
case resp := <-ch:
|
||||
if resp.GetPeerConfig() == nil {
|
||||
if resp.GetPeerConfig() == nil && resp.GetNetworkMap().GetPeerConfig() == nil {
|
||||
t.Error("expecting non nil PeerConfig got nil")
|
||||
}
|
||||
if resp.GetNetbirdConfig() == nil {
|
||||
t.Error("expecting non nil NetbirdConfig got nil")
|
||||
}
|
||||
// we test network map peers from 0.29.3 and dev builds
|
||||
// Top-level RemotePeers is deprecated and must stay empty for
|
||||
// v0.29.3+ (and dev) clients — the field rides inside NetworkMap
|
||||
// (legacy) or the NetworkMapEnvelope (components) instead.
|
||||
if len(resp.GetRemotePeers()) != 0 {
|
||||
t.Error("expecting top-level RemotePeers to be empty for v0.29.3+ clients")
|
||||
}
|
||||
networkMap := resp.GetNetworkMap()
|
||||
if len(networkMap.GetRemotePeers()) != 1 {
|
||||
t.Errorf("expecting RemotePeers size %d got %d", 1, len(networkMap.GetRemotePeers()))
|
||||
// Component-capable clients receive a NetworkMapEnvelope; the
|
||||
// remote-peers list is encoded inside it. Decode it and check the
|
||||
// envelope's peers slice. Legacy peers populate NetworkMap.RemotePeers;
|
||||
// both shapes must surface exactly one remote peer.
|
||||
remotePeerKeys := remotePeerKeysFromSync(resp, testKey.PublicKey().String())
|
||||
if len(remotePeerKeys) != 1 {
|
||||
t.Errorf("expecting RemotePeers size %d got %d", 1, len(remotePeerKeys))
|
||||
return
|
||||
}
|
||||
|
||||
if networkMap.GetRemotePeersIsEmpty() {
|
||||
if resp.GetNetworkMap() != nil && resp.GetNetworkMap().GetRemotePeersIsEmpty() {
|
||||
t.Error("expecting RemotePeers property to be false, got true")
|
||||
}
|
||||
if networkMap.GetRemotePeers()[0].GetWgPubKey() != remoteKey.PublicKey().String() {
|
||||
t.Errorf("expecting RemotePeer public key %s got %s", remoteKey.PublicKey().String(), networkMap.GetRemotePeers()[0].GetWgPubKey())
|
||||
if remotePeerKeys[0] != remoteKey.PublicKey().String() {
|
||||
t.Errorf("expecting RemotePeer public key %s got %s", remoteKey.PublicKey().String(), remotePeerKeys[0])
|
||||
}
|
||||
case <-time.After(3 * time.Second):
|
||||
t.Error("timeout waiting for test to finish")
|
||||
}
|
||||
}
|
||||
|
||||
// remotePeerKeysFromSync extracts the remote-peer WG keys from either the
|
||||
// legacy NetworkMap.RemotePeers list or the components NetworkMapEnvelope's
|
||||
// inner peers slice (filtering out the local receiving peer identified by
|
||||
// localKey, since the envelope's peers list is index-addressed and includes
|
||||
// the local peer alongside remotes).
|
||||
func remotePeerKeysFromSync(resp *mgmtProto.SyncResponse, localKey string) []string {
|
||||
if rp := resp.GetRemotePeers(); len(rp) > 0 {
|
||||
out := make([]string, 0, len(rp))
|
||||
for _, p := range rp {
|
||||
out = append(out, p.GetWgPubKey())
|
||||
}
|
||||
return out
|
||||
}
|
||||
if rp := resp.GetNetworkMap().GetRemotePeers(); len(rp) > 0 {
|
||||
out := make([]string, 0, len(rp))
|
||||
for _, p := range rp {
|
||||
out = append(out, p.GetWgPubKey())
|
||||
}
|
||||
return out
|
||||
}
|
||||
env := resp.GetNetworkMapEnvelope().GetFull()
|
||||
if env == nil {
|
||||
return nil
|
||||
}
|
||||
out := make([]string, 0, len(env.GetPeers()))
|
||||
for _, p := range env.GetPeers() {
|
||||
key := wgKeyFromBytes(p.GetWgPubKey())
|
||||
if key == "" || key == localKey {
|
||||
continue
|
||||
}
|
||||
out = append(out, key)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// wgKeyFromBytes mirrors the client-side decoder: the envelope ships raw 32
|
||||
// bytes; reconstruct the standard base64 key the test compares against.
|
||||
func wgKeyFromBytes(raw []byte) string {
|
||||
if len(raw) == 0 {
|
||||
return ""
|
||||
}
|
||||
var k wgtypes.Key
|
||||
if len(raw) != len(k) {
|
||||
return ""
|
||||
}
|
||||
copy(k[:], raw)
|
||||
return k.String()
|
||||
}
|
||||
|
||||
func Test_SystemMetaDataFromClient(t *testing.T) {
|
||||
s, lis, mgmtMockServer, serverKey := startMockManagement(t)
|
||||
defer s.GracefulStop()
|
||||
|
||||
@@ -1005,6 +1005,10 @@ func infoToMetaData(info *system.Info) *proto.PeerSystemMeta {
|
||||
func peerCapabilities(info system.Info) []proto.PeerCapability {
|
||||
caps := []proto.PeerCapability{
|
||||
proto.PeerCapability_PeerCapabilitySourcePrefixes,
|
||||
// PeerCapabilityComponentNetworkMap signals that this client can
|
||||
// decode the components-format SyncResponse.NetworkMapEnvelope and
|
||||
// run Calculate() locally.
|
||||
proto.PeerCapability_PeerCapabilityComponentNetworkMap,
|
||||
}
|
||||
if !info.DisableIPv6 {
|
||||
caps = append(caps, proto.PeerCapability_PeerCapabilityIPv6Overlay)
|
||||
|
||||
610
shared/management/networkmap/decode.go
Normal file
610
shared/management/networkmap/decode.go
Normal file
@@ -0,0 +1,610 @@
|
||||
package networkmap
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
nbroute "github.com/netbirdio/netbird/route"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
"github.com/netbirdio/netbird/shared/management/types"
|
||||
)
|
||||
|
||||
// DecodeEnvelope converts a NetworkMapEnvelope into a NetworkMapComponents
|
||||
// the client can run Calculate() over. Every ID-reference on the wire is a
|
||||
// uint32 (peer index or account_seq_id) — no xid strings travel. The decoder
|
||||
// synthesises consistent string IDs from the uint32s so the reconstructed
|
||||
// components struct round-trips through Calculate exactly the way the
|
||||
// server-side typed components would.
|
||||
//
|
||||
// ID scheme on the client side:
|
||||
//
|
||||
// Peers base64(wg_pub_key) // stable across snapshots
|
||||
// Groups "g_<account_seq_id>"
|
||||
// Policies "pol_<account_seq_id>" // 1 rule per policy
|
||||
// Routes "r_<account_seq_id>"
|
||||
// Network resources "nres_<account_seq_id>"
|
||||
// Posture checks "pc_<account_seq_id>"
|
||||
// Networks "net_<account_seq_id>"
|
||||
// Nameserver groups "nsg_<account_seq_id>"
|
||||
func DecodeEnvelope(env *proto.NetworkMapEnvelope) (*types.NetworkMapComponents, error) {
|
||||
if env == nil {
|
||||
return nil, fmt.Errorf("nil envelope")
|
||||
}
|
||||
full := env.GetFull()
|
||||
if full == nil {
|
||||
return nil, fmt.Errorf("envelope has no Full payload")
|
||||
}
|
||||
|
||||
c := &types.NetworkMapComponents{
|
||||
PeerID: "", // engine fills its own peer id from PeerConfig
|
||||
Network: decodeAccountNetwork(full.Network),
|
||||
AccountSettings: decodeAccountSettings(full.AccountSettings),
|
||||
CustomZoneDomain: full.CustomZoneDomain,
|
||||
Peers: make(map[string]*nbpeer.Peer, len(full.Peers)),
|
||||
Groups: make(map[string]*types.Group, len(full.Groups)),
|
||||
Policies: make([]*types.Policy, 0, len(full.Policies)),
|
||||
Routes: make([]*nbroute.Route, 0, len(full.Routes)),
|
||||
NameServerGroups: make([]*nbdns.NameServerGroup, 0, len(full.NameserverGroups)),
|
||||
AllDNSRecords: decodeSimpleRecords(full.AllDnsRecords),
|
||||
AccountZones: decodeCustomZones(full.AccountZones),
|
||||
ResourcePoliciesMap: make(map[string][]*types.Policy),
|
||||
RoutersMap: make(map[string]map[string]*routerTypes.NetworkRouter),
|
||||
NetworkResources: make([]*resourceTypes.NetworkResource, 0, len(full.NetworkResources)),
|
||||
RouterPeers: make(map[string]*nbpeer.Peer),
|
||||
AllowedUserIDs: stringSliceToSet(full.AllowedUserIds),
|
||||
PostureFailedPeers: make(map[string]map[string]struct{}, len(full.PostureFailedPeers)),
|
||||
GroupIDToUserIDs: make(map[string][]string, len(full.GroupIdToUserIds)),
|
||||
}
|
||||
|
||||
if full.DnsSettings != nil {
|
||||
c.DNSSettings = &types.DNSSettings{
|
||||
DisabledManagementGroups: groupIDsFromSeqs(full.DnsSettings.DisabledManagementGroupIds),
|
||||
}
|
||||
} else {
|
||||
c.DNSSettings = &types.DNSSettings{}
|
||||
}
|
||||
|
||||
// Phase 1: peers. The envelope's peers slice is index-addressed on the
|
||||
// wire; we re-key by the peer's WireGuard public key (base64) so the
|
||||
// in-memory components struct uses a stable identifier across
|
||||
// snapshots. peerIDByIndex lets downstream phases resolve wire indexes
|
||||
// back to that key. A peer with a missing or malformed wg_pub_key is
|
||||
// skipped (and its index keeps "" so any cross-reference falls into the
|
||||
// same missing-peer branch downstream) — matches legacy behaviour, which
|
||||
// degrades gracefully rather than aborting the whole sync on a single
|
||||
// bad row.
|
||||
peerIDByIndex := make([]string, len(full.Peers))
|
||||
for idx, pc := range full.Peers {
|
||||
if pc == nil {
|
||||
log.Warnf("envelope: peers[%d] is nil, skipping", idx)
|
||||
continue
|
||||
}
|
||||
if len(pc.WgPubKey) != 32 {
|
||||
log.Warnf("envelope: peers[%d] wg_pub_key length %d (want 32), skipping", idx, len(pc.WgPubKey))
|
||||
continue
|
||||
}
|
||||
peerID := base64.StdEncoding.EncodeToString(pc.WgPubKey)
|
||||
peer := decodePeerCompact(pc, peerID, full.AgentVersions)
|
||||
c.Peers[peerID] = peer
|
||||
peerIDByIndex[idx] = peerID
|
||||
}
|
||||
|
||||
// Phase 2: groups. AccountSeqID becomes both the synthesized string ID
|
||||
// and the GroupCompact.id wire value.
|
||||
for i, gc := range full.Groups {
|
||||
if gc == nil {
|
||||
return nil, fmt.Errorf("invalid envelope: groups[%d] is nil", i)
|
||||
}
|
||||
groupID := synthGroupID(gc.Id)
|
||||
peerIDs := make([]string, 0, len(gc.PeerIndexes))
|
||||
for _, idx := range gc.PeerIndexes {
|
||||
if int(idx) < len(peerIDByIndex) {
|
||||
peerIDs = append(peerIDs, peerIDByIndex[idx])
|
||||
}
|
||||
}
|
||||
c.Groups[groupID] = &types.Group{
|
||||
ID: groupID,
|
||||
AccountSeqID: gc.Id,
|
||||
Name: gc.Name,
|
||||
Peers: peerIDs,
|
||||
}
|
||||
}
|
||||
|
||||
// Phase 3: policies (PolicyCompact = one rule per entry; current data
|
||||
// model is 1 rule per policy). Policy.ID is synthesized from the
|
||||
// per-account seq id; proto.FirewallRule.PolicyID downstream carries
|
||||
// the same synth string (no xid on the wire).
|
||||
for i, pc := range full.Policies {
|
||||
if pc == nil {
|
||||
return nil, fmt.Errorf("invalid envelope: policies[%d] is nil", i)
|
||||
}
|
||||
policyID := synthPolicyID(pc.Id)
|
||||
c.Policies = append(c.Policies, decodePolicyCompact(pc, policyID, peerIDByIndex))
|
||||
}
|
||||
|
||||
// Phase 4: routes.
|
||||
for i, rr := range full.Routes {
|
||||
if rr == nil {
|
||||
return nil, fmt.Errorf("invalid envelope: routes[%d] is nil", i)
|
||||
}
|
||||
c.Routes = append(c.Routes, decodeRouteRaw(rr, peerIDByIndex))
|
||||
}
|
||||
|
||||
// Phase 5: NSGs.
|
||||
for i, nsg := range full.NameserverGroups {
|
||||
if nsg == nil {
|
||||
return nil, fmt.Errorf("invalid envelope: nameserver_groups[%d] is nil", i)
|
||||
}
|
||||
c.NameServerGroups = append(c.NameServerGroups, decodeNameServerGroupRaw(nsg))
|
||||
}
|
||||
|
||||
// Phase 6: network resources.
|
||||
for i, nr := range full.NetworkResources {
|
||||
if nr == nil {
|
||||
return nil, fmt.Errorf("invalid envelope: network_resources[%d] is nil", i)
|
||||
}
|
||||
c.NetworkResources = append(c.NetworkResources, decodeNetworkResource(nr))
|
||||
}
|
||||
|
||||
// Phase 7: routers_map (outer key = network seq id, inner key = peer-id
|
||||
// reconstructed from peer_index). Synthesized network id is "net_<seq>".
|
||||
for networkSeq, list := range full.RoutersMap {
|
||||
networkID := synthNetworkID(networkSeq)
|
||||
inner := make(map[string]*routerTypes.NetworkRouter, len(list.Entries))
|
||||
for _, entry := range list.Entries {
|
||||
if !entry.PeerIndexSet {
|
||||
continue
|
||||
}
|
||||
if int(entry.PeerIndex) >= len(peerIDByIndex) {
|
||||
continue
|
||||
}
|
||||
peerID := peerIDByIndex[entry.PeerIndex]
|
||||
inner[peerID] = &routerTypes.NetworkRouter{
|
||||
ID: "",
|
||||
NetworkID: networkID,
|
||||
AccountSeqID: entry.Id,
|
||||
Peer: peerID,
|
||||
PeerGroups: groupIDsFromSeqs(entry.PeerGroupIds),
|
||||
Masquerade: entry.Masquerade,
|
||||
Metric: int(entry.Metric),
|
||||
Enabled: entry.Enabled,
|
||||
}
|
||||
}
|
||||
if len(inner) > 0 {
|
||||
c.RoutersMap[networkID] = inner
|
||||
}
|
||||
}
|
||||
|
||||
// Phase 8: resource_policies_map (resource seq id → list of *types.Policy
|
||||
// pointers from the decoded policies slice). Resource ID is synthesized
|
||||
// the same way as in decodeNetworkResource.
|
||||
for resourceSeq, idxs := range full.ResourcePoliciesMap {
|
||||
if len(idxs.Indexes) == 0 {
|
||||
continue
|
||||
}
|
||||
resourceID := synthNetworkResourceID(resourceSeq)
|
||||
policies := make([]*types.Policy, 0, len(idxs.Indexes))
|
||||
for _, i := range idxs.Indexes {
|
||||
if int(i) < len(c.Policies) {
|
||||
policies = append(policies, c.Policies[i])
|
||||
}
|
||||
}
|
||||
if len(policies) > 0 {
|
||||
c.ResourcePoliciesMap[resourceID] = policies
|
||||
}
|
||||
}
|
||||
|
||||
// Phase 9: group_id_to_user_ids — wire keys are seq ids, synth to strings.
|
||||
for groupSeq, list := range full.GroupIdToUserIds {
|
||||
c.GroupIDToUserIDs[synthGroupID(groupSeq)] = append([]string(nil), list.UserIds...)
|
||||
}
|
||||
|
||||
// Phase 10: posture_failed_peers — wire keys are posture-check seq ids,
|
||||
// values are peer indexes that need to be turned into peer ids. PolicyRule
|
||||
// SourcePostureChecks (also synth ids) reference the same key space.
|
||||
for checkSeq, set := range full.PostureFailedPeers {
|
||||
checkID := synthPostureCheckID(checkSeq)
|
||||
failed := make(map[string]struct{}, len(set.PeerIndexes))
|
||||
for _, idx := range set.PeerIndexes {
|
||||
if int(idx) < len(peerIDByIndex) {
|
||||
failed[peerIDByIndex[idx]] = struct{}{}
|
||||
}
|
||||
}
|
||||
if len(failed) > 0 {
|
||||
c.PostureFailedPeers[checkID] = failed
|
||||
}
|
||||
}
|
||||
|
||||
// Phase 11: router_peer_indexes — peers that act as routers. They're
|
||||
// already in c.Peers (router peers are appended to the global peers
|
||||
// list by the encoder); RouterPeers is the subset.
|
||||
for _, idx := range full.RouterPeerIndexes {
|
||||
if int(idx) < len(peerIDByIndex) {
|
||||
peerID := peerIDByIndex[idx]
|
||||
c.RouterPeers[peerID] = c.Peers[peerID]
|
||||
}
|
||||
}
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func decodeAccountNetwork(an *proto.AccountNetwork) *types.Network {
|
||||
if an == nil {
|
||||
return nil
|
||||
}
|
||||
n := &types.Network{
|
||||
Identifier: an.Identifier,
|
||||
Dns: an.Dns,
|
||||
Serial: an.Serial,
|
||||
}
|
||||
if an.NetCidr != "" {
|
||||
if _, ipnet, err := net.ParseCIDR(an.NetCidr); err == nil && ipnet != nil {
|
||||
n.Net = *ipnet
|
||||
}
|
||||
}
|
||||
if an.NetV6Cidr != "" {
|
||||
if _, ipnet, err := net.ParseCIDR(an.NetV6Cidr); err == nil && ipnet != nil {
|
||||
n.NetV6 = *ipnet
|
||||
}
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
func decodeAccountSettings(as *proto.AccountSettingsCompact) *types.AccountSettingsInfo {
|
||||
if as == nil {
|
||||
return &types.AccountSettingsInfo{}
|
||||
}
|
||||
return &types.AccountSettingsInfo{
|
||||
PeerLoginExpirationEnabled: as.PeerLoginExpirationEnabled,
|
||||
PeerLoginExpiration: time.Duration(as.PeerLoginExpirationNs),
|
||||
}
|
||||
}
|
||||
|
||||
func decodePeerCompact(pc *proto.PeerCompact, peerID string, agentVersions []string) *nbpeer.Peer {
|
||||
var caps []int32
|
||||
if pc.SupportsSourcePrefixes {
|
||||
caps = append(caps, nbpeer.PeerCapabilitySourcePrefixes)
|
||||
}
|
||||
if pc.SupportsIpv6 {
|
||||
caps = append(caps, nbpeer.PeerCapabilityIPv6Overlay)
|
||||
}
|
||||
peer := &nbpeer.Peer{
|
||||
ID: peerID,
|
||||
Key: peerID,
|
||||
SSHKey: string(pc.SshPubKey),
|
||||
SSHEnabled: pc.SshEnabled,
|
||||
DNSLabel: pc.DnsLabel,
|
||||
LoginExpirationEnabled: pc.LoginExpirationEnabled,
|
||||
Meta: nbpeer.PeerSystemMeta{
|
||||
WtVersion: lookupAgentVersion(agentVersions, pc.AgentVersionIdx),
|
||||
Capabilities: caps,
|
||||
Flags: nbpeer.Flags{
|
||||
ServerSSHAllowed: pc.ServerSshAllowed,
|
||||
},
|
||||
},
|
||||
}
|
||||
if pc.AddedWithSsoLogin {
|
||||
// Set a non-empty UserID so (*Peer).AddedWithSSOLogin() returns true.
|
||||
// The original UserID isn't on the wire; the value is intentionally
|
||||
// visibly synthetic so any future consumer that mistakes UserID for a
|
||||
// real account user xid won't silently match (or worse, write the
|
||||
// sentinel into a downstream record).
|
||||
peer.UserID = "<env-sso>"
|
||||
}
|
||||
if pc.LastLoginUnixNano != 0 {
|
||||
t := time.Unix(0, pc.LastLoginUnixNano)
|
||||
peer.LastLogin = &t
|
||||
}
|
||||
switch len(pc.Ip) {
|
||||
case 4:
|
||||
peer.IP = netip.AddrFrom4([4]byte{pc.Ip[0], pc.Ip[1], pc.Ip[2], pc.Ip[3]})
|
||||
case 16:
|
||||
var a [16]byte
|
||||
copy(a[:], pc.Ip)
|
||||
peer.IP = netip.AddrFrom16(a)
|
||||
}
|
||||
if len(pc.Ipv6) == 16 {
|
||||
var a [16]byte
|
||||
copy(a[:], pc.Ipv6)
|
||||
peer.IPv6 = netip.AddrFrom16(a)
|
||||
}
|
||||
return peer
|
||||
}
|
||||
|
||||
func decodePolicyCompact(pc *proto.PolicyCompact, policyID string, peerIDByIndex []string) *types.Policy {
|
||||
rule := &types.PolicyRule{
|
||||
ID: policyID, // 1 rule per policy → reuse synthesized id
|
||||
PolicyID: policyID,
|
||||
Enabled: true,
|
||||
Action: actionFromProto(pc.Action),
|
||||
Protocol: protocolFromProto(pc.Protocol),
|
||||
Bidirectional: pc.Bidirectional,
|
||||
Ports: uint32SliceToStrings(pc.Ports),
|
||||
PortRanges: portRangesFromProto(pc.PortRanges),
|
||||
Sources: groupIDsFromSeqs(pc.SourceGroupIds),
|
||||
Destinations: groupIDsFromSeqs(pc.DestinationGroupIds),
|
||||
AuthorizedUser: pc.AuthorizedUser,
|
||||
AuthorizedGroups: authorizedGroupsFromProto(pc.AuthorizedGroups),
|
||||
SourceResource: resourceFromProto(pc.SourceResource, peerIDByIndex),
|
||||
DestinationResource: resourceFromProto(pc.DestinationResource, peerIDByIndex),
|
||||
}
|
||||
return &types.Policy{
|
||||
ID: policyID,
|
||||
AccountSeqID: pc.Id,
|
||||
Enabled: true,
|
||||
Rules: []*types.PolicyRule{rule},
|
||||
SourcePostureChecks: postureCheckIDsFromSeqs(pc.SourcePostureCheckSeqIds),
|
||||
}
|
||||
}
|
||||
|
||||
// resourceFromProto rebuilds types.Resource. For peer-typed resources the
|
||||
// peer reference is reconstructed from the envelope's peer index — wire
|
||||
// format ships no xid for peers, so we use the synthesized peer id.
|
||||
func resourceFromProto(r *proto.ResourceCompact, peerIDByIndex []string) types.Resource {
|
||||
if r == nil {
|
||||
return types.Resource{}
|
||||
}
|
||||
out := types.Resource{Type: types.ResourceType(r.Type)}
|
||||
if r.PeerIndexSet && int(r.PeerIndex) < len(peerIDByIndex) {
|
||||
out.ID = peerIDByIndex[r.PeerIndex]
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// postureCheckIDsFromSeqs synths posture-check ids from per-account seq ids.
|
||||
// Mirrors groupIDsFromSeqs.
|
||||
func postureCheckIDsFromSeqs(seqs []uint32) []string {
|
||||
if len(seqs) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]string, len(seqs))
|
||||
for i, s := range seqs {
|
||||
out[i] = synthPostureCheckID(s)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// authorizedGroupsFromProto inverts encodeAuthorizedGroups: the wire form
|
||||
// keys by group account_seq_id, the typed PolicyRule field keys by group
|
||||
// xid string. We rebuild using the same synthetic scheme the rest of the
|
||||
// decoder uses ("g<seq>").
|
||||
func authorizedGroupsFromProto(m map[uint32]*proto.UserNameList) map[string][]string {
|
||||
if len(m) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make(map[string][]string, len(m))
|
||||
for seq, list := range m {
|
||||
if list == nil {
|
||||
continue
|
||||
}
|
||||
out[synthGroupID(seq)] = append([]string(nil), list.Names...)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func decodeRouteRaw(rr *proto.RouteRaw, peerIDByIndex []string) *nbroute.Route {
|
||||
r := &nbroute.Route{
|
||||
ID: nbroute.ID(synthRouteID(rr.Id)),
|
||||
AccountSeqID: rr.Id,
|
||||
NetID: nbroute.NetID(rr.NetId),
|
||||
Description: rr.Description,
|
||||
Domains: domainsFromPunycode(rr.Domains),
|
||||
KeepRoute: rr.KeepRoute,
|
||||
NetworkType: nbroute.NetworkType(rr.NetworkType),
|
||||
Masquerade: rr.Masquerade,
|
||||
Metric: int(rr.Metric),
|
||||
Enabled: rr.Enabled,
|
||||
Groups: groupIDsFromSeqs(rr.GroupIds),
|
||||
AccessControlGroups: groupIDsFromSeqs(rr.AccessControlGroupIds),
|
||||
PeerGroups: groupIDsFromSeqs(rr.PeerGroupIds),
|
||||
SkipAutoApply: rr.SkipAutoApply,
|
||||
}
|
||||
if rr.NetworkCidr != "" {
|
||||
if p, err := netip.ParsePrefix(rr.NetworkCidr); err == nil {
|
||||
r.Network = p
|
||||
}
|
||||
}
|
||||
if rr.PeerIndexSet && int(rr.PeerIndex) < len(peerIDByIndex) {
|
||||
r.Peer = peerIDByIndex[rr.PeerIndex]
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
func decodeNameServerGroupRaw(nsg *proto.NameServerGroupRaw) *nbdns.NameServerGroup {
|
||||
out := &nbdns.NameServerGroup{
|
||||
ID: synthNameServerGroupID(nsg.Id),
|
||||
AccountSeqID: nsg.Id,
|
||||
Name: nsg.Name,
|
||||
Description: nsg.Description,
|
||||
Groups: groupIDsFromSeqs(nsg.GroupIds),
|
||||
Primary: nsg.Primary,
|
||||
Domains: nsg.Domains,
|
||||
Enabled: nsg.Enabled,
|
||||
SearchDomainsEnabled: nsg.SearchDomainsEnabled,
|
||||
NameServers: make([]nbdns.NameServer, 0, len(nsg.Nameservers)),
|
||||
}
|
||||
for _, ns := range nsg.Nameservers {
|
||||
if addr, err := netip.ParseAddr(ns.IP); err == nil {
|
||||
out.NameServers = append(out.NameServers, nbdns.NameServer{
|
||||
IP: addr,
|
||||
NSType: nbdns.NameServerType(ns.NSType),
|
||||
Port: int(ns.Port),
|
||||
})
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func decodeNetworkResource(nr *proto.NetworkResourceRaw) *resourceTypes.NetworkResource {
|
||||
out := &resourceTypes.NetworkResource{
|
||||
ID: synthNetworkResourceID(nr.Id),
|
||||
AccountSeqID: nr.Id,
|
||||
NetworkID: synthNetworkID(nr.NetworkSeq),
|
||||
Name: nr.Name,
|
||||
Description: nr.Description,
|
||||
Type: resourceTypes.NetworkResourceType(nr.Type),
|
||||
Address: nr.Address,
|
||||
Domain: nr.DomainValue,
|
||||
Enabled: nr.Enabled,
|
||||
}
|
||||
if nr.PrefixCidr != "" {
|
||||
if p, err := netip.ParsePrefix(nr.PrefixCidr); err == nil {
|
||||
out.Prefix = p
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func decodeSimpleRecords(records []*proto.SimpleRecord) []nbdns.SimpleRecord {
|
||||
out := make([]nbdns.SimpleRecord, 0, len(records))
|
||||
for _, r := range records {
|
||||
out = append(out, nbdns.SimpleRecord{
|
||||
Name: r.Name,
|
||||
Type: int(r.Type),
|
||||
Class: r.Class,
|
||||
TTL: int(r.TTL),
|
||||
RData: r.RData,
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func decodeCustomZones(zones []*proto.CustomZone) []nbdns.CustomZone {
|
||||
out := make([]nbdns.CustomZone, 0, len(zones))
|
||||
for _, z := range zones {
|
||||
out = append(out, nbdns.CustomZone{
|
||||
Domain: z.Domain,
|
||||
Records: decodeSimpleRecords(z.Records),
|
||||
SearchDomainDisabled: z.SearchDomainDisabled,
|
||||
NonAuthoritative: z.NonAuthoritative,
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// Synthetic ID generators — deterministic given the same wire input.
|
||||
// Underscore-separated ("p_<n>", "pol_<n>", ...) so they're visually
|
||||
// distinct in operator logs. fmt.Sprintf would dominate the decode hot path
|
||||
// on large accounts (a 10k-peer envelope produces ~50k synth calls); the
|
||||
// strconv.AppendUint builder keeps it allocation-light.
|
||||
func synthID(prefix string, n uint32) string {
|
||||
buf := make([]byte, 0, len(prefix)+10)
|
||||
buf = append(buf, prefix...)
|
||||
buf = strconv.AppendUint(buf, uint64(n), 10)
|
||||
return string(buf)
|
||||
}
|
||||
|
||||
func synthGroupID(seq uint32) string { return synthID("g_", seq) }
|
||||
func synthPolicyID(seq uint32) string { return synthID("pol_", seq) }
|
||||
func synthRouteID(seq uint32) string { return synthID("r_", seq) }
|
||||
func synthNetworkResourceID(seq uint32) string { return synthID("nres_", seq) }
|
||||
func synthPostureCheckID(seq uint32) string { return synthID("pc_", seq) }
|
||||
func synthNetworkID(seq uint32) string { return synthID("net_", seq) }
|
||||
func synthNameServerGroupID(seq uint32) string { return synthID("nsg_", seq) }
|
||||
|
||||
func groupIDsFromSeqs(seqs []uint32) []string {
|
||||
if len(seqs) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]string, len(seqs))
|
||||
for i, s := range seqs {
|
||||
out[i] = synthGroupID(s)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func uint32SliceToStrings(ports []uint32) []string {
|
||||
if len(ports) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]string, len(ports))
|
||||
for i, p := range ports {
|
||||
out[i] = strconv.FormatUint(uint64(p), 10)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func portRangesFromProto(ranges []*proto.PortInfo_Range) []types.RulePortRange {
|
||||
if len(ranges) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]types.RulePortRange, 0, len(ranges))
|
||||
for _, r := range ranges {
|
||||
if r == nil || r.Start > 65535 || r.End > 65535 {
|
||||
continue
|
||||
}
|
||||
out = append(out, types.RulePortRange{
|
||||
Start: uint16(r.Start),
|
||||
End: uint16(r.End),
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func actionFromProto(a proto.RuleAction) types.PolicyTrafficActionType {
|
||||
if a == proto.RuleAction_DROP {
|
||||
return types.PolicyTrafficActionDrop
|
||||
}
|
||||
return types.PolicyTrafficActionAccept
|
||||
}
|
||||
|
||||
func protocolFromProto(p proto.RuleProtocol) types.PolicyRuleProtocolType {
|
||||
switch p {
|
||||
case proto.RuleProtocol_TCP:
|
||||
return types.PolicyRuleProtocolTCP
|
||||
case proto.RuleProtocol_UDP:
|
||||
return types.PolicyRuleProtocolUDP
|
||||
case proto.RuleProtocol_ICMP:
|
||||
return types.PolicyRuleProtocolICMP
|
||||
case proto.RuleProtocol_ALL:
|
||||
return types.PolicyRuleProtocolALL
|
||||
case proto.RuleProtocol_NETBIRD_SSH:
|
||||
return types.PolicyRuleProtocolNetbirdSSH
|
||||
default:
|
||||
return types.PolicyRuleProtocolALL
|
||||
}
|
||||
}
|
||||
|
||||
func lookupAgentVersion(table []string, idx uint32) string {
|
||||
if int(idx) < len(table) {
|
||||
return table[idx]
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func stringSliceToSet(s []string) map[string]struct{} {
|
||||
if len(s) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make(map[string]struct{}, len(s))
|
||||
for _, v := range s {
|
||||
out[v] = struct{}{}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// domainsFromPunycode is a thin wrapper that converts a punycode list back to
|
||||
// the domain.List type the route.Route struct expects. It accepts the
|
||||
// punycode strings as-is (no extra decoding) — symmetric with
|
||||
// route.Domains.ToPunycodeList() used in the encoder.
|
||||
func domainsFromPunycode(punycoded []string) domain.List {
|
||||
if len(punycoded) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make(domain.List, 0, len(punycoded))
|
||||
for _, d := range punycoded {
|
||||
out = append(out, domain.Domain(d))
|
||||
}
|
||||
return out
|
||||
}
|
||||
323
shared/management/networkmap/encode.go
Normal file
323
shared/management/networkmap/encode.go
Normal file
@@ -0,0 +1,323 @@
|
||||
// Package networkmap contains the shared NetworkMap helpers that both the
|
||||
// management server and the client agent need.
|
||||
//
|
||||
// The proto-conversion helpers (types.NetworkMap → proto.NetworkMap) live
|
||||
// here so the client can run the same conversion locally after deriving its
|
||||
// NetworkMap from a NetworkMapEnvelope, without taking a dependency on the
|
||||
// server-side conversion package (which pulls in cloud integrations and is
|
||||
// otherwise an unwanted internal import on the client).
|
||||
//
|
||||
// The helpers are pure functions over inputs — no caches, no IO, no logging
|
||||
// beyond a context-aware error log when an individual user-id hash fails.
|
||||
package networkmap
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
goproto "google.golang.org/protobuf/proto"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"net/netip"
|
||||
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/shared/management/types"
|
||||
nbroute "github.com/netbirdio/netbird/route"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
"github.com/netbirdio/netbird/shared/netiputil"
|
||||
"github.com/netbirdio/netbird/shared/sshauth"
|
||||
)
|
||||
|
||||
// ToProtocolRoutes converts a slice of typed routes to their proto form.
|
||||
func ToProtocolRoutes(routes []*nbroute.Route) []*proto.Route {
|
||||
protoRoutes := make([]*proto.Route, 0, len(routes))
|
||||
for _, r := range routes {
|
||||
protoRoutes = append(protoRoutes, ToProtocolRoute(r))
|
||||
}
|
||||
return protoRoutes
|
||||
}
|
||||
|
||||
// ToProtocolRoute converts one typed route to its proto form.
|
||||
func ToProtocolRoute(route *nbroute.Route) *proto.Route {
|
||||
return &proto.Route{
|
||||
ID: string(route.ID),
|
||||
NetID: string(route.NetID),
|
||||
Network: route.Network.String(),
|
||||
Domains: route.Domains.ToPunycodeList(),
|
||||
NetworkType: int64(route.NetworkType),
|
||||
Peer: route.Peer,
|
||||
Metric: int64(route.Metric),
|
||||
Masquerade: route.Masquerade,
|
||||
KeepRoute: route.KeepRoute,
|
||||
SkipAutoApply: route.SkipAutoApply,
|
||||
}
|
||||
}
|
||||
|
||||
// ToProtocolFirewallRules converts the firewall rules to the protocol form.
|
||||
// When useSourcePrefixes is true, the compact SourcePrefixes field is
|
||||
// populated alongside the deprecated PeerIP for forward compatibility.
|
||||
// Wildcard rules ("0.0.0.0") are expanded into separate v4/v6 SourcePrefixes
|
||||
// when includeIPv6 is true.
|
||||
func ToProtocolFirewallRules(rules []*types.FirewallRule, includeIPv6, useSourcePrefixes bool) []*proto.FirewallRule {
|
||||
result := make([]*proto.FirewallRule, 0, len(rules))
|
||||
for i := range rules {
|
||||
rule := rules[i]
|
||||
|
||||
fwRule := &proto.FirewallRule{
|
||||
PolicyID: []byte(rule.PolicyID),
|
||||
PeerIP: rule.PeerIP, //nolint:staticcheck // populated for backward compatibility
|
||||
Direction: GetProtoDirection(rule.Direction),
|
||||
Action: GetProtoAction(rule.Action),
|
||||
Protocol: GetProtoProtocol(rule.Protocol),
|
||||
Port: rule.Port,
|
||||
}
|
||||
|
||||
if useSourcePrefixes && rule.PeerIP != "" {
|
||||
result = append(result, populateSourcePrefixes(fwRule, rule, includeIPv6)...)
|
||||
}
|
||||
|
||||
if ShouldUsePortRange(fwRule) {
|
||||
fwRule.PortInfo = rule.PortRange.ToProto()
|
||||
}
|
||||
|
||||
result = append(result, fwRule)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// populateSourcePrefixes sets SourcePrefixes on fwRule and returns any
|
||||
// additional rules needed (e.g. a v6 wildcard clone when the peer IP is
|
||||
// unspecified).
|
||||
func populateSourcePrefixes(fwRule *proto.FirewallRule, rule *types.FirewallRule, includeIPv6 bool) []*proto.FirewallRule {
|
||||
addr, err := netip.ParseAddr(rule.PeerIP)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if !addr.IsUnspecified() {
|
||||
fwRule.SourcePrefixes = [][]byte{netiputil.EncodeAddr(addr.Unmap())}
|
||||
return nil
|
||||
}
|
||||
|
||||
v4Wildcard, _ := netiputil.EncodePrefix(netip.PrefixFrom(netip.IPv4Unspecified(), 0))
|
||||
fwRule.SourcePrefixes = [][]byte{v4Wildcard}
|
||||
|
||||
if !includeIPv6 {
|
||||
return nil
|
||||
}
|
||||
|
||||
v6Rule := goproto.Clone(fwRule).(*proto.FirewallRule)
|
||||
v6Rule.PeerIP = "::" //nolint:staticcheck // populated for backward compatibility
|
||||
v6Wildcard, _ := netiputil.EncodePrefix(netip.PrefixFrom(netip.IPv6Unspecified(), 0))
|
||||
v6Rule.SourcePrefixes = [][]byte{v6Wildcard}
|
||||
if ShouldUsePortRange(v6Rule) {
|
||||
v6Rule.PortInfo = rule.PortRange.ToProto()
|
||||
}
|
||||
return []*proto.FirewallRule{v6Rule}
|
||||
}
|
||||
|
||||
// GetProtoDirection converts the direction to proto.RuleDirection.
|
||||
func GetProtoDirection(direction int) proto.RuleDirection {
|
||||
if direction == types.FirewallRuleDirectionOUT {
|
||||
return proto.RuleDirection_OUT
|
||||
}
|
||||
return proto.RuleDirection_IN
|
||||
}
|
||||
|
||||
// GetProtoAction converts the action to proto.RuleAction.
|
||||
func GetProtoAction(action string) proto.RuleAction {
|
||||
if action == string(types.PolicyTrafficActionDrop) {
|
||||
return proto.RuleAction_DROP
|
||||
}
|
||||
return proto.RuleAction_ACCEPT
|
||||
}
|
||||
|
||||
// GetProtoProtocol converts the protocol to proto.RuleProtocol.
|
||||
func GetProtoProtocol(protocol string) proto.RuleProtocol {
|
||||
switch types.PolicyRuleProtocolType(protocol) {
|
||||
case types.PolicyRuleProtocolALL:
|
||||
return proto.RuleProtocol_ALL
|
||||
case types.PolicyRuleProtocolTCP:
|
||||
return proto.RuleProtocol_TCP
|
||||
case types.PolicyRuleProtocolUDP:
|
||||
return proto.RuleProtocol_UDP
|
||||
case types.PolicyRuleProtocolICMP:
|
||||
return proto.RuleProtocol_ICMP
|
||||
case types.PolicyRuleProtocolNetbirdSSH:
|
||||
return proto.RuleProtocol_NETBIRD_SSH
|
||||
default:
|
||||
return proto.RuleProtocol_UNKNOWN
|
||||
}
|
||||
}
|
||||
|
||||
// GetProtoPortInfo converts route-firewall-rule port info to proto.PortInfo.
|
||||
func GetProtoPortInfo(rule *types.RouteFirewallRule) *proto.PortInfo {
|
||||
var portInfo proto.PortInfo
|
||||
if rule.Port != 0 {
|
||||
portInfo.PortSelection = &proto.PortInfo_Port{Port: uint32(rule.Port)}
|
||||
} else if portRange := rule.PortRange; portRange.Start != 0 && portRange.End != 0 {
|
||||
portInfo.PortSelection = &proto.PortInfo_Range_{
|
||||
Range: &proto.PortInfo_Range{
|
||||
Start: uint32(portRange.Start),
|
||||
End: uint32(portRange.End),
|
||||
},
|
||||
}
|
||||
}
|
||||
return &portInfo
|
||||
}
|
||||
|
||||
// ShouldUsePortRange reports whether the firewall rule should use a port
|
||||
// range rather than a single port (TCP/UDP without a single port).
|
||||
func ShouldUsePortRange(rule *proto.FirewallRule) bool {
|
||||
return rule.Port == "" && (rule.Protocol == proto.RuleProtocol_UDP || rule.Protocol == proto.RuleProtocol_TCP)
|
||||
}
|
||||
|
||||
// ToProtocolRoutesFirewallRules converts a slice of typed route-firewall
|
||||
// rules to proto.
|
||||
func ToProtocolRoutesFirewallRules(rules []*types.RouteFirewallRule) []*proto.RouteFirewallRule {
|
||||
result := make([]*proto.RouteFirewallRule, len(rules))
|
||||
for i := range rules {
|
||||
rule := rules[i]
|
||||
result[i] = &proto.RouteFirewallRule{
|
||||
SourceRanges: rule.SourceRanges,
|
||||
Action: GetProtoAction(rule.Action),
|
||||
Destination: rule.Destination,
|
||||
Protocol: GetProtoProtocol(rule.Protocol),
|
||||
PortInfo: GetProtoPortInfo(rule),
|
||||
IsDynamic: rule.IsDynamic,
|
||||
Domains: rule.Domains.ToPunycodeList(),
|
||||
PolicyID: []byte(rule.PolicyID),
|
||||
RouteID: string(rule.RouteID),
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// ConvertToProtoCustomZone converts an nbdns.CustomZone to its proto form.
|
||||
func ConvertToProtoCustomZone(zone nbdns.CustomZone) *proto.CustomZone {
|
||||
protoZone := &proto.CustomZone{
|
||||
Domain: zone.Domain,
|
||||
Records: make([]*proto.SimpleRecord, 0, len(zone.Records)),
|
||||
SearchDomainDisabled: zone.SearchDomainDisabled,
|
||||
NonAuthoritative: zone.NonAuthoritative,
|
||||
}
|
||||
for _, record := range zone.Records {
|
||||
protoZone.Records = append(protoZone.Records, &proto.SimpleRecord{
|
||||
Name: record.Name,
|
||||
Type: int64(record.Type),
|
||||
Class: record.Class,
|
||||
TTL: int64(record.TTL),
|
||||
RData: record.RData,
|
||||
})
|
||||
}
|
||||
return protoZone
|
||||
}
|
||||
|
||||
// ConvertToProtoNameServerGroup converts a NameServerGroup to its proto form.
|
||||
func ConvertToProtoNameServerGroup(nsGroup *nbdns.NameServerGroup) *proto.NameServerGroup {
|
||||
protoGroup := &proto.NameServerGroup{
|
||||
Primary: nsGroup.Primary,
|
||||
Domains: nsGroup.Domains,
|
||||
SearchDomainsEnabled: nsGroup.SearchDomainsEnabled,
|
||||
NameServers: make([]*proto.NameServer, 0, len(nsGroup.NameServers)),
|
||||
}
|
||||
for _, ns := range nsGroup.NameServers {
|
||||
protoGroup.NameServers = append(protoGroup.NameServers, &proto.NameServer{
|
||||
IP: ns.IP.String(),
|
||||
Port: int64(ns.Port),
|
||||
NSType: int64(ns.NSType),
|
||||
})
|
||||
}
|
||||
return protoGroup
|
||||
}
|
||||
|
||||
// DNSConfigCache is the cache contract for amortising NameServerGroup
|
||||
// proto-conversion across peers in the same account. Server uses a concrete
|
||||
// implementation; client passes nil (no cross-peer caching needed when
|
||||
// rebuilding a single NetworkMap from an envelope).
|
||||
type DNSConfigCache interface {
|
||||
GetNameServerGroup(key string) (*proto.NameServerGroup, bool)
|
||||
SetNameServerGroup(key string, value *proto.NameServerGroup)
|
||||
}
|
||||
|
||||
// ToProtocolDNSConfig converts nbdns.Config to proto.DNSConfig. If cache is
|
||||
// non-nil, NameServerGroup proto values are cached by NSG.ID across calls —
|
||||
// the server amortises this across peers, the client passes nil.
|
||||
func ToProtocolDNSConfig(update nbdns.Config, cache DNSConfigCache, forwardPort int64) *proto.DNSConfig {
|
||||
protoUpdate := &proto.DNSConfig{
|
||||
ServiceEnable: update.ServiceEnable,
|
||||
CustomZones: make([]*proto.CustomZone, 0, len(update.CustomZones)),
|
||||
NameServerGroups: make([]*proto.NameServerGroup, 0, len(update.NameServerGroups)),
|
||||
ForwarderPort: forwardPort,
|
||||
}
|
||||
|
||||
for _, zone := range update.CustomZones {
|
||||
protoUpdate.CustomZones = append(protoUpdate.CustomZones, ConvertToProtoCustomZone(zone))
|
||||
}
|
||||
|
||||
for _, nsGroup := range update.NameServerGroups {
|
||||
if cache != nil {
|
||||
if cachedGroup, exists := cache.GetNameServerGroup(nsGroup.ID); exists {
|
||||
protoUpdate.NameServerGroups = append(protoUpdate.NameServerGroups, cachedGroup)
|
||||
continue
|
||||
}
|
||||
}
|
||||
protoGroup := ConvertToProtoNameServerGroup(nsGroup)
|
||||
if cache != nil {
|
||||
cache.SetNameServerGroup(nsGroup.ID, protoGroup)
|
||||
}
|
||||
protoUpdate.NameServerGroups = append(protoUpdate.NameServerGroups, protoGroup)
|
||||
}
|
||||
|
||||
return protoUpdate
|
||||
}
|
||||
|
||||
// AppendRemotePeerConfig appends typed peers as proto.RemotePeerConfig
|
||||
// entries to dst and returns the result.
|
||||
func AppendRemotePeerConfig(dst []*proto.RemotePeerConfig, peers []*nbpeer.Peer, dnsName string, includeIPv6 bool) []*proto.RemotePeerConfig {
|
||||
for _, rPeer := range peers {
|
||||
allowedIPs := []string{rPeer.IP.String() + "/32"}
|
||||
if includeIPv6 && rPeer.IPv6.IsValid() {
|
||||
allowedIPs = append(allowedIPs, rPeer.IPv6.String()+"/128")
|
||||
}
|
||||
dst = append(dst, &proto.RemotePeerConfig{
|
||||
WgPubKey: rPeer.Key,
|
||||
AllowedIps: allowedIPs,
|
||||
SshConfig: &proto.SSHConfig{SshPubKey: []byte(rPeer.SSHKey)},
|
||||
Fqdn: rPeer.FQDN(dnsName),
|
||||
AgentVersion: rPeer.Meta.WtVersion,
|
||||
})
|
||||
}
|
||||
return dst
|
||||
}
|
||||
|
||||
// BuildAuthorizedUsersProto deduplicates user-IDs into a hashed list and
|
||||
// builds per-machine-user index maps. Returns (hashedUsers, machineUsers).
|
||||
// Errors from individual hash failures are logged via the provided context;
|
||||
// they leave the offending user out of the result but don't abort the build.
|
||||
func BuildAuthorizedUsersProto(ctx context.Context, authorizedUsers map[string]map[string]struct{}) ([][]byte, map[string]*proto.MachineUserIndexes) {
|
||||
userIDToIndex := make(map[string]uint32)
|
||||
var hashedUsers [][]byte
|
||||
machineUsers := make(map[string]*proto.MachineUserIndexes, len(authorizedUsers))
|
||||
|
||||
for machineUser, users := range authorizedUsers {
|
||||
indexes := make([]uint32, 0, len(users))
|
||||
for userID := range users {
|
||||
idx, exists := userIDToIndex[userID]
|
||||
if !exists {
|
||||
hash, err := sshauth.HashUserID(userID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).WithError(err).Error("failed to hash user id")
|
||||
continue
|
||||
}
|
||||
idx = uint32(len(hashedUsers))
|
||||
userIDToIndex[userID] = idx
|
||||
hashedUsers = append(hashedUsers, hash[:])
|
||||
}
|
||||
indexes = append(indexes, idx)
|
||||
}
|
||||
machineUsers[machineUser] = &proto.MachineUserIndexes{Indexes: indexes}
|
||||
}
|
||||
|
||||
return hashedUsers, machineUsers
|
||||
}
|
||||
190
shared/management/networkmap/envelope.go
Normal file
190
shared/management/networkmap/envelope.go
Normal file
@@ -0,0 +1,190 @@
|
||||
package networkmap
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/types"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
// EnvelopeResult is what the client engine consumes after receiving a
|
||||
// component-format NetworkMap. Both fields are populated:
|
||||
//
|
||||
// - NetworkMap is the *proto.NetworkMap shape the engine reads today via
|
||||
// update.GetNetworkMap() — built from the envelope's components by
|
||||
// running Calculate() locally + converting back through the shared
|
||||
// proto helpers + merging the optional ProxyPatch.
|
||||
// - Components is the *types.NetworkMapComponents the engine retains so
|
||||
// future incremental delta updates have a base to apply changes
|
||||
// against. The client keeps it under its sync lock.
|
||||
type EnvelopeResult struct {
|
||||
NetworkMap *proto.NetworkMap
|
||||
Components *types.NetworkMapComponents
|
||||
}
|
||||
|
||||
// EnvelopeToNetworkMap is the full client-side pipeline: decode the
|
||||
// component envelope back to a typed NetworkMapComponents, run Calculate()
|
||||
// locally to produce the typed NetworkMap, convert it to the wire form the
|
||||
// engine consumes, and fold in any ProxyPatch the server attached.
|
||||
//
|
||||
// localPeerKey is the receiving peer's WG pub key (used to derive
|
||||
// includeIPv6 / useSourcePrefixes from the receiving peer's own record in
|
||||
// the components struct, mirroring legacy ToSyncResponse behaviour).
|
||||
//
|
||||
// dnsName is the account's DNS domain ("netbird.cloud" etc.); used when
|
||||
// rebuilding the per-peer FQDNs that proto.RemotePeerConfig carries.
|
||||
func EnvelopeToNetworkMap(ctx context.Context, env *proto.NetworkMapEnvelope, localPeerKey, dnsName string) (*EnvelopeResult, error) {
|
||||
components, err := DecodeEnvelope(env)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decode envelope: %w", err)
|
||||
}
|
||||
|
||||
// Find the receiving peer in the decoded components by WG key.
|
||||
// c.Peers is keyed by canonical base64 of the raw 32-byte pub key
|
||||
// (decoder re-encodes the bytes off the wire). The caller may pass a
|
||||
// non-canonical encoding (some persisted production keys carry
|
||||
// non-zero trailing padding bits that survived a legacy import), so
|
||||
// round-trip through raw bytes once to canonicalize before lookup.
|
||||
canonicalKey := canonicalizeWgKey(localPeerKey)
|
||||
localPeer := components.Peers[canonicalKey]
|
||||
if localPeer == nil {
|
||||
return nil, fmt.Errorf("receiving peer (wg_key prefix %q) not found among %d decoded peers — components have no PeerID, Calculate would return empty", trimKey(localPeerKey), len(components.Peers))
|
||||
}
|
||||
components.PeerID = canonicalKey
|
||||
|
||||
includeIPv6 := localPeer.SupportsIPv6() && localPeer.IPv6.IsValid()
|
||||
useSourcePrefixes := localPeer.SupportsSourcePrefixes()
|
||||
|
||||
typedNM := components.Calculate(ctx)
|
||||
|
||||
full := env.GetFull()
|
||||
dnsFwdPort := int64(0)
|
||||
if full != nil {
|
||||
dnsFwdPort = full.DnsForwarderPort
|
||||
}
|
||||
|
||||
protoNM := &proto.NetworkMap{
|
||||
Serial: typedNM.Network.CurrentSerial(),
|
||||
}
|
||||
if full != nil {
|
||||
protoNM.PeerConfig = full.PeerConfig
|
||||
}
|
||||
protoNM.Routes = ToProtocolRoutes(typedNM.Routes)
|
||||
protoNM.DNSConfig = ToProtocolDNSConfig(typedNM.DNSConfig, nil, dnsFwdPort)
|
||||
|
||||
remotePeers := AppendRemotePeerConfig(nil, typedNM.Peers, dnsName, includeIPv6)
|
||||
protoNM.RemotePeers = remotePeers
|
||||
protoNM.RemotePeersIsEmpty = len(remotePeers) == 0
|
||||
|
||||
protoNM.OfflinePeers = AppendRemotePeerConfig(nil, typedNM.OfflinePeers, dnsName, includeIPv6)
|
||||
|
||||
firewallRules := ToProtocolFirewallRules(typedNM.FirewallRules, includeIPv6, useSourcePrefixes)
|
||||
protoNM.FirewallRules = firewallRules
|
||||
protoNM.FirewallRulesIsEmpty = len(firewallRules) == 0
|
||||
|
||||
routesFirewallRules := ToProtocolRoutesFirewallRules(typedNM.RoutesFirewallRules)
|
||||
protoNM.RoutesFirewallRules = routesFirewallRules
|
||||
protoNM.RoutesFirewallRulesIsEmpty = len(routesFirewallRules) == 0
|
||||
|
||||
if typedNM.AuthorizedUsers != nil {
|
||||
hashedUsers, machineUsers := BuildAuthorizedUsersProto(ctx, typedNM.AuthorizedUsers)
|
||||
userIDClaim := ""
|
||||
if full != nil {
|
||||
userIDClaim = full.UserIdClaim
|
||||
}
|
||||
protoNM.SshAuth = &proto.SSHAuth{
|
||||
AuthorizedUsers: hashedUsers,
|
||||
MachineUsers: machineUsers,
|
||||
UserIDClaim: userIDClaim,
|
||||
}
|
||||
}
|
||||
|
||||
if typedNM.ForwardingRules != nil {
|
||||
forwardingRules := make([]*proto.ForwardingRule, 0, len(typedNM.ForwardingRules))
|
||||
for _, rule := range typedNM.ForwardingRules {
|
||||
forwardingRules = append(forwardingRules, rule.ToProto())
|
||||
}
|
||||
protoNM.ForwardingRules = forwardingRules
|
||||
}
|
||||
|
||||
// Merge the proxy patch the server attached. Mirrors the legacy
|
||||
// NetworkMap.Merge step that the server runs after Calculate().
|
||||
if full != nil && full.ProxyPatch != nil {
|
||||
mergeProxyPatch(protoNM, full.ProxyPatch)
|
||||
}
|
||||
|
||||
return &EnvelopeResult{
|
||||
NetworkMap: protoNM,
|
||||
Components: components,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// mergeProxyPatch folds a ProxyPatch's pre-expanded fragments into the
|
||||
// proto.NetworkMap that Calculate() produced. Mirrors types.NetworkMap.Merge
|
||||
// — same six collections, deduplicated where the legacy merge dedupes.
|
||||
func mergeProxyPatch(nm *proto.NetworkMap, patch *proto.ProxyPatch) {
|
||||
nm.RemotePeers = appendUniquePeers(nm.RemotePeers, patch.Peers)
|
||||
nm.OfflinePeers = appendUniquePeers(nm.OfflinePeers, patch.OfflinePeers)
|
||||
nm.FirewallRules = append(nm.FirewallRules, patch.FirewallRules...)
|
||||
nm.Routes = append(nm.Routes, patch.Routes...)
|
||||
nm.RoutesFirewallRules = append(nm.RoutesFirewallRules, patch.RouteFirewallRules...)
|
||||
nm.ForwardingRules = append(nm.ForwardingRules, patch.ForwardingRules...)
|
||||
if len(nm.RemotePeers) > 0 {
|
||||
nm.RemotePeersIsEmpty = false
|
||||
}
|
||||
if len(nm.FirewallRules) > 0 {
|
||||
nm.FirewallRulesIsEmpty = false
|
||||
}
|
||||
if len(nm.RoutesFirewallRules) > 0 {
|
||||
nm.RoutesFirewallRulesIsEmpty = false
|
||||
}
|
||||
}
|
||||
|
||||
// appendUniquePeers dedupes by WgPubKey — mirrors legacy
|
||||
// mergeUniquePeersByID's intent (legacy keyed off Peer.ID; in proto form the
|
||||
// closest stable identifier is WgPubKey).
|
||||
func appendUniquePeers(dst, extra []*proto.RemotePeerConfig) []*proto.RemotePeerConfig {
|
||||
if len(extra) == 0 {
|
||||
return dst
|
||||
}
|
||||
seen := make(map[string]struct{}, len(dst))
|
||||
for _, p := range dst {
|
||||
if p == nil {
|
||||
continue
|
||||
}
|
||||
seen[p.WgPubKey] = struct{}{}
|
||||
}
|
||||
for _, p := range extra {
|
||||
if p == nil {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[p.WgPubKey]; ok {
|
||||
continue
|
||||
}
|
||||
seen[p.WgPubKey] = struct{}{}
|
||||
dst = append(dst, p)
|
||||
}
|
||||
return dst
|
||||
}
|
||||
|
||||
func trimKey(s string) string {
|
||||
if len(s) > 12 {
|
||||
return s[:12]
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// canonicalizeWgKey normalises a base64-encoded WireGuard public key so it
|
||||
// matches the canonical encoding emitted by the envelope decoder. Returns
|
||||
// the input unchanged when it does not decode to 32 raw bytes (caller will
|
||||
// hit a miss in the peer map and surface the error).
|
||||
func canonicalizeWgKey(s string) string {
|
||||
raw, err := base64.StdEncoding.DecodeString(s)
|
||||
if err != nil || len(raw) != 32 {
|
||||
return s
|
||||
}
|
||||
return base64.StdEncoding.EncodeToString(raw)
|
||||
}
|
||||
|
||||
201
shared/management/networkmap/envelope_test.go
Normal file
201
shared/management/networkmap/envelope_test.go
Normal file
@@ -0,0 +1,201 @@
|
||||
package networkmap_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
goproto "google.golang.org/protobuf/proto"
|
||||
|
||||
mgmtgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
nbnetworkmap "github.com/netbirdio/netbird/shared/management/networkmap"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
// TestEnvelopeToNetworkMap_RoundTrip exercises the full client-side pipeline:
|
||||
// build a small components struct, encode an envelope, marshal/unmarshal the
|
||||
// wire bytes, decode back via EnvelopeToNetworkMap, and verify the result is
|
||||
// non-empty and consistent.
|
||||
func TestEnvelopeToNetworkMap_RoundTrip(t *testing.T) {
|
||||
c, localPeerKey := buildSmokeComponents(t)
|
||||
|
||||
envelope := mgmtgrpc.EncodeNetworkMapEnvelope(mgmtgrpc.ComponentsEnvelopeInput{
|
||||
Components: c,
|
||||
DNSDomain: "netbird.cloud",
|
||||
})
|
||||
|
||||
wire, err := goproto.Marshal(envelope)
|
||||
require.NoError(t, err, "marshal envelope")
|
||||
|
||||
var decoded proto.NetworkMapEnvelope
|
||||
require.NoError(t, goproto.Unmarshal(wire, &decoded), "unmarshal envelope")
|
||||
|
||||
result, err := nbnetworkmap.EnvelopeToNetworkMap(context.Background(), &decoded, localPeerKey, "netbird.cloud")
|
||||
require.NoError(t, err, "EnvelopeToNetworkMap")
|
||||
require.NotNil(t, result)
|
||||
require.NotNil(t, result.NetworkMap, "decoded NetworkMap must be non-nil")
|
||||
require.NotNil(t, result.Components, "Components must be retained for future delta updates")
|
||||
require.NotNil(t, result.Components.AccountSettings)
|
||||
require.NotEmpty(t, result.NetworkMap.RemotePeers, "two-peer allow policy should produce one remote peer")
|
||||
require.NotEmpty(t, result.NetworkMap.FirewallRules, "two-peer allow policy should produce firewall rules")
|
||||
}
|
||||
|
||||
// TestCalculate_FirewallRuleProtocol_NeverNetbirdSSH guards against the
|
||||
// scenario where a rule with Protocol=NetbirdSSH leaks the enum value into
|
||||
// proto.FirewallRule.Protocol. Calculate() must rewrite NetbirdSSH → TCP
|
||||
// before forming firewall rules. Without that rewrite, agents fall into
|
||||
// UNKNOWN-protocol handling, which on some platforms downgrades to
|
||||
// allow-all — a real security regression.
|
||||
func TestCalculate_FirewallRuleProtocol_NeverNetbirdSSH(t *testing.T) {
|
||||
c, localPeerKey := buildSmokeComponents(t)
|
||||
// Replace the smoke policy with a NetbirdSSH-protocol allow.
|
||||
c.Policies = []*types.Policy{{
|
||||
ID: "pol-ssh", AccountSeqID: 2, Enabled: true,
|
||||
Rules: []*types.PolicyRule{{
|
||||
ID: "rule-ssh",
|
||||
Enabled: true,
|
||||
Action: types.PolicyTrafficActionAccept,
|
||||
Protocol: types.PolicyRuleProtocolNetbirdSSH,
|
||||
Bidirectional: true,
|
||||
Sources: []string{"group-all"},
|
||||
Destinations: []string{"group-all"},
|
||||
}},
|
||||
}}
|
||||
|
||||
envelope := mgmtgrpc.EncodeNetworkMapEnvelope(mgmtgrpc.ComponentsEnvelopeInput{
|
||||
Components: c,
|
||||
DNSDomain: "netbird.cloud",
|
||||
})
|
||||
wire, err := goproto.Marshal(envelope)
|
||||
require.NoError(t, err)
|
||||
var decoded proto.NetworkMapEnvelope
|
||||
require.NoError(t, goproto.Unmarshal(wire, &decoded))
|
||||
|
||||
result, err := nbnetworkmap.EnvelopeToNetworkMap(context.Background(), &decoded, localPeerKey, "netbird.cloud")
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, result.NetworkMap.FirewallRules, "ssh policy should produce firewall rules")
|
||||
for i, fr := range result.NetworkMap.FirewallRules {
|
||||
require.NotEqualf(t, proto.RuleProtocol_NETBIRD_SSH, fr.Protocol,
|
||||
"FirewallRules[%d].Protocol must be the rewritten TCP, not NETBIRD_SSH", i)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnvelopeToNetworkMap_NilEnvelope(t *testing.T) {
|
||||
_, err := nbnetworkmap.EnvelopeToNetworkMap(context.Background(), nil, "key", "netbird.cloud")
|
||||
require.Error(t, err, "nil envelope must produce an error rather than panic")
|
||||
}
|
||||
|
||||
func TestEnvelopeToNetworkMap_FullPayloadMissing(t *testing.T) {
|
||||
env := &proto.NetworkMapEnvelope{}
|
||||
_, err := nbnetworkmap.EnvelopeToNetworkMap(context.Background(), env, "key", "netbird.cloud")
|
||||
require.Error(t, err, "envelope with no Full payload must produce an error")
|
||||
}
|
||||
|
||||
// TestDecodeEnvelope_MalformedWgKeyPeerSkipped feeds an envelope where one
|
||||
// peer has a wg_pub_key that is not 32 bytes long. The decoder must skip
|
||||
// that peer (keeping the rest of the snapshot usable) instead of aborting
|
||||
// the whole sync — mirrors legacy behaviour that tolerates an occasional
|
||||
// bad row.
|
||||
func TestDecodeEnvelope_MalformedWgKeyPeerSkipped(t *testing.T) {
|
||||
c, localPeerKey := buildSmokeComponents(t)
|
||||
envelope := mgmtgrpc.EncodeNetworkMapEnvelope(mgmtgrpc.ComponentsEnvelopeInput{
|
||||
Components: c,
|
||||
DNSDomain: "netbird.cloud",
|
||||
})
|
||||
require.NotNil(t, envelope.GetFull())
|
||||
|
||||
full := envelope.GetFull()
|
||||
require.Len(t, full.Peers, 2, "smoke fixture should have two peers")
|
||||
|
||||
// Truncate the second peer's wg_pub_key so it fails the length gate.
|
||||
full.Peers[1].WgPubKey = full.Peers[1].WgPubKey[:31]
|
||||
|
||||
wire, err := goproto.Marshal(envelope)
|
||||
require.NoError(t, err, "marshal envelope")
|
||||
var decoded proto.NetworkMapEnvelope
|
||||
require.NoError(t, goproto.Unmarshal(wire, &decoded), "unmarshal envelope")
|
||||
|
||||
result, err := nbnetworkmap.EnvelopeToNetworkMap(context.Background(), &decoded, localPeerKey, "netbird.cloud")
|
||||
require.NoError(t, err, "EnvelopeToNetworkMap must tolerate one bad peer key")
|
||||
require.NotNil(t, result)
|
||||
require.NotNil(t, result.Components)
|
||||
require.Len(t, result.Components.Peers, 1, "the well-formed peer survives, the malformed one is dropped")
|
||||
}
|
||||
|
||||
// buildSmokeComponents returns a minimal NetworkMapComponents (2 peers, 1
|
||||
// group, 1 allow policy) plus the receiving peer's WG public key. Sufficient
|
||||
// to validate the encode → marshal → decode → Calculate pipeline produces
|
||||
// non-empty output.
|
||||
func buildSmokeComponents(t *testing.T) (*types.NetworkMapComponents, string) {
|
||||
t.Helper()
|
||||
|
||||
peerAKey := randomWgKey(t)
|
||||
peerBKey := randomWgKey(t)
|
||||
|
||||
peerA := &nbpeer.Peer{
|
||||
ID: "peer-A",
|
||||
Key: peerAKey,
|
||||
IP: netip.AddrFrom4([4]byte{100, 64, 0, 1}),
|
||||
DNSLabel: "peerA",
|
||||
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.40.0"},
|
||||
}
|
||||
peerB := &nbpeer.Peer{
|
||||
ID: "peer-B",
|
||||
Key: peerBKey,
|
||||
IP: netip.AddrFrom4([4]byte{100, 64, 0, 2}),
|
||||
DNSLabel: "peerB",
|
||||
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.40.0"},
|
||||
}
|
||||
|
||||
group := &types.Group{
|
||||
ID: "group-all", AccountSeqID: 1, Name: "All",
|
||||
Peers: []string{"peer-A", "peer-B"},
|
||||
}
|
||||
|
||||
policy := &types.Policy{
|
||||
ID: "pol-allow", AccountSeqID: 1, Enabled: true,
|
||||
Rules: []*types.PolicyRule{{
|
||||
ID: "rule-allow",
|
||||
Enabled: true,
|
||||
Action: types.PolicyTrafficActionAccept,
|
||||
Protocol: types.PolicyRuleProtocolALL,
|
||||
Bidirectional: true,
|
||||
Sources: []string{"group-all"},
|
||||
Destinations: []string{"group-all"},
|
||||
}},
|
||||
}
|
||||
|
||||
c := &types.NetworkMapComponents{
|
||||
PeerID: "peer-A",
|
||||
Network: &types.Network{
|
||||
Identifier: "net-smoke",
|
||||
Net: net.IPNet{IP: net.IP{100, 64, 0, 0}, Mask: net.CIDRMask(10, 32)},
|
||||
Serial: 1,
|
||||
},
|
||||
AccountSettings: &types.AccountSettingsInfo{},
|
||||
DNSSettings: &types.DNSSettings{},
|
||||
Peers: map[string]*nbpeer.Peer{
|
||||
"peer-A": peerA,
|
||||
"peer-B": peerB,
|
||||
},
|
||||
Groups: map[string]*types.Group{
|
||||
"group-all": group,
|
||||
},
|
||||
Policies: []*types.Policy{policy},
|
||||
}
|
||||
return c, peerAKey
|
||||
}
|
||||
|
||||
func randomWgKey(t *testing.T) string {
|
||||
t.Helper()
|
||||
var raw [32]byte
|
||||
_, err := rand.Read(raw[:])
|
||||
require.NoError(t, err)
|
||||
return base64.StdEncoding.EncodeToString(raw[:])
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -150,6 +150,12 @@ message SyncResponse {
|
||||
// SSO-registered; client clears its anchor
|
||||
// set, valid timestamp → new absolute UTC deadline
|
||||
google.protobuf.Timestamp sessionExpiresAt = 7;
|
||||
|
||||
// NetworkMapEnvelope carries the component-based wire format for peers that
|
||||
// advertise PeerCapabilityComponentNetworkMap. When set, NetworkMap (field 5)
|
||||
// is left empty: management ships components and the client runs Calculate()
|
||||
// locally instead of receiving an expanded NetworkMap.
|
||||
NetworkMapEnvelope NetworkMapEnvelope = 8;
|
||||
}
|
||||
|
||||
message SyncMetaRequest {
|
||||
@@ -229,6 +235,8 @@ enum PeerCapability {
|
||||
PeerCapabilitySourcePrefixes = 1;
|
||||
// Client handles IPv6 overlay addresses and firewall rules.
|
||||
PeerCapabilityIPv6Overlay = 2;
|
||||
// Client receives NetworkMap as components and assembles it locally.
|
||||
PeerCapabilityComponentNetworkMap = 3;
|
||||
}
|
||||
|
||||
// PeerSystemMeta is machine meta data like OS and version.
|
||||
@@ -611,6 +619,13 @@ enum RuleProtocol {
|
||||
UDP = 3;
|
||||
ICMP = 4;
|
||||
CUSTOM = 5;
|
||||
// NETBIRD_SSH (types.PolicyRuleProtocolType "netbird-ssh") is the marker
|
||||
// policy rule that drives SSH-server activation in Calculate(). The legacy
|
||||
// proto.FirewallRule path doesn't ship this value (Calculate already
|
||||
// expands SSH rules into TCP/22 before encoding), but the components path
|
||||
// ships RAW policies — the client must see this protocol to derive
|
||||
// AuthorizedUsers locally.
|
||||
NETBIRD_SSH = 6;
|
||||
}
|
||||
|
||||
enum RuleDirection {
|
||||
@@ -751,3 +766,462 @@ message StopExposeRequest {
|
||||
}
|
||||
|
||||
message StopExposeResponse {}
|
||||
|
||||
// =====================================================================
|
||||
// Component-based NetworkMap wire format (PeerCapabilityComponentNetworkMap).
|
||||
//
|
||||
// Peers that advertise this capability receive NetworkMap building blocks
|
||||
// (peers + groups + policies + routes + dns + ssh + forwarding) and run the
|
||||
// expansion (Calculate) locally instead of receiving a fully-expanded
|
||||
// NetworkMap from the server.
|
||||
// =====================================================================
|
||||
|
||||
// NetworkMapEnvelope wraps either a full snapshot or a delta. Only Full is
|
||||
// emitted today; Delta is reserved for the incremental-update work.
|
||||
message NetworkMapEnvelope {
|
||||
oneof payload {
|
||||
NetworkMapComponentsFull full = 1;
|
||||
NetworkMapComponentsDelta delta = 2;
|
||||
}
|
||||
}
|
||||
|
||||
// NetworkMapComponentsFull is the full per-peer component snapshot. The
|
||||
// client decodes it into a types.NetworkMapComponents and runs Calculate()
|
||||
// locally to produce the same NetworkMap the legacy server path would have
|
||||
// produced. Every field carries RAW component data — no server-side
|
||||
// expansion (firewall rules, DNS config, SSH auth, route firewall rules,
|
||||
// forwarding rules) is shipped; the client computes those itself.
|
||||
message NetworkMapComponentsFull {
|
||||
uint64 serial = 1;
|
||||
|
||||
// Peer config for the receiving peer (legacy proto.PeerConfig kept as-is —
|
||||
// it carries the receiving peer's own overlay address, FQDN, SSH config).
|
||||
PeerConfig peer_config = 2;
|
||||
|
||||
// Account-level network metadata (id, IPv4/IPv6 overlay subnets, DNS,
|
||||
// serial). Mirrors types.Network.
|
||||
AccountNetwork network = 3;
|
||||
|
||||
// Account-level settings the client needs for its local Calculate().
|
||||
AccountSettingsCompact account_settings = 4;
|
||||
|
||||
// Account DNS settings (mirrors types.DNSSettings).
|
||||
DNSSettingsCompact dns_settings = 5;
|
||||
|
||||
// Domain shared across all peers in this account, e.g. "netbird.cloud".
|
||||
// Each peer's FQDN is dns_label + "." + dns_domain.
|
||||
string dns_domain = 6;
|
||||
|
||||
// Custom-zone domain for this peer's view (c.CustomZoneDomain). Empty when
|
||||
// the peer has no custom zone records.
|
||||
string custom_zone_domain = 7;
|
||||
|
||||
// Deduplicated agent versions; PeerCompact.agent_version_idx indexes here.
|
||||
// Empty string at index 0 if any peer has no version.
|
||||
repeated string agent_versions = 8;
|
||||
|
||||
// All peers (deduplicated). The client splits peers into online / offline
|
||||
// locally using account_settings.peer_login_expiration on receive.
|
||||
repeated PeerCompact peers = 9;
|
||||
|
||||
// Indexes into peers for the subset that may act as routers.
|
||||
repeated uint32 router_peer_indexes = 10;
|
||||
|
||||
// Policies that affect the receiving peer.
|
||||
repeated PolicyCompact policies = 11;
|
||||
|
||||
// Groups in unspecified order — clients key off id (account_seq_id).
|
||||
repeated GroupCompact groups = 12;
|
||||
|
||||
// Routes relevant to this peer, raw shape (mirrors []*route.Route).
|
||||
repeated RouteRaw routes = 13;
|
||||
|
||||
// Nameserver groups (mirrors []*nbdns.NameServerGroup).
|
||||
repeated NameServerGroupRaw nameserver_groups = 14;
|
||||
|
||||
// All DNS records the client needs to assemble its custom zone. Reuses
|
||||
// the existing SimpleRecord wire shape.
|
||||
repeated SimpleRecord all_dns_records = 15;
|
||||
|
||||
// Custom zones (typically the peer's own zone). Reuses the existing
|
||||
// CustomZone wire shape.
|
||||
repeated CustomZone account_zones = 16;
|
||||
|
||||
// Network resources (mirrors []*resourceTypes.NetworkResource).
|
||||
repeated NetworkResourceRaw network_resources = 17;
|
||||
|
||||
// Routers per network. Outer key: network account_seq_id. Each entry is
|
||||
// the set of routers backing that network for this peer's view.
|
||||
//
|
||||
// INCOMPATIBLE WIRE CHANGE: the map key changed from string (network xid)
|
||||
// to uint32 (account_seq_id). Field 18 was reused without a `reserved`
|
||||
// entry because capability=3 has never been released — every cap=3
|
||||
// producer and consumer carries the same regenerated descriptor. Do NOT
|
||||
// reuse this pattern for any further wire change once cap=3 ships.
|
||||
map<uint32, NetworkRouterList> routers_map = 18;
|
||||
|
||||
// For each NetworkResource account_seq_id, the indexes into policies[]
|
||||
// that apply to it.
|
||||
//
|
||||
// INCOMPATIBLE WIRE CHANGE: see routers_map note above.
|
||||
map<uint32, PolicyIndexes> resource_policies_map = 19;
|
||||
|
||||
// Group-id (account_seq_id) → user ids authorized for SSH on members.
|
||||
map<uint32, UserIDList> group_id_to_user_ids = 20;
|
||||
|
||||
// Account-level allowed user ids (used by Calculate() when assembling SSH
|
||||
// authorized users for the receiving peer).
|
||||
repeated string allowed_user_ids = 21;
|
||||
|
||||
// Per posture-check account_seq_id, the set of peer indexes that failed
|
||||
// the check. Server-side evaluation result; clients do not re-evaluate.
|
||||
//
|
||||
// INCOMPATIBLE WIRE CHANGE: see routers_map note above.
|
||||
map<uint32, PeerIndexSet> posture_failed_peers = 22;
|
||||
|
||||
// Account-level DNS forwarder port (mirrors the legacy
|
||||
// proto.DNSConfig.ForwarderPort). Computed by the controller from peer
|
||||
// versions; clients fold it into their Calculate() DNS output.
|
||||
int64 dns_forwarder_port = 23;
|
||||
|
||||
// Pre-expanded NetworkMap fragments injected post-Calculate by external
|
||||
// controllers (BYOP / port-forwarding proxies). The receiving client
|
||||
// merges these into its locally-computed NetworkMap the same way the
|
||||
// legacy server does via NetworkMap.Merge — so downstream consumers see
|
||||
// a unified merged result regardless of source.
|
||||
ProxyPatch proxy_patch = 24;
|
||||
|
||||
// SSH UserIDClaim — server-side HttpServerConfig.AuthUserIDClaim, or
|
||||
// "sub" by default. Populated in proto.SSHAuth.UserIDClaim when the
|
||||
// client rebuilds the NetworkMap from this envelope. Empty when the
|
||||
// account has no AuthorizedUsers (and thus no SshAuth to populate).
|
||||
string user_id_claim = 25;
|
||||
|
||||
// Reserved for future component additions (incremental_serial, parent_seq,
|
||||
// etc.) without forcing a renumber.
|
||||
reserved 26 to 50;
|
||||
}
|
||||
|
||||
// ProxyPatch carries NetworkMap fragments that don't fit the component-graph
|
||||
// model — they're pre-expanded by external controllers (BYOP /
|
||||
// port-forwarding proxies) and injected post-Calculate. Fields use the
|
||||
// legacy wire types because the proxy delivers them pre-formed; there is
|
||||
// no raw component shape to convert from. Empty when no proxy is active.
|
||||
message ProxyPatch {
|
||||
repeated RemotePeerConfig peers = 1;
|
||||
repeated RemotePeerConfig offline_peers = 2;
|
||||
repeated FirewallRule firewall_rules = 3;
|
||||
repeated Route routes = 4;
|
||||
repeated RouteFirewallRule route_firewall_rules = 5;
|
||||
repeated ForwardingRule forwarding_rules = 6;
|
||||
}
|
||||
|
||||
// AccountSettingsCompact carries the account-level settings the client needs
|
||||
// to evaluate locally. Mirrors the subset of types.AccountSettingsInfo that
|
||||
// Calculate() actually reads — login-expiration (used to filter expired
|
||||
// peers). Inactivity expiration is purely server-side bookkeeping and is not
|
||||
// shipped.
|
||||
message AccountSettingsCompact {
|
||||
bool peer_login_expiration_enabled = 1;
|
||||
// Login expiration window. Unit is nanoseconds (matches time.Duration).
|
||||
int64 peer_login_expiration_ns = 2;
|
||||
}
|
||||
|
||||
// AccountNetwork is the account-level overlay metadata. Mirrors types.Network
|
||||
// so the client can populate NetworkMap.Network without a server round-trip.
|
||||
message AccountNetwork {
|
||||
string identifier = 1;
|
||||
// IPv4 overlay subnet in CIDR form (e.g. "100.64.0.0/16").
|
||||
string net_cidr = 2;
|
||||
// IPv6 ULA overlay subnet in CIDR form (e.g. "fd00:4e42::/64"). Empty when
|
||||
// the account has no IPv6 overlay yet.
|
||||
string net_v6_cidr = 3;
|
||||
string dns = 4;
|
||||
uint64 serial = 5;
|
||||
}
|
||||
|
||||
// NetworkMapComponentsDelta is reserved for the incremental update
|
||||
// protocol. Field numbers 1–100 are pre-allocated to keep room for the
|
||||
// planned event types without needing a renumber.
|
||||
message NetworkMapComponentsDelta {
|
||||
reserved 1 to 100;
|
||||
}
|
||||
|
||||
// PeerCompact is the wire-shape of a remote peer used by the component
|
||||
// format. It carries every field of types.Peer that the client's local
|
||||
// Calculate() reads — including the trio needed to evaluate
|
||||
// LoginExpired() (added_with_sso_login + login_expiration_enabled +
|
||||
// last_login_unix_nano). Fields the client does not consume (Status,
|
||||
// CreatedAt, etc.) are not shipped.
|
||||
message PeerCompact {
|
||||
// Raw 32-byte WireGuard public key (no base64 wrapping).
|
||||
bytes wg_pub_key = 1;
|
||||
|
||||
// Raw 4-byte IPv4 overlay address. Always a /32 host route, so no prefix
|
||||
// byte is needed.
|
||||
bytes ip = 2;
|
||||
|
||||
// Raw 16-byte IPv6 overlay address; always a /128 host route. Empty when
|
||||
// the peer has no IPv6 overlay address.
|
||||
bytes ipv6 = 3;
|
||||
|
||||
// Raw SSH public key bytes (or empty).
|
||||
bytes ssh_pub_key = 4;
|
||||
|
||||
// DNS label without the account's domain suffix. Full FQDN is
|
||||
// dns_label + "." + NetworkMapComponentsFull.dns_domain.
|
||||
string dns_label = 5;
|
||||
|
||||
// Index into NetworkMapComponentsFull.agent_versions.
|
||||
uint32 agent_version_idx = 6;
|
||||
|
||||
// True iff the peer was added via SSO login (i.e., types.Peer.UserID is
|
||||
// non-empty). Combined with login_expiration_enabled and
|
||||
// last_login_unix_nano this lets the client reproduce
|
||||
// (*Peer).LoginExpired() locally.
|
||||
bool added_with_sso_login = 7;
|
||||
|
||||
// True when the peer's login can expire — mirrors
|
||||
// types.Peer.LoginExpirationEnabled.
|
||||
bool login_expiration_enabled = 8;
|
||||
|
||||
// Unix-nanosecond timestamp of the peer's last login. 0 when the peer has
|
||||
// never logged in (server stores nil; client treats 0 as "epoch", which
|
||||
// makes a fresh peer immediately expired iff login_expiration_enabled is
|
||||
// true — the same semantics as types.Peer.GetLastLogin).
|
||||
int64 last_login_unix_nano = 9;
|
||||
|
||||
// True when the peer has an SSH server enabled locally. Used by the
|
||||
// legacy SSH path in Calculate() (`policyRuleImpliesLegacySSH`): a rule
|
||||
// with protocol ALL/TCP-with-SSH-ports activates SSH for the receiving
|
||||
// peer when this bit is set, even without an explicit NetbirdSSH rule.
|
||||
bool ssh_enabled = 10;
|
||||
|
||||
reserved 11; // was: id (string xid)
|
||||
|
||||
// Mirror of types.Peer.SupportsIPv6() — !Meta.Flags.DisableIPv6 &&
|
||||
// HasCapability(PeerCapabilityIPv6Overlay). Used by the local peer's
|
||||
// Calculate() when deciding whether to emit IPv6 firewall rules
|
||||
// (appendIPv6FirewallRule) against this peer's IPv6 address.
|
||||
bool supports_ipv6 = 12;
|
||||
|
||||
// Mirror of types.Peer.SupportsSourcePrefixes() —
|
||||
// HasCapability(PeerCapabilitySourcePrefixes). Determines whether the
|
||||
// local peer's Calculate() emits SourcePrefixes alongside legacy PeerIP
|
||||
// fields in proto.FirewallRule.
|
||||
bool supports_source_prefixes = 13;
|
||||
|
||||
// Mirror of types.Peer.Meta.Flags.ServerSSHAllowed. Read by Calculate()
|
||||
// when expanding TCP port-22 firewall rules — the native SSH companion
|
||||
// (port 22022) is only added when this flag is set and the peer agent
|
||||
// version supports it.
|
||||
bool server_ssh_allowed = 14;
|
||||
}
|
||||
|
||||
// PolicyCompact is the compact form of a policy rule. Group references use
|
||||
// the per-account integer ids from account_seq_counters; the client resolves
|
||||
// them against NetworkMapComponentsFull.groups. Direction is derived per-peer
|
||||
// on the client (ingress when the peer is in destination_group_ids, egress
|
||||
// when in source_group_ids; both when bidirectional).
|
||||
message PolicyCompact {
|
||||
// Per-account integer id (matches policies.account_seq_id). Used as a
|
||||
// stable reference for ResourcePoliciesMap.indexes and future delta
|
||||
// updates.
|
||||
uint32 id = 1;
|
||||
|
||||
RuleAction action = 2;
|
||||
RuleProtocol protocol = 3;
|
||||
bool bidirectional = 4;
|
||||
|
||||
// Single ports referenced by the rule.
|
||||
repeated uint32 ports = 5;
|
||||
|
||||
// Port ranges (start..end) referenced by the rule.
|
||||
repeated PortInfo.Range port_ranges = 6;
|
||||
|
||||
// Group ids (account_seq_id) of source / destination groups.
|
||||
repeated uint32 source_group_ids = 7;
|
||||
repeated uint32 destination_group_ids = 8;
|
||||
|
||||
reserved 9; // was: xid (string)
|
||||
|
||||
// SSH authorization fields. PolicyRule.AuthorizedGroups maps the rule's
|
||||
// applicable group ids (account_seq_id) to a list of local-user names —
|
||||
// when a peer in one of those groups is the SSH destination, the named
|
||||
// local users gain access. AuthorizedUser is the single-user form
|
||||
// (legacy: rule scopes SSH to one specific user id).
|
||||
//
|
||||
// Both fields are only consumed by Calculate() when the rule's protocol
|
||||
// is NetbirdSSH (or the legacy implicit-SSH heuristic).
|
||||
map<uint32, UserNameList> authorized_groups = 10;
|
||||
string authorized_user = 11;
|
||||
|
||||
// Resource-typed rule sources/destinations. When a rule targets a specific
|
||||
// peer (rather than groups), Calculate() reads SourceResource /
|
||||
// DestinationResource — without these the rule's connection resources
|
||||
// can't be produced on the client. ResourceCompact's peer_index refers to
|
||||
// NetworkMapComponentsFull.peers; type is the raw ResourceType string
|
||||
// ("peer", "host", "subnet", "domain"). Only "peer" is meaningful for
|
||||
// Calculate's resource-typed rule path today.
|
||||
ResourceCompact source_resource = 12;
|
||||
ResourceCompact destination_resource = 13;
|
||||
|
||||
// Posture-check seq ids gating this policy's source peers. Calculate()
|
||||
// reads them when filtering rule peers (peers that fail any listed check
|
||||
// are dropped from sourcePeers). Match keys in
|
||||
// NetworkMapComponentsFull.posture_failed_peers.
|
||||
repeated uint32 source_posture_check_seq_ids = 15;
|
||||
|
||||
reserved 14; // was: source_posture_check_ids (repeated string xid)
|
||||
}
|
||||
|
||||
// ResourceCompact mirrors types.Resource. Used by PolicyCompact to carry
|
||||
// rule.SourceResource / rule.DestinationResource when the rule targets a
|
||||
// specific resource (typically a peer) rather than groups.
|
||||
// peer_index_set tells whether peer_index is valid (proto3 uint32 cannot
|
||||
// disambiguate "0" from "unset"); set only when type == "peer".
|
||||
message ResourceCompact {
|
||||
string type = 1;
|
||||
bool peer_index_set = 2;
|
||||
uint32 peer_index = 3;
|
||||
reserved 4; // future: host/subnet/domain references when needed
|
||||
}
|
||||
|
||||
// UserNameList is a list of local-user names — used as the value type in
|
||||
// PolicyCompact.authorized_groups.
|
||||
message UserNameList {
|
||||
repeated string names = 1;
|
||||
}
|
||||
|
||||
// GroupCompact is the wire-shape of a group: per-account integer id, optional
|
||||
// name, and indexes into NetworkMapComponentsFull.peers identifying members.
|
||||
message GroupCompact {
|
||||
// Per-account integer id (matches groups.account_seq_id). Used by
|
||||
// PolicyCompact.source_group_ids / destination_group_ids.
|
||||
uint32 id = 1;
|
||||
|
||||
// Group name; only sent when non-empty (clients use it for diagnostics).
|
||||
string name = 2;
|
||||
|
||||
// Indexes into NetworkMapComponentsFull.peers.
|
||||
repeated uint32 peer_indexes = 3;
|
||||
}
|
||||
|
||||
// DNSSettingsCompact mirrors types.DNSSettings.
|
||||
message DNSSettingsCompact {
|
||||
// Group ids (account_seq_id) whose DNS management is disabled.
|
||||
repeated uint32 disabled_management_group_ids = 1;
|
||||
}
|
||||
|
||||
// RouteRaw mirrors *route.Route (the domain type), trimmed to fields that
|
||||
// types.NetworkMapComponents.Calculate() reads. Group references are
|
||||
// account_seq_ids; the routing peer (when set) is referenced by index into
|
||||
// NetworkMapComponentsFull.peers.
|
||||
message RouteRaw {
|
||||
// Per-account integer id (matches routes.account_seq_id).
|
||||
uint32 id = 1;
|
||||
string net_id = 2;
|
||||
string description = 3;
|
||||
|
||||
// Either network_cidr (e.g. "10.0.0.0/16") or domains is set, not both.
|
||||
string network_cidr = 4;
|
||||
repeated string domains = 5;
|
||||
bool keep_route = 6;
|
||||
|
||||
// Routing peer reference: peer_index_set tells whether peer_index is valid
|
||||
// (proto3 uint32 cannot disambiguate "0" from "unset"). Mutually exclusive
|
||||
// with peer_group_ids.
|
||||
//
|
||||
// peer_index decodes back to types.Peer.ID (the peer's xid string), NOT
|
||||
// to its WireGuard public key. This matches the server-side data flow:
|
||||
// c.Routes carry route.Peer = peer.ID, and getRoutingPeerRoutes mutates
|
||||
// it to peer.Key only after the route has been admitted to the network
|
||||
// map. Decoders MUST set Route.Peer = peer.ID; the legacy Calculate()
|
||||
// path will substitute the WG key downstream.
|
||||
bool peer_index_set = 7;
|
||||
uint32 peer_index = 8;
|
||||
repeated uint32 peer_group_ids = 9;
|
||||
|
||||
int32 network_type = 10;
|
||||
bool masquerade = 11;
|
||||
int32 metric = 12;
|
||||
bool enabled = 13;
|
||||
repeated uint32 group_ids = 14;
|
||||
repeated uint32 access_control_group_ids = 15;
|
||||
bool skip_auto_apply = 16;
|
||||
|
||||
reserved 17; // was: xid (string)
|
||||
}
|
||||
|
||||
// NameServerGroupRaw mirrors *nbdns.NameServerGroup. Distinct from the
|
||||
// legacy NameServerGroup (which is the wire-trimmed shape consumed by
|
||||
// proto.DNSConfig and lacks the Name/Description/Groups/Enabled fields).
|
||||
message NameServerGroupRaw {
|
||||
uint32 id = 1; // nameserver_groups.account_seq_id
|
||||
string name = 2;
|
||||
string description = 3;
|
||||
// Reuses the legacy NameServer wire shape (IP as string).
|
||||
repeated NameServer nameservers = 4;
|
||||
// Group ids (account_seq_id) the NSG distributes nameservers to.
|
||||
repeated uint32 group_ids = 5;
|
||||
bool primary = 6;
|
||||
repeated string domains = 7;
|
||||
bool enabled = 8;
|
||||
bool search_domains_enabled = 9;
|
||||
}
|
||||
|
||||
// NetworkResourceRaw mirrors *resourceTypes.NetworkResource.
|
||||
//
|
||||
// INCOMPATIBLE WIRE CHANGE: field 2 changed from `string network_id` (xid)
|
||||
// to `uint32 network_seq` without a `reserved` entry. Safe only because
|
||||
// capability=3 has never been released — every cap=3 producer and consumer
|
||||
// carries the same regenerated descriptor. Do NOT reuse this pattern once
|
||||
// cap=3 ships.
|
||||
message NetworkResourceRaw {
|
||||
uint32 id = 1; // network_resources.account_seq_id
|
||||
uint32 network_seq = 2; // networks.account_seq_id (replaces xid)
|
||||
string name = 3;
|
||||
string description = 4;
|
||||
// Resource type: "host" / "subnet" / "domain".
|
||||
string type = 5;
|
||||
string address = 6;
|
||||
string domain_value = 7; // resource.Domain
|
||||
string prefix_cidr = 8;
|
||||
bool enabled = 9;
|
||||
reserved 10; // was: xid (string)
|
||||
}
|
||||
|
||||
// NetworkRouterList carries the routers backing one network.
|
||||
message NetworkRouterList {
|
||||
// Routers in this network, keyed by peer_index (the routing peer).
|
||||
repeated NetworkRouterEntry entries = 1;
|
||||
}
|
||||
|
||||
// NetworkRouterEntry mirrors a single *routerTypes.NetworkRouter; the routing
|
||||
// peer is referenced by index into NetworkMapComponentsFull.peers.
|
||||
message NetworkRouterEntry {
|
||||
uint32 id = 1; // network_routers.account_seq_id
|
||||
uint32 peer_index = 2;
|
||||
bool peer_index_set = 3;
|
||||
repeated uint32 peer_group_ids = 4;
|
||||
bool masquerade = 5;
|
||||
int32 metric = 6;
|
||||
bool enabled = 7;
|
||||
}
|
||||
|
||||
// PolicyIndexes is a list of indexes into NetworkMapComponentsFull.policies.
|
||||
message PolicyIndexes {
|
||||
repeated uint32 indexes = 1;
|
||||
}
|
||||
|
||||
// UserIDList is a list of user ids — used as the value type in
|
||||
// NetworkMapComponentsFull.group_id_to_user_ids.
|
||||
message UserIDList {
|
||||
repeated string user_ids = 1;
|
||||
}
|
||||
|
||||
// PeerIndexSet is a set of peer indexes — used as the value type in
|
||||
// NetworkMapComponentsFull.posture_failed_peers.
|
||||
message PeerIndexSet {
|
||||
repeated uint32 peer_indexes = 1;
|
||||
}
|
||||
|
||||
131
shared/management/types/firewall_helpers.go
Normal file
131
shared/management/types/firewall_helpers.go
Normal file
@@ -0,0 +1,131 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
)
|
||||
|
||||
const (
|
||||
firewallRuleMinPortRangesVer = "0.48.0"
|
||||
firewallRuleMinNativeSSHVer = "0.60.0"
|
||||
|
||||
nativeSSHPortString = "22022"
|
||||
nativeSSHPortNumber = 22022
|
||||
defaultSSHPortString = "22"
|
||||
defaultSSHPortNumber = 22
|
||||
)
|
||||
|
||||
type supportedFeatures struct {
|
||||
nativeSSH bool
|
||||
portRanges bool
|
||||
}
|
||||
|
||||
type LookupMap map[string]struct{}
|
||||
|
||||
func PolicyRuleImpliesLegacySSH(rule *PolicyRule) bool {
|
||||
return rule.Protocol == PolicyRuleProtocolALL || (rule.Protocol == PolicyRuleProtocolTCP && (portsIncludesSSH(rule.Ports) || portRangeIncludesSSH(rule.PortRanges)))
|
||||
}
|
||||
|
||||
func portRangeIncludesSSH(portRanges []RulePortRange) bool {
|
||||
for _, pr := range portRanges {
|
||||
if (pr.Start <= defaultSSHPortNumber && pr.End >= defaultSSHPortNumber) || (pr.Start <= nativeSSHPortNumber && pr.End >= nativeSSHPortNumber) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func portsIncludesSSH(ports []string) bool {
|
||||
for _, port := range ports {
|
||||
if port == defaultSSHPortString || port == nativeSSHPortString {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// ExpandPortsAndRanges expands Ports and PortRanges of a rule into individual firewall rules.
|
||||
func ExpandPortsAndRanges(base FirewallRule, rule *PolicyRule, peer *nbpeer.Peer) []*FirewallRule {
|
||||
features := peerSupportedFirewallFeatures(peer.Meta.WtVersion)
|
||||
|
||||
var expanded []*FirewallRule
|
||||
|
||||
for _, port := range rule.Ports {
|
||||
fr := base
|
||||
fr.Port = port
|
||||
expanded = append(expanded, &fr)
|
||||
}
|
||||
|
||||
for _, portRange := range rule.PortRanges {
|
||||
if len(rule.Ports) > 0 {
|
||||
break
|
||||
}
|
||||
fr := base
|
||||
|
||||
if features.portRanges {
|
||||
fr.PortRange = portRange
|
||||
} else {
|
||||
if portRange.Start != portRange.End {
|
||||
continue
|
||||
}
|
||||
fr.Port = strconv.FormatUint(uint64(portRange.Start), 10)
|
||||
}
|
||||
expanded = append(expanded, &fr)
|
||||
}
|
||||
|
||||
if shouldCheckRulesForNativeSSH(features.nativeSSH, rule, peer) || rule.Protocol == PolicyRuleProtocolNetbirdSSH {
|
||||
expanded = addNativeSSHRule(base, expanded)
|
||||
}
|
||||
|
||||
return expanded
|
||||
}
|
||||
|
||||
func addNativeSSHRule(base FirewallRule, expanded []*FirewallRule) []*FirewallRule {
|
||||
shouldAdd := false
|
||||
for _, fr := range expanded {
|
||||
if isPortInRule(nativeSSHPortString, 22022, fr) {
|
||||
return expanded
|
||||
}
|
||||
if isPortInRule(defaultSSHPortString, 22, fr) {
|
||||
shouldAdd = true
|
||||
}
|
||||
}
|
||||
if !shouldAdd {
|
||||
return expanded
|
||||
}
|
||||
|
||||
fr := base
|
||||
fr.Port = nativeSSHPortString
|
||||
return append(expanded, &fr)
|
||||
}
|
||||
|
||||
func isPortInRule(portString string, portInt uint16, rule *FirewallRule) bool {
|
||||
return rule.Port == portString || (rule.PortRange.Start <= portInt && portInt <= rule.PortRange.End)
|
||||
}
|
||||
|
||||
func shouldCheckRulesForNativeSSH(supportsNative bool, rule *PolicyRule, peer *nbpeer.Peer) bool {
|
||||
return supportsNative && peer.SSHEnabled && peer.Meta.Flags.ServerSSHAllowed && rule.Protocol == PolicyRuleProtocolTCP
|
||||
}
|
||||
|
||||
func peerSupportedFirewallFeatures(peerVer string) supportedFeatures {
|
||||
if strings.Contains(peerVer, "dev") {
|
||||
return supportedFeatures{true, true}
|
||||
}
|
||||
|
||||
var features supportedFeatures
|
||||
|
||||
meetMinVer, err := posture.MeetsMinVersion(firewallRuleMinNativeSSHVer, peerVer)
|
||||
features.nativeSSH = err == nil && meetMinVer
|
||||
|
||||
if features.nativeSSH {
|
||||
features.portRanges = true
|
||||
} else {
|
||||
meetMinVer, err = posture.MeetsMinVersion(firewallRuleMinPortRangesVer, peerVer)
|
||||
features.portRanges = err == nil && meetMinVer
|
||||
}
|
||||
|
||||
return features
|
||||
}
|
||||
@@ -47,11 +47,11 @@ func (r *FirewallRule) Equal(other *FirewallRule) bool {
|
||||
return reflect.DeepEqual(r, other)
|
||||
}
|
||||
|
||||
// generateRouteFirewallRules generates a list of firewall rules for a given route.
|
||||
// GenerateRouteFirewallRules generates a list of firewall rules for a given route.
|
||||
// For static routes, source ranges match the destination family (v4 or v6).
|
||||
// For dynamic routes (domain-based), separate v4 and v6 rules are generated
|
||||
// so the routing peer's forwarding chain allows both address families.
|
||||
func generateRouteFirewallRules(ctx context.Context, route *nbroute.Route, rule *PolicyRule, groupPeers []*nbpeer.Peer, direction int, includeIPv6 bool) []*RouteFirewallRule {
|
||||
func GenerateRouteFirewallRules(ctx context.Context, route *nbroute.Route, rule *PolicyRule, groupPeers []*nbpeer.Peer, direction int, includeIPv6 bool) []*RouteFirewallRule {
|
||||
rulesExists := make(map[string]struct{})
|
||||
rules := make([]*RouteFirewallRule, 0)
|
||||
|
||||
@@ -57,7 +57,7 @@ func TestGenerateRouteFirewallRules_V4Route(t *testing.T) {
|
||||
Protocol: PolicyRuleProtocolALL,
|
||||
}
|
||||
|
||||
rules := generateRouteFirewallRules(context.Background(), r, rule, peers, FirewallRuleDirectionIN, true)
|
||||
rules := GenerateRouteFirewallRules(context.Background(), r, rule, peers, FirewallRuleDirectionIN, true)
|
||||
|
||||
require.Len(t, rules, 1)
|
||||
assert.Equal(t, []string{"100.64.0.1/32", "100.64.0.2/32"}, rules[0].SourceRanges, "v4 route should only have v4 sources")
|
||||
@@ -86,7 +86,7 @@ func TestGenerateRouteFirewallRules_V6Route(t *testing.T) {
|
||||
Protocol: PolicyRuleProtocolALL,
|
||||
}
|
||||
|
||||
rules := generateRouteFirewallRules(context.Background(), r, rule, peers, FirewallRuleDirectionIN, true)
|
||||
rules := GenerateRouteFirewallRules(context.Background(), r, rule, peers, FirewallRuleDirectionIN, true)
|
||||
|
||||
require.Len(t, rules, 1)
|
||||
assert.Equal(t, []string{"fd00::1/128"}, rules[0].SourceRanges, "v6 route should only have v6 sources")
|
||||
@@ -115,7 +115,7 @@ func TestGenerateRouteFirewallRules_DynamicRoute_DualStack(t *testing.T) {
|
||||
Protocol: PolicyRuleProtocolALL,
|
||||
}
|
||||
|
||||
rules := generateRouteFirewallRules(context.Background(), r, rule, peers, FirewallRuleDirectionIN, true)
|
||||
rules := GenerateRouteFirewallRules(context.Background(), r, rule, peers, FirewallRuleDirectionIN, true)
|
||||
|
||||
require.Len(t, rules, 2, "dynamic route should produce both v4 and v6 rules")
|
||||
assert.Equal(t, []string{"100.64.0.1/32", "100.64.0.2/32"}, rules[0].SourceRanges)
|
||||
@@ -143,7 +143,7 @@ func TestGenerateRouteFirewallRules_DynamicRoute_NoV6Peers(t *testing.T) {
|
||||
Protocol: PolicyRuleProtocolALL,
|
||||
}
|
||||
|
||||
rules := generateRouteFirewallRules(context.Background(), r, rule, peers, FirewallRuleDirectionIN, true)
|
||||
rules := GenerateRouteFirewallRules(context.Background(), r, rule, peers, FirewallRuleDirectionIN, true)
|
||||
|
||||
require.Len(t, rules, 1, "no v6 peers means only v4 rule")
|
||||
assert.Equal(t, []string{"100.64.0.1/32", "100.64.0.2/32"}, rules[0].SourceRanges)
|
||||
@@ -173,7 +173,7 @@ func TestGenerateRouteFirewallRules_IncludeIPv6False(t *testing.T) {
|
||||
Protocol: PolicyRuleProtocolALL,
|
||||
}
|
||||
|
||||
rules := generateRouteFirewallRules(context.Background(), r, rule, peers, FirewallRuleDirectionIN, false)
|
||||
rules := GenerateRouteFirewallRules(context.Background(), r, rule, peers, FirewallRuleDirectionIN, false)
|
||||
assert.Empty(t, rules, "v6 route should produce no rules when includeIPv6 is false")
|
||||
})
|
||||
|
||||
@@ -190,7 +190,7 @@ func TestGenerateRouteFirewallRules_IncludeIPv6False(t *testing.T) {
|
||||
Protocol: PolicyRuleProtocolALL,
|
||||
}
|
||||
|
||||
rules := generateRouteFirewallRules(context.Background(), r, rule, peers, FirewallRuleDirectionIN, false)
|
||||
rules := GenerateRouteFirewallRules(context.Background(), r, rule, peers, FirewallRuleDirectionIN, false)
|
||||
require.Len(t, rules, 1, "dynamic route with includeIPv6=false should produce only v4 rule")
|
||||
assert.Equal(t, []string{"100.64.0.1/32", "100.64.0.2/32"}, rules[0].SourceRanges)
|
||||
})
|
||||
@@ -19,6 +19,10 @@ type Group struct {
|
||||
// AccountID is a reference to Account that this object belongs
|
||||
AccountID string `json:"-" gorm:"index"`
|
||||
|
||||
// AccountSeqID is a per-account monotonically increasing identifier used as the
|
||||
// compact wire id when sending NetworkMap components to capable peers.
|
||||
AccountSeqID uint32 `json:"-" gorm:"index:idx_groups_account_seq_id;not null;default:0"`
|
||||
|
||||
// Name visible in the UI
|
||||
Name string
|
||||
|
||||
@@ -41,6 +45,14 @@ type GroupPeer struct {
|
||||
PeerID string `gorm:"primaryKey"`
|
||||
}
|
||||
|
||||
// HasSeqID reports whether the group has been persisted long enough to have a
|
||||
// per-account sequence id allocated. Wire encoders that key off AccountSeqID
|
||||
// must skip groups that return false here — otherwise multiple unpersisted
|
||||
// groups would collide on id 0.
|
||||
func (g *Group) HasSeqID() bool {
|
||||
return g != nil && g.AccountSeqID != 0
|
||||
}
|
||||
|
||||
func (g *Group) LoadGroupPeers() {
|
||||
g.Peers = make([]string, len(g.GroupPeers))
|
||||
for i, peer := range g.GroupPeers {
|
||||
@@ -74,6 +86,7 @@ func (g *Group) Copy() *Group {
|
||||
group := &Group{
|
||||
ID: g.ID,
|
||||
AccountID: g.AccountID,
|
||||
AccountSeqID: g.AccountSeqID,
|
||||
Name: g.Name,
|
||||
Issued: g.Issued,
|
||||
Peers: make([]string, len(g.Peers)),
|
||||
@@ -42,6 +42,17 @@ type NetworkMapComponents struct {
|
||||
PostureFailedPeers map[string]map[string]struct{}
|
||||
|
||||
RouterPeers map[string]*nbpeer.Peer
|
||||
|
||||
// NetworkXIDToSeq maps Network.ID (xid) → AccountSeqID. Populated by the
|
||||
// account-side component builder; consumed by the envelope encoder to
|
||||
// translate RoutersMap keys and NetworkResource.NetworkID references
|
||||
// to compact uint32 ids. Legacy Calculate() doesn't consult it.
|
||||
NetworkXIDToSeq map[string]uint32
|
||||
|
||||
// PostureCheckXIDToSeq maps posture.Checks.ID (xid) → AccountSeqID.
|
||||
// Same role as NetworkXIDToSeq, used for PostureFailedPeers keys and
|
||||
// policy SourcePostureChecks references.
|
||||
PostureCheckXIDToSeq map[string]uint32
|
||||
}
|
||||
|
||||
type AccountSettingsInfo struct {
|
||||
@@ -252,7 +263,7 @@ func (c *NetworkMapComponents) getPeerConnectionResources(targetPeerID string) (
|
||||
default:
|
||||
authorizedUsers[auth.Wildcard] = c.getAllowedUserIDs()
|
||||
}
|
||||
} else if peerInDestinations && policyRuleImpliesLegacySSH(rule) && targetPeer.SSHEnabled {
|
||||
} else if peerInDestinations && PolicyRuleImpliesLegacySSH(rule) && targetPeer.SSHEnabled {
|
||||
sshEnabled = true
|
||||
authorizedUsers[auth.Wildcard] = c.getAllowedUserIDs()
|
||||
}
|
||||
@@ -319,15 +330,15 @@ func (c *NetworkMapComponents) connResourcesGenerator(targetPeer *nbpeer.Peer) (
|
||||
if len(rule.Ports) == 0 && len(rule.PortRanges) == 0 {
|
||||
rules = append(rules, &fr)
|
||||
} else {
|
||||
rules = append(rules, expandPortsAndRanges(fr, rule, targetPeer)...)
|
||||
rules = append(rules, ExpandPortsAndRanges(fr, rule, targetPeer)...)
|
||||
}
|
||||
|
||||
rules = appendIPv6FirewallRule(rules, rulesExists, peer, targetPeer, rule, firewallRuleContext{
|
||||
direction: direction,
|
||||
dirStr: dirStr,
|
||||
protocolStr: protocolStr,
|
||||
actionStr: actionStr,
|
||||
portsJoined: portsJoined,
|
||||
rules = AppendIPv6FirewallRule(rules, rulesExists, peer, targetPeer, rule, FirewallRuleContext{
|
||||
Direction: direction,
|
||||
DirStr: dirStr,
|
||||
ProtocolStr: protocolStr,
|
||||
ActionStr: actionStr,
|
||||
PortsJoined: portsJoined,
|
||||
})
|
||||
}
|
||||
}, func() ([]*nbpeer.Peer, []*FirewallRule) {
|
||||
@@ -684,7 +695,7 @@ func (c *NetworkMapComponents) getRouteFirewallRules(ctx context.Context, peerID
|
||||
}
|
||||
|
||||
rulePeers := c.getRulePeers(rule, policy.SourcePostureChecks, peerID, distributionPeers)
|
||||
rules := generateRouteFirewallRules(ctx, route, rule, rulePeers, FirewallRuleDirectionIN, includeIPv6)
|
||||
rules := GenerateRouteFirewallRules(ctx, route, rule, rulePeers, FirewallRuleDirectionIN, includeIPv6)
|
||||
fwRules = append(fwRules, rules...)
|
||||
}
|
||||
}
|
||||
@@ -953,21 +964,21 @@ func (c *NetworkMapComponents) addNetworksRoutingPeers(
|
||||
return peersToConnect
|
||||
}
|
||||
|
||||
type firewallRuleContext struct {
|
||||
direction int
|
||||
dirStr string
|
||||
protocolStr string
|
||||
actionStr string
|
||||
portsJoined string
|
||||
type FirewallRuleContext struct {
|
||||
Direction int
|
||||
DirStr string
|
||||
ProtocolStr string
|
||||
ActionStr string
|
||||
PortsJoined string
|
||||
}
|
||||
|
||||
func appendIPv6FirewallRule(rules []*FirewallRule, rulesExists map[string]struct{}, peer, targetPeer *nbpeer.Peer, rule *PolicyRule, rc firewallRuleContext) []*FirewallRule {
|
||||
func AppendIPv6FirewallRule(rules []*FirewallRule, rulesExists map[string]struct{}, peer, targetPeer *nbpeer.Peer, rule *PolicyRule, rc FirewallRuleContext) []*FirewallRule {
|
||||
if !peer.IPv6.IsValid() || !targetPeer.SupportsIPv6() || !targetPeer.IPv6.IsValid() {
|
||||
return rules
|
||||
}
|
||||
|
||||
v6IP := peer.IPv6.String()
|
||||
v6RuleID := rule.ID + v6IP + rc.dirStr + rc.protocolStr + rc.actionStr + rc.portsJoined
|
||||
v6RuleID := rule.ID + v6IP + rc.DirStr + rc.ProtocolStr + rc.ActionStr + rc.PortsJoined
|
||||
if _, ok := rulesExists[v6RuleID]; ok {
|
||||
return rules
|
||||
}
|
||||
@@ -976,12 +987,12 @@ func appendIPv6FirewallRule(rules []*FirewallRule, rulesExists map[string]struct
|
||||
v6fr := FirewallRule{
|
||||
PolicyID: rule.ID,
|
||||
PeerIP: v6IP,
|
||||
Direction: rc.direction,
|
||||
Action: rc.actionStr,
|
||||
Protocol: rc.protocolStr,
|
||||
Direction: rc.Direction,
|
||||
Action: rc.ActionStr,
|
||||
Protocol: rc.ProtocolStr,
|
||||
}
|
||||
if len(rule.Ports) == 0 && len(rule.PortRanges) == 0 {
|
||||
return append(rules, &v6fr)
|
||||
}
|
||||
return append(rules, expandPortsAndRanges(v6fr, rule, targetPeer)...)
|
||||
return append(rules, ExpandPortsAndRanges(v6fr, rule, targetPeer)...)
|
||||
}
|
||||
@@ -59,6 +59,10 @@ type Policy struct {
|
||||
// AccountID is a reference to Account that this object belongs
|
||||
AccountID string `json:"-" gorm:"index"`
|
||||
|
||||
// AccountSeqID is a per-account monotonically increasing identifier used as the
|
||||
// compact wire id when sending NetworkMap components to capable peers.
|
||||
AccountSeqID uint32 `json:"-" gorm:"index:idx_policies_account_seq_id;not null;default:0"`
|
||||
|
||||
// Name of the Policy
|
||||
Name string
|
||||
|
||||
@@ -75,11 +79,19 @@ type Policy struct {
|
||||
SourcePostureChecks []string `gorm:"serializer:json"`
|
||||
}
|
||||
|
||||
// HasSeqID reports whether the policy has been persisted long enough to have
|
||||
// a per-account sequence id allocated. Wire encoders that key off
|
||||
// AccountSeqID must skip policies that return false here.
|
||||
func (p *Policy) HasSeqID() bool {
|
||||
return p != nil && p.AccountSeqID != 0
|
||||
}
|
||||
|
||||
// Copy returns a copy of the policy.
|
||||
func (p *Policy) Copy() *Policy {
|
||||
c := &Policy{
|
||||
ID: p.ID,
|
||||
AccountID: p.AccountID,
|
||||
AccountSeqID: p.AccountSeqID,
|
||||
Name: p.Name,
|
||||
Description: p.Description,
|
||||
Enabled: p.Enabled,
|
||||
Reference in New Issue
Block a user