diff --git a/client/cmd/debug.go b/client/cmd/debug.go
index 0e2717756..e3d3afe5f 100644
--- a/client/cmd/debug.go
+++ b/client/cmd/debug.go
@@ -199,9 +199,11 @@ func runForDuration(cmd *cobra.Command, args []string) error {
cmd.Println("Log level set to trace.")
}
+ needsRestoreUp := false
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
cmd.PrintErrf("Failed to bring service down: %v\n", status.Convert(err).Message())
} else {
+ needsRestoreUp = !stateWasDown
cmd.Println("netbird down")
}
@@ -217,6 +219,7 @@ func runForDuration(cmd *cobra.Command, args []string) error {
if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil {
cmd.PrintErrf("Failed to bring service up: %v\n", status.Convert(err).Message())
} else {
+ needsRestoreUp = false
cmd.Println("netbird up")
}
@@ -264,6 +267,14 @@ func runForDuration(cmd *cobra.Command, args []string) error {
return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message())
}
+ if needsRestoreUp {
+ if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil {
+ cmd.PrintErrf("Failed to restore service up state: %v\n", status.Convert(err).Message())
+ } else {
+ cmd.Println("netbird up (restored)")
+ }
+ }
+
if stateWasDown {
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
cmd.PrintErrf("Failed to restore service down state: %v\n", status.Convert(err).Message())
diff --git a/client/internal/connect.go b/client/internal/connect.go
index 7ee745ae6..97c350d4e 100644
--- a/client/internal/connect.go
+++ b/client/internal/connect.go
@@ -114,6 +114,7 @@ func (c *ConnectClient) RunOniOS(
fileDescriptor int32,
networkChangeListener listener.NetworkChangeListener,
dnsManager dns.IosDnsManager,
+ dnsAddresses []netip.AddrPort,
stateFilePath string,
) error {
// Set GC percent to 5% to reduce memory usage as iOS only allows 50MB of memory for the extension.
@@ -123,6 +124,7 @@ func (c *ConnectClient) RunOniOS(
FileDescriptor: fileDescriptor,
NetworkChangeListener: networkChangeListener,
DnsManager: dnsManager,
+ HostDNSAddresses: dnsAddresses,
StateFilePath: stateFilePath,
}
return c.run(mobileDependency, nil, "")
diff --git a/client/internal/debug/debug.go b/client/internal/debug/debug.go
index 330cebe49..00f8b1a8d 100644
--- a/client/internal/debug/debug.go
+++ b/client/internal/debug/debug.go
@@ -25,6 +25,7 @@ import (
"google.golang.org/protobuf/encoding/protojson"
"github.com/netbirdio/netbird/client/anonymize"
+ "github.com/netbirdio/netbird/client/configs"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/internal/updater/installer"
@@ -52,6 +53,7 @@ resolved_domains.txt: Anonymized resolved domain IP addresses from the status re
config.txt: Anonymized configuration information of the NetBird client.
network_map.json: Anonymized sync response containing peer configurations, routes, DNS settings, and firewall rules.
state.json: Anonymized client state dump containing netbird states for the active profile.
+service_params.json: Sanitized service install parameters (service.json). Sensitive environment variable values are masked. Only present when service.json exists.
metrics.txt: Buffered client metrics in InfluxDB line protocol format. Only present when metrics collection is enabled. Peer identifiers are anonymized.
mutex.prof: Mutex profiling information.
goroutine.prof: Goroutine profiling information.
@@ -359,6 +361,10 @@ func (g *BundleGenerator) createArchive() error {
log.Errorf("failed to add corrupted state files to debug bundle: %v", err)
}
+ if err := g.addServiceParams(); err != nil {
+ log.Errorf("failed to add service params to debug bundle: %v", err)
+ }
+
if err := g.addMetrics(); err != nil {
log.Errorf("failed to add metrics to debug bundle: %v", err)
}
@@ -488,6 +494,90 @@ func (g *BundleGenerator) addConfig() error {
return nil
}
+const (
+ serviceParamsFile = "service.json"
+ serviceParamsBundle = "service_params.json"
+ maskedValue = "***"
+ envVarPrefix = "NB_"
+ jsonKeyManagementURL = "management_url"
+ jsonKeyServiceEnv = "service_env_vars"
+)
+
+var sensitiveEnvSubstrings = []string{"key", "token", "secret", "password", "credential"}
+
+// addServiceParams reads the service.json file and adds a sanitized version to the bundle.
+// Non-NB_ env vars and vars with sensitive names are masked. Other NB_ values are anonymized.
+func (g *BundleGenerator) addServiceParams() error {
+ path := filepath.Join(configs.StateDir, serviceParamsFile)
+
+ data, err := os.ReadFile(path)
+ if err != nil {
+ if os.IsNotExist(err) {
+ return nil
+ }
+ return fmt.Errorf("read service params: %w", err)
+ }
+
+ var params map[string]any
+ if err := json.Unmarshal(data, ¶ms); err != nil {
+ return fmt.Errorf("parse service params: %w", err)
+ }
+
+ if g.anonymize {
+ if mgmtURL, ok := params[jsonKeyManagementURL].(string); ok && mgmtURL != "" {
+ params[jsonKeyManagementURL] = g.anonymizer.AnonymizeURI(mgmtURL)
+ }
+ }
+
+ g.sanitizeServiceEnvVars(params)
+
+ sanitizedData, err := json.MarshalIndent(params, "", " ")
+ if err != nil {
+ return fmt.Errorf("marshal sanitized service params: %w", err)
+ }
+
+ if err := g.addFileToZip(bytes.NewReader(sanitizedData), serviceParamsBundle); err != nil {
+ return fmt.Errorf("add service params to zip: %w", err)
+ }
+
+ return nil
+}
+
+// sanitizeServiceEnvVars masks or anonymizes env var values in service params.
+// Non-NB_ vars and vars with sensitive names (key, token, etc.) are fully masked.
+// Other NB_ var values are passed through the anonymizer when anonymization is enabled.
+func (g *BundleGenerator) sanitizeServiceEnvVars(params map[string]any) {
+ envVars, ok := params[jsonKeyServiceEnv].(map[string]any)
+ if !ok {
+ return
+ }
+
+ sanitized := make(map[string]any, len(envVars))
+ for k, v := range envVars {
+ val, _ := v.(string)
+ switch {
+ case !strings.HasPrefix(k, envVarPrefix) || isSensitiveEnvVar(k):
+ sanitized[k] = maskedValue
+ case g.anonymize:
+ sanitized[k] = g.anonymizer.AnonymizeString(val)
+ default:
+ sanitized[k] = val
+ }
+ }
+ params[jsonKeyServiceEnv] = sanitized
+}
+
+// isSensitiveEnvVar returns true for env var names that may contain secrets.
+func isSensitiveEnvVar(key string) bool {
+ lower := strings.ToLower(key)
+ for _, s := range sensitiveEnvSubstrings {
+ if strings.Contains(lower, s) {
+ return true
+ }
+ }
+ return false
+}
+
func (g *BundleGenerator) addCommonConfigFields(configContent *strings.Builder) {
configContent.WriteString("NetBird Client Configuration:\n\n")
diff --git a/client/internal/debug/debug_test.go b/client/internal/debug/debug_test.go
index e242b8b1b..49c18c679 100644
--- a/client/internal/debug/debug_test.go
+++ b/client/internal/debug/debug_test.go
@@ -1,8 +1,12 @@
package debug
import (
+ "archive/zip"
+ "bytes"
"encoding/json"
"net"
+ "os"
+ "path/filepath"
"strings"
"testing"
@@ -10,6 +14,7 @@ import (
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/anonymize"
+ "github.com/netbirdio/netbird/client/configs"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
)
@@ -420,6 +425,226 @@ func TestAnonymizeNetworkMap(t *testing.T) {
}
}
+func TestIsSensitiveEnvVar(t *testing.T) {
+ tests := []struct {
+ key string
+ sensitive bool
+ }{
+ {"NB_SETUP_KEY", true},
+ {"NB_API_TOKEN", true},
+ {"NB_CLIENT_SECRET", true},
+ {"NB_PASSWORD", true},
+ {"NB_CREDENTIAL", true},
+ {"NB_LOG_LEVEL", false},
+ {"NB_MANAGEMENT_URL", false},
+ {"NB_HOSTNAME", false},
+ {"HOME", false},
+ {"PATH", false},
+ }
+ for _, tt := range tests {
+ t.Run(tt.key, func(t *testing.T) {
+ assert.Equal(t, tt.sensitive, isSensitiveEnvVar(tt.key))
+ })
+ }
+}
+
+func TestSanitizeServiceEnvVars(t *testing.T) {
+ tests := []struct {
+ name string
+ anonymize bool
+ input map[string]any
+ check func(t *testing.T, params map[string]any)
+ }{
+ {
+ name: "no env vars key",
+ anonymize: false,
+ input: map[string]any{"management_url": "https://mgmt.example.com"},
+ check: func(t *testing.T, params map[string]any) {
+ t.Helper()
+ assert.Equal(t, "https://mgmt.example.com", params["management_url"], "non-env fields should be untouched")
+ _, ok := params[jsonKeyServiceEnv]
+ assert.False(t, ok, "service_env_vars should not be added")
+ },
+ },
+ {
+ name: "non-NB vars are masked",
+ anonymize: false,
+ input: map[string]any{
+ jsonKeyServiceEnv: map[string]any{
+ "HOME": "/root",
+ "PATH": "/usr/bin",
+ "NB_LOG_LEVEL": "debug",
+ },
+ },
+ check: func(t *testing.T, params map[string]any) {
+ t.Helper()
+ env := params[jsonKeyServiceEnv].(map[string]any)
+ assert.Equal(t, maskedValue, env["HOME"], "non-NB_ var should be masked")
+ assert.Equal(t, maskedValue, env["PATH"], "non-NB_ var should be masked")
+ assert.Equal(t, "debug", env["NB_LOG_LEVEL"], "safe NB_ var should pass through")
+ },
+ },
+ {
+ name: "sensitive NB vars are masked",
+ anonymize: false,
+ input: map[string]any{
+ jsonKeyServiceEnv: map[string]any{
+ "NB_SETUP_KEY": "abc123",
+ "NB_API_TOKEN": "tok_xyz",
+ "NB_LOG_LEVEL": "info",
+ },
+ },
+ check: func(t *testing.T, params map[string]any) {
+ t.Helper()
+ env := params[jsonKeyServiceEnv].(map[string]any)
+ assert.Equal(t, maskedValue, env["NB_SETUP_KEY"], "sensitive NB_ var should be masked")
+ assert.Equal(t, maskedValue, env["NB_API_TOKEN"], "sensitive NB_ var should be masked")
+ assert.Equal(t, "info", env["NB_LOG_LEVEL"], "safe NB_ var should pass through")
+ },
+ },
+ {
+ name: "safe NB vars anonymized when anonymize is true",
+ anonymize: true,
+ input: map[string]any{
+ jsonKeyServiceEnv: map[string]any{
+ "NB_MANAGEMENT_URL": "https://mgmt.example.com:443",
+ "NB_LOG_LEVEL": "debug",
+ "NB_SETUP_KEY": "secret",
+ "SOME_OTHER": "val",
+ },
+ },
+ check: func(t *testing.T, params map[string]any) {
+ t.Helper()
+ env := params[jsonKeyServiceEnv].(map[string]any)
+ // Safe NB_ values should be anonymized (not the original, not masked)
+ mgmtVal := env["NB_MANAGEMENT_URL"].(string)
+ assert.NotEqual(t, "https://mgmt.example.com:443", mgmtVal, "should be anonymized")
+ assert.NotEqual(t, maskedValue, mgmtVal, "should not be masked")
+
+ logVal := env["NB_LOG_LEVEL"].(string)
+ assert.NotEqual(t, maskedValue, logVal, "safe NB_ var should not be masked")
+
+ // Sensitive and non-NB_ still masked
+ assert.Equal(t, maskedValue, env["NB_SETUP_KEY"])
+ assert.Equal(t, maskedValue, env["SOME_OTHER"])
+ },
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ anonymizer := anonymize.NewAnonymizer(anonymize.DefaultAddresses())
+ g := &BundleGenerator{
+ anonymize: tt.anonymize,
+ anonymizer: anonymizer,
+ }
+ g.sanitizeServiceEnvVars(tt.input)
+ tt.check(t, tt.input)
+ })
+ }
+}
+
+func TestAddServiceParams(t *testing.T) {
+ t.Run("missing service.json returns nil", func(t *testing.T) {
+ g := &BundleGenerator{
+ anonymizer: anonymize.NewAnonymizer(anonymize.DefaultAddresses()),
+ }
+
+ origStateDir := configs.StateDir
+ configs.StateDir = t.TempDir()
+ t.Cleanup(func() { configs.StateDir = origStateDir })
+
+ err := g.addServiceParams()
+ assert.NoError(t, err)
+ })
+
+ t.Run("management_url anonymized when anonymize is true", func(t *testing.T) {
+ dir := t.TempDir()
+ origStateDir := configs.StateDir
+ configs.StateDir = dir
+ t.Cleanup(func() { configs.StateDir = origStateDir })
+
+ input := map[string]any{
+ jsonKeyManagementURL: "https://api.example.com:443",
+ jsonKeyServiceEnv: map[string]any{
+ "NB_LOG_LEVEL": "trace",
+ },
+ }
+ data, err := json.Marshal(input)
+ require.NoError(t, err)
+ require.NoError(t, os.WriteFile(filepath.Join(dir, serviceParamsFile), data, 0600))
+
+ var buf bytes.Buffer
+ zw := zip.NewWriter(&buf)
+
+ g := &BundleGenerator{
+ anonymize: true,
+ anonymizer: anonymize.NewAnonymizer(anonymize.DefaultAddresses()),
+ archive: zw,
+ }
+
+ require.NoError(t, g.addServiceParams())
+ require.NoError(t, zw.Close())
+
+ zr, err := zip.NewReader(bytes.NewReader(buf.Bytes()), int64(buf.Len()))
+ require.NoError(t, err)
+ require.Len(t, zr.File, 1)
+ assert.Equal(t, serviceParamsBundle, zr.File[0].Name)
+
+ rc, err := zr.File[0].Open()
+ require.NoError(t, err)
+ defer rc.Close()
+
+ var result map[string]any
+ require.NoError(t, json.NewDecoder(rc).Decode(&result))
+
+ mgmt := result[jsonKeyManagementURL].(string)
+ assert.NotEqual(t, "https://api.example.com:443", mgmt, "management_url should be anonymized")
+ assert.NotEmpty(t, mgmt)
+
+ env := result[jsonKeyServiceEnv].(map[string]any)
+ assert.NotEqual(t, maskedValue, env["NB_LOG_LEVEL"], "safe NB_ var should not be masked")
+ })
+
+ t.Run("management_url preserved when anonymize is false", func(t *testing.T) {
+ dir := t.TempDir()
+ origStateDir := configs.StateDir
+ configs.StateDir = dir
+ t.Cleanup(func() { configs.StateDir = origStateDir })
+
+ input := map[string]any{
+ jsonKeyManagementURL: "https://api.example.com:443",
+ }
+ data, err := json.Marshal(input)
+ require.NoError(t, err)
+ require.NoError(t, os.WriteFile(filepath.Join(dir, serviceParamsFile), data, 0600))
+
+ var buf bytes.Buffer
+ zw := zip.NewWriter(&buf)
+
+ g := &BundleGenerator{
+ anonymize: false,
+ anonymizer: anonymize.NewAnonymizer(anonymize.DefaultAddresses()),
+ archive: zw,
+ }
+
+ require.NoError(t, g.addServiceParams())
+ require.NoError(t, zw.Close())
+
+ zr, err := zip.NewReader(bytes.NewReader(buf.Bytes()), int64(buf.Len()))
+ require.NoError(t, err)
+
+ rc, err := zr.File[0].Open()
+ require.NoError(t, err)
+ defer rc.Close()
+
+ var result map[string]any
+ require.NoError(t, json.NewDecoder(rc).Decode(&result))
+
+ assert.Equal(t, "https://api.example.com:443", result[jsonKeyManagementURL], "management_url should be preserved")
+ })
+}
+
// Helper function to check if IP is in CGNAT range
func isInCGNATRange(ip net.IP) bool {
cgnat := net.IPNet{
diff --git a/client/internal/dns/local/local_test.go b/client/internal/dns/local/local_test.go
index 73f70035f..2c6b7dbc3 100644
--- a/client/internal/dns/local/local_test.go
+++ b/client/internal/dns/local/local_test.go
@@ -1263,9 +1263,9 @@ func TestLocalResolver_AuthoritativeFlag(t *testing.T) {
})
}
-// TestLocalResolver_Stop tests cleanup on Stop
+// TestLocalResolver_Stop tests cleanup on GracefullyStop
func TestLocalResolver_Stop(t *testing.T) {
- t.Run("Stop clears all state", func(t *testing.T) {
+ t.Run("GracefullyStop clears all state", func(t *testing.T) {
resolver := NewResolver()
resolver.Update([]nbdns.CustomZone{{
Domain: "example.com.",
@@ -1285,7 +1285,7 @@ func TestLocalResolver_Stop(t *testing.T) {
assert.False(t, resolver.isInManagedZone("host.example.com."))
})
- t.Run("Stop is safe to call multiple times", func(t *testing.T) {
+ t.Run("GracefullyStop is safe to call multiple times", func(t *testing.T) {
resolver := NewResolver()
resolver.Update([]nbdns.CustomZone{{
Domain: "example.com.",
@@ -1299,7 +1299,7 @@ func TestLocalResolver_Stop(t *testing.T) {
resolver.Stop()
})
- t.Run("Stop cancels in-flight external resolution", func(t *testing.T) {
+ t.Run("GracefullyStop cancels in-flight external resolution", func(t *testing.T) {
resolver := NewResolver()
lookupStarted := make(chan struct{})
diff --git a/client/internal/dns/server.go b/client/internal/dns/server.go
index d4fda5db3..f7865047b 100644
--- a/client/internal/dns/server.go
+++ b/client/internal/dns/server.go
@@ -187,11 +187,16 @@ func NewDefaultServerIos(
ctx context.Context,
wgInterface WGIface,
iosDnsManager IosDnsManager,
+ hostsDnsList []netip.AddrPort,
statusRecorder *peer.Status,
disableSys bool,
) *DefaultServer {
+ log.Debugf("iOS host dns address list is: %v", hostsDnsList)
ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder, nil, disableSys)
ds.iosDnsManager = iosDnsManager
+ ds.hostsDNSHolder.set(hostsDnsList)
+ ds.permanent = true
+ ds.addHostRootZone()
return ds
}
diff --git a/client/internal/engine.go b/client/internal/engine.go
index 16410519b..ce4d71e35 100644
--- a/client/internal/engine.go
+++ b/client/internal/engine.go
@@ -47,6 +47,7 @@ import (
"github.com/netbirdio/netbird/client/internal/peer/guard"
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
"github.com/netbirdio/netbird/client/internal/peerstore"
+ "github.com/netbirdio/netbird/client/internal/portforward"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/internal/relay"
"github.com/netbirdio/netbird/client/internal/rosenpass"
@@ -214,9 +215,10 @@ type Engine struct {
// checks are the client-applied posture checks that need to be evaluated on the client
checks []*mgmProto.Checks
- relayManager *relayClient.Manager
- stateManager *statemanager.Manager
- srWatcher *guard.SRWatcher
+ relayManager *relayClient.Manager
+ stateManager *statemanager.Manager
+ portForwardManager *portforward.Manager
+ srWatcher *guard.SRWatcher
// Sync response persistence (protected by syncRespMux)
syncRespMux sync.RWMutex
@@ -263,26 +265,27 @@ func NewEngine(
mobileDep MobileDependency,
) *Engine {
engine := &Engine{
- clientCtx: clientCtx,
- clientCancel: clientCancel,
- signal: services.SignalClient,
- signaler: peer.NewSignaler(services.SignalClient, config.WgPrivateKey),
- mgmClient: services.MgmClient,
- relayManager: services.RelayManager,
- peerStore: peerstore.NewConnStore(),
- syncMsgMux: &sync.Mutex{},
- config: config,
- mobileDep: mobileDep,
- STUNs: []*stun.URI{},
- TURNs: []*stun.URI{},
- networkSerial: 0,
- statusRecorder: services.StatusRecorder,
- stateManager: services.StateManager,
- checks: services.Checks,
- probeStunTurn: relay.NewStunTurnProbe(relay.DefaultCacheTTL),
- jobExecutor: jobexec.NewExecutor(),
- clientMetrics: services.ClientMetrics,
- updateManager: services.UpdateManager,
+ clientCtx: clientCtx,
+ clientCancel: clientCancel,
+ signal: services.SignalClient,
+ signaler: peer.NewSignaler(services.SignalClient, config.WgPrivateKey),
+ mgmClient: services.MgmClient,
+ relayManager: services.RelayManager,
+ peerStore: peerstore.NewConnStore(),
+ syncMsgMux: &sync.Mutex{},
+ config: config,
+ mobileDep: mobileDep,
+ STUNs: []*stun.URI{},
+ TURNs: []*stun.URI{},
+ networkSerial: 0,
+ statusRecorder: services.StatusRecorder,
+ stateManager: services.StateManager,
+ portForwardManager: portforward.NewManager(),
+ checks: services.Checks,
+ probeStunTurn: relay.NewStunTurnProbe(relay.DefaultCacheTTL),
+ jobExecutor: jobexec.NewExecutor(),
+ clientMetrics: services.ClientMetrics,
+ updateManager: services.UpdateManager,
}
log.Infof("I am: %s", config.WgPrivateKey.PublicKey().String())
@@ -541,6 +544,13 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
// conntrack entries from being created before the rules are in place
e.setupWGProxyNoTrack()
+ // Start after interface is up since port may have been resolved from 0 or changed if occupied
+ e.shutdownWg.Add(1)
+ go func() {
+ defer e.shutdownWg.Done()
+ e.portForwardManager.Start(e.ctx, uint16(e.config.WgPort))
+ }()
+
// Set the WireGuard interface for rosenpass after interface is up
if e.rpManager != nil {
e.rpManager.SetInterface(e.wgInterface)
@@ -1627,12 +1637,13 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix, agentV
}
serviceDependencies := peer.ServiceDependencies{
- StatusRecorder: e.statusRecorder,
- Signaler: e.signaler,
- IFaceDiscover: e.mobileDep.IFaceDiscover,
- RelayManager: e.relayManager,
- SrWatcher: e.srWatcher,
- MetricsRecorder: e.clientMetrics,
+ StatusRecorder: e.statusRecorder,
+ Signaler: e.signaler,
+ IFaceDiscover: e.mobileDep.IFaceDiscover,
+ RelayManager: e.relayManager,
+ SrWatcher: e.srWatcher,
+ PortForwardManager: e.portForwardManager,
+ MetricsRecorder: e.clientMetrics,
}
peerConn, err := peer.NewConn(config, serviceDependencies)
if err != nil {
@@ -1789,6 +1800,12 @@ func (e *Engine) close() {
if e.rpManager != nil {
_ = e.rpManager.Close()
}
+
+ ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+ defer cancel()
+ if err := e.portForwardManager.GracefullyStop(ctx); err != nil {
+ log.Warnf("failed to gracefully stop port forwarding manager: %s", err)
+ }
}
func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, bool, error) {
@@ -1894,7 +1911,7 @@ func (e *Engine) newDnsServer(dnsConfig *nbdns.Config) (dns.Server, error) {
return dnsServer, nil
case "ios":
- dnsServer := dns.NewDefaultServerIos(e.ctx, e.wgInterface, e.mobileDep.DnsManager, e.statusRecorder, e.config.DisableDNS)
+ dnsServer := dns.NewDefaultServerIos(e.ctx, e.wgInterface, e.mobileDep.DnsManager, e.mobileDep.HostDNSAddresses, e.statusRecorder, e.config.DisableDNS)
return dnsServer, nil
default:
diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go
index bea0725f2..8d1585b3f 100644
--- a/client/internal/peer/conn.go
+++ b/client/internal/peer/conn.go
@@ -22,6 +22,7 @@ import (
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
"github.com/netbirdio/netbird/client/internal/peer/id"
"github.com/netbirdio/netbird/client/internal/peer/worker"
+ "github.com/netbirdio/netbird/client/internal/portforward"
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/route"
relayClient "github.com/netbirdio/netbird/shared/relay/client"
@@ -45,6 +46,7 @@ type ServiceDependencies struct {
RelayManager *relayClient.Manager
SrWatcher *guard.SRWatcher
PeerConnDispatcher *dispatcher.ConnectionDispatcher
+ PortForwardManager *portforward.Manager
MetricsRecorder MetricsRecorder
}
@@ -87,16 +89,17 @@ type ConnConfig struct {
}
type Conn struct {
- Log *log.Entry
- mu sync.Mutex
- ctx context.Context
- ctxCancel context.CancelFunc
- config ConnConfig
- statusRecorder *Status
- signaler *Signaler
- iFaceDiscover stdnet.ExternalIFaceDiscover
- relayManager *relayClient.Manager
- srWatcher *guard.SRWatcher
+ Log *log.Entry
+ mu sync.Mutex
+ ctx context.Context
+ ctxCancel context.CancelFunc
+ config ConnConfig
+ statusRecorder *Status
+ signaler *Signaler
+ iFaceDiscover stdnet.ExternalIFaceDiscover
+ relayManager *relayClient.Manager
+ srWatcher *guard.SRWatcher
+ portForwardManager *portforward.Manager
onConnected func(remoteWireGuardKey string, remoteRosenpassPubKey []byte, wireGuardIP string, remoteRosenpassAddr string)
onDisconnected func(remotePeer string)
@@ -145,19 +148,20 @@ func NewConn(config ConnConfig, services ServiceDependencies) (*Conn, error) {
dumpState := newStateDump(config.Key, connLog, services.StatusRecorder)
var conn = &Conn{
- Log: connLog,
- config: config,
- statusRecorder: services.StatusRecorder,
- signaler: services.Signaler,
- iFaceDiscover: services.IFaceDiscover,
- relayManager: services.RelayManager,
- srWatcher: services.SrWatcher,
- statusRelay: worker.NewAtomicStatus(),
- statusICE: worker.NewAtomicStatus(),
- dumpState: dumpState,
- endpointUpdater: NewEndpointUpdater(connLog, config.WgConfig, isController(config)),
- wgWatcher: NewWGWatcher(connLog, config.WgConfig.WgInterface, config.Key, dumpState),
- metricsRecorder: services.MetricsRecorder,
+ Log: connLog,
+ config: config,
+ statusRecorder: services.StatusRecorder,
+ signaler: services.Signaler,
+ iFaceDiscover: services.IFaceDiscover,
+ relayManager: services.RelayManager,
+ srWatcher: services.SrWatcher,
+ portForwardManager: services.PortForwardManager,
+ statusRelay: worker.NewAtomicStatus(),
+ statusICE: worker.NewAtomicStatus(),
+ dumpState: dumpState,
+ endpointUpdater: NewEndpointUpdater(connLog, config.WgConfig, isController(config)),
+ wgWatcher: NewWGWatcher(connLog, config.WgConfig.WgInterface, config.Key, dumpState),
+ metricsRecorder: services.MetricsRecorder,
}
return conn, nil
diff --git a/client/internal/peer/worker_ice.go b/client/internal/peer/worker_ice.go
index edd70fb20..29bf5aaaa 100644
--- a/client/internal/peer/worker_ice.go
+++ b/client/internal/peer/worker_ice.go
@@ -16,6 +16,7 @@ import (
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/internal/peer/conntype"
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
+ "github.com/netbirdio/netbird/client/internal/portforward"
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/route"
)
@@ -61,6 +62,9 @@ type WorkerICE struct {
// we record the last known state of the ICE agent to avoid duplicate on disconnected events
lastKnownState ice.ConnectionState
+
+ // portForwardAttempted tracks if we've already tried port forwarding this session
+ portForwardAttempted bool
}
func NewWorkerICE(ctx context.Context, log *log.Entry, config ConnConfig, conn *Conn, signaler *Signaler, ifaceDiscover stdnet.ExternalIFaceDiscover, statusRecorder *Status, hasRelayOnLocally bool) (*WorkerICE, error) {
@@ -214,6 +218,8 @@ func (w *WorkerICE) Close() {
}
func (w *WorkerICE) reCreateAgent(dialerCancel context.CancelFunc, candidates []ice.CandidateType) (*icemaker.ThreadSafeAgent, error) {
+ w.portForwardAttempted = false
+
agent, err := icemaker.NewAgent(w.ctx, w.iFaceDiscover, w.config.ICEConfig, candidates, w.localUfrag, w.localPwd)
if err != nil {
return nil, fmt.Errorf("create agent: %w", err)
@@ -370,6 +376,93 @@ func (w *WorkerICE) onICECandidate(candidate ice.Candidate) {
w.log.Errorf("failed signaling candidate to the remote peer %s %s", w.config.Key, err)
}
}()
+
+ if candidate.Type() == ice.CandidateTypeServerReflexive {
+ w.injectPortForwardedCandidate(candidate)
+ }
+}
+
+// injectPortForwardedCandidate signals an additional candidate using the pre-created port mapping.
+func (w *WorkerICE) injectPortForwardedCandidate(srflxCandidate ice.Candidate) {
+ pfManager := w.conn.portForwardManager
+ if pfManager == nil {
+ return
+ }
+
+ mapping := pfManager.GetMapping()
+ if mapping == nil {
+ return
+ }
+
+ w.muxAgent.Lock()
+ if w.portForwardAttempted {
+ w.muxAgent.Unlock()
+ return
+ }
+ w.portForwardAttempted = true
+ w.muxAgent.Unlock()
+
+ forwardedCandidate, err := w.createForwardedCandidate(srflxCandidate, mapping)
+ if err != nil {
+ w.log.Warnf("create forwarded candidate: %v", err)
+ return
+ }
+
+ w.log.Debugf("injecting port-forwarded candidate: %s (mapping: %d -> %d via %s, priority: %d)",
+ forwardedCandidate.String(), mapping.InternalPort, mapping.ExternalPort, mapping.NATType, forwardedCandidate.Priority())
+
+ go func() {
+ if err := w.signaler.SignalICECandidate(forwardedCandidate, w.config.Key); err != nil {
+ w.log.Errorf("signal port-forwarded candidate: %v", err)
+ }
+ }()
+}
+
+// createForwardedCandidate creates a new server reflexive candidate with the forwarded port.
+// It uses the NAT gateway's external IP with the forwarded port.
+func (w *WorkerICE) createForwardedCandidate(srflxCandidate ice.Candidate, mapping *portforward.Mapping) (ice.Candidate, error) {
+ var externalIP string
+ if mapping.ExternalIP != nil && !mapping.ExternalIP.IsUnspecified() {
+ externalIP = mapping.ExternalIP.String()
+ } else {
+ // Fallback to STUN-discovered address if NAT didn't provide external IP
+ externalIP = srflxCandidate.Address()
+ }
+
+ // Per RFC 8445, the related address for srflx is the base (host candidate address).
+ // If the original srflx has unspecified related address, use its own address as base.
+ relAddr := srflxCandidate.RelatedAddress().Address
+ if relAddr == "" || relAddr == "0.0.0.0" || relAddr == "::" {
+ relAddr = srflxCandidate.Address()
+ }
+
+ // Arbitrary +1000 boost on top of RFC 8445 priority to favor port-forwarded candidates
+ // over regular srflx during ICE connectivity checks.
+ priority := srflxCandidate.Priority() + 1000
+
+ candidate, err := ice.NewCandidateServerReflexive(&ice.CandidateServerReflexiveConfig{
+ Network: srflxCandidate.NetworkType().String(),
+ Address: externalIP,
+ Port: int(mapping.ExternalPort),
+ Component: srflxCandidate.Component(),
+ Priority: priority,
+ RelAddr: relAddr,
+ RelPort: int(mapping.InternalPort),
+ })
+ if err != nil {
+ return nil, fmt.Errorf("create candidate: %w", err)
+ }
+
+ for _, e := range srflxCandidate.Extensions() {
+ if e.Key == ice.ExtensionKeyCandidateID {
+ e.Value = srflxCandidate.ID()
+ }
+ if err := candidate.AddExtension(e); err != nil {
+ return nil, fmt.Errorf("add extension: %w", err)
+ }
+ }
+
+ return candidate, nil
}
func (w *WorkerICE) onICESelectedCandidatePair(agent *icemaker.ThreadSafeAgent, c1, c2 ice.Candidate) {
@@ -411,10 +504,10 @@ func (w *WorkerICE) logSuccessfulPaths(agent *icemaker.ThreadSafeAgent) {
if !lok || !rok {
continue
}
- w.log.Debugf("successful ICE path %s: [%s %s %s] <-> [%s %s %s] rtt=%.3fms",
+ w.log.Debugf("successful ICE path %s: [%s %s %s:%d] <-> [%s %s %s:%d] rtt=%.3fms",
sessionID,
- local.NetworkType(), local.Type(), local.Address(),
- remote.NetworkType(), remote.Type(), remote.Address(),
+ local.NetworkType(), local.Type(), local.Address(), local.Port(),
+ remote.NetworkType(), remote.Type(), remote.Address(), remote.Port(),
stat.CurrentRoundTripTime*1000)
}
}
diff --git a/client/internal/portforward/env.go b/client/internal/portforward/env.go
new file mode 100644
index 000000000..444a6b478
--- /dev/null
+++ b/client/internal/portforward/env.go
@@ -0,0 +1,26 @@
+package portforward
+
+import (
+ "os"
+ "strconv"
+
+ log "github.com/sirupsen/logrus"
+)
+
+const (
+ envDisableNATMapper = "NB_DISABLE_NAT_MAPPER"
+)
+
+func isDisabledByEnv() bool {
+ val := os.Getenv(envDisableNATMapper)
+ if val == "" {
+ return false
+ }
+
+ disabled, err := strconv.ParseBool(val)
+ if err != nil {
+ log.Warnf("failed to parse %s: %v", envDisableNATMapper, err)
+ return false
+ }
+ return disabled
+}
diff --git a/client/internal/portforward/manager.go b/client/internal/portforward/manager.go
new file mode 100644
index 000000000..bf7533af9
--- /dev/null
+++ b/client/internal/portforward/manager.go
@@ -0,0 +1,280 @@
+//go:build !js
+
+package portforward
+
+import (
+ "context"
+ "fmt"
+ "net"
+ "regexp"
+ "sync"
+ "time"
+
+ "github.com/libp2p/go-nat"
+ log "github.com/sirupsen/logrus"
+)
+
+const (
+ defaultMappingTTL = 2 * time.Hour
+ discoveryTimeout = 10 * time.Second
+ mappingDescription = "NetBird"
+)
+
+// upnpErrPermanentLeaseOnly matches UPnP error 725 in SOAP fault XML,
+// allowing for whitespace/newlines between tags from different router firmware.
+var upnpErrPermanentLeaseOnly = regexp.MustCompile(`\s*725\s*`)
+
+// Mapping represents an active NAT port mapping.
+type Mapping struct {
+ Protocol string
+ InternalPort uint16
+ ExternalPort uint16
+ ExternalIP net.IP
+ NATType string
+ // TTL is the lease duration. Zero means a permanent lease that never expires.
+ TTL time.Duration
+}
+
+// TODO: persist mapping state for crash recovery cleanup of permanent leases.
+// Currently not done because State.Cleanup requires NAT gateway re-discovery,
+// which blocks startup for ~10s when no gateway is present (affects all clients).
+
+type Manager struct {
+ cancel context.CancelFunc
+
+ mapping *Mapping
+ mappingLock sync.Mutex
+
+ wgPort uint16
+
+ done chan struct{}
+ stopCtx chan context.Context
+
+ // protect exported functions
+ mu sync.Mutex
+}
+
+// NewManager creates a new port forwarding manager.
+func NewManager() *Manager {
+ return &Manager{
+ stopCtx: make(chan context.Context, 1),
+ }
+}
+
+func (m *Manager) Start(ctx context.Context, wgPort uint16) {
+ m.mu.Lock()
+ if m.cancel != nil {
+ m.mu.Unlock()
+ return
+ }
+
+ if isDisabledByEnv() {
+ log.Infof("NAT port mapper disabled via %s", envDisableNATMapper)
+ m.mu.Unlock()
+ return
+ }
+
+ if wgPort == 0 {
+ log.Warnf("invalid WireGuard port 0; NAT mapping disabled")
+ m.mu.Unlock()
+ return
+ }
+ m.wgPort = wgPort
+
+ m.done = make(chan struct{})
+ defer close(m.done)
+
+ ctx, m.cancel = context.WithCancel(ctx)
+ m.mu.Unlock()
+
+ gateway, mapping, err := m.setup(ctx)
+ if err != nil {
+ log.Infof("port forwarding setup: %v", err)
+ return
+ }
+
+ m.mappingLock.Lock()
+ m.mapping = mapping
+ m.mappingLock.Unlock()
+
+ m.renewLoop(ctx, gateway, mapping.TTL)
+
+ select {
+ case cleanupCtx := <-m.stopCtx:
+ // block the Start while cleaned up gracefully
+ m.cleanup(cleanupCtx, gateway)
+ default:
+ // return Start immediately and cleanup in background
+ cleanupCtx, cleanupCancel := context.WithTimeout(context.Background(), 10*time.Second)
+ go func() {
+ defer cleanupCancel()
+ m.cleanup(cleanupCtx, gateway)
+ }()
+ }
+}
+
+// GetMapping returns the current mapping if ready, nil otherwise
+func (m *Manager) GetMapping() *Mapping {
+ m.mappingLock.Lock()
+ defer m.mappingLock.Unlock()
+
+ if m.mapping == nil {
+ return nil
+ }
+
+ mapping := *m.mapping
+ return &mapping
+}
+
+// GracefullyStop cancels the manager and attempts to delete the port mapping.
+// After GracefullyStop returns, the manager cannot be restarted.
+func (m *Manager) GracefullyStop(ctx context.Context) error {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+
+ if m.cancel == nil {
+ return nil
+ }
+
+ // Send cleanup context before cancelling, so Start picks it up after renewLoop exits.
+ m.startTearDown(ctx)
+
+ m.cancel()
+ m.cancel = nil
+
+ select {
+ case <-ctx.Done():
+ return ctx.Err()
+ case <-m.done:
+ return nil
+ }
+}
+
+func (m *Manager) setup(ctx context.Context) (nat.NAT, *Mapping, error) {
+ discoverCtx, discoverCancel := context.WithTimeout(ctx, discoveryTimeout)
+ defer discoverCancel()
+
+ gateway, err := nat.DiscoverGateway(discoverCtx)
+ if err != nil {
+ return nil, nil, fmt.Errorf("discover gateway: %w", err)
+ }
+
+ log.Infof("discovered NAT gateway: %s", gateway.Type())
+
+ mapping, err := m.createMapping(ctx, gateway)
+ if err != nil {
+ return nil, nil, fmt.Errorf("create port mapping: %w", err)
+ }
+ return gateway, mapping, nil
+}
+
+func (m *Manager) createMapping(ctx context.Context, gateway nat.NAT) (*Mapping, error) {
+ ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
+ defer cancel()
+
+ ttl := defaultMappingTTL
+ externalPort, err := gateway.AddPortMapping(ctx, "udp", int(m.wgPort), mappingDescription, ttl)
+ if err != nil {
+ if !isPermanentLeaseRequired(err) {
+ return nil, err
+ }
+ log.Infof("gateway only supports permanent leases, retrying with indefinite duration")
+ ttl = 0
+ externalPort, err = gateway.AddPortMapping(ctx, "udp", int(m.wgPort), mappingDescription, ttl)
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ externalIP, err := gateway.GetExternalAddress()
+ if err != nil {
+ log.Debugf("failed to get external address: %v", err)
+ // todo return with err?
+ }
+
+ mapping := &Mapping{
+ Protocol: "udp",
+ InternalPort: m.wgPort,
+ ExternalPort: uint16(externalPort),
+ ExternalIP: externalIP,
+ NATType: gateway.Type(),
+ TTL: ttl,
+ }
+
+ log.Infof("created port mapping: %d -> %d via %s (external IP: %s)",
+ m.wgPort, externalPort, gateway.Type(), externalIP)
+ return mapping, nil
+}
+
+func (m *Manager) renewLoop(ctx context.Context, gateway nat.NAT, ttl time.Duration) {
+ if ttl == 0 {
+ // Permanent mappings don't expire, just wait for cancellation.
+ <-ctx.Done()
+ return
+ }
+
+ ticker := time.NewTicker(ttl / 2)
+ defer ticker.Stop()
+
+ for {
+ select {
+ case <-ctx.Done():
+ return
+ case <-ticker.C:
+ if err := m.renewMapping(ctx, gateway); err != nil {
+ log.Warnf("failed to renew port mapping: %v", err)
+ continue
+ }
+ }
+ }
+}
+
+func (m *Manager) renewMapping(ctx context.Context, gateway nat.NAT) error {
+ ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
+ defer cancel()
+
+ externalPort, err := gateway.AddPortMapping(ctx, m.mapping.Protocol, int(m.mapping.InternalPort), mappingDescription, m.mapping.TTL)
+ if err != nil {
+ return fmt.Errorf("add port mapping: %w", err)
+ }
+
+ if uint16(externalPort) != m.mapping.ExternalPort {
+ log.Warnf("external port changed on renewal: %d -> %d (candidate may be stale)", m.mapping.ExternalPort, externalPort)
+ m.mappingLock.Lock()
+ m.mapping.ExternalPort = uint16(externalPort)
+ m.mappingLock.Unlock()
+ }
+
+ log.Debugf("renewed port mapping: %d -> %d", m.mapping.InternalPort, m.mapping.ExternalPort)
+ return nil
+}
+
+func (m *Manager) cleanup(ctx context.Context, gateway nat.NAT) {
+ m.mappingLock.Lock()
+ mapping := m.mapping
+ m.mapping = nil
+ m.mappingLock.Unlock()
+
+ if mapping == nil {
+ return
+ }
+
+ if err := gateway.DeletePortMapping(ctx, mapping.Protocol, int(mapping.InternalPort)); err != nil {
+ log.Warnf("delete port mapping on stop: %v", err)
+ return
+ }
+
+ log.Infof("deleted port mapping for port %d", mapping.InternalPort)
+}
+
+func (m *Manager) startTearDown(ctx context.Context) {
+ select {
+ case m.stopCtx <- ctx:
+ default:
+ }
+}
+
+// isPermanentLeaseRequired checks if a UPnP error indicates the gateway only supports permanent leases (error 725).
+func isPermanentLeaseRequired(err error) bool {
+ return err != nil && upnpErrPermanentLeaseOnly.MatchString(err.Error())
+}
diff --git a/client/internal/portforward/manager_js.go b/client/internal/portforward/manager_js.go
new file mode 100644
index 000000000..36c55063b
--- /dev/null
+++ b/client/internal/portforward/manager_js.go
@@ -0,0 +1,39 @@
+package portforward
+
+import (
+ "context"
+ "net"
+ "time"
+)
+
+// Mapping represents an active NAT port mapping.
+type Mapping struct {
+ Protocol string
+ InternalPort uint16
+ ExternalPort uint16
+ ExternalIP net.IP
+ NATType string
+ // TTL is the lease duration. Zero means a permanent lease that never expires.
+ TTL time.Duration
+}
+
+// Manager is a stub for js/wasm builds where NAT-PMP/UPnP is not supported.
+type Manager struct{}
+
+// NewManager returns a stub manager for js/wasm builds.
+func NewManager() *Manager {
+ return &Manager{}
+}
+
+// Start is a no-op on js/wasm: NAT-PMP/UPnP is not available in browser environments.
+func (m *Manager) Start(context.Context, uint16) {
+ // no NAT traversal in wasm
+}
+
+// GracefullyStop is a no-op on js/wasm.
+func (m *Manager) GracefullyStop(context.Context) error { return nil }
+
+// GetMapping always returns nil on js/wasm.
+func (m *Manager) GetMapping() *Mapping {
+ return nil
+}
diff --git a/client/internal/portforward/manager_test.go b/client/internal/portforward/manager_test.go
new file mode 100644
index 000000000..1f66f9ccd
--- /dev/null
+++ b/client/internal/portforward/manager_test.go
@@ -0,0 +1,201 @@
+//go:build !js
+
+package portforward
+
+import (
+ "context"
+ "fmt"
+ "net"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+type mockNAT struct {
+ natType string
+ deviceAddr net.IP
+ externalAddr net.IP
+ internalAddr net.IP
+ mappings map[int]int
+ addMappingErr error
+ deleteMappingErr error
+ onlyPermanentLeases bool
+ lastTimeout time.Duration
+}
+
+func newMockNAT() *mockNAT {
+ return &mockNAT{
+ natType: "Mock-NAT",
+ deviceAddr: net.ParseIP("192.168.1.1"),
+ externalAddr: net.ParseIP("203.0.113.50"),
+ internalAddr: net.ParseIP("192.168.1.100"),
+ mappings: make(map[int]int),
+ }
+}
+
+func (m *mockNAT) Type() string {
+ return m.natType
+}
+
+func (m *mockNAT) GetDeviceAddress() (net.IP, error) {
+ return m.deviceAddr, nil
+}
+
+func (m *mockNAT) GetExternalAddress() (net.IP, error) {
+ return m.externalAddr, nil
+}
+
+func (m *mockNAT) GetInternalAddress() (net.IP, error) {
+ return m.internalAddr, nil
+}
+
+func (m *mockNAT) AddPortMapping(ctx context.Context, protocol string, internalPort int, description string, timeout time.Duration) (int, error) {
+ if m.addMappingErr != nil {
+ return 0, m.addMappingErr
+ }
+ if m.onlyPermanentLeases && timeout != 0 {
+ return 0, fmt.Errorf("SOAP fault. Code: | Explanation: | Detail: 725OnlyPermanentLeasesSupported")
+ }
+ externalPort := internalPort
+ m.mappings[internalPort] = externalPort
+ m.lastTimeout = timeout
+ return externalPort, nil
+}
+
+func (m *mockNAT) DeletePortMapping(ctx context.Context, protocol string, internalPort int) error {
+ if m.deleteMappingErr != nil {
+ return m.deleteMappingErr
+ }
+ delete(m.mappings, internalPort)
+ return nil
+}
+
+func TestManager_CreateMapping(t *testing.T) {
+ m := NewManager()
+ m.wgPort = 51820
+
+ gateway := newMockNAT()
+ mapping, err := m.createMapping(context.Background(), gateway)
+ require.NoError(t, err)
+ require.NotNil(t, mapping)
+
+ assert.Equal(t, "udp", mapping.Protocol)
+ assert.Equal(t, uint16(51820), mapping.InternalPort)
+ assert.Equal(t, uint16(51820), mapping.ExternalPort)
+ assert.Equal(t, "Mock-NAT", mapping.NATType)
+ assert.Equal(t, net.ParseIP("203.0.113.50").To4(), mapping.ExternalIP.To4())
+ assert.Equal(t, defaultMappingTTL, mapping.TTL)
+}
+
+func TestManager_GetMapping_ReturnsNilWhenNotReady(t *testing.T) {
+ m := NewManager()
+ assert.Nil(t, m.GetMapping())
+}
+
+func TestManager_GetMapping_ReturnsCopy(t *testing.T) {
+ m := NewManager()
+ m.mapping = &Mapping{
+ Protocol: "udp",
+ InternalPort: 51820,
+ ExternalPort: 51820,
+ }
+
+ mapping := m.GetMapping()
+ require.NotNil(t, mapping)
+ assert.Equal(t, uint16(51820), mapping.InternalPort)
+
+ // Mutating the returned copy should not affect the manager's mapping.
+ mapping.ExternalPort = 9999
+ assert.Equal(t, uint16(51820), m.GetMapping().ExternalPort)
+}
+
+func TestManager_Cleanup_DeletesMapping(t *testing.T) {
+ m := NewManager()
+ m.mapping = &Mapping{
+ Protocol: "udp",
+ InternalPort: 51820,
+ ExternalPort: 51820,
+ }
+
+ gateway := newMockNAT()
+ // Seed the mock so we can verify deletion.
+ gateway.mappings[51820] = 51820
+
+ m.cleanup(context.Background(), gateway)
+
+ _, exists := gateway.mappings[51820]
+ assert.False(t, exists, "mapping should be deleted from gateway")
+ assert.Nil(t, m.GetMapping(), "in-memory mapping should be cleared")
+}
+
+func TestManager_Cleanup_NilMapping(t *testing.T) {
+ m := NewManager()
+ gateway := newMockNAT()
+
+ // Should not panic or call gateway.
+ m.cleanup(context.Background(), gateway)
+}
+
+
+func TestManager_CreateMapping_PermanentLeaseFallback(t *testing.T) {
+ m := NewManager()
+ m.wgPort = 51820
+
+ gateway := newMockNAT()
+ gateway.onlyPermanentLeases = true
+
+ mapping, err := m.createMapping(context.Background(), gateway)
+ require.NoError(t, err)
+ require.NotNil(t, mapping)
+
+ assert.Equal(t, uint16(51820), mapping.InternalPort)
+ assert.Equal(t, time.Duration(0), mapping.TTL, "should return zero TTL for permanent lease")
+ assert.Equal(t, time.Duration(0), gateway.lastTimeout, "should have retried with zero duration")
+}
+
+func TestIsPermanentLeaseRequired(t *testing.T) {
+ tests := []struct {
+ name string
+ err error
+ expected bool
+ }{
+ {
+ name: "nil error",
+ err: nil,
+ expected: false,
+ },
+ {
+ name: "UPnP error 725",
+ err: fmt.Errorf("SOAP fault. Code: | Detail: 725OnlyPermanentLeasesSupported"),
+ expected: true,
+ },
+ {
+ name: "wrapped error with 725",
+ err: fmt.Errorf("add port mapping: %w", fmt.Errorf("Detail: 725")),
+ expected: true,
+ },
+ {
+ name: "error 725 with newlines in XML",
+ err: fmt.Errorf("\n 725\n"),
+ expected: true,
+ },
+ {
+ name: "bare 725 without XML tag",
+ err: fmt.Errorf("error code 725"),
+ expected: false,
+ },
+ {
+ name: "unrelated error",
+ err: fmt.Errorf("connection refused"),
+ expected: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ assert.Equal(t, tt.expected, isPermanentLeaseRequired(tt.err))
+ })
+ }
+}
diff --git a/client/internal/routemanager/notifier/notifier_ios.go b/client/internal/routemanager/notifier/notifier_ios.go
index bb125cfa4..343d2799e 100644
--- a/client/internal/routemanager/notifier/notifier_ios.go
+++ b/client/internal/routemanager/notifier/notifier_ios.go
@@ -53,7 +53,6 @@ func (n *Notifier) OnNewPrefixes(prefixes []netip.Prefix) {
n.currentPrefixes = newNets
n.notify()
}
-
func (n *Notifier) notify() {
n.listenerMux.Lock()
defer n.listenerMux.Unlock()
diff --git a/client/ios/NetBirdSDK/client.go b/client/ios/NetBirdSDK/client.go
index c73a0dcd1..f15b97d30 100644
--- a/client/ios/NetBirdSDK/client.go
+++ b/client/ios/NetBirdSDK/client.go
@@ -162,7 +162,11 @@ func (c *Client) Run(fd int32, interfaceName string, envList *EnvList) error {
cfg.WgIface = interfaceName
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
- return c.connectClient.RunOniOS(fd, c.networkChangeListener, c.dnsManager, c.stateFile)
+ hostDNS := []netip.AddrPort{
+ netip.MustParseAddrPort("9.9.9.9:53"),
+ netip.MustParseAddrPort("149.112.112.112:53"),
+ }
+ return c.connectClient.RunOniOS(fd, c.networkChangeListener, c.dnsManager, hostDNS, c.stateFile)
}
// Stop the internal client and free the resources
diff --git a/client/server/state_generic.go b/client/server/state_generic.go
index 980ba0cda..86475ca42 100644
--- a/client/server/state_generic.go
+++ b/client/server/state_generic.go
@@ -9,6 +9,7 @@ import (
"github.com/netbirdio/netbird/client/ssh/config"
)
+// registerStates registers all states that need crash recovery cleanup.
func registerStates(mgr *statemanager.Manager) {
mgr.RegisterState(&dns.ShutdownState{})
mgr.RegisterState(&systemops.ShutdownState{})
diff --git a/client/server/state_linux.go b/client/server/state_linux.go
index 019477d8e..b193d4dfa 100644
--- a/client/server/state_linux.go
+++ b/client/server/state_linux.go
@@ -11,6 +11,7 @@ import (
"github.com/netbirdio/netbird/client/ssh/config"
)
+// registerStates registers all states that need crash recovery cleanup.
func registerStates(mgr *statemanager.Manager) {
mgr.RegisterState(&dns.ShutdownState{})
mgr.RegisterState(&systemops.ShutdownState{})
diff --git a/client/ssh/proxy/proxy.go b/client/ssh/proxy/proxy.go
index 8897b9c7e..59007f75c 100644
--- a/client/ssh/proxy/proxy.go
+++ b/client/ssh/proxy/proxy.go
@@ -141,7 +141,7 @@ func (p *SSHProxy) runProxySSHServer(jwtToken string) error {
func (p *SSHProxy) handleSSHSession(session ssh.Session) {
ptyReq, winCh, isPty := session.Pty()
- hasCommand := len(session.Command()) > 0
+ hasCommand := session.RawCommand() != ""
sshClient, err := p.getOrCreateBackendClient(session.Context(), session.User())
if err != nil {
@@ -180,7 +180,7 @@ func (p *SSHProxy) handleSSHSession(session ssh.Session) {
}
if hasCommand {
- if err := serverSession.Run(strings.Join(session.Command(), " ")); err != nil {
+ if err := serverSession.Run(session.RawCommand()); err != nil {
log.Debugf("run command: %v", err)
p.handleProxyExitCode(session, err)
}
diff --git a/client/ssh/proxy/proxy_test.go b/client/ssh/proxy/proxy_test.go
index dba2e88da..b33d5f8f4 100644
--- a/client/ssh/proxy/proxy_test.go
+++ b/client/ssh/proxy/proxy_test.go
@@ -1,6 +1,7 @@
package proxy
import (
+ "bytes"
"context"
"crypto/rand"
"crypto/rsa"
@@ -245,6 +246,191 @@ func TestSSHProxy_Connect(t *testing.T) {
cancel()
}
+// TestSSHProxy_CommandQuoting verifies that the proxy preserves shell quoting
+// when forwarding commands to the backend. This is critical for tools like
+// Ansible that send commands such as:
+//
+// /bin/sh -c '( umask 77 && mkdir -p ... ) && sleep 0'
+//
+// The single quotes must be preserved so the backend shell receives the
+// subshell expression as a single argument to -c.
+func TestSSHProxy_CommandQuoting(t *testing.T) {
+ if testing.Short() {
+ t.Skip("Skipping integration test in short mode")
+ }
+
+ sshClient, cleanup := setupProxySSHClient(t)
+ defer cleanup()
+
+ // These commands simulate what the SSH protocol delivers as exec payloads.
+ // When a user types: ssh host '/bin/sh -c "( echo hello )"'
+ // the local shell strips the outer single quotes, and the SSH exec request
+ // contains the raw string: /bin/sh -c "( echo hello )"
+ //
+ // The proxy must forward this string verbatim. Using session.Command()
+ // (shlex.Split + strings.Join) strips the inner double quotes, breaking
+ // the command on the backend.
+ tests := []struct {
+ name string
+ command string
+ expect string
+ }{
+ {
+ name: "subshell_in_double_quotes",
+ command: `/bin/sh -c "( echo from-subshell ) && echo outer"`,
+ expect: "from-subshell\nouter\n",
+ },
+ {
+ name: "printf_with_special_chars",
+ command: `/bin/sh -c "printf '%s\n' 'hello world'"`,
+ expect: "hello world\n",
+ },
+ {
+ name: "nested_command_substitution",
+ command: `/bin/sh -c "echo $(echo nested)"`,
+ expect: "nested\n",
+ },
+ }
+
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ session, err := sshClient.NewSession()
+ require.NoError(t, err)
+ defer func() { _ = session.Close() }()
+
+ var stderrBuf bytes.Buffer
+ session.Stderr = &stderrBuf
+
+ outputCh := make(chan []byte, 1)
+ errCh := make(chan error, 1)
+ go func() {
+ output, err := session.Output(tc.command)
+ outputCh <- output
+ errCh <- err
+ }()
+
+ select {
+ case output := <-outputCh:
+ err := <-errCh
+ if stderrBuf.Len() > 0 {
+ t.Logf("stderr: %s", stderrBuf.String())
+ }
+ require.NoError(t, err, "command should succeed: %s", tc.command)
+ assert.Equal(t, tc.expect, string(output), "output mismatch for: %s", tc.command)
+ case <-time.After(5 * time.Second):
+ t.Fatalf("command timed out: %s", tc.command)
+ }
+ })
+ }
+}
+
+// setupProxySSHClient creates a full proxy test environment and returns
+// an SSH client connected through the proxy to a backend NetBird SSH server.
+func setupProxySSHClient(t *testing.T) (*cryptossh.Client, func()) {
+ t.Helper()
+
+ const (
+ issuer = "https://test-issuer.example.com"
+ audience = "test-audience"
+ )
+
+ jwksServer, privateKey, jwksURL := setupJWKSServer(t)
+
+ hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
+ require.NoError(t, err)
+ hostPubKey, err := nbssh.GeneratePublicKey(hostKey)
+ require.NoError(t, err)
+
+ serverConfig := &server.Config{
+ HostKeyPEM: hostKey,
+ JWT: &server.JWTConfig{
+ Issuer: issuer,
+ Audiences: []string{audience},
+ KeysLocation: jwksURL,
+ },
+ }
+ sshServer := server.New(serverConfig)
+ sshServer.SetAllowRootLogin(true)
+
+ testUsername := testutil.GetTestUsername(t)
+ testJWTUser := "test-username"
+ testUserHash, err := sshuserhash.HashUserID(testJWTUser)
+ require.NoError(t, err)
+
+ authConfig := &sshauth.Config{
+ UserIDClaim: sshauth.DefaultUserIDClaim,
+ AuthorizedUsers: []sshuserhash.UserIDHash{testUserHash},
+ MachineUsers: map[string][]uint32{
+ testUsername: {0},
+ },
+ }
+ sshServer.UpdateSSHAuth(authConfig)
+
+ sshServerAddr := server.StartTestServer(t, sshServer)
+
+ mockDaemon := startMockDaemon(t)
+
+ host, portStr, err := net.SplitHostPort(sshServerAddr)
+ require.NoError(t, err)
+ port, err := strconv.Atoi(portStr)
+ require.NoError(t, err)
+
+ mockDaemon.setHostKey(host, hostPubKey)
+
+ validToken := generateValidJWT(t, privateKey, issuer, audience, testJWTUser)
+ mockDaemon.setJWTToken(validToken)
+
+ proxyInstance, err := New(mockDaemon.addr, host, port, io.Discard, nil)
+ require.NoError(t, err)
+
+ origStdin := os.Stdin
+ origStdout := os.Stdout
+
+ stdinReader, stdinWriter, err := os.Pipe()
+ require.NoError(t, err)
+ stdoutReader, stdoutWriter, err := os.Pipe()
+ require.NoError(t, err)
+
+ os.Stdin = stdinReader
+ os.Stdout = stdoutWriter
+
+ clientConn, proxyConn := net.Pipe()
+
+ go func() { _, _ = io.Copy(stdinWriter, proxyConn) }()
+ go func() { _, _ = io.Copy(proxyConn, stdoutReader) }()
+
+ ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
+
+ go func() {
+ _ = proxyInstance.Connect(ctx)
+ }()
+
+ sshConfig := &cryptossh.ClientConfig{
+ User: testutil.GetTestUsername(t),
+ Auth: []cryptossh.AuthMethod{},
+ HostKeyCallback: cryptossh.InsecureIgnoreHostKey(),
+ Timeout: 5 * time.Second,
+ }
+
+ sshClientConn, chans, reqs, err := cryptossh.NewClientConn(clientConn, "test", sshConfig)
+ require.NoError(t, err)
+
+ client := cryptossh.NewClient(sshClientConn, chans, reqs)
+
+ cleanupFn := func() {
+ _ = client.Close()
+ _ = clientConn.Close()
+ cancel()
+ os.Stdin = origStdin
+ os.Stdout = origStdout
+ _ = sshServer.Stop()
+ mockDaemon.stop()
+ jwksServer.Close()
+ }
+
+ return client, cleanupFn
+}
+
type mockDaemonServer struct {
proto.UnimplementedDaemonServiceServer
hostKeys map[string][]byte
diff --git a/client/ssh/server/session_handlers.go b/client/ssh/server/session_handlers.go
index f12a75961..0e531bb96 100644
--- a/client/ssh/server/session_handlers.go
+++ b/client/ssh/server/session_handlers.go
@@ -60,7 +60,7 @@ func (s *Server) sessionHandler(session ssh.Session) {
}
ptyReq, winCh, isPty := session.Pty()
- hasCommand := len(session.Command()) > 0
+ hasCommand := session.RawCommand() != ""
if isPty && !hasCommand {
// ssh - PTY interactive session (login)
diff --git a/client/system/info_freebsd.go b/client/system/info_freebsd.go
index 8e1353151..755172842 100644
--- a/client/system/info_freebsd.go
+++ b/client/system/info_freebsd.go
@@ -43,18 +43,24 @@ func GetInfo(ctx context.Context) *Info {
systemHostname, _ := os.Hostname()
+ addrs, err := networkAddresses()
+ if err != nil {
+ log.Warnf("failed to discover network addresses: %s", err)
+ }
+
return &Info{
- GoOS: runtime.GOOS,
- Kernel: osInfo[0],
- Platform: runtime.GOARCH,
- OS: osName,
- OSVersion: osVersion,
- Hostname: extractDeviceName(ctx, systemHostname),
- CPUs: runtime.NumCPU(),
- NetbirdVersion: version.NetbirdVersion(),
- UIVersion: extractUserAgent(ctx),
- KernelVersion: osInfo[1],
- Environment: env,
+ GoOS: runtime.GOOS,
+ Kernel: osInfo[0],
+ Platform: runtime.GOARCH,
+ OS: osName,
+ OSVersion: osVersion,
+ Hostname: extractDeviceName(ctx, systemHostname),
+ CPUs: runtime.NumCPU(),
+ NetbirdVersion: version.NetbirdVersion(),
+ UIVersion: extractUserAgent(ctx),
+ KernelVersion: osInfo[1],
+ NetworkAddresses: addrs,
+ Environment: env,
}
}
diff --git a/client/ui/debug.go b/client/ui/debug.go
index 29f73a66a..4ebe4d675 100644
--- a/client/ui/debug.go
+++ b/client/ui/debug.go
@@ -24,9 +24,10 @@ import (
// Initial state for the debug collection
type debugInitialState struct {
- wasDown bool
- logLevel proto.LogLevel
- isLevelTrace bool
+ wasDown bool
+ needsRestoreUp bool
+ logLevel proto.LogLevel
+ isLevelTrace bool
}
// Debug collection parameters
@@ -371,46 +372,51 @@ func (s *serviceClient) configureServiceForDebug(
conn proto.DaemonServiceClient,
state *debugInitialState,
enablePersistence bool,
-) error {
+) {
if state.wasDown {
if _, err := conn.Up(s.ctx, &proto.UpRequest{}); err != nil {
- return fmt.Errorf("bring service up: %v", err)
+ log.Warnf("failed to bring service up: %v", err)
+ } else {
+ log.Info("Service brought up for debug")
+ time.Sleep(time.Second * 10)
}
- log.Info("Service brought up for debug")
- time.Sleep(time.Second * 10)
}
if !state.isLevelTrace {
if _, err := conn.SetLogLevel(s.ctx, &proto.SetLogLevelRequest{Level: proto.LogLevel_TRACE}); err != nil {
- return fmt.Errorf("set log level to TRACE: %v", err)
+ log.Warnf("failed to set log level to TRACE: %v", err)
+ } else {
+ log.Info("Log level set to TRACE for debug")
}
- log.Info("Log level set to TRACE for debug")
}
if _, err := conn.Down(s.ctx, &proto.DownRequest{}); err != nil {
- return fmt.Errorf("bring service down: %v", err)
+ log.Warnf("failed to bring service down: %v", err)
+ } else {
+ state.needsRestoreUp = !state.wasDown
+ time.Sleep(time.Second)
}
- time.Sleep(time.Second)
if enablePersistence {
if _, err := conn.SetSyncResponsePersistence(s.ctx, &proto.SetSyncResponsePersistenceRequest{
Enabled: true,
}); err != nil {
- return fmt.Errorf("enable sync response persistence: %v", err)
+ log.Warnf("failed to enable sync response persistence: %v", err)
+ } else {
+ log.Info("Sync response persistence enabled for debug")
}
- log.Info("Sync response persistence enabled for debug")
}
if _, err := conn.Up(s.ctx, &proto.UpRequest{}); err != nil {
- return fmt.Errorf("bring service back up: %v", err)
+ log.Warnf("failed to bring service back up: %v", err)
+ } else {
+ state.needsRestoreUp = false
+ time.Sleep(time.Second * 3)
}
- time.Sleep(time.Second * 3)
if _, err := conn.StartCPUProfile(s.ctx, &proto.StartCPUProfileRequest{}); err != nil {
log.Warnf("failed to start CPU profiling: %v", err)
}
-
- return nil
}
func (s *serviceClient) collectDebugData(
@@ -424,9 +430,7 @@ func (s *serviceClient) collectDebugData(
var wg sync.WaitGroup
startProgressTracker(ctx, &wg, params.duration, progress)
- if err := s.configureServiceForDebug(conn, state, params.enablePersistence); err != nil {
- return err
- }
+ s.configureServiceForDebug(conn, state, params.enablePersistence)
wg.Wait()
progress.progressBar.Hide()
@@ -482,9 +486,17 @@ func (s *serviceClient) createDebugBundleFromCollection(
// Restore service to original state
func (s *serviceClient) restoreServiceState(conn proto.DaemonServiceClient, state *debugInitialState) {
+ if state.needsRestoreUp {
+ if _, err := conn.Up(s.ctx, &proto.UpRequest{}); err != nil {
+ log.Warnf("failed to restore up state: %v", err)
+ } else {
+ log.Info("Service state restored to up")
+ }
+ }
+
if state.wasDown {
if _, err := conn.Down(s.ctx, &proto.DownRequest{}); err != nil {
- log.Errorf("Failed to restore down state: %v", err)
+ log.Warnf("failed to restore down state: %v", err)
} else {
log.Info("Service state restored to down")
}
@@ -492,7 +504,7 @@ func (s *serviceClient) restoreServiceState(conn proto.DaemonServiceClient, stat
if !state.isLevelTrace {
if _, err := conn.SetLogLevel(s.ctx, &proto.SetLogLevelRequest{Level: state.logLevel}); err != nil {
- log.Errorf("Failed to restore log level: %v", err)
+ log.Warnf("failed to restore log level: %v", err)
} else {
log.Info("Log level restored to original setting")
}
diff --git a/combined/cmd/root.go b/combined/cmd/root.go
index ea1ff908a..db986b4d4 100644
--- a/combined/cmd/root.go
+++ b/combined/cmd/root.go
@@ -29,6 +29,7 @@ import (
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/relay/healthcheck"
relayServer "github.com/netbirdio/netbird/relay/server"
+ "github.com/netbirdio/netbird/relay/server/listener"
"github.com/netbirdio/netbird/relay/server/listener/ws"
sharedMetrics "github.com/netbirdio/netbird/shared/metrics"
"github.com/netbirdio/netbird/shared/relay/auth"
@@ -523,7 +524,7 @@ func createManagementServer(cfg *CombinedConfig, mgmtConfig *nbconfig.Config) (*
func createCombinedHandler(grpcServer *grpc.Server, httpHandler http.Handler, relaySrv *relayServer.Server, meter metric.Meter, cfg *CombinedConfig) http.Handler {
wsProxy := wsproxyserver.New(grpcServer, wsproxyserver.WithOTelMeter(meter))
- var relayAcceptFn func(conn net.Conn)
+ var relayAcceptFn func(conn listener.Conn)
if relaySrv != nil {
relayAcceptFn = relaySrv.RelayAccept()
}
@@ -563,7 +564,7 @@ func createCombinedHandler(grpcServer *grpc.Server, httpHandler http.Handler, re
}
// handleRelayWebSocket handles incoming WebSocket connections for the relay service
-func handleRelayWebSocket(w http.ResponseWriter, r *http.Request, acceptFn func(conn net.Conn), cfg *CombinedConfig) {
+func handleRelayWebSocket(w http.ResponseWriter, r *http.Request, acceptFn func(conn listener.Conn), cfg *CombinedConfig) {
acceptOptions := &websocket.AcceptOptions{
OriginPatterns: []string{"*"},
}
@@ -585,15 +586,9 @@ func handleRelayWebSocket(w http.ResponseWriter, r *http.Request, acceptFn func(
return
}
- lAddr, err := net.ResolveTCPAddr("tcp", cfg.Server.ListenAddress)
- if err != nil {
- _ = wsConn.Close(websocket.StatusInternalError, "internal error")
- return
- }
-
log.Debugf("Relay WS client connected from: %s", rAddr)
- conn := ws.NewConn(wsConn, lAddr, rAddr)
+ conn := ws.NewConn(wsConn, rAddr)
acceptFn(conn)
}
diff --git a/go.mod b/go.mod
index e9334f85b..a95192600 100644
--- a/go.mod
+++ b/go.mod
@@ -63,6 +63,7 @@ require (
github.com/hashicorp/go-version v1.6.0
github.com/jackc/pgx/v5 v5.5.5
github.com/libdns/route53 v1.5.0
+ github.com/libp2p/go-nat v0.2.0
github.com/libp2p/go-netroute v0.2.1
github.com/lrh3321/ipset-go v0.0.0-20250619021614-54a0a98ace81
github.com/mdlayher/socket v0.5.1
@@ -200,10 +201,12 @@ require (
github.com/hashicorp/errwrap v1.1.0 // indirect
github.com/hashicorp/go-uuid v1.0.3 // indirect
github.com/huandu/xstrings v1.5.0 // indirect
+ github.com/huin/goupnp v1.2.0 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect
github.com/jackc/puddle/v2 v2.2.1 // indirect
+ github.com/jackpal/go-nat-pmp v1.0.2 // indirect
github.com/jeandeaual/go-locale v0.0.0-20250612000132-0ef82f21eade // indirect
github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect
@@ -213,6 +216,7 @@ require (
github.com/kelseyhightower/envconfig v1.4.0 // indirect
github.com/klauspost/compress v1.18.0 // indirect
github.com/klauspost/cpuid/v2 v2.2.7 // indirect
+ github.com/koron/go-ssdp v0.0.4 // indirect
github.com/kr/fs v0.1.0 // indirect
github.com/lib/pq v1.10.9 // indirect
github.com/libdns/libdns v0.2.2 // indirect
diff --git a/go.sum b/go.sum
index 629388ccb..a1d2bb71f 100644
--- a/go.sum
+++ b/go.sum
@@ -281,6 +281,8 @@ github.com/hashicorp/go-version v1.6.0/go.mod h1:fltr4n8CU8Ke44wwGCBoEymUuxUHl09
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
github.com/huandu/xstrings v1.5.0 h1:2ag3IFq9ZDANvthTwTiqSSZLjDc+BedvHPAp5tJy2TI=
github.com/huandu/xstrings v1.5.0/go.mod h1:y5/lhBue+AyNmUVz9RLU9xbLR0o4KIIExikq4ovT0aE=
+github.com/huin/goupnp v1.2.0 h1:uOKW26NG1hsSSbXIZ1IR7XP9Gjd1U8pnLaCMgntmkmY=
+github.com/huin/goupnp v1.2.0/go.mod h1:gnGPsThkYa7bFi/KWmEysQRf48l2dvR5bxr2OFckNX8=
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
@@ -291,6 +293,8 @@ github.com/jackc/pgx/v5 v5.5.5 h1:amBjrZVmksIdNjxGW/IiIMzxMKZFelXbUoPNb+8sjQw=
github.com/jackc/pgx/v5 v5.5.5/go.mod h1:ez9gk+OAat140fv9ErkZDYFWmXLfV+++K0uAOiwgm1A=
github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk=
github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
+github.com/jackpal/go-nat-pmp v1.0.2 h1:KzKSgb7qkJvOUTqYl9/Hg/me3pWgBmERKrTGD7BdWus=
+github.com/jackpal/go-nat-pmp v1.0.2/go.mod h1:QPH045xvCAeXUZOxsnwmrtiCoxIr9eob+4orBN1SBKc=
github.com/jcmturner/aescts/v2 v2.0.0 h1:9YKLH6ey7H4eDBXW8khjYslgyqG2xZikXP0EQFKrle8=
github.com/jcmturner/aescts/v2 v2.0.0/go.mod h1:AiaICIRyfYg35RUkr8yESTqvSy7csK90qZ5xfvvsoNs=
github.com/jcmturner/dnsutils/v2 v2.0.0 h1:lltnkeZGL0wILNvrNiVCR6Ro5PGU/SeBvVO/8c/iPbo=
@@ -328,6 +332,8 @@ github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYW
github.com/klauspost/cpuid/v2 v2.0.12/go.mod h1:g2LTdtYhdyuGPqyWyv7qRAmj1WBqxuObKfj5c0PQa7c=
github.com/klauspost/cpuid/v2 v2.2.7 h1:ZWSB3igEs+d0qvnxR/ZBzXVmxkgt8DdzP6m9pfuVLDM=
github.com/klauspost/cpuid/v2 v2.2.7/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws=
+github.com/koron/go-ssdp v0.0.4 h1:1IDwrghSKYM7yLf7XCzbByg2sJ/JcNOZRXS2jczTwz0=
+github.com/koron/go-ssdp v0.0.4/go.mod h1:oDXq+E5IL5q0U8uSBcoAXzTzInwy5lEgC91HoKtbmZk=
github.com/kr/fs v0.1.0 h1:Jskdu9ieNAYnjxsi0LbQp1ulIKZV1LAFgK1tWhpZgl8=
github.com/kr/fs v0.1.0/go.mod h1:FFnZGqtBN9Gxj7eW1uZ42v5BccTP0vu6NEaFoC2HwRg=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
@@ -346,6 +352,8 @@ github.com/libdns/libdns v0.2.2 h1:O6ws7bAfRPaBsgAYt8MDe2HcNBGC29hkZ9MX2eUSX3s=
github.com/libdns/libdns v0.2.2/go.mod h1:4Bj9+5CQiNMVGf87wjX4CY3HQJypUHRuLvlsfsZqLWQ=
github.com/libdns/route53 v1.5.0 h1:2SKdpPFl/qgWsXQvsLNJJAoX7rSxlk7zgoL4jnWdXVA=
github.com/libdns/route53 v1.5.0/go.mod h1:joT4hKmaTNKHEwb7GmZ65eoDz1whTu7KKYPS8ZqIh6Q=
+github.com/libp2p/go-nat v0.2.0 h1:Tyz+bUFAYqGyJ/ppPPymMGbIgNRH+WqC5QrT5fKrrGk=
+github.com/libp2p/go-nat v0.2.0/go.mod h1:3MJr+GRpRkyT65EpVPBstXLvOlAPzUVlG6Pwg9ohLJk=
github.com/lrh3321/ipset-go v0.0.0-20250619021614-54a0a98ace81 h1:J56rFEfUTFT9j9CiRXhi1r8lUJ4W5idG3CiaBZGojNU=
github.com/lrh3321/ipset-go v0.0.0-20250619021614-54a0a98ace81/go.mod h1:RD8ML/YdXctQ7qbcizZkw5mZ6l8Ogrl1dodBzVJduwI=
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I=
diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go
index 3b4b31765..524b099f1 100644
--- a/management/server/store/sql_store.go
+++ b/management/server/store/sql_store.go
@@ -2115,6 +2115,7 @@ func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*rpserv
var createdAt, certIssuedAt sql.NullTime
var status, proxyCluster, sessionPrivateKey, sessionPublicKey sql.NullString
var mode, source, sourcePeer sql.NullString
+ var terminated sql.NullBool
err := row.Scan(
&s.ID,
&s.AccountID,
@@ -2135,7 +2136,7 @@ func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*rpserv
&s.PortAutoAssigned,
&source,
&sourcePeer,
- &s.Terminated,
+ &terminated,
)
if err != nil {
return nil, err
@@ -2176,7 +2177,9 @@ func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*rpserv
if sourcePeer.Valid {
s.SourcePeer = sourcePeer.String
}
-
+ if terminated.Valid {
+ s.Terminated = terminated.Bool
+ }
s.Targets = []*rpservice.Target{}
return &s, nil
})
diff --git a/management/server/types/networkmap_benchmark_test.go b/management/server/types/networkmap_benchmark_test.go
new file mode 100644
index 000000000..38272e7b0
--- /dev/null
+++ b/management/server/types/networkmap_benchmark_test.go
@@ -0,0 +1,217 @@
+package types_test
+
+import (
+ "context"
+ "fmt"
+ "os"
+ "testing"
+
+ nbdns "github.com/netbirdio/netbird/dns"
+ "github.com/netbirdio/netbird/management/server/types"
+)
+
+type benchmarkScale struct {
+ name string
+ peers int
+ groups int
+}
+
+var defaultScales = []benchmarkScale{
+ {"100peers_5groups", 100, 5},
+ {"500peers_20groups", 500, 20},
+ {"1000peers_50groups", 1000, 50},
+ {"5000peers_100groups", 5000, 100},
+ {"10000peers_200groups", 10000, 200},
+ {"20000peers_200groups", 20000, 200},
+ {"30000peers_300groups", 30000, 300},
+}
+
+func skipCIBenchmark(b *testing.B) {
+ if os.Getenv("CI") == "true" {
+ b.Skip("Skipping benchmark in CI")
+ }
+}
+
+// ──────────────────────────────────────────────────────────────────────────────
+// Single Peer Network Map Generation
+// ──────────────────────────────────────────────────────────────────────────────
+
+// BenchmarkNetworkMapGeneration_Components benchmarks the components-based approach for a single peer.
+func BenchmarkNetworkMapGeneration_Components(b *testing.B) {
+ skipCIBenchmark(b)
+ for _, scale := range defaultScales {
+ b.Run(scale.name, func(b *testing.B) {
+ account, validatedPeers := scalableTestAccount(scale.peers, scale.groups)
+ ctx := context.Background()
+ resourcePolicies := account.GetResourcePoliciesMap()
+ routers := account.GetResourceRoutersMap()
+ groupIDToUserIDs := account.GetActiveGroupUsers()
+
+ b.ReportAllocs()
+ b.ResetTimer()
+ for range b.N {
+ _ = account.GetPeerNetworkMapFromComponents(ctx, "peer-0", nbdns.CustomZone{}, nil, validatedPeers, resourcePolicies, routers, nil, groupIDToUserIDs)
+ }
+ })
+ }
+}
+
+// ──────────────────────────────────────────────────────────────────────────────
+// All Peers (UpdateAccountPeers hot path)
+// ──────────────────────────────────────────────────────────────────────────────
+
+// BenchmarkNetworkMapGeneration_AllPeers benchmarks generating network maps for ALL peers.
+func BenchmarkNetworkMapGeneration_AllPeers(b *testing.B) {
+ skipCIBenchmark(b)
+ scales := []benchmarkScale{
+ {"100peers_5groups", 100, 5},
+ {"500peers_20groups", 500, 20},
+ {"1000peers_50groups", 1000, 50},
+ {"5000peers_100groups", 5000, 100},
+ }
+
+ for _, scale := range scales {
+ account, validatedPeers := scalableTestAccount(scale.peers, scale.groups)
+ ctx := context.Background()
+
+ peerIDs := make([]string, 0, len(account.Peers))
+ for peerID := range account.Peers {
+ peerIDs = append(peerIDs, peerID)
+ }
+
+ b.Run("components/"+scale.name, func(b *testing.B) {
+ resourcePolicies := account.GetResourcePoliciesMap()
+ routers := account.GetResourceRoutersMap()
+ groupIDToUserIDs := account.GetActiveGroupUsers()
+ b.ReportAllocs()
+ b.ResetTimer()
+ for range b.N {
+ for _, peerID := range peerIDs {
+ _ = account.GetPeerNetworkMapFromComponents(ctx, peerID, nbdns.CustomZone{}, nil, validatedPeers, resourcePolicies, routers, nil, groupIDToUserIDs)
+ }
+ }
+ })
+ }
+}
+
+// ──────────────────────────────────────────────────────────────────────────────
+// Sub-operations
+// ──────────────────────────────────────────────────────────────────────────────
+
+// BenchmarkNetworkMapGeneration_ComponentsCreation benchmarks components extraction.
+func BenchmarkNetworkMapGeneration_ComponentsCreation(b *testing.B) {
+ skipCIBenchmark(b)
+ for _, scale := range defaultScales {
+ b.Run(scale.name, func(b *testing.B) {
+ account, validatedPeers := scalableTestAccount(scale.peers, scale.groups)
+ ctx := context.Background()
+ resourcePolicies := account.GetResourcePoliciesMap()
+ routers := account.GetResourceRoutersMap()
+ groupIDToUserIDs := account.GetActiveGroupUsers()
+ b.ReportAllocs()
+ b.ResetTimer()
+ for range b.N {
+ _ = account.GetPeerNetworkMapComponents(ctx, "peer-0", nbdns.CustomZone{}, nil, validatedPeers, resourcePolicies, routers, groupIDToUserIDs)
+ }
+ })
+ }
+}
+
+// BenchmarkNetworkMapGeneration_ComponentsCalculation benchmarks calculation from pre-built components.
+func BenchmarkNetworkMapGeneration_ComponentsCalculation(b *testing.B) {
+ skipCIBenchmark(b)
+ for _, scale := range defaultScales {
+ b.Run(scale.name, func(b *testing.B) {
+ account, validatedPeers := scalableTestAccount(scale.peers, scale.groups)
+ ctx := context.Background()
+ resourcePolicies := account.GetResourcePoliciesMap()
+ routers := account.GetResourceRoutersMap()
+ groupIDToUserIDs := account.GetActiveGroupUsers()
+ components := account.GetPeerNetworkMapComponents(ctx, "peer-0", nbdns.CustomZone{}, nil, validatedPeers, resourcePolicies, routers, groupIDToUserIDs)
+ b.ReportAllocs()
+ b.ResetTimer()
+ for range b.N {
+ _ = types.CalculateNetworkMapFromComponents(ctx, components)
+ }
+ })
+ }
+}
+
+// BenchmarkNetworkMapGeneration_PrecomputeMaps benchmarks precomputed map costs.
+func BenchmarkNetworkMapGeneration_PrecomputeMaps(b *testing.B) {
+ skipCIBenchmark(b)
+ for _, scale := range defaultScales {
+ b.Run("ResourcePoliciesMap/"+scale.name, func(b *testing.B) {
+ account, _ := scalableTestAccount(scale.peers, scale.groups)
+ b.ReportAllocs()
+ b.ResetTimer()
+ for range b.N {
+ _ = account.GetResourcePoliciesMap()
+ }
+ })
+ b.Run("ResourceRoutersMap/"+scale.name, func(b *testing.B) {
+ account, _ := scalableTestAccount(scale.peers, scale.groups)
+ b.ReportAllocs()
+ b.ResetTimer()
+ for range b.N {
+ _ = account.GetResourceRoutersMap()
+ }
+ })
+ b.Run("ActiveGroupUsers/"+scale.name, func(b *testing.B) {
+ account, _ := scalableTestAccount(scale.peers, scale.groups)
+ b.ReportAllocs()
+ b.ResetTimer()
+ for range b.N {
+ _ = account.GetActiveGroupUsers()
+ }
+ })
+ }
+}
+
+// ──────────────────────────────────────────────────────────────────────────────
+// Scaling Analysis
+// ──────────────────────────────────────────────────────────────────────────────
+
+// BenchmarkNetworkMapGeneration_GroupScaling tests group count impact on performance.
+func BenchmarkNetworkMapGeneration_GroupScaling(b *testing.B) {
+ skipCIBenchmark(b)
+ groupCounts := []int{1, 5, 20, 50, 100, 200, 500}
+ for _, numGroups := range groupCounts {
+ b.Run(fmt.Sprintf("components_%dgroups", numGroups), func(b *testing.B) {
+ account, validatedPeers := scalableTestAccount(1000, numGroups)
+ ctx := context.Background()
+ resourcePolicies := account.GetResourcePoliciesMap()
+ routers := account.GetResourceRoutersMap()
+ groupIDToUserIDs := account.GetActiveGroupUsers()
+ b.ReportAllocs()
+ b.ResetTimer()
+ for range b.N {
+ _ = account.GetPeerNetworkMapFromComponents(ctx, "peer-0", nbdns.CustomZone{}, nil, validatedPeers, resourcePolicies, routers, nil, groupIDToUserIDs)
+ }
+ })
+ }
+}
+
+// BenchmarkNetworkMapGeneration_PeerScaling tests peer count impact on performance.
+func BenchmarkNetworkMapGeneration_PeerScaling(b *testing.B) {
+ skipCIBenchmark(b)
+ peerCounts := []int{50, 100, 500, 1000, 2000, 5000, 10000, 20000, 30000}
+ for _, numPeers := range peerCounts {
+ numGroups := numPeers / 20
+ if numGroups < 1 {
+ numGroups = 1
+ }
+ b.Run(fmt.Sprintf("components_%dpeers", numPeers), func(b *testing.B) {
+ account, validatedPeers := scalableTestAccount(numPeers, numGroups)
+ ctx := context.Background()
+ resourcePolicies := account.GetResourcePoliciesMap()
+ routers := account.GetResourceRoutersMap()
+ groupIDToUserIDs := account.GetActiveGroupUsers()
+ b.ReportAllocs()
+ b.ResetTimer()
+ for range b.N {
+ _ = account.GetPeerNetworkMapFromComponents(ctx, "peer-0", nbdns.CustomZone{}, nil, validatedPeers, resourcePolicies, routers, nil, groupIDToUserIDs)
+ }
+ })
+ }
+}
diff --git a/management/server/types/networkmap_components_correctness_test.go b/management/server/types/networkmap_components_correctness_test.go
new file mode 100644
index 000000000..5cd41ff10
--- /dev/null
+++ b/management/server/types/networkmap_components_correctness_test.go
@@ -0,0 +1,1192 @@
+package types_test
+
+import (
+ "context"
+ "fmt"
+ "net"
+ "net/netip"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+
+ nbdns "github.com/netbirdio/netbird/dns"
+ resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
+ routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
+ networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
+ nbpeer "github.com/netbirdio/netbird/management/server/peer"
+ "github.com/netbirdio/netbird/management/server/posture"
+ "github.com/netbirdio/netbird/management/server/types"
+ "github.com/netbirdio/netbird/route"
+)
+
+// scalableTestAccountWithoutDefaultPolicy creates an account without the blanket "Allow All" policy.
+// Use this for tests that need to verify feature-specific connectivity in isolation.
+func scalableTestAccountWithoutDefaultPolicy(numPeers, numGroups int) (*types.Account, map[string]struct{}) {
+ return buildScalableTestAccount(numPeers, numGroups, false)
+}
+
+// scalableTestAccount creates a realistic account with a blanket "Allow All" policy
+// plus per-group policies, routes, network resources, posture checks, and DNS settings.
+func scalableTestAccount(numPeers, numGroups int) (*types.Account, map[string]struct{}) {
+ return buildScalableTestAccount(numPeers, numGroups, true)
+}
+
+// buildScalableTestAccount is the core builder. When withDefaultPolicy is true it adds
+// a blanket group-all <-> group-all allow rule; when false the only policies are the
+// per-group ones, so tests can verify feature-specific connectivity in isolation.
+func buildScalableTestAccount(numPeers, numGroups int, withDefaultPolicy bool) (*types.Account, map[string]struct{}) {
+ peers := make(map[string]*nbpeer.Peer, numPeers)
+ allGroupPeers := make([]string, 0, numPeers)
+
+ for i := range numPeers {
+ peerID := fmt.Sprintf("peer-%d", i)
+ ip := net.IP{100, byte(64 + i/65536), byte((i / 256) % 256), byte(i % 256)}
+ wtVersion := "0.25.0"
+ if i%2 == 0 {
+ wtVersion = "0.40.0"
+ }
+
+ p := &nbpeer.Peer{
+ ID: peerID,
+ IP: ip,
+ Key: fmt.Sprintf("key-%s", peerID),
+ DNSLabel: fmt.Sprintf("peer%d", i),
+ Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()},
+ UserID: "user-admin",
+ Meta: nbpeer.PeerSystemMeta{WtVersion: wtVersion, GoOS: "linux"},
+ }
+
+ if i == numPeers-2 {
+ p.LoginExpirationEnabled = true
+ pastTimestamp := time.Now().Add(-2 * time.Hour)
+ p.LastLogin = &pastTimestamp
+ }
+
+ peers[peerID] = p
+ allGroupPeers = append(allGroupPeers, peerID)
+ }
+
+ groups := make(map[string]*types.Group, numGroups+1)
+ groups["group-all"] = &types.Group{ID: "group-all", Name: "All", Peers: allGroupPeers}
+
+ peersPerGroup := numPeers / numGroups
+ if peersPerGroup < 1 {
+ peersPerGroup = 1
+ }
+
+ for g := range numGroups {
+ groupID := fmt.Sprintf("group-%d", g)
+ groupPeers := make([]string, 0, peersPerGroup)
+ start := g * peersPerGroup
+ end := start + peersPerGroup
+ if end > numPeers {
+ end = numPeers
+ }
+ for i := start; i < end; i++ {
+ groupPeers = append(groupPeers, fmt.Sprintf("peer-%d", i))
+ }
+ groups[groupID] = &types.Group{ID: groupID, Name: fmt.Sprintf("Group %d", g), Peers: groupPeers}
+ }
+
+ policies := make([]*types.Policy, 0, numGroups+2)
+ if withDefaultPolicy {
+ policies = append(policies, &types.Policy{
+ ID: "policy-all", Name: "Default-Allow", Enabled: true,
+ Rules: []*types.PolicyRule{{
+ ID: "rule-all", Name: "Allow All", Enabled: true, Action: types.PolicyTrafficActionAccept,
+ Protocol: types.PolicyRuleProtocolALL, Bidirectional: true,
+ Sources: []string{"group-all"}, Destinations: []string{"group-all"},
+ }},
+ })
+ }
+
+ for g := range numGroups {
+ groupID := fmt.Sprintf("group-%d", g)
+ dstGroup := fmt.Sprintf("group-%d", (g+1)%numGroups)
+ policies = append(policies, &types.Policy{
+ ID: fmt.Sprintf("policy-%d", g), Name: fmt.Sprintf("Policy %d", g), Enabled: true,
+ Rules: []*types.PolicyRule{{
+ ID: fmt.Sprintf("rule-%d", g), Name: fmt.Sprintf("Rule %d", g), Enabled: true,
+ Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolTCP,
+ Bidirectional: true,
+ Ports: []string{"8080"},
+ Sources: []string{groupID}, Destinations: []string{dstGroup},
+ }},
+ })
+ }
+
+ if numGroups >= 2 {
+ policies = append(policies, &types.Policy{
+ ID: "policy-drop", Name: "Drop DB traffic", Enabled: true,
+ Rules: []*types.PolicyRule{{
+ ID: "rule-drop", Name: "Drop DB", Enabled: true, Action: types.PolicyTrafficActionDrop,
+ Protocol: types.PolicyRuleProtocolTCP, Ports: []string{"5432"}, Bidirectional: true,
+ Sources: []string{"group-0"}, Destinations: []string{"group-1"},
+ }},
+ })
+ }
+
+ numRoutes := numGroups
+ if numRoutes > 20 {
+ numRoutes = 20
+ }
+ routes := make(map[route.ID]*route.Route, numRoutes)
+ for r := range numRoutes {
+ routeID := route.ID(fmt.Sprintf("route-%d", r))
+ peerIdx := (numPeers / 2) + r
+ if peerIdx >= numPeers {
+ peerIdx = numPeers - 1
+ }
+ routePeerID := fmt.Sprintf("peer-%d", peerIdx)
+ groupID := fmt.Sprintf("group-%d", r%numGroups)
+ routes[routeID] = &route.Route{
+ ID: routeID,
+ Network: netip.MustParsePrefix(fmt.Sprintf("10.%d.0.0/16", r)),
+ Peer: peers[routePeerID].Key,
+ PeerID: routePeerID,
+ Description: fmt.Sprintf("Route %d", r),
+ Enabled: true,
+ PeerGroups: []string{groupID},
+ Groups: []string{"group-all"},
+ AccessControlGroups: []string{groupID},
+ AccountID: "test-account",
+ }
+ }
+
+ numResources := numGroups / 2
+ if numResources < 1 {
+ numResources = 1
+ }
+ if numResources > 50 {
+ numResources = 50
+ }
+
+ networkResources := make([]*resourceTypes.NetworkResource, 0, numResources)
+ networksList := make([]*networkTypes.Network, 0, numResources)
+ networkRouters := make([]*routerTypes.NetworkRouter, 0, numResources)
+
+ routingPeerStart := numPeers * 3 / 4
+ for nr := range numResources {
+ netID := fmt.Sprintf("net-%d", nr)
+ resID := fmt.Sprintf("res-%d", nr)
+ routerPeerIdx := routingPeerStart + nr
+ if routerPeerIdx >= numPeers {
+ routerPeerIdx = numPeers - 1
+ }
+ routerPeerID := fmt.Sprintf("peer-%d", routerPeerIdx)
+
+ networksList = append(networksList, &networkTypes.Network{ID: netID, Name: fmt.Sprintf("Network %d", nr), AccountID: "test-account"})
+ networkResources = append(networkResources, &resourceTypes.NetworkResource{
+ ID: resID, NetworkID: netID, AccountID: "test-account", Enabled: true,
+ Address: fmt.Sprintf("svc-%d.netbird.cloud", nr),
+ })
+ networkRouters = append(networkRouters, &routerTypes.NetworkRouter{
+ ID: fmt.Sprintf("router-%d", nr), NetworkID: netID, Peer: routerPeerID,
+ Enabled: true, AccountID: "test-account",
+ })
+
+ policies = append(policies, &types.Policy{
+ ID: fmt.Sprintf("policy-res-%d", nr), Name: fmt.Sprintf("Resource Policy %d", nr), Enabled: true,
+ SourcePostureChecks: []string{"posture-check-ver"},
+ Rules: []*types.PolicyRule{{
+ ID: fmt.Sprintf("rule-res-%d", nr), Name: fmt.Sprintf("Allow Resource %d", nr), Enabled: true,
+ Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolALL, Bidirectional: true,
+ Sources: []string{fmt.Sprintf("group-%d", nr%numGroups)},
+ DestinationResource: types.Resource{ID: resID},
+ }},
+ })
+ }
+
+ account := &types.Account{
+ Id: "test-account",
+ Peers: peers,
+ Groups: groups,
+ Policies: policies,
+ Routes: routes,
+ Users: map[string]*types.User{
+ "user-admin": {Id: "user-admin", Role: types.UserRoleAdmin, IsServiceUser: false, AccountID: "test-account"},
+ },
+ Network: &types.Network{
+ Identifier: "net-test", Net: net.IPNet{IP: net.IP{100, 64, 0, 0}, Mask: net.CIDRMask(10, 32)}, Serial: 1,
+ },
+ DNSSettings: types.DNSSettings{DisabledManagementGroups: []string{}},
+ NameServerGroups: map[string]*nbdns.NameServerGroup{
+ "ns-group-main": {
+ ID: "ns-group-main", Name: "Main NS", Enabled: true, Groups: []string{"group-all"},
+ NameServers: []nbdns.NameServer{{IP: netip.MustParseAddr("8.8.8.8"), NSType: nbdns.UDPNameServerType, Port: 53}},
+ },
+ },
+ PostureChecks: []*posture.Checks{
+ {ID: "posture-check-ver", Name: "Check version", Checks: posture.ChecksDefinition{
+ NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.26.0"},
+ }},
+ },
+ NetworkResources: networkResources,
+ Networks: networksList,
+ NetworkRouters: networkRouters,
+ Settings: &types.Settings{PeerLoginExpirationEnabled: true, PeerLoginExpiration: 1 * time.Hour},
+ }
+
+ for _, p := range account.Policies {
+ p.AccountID = account.Id
+ }
+ for _, r := range account.Routes {
+ r.AccountID = account.Id
+ }
+
+ validatedPeers := make(map[string]struct{}, numPeers)
+ for i := range numPeers {
+ peerID := fmt.Sprintf("peer-%d", i)
+ if i != numPeers-1 {
+ validatedPeers[peerID] = struct{}{}
+ }
+ }
+
+ return account, validatedPeers
+}
+
+// componentsNetworkMap is a convenience wrapper for GetPeerNetworkMapFromComponents.
+func componentsNetworkMap(account *types.Account, peerID string, validatedPeers map[string]struct{}) *types.NetworkMap {
+ return account.GetPeerNetworkMapFromComponents(
+ context.Background(), peerID, nbdns.CustomZone{}, nil,
+ validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(),
+ nil, account.GetActiveGroupUsers(),
+ )
+}
+
+// ──────────────────────────────────────────────────────────────────────────────
+// 1. PEER VISIBILITY & GROUPS
+// ──────────────────────────────────────────────────────────────────────────────
+
+func TestComponents_PeerVisibility(t *testing.T) {
+ account, validatedPeers := scalableTestAccount(100, 5)
+ nm := componentsNetworkMap(account, "peer-0", validatedPeers)
+ require.NotNil(t, nm)
+ assert.Equal(t, len(validatedPeers)-1-len(nm.OfflinePeers), len(nm.Peers), "peer should see all other validated non-expired peers")
+}
+
+func TestComponents_PeerDoesNotSeeItself(t *testing.T) {
+ account, validatedPeers := scalableTestAccount(50, 5)
+ nm := componentsNetworkMap(account, "peer-0", validatedPeers)
+ require.NotNil(t, nm)
+ for _, p := range nm.Peers {
+ assert.NotEqual(t, "peer-0", p.ID, "peer should not see itself")
+ }
+}
+
+func TestComponents_IntraGroupConnectivity(t *testing.T) {
+ account, validatedPeers := scalableTestAccount(20, 2)
+ nm := componentsNetworkMap(account, "peer-0", validatedPeers)
+ require.NotNil(t, nm)
+
+ peerIDs := make(map[string]bool, len(nm.Peers))
+ for _, p := range nm.Peers {
+ peerIDs[p.ID] = true
+ }
+ assert.True(t, peerIDs["peer-5"], "peer-0 should see peer-5 from same group")
+}
+
+func TestComponents_CrossGroupConnectivity(t *testing.T) {
+ // Without default policy, only per-group policies provide connectivity
+ account, validatedPeers := scalableTestAccountWithoutDefaultPolicy(20, 2)
+ nm := componentsNetworkMap(account, "peer-0", validatedPeers)
+ require.NotNil(t, nm)
+
+ peerIDs := make(map[string]bool, len(nm.Peers))
+ for _, p := range nm.Peers {
+ peerIDs[p.ID] = true
+ }
+ assert.True(t, peerIDs["peer-10"], "peer-0 should see peer-10 from cross-group policy")
+}
+
+func TestComponents_BidirectionalPolicy(t *testing.T) {
+ // Without default policy so bidirectional visibility comes only from per-group policies
+ account, validatedPeers := scalableTestAccountWithoutDefaultPolicy(100, 5)
+ nm0 := componentsNetworkMap(account, "peer-0", validatedPeers)
+ nm20 := componentsNetworkMap(account, "peer-20", validatedPeers)
+ require.NotNil(t, nm0)
+ require.NotNil(t, nm20)
+
+ peer0SeesPeer20 := false
+ for _, p := range nm0.Peers {
+ if p.ID == "peer-20" {
+ peer0SeesPeer20 = true
+ }
+ }
+ peer20SeesPeer0 := false
+ for _, p := range nm20.Peers {
+ if p.ID == "peer-0" {
+ peer20SeesPeer0 = true
+ }
+ }
+ assert.True(t, peer0SeesPeer20, "peer-0 should see peer-20 via bidirectional policy")
+ assert.True(t, peer20SeesPeer0, "peer-20 should see peer-0 via bidirectional policy")
+}
+
+// ──────────────────────────────────────────────────────────────────────────────
+// 2. PEER EXPIRATION & ACCOUNT SETTINGS
+// ──────────────────────────────────────────────────────────────────────────────
+
+func TestComponents_ExpiredPeerInOfflineList(t *testing.T) {
+ account, validatedPeers := scalableTestAccount(100, 5)
+ nm := componentsNetworkMap(account, "peer-0", validatedPeers)
+ require.NotNil(t, nm)
+
+ offlineIDs := make(map[string]bool, len(nm.OfflinePeers))
+ for _, p := range nm.OfflinePeers {
+ offlineIDs[p.ID] = true
+ }
+ assert.True(t, offlineIDs["peer-98"], "expired peer should be in OfflinePeers")
+ for _, p := range nm.Peers {
+ assert.NotEqual(t, "peer-98", p.ID, "expired peer should not be in active Peers")
+ }
+}
+
+func TestComponents_ExpirationDisabledSetting(t *testing.T) {
+ account, validatedPeers := scalableTestAccount(100, 5)
+ account.Settings.PeerLoginExpirationEnabled = false
+
+ nm := componentsNetworkMap(account, "peer-0", validatedPeers)
+ require.NotNil(t, nm)
+
+ peerIDs := make(map[string]bool, len(nm.Peers))
+ for _, p := range nm.Peers {
+ peerIDs[p.ID] = true
+ }
+ assert.True(t, peerIDs["peer-98"], "with expiration disabled, peer-98 should be in active Peers")
+}
+
+func TestComponents_LoginExpiration_PeerLevel(t *testing.T) {
+ account, validatedPeers := scalableTestAccount(20, 2)
+ account.Settings.PeerLoginExpirationEnabled = true
+ account.Settings.PeerLoginExpiration = 1 * time.Hour
+
+ pastLogin := time.Now().Add(-2 * time.Hour)
+ account.Peers["peer-5"].LastLogin = &pastLogin
+ account.Peers["peer-5"].LoginExpirationEnabled = true
+
+ nm := componentsNetworkMap(account, "peer-0", validatedPeers)
+ require.NotNil(t, nm)
+
+ offlineIDs := make(map[string]bool, len(nm.OfflinePeers))
+ for _, p := range nm.OfflinePeers {
+ offlineIDs[p.ID] = true
+ }
+ assert.True(t, offlineIDs["peer-5"], "login-expired peer should be in OfflinePeers")
+ for _, p := range nm.Peers {
+ assert.NotEqual(t, "peer-5", p.ID, "login-expired peer should not be in active Peers")
+ }
+}
+
+func TestComponents_NetworkSerial(t *testing.T) {
+ account, validatedPeers := scalableTestAccount(50, 5)
+ account.Network.Serial = 42
+ nm := componentsNetworkMap(account, "peer-0", validatedPeers)
+ require.NotNil(t, nm)
+ assert.Equal(t, uint64(42), nm.Network.Serial, "network serial should match")
+}
+
+// ──────────────────────────────────────────────────────────────────────────────
+// 3. NON-VALIDATED PEERS
+// ──────────────────────────────────────────────────────────────────────────────
+
+func TestComponents_NonValidatedPeerExcluded(t *testing.T) {
+ account, validatedPeers := scalableTestAccount(100, 5)
+ nm := componentsNetworkMap(account, "peer-0", validatedPeers)
+ require.NotNil(t, nm)
+ for _, p := range nm.Peers {
+ assert.NotEqual(t, "peer-99", p.ID, "non-validated peer should not appear in Peers")
+ }
+ for _, p := range nm.OfflinePeers {
+ assert.NotEqual(t, "peer-99", p.ID, "non-validated peer should not appear in OfflinePeers")
+ }
+}
+
+func TestComponents_NonValidatedTargetPeerGetsEmptyMap(t *testing.T) {
+ account, validatedPeers := scalableTestAccount(100, 5)
+ nm := componentsNetworkMap(account, "peer-99", validatedPeers)
+ require.NotNil(t, nm)
+ assert.Empty(t, nm.Peers)
+ assert.Empty(t, nm.FirewallRules)
+}
+
+func TestComponents_NonExistentPeerGetsEmptyMap(t *testing.T) {
+ account, validatedPeers := scalableTestAccount(100, 5)
+ nm := componentsNetworkMap(account, "peer-does-not-exist", validatedPeers)
+ require.NotNil(t, nm)
+ assert.Empty(t, nm.Peers)
+ assert.Empty(t, nm.FirewallRules)
+}
+
+// ──────────────────────────────────────────────────────────────────────────────
+// 4. POLICIES & FIREWALL RULES
+// ──────────────────────────────────────────────────────────────────────────────
+
+func TestComponents_FirewallRulesGenerated(t *testing.T) {
+ account, validatedPeers := scalableTestAccount(100, 5)
+ nm := componentsNetworkMap(account, "peer-0", validatedPeers)
+ require.NotNil(t, nm)
+ assert.NotEmpty(t, nm.FirewallRules, "should have firewall rules from policies")
+}
+
+func TestComponents_DropPolicyGeneratesDropRules(t *testing.T) {
+ account, validatedPeers := scalableTestAccount(100, 5)
+ nm := componentsNetworkMap(account, "peer-0", validatedPeers)
+ require.NotNil(t, nm)
+
+ hasDropRule := false
+ for _, rule := range nm.FirewallRules {
+ if rule.Action == string(types.PolicyTrafficActionDrop) {
+ hasDropRule = true
+ break
+ }
+ }
+ assert.True(t, hasDropRule, "should have at least one drop firewall rule")
+}
+
+func TestComponents_DisabledPolicyIgnored(t *testing.T) {
+ account, validatedPeers := scalableTestAccount(50, 2)
+ for _, p := range account.Policies {
+ p.Enabled = false
+ }
+ nm := componentsNetworkMap(account, "peer-0", validatedPeers)
+ require.NotNil(t, nm)
+ assert.Empty(t, nm.Peers, "disabled policies should yield no peers")
+ assert.Empty(t, nm.FirewallRules, "disabled policies should yield no firewall rules")
+}
+
+func TestComponents_PortPolicy(t *testing.T) {
+ account, validatedPeers := scalableTestAccount(50, 2)
+ nm := componentsNetworkMap(account, "peer-0", validatedPeers)
+ require.NotNil(t, nm)
+
+ has8080, has5432 := false, false
+ for _, rule := range nm.FirewallRules {
+ if rule.Port == "8080" {
+ has8080 = true
+ }
+ if rule.Port == "5432" {
+ has5432 = true
+ }
+ }
+ assert.True(t, has8080, "should have firewall rule for port 8080")
+ assert.True(t, has5432, "should have firewall rule for port 5432 (drop policy)")
+}
+
+func TestComponents_PortRangePolicy(t *testing.T) {
+ account, validatedPeers := scalableTestAccount(50, 2)
+ account.Peers["peer-0"].Meta.WtVersion = "0.50.0"
+
+ account.Policies = append(account.Policies, &types.Policy{
+ ID: "policy-port-range", Name: "Port Range", Enabled: true, AccountID: "test-account",
+ Rules: []*types.PolicyRule{{
+ ID: "rule-port-range", Name: "Port Range Rule", Enabled: true,
+ Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolTCP,
+ Bidirectional: true,
+ PortRanges: []types.RulePortRange{{Start: 8000, End: 9000}},
+ Sources: []string{"group-0"}, Destinations: []string{"group-1"},
+ }},
+ })
+
+ nm := componentsNetworkMap(account, "peer-0", validatedPeers)
+ require.NotNil(t, nm)
+
+ hasPortRange := false
+ for _, rule := range nm.FirewallRules {
+ if rule.PortRange.Start == 8000 && rule.PortRange.End == 9000 {
+ hasPortRange = true
+ break
+ }
+ }
+ assert.True(t, hasPortRange, "should have firewall rule with port range 8000-9000")
+}
+
+func TestComponents_FirewallRuleDirection(t *testing.T) {
+ account, validatedPeers := scalableTestAccount(50, 2)
+ nm := componentsNetworkMap(account, "peer-0", validatedPeers)
+ require.NotNil(t, nm)
+
+ hasIn, hasOut := false, false
+ for _, rule := range nm.FirewallRules {
+ if rule.Direction == types.FirewallRuleDirectionIN {
+ hasIn = true
+ }
+ if rule.Direction == types.FirewallRuleDirectionOUT {
+ hasOut = true
+ }
+ }
+ assert.True(t, hasIn, "should have inbound firewall rules")
+ assert.True(t, hasOut, "should have outbound firewall rules")
+}
+
+// ──────────────────────────────────────────────────────────────────────────────
+// 5. ROUTES
+// ──────────────────────────────────────────────────────────────────────────────
+
+func TestComponents_RoutesIncluded(t *testing.T) {
+ account, validatedPeers := scalableTestAccount(100, 5)
+ nm := componentsNetworkMap(account, "peer-0", validatedPeers)
+ require.NotNil(t, nm)
+ assert.NotEmpty(t, nm.Routes, "should have routes")
+}
+
+func TestComponents_DisabledRouteExcluded(t *testing.T) {
+ account, validatedPeers := scalableTestAccount(50, 2)
+ for _, r := range account.Routes {
+ r.Enabled = false
+ }
+ nm := componentsNetworkMap(account, "peer-0", validatedPeers)
+ require.NotNil(t, nm)
+ for _, r := range nm.Routes {
+ assert.True(t, r.Enabled, "only enabled routes should appear")
+ }
+}
+
+func TestComponents_RoutesFirewallRulesForACG(t *testing.T) {
+ account, validatedPeers := scalableTestAccount(100, 5)
+ nm := componentsNetworkMap(account, "peer-0", validatedPeers)
+ require.NotNil(t, nm)
+ assert.NotEmpty(t, nm.RoutesFirewallRules, "should have route firewall rules for access-controlled routes")
+}
+
+func TestComponents_HARouteDeduplication(t *testing.T) {
+ account, validatedPeers := scalableTestAccount(50, 5)
+
+ haNetwork := netip.MustParsePrefix("172.16.0.0/16")
+ account.Routes["route-ha-1"] = &route.Route{
+ ID: "route-ha-1", Network: haNetwork, PeerID: "peer-10",
+ Peer: account.Peers["peer-10"].Key, Enabled: true, Metric: 100,
+ Groups: []string{"group-all"}, PeerGroups: []string{"group-0"}, AccountID: "test-account",
+ }
+ account.Routes["route-ha-2"] = &route.Route{
+ ID: "route-ha-2", Network: haNetwork, PeerID: "peer-20",
+ Peer: account.Peers["peer-20"].Key, Enabled: true, Metric: 200,
+ Groups: []string{"group-all"}, PeerGroups: []string{"group-1"}, AccountID: "test-account",
+ }
+
+ nm := componentsNetworkMap(account, "peer-0", validatedPeers)
+ require.NotNil(t, nm)
+
+ haRoutes := 0
+ for _, r := range nm.Routes {
+ if r.Network == haNetwork {
+ haRoutes++
+ }
+ }
+ // Components deduplicates HA routes with the same HA unique ID, returning one entry per HA group
+ assert.Equal(t, 1, haRoutes, "HA routes with same network should be deduplicated into one entry")
+}
+
+// ──────────────────────────────────────────────────────────────────────────────
+// 6. NETWORK RESOURCES & ROUTERS
+// ──────────────────────────────────────────────────────────────────────────────
+
+func TestComponents_NetworkResourceRoutes_RouterPeer(t *testing.T) {
+ account, validatedPeers := scalableTestAccount(100, 5)
+
+ var routerPeerID string
+ for _, nr := range account.NetworkRouters {
+ routerPeerID = nr.Peer
+ break
+ }
+ require.NotEmpty(t, routerPeerID)
+
+ nm := componentsNetworkMap(account, routerPeerID, validatedPeers)
+ require.NotNil(t, nm)
+ assert.NotEmpty(t, nm.Peers, "router peer should see source peers")
+}
+
+func TestComponents_NetworkResourceRoutes_SourcePeerSeesRouterPeer(t *testing.T) {
+ account, validatedPeers := scalableTestAccount(100, 5)
+
+ var routerPeerID string
+ for _, nr := range account.NetworkRouters {
+ routerPeerID = nr.Peer
+ break
+ }
+ require.NotEmpty(t, routerPeerID)
+
+ nm := componentsNetworkMap(account, "peer-0", validatedPeers)
+ require.NotNil(t, nm)
+
+ peerIDs := make(map[string]bool, len(nm.Peers))
+ for _, p := range nm.Peers {
+ peerIDs[p.ID] = true
+ }
+ assert.True(t, peerIDs[routerPeerID], "source peer should see router peer for network resource")
+}
+
+func TestComponents_DisabledNetworkResourceIgnored(t *testing.T) {
+ account, validatedPeers := scalableTestAccount(50, 5)
+ for _, nr := range account.NetworkResources {
+ nr.Enabled = false
+ }
+ nm := componentsNetworkMap(account, "peer-0", validatedPeers)
+ require.NotNil(t, nm)
+ assert.NotNil(t, nm.Network)
+}
+
+// ──────────────────────────────────────────────────────────────────────────────
+// 7. POSTURE CHECKS
+// ──────────────────────────────────────────────────────────────────────────────
+
+func TestComponents_PostureCheckFiltering_PassingPeer(t *testing.T) {
+ account, validatedPeers := scalableTestAccount(100, 5)
+ nm := componentsNetworkMap(account, "peer-0", validatedPeers)
+ require.NotNil(t, nm)
+ assert.NotEmpty(t, nm.Routes, "passing peer should have routes including resource routes")
+}
+
+func TestComponents_PostureCheckFiltering_FailingPeer(t *testing.T) {
+ // peer-0 has version 0.40.0 (passes posture check >= 0.26.0)
+ // peer-1 has version 0.25.0 (fails posture check >= 0.26.0)
+ // Resource policies require posture-check-ver, so the failing peer
+ // should not see the router peer for those resources.
+ account, validatedPeers := scalableTestAccountWithoutDefaultPolicy(100, 5)
+
+ nm0 := componentsNetworkMap(account, "peer-0", validatedPeers)
+ nm1 := componentsNetworkMap(account, "peer-1", validatedPeers)
+ require.NotNil(t, nm0)
+ require.NotNil(t, nm1)
+
+ // The passing peer should have more peers visible (including resource router peers)
+ // than the failing peer, because the failing peer is excluded from resource policies.
+ assert.Greater(t, len(nm0.Peers), len(nm1.Peers),
+ "passing peer (0.40.0) should see more peers than failing peer (0.25.0) due to posture-gated resource policies")
+}
+
+func TestComponents_MultiplePostureChecks(t *testing.T) {
+ account, validatedPeers := scalableTestAccountWithoutDefaultPolicy(50, 2)
+
+ // Keep only the posture-gated policy — remove per-group policies so connectivity is isolated
+ account.Policies = []*types.Policy{}
+
+ // Set kernel version on peers so the OS posture check can evaluate
+ for _, p := range account.Peers {
+ p.Meta.KernelVersion = "5.15.0"
+ }
+
+ account.PostureChecks = append(account.PostureChecks, &posture.Checks{
+ ID: "posture-check-os", Name: "Check OS",
+ Checks: posture.ChecksDefinition{
+ OSVersionCheck: &posture.OSVersionCheck{Linux: &posture.MinKernelVersionCheck{MinKernelVersion: "0.0.1"}},
+ },
+ })
+ account.Policies = append(account.Policies, &types.Policy{
+ ID: "policy-multi-posture", Name: "Multi Posture", Enabled: true, AccountID: "test-account",
+ SourcePostureChecks: []string{"posture-check-ver", "posture-check-os"},
+ Rules: []*types.PolicyRule{{
+ ID: "rule-multi-posture", Name: "Multi Check Rule", Enabled: true,
+ Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolALL,
+ Bidirectional: true,
+ Sources: []string{"group-0"}, Destinations: []string{"group-1"},
+ }},
+ })
+
+ // peer-0 (0.40.0, kernel 5.15.0) passes both checks, should see group-1 peers
+ nm0 := componentsNetworkMap(account, "peer-0", validatedPeers)
+ require.NotNil(t, nm0)
+ assert.NotEmpty(t, nm0.Peers, "peer passing both posture checks should see destination peers")
+
+ // peer-1 (0.25.0, kernel 5.15.0) fails version check, should NOT see group-1 peers
+ nm1 := componentsNetworkMap(account, "peer-1", validatedPeers)
+ require.NotNil(t, nm1)
+ assert.Empty(t, nm1.Peers,
+ "peer failing posture check should see no peers when posture-gated policy is the only connectivity")
+}
+
+// ──────────────────────────────────────────────────────────────────────────────
+// 8. DNS
+// ──────────────────────────────────────────────────────────────────────────────
+
+func TestComponents_DNSConfigEnabled(t *testing.T) {
+ account, validatedPeers := scalableTestAccount(100, 5)
+ nm := componentsNetworkMap(account, "peer-0", validatedPeers)
+ require.NotNil(t, nm)
+ assert.True(t, nm.DNSConfig.ServiceEnable, "DNS should be enabled")
+ assert.NotEmpty(t, nm.DNSConfig.NameServerGroups, "should have nameserver groups")
+}
+
+func TestComponents_DNSDisabledByManagementGroup(t *testing.T) {
+ account, validatedPeers := scalableTestAccount(100, 5)
+ account.DNSSettings.DisabledManagementGroups = []string{"group-all"}
+
+ nm := componentsNetworkMap(account, "peer-0", validatedPeers)
+ require.NotNil(t, nm)
+ assert.False(t, nm.DNSConfig.ServiceEnable, "DNS should be disabled for peer in disabled group")
+}
+
+func TestComponents_DNSNameServerGroupDistribution(t *testing.T) {
+ account, validatedPeers := scalableTestAccount(20, 2)
+ account.NameServerGroups["ns-group-0"] = &nbdns.NameServerGroup{
+ ID: "ns-group-0", Name: "Group 0 NS", Enabled: true, Groups: []string{"group-0"},
+ NameServers: []nbdns.NameServer{{IP: netip.MustParseAddr("1.1.1.1"), NSType: nbdns.UDPNameServerType, Port: 53}},
+ }
+
+ nm0 := componentsNetworkMap(account, "peer-0", validatedPeers)
+ require.NotNil(t, nm0)
+ hasGroup0NS := false
+ for _, ns := range nm0.DNSConfig.NameServerGroups {
+ if ns.ID == "ns-group-0" {
+ hasGroup0NS = true
+ }
+ }
+ assert.True(t, hasGroup0NS, "peer-0 in group-0 should receive ns-group-0")
+
+ nm10 := componentsNetworkMap(account, "peer-10", validatedPeers)
+ require.NotNil(t, nm10)
+ hasGroup0NSForPeer10 := false
+ for _, ns := range nm10.DNSConfig.NameServerGroups {
+ if ns.ID == "ns-group-0" {
+ hasGroup0NSForPeer10 = true
+ }
+ }
+ assert.False(t, hasGroup0NSForPeer10, "peer-10 in group-1 should NOT receive ns-group-0")
+}
+
+func TestComponents_DNSCustomZone(t *testing.T) {
+ account, validatedPeers := scalableTestAccount(20, 2)
+ customZone := nbdns.CustomZone{
+ Domain: "netbird.cloud.",
+ Records: []nbdns.SimpleRecord{
+ {Name: "peer0.netbird.cloud.", Type: 1, Class: "IN", TTL: 300, RData: account.Peers["peer-0"].IP.String()},
+ {Name: "peer1.netbird.cloud.", Type: 1, Class: "IN", TTL: 300, RData: account.Peers["peer-1"].IP.String()},
+ },
+ }
+
+ nm := account.GetPeerNetworkMapFromComponents(
+ context.Background(), "peer-0", customZone, nil,
+ validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(),
+ nil, account.GetActiveGroupUsers(),
+ )
+ require.NotNil(t, nm)
+ assert.True(t, nm.DNSConfig.ServiceEnable)
+}
+
+// ──────────────────────────────────────────────────────────────────────────────
+// 9. SSH
+// ──────────────────────────────────────────────────────────────────────────────
+
+func TestComponents_SSHPolicy(t *testing.T) {
+ account, validatedPeers := scalableTestAccount(20, 2)
+ account.Groups["ssh-users"] = &types.Group{ID: "ssh-users", Name: "SSH Users", Peers: []string{}}
+ account.Policies = append(account.Policies, &types.Policy{
+ ID: "policy-ssh", Name: "SSH Access", Enabled: true, AccountID: "test-account",
+ Rules: []*types.PolicyRule{{
+ ID: "rule-ssh", Name: "Allow SSH", Enabled: true,
+ Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolNetbirdSSH,
+ Bidirectional: false,
+ Sources: []string{"group-0"}, Destinations: []string{"group-1"},
+ AuthorizedGroups: map[string][]string{"ssh-users": {"root"}},
+ }},
+ })
+
+ nm := componentsNetworkMap(account, "peer-10", validatedPeers)
+ require.NotNil(t, nm)
+ assert.True(t, nm.EnableSSH, "SSH should be enabled for destination peer of SSH policy")
+}
+
+func TestComponents_SSHNotEnabledWithoutPolicy(t *testing.T) {
+ account, validatedPeers := scalableTestAccount(20, 2)
+ nm := componentsNetworkMap(account, "peer-0", validatedPeers)
+ require.NotNil(t, nm)
+ assert.False(t, nm.EnableSSH, "SSH should not be enabled without SSH policy")
+}
+
+// ──────────────────────────────────────────────────────────────────────────────
+// 10. CROSS-PEER CONSISTENCY
+// ──────────────────────────────────────────────────────────────────────────────
+
+// TestComponents_AllPeersGetValidMaps verifies that every validated peer gets a
+// non-nil map with a consistent network serial and non-empty peer list.
+func TestComponents_AllPeersGetValidMaps(t *testing.T) {
+ account, validatedPeers := scalableTestAccount(50, 5)
+ for peerID := range account.Peers {
+ if _, validated := validatedPeers[peerID]; !validated {
+ continue
+ }
+ nm := componentsNetworkMap(account, peerID, validatedPeers)
+ require.NotNil(t, nm, "network map should not be nil for %s", peerID)
+ assert.Equal(t, account.Network.Serial, nm.Network.Serial, "serial mismatch for %s", peerID)
+ assert.NotEmpty(t, nm.Peers, "validated peer %s should see other peers", peerID)
+ }
+}
+
+// TestComponents_LargeScaleMapGeneration verifies that components can generate maps
+// at larger scales without errors and with consistent output.
+func TestComponents_LargeScaleMapGeneration(t *testing.T) {
+ scales := []struct{ peers, groups int }{
+ {500, 20},
+ {1000, 50},
+ }
+ for _, s := range scales {
+ t.Run(fmt.Sprintf("%dpeers_%dgroups", s.peers, s.groups), func(t *testing.T) {
+ account, validatedPeers := scalableTestAccount(s.peers, s.groups)
+ testPeers := []string{"peer-0", fmt.Sprintf("peer-%d", s.peers/4), fmt.Sprintf("peer-%d", s.peers/2)}
+ for _, peerID := range testPeers {
+ nm := componentsNetworkMap(account, peerID, validatedPeers)
+ require.NotNil(t, nm, "network map should not be nil for %s", peerID)
+ assert.NotEmpty(t, nm.Peers, "peer %s should see other peers at scale", peerID)
+ assert.NotEmpty(t, nm.Routes, "peer %s should have routes at scale", peerID)
+ assert.Equal(t, account.Network.Serial, nm.Network.Serial, "serial mismatch for %s", peerID)
+ }
+ })
+ }
+}
+
+// ──────────────────────────────────────────────────────────────────────────────
+// 11. PEER-AS-RESOURCE POLICIES
+// ──────────────────────────────────────────────────────────────────────────────
+
+// TestComponents_PeerAsSourceResource verifies that a policy with SourceResource.Type=Peer
+// targets only that specific peer as the source.
+func TestComponents_PeerAsSourceResource(t *testing.T) {
+ account, validatedPeers := scalableTestAccount(20, 2)
+
+ account.Policies = append(account.Policies, &types.Policy{
+ ID: "policy-peer-src", Name: "Peer Source Resource", Enabled: true, AccountID: "test-account",
+ Rules: []*types.PolicyRule{{
+ ID: "rule-peer-src", Name: "Peer Source Rule", Enabled: true,
+ Action: types.PolicyTrafficActionAccept,
+ Protocol: types.PolicyRuleProtocolTCP,
+ Bidirectional: true,
+ Ports: []string{"443"},
+ SourceResource: types.Resource{ID: "peer-0", Type: types.ResourceTypePeer},
+ Destinations: []string{"group-1"},
+ }},
+ })
+
+ // peer-0 is the source resource, should see group-1 peers
+ nm0 := componentsNetworkMap(account, "peer-0", validatedPeers)
+ require.NotNil(t, nm0)
+
+ has443 := false
+ for _, rule := range nm0.FirewallRules {
+ if rule.Port == "443" {
+ has443 = true
+ break
+ }
+ }
+ assert.True(t, has443, "peer-0 as source resource should have port 443 rule")
+}
+
+// TestComponents_PeerAsDestinationResource verifies that a policy with DestinationResource.Type=Peer
+// targets only that specific peer as the destination.
+func TestComponents_PeerAsDestinationResource(t *testing.T) {
+ account, validatedPeers := scalableTestAccountWithoutDefaultPolicy(20, 2)
+
+ account.Policies = append(account.Policies, &types.Policy{
+ ID: "policy-peer-dst", Name: "Peer Dest Resource", Enabled: true, AccountID: "test-account",
+ Rules: []*types.PolicyRule{{
+ ID: "rule-peer-dst", Name: "Peer Dest Rule", Enabled: true,
+ Action: types.PolicyTrafficActionAccept,
+ Protocol: types.PolicyRuleProtocolTCP,
+ Bidirectional: true,
+ Ports: []string{"443"},
+ Sources: []string{"group-0"},
+ DestinationResource: types.Resource{ID: "peer-15", Type: types.ResourceTypePeer},
+ }},
+ })
+
+ // peer-0 is in group-0 (source), should see peer-15 as destination
+ nm0 := componentsNetworkMap(account, "peer-0", validatedPeers)
+ require.NotNil(t, nm0)
+
+ peerIDs := make(map[string]bool, len(nm0.Peers))
+ for _, p := range nm0.Peers {
+ peerIDs[p.ID] = true
+ }
+ assert.True(t, peerIDs["peer-15"], "peer-0 should see peer-15 via peer-as-destination-resource policy")
+}
+
+// ──────────────────────────────────────────────────────────────────────────────
+// 12. MULTIPLE RULES PER POLICY
+// ──────────────────────────────────────────────────────────────────────────────
+
+// TestComponents_MultipleRulesPerPolicy verifies a policy with multiple rules generates
+// firewall rules for each.
+func TestComponents_MultipleRulesPerPolicy(t *testing.T) {
+ account, validatedPeers := scalableTestAccount(20, 2)
+
+ account.Policies = append(account.Policies, &types.Policy{
+ ID: "policy-multi-rule", Name: "Multi Rule Policy", Enabled: true, AccountID: "test-account",
+ Rules: []*types.PolicyRule{
+ {
+ ID: "rule-http", Name: "Allow HTTP", Enabled: true,
+ Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolTCP,
+ Bidirectional: true, Ports: []string{"80"},
+ Sources: []string{"group-0"}, Destinations: []string{"group-1"},
+ },
+ {
+ ID: "rule-https", Name: "Allow HTTPS", Enabled: true,
+ Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolTCP,
+ Bidirectional: true, Ports: []string{"443"},
+ Sources: []string{"group-0"}, Destinations: []string{"group-1"},
+ },
+ },
+ })
+
+ nm := componentsNetworkMap(account, "peer-0", validatedPeers)
+ require.NotNil(t, nm)
+
+ has80, has443 := false, false
+ for _, rule := range nm.FirewallRules {
+ if rule.Port == "80" {
+ has80 = true
+ }
+ if rule.Port == "443" {
+ has443 = true
+ }
+ }
+ assert.True(t, has80, "should have firewall rule for port 80 from first rule")
+ assert.True(t, has443, "should have firewall rule for port 443 from second rule")
+}
+
+// ──────────────────────────────────────────────────────────────────────────────
+// 13. SSH AUTHORIZED USERS CONTENT
+// ──────────────────────────────────────────────────────────────────────────────
+
+// TestComponents_SSHAuthorizedUsersContent verifies that SSH policies populate
+// the AuthorizedUsers map with the correct users and machine mappings.
+func TestComponents_SSHAuthorizedUsersContent(t *testing.T) {
+ account, validatedPeers := scalableTestAccount(20, 2)
+
+ account.Users["user-dev"] = &types.User{Id: "user-dev", Role: types.UserRoleUser, AccountID: "test-account", AutoGroups: []string{"ssh-users"}}
+ account.Groups["ssh-users"] = &types.Group{ID: "ssh-users", Name: "SSH Users", Peers: []string{}}
+
+ account.Policies = append(account.Policies, &types.Policy{
+ ID: "policy-ssh", Name: "SSH Access", Enabled: true, AccountID: "test-account",
+ Rules: []*types.PolicyRule{{
+ ID: "rule-ssh", Name: "Allow SSH", Enabled: true,
+ Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolNetbirdSSH,
+ Bidirectional: false,
+ Sources: []string{"group-0"}, Destinations: []string{"group-1"},
+ AuthorizedGroups: map[string][]string{"ssh-users": {"root", "admin"}},
+ }},
+ })
+
+ // peer-10 is in group-1 (destination)
+ nm := componentsNetworkMap(account, "peer-10", validatedPeers)
+ require.NotNil(t, nm)
+ assert.True(t, nm.EnableSSH, "SSH should be enabled")
+ assert.NotNil(t, nm.AuthorizedUsers, "AuthorizedUsers should not be nil")
+ assert.NotEmpty(t, nm.AuthorizedUsers, "AuthorizedUsers should have entries")
+
+ // Check that "root" machine user mapping exists
+ _, hasRoot := nm.AuthorizedUsers["root"]
+ _, hasAdmin := nm.AuthorizedUsers["admin"]
+ assert.True(t, hasRoot || hasAdmin, "AuthorizedUsers should contain 'root' or 'admin' machine user mapping")
+}
+
+// TestComponents_SSHLegacyImpliedSSH verifies that a non-SSH ALL protocol policy with
+// SSHEnabled peer implies legacy SSH access.
+func TestComponents_SSHLegacyImpliedSSH(t *testing.T) {
+ account, validatedPeers := scalableTestAccount(20, 2)
+
+ // Enable SSH on the destination peer
+ account.Peers["peer-10"].SSHEnabled = true
+
+ // The default "Allow All" policy with Protocol=ALL + SSHEnabled peer should imply SSH
+ nm := componentsNetworkMap(account, "peer-10", validatedPeers)
+ require.NotNil(t, nm)
+ assert.True(t, nm.EnableSSH, "SSH should be implied by ALL protocol policy with SSHEnabled peer")
+}
+
+// ──────────────────────────────────────────────────────────────────────────────
+// 14. ROUTE DEFAULT PERMIT (no AccessControlGroups)
+// ──────────────────────────────────────────────────────────────────────────────
+
+// TestComponents_RouteDefaultPermit verifies that a route without AccessControlGroups
+// generates default permit firewall rules (0.0.0.0/0 source).
+func TestComponents_RouteDefaultPermit(t *testing.T) {
+ account, validatedPeers := scalableTestAccount(20, 2)
+
+ // Add a route without ACGs — this peer is the routing peer
+ routingPeerID := "peer-5"
+ account.Routes["route-no-acg"] = &route.Route{
+ ID: "route-no-acg", Network: netip.MustParsePrefix("192.168.99.0/24"),
+ PeerID: routingPeerID, Peer: account.Peers[routingPeerID].Key,
+ Enabled: true, Groups: []string{"group-all"}, PeerGroups: []string{"group-0"},
+ AccessControlGroups: []string{},
+ AccountID: "test-account",
+ }
+
+ // The routing peer should get default permit route firewall rules
+ nm := componentsNetworkMap(account, routingPeerID, validatedPeers)
+ require.NotNil(t, nm)
+
+ hasDefaultPermit := false
+ for _, rfr := range nm.RoutesFirewallRules {
+ for _, src := range rfr.SourceRanges {
+ if src == "0.0.0.0/0" || src == "::/0" {
+ hasDefaultPermit = true
+ break
+ }
+ }
+ }
+ assert.True(t, hasDefaultPermit, "route without ACG should have default permit rule with 0.0.0.0/0 source")
+}
+
+// ──────────────────────────────────────────────────────────────────────────────
+// 15. MULTIPLE ROUTERS PER NETWORK
+// ──────────────────────────────────────────────────────────────────────────────
+
+// TestComponents_MultipleRoutersPerNetwork verifies that a network resource
+// with multiple routers provides routes through all available routers.
+func TestComponents_MultipleRoutersPerNetwork(t *testing.T) {
+ account, validatedPeers := scalableTestAccountWithoutDefaultPolicy(20, 2)
+
+ netID := "net-multi-router"
+ resID := "res-multi-router"
+ account.Networks = append(account.Networks, &networkTypes.Network{ID: netID, Name: "Multi Router Network", AccountID: "test-account"})
+ account.NetworkResources = append(account.NetworkResources, &resourceTypes.NetworkResource{
+ ID: resID, NetworkID: netID, AccountID: "test-account", Enabled: true,
+ Address: "multi-svc.netbird.cloud",
+ })
+ account.NetworkRouters = append(account.NetworkRouters,
+ &routerTypes.NetworkRouter{ID: "router-a", NetworkID: netID, Peer: "peer-5", Enabled: true, AccountID: "test-account", Metric: 100},
+ &routerTypes.NetworkRouter{ID: "router-b", NetworkID: netID, Peer: "peer-15", Enabled: true, AccountID: "test-account", Metric: 200},
+ )
+ account.Policies = append(account.Policies, &types.Policy{
+ ID: "policy-multi-router-res", Name: "Multi Router Resource", Enabled: true, AccountID: "test-account",
+ Rules: []*types.PolicyRule{{
+ ID: "rule-multi-router-res", Name: "Allow Multi Router", Enabled: true,
+ Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolALL, Bidirectional: true,
+ Sources: []string{"group-0"}, DestinationResource: types.Resource{ID: resID},
+ }},
+ })
+
+ // peer-0 is in group-0 (source), should see both router peers
+ nm := componentsNetworkMap(account, "peer-0", validatedPeers)
+ require.NotNil(t, nm)
+
+ peerIDs := make(map[string]bool, len(nm.Peers))
+ for _, p := range nm.Peers {
+ peerIDs[p.ID] = true
+ }
+ assert.True(t, peerIDs["peer-5"], "source peer should see router-a (peer-5)")
+ assert.True(t, peerIDs["peer-15"], "source peer should see router-b (peer-15)")
+}
+
+// ──────────────────────────────────────────────────────────────────────────────
+// 16. PEER-AS-NAMESERVER EXCLUSION
+// ──────────────────────────────────────────────────────────────────────────────
+
+// TestComponents_PeerIsNameserverExcludedFromNSGroup verifies that a peer serving
+// as a nameserver does not receive its own NS group in DNS config.
+func TestComponents_PeerIsNameserverExcludedFromNSGroup(t *testing.T) {
+ account, validatedPeers := scalableTestAccount(20, 2)
+
+ // peer-0 has IP 100.64.0.0 — make it a nameserver
+ nsIP := account.Peers["peer-0"].IP
+ account.NameServerGroups["ns-self"] = &nbdns.NameServerGroup{
+ ID: "ns-self", Name: "Self NS", Enabled: true, Groups: []string{"group-all"},
+ NameServers: []nbdns.NameServer{{IP: netip.AddrFrom4([4]byte{nsIP[0], nsIP[1], nsIP[2], nsIP[3]}), NSType: nbdns.UDPNameServerType, Port: 53}},
+ }
+
+ nm := componentsNetworkMap(account, "peer-0", validatedPeers)
+ require.NotNil(t, nm)
+
+ hasSelfNS := false
+ for _, ns := range nm.DNSConfig.NameServerGroups {
+ if ns.ID == "ns-self" {
+ hasSelfNS = true
+ }
+ }
+ assert.False(t, hasSelfNS, "peer serving as nameserver should NOT receive its own NS group")
+
+ // peer-10 is NOT the nameserver, should receive the NS group
+ nm10 := componentsNetworkMap(account, "peer-10", validatedPeers)
+ require.NotNil(t, nm10)
+ hasNSForPeer10 := false
+ for _, ns := range nm10.DNSConfig.NameServerGroups {
+ if ns.ID == "ns-self" {
+ hasNSForPeer10 = true
+ }
+ }
+ assert.True(t, hasNSForPeer10, "non-nameserver peer should receive the NS group")
+}
+
+// ──────────────────────────────────────────────────────────────────────────────
+// 17. DOMAIN NETWORK RESOURCES
+// ──────────────────────────────────────────────────────────────────────────────
+
+// TestComponents_DomainNetworkResource verifies that domain-based network resources
+// produce routes with the correct domain configuration.
+func TestComponents_DomainNetworkResource(t *testing.T) {
+ account, validatedPeers := scalableTestAccountWithoutDefaultPolicy(20, 2)
+
+ netID := "net-domain"
+ resID := "res-domain"
+ account.Networks = append(account.Networks, &networkTypes.Network{ID: netID, Name: "Domain Network", AccountID: "test-account"})
+ account.NetworkResources = append(account.NetworkResources, &resourceTypes.NetworkResource{
+ ID: resID, NetworkID: netID, AccountID: "test-account", Enabled: true,
+ Address: "api.example.com", Type: "domain",
+ })
+ account.NetworkRouters = append(account.NetworkRouters, &routerTypes.NetworkRouter{
+ ID: "router-domain", NetworkID: netID, Peer: "peer-5", Enabled: true, AccountID: "test-account",
+ })
+ account.Policies = append(account.Policies, &types.Policy{
+ ID: "policy-domain-res", Name: "Domain Resource Policy", Enabled: true, AccountID: "test-account",
+ Rules: []*types.PolicyRule{{
+ ID: "rule-domain-res", Name: "Allow Domain", Enabled: true,
+ Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolALL, Bidirectional: true,
+ Sources: []string{"group-0"}, DestinationResource: types.Resource{ID: resID},
+ }},
+ })
+
+ // peer-0 is source, should get route to the domain resource via peer-5
+ nm := componentsNetworkMap(account, "peer-0", validatedPeers)
+ require.NotNil(t, nm)
+
+ peerIDs := make(map[string]bool, len(nm.Peers))
+ for _, p := range nm.Peers {
+ peerIDs[p.ID] = true
+ }
+ assert.True(t, peerIDs["peer-5"], "source peer should see domain resource router peer")
+}
+
+// ──────────────────────────────────────────────────────────────────────────────
+// 18. DISABLED RULE WITHIN ENABLED POLICY
+// ──────────────────────────────────────────────────────────────────────────────
+
+// TestComponents_DisabledRuleInEnabledPolicy verifies that a disabled rule within
+// an enabled policy does not generate firewall rules.
+func TestComponents_DisabledRuleInEnabledPolicy(t *testing.T) {
+ account, validatedPeers := scalableTestAccount(20, 2)
+
+ account.Policies = append(account.Policies, &types.Policy{
+ ID: "policy-mixed-rules", Name: "Mixed Rules", Enabled: true, AccountID: "test-account",
+ Rules: []*types.PolicyRule{
+ {
+ ID: "rule-enabled", Name: "Enabled Rule", Enabled: true,
+ Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolTCP,
+ Bidirectional: true, Ports: []string{"3000"},
+ Sources: []string{"group-0"}, Destinations: []string{"group-1"},
+ },
+ {
+ ID: "rule-disabled", Name: "Disabled Rule", Enabled: false,
+ Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolTCP,
+ Bidirectional: true, Ports: []string{"3001"},
+ Sources: []string{"group-0"}, Destinations: []string{"group-1"},
+ },
+ },
+ })
+
+ nm := componentsNetworkMap(account, "peer-0", validatedPeers)
+ require.NotNil(t, nm)
+
+ has3000, has3001 := false, false
+ for _, rule := range nm.FirewallRules {
+ if rule.Port == "3000" {
+ has3000 = true
+ }
+ if rule.Port == "3001" {
+ has3001 = true
+ }
+ }
+ assert.True(t, has3000, "enabled rule should generate firewall rule for port 3000")
+ assert.False(t, has3001, "disabled rule should NOT generate firewall rule for port 3001")
+}
diff --git a/relay/server/handshake.go b/relay/server/handshake.go
index 8c3ee1899..067888406 100644
--- a/relay/server/handshake.go
+++ b/relay/server/handshake.go
@@ -1,11 +1,13 @@
package server
import (
+ "context"
"fmt"
- "net"
+ "time"
log "github.com/sirupsen/logrus"
+ "github.com/netbirdio/netbird/relay/server/listener"
"github.com/netbirdio/netbird/shared/relay/messages"
//nolint:staticcheck
"github.com/netbirdio/netbird/shared/relay/messages/address"
@@ -13,6 +15,12 @@ import (
authmsg "github.com/netbirdio/netbird/shared/relay/messages/auth"
)
+const (
+ // handshakeTimeout bounds how long a connection may remain in the
+ // pre-authentication handshake phase before being closed.
+ handshakeTimeout = 10 * time.Second
+)
+
type Validator interface {
Validate(any) error
// Deprecated: Use Validate instead.
@@ -58,7 +66,7 @@ func marshalResponseHelloMsg(instanceURL string) ([]byte, error) {
}
type handshake struct {
- conn net.Conn
+ conn listener.Conn
validator Validator
preparedMsg *preparedMsg
@@ -66,9 +74,9 @@ type handshake struct {
peerID *messages.PeerID
}
-func (h *handshake) handshakeReceive() (*messages.PeerID, error) {
+func (h *handshake) handshakeReceive(ctx context.Context) (*messages.PeerID, error) {
buf := make([]byte, messages.MaxHandshakeSize)
- n, err := h.conn.Read(buf)
+ n, err := h.conn.Read(ctx, buf)
if err != nil {
return nil, fmt.Errorf("read from %s: %w", h.conn.RemoteAddr(), err)
}
@@ -103,7 +111,7 @@ func (h *handshake) handshakeReceive() (*messages.PeerID, error) {
return peerID, nil
}
-func (h *handshake) handshakeResponse() error {
+func (h *handshake) handshakeResponse(ctx context.Context) error {
var responseMsg []byte
if h.handshakeMethodAuth {
responseMsg = h.preparedMsg.responseAuthMsg
@@ -111,7 +119,7 @@ func (h *handshake) handshakeResponse() error {
responseMsg = h.preparedMsg.responseHelloMsg
}
- if _, err := h.conn.Write(responseMsg); err != nil {
+ if _, err := h.conn.Write(ctx, responseMsg); err != nil {
return fmt.Errorf("handshake response write to %s (%s): %w", h.peerID, h.conn.RemoteAddr(), err)
}
diff --git a/relay/server/listener/conn.go b/relay/server/listener/conn.go
new file mode 100644
index 000000000..ef0869594
--- /dev/null
+++ b/relay/server/listener/conn.go
@@ -0,0 +1,14 @@
+package listener
+
+import (
+ "context"
+ "net"
+)
+
+// Conn is the relay connection contract implemented by WS and QUIC transports.
+type Conn interface {
+ Read(ctx context.Context, b []byte) (n int, err error)
+ Write(ctx context.Context, b []byte) (n int, err error)
+ RemoteAddr() net.Addr
+ Close() error
+}
diff --git a/relay/server/listener/listener.go b/relay/server/listener/listener.go
deleted file mode 100644
index 0a79182f4..000000000
--- a/relay/server/listener/listener.go
+++ /dev/null
@@ -1,14 +0,0 @@
-package listener
-
-import (
- "context"
- "net"
-
- "github.com/netbirdio/netbird/relay/protocol"
-)
-
-type Listener interface {
- Listen(func(conn net.Conn)) error
- Shutdown(ctx context.Context) error
- Protocol() protocol.Protocol
-}
diff --git a/relay/server/listener/quic/conn.go b/relay/server/listener/quic/conn.go
index 6e2201bf7..d8dafcd1f 100644
--- a/relay/server/listener/quic/conn.go
+++ b/relay/server/listener/quic/conn.go
@@ -3,33 +3,26 @@ package quic
import (
"context"
"errors"
- "fmt"
"net"
"sync"
- "time"
"github.com/quic-go/quic-go"
)
type Conn struct {
- session *quic.Conn
- closed bool
- closedMu sync.Mutex
- ctx context.Context
- ctxCancel context.CancelFunc
+ session *quic.Conn
+ closed bool
+ closedMu sync.Mutex
}
func NewConn(session *quic.Conn) *Conn {
- ctx, cancel := context.WithCancel(context.Background())
return &Conn{
- session: session,
- ctx: ctx,
- ctxCancel: cancel,
+ session: session,
}
}
-func (c *Conn) Read(b []byte) (n int, err error) {
- dgram, err := c.session.ReceiveDatagram(c.ctx)
+func (c *Conn) Read(ctx context.Context, b []byte) (n int, err error) {
+ dgram, err := c.session.ReceiveDatagram(ctx)
if err != nil {
return 0, c.remoteCloseErrHandling(err)
}
@@ -38,33 +31,17 @@ func (c *Conn) Read(b []byte) (n int, err error) {
return n, nil
}
-func (c *Conn) Write(b []byte) (int, error) {
+func (c *Conn) Write(_ context.Context, b []byte) (int, error) {
if err := c.session.SendDatagram(b); err != nil {
return 0, c.remoteCloseErrHandling(err)
}
return len(b), nil
}
-func (c *Conn) LocalAddr() net.Addr {
- return c.session.LocalAddr()
-}
-
func (c *Conn) RemoteAddr() net.Addr {
return c.session.RemoteAddr()
}
-func (c *Conn) SetReadDeadline(t time.Time) error {
- return nil
-}
-
-func (c *Conn) SetWriteDeadline(t time.Time) error {
- return fmt.Errorf("SetWriteDeadline is not implemented")
-}
-
-func (c *Conn) SetDeadline(t time.Time) error {
- return fmt.Errorf("SetDeadline is not implemented")
-}
-
func (c *Conn) Close() error {
c.closedMu.Lock()
if c.closed {
@@ -74,8 +51,6 @@ func (c *Conn) Close() error {
c.closed = true
c.closedMu.Unlock()
- c.ctxCancel() // Cancel the context
-
sessionErr := c.session.CloseWithError(0, "normal closure")
return sessionErr
}
diff --git a/relay/server/listener/quic/listener.go b/relay/server/listener/quic/listener.go
index 797223e74..68f0e03c0 100644
--- a/relay/server/listener/quic/listener.go
+++ b/relay/server/listener/quic/listener.go
@@ -5,12 +5,12 @@ import (
"crypto/tls"
"errors"
"fmt"
- "net"
"github.com/quic-go/quic-go"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/relay/protocol"
+ relaylistener "github.com/netbirdio/netbird/relay/server/listener"
nbRelay "github.com/netbirdio/netbird/shared/relay"
)
@@ -25,7 +25,7 @@ type Listener struct {
listener *quic.Listener
}
-func (l *Listener) Listen(acceptFn func(conn net.Conn)) error {
+func (l *Listener) Listen(acceptFn func(conn relaylistener.Conn)) error {
quicCfg := &quic.Config{
EnableDatagrams: true,
InitialPacketSize: nbRelay.QUICInitialPacketSize,
diff --git a/relay/server/listener/ws/conn.go b/relay/server/listener/ws/conn.go
index d5bce56f7..c22b5719d 100644
--- a/relay/server/listener/ws/conn.go
+++ b/relay/server/listener/ws/conn.go
@@ -18,25 +18,21 @@ const (
type Conn struct {
*websocket.Conn
- lAddr *net.TCPAddr
rAddr *net.TCPAddr
closed bool
closedMu sync.Mutex
- ctx context.Context
}
-func NewConn(wsConn *websocket.Conn, lAddr, rAddr *net.TCPAddr) *Conn {
+func NewConn(wsConn *websocket.Conn, rAddr *net.TCPAddr) *Conn {
return &Conn{
Conn: wsConn,
- lAddr: lAddr,
rAddr: rAddr,
- ctx: context.Background(),
}
}
-func (c *Conn) Read(b []byte) (n int, err error) {
- t, r, err := c.Reader(c.ctx)
+func (c *Conn) Read(ctx context.Context, b []byte) (n int, err error) {
+ t, r, err := c.Reader(ctx)
if err != nil {
return 0, c.ioErrHandling(err)
}
@@ -56,34 +52,18 @@ func (c *Conn) Read(b []byte) (n int, err error) {
// Write writes a binary message with the given payload.
// It does not block until fill the internal buffer.
// If the buffer filled up, wait until the buffer is drained or timeout.
-func (c *Conn) Write(b []byte) (int, error) {
- ctx, ctxCancel := context.WithTimeout(c.ctx, writeTimeout)
+func (c *Conn) Write(ctx context.Context, b []byte) (int, error) {
+ ctx, ctxCancel := context.WithTimeout(ctx, writeTimeout)
defer ctxCancel()
err := c.Conn.Write(ctx, websocket.MessageBinary, b)
return len(b), err
}
-func (c *Conn) LocalAddr() net.Addr {
- return c.lAddr
-}
-
func (c *Conn) RemoteAddr() net.Addr {
return c.rAddr
}
-func (c *Conn) SetReadDeadline(t time.Time) error {
- return fmt.Errorf("SetReadDeadline is not implemented")
-}
-
-func (c *Conn) SetWriteDeadline(t time.Time) error {
- return fmt.Errorf("SetWriteDeadline is not implemented")
-}
-
-func (c *Conn) SetDeadline(t time.Time) error {
- return fmt.Errorf("SetDeadline is not implemented")
-}
-
func (c *Conn) Close() error {
c.closedMu.Lock()
c.closed = true
diff --git a/relay/server/listener/ws/listener.go b/relay/server/listener/ws/listener.go
index 12219e29b..ba175f901 100644
--- a/relay/server/listener/ws/listener.go
+++ b/relay/server/listener/ws/listener.go
@@ -7,11 +7,13 @@ import (
"fmt"
"net"
"net/http"
+ "time"
"github.com/coder/websocket"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/relay/protocol"
+ relaylistener "github.com/netbirdio/netbird/relay/server/listener"
"github.com/netbirdio/netbird/shared/relay"
)
@@ -27,18 +29,19 @@ type Listener struct {
TLSConfig *tls.Config
server *http.Server
- acceptFn func(conn net.Conn)
+ acceptFn func(conn relaylistener.Conn)
}
-func (l *Listener) Listen(acceptFn func(conn net.Conn)) error {
+func (l *Listener) Listen(acceptFn func(conn relaylistener.Conn)) error {
l.acceptFn = acceptFn
mux := http.NewServeMux()
mux.HandleFunc(URLPath, l.onAccept)
l.server = &http.Server{
- Addr: l.Address,
- Handler: mux,
- TLSConfig: l.TLSConfig,
+ Addr: l.Address,
+ Handler: mux,
+ TLSConfig: l.TLSConfig,
+ ReadHeaderTimeout: 5 * time.Second,
}
log.Infof("WS server listening address: %s", l.Address)
@@ -93,18 +96,9 @@ func (l *Listener) onAccept(w http.ResponseWriter, r *http.Request) {
return
}
- lAddr, err := net.ResolveTCPAddr("tcp", l.server.Addr)
- if err != nil {
- err = wsConn.Close(websocket.StatusInternalError, "internal error")
- if err != nil {
- log.Errorf("failed to close ws connection: %s", err)
- }
- return
- }
-
log.Infof("WS client connected from: %s", rAddr)
- conn := NewConn(wsConn, lAddr, rAddr)
+ conn := NewConn(wsConn, rAddr)
l.acceptFn(conn)
}
diff --git a/relay/server/peer.go b/relay/server/peer.go
index c5ff41857..8376cdfa7 100644
--- a/relay/server/peer.go
+++ b/relay/server/peer.go
@@ -10,6 +10,7 @@ import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/relay/metrics"
+ "github.com/netbirdio/netbird/relay/server/listener"
"github.com/netbirdio/netbird/relay/server/store"
"github.com/netbirdio/netbird/shared/relay/healthcheck"
"github.com/netbirdio/netbird/shared/relay/messages"
@@ -26,11 +27,14 @@ type Peer struct {
metrics *metrics.Metrics
log *log.Entry
id messages.PeerID
- conn net.Conn
+ conn listener.Conn
connMu sync.RWMutex
store *store.Store
notifier *store.PeerNotifier
+ ctx context.Context
+ ctxCancel context.CancelFunc
+
peersListener *store.Listener
// between the online peer collection step and the notification sending should not be sent offline notifications from another thread
@@ -38,14 +42,17 @@ type Peer struct {
}
// NewPeer creates a new Peer instance and prepare custom logging
-func NewPeer(metrics *metrics.Metrics, id messages.PeerID, conn net.Conn, store *store.Store, notifier *store.PeerNotifier) *Peer {
+func NewPeer(metrics *metrics.Metrics, id messages.PeerID, conn listener.Conn, store *store.Store, notifier *store.PeerNotifier) *Peer {
+ ctx, cancel := context.WithCancel(context.Background())
p := &Peer{
- metrics: metrics,
- log: log.WithField("peer_id", id.String()),
- id: id,
- conn: conn,
- store: store,
- notifier: notifier,
+ metrics: metrics,
+ log: log.WithField("peer_id", id.String()),
+ id: id,
+ conn: conn,
+ store: store,
+ notifier: notifier,
+ ctx: ctx,
+ ctxCancel: cancel,
}
return p
@@ -57,6 +64,7 @@ func NewPeer(metrics *metrics.Metrics, id messages.PeerID, conn net.Conn, store
func (p *Peer) Work() {
p.peersListener = p.notifier.NewListener(p.sendPeersOnline, p.sendPeersWentOffline)
defer func() {
+ p.ctxCancel()
p.notifier.RemoveListener(p.peersListener)
if err := p.conn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
@@ -64,8 +72,7 @@ func (p *Peer) Work() {
}
}()
- ctx, cancel := context.WithCancel(context.Background())
- defer cancel()
+ ctx := p.ctx
hc := healthcheck.NewSender(p.log)
go hc.StartHealthCheck(ctx)
@@ -73,7 +80,7 @@ func (p *Peer) Work() {
buf := make([]byte, bufferSize)
for {
- n, err := p.conn.Read(buf)
+ n, err := p.conn.Read(ctx, buf)
if err != nil {
if !errors.Is(err, net.ErrClosed) {
p.log.Errorf("failed to read message: %s", err)
@@ -131,10 +138,10 @@ func (p *Peer) handleMsgType(ctx context.Context, msgType messages.MsgType, hc *
}
// Write writes data to the connection
-func (p *Peer) Write(b []byte) (int, error) {
+func (p *Peer) Write(ctx context.Context, b []byte) (int, error) {
p.connMu.RLock()
defer p.connMu.RUnlock()
- return p.conn.Write(b)
+ return p.conn.Write(ctx, b)
}
// CloseGracefully closes the connection with the peer gracefully. Send a close message to the client and close the
@@ -147,6 +154,7 @@ func (p *Peer) CloseGracefully(ctx context.Context) {
p.log.Errorf("failed to send close message to peer: %s", p.String())
}
+ p.ctxCancel()
if err := p.conn.Close(); err != nil {
p.log.Errorf(errCloseConn, err)
}
@@ -156,6 +164,7 @@ func (p *Peer) Close() {
p.connMu.Lock()
defer p.connMu.Unlock()
+ p.ctxCancel()
if err := p.conn.Close(); err != nil {
p.log.Errorf(errCloseConn, err)
}
@@ -170,26 +179,15 @@ func (p *Peer) writeWithTimeout(ctx context.Context, buf []byte) error {
ctx, cancel := context.WithTimeout(ctx, 3*time.Second)
defer cancel()
- writeDone := make(chan struct{})
- var err error
- go func() {
- _, err = p.conn.Write(buf)
- close(writeDone)
- }()
-
- select {
- case <-ctx.Done():
- return ctx.Err()
- case <-writeDone:
- return err
- }
+ _, err := p.conn.Write(ctx, buf)
+ return err
}
func (p *Peer) handleHealthcheckEvents(ctx context.Context, hc *healthcheck.Sender) {
for {
select {
case <-hc.HealthCheck:
- _, err := p.Write(messages.MarshalHealthcheck())
+ _, err := p.Write(ctx, messages.MarshalHealthcheck())
if err != nil {
p.log.Errorf("failed to send healthcheck message: %s", err)
return
@@ -228,12 +226,12 @@ func (p *Peer) handleTransportMsg(msg []byte) {
return
}
- n, err := dp.Write(msg)
+ n, err := dp.Write(dp.ctx, msg)
if err != nil {
p.log.Errorf("failed to write transport message to: %s", dp.String())
return
}
- p.metrics.TransferBytesSent.Add(context.Background(), int64(n))
+ p.metrics.TransferBytesSent.Add(p.ctx, int64(n))
}
func (p *Peer) handleSubscribePeerState(msg []byte) {
@@ -276,7 +274,7 @@ func (p *Peer) sendPeersOnline(peers []messages.PeerID) {
}
for n, msg := range msgs {
- if _, err := p.Write(msg); err != nil {
+ if _, err := p.Write(p.ctx, msg); err != nil {
p.log.Errorf("failed to write %d. peers offline message: %s", n, err)
}
}
@@ -293,7 +291,7 @@ func (p *Peer) sendPeersWentOffline(peers []messages.PeerID) {
}
for n, msg := range msgs {
- if _, err := p.Write(msg); err != nil {
+ if _, err := p.Write(p.ctx, msg); err != nil {
p.log.Errorf("failed to write %d. peers offline message: %s", n, err)
}
}
diff --git a/relay/server/relay.go b/relay/server/relay.go
index bb355f58f..56add8bea 100644
--- a/relay/server/relay.go
+++ b/relay/server/relay.go
@@ -3,7 +3,6 @@ package server
import (
"context"
"fmt"
- "net"
"net/url"
"sync"
"time"
@@ -13,11 +12,20 @@ import (
"go.opentelemetry.io/otel/metric"
"github.com/netbirdio/netbird/relay/healthcheck/peerid"
+ "github.com/netbirdio/netbird/relay/protocol"
+ "github.com/netbirdio/netbird/relay/server/listener"
+
//nolint:staticcheck
"github.com/netbirdio/netbird/relay/metrics"
"github.com/netbirdio/netbird/relay/server/store"
)
+type Listener interface {
+ Listen(func(conn listener.Conn)) error
+ Shutdown(ctx context.Context) error
+ Protocol() protocol.Protocol
+}
+
type Config struct {
Meter metric.Meter
ExposedAddress string
@@ -109,7 +117,7 @@ func NewRelay(config Config) (*Relay, error) {
}
// Accept start to handle a new peer connection
-func (r *Relay) Accept(conn net.Conn) {
+func (r *Relay) Accept(conn listener.Conn) {
acceptTime := time.Now()
r.closeMu.RLock()
defer r.closeMu.RUnlock()
@@ -117,12 +125,15 @@ func (r *Relay) Accept(conn net.Conn) {
return
}
+ hsCtx, hsCancel := context.WithTimeout(context.Background(), handshakeTimeout)
+ defer hsCancel()
+
h := handshake{
conn: conn,
validator: r.validator,
preparedMsg: r.preparedMsg,
}
- peerID, err := h.handshakeReceive()
+ peerID, err := h.handshakeReceive(hsCtx)
if err != nil {
if peerid.IsHealthCheck(peerID) {
log.Debugf("health check connection from %s", conn.RemoteAddr())
@@ -154,7 +165,7 @@ func (r *Relay) Accept(conn net.Conn) {
r.metrics.PeerDisconnected(peer.String())
}()
- if err := h.handshakeResponse(); err != nil {
+ if err := h.handshakeResponse(hsCtx); err != nil {
log.Errorf("failed to send handshake response, close peer: %s", err)
peer.Close()
}
diff --git a/relay/server/server.go b/relay/server/server.go
index a0f7eb73c..340da55b8 100644
--- a/relay/server/server.go
+++ b/relay/server/server.go
@@ -3,7 +3,6 @@ package server
import (
"context"
"crypto/tls"
- "net"
"net/url"
"sync"
@@ -31,7 +30,7 @@ type ListenerConfig struct {
// In a new HTTP connection, the server will accept the connection and pass it to the Relay server via the Accept method.
type Server struct {
relay *Relay
- listeners []listener.Listener
+ listeners []Listener
listenerMux sync.Mutex
}
@@ -56,7 +55,7 @@ func NewServer(config Config) (*Server, error) {
}
return &Server{
relay: relay,
- listeners: make([]listener.Listener, 0, 2),
+ listeners: make([]Listener, 0, 2),
}, nil
}
@@ -86,7 +85,7 @@ func (r *Server) Listen(cfg ListenerConfig) error {
wg := sync.WaitGroup{}
for _, l := range r.listeners {
wg.Add(1)
- go func(listener listener.Listener) {
+ go func(listener Listener) {
defer wg.Done()
errChan <- listener.Listen(r.relay.Accept)
}(l)
@@ -139,6 +138,6 @@ func (r *Server) InstanceURL() url.URL {
// RelayAccept returns the relay's Accept function for handling incoming connections.
// This allows external HTTP handlers to route connections to the relay without
// starting the relay's own listeners.
-func (r *Server) RelayAccept() func(conn net.Conn) {
+func (r *Server) RelayAccept() func(conn listener.Conn) {
return r.relay.Accept
}