diff --git a/client/cmd/debug.go b/client/cmd/debug.go index 0e2717756..e3d3afe5f 100644 --- a/client/cmd/debug.go +++ b/client/cmd/debug.go @@ -199,9 +199,11 @@ func runForDuration(cmd *cobra.Command, args []string) error { cmd.Println("Log level set to trace.") } + needsRestoreUp := false if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil { cmd.PrintErrf("Failed to bring service down: %v\n", status.Convert(err).Message()) } else { + needsRestoreUp = !stateWasDown cmd.Println("netbird down") } @@ -217,6 +219,7 @@ func runForDuration(cmd *cobra.Command, args []string) error { if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil { cmd.PrintErrf("Failed to bring service up: %v\n", status.Convert(err).Message()) } else { + needsRestoreUp = false cmd.Println("netbird up") } @@ -264,6 +267,14 @@ func runForDuration(cmd *cobra.Command, args []string) error { return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message()) } + if needsRestoreUp { + if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil { + cmd.PrintErrf("Failed to restore service up state: %v\n", status.Convert(err).Message()) + } else { + cmd.Println("netbird up (restored)") + } + } + if stateWasDown { if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil { cmd.PrintErrf("Failed to restore service down state: %v\n", status.Convert(err).Message()) diff --git a/client/internal/connect.go b/client/internal/connect.go index 7ee745ae6..97c350d4e 100644 --- a/client/internal/connect.go +++ b/client/internal/connect.go @@ -114,6 +114,7 @@ func (c *ConnectClient) RunOniOS( fileDescriptor int32, networkChangeListener listener.NetworkChangeListener, dnsManager dns.IosDnsManager, + dnsAddresses []netip.AddrPort, stateFilePath string, ) error { // Set GC percent to 5% to reduce memory usage as iOS only allows 50MB of memory for the extension. @@ -123,6 +124,7 @@ func (c *ConnectClient) RunOniOS( FileDescriptor: fileDescriptor, NetworkChangeListener: networkChangeListener, DnsManager: dnsManager, + HostDNSAddresses: dnsAddresses, StateFilePath: stateFilePath, } return c.run(mobileDependency, nil, "") diff --git a/client/internal/debug/debug.go b/client/internal/debug/debug.go index 330cebe49..00f8b1a8d 100644 --- a/client/internal/debug/debug.go +++ b/client/internal/debug/debug.go @@ -25,6 +25,7 @@ import ( "google.golang.org/protobuf/encoding/protojson" "github.com/netbirdio/netbird/client/anonymize" + "github.com/netbirdio/netbird/client/configs" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/profilemanager" "github.com/netbirdio/netbird/client/internal/updater/installer" @@ -52,6 +53,7 @@ resolved_domains.txt: Anonymized resolved domain IP addresses from the status re config.txt: Anonymized configuration information of the NetBird client. network_map.json: Anonymized sync response containing peer configurations, routes, DNS settings, and firewall rules. state.json: Anonymized client state dump containing netbird states for the active profile. +service_params.json: Sanitized service install parameters (service.json). Sensitive environment variable values are masked. Only present when service.json exists. metrics.txt: Buffered client metrics in InfluxDB line protocol format. Only present when metrics collection is enabled. Peer identifiers are anonymized. mutex.prof: Mutex profiling information. goroutine.prof: Goroutine profiling information. @@ -359,6 +361,10 @@ func (g *BundleGenerator) createArchive() error { log.Errorf("failed to add corrupted state files to debug bundle: %v", err) } + if err := g.addServiceParams(); err != nil { + log.Errorf("failed to add service params to debug bundle: %v", err) + } + if err := g.addMetrics(); err != nil { log.Errorf("failed to add metrics to debug bundle: %v", err) } @@ -488,6 +494,90 @@ func (g *BundleGenerator) addConfig() error { return nil } +const ( + serviceParamsFile = "service.json" + serviceParamsBundle = "service_params.json" + maskedValue = "***" + envVarPrefix = "NB_" + jsonKeyManagementURL = "management_url" + jsonKeyServiceEnv = "service_env_vars" +) + +var sensitiveEnvSubstrings = []string{"key", "token", "secret", "password", "credential"} + +// addServiceParams reads the service.json file and adds a sanitized version to the bundle. +// Non-NB_ env vars and vars with sensitive names are masked. Other NB_ values are anonymized. +func (g *BundleGenerator) addServiceParams() error { + path := filepath.Join(configs.StateDir, serviceParamsFile) + + data, err := os.ReadFile(path) + if err != nil { + if os.IsNotExist(err) { + return nil + } + return fmt.Errorf("read service params: %w", err) + } + + var params map[string]any + if err := json.Unmarshal(data, ¶ms); err != nil { + return fmt.Errorf("parse service params: %w", err) + } + + if g.anonymize { + if mgmtURL, ok := params[jsonKeyManagementURL].(string); ok && mgmtURL != "" { + params[jsonKeyManagementURL] = g.anonymizer.AnonymizeURI(mgmtURL) + } + } + + g.sanitizeServiceEnvVars(params) + + sanitizedData, err := json.MarshalIndent(params, "", " ") + if err != nil { + return fmt.Errorf("marshal sanitized service params: %w", err) + } + + if err := g.addFileToZip(bytes.NewReader(sanitizedData), serviceParamsBundle); err != nil { + return fmt.Errorf("add service params to zip: %w", err) + } + + return nil +} + +// sanitizeServiceEnvVars masks or anonymizes env var values in service params. +// Non-NB_ vars and vars with sensitive names (key, token, etc.) are fully masked. +// Other NB_ var values are passed through the anonymizer when anonymization is enabled. +func (g *BundleGenerator) sanitizeServiceEnvVars(params map[string]any) { + envVars, ok := params[jsonKeyServiceEnv].(map[string]any) + if !ok { + return + } + + sanitized := make(map[string]any, len(envVars)) + for k, v := range envVars { + val, _ := v.(string) + switch { + case !strings.HasPrefix(k, envVarPrefix) || isSensitiveEnvVar(k): + sanitized[k] = maskedValue + case g.anonymize: + sanitized[k] = g.anonymizer.AnonymizeString(val) + default: + sanitized[k] = val + } + } + params[jsonKeyServiceEnv] = sanitized +} + +// isSensitiveEnvVar returns true for env var names that may contain secrets. +func isSensitiveEnvVar(key string) bool { + lower := strings.ToLower(key) + for _, s := range sensitiveEnvSubstrings { + if strings.Contains(lower, s) { + return true + } + } + return false +} + func (g *BundleGenerator) addCommonConfigFields(configContent *strings.Builder) { configContent.WriteString("NetBird Client Configuration:\n\n") diff --git a/client/internal/debug/debug_test.go b/client/internal/debug/debug_test.go index e242b8b1b..49c18c679 100644 --- a/client/internal/debug/debug_test.go +++ b/client/internal/debug/debug_test.go @@ -1,8 +1,12 @@ package debug import ( + "archive/zip" + "bytes" "encoding/json" "net" + "os" + "path/filepath" "strings" "testing" @@ -10,6 +14,7 @@ import ( "github.com/stretchr/testify/require" "github.com/netbirdio/netbird/client/anonymize" + "github.com/netbirdio/netbird/client/configs" mgmProto "github.com/netbirdio/netbird/shared/management/proto" ) @@ -420,6 +425,226 @@ func TestAnonymizeNetworkMap(t *testing.T) { } } +func TestIsSensitiveEnvVar(t *testing.T) { + tests := []struct { + key string + sensitive bool + }{ + {"NB_SETUP_KEY", true}, + {"NB_API_TOKEN", true}, + {"NB_CLIENT_SECRET", true}, + {"NB_PASSWORD", true}, + {"NB_CREDENTIAL", true}, + {"NB_LOG_LEVEL", false}, + {"NB_MANAGEMENT_URL", false}, + {"NB_HOSTNAME", false}, + {"HOME", false}, + {"PATH", false}, + } + for _, tt := range tests { + t.Run(tt.key, func(t *testing.T) { + assert.Equal(t, tt.sensitive, isSensitiveEnvVar(tt.key)) + }) + } +} + +func TestSanitizeServiceEnvVars(t *testing.T) { + tests := []struct { + name string + anonymize bool + input map[string]any + check func(t *testing.T, params map[string]any) + }{ + { + name: "no env vars key", + anonymize: false, + input: map[string]any{"management_url": "https://mgmt.example.com"}, + check: func(t *testing.T, params map[string]any) { + t.Helper() + assert.Equal(t, "https://mgmt.example.com", params["management_url"], "non-env fields should be untouched") + _, ok := params[jsonKeyServiceEnv] + assert.False(t, ok, "service_env_vars should not be added") + }, + }, + { + name: "non-NB vars are masked", + anonymize: false, + input: map[string]any{ + jsonKeyServiceEnv: map[string]any{ + "HOME": "/root", + "PATH": "/usr/bin", + "NB_LOG_LEVEL": "debug", + }, + }, + check: func(t *testing.T, params map[string]any) { + t.Helper() + env := params[jsonKeyServiceEnv].(map[string]any) + assert.Equal(t, maskedValue, env["HOME"], "non-NB_ var should be masked") + assert.Equal(t, maskedValue, env["PATH"], "non-NB_ var should be masked") + assert.Equal(t, "debug", env["NB_LOG_LEVEL"], "safe NB_ var should pass through") + }, + }, + { + name: "sensitive NB vars are masked", + anonymize: false, + input: map[string]any{ + jsonKeyServiceEnv: map[string]any{ + "NB_SETUP_KEY": "abc123", + "NB_API_TOKEN": "tok_xyz", + "NB_LOG_LEVEL": "info", + }, + }, + check: func(t *testing.T, params map[string]any) { + t.Helper() + env := params[jsonKeyServiceEnv].(map[string]any) + assert.Equal(t, maskedValue, env["NB_SETUP_KEY"], "sensitive NB_ var should be masked") + assert.Equal(t, maskedValue, env["NB_API_TOKEN"], "sensitive NB_ var should be masked") + assert.Equal(t, "info", env["NB_LOG_LEVEL"], "safe NB_ var should pass through") + }, + }, + { + name: "safe NB vars anonymized when anonymize is true", + anonymize: true, + input: map[string]any{ + jsonKeyServiceEnv: map[string]any{ + "NB_MANAGEMENT_URL": "https://mgmt.example.com:443", + "NB_LOG_LEVEL": "debug", + "NB_SETUP_KEY": "secret", + "SOME_OTHER": "val", + }, + }, + check: func(t *testing.T, params map[string]any) { + t.Helper() + env := params[jsonKeyServiceEnv].(map[string]any) + // Safe NB_ values should be anonymized (not the original, not masked) + mgmtVal := env["NB_MANAGEMENT_URL"].(string) + assert.NotEqual(t, "https://mgmt.example.com:443", mgmtVal, "should be anonymized") + assert.NotEqual(t, maskedValue, mgmtVal, "should not be masked") + + logVal := env["NB_LOG_LEVEL"].(string) + assert.NotEqual(t, maskedValue, logVal, "safe NB_ var should not be masked") + + // Sensitive and non-NB_ still masked + assert.Equal(t, maskedValue, env["NB_SETUP_KEY"]) + assert.Equal(t, maskedValue, env["SOME_OTHER"]) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + anonymizer := anonymize.NewAnonymizer(anonymize.DefaultAddresses()) + g := &BundleGenerator{ + anonymize: tt.anonymize, + anonymizer: anonymizer, + } + g.sanitizeServiceEnvVars(tt.input) + tt.check(t, tt.input) + }) + } +} + +func TestAddServiceParams(t *testing.T) { + t.Run("missing service.json returns nil", func(t *testing.T) { + g := &BundleGenerator{ + anonymizer: anonymize.NewAnonymizer(anonymize.DefaultAddresses()), + } + + origStateDir := configs.StateDir + configs.StateDir = t.TempDir() + t.Cleanup(func() { configs.StateDir = origStateDir }) + + err := g.addServiceParams() + assert.NoError(t, err) + }) + + t.Run("management_url anonymized when anonymize is true", func(t *testing.T) { + dir := t.TempDir() + origStateDir := configs.StateDir + configs.StateDir = dir + t.Cleanup(func() { configs.StateDir = origStateDir }) + + input := map[string]any{ + jsonKeyManagementURL: "https://api.example.com:443", + jsonKeyServiceEnv: map[string]any{ + "NB_LOG_LEVEL": "trace", + }, + } + data, err := json.Marshal(input) + require.NoError(t, err) + require.NoError(t, os.WriteFile(filepath.Join(dir, serviceParamsFile), data, 0600)) + + var buf bytes.Buffer + zw := zip.NewWriter(&buf) + + g := &BundleGenerator{ + anonymize: true, + anonymizer: anonymize.NewAnonymizer(anonymize.DefaultAddresses()), + archive: zw, + } + + require.NoError(t, g.addServiceParams()) + require.NoError(t, zw.Close()) + + zr, err := zip.NewReader(bytes.NewReader(buf.Bytes()), int64(buf.Len())) + require.NoError(t, err) + require.Len(t, zr.File, 1) + assert.Equal(t, serviceParamsBundle, zr.File[0].Name) + + rc, err := zr.File[0].Open() + require.NoError(t, err) + defer rc.Close() + + var result map[string]any + require.NoError(t, json.NewDecoder(rc).Decode(&result)) + + mgmt := result[jsonKeyManagementURL].(string) + assert.NotEqual(t, "https://api.example.com:443", mgmt, "management_url should be anonymized") + assert.NotEmpty(t, mgmt) + + env := result[jsonKeyServiceEnv].(map[string]any) + assert.NotEqual(t, maskedValue, env["NB_LOG_LEVEL"], "safe NB_ var should not be masked") + }) + + t.Run("management_url preserved when anonymize is false", func(t *testing.T) { + dir := t.TempDir() + origStateDir := configs.StateDir + configs.StateDir = dir + t.Cleanup(func() { configs.StateDir = origStateDir }) + + input := map[string]any{ + jsonKeyManagementURL: "https://api.example.com:443", + } + data, err := json.Marshal(input) + require.NoError(t, err) + require.NoError(t, os.WriteFile(filepath.Join(dir, serviceParamsFile), data, 0600)) + + var buf bytes.Buffer + zw := zip.NewWriter(&buf) + + g := &BundleGenerator{ + anonymize: false, + anonymizer: anonymize.NewAnonymizer(anonymize.DefaultAddresses()), + archive: zw, + } + + require.NoError(t, g.addServiceParams()) + require.NoError(t, zw.Close()) + + zr, err := zip.NewReader(bytes.NewReader(buf.Bytes()), int64(buf.Len())) + require.NoError(t, err) + + rc, err := zr.File[0].Open() + require.NoError(t, err) + defer rc.Close() + + var result map[string]any + require.NoError(t, json.NewDecoder(rc).Decode(&result)) + + assert.Equal(t, "https://api.example.com:443", result[jsonKeyManagementURL], "management_url should be preserved") + }) +} + // Helper function to check if IP is in CGNAT range func isInCGNATRange(ip net.IP) bool { cgnat := net.IPNet{ diff --git a/client/internal/dns/local/local_test.go b/client/internal/dns/local/local_test.go index 73f70035f..2c6b7dbc3 100644 --- a/client/internal/dns/local/local_test.go +++ b/client/internal/dns/local/local_test.go @@ -1263,9 +1263,9 @@ func TestLocalResolver_AuthoritativeFlag(t *testing.T) { }) } -// TestLocalResolver_Stop tests cleanup on Stop +// TestLocalResolver_Stop tests cleanup on GracefullyStop func TestLocalResolver_Stop(t *testing.T) { - t.Run("Stop clears all state", func(t *testing.T) { + t.Run("GracefullyStop clears all state", func(t *testing.T) { resolver := NewResolver() resolver.Update([]nbdns.CustomZone{{ Domain: "example.com.", @@ -1285,7 +1285,7 @@ func TestLocalResolver_Stop(t *testing.T) { assert.False(t, resolver.isInManagedZone("host.example.com.")) }) - t.Run("Stop is safe to call multiple times", func(t *testing.T) { + t.Run("GracefullyStop is safe to call multiple times", func(t *testing.T) { resolver := NewResolver() resolver.Update([]nbdns.CustomZone{{ Domain: "example.com.", @@ -1299,7 +1299,7 @@ func TestLocalResolver_Stop(t *testing.T) { resolver.Stop() }) - t.Run("Stop cancels in-flight external resolution", func(t *testing.T) { + t.Run("GracefullyStop cancels in-flight external resolution", func(t *testing.T) { resolver := NewResolver() lookupStarted := make(chan struct{}) diff --git a/client/internal/dns/server.go b/client/internal/dns/server.go index d4fda5db3..f7865047b 100644 --- a/client/internal/dns/server.go +++ b/client/internal/dns/server.go @@ -187,11 +187,16 @@ func NewDefaultServerIos( ctx context.Context, wgInterface WGIface, iosDnsManager IosDnsManager, + hostsDnsList []netip.AddrPort, statusRecorder *peer.Status, disableSys bool, ) *DefaultServer { + log.Debugf("iOS host dns address list is: %v", hostsDnsList) ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder, nil, disableSys) ds.iosDnsManager = iosDnsManager + ds.hostsDNSHolder.set(hostsDnsList) + ds.permanent = true + ds.addHostRootZone() return ds } diff --git a/client/internal/engine.go b/client/internal/engine.go index 16410519b..ce4d71e35 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -47,6 +47,7 @@ import ( "github.com/netbirdio/netbird/client/internal/peer/guard" icemaker "github.com/netbirdio/netbird/client/internal/peer/ice" "github.com/netbirdio/netbird/client/internal/peerstore" + "github.com/netbirdio/netbird/client/internal/portforward" "github.com/netbirdio/netbird/client/internal/profilemanager" "github.com/netbirdio/netbird/client/internal/relay" "github.com/netbirdio/netbird/client/internal/rosenpass" @@ -214,9 +215,10 @@ type Engine struct { // checks are the client-applied posture checks that need to be evaluated on the client checks []*mgmProto.Checks - relayManager *relayClient.Manager - stateManager *statemanager.Manager - srWatcher *guard.SRWatcher + relayManager *relayClient.Manager + stateManager *statemanager.Manager + portForwardManager *portforward.Manager + srWatcher *guard.SRWatcher // Sync response persistence (protected by syncRespMux) syncRespMux sync.RWMutex @@ -263,26 +265,27 @@ func NewEngine( mobileDep MobileDependency, ) *Engine { engine := &Engine{ - clientCtx: clientCtx, - clientCancel: clientCancel, - signal: services.SignalClient, - signaler: peer.NewSignaler(services.SignalClient, config.WgPrivateKey), - mgmClient: services.MgmClient, - relayManager: services.RelayManager, - peerStore: peerstore.NewConnStore(), - syncMsgMux: &sync.Mutex{}, - config: config, - mobileDep: mobileDep, - STUNs: []*stun.URI{}, - TURNs: []*stun.URI{}, - networkSerial: 0, - statusRecorder: services.StatusRecorder, - stateManager: services.StateManager, - checks: services.Checks, - probeStunTurn: relay.NewStunTurnProbe(relay.DefaultCacheTTL), - jobExecutor: jobexec.NewExecutor(), - clientMetrics: services.ClientMetrics, - updateManager: services.UpdateManager, + clientCtx: clientCtx, + clientCancel: clientCancel, + signal: services.SignalClient, + signaler: peer.NewSignaler(services.SignalClient, config.WgPrivateKey), + mgmClient: services.MgmClient, + relayManager: services.RelayManager, + peerStore: peerstore.NewConnStore(), + syncMsgMux: &sync.Mutex{}, + config: config, + mobileDep: mobileDep, + STUNs: []*stun.URI{}, + TURNs: []*stun.URI{}, + networkSerial: 0, + statusRecorder: services.StatusRecorder, + stateManager: services.StateManager, + portForwardManager: portforward.NewManager(), + checks: services.Checks, + probeStunTurn: relay.NewStunTurnProbe(relay.DefaultCacheTTL), + jobExecutor: jobexec.NewExecutor(), + clientMetrics: services.ClientMetrics, + updateManager: services.UpdateManager, } log.Infof("I am: %s", config.WgPrivateKey.PublicKey().String()) @@ -541,6 +544,13 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL) // conntrack entries from being created before the rules are in place e.setupWGProxyNoTrack() + // Start after interface is up since port may have been resolved from 0 or changed if occupied + e.shutdownWg.Add(1) + go func() { + defer e.shutdownWg.Done() + e.portForwardManager.Start(e.ctx, uint16(e.config.WgPort)) + }() + // Set the WireGuard interface for rosenpass after interface is up if e.rpManager != nil { e.rpManager.SetInterface(e.wgInterface) @@ -1627,12 +1637,13 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix, agentV } serviceDependencies := peer.ServiceDependencies{ - StatusRecorder: e.statusRecorder, - Signaler: e.signaler, - IFaceDiscover: e.mobileDep.IFaceDiscover, - RelayManager: e.relayManager, - SrWatcher: e.srWatcher, - MetricsRecorder: e.clientMetrics, + StatusRecorder: e.statusRecorder, + Signaler: e.signaler, + IFaceDiscover: e.mobileDep.IFaceDiscover, + RelayManager: e.relayManager, + SrWatcher: e.srWatcher, + PortForwardManager: e.portForwardManager, + MetricsRecorder: e.clientMetrics, } peerConn, err := peer.NewConn(config, serviceDependencies) if err != nil { @@ -1789,6 +1800,12 @@ func (e *Engine) close() { if e.rpManager != nil { _ = e.rpManager.Close() } + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if err := e.portForwardManager.GracefullyStop(ctx); err != nil { + log.Warnf("failed to gracefully stop port forwarding manager: %s", err) + } } func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, bool, error) { @@ -1894,7 +1911,7 @@ func (e *Engine) newDnsServer(dnsConfig *nbdns.Config) (dns.Server, error) { return dnsServer, nil case "ios": - dnsServer := dns.NewDefaultServerIos(e.ctx, e.wgInterface, e.mobileDep.DnsManager, e.statusRecorder, e.config.DisableDNS) + dnsServer := dns.NewDefaultServerIos(e.ctx, e.wgInterface, e.mobileDep.DnsManager, e.mobileDep.HostDNSAddresses, e.statusRecorder, e.config.DisableDNS) return dnsServer, nil default: diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index bea0725f2..8d1585b3f 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -22,6 +22,7 @@ import ( icemaker "github.com/netbirdio/netbird/client/internal/peer/ice" "github.com/netbirdio/netbird/client/internal/peer/id" "github.com/netbirdio/netbird/client/internal/peer/worker" + "github.com/netbirdio/netbird/client/internal/portforward" "github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/route" relayClient "github.com/netbirdio/netbird/shared/relay/client" @@ -45,6 +46,7 @@ type ServiceDependencies struct { RelayManager *relayClient.Manager SrWatcher *guard.SRWatcher PeerConnDispatcher *dispatcher.ConnectionDispatcher + PortForwardManager *portforward.Manager MetricsRecorder MetricsRecorder } @@ -87,16 +89,17 @@ type ConnConfig struct { } type Conn struct { - Log *log.Entry - mu sync.Mutex - ctx context.Context - ctxCancel context.CancelFunc - config ConnConfig - statusRecorder *Status - signaler *Signaler - iFaceDiscover stdnet.ExternalIFaceDiscover - relayManager *relayClient.Manager - srWatcher *guard.SRWatcher + Log *log.Entry + mu sync.Mutex + ctx context.Context + ctxCancel context.CancelFunc + config ConnConfig + statusRecorder *Status + signaler *Signaler + iFaceDiscover stdnet.ExternalIFaceDiscover + relayManager *relayClient.Manager + srWatcher *guard.SRWatcher + portForwardManager *portforward.Manager onConnected func(remoteWireGuardKey string, remoteRosenpassPubKey []byte, wireGuardIP string, remoteRosenpassAddr string) onDisconnected func(remotePeer string) @@ -145,19 +148,20 @@ func NewConn(config ConnConfig, services ServiceDependencies) (*Conn, error) { dumpState := newStateDump(config.Key, connLog, services.StatusRecorder) var conn = &Conn{ - Log: connLog, - config: config, - statusRecorder: services.StatusRecorder, - signaler: services.Signaler, - iFaceDiscover: services.IFaceDiscover, - relayManager: services.RelayManager, - srWatcher: services.SrWatcher, - statusRelay: worker.NewAtomicStatus(), - statusICE: worker.NewAtomicStatus(), - dumpState: dumpState, - endpointUpdater: NewEndpointUpdater(connLog, config.WgConfig, isController(config)), - wgWatcher: NewWGWatcher(connLog, config.WgConfig.WgInterface, config.Key, dumpState), - metricsRecorder: services.MetricsRecorder, + Log: connLog, + config: config, + statusRecorder: services.StatusRecorder, + signaler: services.Signaler, + iFaceDiscover: services.IFaceDiscover, + relayManager: services.RelayManager, + srWatcher: services.SrWatcher, + portForwardManager: services.PortForwardManager, + statusRelay: worker.NewAtomicStatus(), + statusICE: worker.NewAtomicStatus(), + dumpState: dumpState, + endpointUpdater: NewEndpointUpdater(connLog, config.WgConfig, isController(config)), + wgWatcher: NewWGWatcher(connLog, config.WgConfig.WgInterface, config.Key, dumpState), + metricsRecorder: services.MetricsRecorder, } return conn, nil diff --git a/client/internal/peer/worker_ice.go b/client/internal/peer/worker_ice.go index edd70fb20..29bf5aaaa 100644 --- a/client/internal/peer/worker_ice.go +++ b/client/internal/peer/worker_ice.go @@ -16,6 +16,7 @@ import ( "github.com/netbirdio/netbird/client/iface/udpmux" "github.com/netbirdio/netbird/client/internal/peer/conntype" icemaker "github.com/netbirdio/netbird/client/internal/peer/ice" + "github.com/netbirdio/netbird/client/internal/portforward" "github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/route" ) @@ -61,6 +62,9 @@ type WorkerICE struct { // we record the last known state of the ICE agent to avoid duplicate on disconnected events lastKnownState ice.ConnectionState + + // portForwardAttempted tracks if we've already tried port forwarding this session + portForwardAttempted bool } func NewWorkerICE(ctx context.Context, log *log.Entry, config ConnConfig, conn *Conn, signaler *Signaler, ifaceDiscover stdnet.ExternalIFaceDiscover, statusRecorder *Status, hasRelayOnLocally bool) (*WorkerICE, error) { @@ -214,6 +218,8 @@ func (w *WorkerICE) Close() { } func (w *WorkerICE) reCreateAgent(dialerCancel context.CancelFunc, candidates []ice.CandidateType) (*icemaker.ThreadSafeAgent, error) { + w.portForwardAttempted = false + agent, err := icemaker.NewAgent(w.ctx, w.iFaceDiscover, w.config.ICEConfig, candidates, w.localUfrag, w.localPwd) if err != nil { return nil, fmt.Errorf("create agent: %w", err) @@ -370,6 +376,93 @@ func (w *WorkerICE) onICECandidate(candidate ice.Candidate) { w.log.Errorf("failed signaling candidate to the remote peer %s %s", w.config.Key, err) } }() + + if candidate.Type() == ice.CandidateTypeServerReflexive { + w.injectPortForwardedCandidate(candidate) + } +} + +// injectPortForwardedCandidate signals an additional candidate using the pre-created port mapping. +func (w *WorkerICE) injectPortForwardedCandidate(srflxCandidate ice.Candidate) { + pfManager := w.conn.portForwardManager + if pfManager == nil { + return + } + + mapping := pfManager.GetMapping() + if mapping == nil { + return + } + + w.muxAgent.Lock() + if w.portForwardAttempted { + w.muxAgent.Unlock() + return + } + w.portForwardAttempted = true + w.muxAgent.Unlock() + + forwardedCandidate, err := w.createForwardedCandidate(srflxCandidate, mapping) + if err != nil { + w.log.Warnf("create forwarded candidate: %v", err) + return + } + + w.log.Debugf("injecting port-forwarded candidate: %s (mapping: %d -> %d via %s, priority: %d)", + forwardedCandidate.String(), mapping.InternalPort, mapping.ExternalPort, mapping.NATType, forwardedCandidate.Priority()) + + go func() { + if err := w.signaler.SignalICECandidate(forwardedCandidate, w.config.Key); err != nil { + w.log.Errorf("signal port-forwarded candidate: %v", err) + } + }() +} + +// createForwardedCandidate creates a new server reflexive candidate with the forwarded port. +// It uses the NAT gateway's external IP with the forwarded port. +func (w *WorkerICE) createForwardedCandidate(srflxCandidate ice.Candidate, mapping *portforward.Mapping) (ice.Candidate, error) { + var externalIP string + if mapping.ExternalIP != nil && !mapping.ExternalIP.IsUnspecified() { + externalIP = mapping.ExternalIP.String() + } else { + // Fallback to STUN-discovered address if NAT didn't provide external IP + externalIP = srflxCandidate.Address() + } + + // Per RFC 8445, the related address for srflx is the base (host candidate address). + // If the original srflx has unspecified related address, use its own address as base. + relAddr := srflxCandidate.RelatedAddress().Address + if relAddr == "" || relAddr == "0.0.0.0" || relAddr == "::" { + relAddr = srflxCandidate.Address() + } + + // Arbitrary +1000 boost on top of RFC 8445 priority to favor port-forwarded candidates + // over regular srflx during ICE connectivity checks. + priority := srflxCandidate.Priority() + 1000 + + candidate, err := ice.NewCandidateServerReflexive(&ice.CandidateServerReflexiveConfig{ + Network: srflxCandidate.NetworkType().String(), + Address: externalIP, + Port: int(mapping.ExternalPort), + Component: srflxCandidate.Component(), + Priority: priority, + RelAddr: relAddr, + RelPort: int(mapping.InternalPort), + }) + if err != nil { + return nil, fmt.Errorf("create candidate: %w", err) + } + + for _, e := range srflxCandidate.Extensions() { + if e.Key == ice.ExtensionKeyCandidateID { + e.Value = srflxCandidate.ID() + } + if err := candidate.AddExtension(e); err != nil { + return nil, fmt.Errorf("add extension: %w", err) + } + } + + return candidate, nil } func (w *WorkerICE) onICESelectedCandidatePair(agent *icemaker.ThreadSafeAgent, c1, c2 ice.Candidate) { @@ -411,10 +504,10 @@ func (w *WorkerICE) logSuccessfulPaths(agent *icemaker.ThreadSafeAgent) { if !lok || !rok { continue } - w.log.Debugf("successful ICE path %s: [%s %s %s] <-> [%s %s %s] rtt=%.3fms", + w.log.Debugf("successful ICE path %s: [%s %s %s:%d] <-> [%s %s %s:%d] rtt=%.3fms", sessionID, - local.NetworkType(), local.Type(), local.Address(), - remote.NetworkType(), remote.Type(), remote.Address(), + local.NetworkType(), local.Type(), local.Address(), local.Port(), + remote.NetworkType(), remote.Type(), remote.Address(), remote.Port(), stat.CurrentRoundTripTime*1000) } } diff --git a/client/internal/portforward/env.go b/client/internal/portforward/env.go new file mode 100644 index 000000000..444a6b478 --- /dev/null +++ b/client/internal/portforward/env.go @@ -0,0 +1,26 @@ +package portforward + +import ( + "os" + "strconv" + + log "github.com/sirupsen/logrus" +) + +const ( + envDisableNATMapper = "NB_DISABLE_NAT_MAPPER" +) + +func isDisabledByEnv() bool { + val := os.Getenv(envDisableNATMapper) + if val == "" { + return false + } + + disabled, err := strconv.ParseBool(val) + if err != nil { + log.Warnf("failed to parse %s: %v", envDisableNATMapper, err) + return false + } + return disabled +} diff --git a/client/internal/portforward/manager.go b/client/internal/portforward/manager.go new file mode 100644 index 000000000..bf7533af9 --- /dev/null +++ b/client/internal/portforward/manager.go @@ -0,0 +1,280 @@ +//go:build !js + +package portforward + +import ( + "context" + "fmt" + "net" + "regexp" + "sync" + "time" + + "github.com/libp2p/go-nat" + log "github.com/sirupsen/logrus" +) + +const ( + defaultMappingTTL = 2 * time.Hour + discoveryTimeout = 10 * time.Second + mappingDescription = "NetBird" +) + +// upnpErrPermanentLeaseOnly matches UPnP error 725 in SOAP fault XML, +// allowing for whitespace/newlines between tags from different router firmware. +var upnpErrPermanentLeaseOnly = regexp.MustCompile(`\s*725\s*`) + +// Mapping represents an active NAT port mapping. +type Mapping struct { + Protocol string + InternalPort uint16 + ExternalPort uint16 + ExternalIP net.IP + NATType string + // TTL is the lease duration. Zero means a permanent lease that never expires. + TTL time.Duration +} + +// TODO: persist mapping state for crash recovery cleanup of permanent leases. +// Currently not done because State.Cleanup requires NAT gateway re-discovery, +// which blocks startup for ~10s when no gateway is present (affects all clients). + +type Manager struct { + cancel context.CancelFunc + + mapping *Mapping + mappingLock sync.Mutex + + wgPort uint16 + + done chan struct{} + stopCtx chan context.Context + + // protect exported functions + mu sync.Mutex +} + +// NewManager creates a new port forwarding manager. +func NewManager() *Manager { + return &Manager{ + stopCtx: make(chan context.Context, 1), + } +} + +func (m *Manager) Start(ctx context.Context, wgPort uint16) { + m.mu.Lock() + if m.cancel != nil { + m.mu.Unlock() + return + } + + if isDisabledByEnv() { + log.Infof("NAT port mapper disabled via %s", envDisableNATMapper) + m.mu.Unlock() + return + } + + if wgPort == 0 { + log.Warnf("invalid WireGuard port 0; NAT mapping disabled") + m.mu.Unlock() + return + } + m.wgPort = wgPort + + m.done = make(chan struct{}) + defer close(m.done) + + ctx, m.cancel = context.WithCancel(ctx) + m.mu.Unlock() + + gateway, mapping, err := m.setup(ctx) + if err != nil { + log.Infof("port forwarding setup: %v", err) + return + } + + m.mappingLock.Lock() + m.mapping = mapping + m.mappingLock.Unlock() + + m.renewLoop(ctx, gateway, mapping.TTL) + + select { + case cleanupCtx := <-m.stopCtx: + // block the Start while cleaned up gracefully + m.cleanup(cleanupCtx, gateway) + default: + // return Start immediately and cleanup in background + cleanupCtx, cleanupCancel := context.WithTimeout(context.Background(), 10*time.Second) + go func() { + defer cleanupCancel() + m.cleanup(cleanupCtx, gateway) + }() + } +} + +// GetMapping returns the current mapping if ready, nil otherwise +func (m *Manager) GetMapping() *Mapping { + m.mappingLock.Lock() + defer m.mappingLock.Unlock() + + if m.mapping == nil { + return nil + } + + mapping := *m.mapping + return &mapping +} + +// GracefullyStop cancels the manager and attempts to delete the port mapping. +// After GracefullyStop returns, the manager cannot be restarted. +func (m *Manager) GracefullyStop(ctx context.Context) error { + m.mu.Lock() + defer m.mu.Unlock() + + if m.cancel == nil { + return nil + } + + // Send cleanup context before cancelling, so Start picks it up after renewLoop exits. + m.startTearDown(ctx) + + m.cancel() + m.cancel = nil + + select { + case <-ctx.Done(): + return ctx.Err() + case <-m.done: + return nil + } +} + +func (m *Manager) setup(ctx context.Context) (nat.NAT, *Mapping, error) { + discoverCtx, discoverCancel := context.WithTimeout(ctx, discoveryTimeout) + defer discoverCancel() + + gateway, err := nat.DiscoverGateway(discoverCtx) + if err != nil { + return nil, nil, fmt.Errorf("discover gateway: %w", err) + } + + log.Infof("discovered NAT gateway: %s", gateway.Type()) + + mapping, err := m.createMapping(ctx, gateway) + if err != nil { + return nil, nil, fmt.Errorf("create port mapping: %w", err) + } + return gateway, mapping, nil +} + +func (m *Manager) createMapping(ctx context.Context, gateway nat.NAT) (*Mapping, error) { + ctx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + + ttl := defaultMappingTTL + externalPort, err := gateway.AddPortMapping(ctx, "udp", int(m.wgPort), mappingDescription, ttl) + if err != nil { + if !isPermanentLeaseRequired(err) { + return nil, err + } + log.Infof("gateway only supports permanent leases, retrying with indefinite duration") + ttl = 0 + externalPort, err = gateway.AddPortMapping(ctx, "udp", int(m.wgPort), mappingDescription, ttl) + if err != nil { + return nil, err + } + } + + externalIP, err := gateway.GetExternalAddress() + if err != nil { + log.Debugf("failed to get external address: %v", err) + // todo return with err? + } + + mapping := &Mapping{ + Protocol: "udp", + InternalPort: m.wgPort, + ExternalPort: uint16(externalPort), + ExternalIP: externalIP, + NATType: gateway.Type(), + TTL: ttl, + } + + log.Infof("created port mapping: %d -> %d via %s (external IP: %s)", + m.wgPort, externalPort, gateway.Type(), externalIP) + return mapping, nil +} + +func (m *Manager) renewLoop(ctx context.Context, gateway nat.NAT, ttl time.Duration) { + if ttl == 0 { + // Permanent mappings don't expire, just wait for cancellation. + <-ctx.Done() + return + } + + ticker := time.NewTicker(ttl / 2) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + if err := m.renewMapping(ctx, gateway); err != nil { + log.Warnf("failed to renew port mapping: %v", err) + continue + } + } + } +} + +func (m *Manager) renewMapping(ctx context.Context, gateway nat.NAT) error { + ctx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + + externalPort, err := gateway.AddPortMapping(ctx, m.mapping.Protocol, int(m.mapping.InternalPort), mappingDescription, m.mapping.TTL) + if err != nil { + return fmt.Errorf("add port mapping: %w", err) + } + + if uint16(externalPort) != m.mapping.ExternalPort { + log.Warnf("external port changed on renewal: %d -> %d (candidate may be stale)", m.mapping.ExternalPort, externalPort) + m.mappingLock.Lock() + m.mapping.ExternalPort = uint16(externalPort) + m.mappingLock.Unlock() + } + + log.Debugf("renewed port mapping: %d -> %d", m.mapping.InternalPort, m.mapping.ExternalPort) + return nil +} + +func (m *Manager) cleanup(ctx context.Context, gateway nat.NAT) { + m.mappingLock.Lock() + mapping := m.mapping + m.mapping = nil + m.mappingLock.Unlock() + + if mapping == nil { + return + } + + if err := gateway.DeletePortMapping(ctx, mapping.Protocol, int(mapping.InternalPort)); err != nil { + log.Warnf("delete port mapping on stop: %v", err) + return + } + + log.Infof("deleted port mapping for port %d", mapping.InternalPort) +} + +func (m *Manager) startTearDown(ctx context.Context) { + select { + case m.stopCtx <- ctx: + default: + } +} + +// isPermanentLeaseRequired checks if a UPnP error indicates the gateway only supports permanent leases (error 725). +func isPermanentLeaseRequired(err error) bool { + return err != nil && upnpErrPermanentLeaseOnly.MatchString(err.Error()) +} diff --git a/client/internal/portforward/manager_js.go b/client/internal/portforward/manager_js.go new file mode 100644 index 000000000..36c55063b --- /dev/null +++ b/client/internal/portforward/manager_js.go @@ -0,0 +1,39 @@ +package portforward + +import ( + "context" + "net" + "time" +) + +// Mapping represents an active NAT port mapping. +type Mapping struct { + Protocol string + InternalPort uint16 + ExternalPort uint16 + ExternalIP net.IP + NATType string + // TTL is the lease duration. Zero means a permanent lease that never expires. + TTL time.Duration +} + +// Manager is a stub for js/wasm builds where NAT-PMP/UPnP is not supported. +type Manager struct{} + +// NewManager returns a stub manager for js/wasm builds. +func NewManager() *Manager { + return &Manager{} +} + +// Start is a no-op on js/wasm: NAT-PMP/UPnP is not available in browser environments. +func (m *Manager) Start(context.Context, uint16) { + // no NAT traversal in wasm +} + +// GracefullyStop is a no-op on js/wasm. +func (m *Manager) GracefullyStop(context.Context) error { return nil } + +// GetMapping always returns nil on js/wasm. +func (m *Manager) GetMapping() *Mapping { + return nil +} diff --git a/client/internal/portforward/manager_test.go b/client/internal/portforward/manager_test.go new file mode 100644 index 000000000..1f66f9ccd --- /dev/null +++ b/client/internal/portforward/manager_test.go @@ -0,0 +1,201 @@ +//go:build !js + +package portforward + +import ( + "context" + "fmt" + "net" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type mockNAT struct { + natType string + deviceAddr net.IP + externalAddr net.IP + internalAddr net.IP + mappings map[int]int + addMappingErr error + deleteMappingErr error + onlyPermanentLeases bool + lastTimeout time.Duration +} + +func newMockNAT() *mockNAT { + return &mockNAT{ + natType: "Mock-NAT", + deviceAddr: net.ParseIP("192.168.1.1"), + externalAddr: net.ParseIP("203.0.113.50"), + internalAddr: net.ParseIP("192.168.1.100"), + mappings: make(map[int]int), + } +} + +func (m *mockNAT) Type() string { + return m.natType +} + +func (m *mockNAT) GetDeviceAddress() (net.IP, error) { + return m.deviceAddr, nil +} + +func (m *mockNAT) GetExternalAddress() (net.IP, error) { + return m.externalAddr, nil +} + +func (m *mockNAT) GetInternalAddress() (net.IP, error) { + return m.internalAddr, nil +} + +func (m *mockNAT) AddPortMapping(ctx context.Context, protocol string, internalPort int, description string, timeout time.Duration) (int, error) { + if m.addMappingErr != nil { + return 0, m.addMappingErr + } + if m.onlyPermanentLeases && timeout != 0 { + return 0, fmt.Errorf("SOAP fault. Code: | Explanation: | Detail: 725OnlyPermanentLeasesSupported") + } + externalPort := internalPort + m.mappings[internalPort] = externalPort + m.lastTimeout = timeout + return externalPort, nil +} + +func (m *mockNAT) DeletePortMapping(ctx context.Context, protocol string, internalPort int) error { + if m.deleteMappingErr != nil { + return m.deleteMappingErr + } + delete(m.mappings, internalPort) + return nil +} + +func TestManager_CreateMapping(t *testing.T) { + m := NewManager() + m.wgPort = 51820 + + gateway := newMockNAT() + mapping, err := m.createMapping(context.Background(), gateway) + require.NoError(t, err) + require.NotNil(t, mapping) + + assert.Equal(t, "udp", mapping.Protocol) + assert.Equal(t, uint16(51820), mapping.InternalPort) + assert.Equal(t, uint16(51820), mapping.ExternalPort) + assert.Equal(t, "Mock-NAT", mapping.NATType) + assert.Equal(t, net.ParseIP("203.0.113.50").To4(), mapping.ExternalIP.To4()) + assert.Equal(t, defaultMappingTTL, mapping.TTL) +} + +func TestManager_GetMapping_ReturnsNilWhenNotReady(t *testing.T) { + m := NewManager() + assert.Nil(t, m.GetMapping()) +} + +func TestManager_GetMapping_ReturnsCopy(t *testing.T) { + m := NewManager() + m.mapping = &Mapping{ + Protocol: "udp", + InternalPort: 51820, + ExternalPort: 51820, + } + + mapping := m.GetMapping() + require.NotNil(t, mapping) + assert.Equal(t, uint16(51820), mapping.InternalPort) + + // Mutating the returned copy should not affect the manager's mapping. + mapping.ExternalPort = 9999 + assert.Equal(t, uint16(51820), m.GetMapping().ExternalPort) +} + +func TestManager_Cleanup_DeletesMapping(t *testing.T) { + m := NewManager() + m.mapping = &Mapping{ + Protocol: "udp", + InternalPort: 51820, + ExternalPort: 51820, + } + + gateway := newMockNAT() + // Seed the mock so we can verify deletion. + gateway.mappings[51820] = 51820 + + m.cleanup(context.Background(), gateway) + + _, exists := gateway.mappings[51820] + assert.False(t, exists, "mapping should be deleted from gateway") + assert.Nil(t, m.GetMapping(), "in-memory mapping should be cleared") +} + +func TestManager_Cleanup_NilMapping(t *testing.T) { + m := NewManager() + gateway := newMockNAT() + + // Should not panic or call gateway. + m.cleanup(context.Background(), gateway) +} + + +func TestManager_CreateMapping_PermanentLeaseFallback(t *testing.T) { + m := NewManager() + m.wgPort = 51820 + + gateway := newMockNAT() + gateway.onlyPermanentLeases = true + + mapping, err := m.createMapping(context.Background(), gateway) + require.NoError(t, err) + require.NotNil(t, mapping) + + assert.Equal(t, uint16(51820), mapping.InternalPort) + assert.Equal(t, time.Duration(0), mapping.TTL, "should return zero TTL for permanent lease") + assert.Equal(t, time.Duration(0), gateway.lastTimeout, "should have retried with zero duration") +} + +func TestIsPermanentLeaseRequired(t *testing.T) { + tests := []struct { + name string + err error + expected bool + }{ + { + name: "nil error", + err: nil, + expected: false, + }, + { + name: "UPnP error 725", + err: fmt.Errorf("SOAP fault. Code: | Detail: 725OnlyPermanentLeasesSupported"), + expected: true, + }, + { + name: "wrapped error with 725", + err: fmt.Errorf("add port mapping: %w", fmt.Errorf("Detail: 725")), + expected: true, + }, + { + name: "error 725 with newlines in XML", + err: fmt.Errorf("\n 725\n"), + expected: true, + }, + { + name: "bare 725 without XML tag", + err: fmt.Errorf("error code 725"), + expected: false, + }, + { + name: "unrelated error", + err: fmt.Errorf("connection refused"), + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.expected, isPermanentLeaseRequired(tt.err)) + }) + } +} diff --git a/client/internal/routemanager/notifier/notifier_ios.go b/client/internal/routemanager/notifier/notifier_ios.go index bb125cfa4..343d2799e 100644 --- a/client/internal/routemanager/notifier/notifier_ios.go +++ b/client/internal/routemanager/notifier/notifier_ios.go @@ -53,7 +53,6 @@ func (n *Notifier) OnNewPrefixes(prefixes []netip.Prefix) { n.currentPrefixes = newNets n.notify() } - func (n *Notifier) notify() { n.listenerMux.Lock() defer n.listenerMux.Unlock() diff --git a/client/ios/NetBirdSDK/client.go b/client/ios/NetBirdSDK/client.go index c73a0dcd1..f15b97d30 100644 --- a/client/ios/NetBirdSDK/client.go +++ b/client/ios/NetBirdSDK/client.go @@ -162,7 +162,11 @@ func (c *Client) Run(fd int32, interfaceName string, envList *EnvList) error { cfg.WgIface = interfaceName c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder) - return c.connectClient.RunOniOS(fd, c.networkChangeListener, c.dnsManager, c.stateFile) + hostDNS := []netip.AddrPort{ + netip.MustParseAddrPort("9.9.9.9:53"), + netip.MustParseAddrPort("149.112.112.112:53"), + } + return c.connectClient.RunOniOS(fd, c.networkChangeListener, c.dnsManager, hostDNS, c.stateFile) } // Stop the internal client and free the resources diff --git a/client/server/state_generic.go b/client/server/state_generic.go index 980ba0cda..86475ca42 100644 --- a/client/server/state_generic.go +++ b/client/server/state_generic.go @@ -9,6 +9,7 @@ import ( "github.com/netbirdio/netbird/client/ssh/config" ) +// registerStates registers all states that need crash recovery cleanup. func registerStates(mgr *statemanager.Manager) { mgr.RegisterState(&dns.ShutdownState{}) mgr.RegisterState(&systemops.ShutdownState{}) diff --git a/client/server/state_linux.go b/client/server/state_linux.go index 019477d8e..b193d4dfa 100644 --- a/client/server/state_linux.go +++ b/client/server/state_linux.go @@ -11,6 +11,7 @@ import ( "github.com/netbirdio/netbird/client/ssh/config" ) +// registerStates registers all states that need crash recovery cleanup. func registerStates(mgr *statemanager.Manager) { mgr.RegisterState(&dns.ShutdownState{}) mgr.RegisterState(&systemops.ShutdownState{}) diff --git a/client/ssh/proxy/proxy.go b/client/ssh/proxy/proxy.go index 8897b9c7e..59007f75c 100644 --- a/client/ssh/proxy/proxy.go +++ b/client/ssh/proxy/proxy.go @@ -141,7 +141,7 @@ func (p *SSHProxy) runProxySSHServer(jwtToken string) error { func (p *SSHProxy) handleSSHSession(session ssh.Session) { ptyReq, winCh, isPty := session.Pty() - hasCommand := len(session.Command()) > 0 + hasCommand := session.RawCommand() != "" sshClient, err := p.getOrCreateBackendClient(session.Context(), session.User()) if err != nil { @@ -180,7 +180,7 @@ func (p *SSHProxy) handleSSHSession(session ssh.Session) { } if hasCommand { - if err := serverSession.Run(strings.Join(session.Command(), " ")); err != nil { + if err := serverSession.Run(session.RawCommand()); err != nil { log.Debugf("run command: %v", err) p.handleProxyExitCode(session, err) } diff --git a/client/ssh/proxy/proxy_test.go b/client/ssh/proxy/proxy_test.go index dba2e88da..b33d5f8f4 100644 --- a/client/ssh/proxy/proxy_test.go +++ b/client/ssh/proxy/proxy_test.go @@ -1,6 +1,7 @@ package proxy import ( + "bytes" "context" "crypto/rand" "crypto/rsa" @@ -245,6 +246,191 @@ func TestSSHProxy_Connect(t *testing.T) { cancel() } +// TestSSHProxy_CommandQuoting verifies that the proxy preserves shell quoting +// when forwarding commands to the backend. This is critical for tools like +// Ansible that send commands such as: +// +// /bin/sh -c '( umask 77 && mkdir -p ... ) && sleep 0' +// +// The single quotes must be preserved so the backend shell receives the +// subshell expression as a single argument to -c. +func TestSSHProxy_CommandQuoting(t *testing.T) { + if testing.Short() { + t.Skip("Skipping integration test in short mode") + } + + sshClient, cleanup := setupProxySSHClient(t) + defer cleanup() + + // These commands simulate what the SSH protocol delivers as exec payloads. + // When a user types: ssh host '/bin/sh -c "( echo hello )"' + // the local shell strips the outer single quotes, and the SSH exec request + // contains the raw string: /bin/sh -c "( echo hello )" + // + // The proxy must forward this string verbatim. Using session.Command() + // (shlex.Split + strings.Join) strips the inner double quotes, breaking + // the command on the backend. + tests := []struct { + name string + command string + expect string + }{ + { + name: "subshell_in_double_quotes", + command: `/bin/sh -c "( echo from-subshell ) && echo outer"`, + expect: "from-subshell\nouter\n", + }, + { + name: "printf_with_special_chars", + command: `/bin/sh -c "printf '%s\n' 'hello world'"`, + expect: "hello world\n", + }, + { + name: "nested_command_substitution", + command: `/bin/sh -c "echo $(echo nested)"`, + expect: "nested\n", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + session, err := sshClient.NewSession() + require.NoError(t, err) + defer func() { _ = session.Close() }() + + var stderrBuf bytes.Buffer + session.Stderr = &stderrBuf + + outputCh := make(chan []byte, 1) + errCh := make(chan error, 1) + go func() { + output, err := session.Output(tc.command) + outputCh <- output + errCh <- err + }() + + select { + case output := <-outputCh: + err := <-errCh + if stderrBuf.Len() > 0 { + t.Logf("stderr: %s", stderrBuf.String()) + } + require.NoError(t, err, "command should succeed: %s", tc.command) + assert.Equal(t, tc.expect, string(output), "output mismatch for: %s", tc.command) + case <-time.After(5 * time.Second): + t.Fatalf("command timed out: %s", tc.command) + } + }) + } +} + +// setupProxySSHClient creates a full proxy test environment and returns +// an SSH client connected through the proxy to a backend NetBird SSH server. +func setupProxySSHClient(t *testing.T) (*cryptossh.Client, func()) { + t.Helper() + + const ( + issuer = "https://test-issuer.example.com" + audience = "test-audience" + ) + + jwksServer, privateKey, jwksURL := setupJWKSServer(t) + + hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519) + require.NoError(t, err) + hostPubKey, err := nbssh.GeneratePublicKey(hostKey) + require.NoError(t, err) + + serverConfig := &server.Config{ + HostKeyPEM: hostKey, + JWT: &server.JWTConfig{ + Issuer: issuer, + Audiences: []string{audience}, + KeysLocation: jwksURL, + }, + } + sshServer := server.New(serverConfig) + sshServer.SetAllowRootLogin(true) + + testUsername := testutil.GetTestUsername(t) + testJWTUser := "test-username" + testUserHash, err := sshuserhash.HashUserID(testJWTUser) + require.NoError(t, err) + + authConfig := &sshauth.Config{ + UserIDClaim: sshauth.DefaultUserIDClaim, + AuthorizedUsers: []sshuserhash.UserIDHash{testUserHash}, + MachineUsers: map[string][]uint32{ + testUsername: {0}, + }, + } + sshServer.UpdateSSHAuth(authConfig) + + sshServerAddr := server.StartTestServer(t, sshServer) + + mockDaemon := startMockDaemon(t) + + host, portStr, err := net.SplitHostPort(sshServerAddr) + require.NoError(t, err) + port, err := strconv.Atoi(portStr) + require.NoError(t, err) + + mockDaemon.setHostKey(host, hostPubKey) + + validToken := generateValidJWT(t, privateKey, issuer, audience, testJWTUser) + mockDaemon.setJWTToken(validToken) + + proxyInstance, err := New(mockDaemon.addr, host, port, io.Discard, nil) + require.NoError(t, err) + + origStdin := os.Stdin + origStdout := os.Stdout + + stdinReader, stdinWriter, err := os.Pipe() + require.NoError(t, err) + stdoutReader, stdoutWriter, err := os.Pipe() + require.NoError(t, err) + + os.Stdin = stdinReader + os.Stdout = stdoutWriter + + clientConn, proxyConn := net.Pipe() + + go func() { _, _ = io.Copy(stdinWriter, proxyConn) }() + go func() { _, _ = io.Copy(proxyConn, stdoutReader) }() + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + + go func() { + _ = proxyInstance.Connect(ctx) + }() + + sshConfig := &cryptossh.ClientConfig{ + User: testutil.GetTestUsername(t), + Auth: []cryptossh.AuthMethod{}, + HostKeyCallback: cryptossh.InsecureIgnoreHostKey(), + Timeout: 5 * time.Second, + } + + sshClientConn, chans, reqs, err := cryptossh.NewClientConn(clientConn, "test", sshConfig) + require.NoError(t, err) + + client := cryptossh.NewClient(sshClientConn, chans, reqs) + + cleanupFn := func() { + _ = client.Close() + _ = clientConn.Close() + cancel() + os.Stdin = origStdin + os.Stdout = origStdout + _ = sshServer.Stop() + mockDaemon.stop() + jwksServer.Close() + } + + return client, cleanupFn +} + type mockDaemonServer struct { proto.UnimplementedDaemonServiceServer hostKeys map[string][]byte diff --git a/client/ssh/server/session_handlers.go b/client/ssh/server/session_handlers.go index f12a75961..0e531bb96 100644 --- a/client/ssh/server/session_handlers.go +++ b/client/ssh/server/session_handlers.go @@ -60,7 +60,7 @@ func (s *Server) sessionHandler(session ssh.Session) { } ptyReq, winCh, isPty := session.Pty() - hasCommand := len(session.Command()) > 0 + hasCommand := session.RawCommand() != "" if isPty && !hasCommand { // ssh - PTY interactive session (login) diff --git a/client/system/info_freebsd.go b/client/system/info_freebsd.go index 8e1353151..755172842 100644 --- a/client/system/info_freebsd.go +++ b/client/system/info_freebsd.go @@ -43,18 +43,24 @@ func GetInfo(ctx context.Context) *Info { systemHostname, _ := os.Hostname() + addrs, err := networkAddresses() + if err != nil { + log.Warnf("failed to discover network addresses: %s", err) + } + return &Info{ - GoOS: runtime.GOOS, - Kernel: osInfo[0], - Platform: runtime.GOARCH, - OS: osName, - OSVersion: osVersion, - Hostname: extractDeviceName(ctx, systemHostname), - CPUs: runtime.NumCPU(), - NetbirdVersion: version.NetbirdVersion(), - UIVersion: extractUserAgent(ctx), - KernelVersion: osInfo[1], - Environment: env, + GoOS: runtime.GOOS, + Kernel: osInfo[0], + Platform: runtime.GOARCH, + OS: osName, + OSVersion: osVersion, + Hostname: extractDeviceName(ctx, systemHostname), + CPUs: runtime.NumCPU(), + NetbirdVersion: version.NetbirdVersion(), + UIVersion: extractUserAgent(ctx), + KernelVersion: osInfo[1], + NetworkAddresses: addrs, + Environment: env, } } diff --git a/client/ui/debug.go b/client/ui/debug.go index 29f73a66a..4ebe4d675 100644 --- a/client/ui/debug.go +++ b/client/ui/debug.go @@ -24,9 +24,10 @@ import ( // Initial state for the debug collection type debugInitialState struct { - wasDown bool - logLevel proto.LogLevel - isLevelTrace bool + wasDown bool + needsRestoreUp bool + logLevel proto.LogLevel + isLevelTrace bool } // Debug collection parameters @@ -371,46 +372,51 @@ func (s *serviceClient) configureServiceForDebug( conn proto.DaemonServiceClient, state *debugInitialState, enablePersistence bool, -) error { +) { if state.wasDown { if _, err := conn.Up(s.ctx, &proto.UpRequest{}); err != nil { - return fmt.Errorf("bring service up: %v", err) + log.Warnf("failed to bring service up: %v", err) + } else { + log.Info("Service brought up for debug") + time.Sleep(time.Second * 10) } - log.Info("Service brought up for debug") - time.Sleep(time.Second * 10) } if !state.isLevelTrace { if _, err := conn.SetLogLevel(s.ctx, &proto.SetLogLevelRequest{Level: proto.LogLevel_TRACE}); err != nil { - return fmt.Errorf("set log level to TRACE: %v", err) + log.Warnf("failed to set log level to TRACE: %v", err) + } else { + log.Info("Log level set to TRACE for debug") } - log.Info("Log level set to TRACE for debug") } if _, err := conn.Down(s.ctx, &proto.DownRequest{}); err != nil { - return fmt.Errorf("bring service down: %v", err) + log.Warnf("failed to bring service down: %v", err) + } else { + state.needsRestoreUp = !state.wasDown + time.Sleep(time.Second) } - time.Sleep(time.Second) if enablePersistence { if _, err := conn.SetSyncResponsePersistence(s.ctx, &proto.SetSyncResponsePersistenceRequest{ Enabled: true, }); err != nil { - return fmt.Errorf("enable sync response persistence: %v", err) + log.Warnf("failed to enable sync response persistence: %v", err) + } else { + log.Info("Sync response persistence enabled for debug") } - log.Info("Sync response persistence enabled for debug") } if _, err := conn.Up(s.ctx, &proto.UpRequest{}); err != nil { - return fmt.Errorf("bring service back up: %v", err) + log.Warnf("failed to bring service back up: %v", err) + } else { + state.needsRestoreUp = false + time.Sleep(time.Second * 3) } - time.Sleep(time.Second * 3) if _, err := conn.StartCPUProfile(s.ctx, &proto.StartCPUProfileRequest{}); err != nil { log.Warnf("failed to start CPU profiling: %v", err) } - - return nil } func (s *serviceClient) collectDebugData( @@ -424,9 +430,7 @@ func (s *serviceClient) collectDebugData( var wg sync.WaitGroup startProgressTracker(ctx, &wg, params.duration, progress) - if err := s.configureServiceForDebug(conn, state, params.enablePersistence); err != nil { - return err - } + s.configureServiceForDebug(conn, state, params.enablePersistence) wg.Wait() progress.progressBar.Hide() @@ -482,9 +486,17 @@ func (s *serviceClient) createDebugBundleFromCollection( // Restore service to original state func (s *serviceClient) restoreServiceState(conn proto.DaemonServiceClient, state *debugInitialState) { + if state.needsRestoreUp { + if _, err := conn.Up(s.ctx, &proto.UpRequest{}); err != nil { + log.Warnf("failed to restore up state: %v", err) + } else { + log.Info("Service state restored to up") + } + } + if state.wasDown { if _, err := conn.Down(s.ctx, &proto.DownRequest{}); err != nil { - log.Errorf("Failed to restore down state: %v", err) + log.Warnf("failed to restore down state: %v", err) } else { log.Info("Service state restored to down") } @@ -492,7 +504,7 @@ func (s *serviceClient) restoreServiceState(conn proto.DaemonServiceClient, stat if !state.isLevelTrace { if _, err := conn.SetLogLevel(s.ctx, &proto.SetLogLevelRequest{Level: state.logLevel}); err != nil { - log.Errorf("Failed to restore log level: %v", err) + log.Warnf("failed to restore log level: %v", err) } else { log.Info("Log level restored to original setting") } diff --git a/combined/cmd/root.go b/combined/cmd/root.go index ea1ff908a..db986b4d4 100644 --- a/combined/cmd/root.go +++ b/combined/cmd/root.go @@ -29,6 +29,7 @@ import ( "github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/relay/healthcheck" relayServer "github.com/netbirdio/netbird/relay/server" + "github.com/netbirdio/netbird/relay/server/listener" "github.com/netbirdio/netbird/relay/server/listener/ws" sharedMetrics "github.com/netbirdio/netbird/shared/metrics" "github.com/netbirdio/netbird/shared/relay/auth" @@ -523,7 +524,7 @@ func createManagementServer(cfg *CombinedConfig, mgmtConfig *nbconfig.Config) (* func createCombinedHandler(grpcServer *grpc.Server, httpHandler http.Handler, relaySrv *relayServer.Server, meter metric.Meter, cfg *CombinedConfig) http.Handler { wsProxy := wsproxyserver.New(grpcServer, wsproxyserver.WithOTelMeter(meter)) - var relayAcceptFn func(conn net.Conn) + var relayAcceptFn func(conn listener.Conn) if relaySrv != nil { relayAcceptFn = relaySrv.RelayAccept() } @@ -563,7 +564,7 @@ func createCombinedHandler(grpcServer *grpc.Server, httpHandler http.Handler, re } // handleRelayWebSocket handles incoming WebSocket connections for the relay service -func handleRelayWebSocket(w http.ResponseWriter, r *http.Request, acceptFn func(conn net.Conn), cfg *CombinedConfig) { +func handleRelayWebSocket(w http.ResponseWriter, r *http.Request, acceptFn func(conn listener.Conn), cfg *CombinedConfig) { acceptOptions := &websocket.AcceptOptions{ OriginPatterns: []string{"*"}, } @@ -585,15 +586,9 @@ func handleRelayWebSocket(w http.ResponseWriter, r *http.Request, acceptFn func( return } - lAddr, err := net.ResolveTCPAddr("tcp", cfg.Server.ListenAddress) - if err != nil { - _ = wsConn.Close(websocket.StatusInternalError, "internal error") - return - } - log.Debugf("Relay WS client connected from: %s", rAddr) - conn := ws.NewConn(wsConn, lAddr, rAddr) + conn := ws.NewConn(wsConn, rAddr) acceptFn(conn) } diff --git a/go.mod b/go.mod index e9334f85b..a95192600 100644 --- a/go.mod +++ b/go.mod @@ -63,6 +63,7 @@ require ( github.com/hashicorp/go-version v1.6.0 github.com/jackc/pgx/v5 v5.5.5 github.com/libdns/route53 v1.5.0 + github.com/libp2p/go-nat v0.2.0 github.com/libp2p/go-netroute v0.2.1 github.com/lrh3321/ipset-go v0.0.0-20250619021614-54a0a98ace81 github.com/mdlayher/socket v0.5.1 @@ -200,10 +201,12 @@ require ( github.com/hashicorp/errwrap v1.1.0 // indirect github.com/hashicorp/go-uuid v1.0.3 // indirect github.com/huandu/xstrings v1.5.0 // indirect + github.com/huin/goupnp v1.2.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect github.com/jackc/puddle/v2 v2.2.1 // indirect + github.com/jackpal/go-nat-pmp v1.0.2 // indirect github.com/jeandeaual/go-locale v0.0.0-20250612000132-0ef82f21eade // indirect github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/now v1.1.5 // indirect @@ -213,6 +216,7 @@ require ( github.com/kelseyhightower/envconfig v1.4.0 // indirect github.com/klauspost/compress v1.18.0 // indirect github.com/klauspost/cpuid/v2 v2.2.7 // indirect + github.com/koron/go-ssdp v0.0.4 // indirect github.com/kr/fs v0.1.0 // indirect github.com/lib/pq v1.10.9 // indirect github.com/libdns/libdns v0.2.2 // indirect diff --git a/go.sum b/go.sum index 629388ccb..a1d2bb71f 100644 --- a/go.sum +++ b/go.sum @@ -281,6 +281,8 @@ github.com/hashicorp/go-version v1.6.0/go.mod h1:fltr4n8CU8Ke44wwGCBoEymUuxUHl09 github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= github.com/huandu/xstrings v1.5.0 h1:2ag3IFq9ZDANvthTwTiqSSZLjDc+BedvHPAp5tJy2TI= github.com/huandu/xstrings v1.5.0/go.mod h1:y5/lhBue+AyNmUVz9RLU9xbLR0o4KIIExikq4ovT0aE= +github.com/huin/goupnp v1.2.0 h1:uOKW26NG1hsSSbXIZ1IR7XP9Gjd1U8pnLaCMgntmkmY= +github.com/huin/goupnp v1.2.0/go.mod h1:gnGPsThkYa7bFi/KWmEysQRf48l2dvR5bxr2OFckNX8= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= @@ -291,6 +293,8 @@ github.com/jackc/pgx/v5 v5.5.5 h1:amBjrZVmksIdNjxGW/IiIMzxMKZFelXbUoPNb+8sjQw= github.com/jackc/pgx/v5 v5.5.5/go.mod h1:ez9gk+OAat140fv9ErkZDYFWmXLfV+++K0uAOiwgm1A= github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk= github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= +github.com/jackpal/go-nat-pmp v1.0.2 h1:KzKSgb7qkJvOUTqYl9/Hg/me3pWgBmERKrTGD7BdWus= +github.com/jackpal/go-nat-pmp v1.0.2/go.mod h1:QPH045xvCAeXUZOxsnwmrtiCoxIr9eob+4orBN1SBKc= github.com/jcmturner/aescts/v2 v2.0.0 h1:9YKLH6ey7H4eDBXW8khjYslgyqG2xZikXP0EQFKrle8= github.com/jcmturner/aescts/v2 v2.0.0/go.mod h1:AiaICIRyfYg35RUkr8yESTqvSy7csK90qZ5xfvvsoNs= github.com/jcmturner/dnsutils/v2 v2.0.0 h1:lltnkeZGL0wILNvrNiVCR6Ro5PGU/SeBvVO/8c/iPbo= @@ -328,6 +332,8 @@ github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYW github.com/klauspost/cpuid/v2 v2.0.12/go.mod h1:g2LTdtYhdyuGPqyWyv7qRAmj1WBqxuObKfj5c0PQa7c= github.com/klauspost/cpuid/v2 v2.2.7 h1:ZWSB3igEs+d0qvnxR/ZBzXVmxkgt8DdzP6m9pfuVLDM= github.com/klauspost/cpuid/v2 v2.2.7/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws= +github.com/koron/go-ssdp v0.0.4 h1:1IDwrghSKYM7yLf7XCzbByg2sJ/JcNOZRXS2jczTwz0= +github.com/koron/go-ssdp v0.0.4/go.mod h1:oDXq+E5IL5q0U8uSBcoAXzTzInwy5lEgC91HoKtbmZk= github.com/kr/fs v0.1.0 h1:Jskdu9ieNAYnjxsi0LbQp1ulIKZV1LAFgK1tWhpZgl8= github.com/kr/fs v0.1.0/go.mod h1:FFnZGqtBN9Gxj7eW1uZ42v5BccTP0vu6NEaFoC2HwRg= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= @@ -346,6 +352,8 @@ github.com/libdns/libdns v0.2.2 h1:O6ws7bAfRPaBsgAYt8MDe2HcNBGC29hkZ9MX2eUSX3s= github.com/libdns/libdns v0.2.2/go.mod h1:4Bj9+5CQiNMVGf87wjX4CY3HQJypUHRuLvlsfsZqLWQ= github.com/libdns/route53 v1.5.0 h1:2SKdpPFl/qgWsXQvsLNJJAoX7rSxlk7zgoL4jnWdXVA= github.com/libdns/route53 v1.5.0/go.mod h1:joT4hKmaTNKHEwb7GmZ65eoDz1whTu7KKYPS8ZqIh6Q= +github.com/libp2p/go-nat v0.2.0 h1:Tyz+bUFAYqGyJ/ppPPymMGbIgNRH+WqC5QrT5fKrrGk= +github.com/libp2p/go-nat v0.2.0/go.mod h1:3MJr+GRpRkyT65EpVPBstXLvOlAPzUVlG6Pwg9ohLJk= github.com/lrh3321/ipset-go v0.0.0-20250619021614-54a0a98ace81 h1:J56rFEfUTFT9j9CiRXhi1r8lUJ4W5idG3CiaBZGojNU= github.com/lrh3321/ipset-go v0.0.0-20250619021614-54a0a98ace81/go.mod h1:RD8ML/YdXctQ7qbcizZkw5mZ6l8Ogrl1dodBzVJduwI= github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I= diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index 3b4b31765..524b099f1 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -2115,6 +2115,7 @@ func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*rpserv var createdAt, certIssuedAt sql.NullTime var status, proxyCluster, sessionPrivateKey, sessionPublicKey sql.NullString var mode, source, sourcePeer sql.NullString + var terminated sql.NullBool err := row.Scan( &s.ID, &s.AccountID, @@ -2135,7 +2136,7 @@ func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*rpserv &s.PortAutoAssigned, &source, &sourcePeer, - &s.Terminated, + &terminated, ) if err != nil { return nil, err @@ -2176,7 +2177,9 @@ func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*rpserv if sourcePeer.Valid { s.SourcePeer = sourcePeer.String } - + if terminated.Valid { + s.Terminated = terminated.Bool + } s.Targets = []*rpservice.Target{} return &s, nil }) diff --git a/management/server/types/networkmap_benchmark_test.go b/management/server/types/networkmap_benchmark_test.go new file mode 100644 index 000000000..38272e7b0 --- /dev/null +++ b/management/server/types/networkmap_benchmark_test.go @@ -0,0 +1,217 @@ +package types_test + +import ( + "context" + "fmt" + "os" + "testing" + + nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/management/server/types" +) + +type benchmarkScale struct { + name string + peers int + groups int +} + +var defaultScales = []benchmarkScale{ + {"100peers_5groups", 100, 5}, + {"500peers_20groups", 500, 20}, + {"1000peers_50groups", 1000, 50}, + {"5000peers_100groups", 5000, 100}, + {"10000peers_200groups", 10000, 200}, + {"20000peers_200groups", 20000, 200}, + {"30000peers_300groups", 30000, 300}, +} + +func skipCIBenchmark(b *testing.B) { + if os.Getenv("CI") == "true" { + b.Skip("Skipping benchmark in CI") + } +} + +// ────────────────────────────────────────────────────────────────────────────── +// Single Peer Network Map Generation +// ────────────────────────────────────────────────────────────────────────────── + +// BenchmarkNetworkMapGeneration_Components benchmarks the components-based approach for a single peer. +func BenchmarkNetworkMapGeneration_Components(b *testing.B) { + skipCIBenchmark(b) + for _, scale := range defaultScales { + b.Run(scale.name, func(b *testing.B) { + account, validatedPeers := scalableTestAccount(scale.peers, scale.groups) + ctx := context.Background() + resourcePolicies := account.GetResourcePoliciesMap() + routers := account.GetResourceRoutersMap() + groupIDToUserIDs := account.GetActiveGroupUsers() + + b.ReportAllocs() + b.ResetTimer() + for range b.N { + _ = account.GetPeerNetworkMapFromComponents(ctx, "peer-0", nbdns.CustomZone{}, nil, validatedPeers, resourcePolicies, routers, nil, groupIDToUserIDs) + } + }) + } +} + +// ────────────────────────────────────────────────────────────────────────────── +// All Peers (UpdateAccountPeers hot path) +// ────────────────────────────────────────────────────────────────────────────── + +// BenchmarkNetworkMapGeneration_AllPeers benchmarks generating network maps for ALL peers. +func BenchmarkNetworkMapGeneration_AllPeers(b *testing.B) { + skipCIBenchmark(b) + scales := []benchmarkScale{ + {"100peers_5groups", 100, 5}, + {"500peers_20groups", 500, 20}, + {"1000peers_50groups", 1000, 50}, + {"5000peers_100groups", 5000, 100}, + } + + for _, scale := range scales { + account, validatedPeers := scalableTestAccount(scale.peers, scale.groups) + ctx := context.Background() + + peerIDs := make([]string, 0, len(account.Peers)) + for peerID := range account.Peers { + peerIDs = append(peerIDs, peerID) + } + + b.Run("components/"+scale.name, func(b *testing.B) { + resourcePolicies := account.GetResourcePoliciesMap() + routers := account.GetResourceRoutersMap() + groupIDToUserIDs := account.GetActiveGroupUsers() + b.ReportAllocs() + b.ResetTimer() + for range b.N { + for _, peerID := range peerIDs { + _ = account.GetPeerNetworkMapFromComponents(ctx, peerID, nbdns.CustomZone{}, nil, validatedPeers, resourcePolicies, routers, nil, groupIDToUserIDs) + } + } + }) + } +} + +// ────────────────────────────────────────────────────────────────────────────── +// Sub-operations +// ────────────────────────────────────────────────────────────────────────────── + +// BenchmarkNetworkMapGeneration_ComponentsCreation benchmarks components extraction. +func BenchmarkNetworkMapGeneration_ComponentsCreation(b *testing.B) { + skipCIBenchmark(b) + for _, scale := range defaultScales { + b.Run(scale.name, func(b *testing.B) { + account, validatedPeers := scalableTestAccount(scale.peers, scale.groups) + ctx := context.Background() + resourcePolicies := account.GetResourcePoliciesMap() + routers := account.GetResourceRoutersMap() + groupIDToUserIDs := account.GetActiveGroupUsers() + b.ReportAllocs() + b.ResetTimer() + for range b.N { + _ = account.GetPeerNetworkMapComponents(ctx, "peer-0", nbdns.CustomZone{}, nil, validatedPeers, resourcePolicies, routers, groupIDToUserIDs) + } + }) + } +} + +// BenchmarkNetworkMapGeneration_ComponentsCalculation benchmarks calculation from pre-built components. +func BenchmarkNetworkMapGeneration_ComponentsCalculation(b *testing.B) { + skipCIBenchmark(b) + for _, scale := range defaultScales { + b.Run(scale.name, func(b *testing.B) { + account, validatedPeers := scalableTestAccount(scale.peers, scale.groups) + ctx := context.Background() + resourcePolicies := account.GetResourcePoliciesMap() + routers := account.GetResourceRoutersMap() + groupIDToUserIDs := account.GetActiveGroupUsers() + components := account.GetPeerNetworkMapComponents(ctx, "peer-0", nbdns.CustomZone{}, nil, validatedPeers, resourcePolicies, routers, groupIDToUserIDs) + b.ReportAllocs() + b.ResetTimer() + for range b.N { + _ = types.CalculateNetworkMapFromComponents(ctx, components) + } + }) + } +} + +// BenchmarkNetworkMapGeneration_PrecomputeMaps benchmarks precomputed map costs. +func BenchmarkNetworkMapGeneration_PrecomputeMaps(b *testing.B) { + skipCIBenchmark(b) + for _, scale := range defaultScales { + b.Run("ResourcePoliciesMap/"+scale.name, func(b *testing.B) { + account, _ := scalableTestAccount(scale.peers, scale.groups) + b.ReportAllocs() + b.ResetTimer() + for range b.N { + _ = account.GetResourcePoliciesMap() + } + }) + b.Run("ResourceRoutersMap/"+scale.name, func(b *testing.B) { + account, _ := scalableTestAccount(scale.peers, scale.groups) + b.ReportAllocs() + b.ResetTimer() + for range b.N { + _ = account.GetResourceRoutersMap() + } + }) + b.Run("ActiveGroupUsers/"+scale.name, func(b *testing.B) { + account, _ := scalableTestAccount(scale.peers, scale.groups) + b.ReportAllocs() + b.ResetTimer() + for range b.N { + _ = account.GetActiveGroupUsers() + } + }) + } +} + +// ────────────────────────────────────────────────────────────────────────────── +// Scaling Analysis +// ────────────────────────────────────────────────────────────────────────────── + +// BenchmarkNetworkMapGeneration_GroupScaling tests group count impact on performance. +func BenchmarkNetworkMapGeneration_GroupScaling(b *testing.B) { + skipCIBenchmark(b) + groupCounts := []int{1, 5, 20, 50, 100, 200, 500} + for _, numGroups := range groupCounts { + b.Run(fmt.Sprintf("components_%dgroups", numGroups), func(b *testing.B) { + account, validatedPeers := scalableTestAccount(1000, numGroups) + ctx := context.Background() + resourcePolicies := account.GetResourcePoliciesMap() + routers := account.GetResourceRoutersMap() + groupIDToUserIDs := account.GetActiveGroupUsers() + b.ReportAllocs() + b.ResetTimer() + for range b.N { + _ = account.GetPeerNetworkMapFromComponents(ctx, "peer-0", nbdns.CustomZone{}, nil, validatedPeers, resourcePolicies, routers, nil, groupIDToUserIDs) + } + }) + } +} + +// BenchmarkNetworkMapGeneration_PeerScaling tests peer count impact on performance. +func BenchmarkNetworkMapGeneration_PeerScaling(b *testing.B) { + skipCIBenchmark(b) + peerCounts := []int{50, 100, 500, 1000, 2000, 5000, 10000, 20000, 30000} + for _, numPeers := range peerCounts { + numGroups := numPeers / 20 + if numGroups < 1 { + numGroups = 1 + } + b.Run(fmt.Sprintf("components_%dpeers", numPeers), func(b *testing.B) { + account, validatedPeers := scalableTestAccount(numPeers, numGroups) + ctx := context.Background() + resourcePolicies := account.GetResourcePoliciesMap() + routers := account.GetResourceRoutersMap() + groupIDToUserIDs := account.GetActiveGroupUsers() + b.ReportAllocs() + b.ResetTimer() + for range b.N { + _ = account.GetPeerNetworkMapFromComponents(ctx, "peer-0", nbdns.CustomZone{}, nil, validatedPeers, resourcePolicies, routers, nil, groupIDToUserIDs) + } + }) + } +} diff --git a/management/server/types/networkmap_components_correctness_test.go b/management/server/types/networkmap_components_correctness_test.go new file mode 100644 index 000000000..5cd41ff10 --- /dev/null +++ b/management/server/types/networkmap_components_correctness_test.go @@ -0,0 +1,1192 @@ +package types_test + +import ( + "context" + "fmt" + "net" + "net/netip" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + nbdns "github.com/netbirdio/netbird/dns" + resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" + routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" + networkTypes "github.com/netbirdio/netbird/management/server/networks/types" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/posture" + "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/route" +) + +// scalableTestAccountWithoutDefaultPolicy creates an account without the blanket "Allow All" policy. +// Use this for tests that need to verify feature-specific connectivity in isolation. +func scalableTestAccountWithoutDefaultPolicy(numPeers, numGroups int) (*types.Account, map[string]struct{}) { + return buildScalableTestAccount(numPeers, numGroups, false) +} + +// scalableTestAccount creates a realistic account with a blanket "Allow All" policy +// plus per-group policies, routes, network resources, posture checks, and DNS settings. +func scalableTestAccount(numPeers, numGroups int) (*types.Account, map[string]struct{}) { + return buildScalableTestAccount(numPeers, numGroups, true) +} + +// buildScalableTestAccount is the core builder. When withDefaultPolicy is true it adds +// a blanket group-all <-> group-all allow rule; when false the only policies are the +// per-group ones, so tests can verify feature-specific connectivity in isolation. +func buildScalableTestAccount(numPeers, numGroups int, withDefaultPolicy bool) (*types.Account, map[string]struct{}) { + peers := make(map[string]*nbpeer.Peer, numPeers) + allGroupPeers := make([]string, 0, numPeers) + + for i := range numPeers { + peerID := fmt.Sprintf("peer-%d", i) + ip := net.IP{100, byte(64 + i/65536), byte((i / 256) % 256), byte(i % 256)} + wtVersion := "0.25.0" + if i%2 == 0 { + wtVersion = "0.40.0" + } + + p := &nbpeer.Peer{ + ID: peerID, + IP: ip, + Key: fmt.Sprintf("key-%s", peerID), + DNSLabel: fmt.Sprintf("peer%d", i), + Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()}, + UserID: "user-admin", + Meta: nbpeer.PeerSystemMeta{WtVersion: wtVersion, GoOS: "linux"}, + } + + if i == numPeers-2 { + p.LoginExpirationEnabled = true + pastTimestamp := time.Now().Add(-2 * time.Hour) + p.LastLogin = &pastTimestamp + } + + peers[peerID] = p + allGroupPeers = append(allGroupPeers, peerID) + } + + groups := make(map[string]*types.Group, numGroups+1) + groups["group-all"] = &types.Group{ID: "group-all", Name: "All", Peers: allGroupPeers} + + peersPerGroup := numPeers / numGroups + if peersPerGroup < 1 { + peersPerGroup = 1 + } + + for g := range numGroups { + groupID := fmt.Sprintf("group-%d", g) + groupPeers := make([]string, 0, peersPerGroup) + start := g * peersPerGroup + end := start + peersPerGroup + if end > numPeers { + end = numPeers + } + for i := start; i < end; i++ { + groupPeers = append(groupPeers, fmt.Sprintf("peer-%d", i)) + } + groups[groupID] = &types.Group{ID: groupID, Name: fmt.Sprintf("Group %d", g), Peers: groupPeers} + } + + policies := make([]*types.Policy, 0, numGroups+2) + if withDefaultPolicy { + policies = append(policies, &types.Policy{ + ID: "policy-all", Name: "Default-Allow", Enabled: true, + Rules: []*types.PolicyRule{{ + ID: "rule-all", Name: "Allow All", Enabled: true, Action: types.PolicyTrafficActionAccept, + Protocol: types.PolicyRuleProtocolALL, Bidirectional: true, + Sources: []string{"group-all"}, Destinations: []string{"group-all"}, + }}, + }) + } + + for g := range numGroups { + groupID := fmt.Sprintf("group-%d", g) + dstGroup := fmt.Sprintf("group-%d", (g+1)%numGroups) + policies = append(policies, &types.Policy{ + ID: fmt.Sprintf("policy-%d", g), Name: fmt.Sprintf("Policy %d", g), Enabled: true, + Rules: []*types.PolicyRule{{ + ID: fmt.Sprintf("rule-%d", g), Name: fmt.Sprintf("Rule %d", g), Enabled: true, + Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolTCP, + Bidirectional: true, + Ports: []string{"8080"}, + Sources: []string{groupID}, Destinations: []string{dstGroup}, + }}, + }) + } + + if numGroups >= 2 { + policies = append(policies, &types.Policy{ + ID: "policy-drop", Name: "Drop DB traffic", Enabled: true, + Rules: []*types.PolicyRule{{ + ID: "rule-drop", Name: "Drop DB", Enabled: true, Action: types.PolicyTrafficActionDrop, + Protocol: types.PolicyRuleProtocolTCP, Ports: []string{"5432"}, Bidirectional: true, + Sources: []string{"group-0"}, Destinations: []string{"group-1"}, + }}, + }) + } + + numRoutes := numGroups + if numRoutes > 20 { + numRoutes = 20 + } + routes := make(map[route.ID]*route.Route, numRoutes) + for r := range numRoutes { + routeID := route.ID(fmt.Sprintf("route-%d", r)) + peerIdx := (numPeers / 2) + r + if peerIdx >= numPeers { + peerIdx = numPeers - 1 + } + routePeerID := fmt.Sprintf("peer-%d", peerIdx) + groupID := fmt.Sprintf("group-%d", r%numGroups) + routes[routeID] = &route.Route{ + ID: routeID, + Network: netip.MustParsePrefix(fmt.Sprintf("10.%d.0.0/16", r)), + Peer: peers[routePeerID].Key, + PeerID: routePeerID, + Description: fmt.Sprintf("Route %d", r), + Enabled: true, + PeerGroups: []string{groupID}, + Groups: []string{"group-all"}, + AccessControlGroups: []string{groupID}, + AccountID: "test-account", + } + } + + numResources := numGroups / 2 + if numResources < 1 { + numResources = 1 + } + if numResources > 50 { + numResources = 50 + } + + networkResources := make([]*resourceTypes.NetworkResource, 0, numResources) + networksList := make([]*networkTypes.Network, 0, numResources) + networkRouters := make([]*routerTypes.NetworkRouter, 0, numResources) + + routingPeerStart := numPeers * 3 / 4 + for nr := range numResources { + netID := fmt.Sprintf("net-%d", nr) + resID := fmt.Sprintf("res-%d", nr) + routerPeerIdx := routingPeerStart + nr + if routerPeerIdx >= numPeers { + routerPeerIdx = numPeers - 1 + } + routerPeerID := fmt.Sprintf("peer-%d", routerPeerIdx) + + networksList = append(networksList, &networkTypes.Network{ID: netID, Name: fmt.Sprintf("Network %d", nr), AccountID: "test-account"}) + networkResources = append(networkResources, &resourceTypes.NetworkResource{ + ID: resID, NetworkID: netID, AccountID: "test-account", Enabled: true, + Address: fmt.Sprintf("svc-%d.netbird.cloud", nr), + }) + networkRouters = append(networkRouters, &routerTypes.NetworkRouter{ + ID: fmt.Sprintf("router-%d", nr), NetworkID: netID, Peer: routerPeerID, + Enabled: true, AccountID: "test-account", + }) + + policies = append(policies, &types.Policy{ + ID: fmt.Sprintf("policy-res-%d", nr), Name: fmt.Sprintf("Resource Policy %d", nr), Enabled: true, + SourcePostureChecks: []string{"posture-check-ver"}, + Rules: []*types.PolicyRule{{ + ID: fmt.Sprintf("rule-res-%d", nr), Name: fmt.Sprintf("Allow Resource %d", nr), Enabled: true, + Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolALL, Bidirectional: true, + Sources: []string{fmt.Sprintf("group-%d", nr%numGroups)}, + DestinationResource: types.Resource{ID: resID}, + }}, + }) + } + + account := &types.Account{ + Id: "test-account", + Peers: peers, + Groups: groups, + Policies: policies, + Routes: routes, + Users: map[string]*types.User{ + "user-admin": {Id: "user-admin", Role: types.UserRoleAdmin, IsServiceUser: false, AccountID: "test-account"}, + }, + Network: &types.Network{ + Identifier: "net-test", Net: net.IPNet{IP: net.IP{100, 64, 0, 0}, Mask: net.CIDRMask(10, 32)}, Serial: 1, + }, + DNSSettings: types.DNSSettings{DisabledManagementGroups: []string{}}, + NameServerGroups: map[string]*nbdns.NameServerGroup{ + "ns-group-main": { + ID: "ns-group-main", Name: "Main NS", Enabled: true, Groups: []string{"group-all"}, + NameServers: []nbdns.NameServer{{IP: netip.MustParseAddr("8.8.8.8"), NSType: nbdns.UDPNameServerType, Port: 53}}, + }, + }, + PostureChecks: []*posture.Checks{ + {ID: "posture-check-ver", Name: "Check version", Checks: posture.ChecksDefinition{ + NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.26.0"}, + }}, + }, + NetworkResources: networkResources, + Networks: networksList, + NetworkRouters: networkRouters, + Settings: &types.Settings{PeerLoginExpirationEnabled: true, PeerLoginExpiration: 1 * time.Hour}, + } + + for _, p := range account.Policies { + p.AccountID = account.Id + } + for _, r := range account.Routes { + r.AccountID = account.Id + } + + validatedPeers := make(map[string]struct{}, numPeers) + for i := range numPeers { + peerID := fmt.Sprintf("peer-%d", i) + if i != numPeers-1 { + validatedPeers[peerID] = struct{}{} + } + } + + return account, validatedPeers +} + +// componentsNetworkMap is a convenience wrapper for GetPeerNetworkMapFromComponents. +func componentsNetworkMap(account *types.Account, peerID string, validatedPeers map[string]struct{}) *types.NetworkMap { + return account.GetPeerNetworkMapFromComponents( + context.Background(), peerID, nbdns.CustomZone{}, nil, + validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), + nil, account.GetActiveGroupUsers(), + ) +} + +// ────────────────────────────────────────────────────────────────────────────── +// 1. PEER VISIBILITY & GROUPS +// ────────────────────────────────────────────────────────────────────────────── + +func TestComponents_PeerVisibility(t *testing.T) { + account, validatedPeers := scalableTestAccount(100, 5) + nm := componentsNetworkMap(account, "peer-0", validatedPeers) + require.NotNil(t, nm) + assert.Equal(t, len(validatedPeers)-1-len(nm.OfflinePeers), len(nm.Peers), "peer should see all other validated non-expired peers") +} + +func TestComponents_PeerDoesNotSeeItself(t *testing.T) { + account, validatedPeers := scalableTestAccount(50, 5) + nm := componentsNetworkMap(account, "peer-0", validatedPeers) + require.NotNil(t, nm) + for _, p := range nm.Peers { + assert.NotEqual(t, "peer-0", p.ID, "peer should not see itself") + } +} + +func TestComponents_IntraGroupConnectivity(t *testing.T) { + account, validatedPeers := scalableTestAccount(20, 2) + nm := componentsNetworkMap(account, "peer-0", validatedPeers) + require.NotNil(t, nm) + + peerIDs := make(map[string]bool, len(nm.Peers)) + for _, p := range nm.Peers { + peerIDs[p.ID] = true + } + assert.True(t, peerIDs["peer-5"], "peer-0 should see peer-5 from same group") +} + +func TestComponents_CrossGroupConnectivity(t *testing.T) { + // Without default policy, only per-group policies provide connectivity + account, validatedPeers := scalableTestAccountWithoutDefaultPolicy(20, 2) + nm := componentsNetworkMap(account, "peer-0", validatedPeers) + require.NotNil(t, nm) + + peerIDs := make(map[string]bool, len(nm.Peers)) + for _, p := range nm.Peers { + peerIDs[p.ID] = true + } + assert.True(t, peerIDs["peer-10"], "peer-0 should see peer-10 from cross-group policy") +} + +func TestComponents_BidirectionalPolicy(t *testing.T) { + // Without default policy so bidirectional visibility comes only from per-group policies + account, validatedPeers := scalableTestAccountWithoutDefaultPolicy(100, 5) + nm0 := componentsNetworkMap(account, "peer-0", validatedPeers) + nm20 := componentsNetworkMap(account, "peer-20", validatedPeers) + require.NotNil(t, nm0) + require.NotNil(t, nm20) + + peer0SeesPeer20 := false + for _, p := range nm0.Peers { + if p.ID == "peer-20" { + peer0SeesPeer20 = true + } + } + peer20SeesPeer0 := false + for _, p := range nm20.Peers { + if p.ID == "peer-0" { + peer20SeesPeer0 = true + } + } + assert.True(t, peer0SeesPeer20, "peer-0 should see peer-20 via bidirectional policy") + assert.True(t, peer20SeesPeer0, "peer-20 should see peer-0 via bidirectional policy") +} + +// ────────────────────────────────────────────────────────────────────────────── +// 2. PEER EXPIRATION & ACCOUNT SETTINGS +// ────────────────────────────────────────────────────────────────────────────── + +func TestComponents_ExpiredPeerInOfflineList(t *testing.T) { + account, validatedPeers := scalableTestAccount(100, 5) + nm := componentsNetworkMap(account, "peer-0", validatedPeers) + require.NotNil(t, nm) + + offlineIDs := make(map[string]bool, len(nm.OfflinePeers)) + for _, p := range nm.OfflinePeers { + offlineIDs[p.ID] = true + } + assert.True(t, offlineIDs["peer-98"], "expired peer should be in OfflinePeers") + for _, p := range nm.Peers { + assert.NotEqual(t, "peer-98", p.ID, "expired peer should not be in active Peers") + } +} + +func TestComponents_ExpirationDisabledSetting(t *testing.T) { + account, validatedPeers := scalableTestAccount(100, 5) + account.Settings.PeerLoginExpirationEnabled = false + + nm := componentsNetworkMap(account, "peer-0", validatedPeers) + require.NotNil(t, nm) + + peerIDs := make(map[string]bool, len(nm.Peers)) + for _, p := range nm.Peers { + peerIDs[p.ID] = true + } + assert.True(t, peerIDs["peer-98"], "with expiration disabled, peer-98 should be in active Peers") +} + +func TestComponents_LoginExpiration_PeerLevel(t *testing.T) { + account, validatedPeers := scalableTestAccount(20, 2) + account.Settings.PeerLoginExpirationEnabled = true + account.Settings.PeerLoginExpiration = 1 * time.Hour + + pastLogin := time.Now().Add(-2 * time.Hour) + account.Peers["peer-5"].LastLogin = &pastLogin + account.Peers["peer-5"].LoginExpirationEnabled = true + + nm := componentsNetworkMap(account, "peer-0", validatedPeers) + require.NotNil(t, nm) + + offlineIDs := make(map[string]bool, len(nm.OfflinePeers)) + for _, p := range nm.OfflinePeers { + offlineIDs[p.ID] = true + } + assert.True(t, offlineIDs["peer-5"], "login-expired peer should be in OfflinePeers") + for _, p := range nm.Peers { + assert.NotEqual(t, "peer-5", p.ID, "login-expired peer should not be in active Peers") + } +} + +func TestComponents_NetworkSerial(t *testing.T) { + account, validatedPeers := scalableTestAccount(50, 5) + account.Network.Serial = 42 + nm := componentsNetworkMap(account, "peer-0", validatedPeers) + require.NotNil(t, nm) + assert.Equal(t, uint64(42), nm.Network.Serial, "network serial should match") +} + +// ────────────────────────────────────────────────────────────────────────────── +// 3. NON-VALIDATED PEERS +// ────────────────────────────────────────────────────────────────────────────── + +func TestComponents_NonValidatedPeerExcluded(t *testing.T) { + account, validatedPeers := scalableTestAccount(100, 5) + nm := componentsNetworkMap(account, "peer-0", validatedPeers) + require.NotNil(t, nm) + for _, p := range nm.Peers { + assert.NotEqual(t, "peer-99", p.ID, "non-validated peer should not appear in Peers") + } + for _, p := range nm.OfflinePeers { + assert.NotEqual(t, "peer-99", p.ID, "non-validated peer should not appear in OfflinePeers") + } +} + +func TestComponents_NonValidatedTargetPeerGetsEmptyMap(t *testing.T) { + account, validatedPeers := scalableTestAccount(100, 5) + nm := componentsNetworkMap(account, "peer-99", validatedPeers) + require.NotNil(t, nm) + assert.Empty(t, nm.Peers) + assert.Empty(t, nm.FirewallRules) +} + +func TestComponents_NonExistentPeerGetsEmptyMap(t *testing.T) { + account, validatedPeers := scalableTestAccount(100, 5) + nm := componentsNetworkMap(account, "peer-does-not-exist", validatedPeers) + require.NotNil(t, nm) + assert.Empty(t, nm.Peers) + assert.Empty(t, nm.FirewallRules) +} + +// ────────────────────────────────────────────────────────────────────────────── +// 4. POLICIES & FIREWALL RULES +// ────────────────────────────────────────────────────────────────────────────── + +func TestComponents_FirewallRulesGenerated(t *testing.T) { + account, validatedPeers := scalableTestAccount(100, 5) + nm := componentsNetworkMap(account, "peer-0", validatedPeers) + require.NotNil(t, nm) + assert.NotEmpty(t, nm.FirewallRules, "should have firewall rules from policies") +} + +func TestComponents_DropPolicyGeneratesDropRules(t *testing.T) { + account, validatedPeers := scalableTestAccount(100, 5) + nm := componentsNetworkMap(account, "peer-0", validatedPeers) + require.NotNil(t, nm) + + hasDropRule := false + for _, rule := range nm.FirewallRules { + if rule.Action == string(types.PolicyTrafficActionDrop) { + hasDropRule = true + break + } + } + assert.True(t, hasDropRule, "should have at least one drop firewall rule") +} + +func TestComponents_DisabledPolicyIgnored(t *testing.T) { + account, validatedPeers := scalableTestAccount(50, 2) + for _, p := range account.Policies { + p.Enabled = false + } + nm := componentsNetworkMap(account, "peer-0", validatedPeers) + require.NotNil(t, nm) + assert.Empty(t, nm.Peers, "disabled policies should yield no peers") + assert.Empty(t, nm.FirewallRules, "disabled policies should yield no firewall rules") +} + +func TestComponents_PortPolicy(t *testing.T) { + account, validatedPeers := scalableTestAccount(50, 2) + nm := componentsNetworkMap(account, "peer-0", validatedPeers) + require.NotNil(t, nm) + + has8080, has5432 := false, false + for _, rule := range nm.FirewallRules { + if rule.Port == "8080" { + has8080 = true + } + if rule.Port == "5432" { + has5432 = true + } + } + assert.True(t, has8080, "should have firewall rule for port 8080") + assert.True(t, has5432, "should have firewall rule for port 5432 (drop policy)") +} + +func TestComponents_PortRangePolicy(t *testing.T) { + account, validatedPeers := scalableTestAccount(50, 2) + account.Peers["peer-0"].Meta.WtVersion = "0.50.0" + + account.Policies = append(account.Policies, &types.Policy{ + ID: "policy-port-range", Name: "Port Range", Enabled: true, AccountID: "test-account", + Rules: []*types.PolicyRule{{ + ID: "rule-port-range", Name: "Port Range Rule", Enabled: true, + Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolTCP, + Bidirectional: true, + PortRanges: []types.RulePortRange{{Start: 8000, End: 9000}}, + Sources: []string{"group-0"}, Destinations: []string{"group-1"}, + }}, + }) + + nm := componentsNetworkMap(account, "peer-0", validatedPeers) + require.NotNil(t, nm) + + hasPortRange := false + for _, rule := range nm.FirewallRules { + if rule.PortRange.Start == 8000 && rule.PortRange.End == 9000 { + hasPortRange = true + break + } + } + assert.True(t, hasPortRange, "should have firewall rule with port range 8000-9000") +} + +func TestComponents_FirewallRuleDirection(t *testing.T) { + account, validatedPeers := scalableTestAccount(50, 2) + nm := componentsNetworkMap(account, "peer-0", validatedPeers) + require.NotNil(t, nm) + + hasIn, hasOut := false, false + for _, rule := range nm.FirewallRules { + if rule.Direction == types.FirewallRuleDirectionIN { + hasIn = true + } + if rule.Direction == types.FirewallRuleDirectionOUT { + hasOut = true + } + } + assert.True(t, hasIn, "should have inbound firewall rules") + assert.True(t, hasOut, "should have outbound firewall rules") +} + +// ────────────────────────────────────────────────────────────────────────────── +// 5. ROUTES +// ────────────────────────────────────────────────────────────────────────────── + +func TestComponents_RoutesIncluded(t *testing.T) { + account, validatedPeers := scalableTestAccount(100, 5) + nm := componentsNetworkMap(account, "peer-0", validatedPeers) + require.NotNil(t, nm) + assert.NotEmpty(t, nm.Routes, "should have routes") +} + +func TestComponents_DisabledRouteExcluded(t *testing.T) { + account, validatedPeers := scalableTestAccount(50, 2) + for _, r := range account.Routes { + r.Enabled = false + } + nm := componentsNetworkMap(account, "peer-0", validatedPeers) + require.NotNil(t, nm) + for _, r := range nm.Routes { + assert.True(t, r.Enabled, "only enabled routes should appear") + } +} + +func TestComponents_RoutesFirewallRulesForACG(t *testing.T) { + account, validatedPeers := scalableTestAccount(100, 5) + nm := componentsNetworkMap(account, "peer-0", validatedPeers) + require.NotNil(t, nm) + assert.NotEmpty(t, nm.RoutesFirewallRules, "should have route firewall rules for access-controlled routes") +} + +func TestComponents_HARouteDeduplication(t *testing.T) { + account, validatedPeers := scalableTestAccount(50, 5) + + haNetwork := netip.MustParsePrefix("172.16.0.0/16") + account.Routes["route-ha-1"] = &route.Route{ + ID: "route-ha-1", Network: haNetwork, PeerID: "peer-10", + Peer: account.Peers["peer-10"].Key, Enabled: true, Metric: 100, + Groups: []string{"group-all"}, PeerGroups: []string{"group-0"}, AccountID: "test-account", + } + account.Routes["route-ha-2"] = &route.Route{ + ID: "route-ha-2", Network: haNetwork, PeerID: "peer-20", + Peer: account.Peers["peer-20"].Key, Enabled: true, Metric: 200, + Groups: []string{"group-all"}, PeerGroups: []string{"group-1"}, AccountID: "test-account", + } + + nm := componentsNetworkMap(account, "peer-0", validatedPeers) + require.NotNil(t, nm) + + haRoutes := 0 + for _, r := range nm.Routes { + if r.Network == haNetwork { + haRoutes++ + } + } + // Components deduplicates HA routes with the same HA unique ID, returning one entry per HA group + assert.Equal(t, 1, haRoutes, "HA routes with same network should be deduplicated into one entry") +} + +// ────────────────────────────────────────────────────────────────────────────── +// 6. NETWORK RESOURCES & ROUTERS +// ────────────────────────────────────────────────────────────────────────────── + +func TestComponents_NetworkResourceRoutes_RouterPeer(t *testing.T) { + account, validatedPeers := scalableTestAccount(100, 5) + + var routerPeerID string + for _, nr := range account.NetworkRouters { + routerPeerID = nr.Peer + break + } + require.NotEmpty(t, routerPeerID) + + nm := componentsNetworkMap(account, routerPeerID, validatedPeers) + require.NotNil(t, nm) + assert.NotEmpty(t, nm.Peers, "router peer should see source peers") +} + +func TestComponents_NetworkResourceRoutes_SourcePeerSeesRouterPeer(t *testing.T) { + account, validatedPeers := scalableTestAccount(100, 5) + + var routerPeerID string + for _, nr := range account.NetworkRouters { + routerPeerID = nr.Peer + break + } + require.NotEmpty(t, routerPeerID) + + nm := componentsNetworkMap(account, "peer-0", validatedPeers) + require.NotNil(t, nm) + + peerIDs := make(map[string]bool, len(nm.Peers)) + for _, p := range nm.Peers { + peerIDs[p.ID] = true + } + assert.True(t, peerIDs[routerPeerID], "source peer should see router peer for network resource") +} + +func TestComponents_DisabledNetworkResourceIgnored(t *testing.T) { + account, validatedPeers := scalableTestAccount(50, 5) + for _, nr := range account.NetworkResources { + nr.Enabled = false + } + nm := componentsNetworkMap(account, "peer-0", validatedPeers) + require.NotNil(t, nm) + assert.NotNil(t, nm.Network) +} + +// ────────────────────────────────────────────────────────────────────────────── +// 7. POSTURE CHECKS +// ────────────────────────────────────────────────────────────────────────────── + +func TestComponents_PostureCheckFiltering_PassingPeer(t *testing.T) { + account, validatedPeers := scalableTestAccount(100, 5) + nm := componentsNetworkMap(account, "peer-0", validatedPeers) + require.NotNil(t, nm) + assert.NotEmpty(t, nm.Routes, "passing peer should have routes including resource routes") +} + +func TestComponents_PostureCheckFiltering_FailingPeer(t *testing.T) { + // peer-0 has version 0.40.0 (passes posture check >= 0.26.0) + // peer-1 has version 0.25.0 (fails posture check >= 0.26.0) + // Resource policies require posture-check-ver, so the failing peer + // should not see the router peer for those resources. + account, validatedPeers := scalableTestAccountWithoutDefaultPolicy(100, 5) + + nm0 := componentsNetworkMap(account, "peer-0", validatedPeers) + nm1 := componentsNetworkMap(account, "peer-1", validatedPeers) + require.NotNil(t, nm0) + require.NotNil(t, nm1) + + // The passing peer should have more peers visible (including resource router peers) + // than the failing peer, because the failing peer is excluded from resource policies. + assert.Greater(t, len(nm0.Peers), len(nm1.Peers), + "passing peer (0.40.0) should see more peers than failing peer (0.25.0) due to posture-gated resource policies") +} + +func TestComponents_MultiplePostureChecks(t *testing.T) { + account, validatedPeers := scalableTestAccountWithoutDefaultPolicy(50, 2) + + // Keep only the posture-gated policy — remove per-group policies so connectivity is isolated + account.Policies = []*types.Policy{} + + // Set kernel version on peers so the OS posture check can evaluate + for _, p := range account.Peers { + p.Meta.KernelVersion = "5.15.0" + } + + account.PostureChecks = append(account.PostureChecks, &posture.Checks{ + ID: "posture-check-os", Name: "Check OS", + Checks: posture.ChecksDefinition{ + OSVersionCheck: &posture.OSVersionCheck{Linux: &posture.MinKernelVersionCheck{MinKernelVersion: "0.0.1"}}, + }, + }) + account.Policies = append(account.Policies, &types.Policy{ + ID: "policy-multi-posture", Name: "Multi Posture", Enabled: true, AccountID: "test-account", + SourcePostureChecks: []string{"posture-check-ver", "posture-check-os"}, + Rules: []*types.PolicyRule{{ + ID: "rule-multi-posture", Name: "Multi Check Rule", Enabled: true, + Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolALL, + Bidirectional: true, + Sources: []string{"group-0"}, Destinations: []string{"group-1"}, + }}, + }) + + // peer-0 (0.40.0, kernel 5.15.0) passes both checks, should see group-1 peers + nm0 := componentsNetworkMap(account, "peer-0", validatedPeers) + require.NotNil(t, nm0) + assert.NotEmpty(t, nm0.Peers, "peer passing both posture checks should see destination peers") + + // peer-1 (0.25.0, kernel 5.15.0) fails version check, should NOT see group-1 peers + nm1 := componentsNetworkMap(account, "peer-1", validatedPeers) + require.NotNil(t, nm1) + assert.Empty(t, nm1.Peers, + "peer failing posture check should see no peers when posture-gated policy is the only connectivity") +} + +// ────────────────────────────────────────────────────────────────────────────── +// 8. DNS +// ────────────────────────────────────────────────────────────────────────────── + +func TestComponents_DNSConfigEnabled(t *testing.T) { + account, validatedPeers := scalableTestAccount(100, 5) + nm := componentsNetworkMap(account, "peer-0", validatedPeers) + require.NotNil(t, nm) + assert.True(t, nm.DNSConfig.ServiceEnable, "DNS should be enabled") + assert.NotEmpty(t, nm.DNSConfig.NameServerGroups, "should have nameserver groups") +} + +func TestComponents_DNSDisabledByManagementGroup(t *testing.T) { + account, validatedPeers := scalableTestAccount(100, 5) + account.DNSSettings.DisabledManagementGroups = []string{"group-all"} + + nm := componentsNetworkMap(account, "peer-0", validatedPeers) + require.NotNil(t, nm) + assert.False(t, nm.DNSConfig.ServiceEnable, "DNS should be disabled for peer in disabled group") +} + +func TestComponents_DNSNameServerGroupDistribution(t *testing.T) { + account, validatedPeers := scalableTestAccount(20, 2) + account.NameServerGroups["ns-group-0"] = &nbdns.NameServerGroup{ + ID: "ns-group-0", Name: "Group 0 NS", Enabled: true, Groups: []string{"group-0"}, + NameServers: []nbdns.NameServer{{IP: netip.MustParseAddr("1.1.1.1"), NSType: nbdns.UDPNameServerType, Port: 53}}, + } + + nm0 := componentsNetworkMap(account, "peer-0", validatedPeers) + require.NotNil(t, nm0) + hasGroup0NS := false + for _, ns := range nm0.DNSConfig.NameServerGroups { + if ns.ID == "ns-group-0" { + hasGroup0NS = true + } + } + assert.True(t, hasGroup0NS, "peer-0 in group-0 should receive ns-group-0") + + nm10 := componentsNetworkMap(account, "peer-10", validatedPeers) + require.NotNil(t, nm10) + hasGroup0NSForPeer10 := false + for _, ns := range nm10.DNSConfig.NameServerGroups { + if ns.ID == "ns-group-0" { + hasGroup0NSForPeer10 = true + } + } + assert.False(t, hasGroup0NSForPeer10, "peer-10 in group-1 should NOT receive ns-group-0") +} + +func TestComponents_DNSCustomZone(t *testing.T) { + account, validatedPeers := scalableTestAccount(20, 2) + customZone := nbdns.CustomZone{ + Domain: "netbird.cloud.", + Records: []nbdns.SimpleRecord{ + {Name: "peer0.netbird.cloud.", Type: 1, Class: "IN", TTL: 300, RData: account.Peers["peer-0"].IP.String()}, + {Name: "peer1.netbird.cloud.", Type: 1, Class: "IN", TTL: 300, RData: account.Peers["peer-1"].IP.String()}, + }, + } + + nm := account.GetPeerNetworkMapFromComponents( + context.Background(), "peer-0", customZone, nil, + validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), + nil, account.GetActiveGroupUsers(), + ) + require.NotNil(t, nm) + assert.True(t, nm.DNSConfig.ServiceEnable) +} + +// ────────────────────────────────────────────────────────────────────────────── +// 9. SSH +// ────────────────────────────────────────────────────────────────────────────── + +func TestComponents_SSHPolicy(t *testing.T) { + account, validatedPeers := scalableTestAccount(20, 2) + account.Groups["ssh-users"] = &types.Group{ID: "ssh-users", Name: "SSH Users", Peers: []string{}} + account.Policies = append(account.Policies, &types.Policy{ + ID: "policy-ssh", Name: "SSH Access", Enabled: true, AccountID: "test-account", + Rules: []*types.PolicyRule{{ + ID: "rule-ssh", Name: "Allow SSH", Enabled: true, + Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolNetbirdSSH, + Bidirectional: false, + Sources: []string{"group-0"}, Destinations: []string{"group-1"}, + AuthorizedGroups: map[string][]string{"ssh-users": {"root"}}, + }}, + }) + + nm := componentsNetworkMap(account, "peer-10", validatedPeers) + require.NotNil(t, nm) + assert.True(t, nm.EnableSSH, "SSH should be enabled for destination peer of SSH policy") +} + +func TestComponents_SSHNotEnabledWithoutPolicy(t *testing.T) { + account, validatedPeers := scalableTestAccount(20, 2) + nm := componentsNetworkMap(account, "peer-0", validatedPeers) + require.NotNil(t, nm) + assert.False(t, nm.EnableSSH, "SSH should not be enabled without SSH policy") +} + +// ────────────────────────────────────────────────────────────────────────────── +// 10. CROSS-PEER CONSISTENCY +// ────────────────────────────────────────────────────────────────────────────── + +// TestComponents_AllPeersGetValidMaps verifies that every validated peer gets a +// non-nil map with a consistent network serial and non-empty peer list. +func TestComponents_AllPeersGetValidMaps(t *testing.T) { + account, validatedPeers := scalableTestAccount(50, 5) + for peerID := range account.Peers { + if _, validated := validatedPeers[peerID]; !validated { + continue + } + nm := componentsNetworkMap(account, peerID, validatedPeers) + require.NotNil(t, nm, "network map should not be nil for %s", peerID) + assert.Equal(t, account.Network.Serial, nm.Network.Serial, "serial mismatch for %s", peerID) + assert.NotEmpty(t, nm.Peers, "validated peer %s should see other peers", peerID) + } +} + +// TestComponents_LargeScaleMapGeneration verifies that components can generate maps +// at larger scales without errors and with consistent output. +func TestComponents_LargeScaleMapGeneration(t *testing.T) { + scales := []struct{ peers, groups int }{ + {500, 20}, + {1000, 50}, + } + for _, s := range scales { + t.Run(fmt.Sprintf("%dpeers_%dgroups", s.peers, s.groups), func(t *testing.T) { + account, validatedPeers := scalableTestAccount(s.peers, s.groups) + testPeers := []string{"peer-0", fmt.Sprintf("peer-%d", s.peers/4), fmt.Sprintf("peer-%d", s.peers/2)} + for _, peerID := range testPeers { + nm := componentsNetworkMap(account, peerID, validatedPeers) + require.NotNil(t, nm, "network map should not be nil for %s", peerID) + assert.NotEmpty(t, nm.Peers, "peer %s should see other peers at scale", peerID) + assert.NotEmpty(t, nm.Routes, "peer %s should have routes at scale", peerID) + assert.Equal(t, account.Network.Serial, nm.Network.Serial, "serial mismatch for %s", peerID) + } + }) + } +} + +// ────────────────────────────────────────────────────────────────────────────── +// 11. PEER-AS-RESOURCE POLICIES +// ────────────────────────────────────────────────────────────────────────────── + +// TestComponents_PeerAsSourceResource verifies that a policy with SourceResource.Type=Peer +// targets only that specific peer as the source. +func TestComponents_PeerAsSourceResource(t *testing.T) { + account, validatedPeers := scalableTestAccount(20, 2) + + account.Policies = append(account.Policies, &types.Policy{ + ID: "policy-peer-src", Name: "Peer Source Resource", Enabled: true, AccountID: "test-account", + Rules: []*types.PolicyRule{{ + ID: "rule-peer-src", Name: "Peer Source Rule", Enabled: true, + Action: types.PolicyTrafficActionAccept, + Protocol: types.PolicyRuleProtocolTCP, + Bidirectional: true, + Ports: []string{"443"}, + SourceResource: types.Resource{ID: "peer-0", Type: types.ResourceTypePeer}, + Destinations: []string{"group-1"}, + }}, + }) + + // peer-0 is the source resource, should see group-1 peers + nm0 := componentsNetworkMap(account, "peer-0", validatedPeers) + require.NotNil(t, nm0) + + has443 := false + for _, rule := range nm0.FirewallRules { + if rule.Port == "443" { + has443 = true + break + } + } + assert.True(t, has443, "peer-0 as source resource should have port 443 rule") +} + +// TestComponents_PeerAsDestinationResource verifies that a policy with DestinationResource.Type=Peer +// targets only that specific peer as the destination. +func TestComponents_PeerAsDestinationResource(t *testing.T) { + account, validatedPeers := scalableTestAccountWithoutDefaultPolicy(20, 2) + + account.Policies = append(account.Policies, &types.Policy{ + ID: "policy-peer-dst", Name: "Peer Dest Resource", Enabled: true, AccountID: "test-account", + Rules: []*types.PolicyRule{{ + ID: "rule-peer-dst", Name: "Peer Dest Rule", Enabled: true, + Action: types.PolicyTrafficActionAccept, + Protocol: types.PolicyRuleProtocolTCP, + Bidirectional: true, + Ports: []string{"443"}, + Sources: []string{"group-0"}, + DestinationResource: types.Resource{ID: "peer-15", Type: types.ResourceTypePeer}, + }}, + }) + + // peer-0 is in group-0 (source), should see peer-15 as destination + nm0 := componentsNetworkMap(account, "peer-0", validatedPeers) + require.NotNil(t, nm0) + + peerIDs := make(map[string]bool, len(nm0.Peers)) + for _, p := range nm0.Peers { + peerIDs[p.ID] = true + } + assert.True(t, peerIDs["peer-15"], "peer-0 should see peer-15 via peer-as-destination-resource policy") +} + +// ────────────────────────────────────────────────────────────────────────────── +// 12. MULTIPLE RULES PER POLICY +// ────────────────────────────────────────────────────────────────────────────── + +// TestComponents_MultipleRulesPerPolicy verifies a policy with multiple rules generates +// firewall rules for each. +func TestComponents_MultipleRulesPerPolicy(t *testing.T) { + account, validatedPeers := scalableTestAccount(20, 2) + + account.Policies = append(account.Policies, &types.Policy{ + ID: "policy-multi-rule", Name: "Multi Rule Policy", Enabled: true, AccountID: "test-account", + Rules: []*types.PolicyRule{ + { + ID: "rule-http", Name: "Allow HTTP", Enabled: true, + Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolTCP, + Bidirectional: true, Ports: []string{"80"}, + Sources: []string{"group-0"}, Destinations: []string{"group-1"}, + }, + { + ID: "rule-https", Name: "Allow HTTPS", Enabled: true, + Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolTCP, + Bidirectional: true, Ports: []string{"443"}, + Sources: []string{"group-0"}, Destinations: []string{"group-1"}, + }, + }, + }) + + nm := componentsNetworkMap(account, "peer-0", validatedPeers) + require.NotNil(t, nm) + + has80, has443 := false, false + for _, rule := range nm.FirewallRules { + if rule.Port == "80" { + has80 = true + } + if rule.Port == "443" { + has443 = true + } + } + assert.True(t, has80, "should have firewall rule for port 80 from first rule") + assert.True(t, has443, "should have firewall rule for port 443 from second rule") +} + +// ────────────────────────────────────────────────────────────────────────────── +// 13. SSH AUTHORIZED USERS CONTENT +// ────────────────────────────────────────────────────────────────────────────── + +// TestComponents_SSHAuthorizedUsersContent verifies that SSH policies populate +// the AuthorizedUsers map with the correct users and machine mappings. +func TestComponents_SSHAuthorizedUsersContent(t *testing.T) { + account, validatedPeers := scalableTestAccount(20, 2) + + account.Users["user-dev"] = &types.User{Id: "user-dev", Role: types.UserRoleUser, AccountID: "test-account", AutoGroups: []string{"ssh-users"}} + account.Groups["ssh-users"] = &types.Group{ID: "ssh-users", Name: "SSH Users", Peers: []string{}} + + account.Policies = append(account.Policies, &types.Policy{ + ID: "policy-ssh", Name: "SSH Access", Enabled: true, AccountID: "test-account", + Rules: []*types.PolicyRule{{ + ID: "rule-ssh", Name: "Allow SSH", Enabled: true, + Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolNetbirdSSH, + Bidirectional: false, + Sources: []string{"group-0"}, Destinations: []string{"group-1"}, + AuthorizedGroups: map[string][]string{"ssh-users": {"root", "admin"}}, + }}, + }) + + // peer-10 is in group-1 (destination) + nm := componentsNetworkMap(account, "peer-10", validatedPeers) + require.NotNil(t, nm) + assert.True(t, nm.EnableSSH, "SSH should be enabled") + assert.NotNil(t, nm.AuthorizedUsers, "AuthorizedUsers should not be nil") + assert.NotEmpty(t, nm.AuthorizedUsers, "AuthorizedUsers should have entries") + + // Check that "root" machine user mapping exists + _, hasRoot := nm.AuthorizedUsers["root"] + _, hasAdmin := nm.AuthorizedUsers["admin"] + assert.True(t, hasRoot || hasAdmin, "AuthorizedUsers should contain 'root' or 'admin' machine user mapping") +} + +// TestComponents_SSHLegacyImpliedSSH verifies that a non-SSH ALL protocol policy with +// SSHEnabled peer implies legacy SSH access. +func TestComponents_SSHLegacyImpliedSSH(t *testing.T) { + account, validatedPeers := scalableTestAccount(20, 2) + + // Enable SSH on the destination peer + account.Peers["peer-10"].SSHEnabled = true + + // The default "Allow All" policy with Protocol=ALL + SSHEnabled peer should imply SSH + nm := componentsNetworkMap(account, "peer-10", validatedPeers) + require.NotNil(t, nm) + assert.True(t, nm.EnableSSH, "SSH should be implied by ALL protocol policy with SSHEnabled peer") +} + +// ────────────────────────────────────────────────────────────────────────────── +// 14. ROUTE DEFAULT PERMIT (no AccessControlGroups) +// ────────────────────────────────────────────────────────────────────────────── + +// TestComponents_RouteDefaultPermit verifies that a route without AccessControlGroups +// generates default permit firewall rules (0.0.0.0/0 source). +func TestComponents_RouteDefaultPermit(t *testing.T) { + account, validatedPeers := scalableTestAccount(20, 2) + + // Add a route without ACGs — this peer is the routing peer + routingPeerID := "peer-5" + account.Routes["route-no-acg"] = &route.Route{ + ID: "route-no-acg", Network: netip.MustParsePrefix("192.168.99.0/24"), + PeerID: routingPeerID, Peer: account.Peers[routingPeerID].Key, + Enabled: true, Groups: []string{"group-all"}, PeerGroups: []string{"group-0"}, + AccessControlGroups: []string{}, + AccountID: "test-account", + } + + // The routing peer should get default permit route firewall rules + nm := componentsNetworkMap(account, routingPeerID, validatedPeers) + require.NotNil(t, nm) + + hasDefaultPermit := false + for _, rfr := range nm.RoutesFirewallRules { + for _, src := range rfr.SourceRanges { + if src == "0.0.0.0/0" || src == "::/0" { + hasDefaultPermit = true + break + } + } + } + assert.True(t, hasDefaultPermit, "route without ACG should have default permit rule with 0.0.0.0/0 source") +} + +// ────────────────────────────────────────────────────────────────────────────── +// 15. MULTIPLE ROUTERS PER NETWORK +// ────────────────────────────────────────────────────────────────────────────── + +// TestComponents_MultipleRoutersPerNetwork verifies that a network resource +// with multiple routers provides routes through all available routers. +func TestComponents_MultipleRoutersPerNetwork(t *testing.T) { + account, validatedPeers := scalableTestAccountWithoutDefaultPolicy(20, 2) + + netID := "net-multi-router" + resID := "res-multi-router" + account.Networks = append(account.Networks, &networkTypes.Network{ID: netID, Name: "Multi Router Network", AccountID: "test-account"}) + account.NetworkResources = append(account.NetworkResources, &resourceTypes.NetworkResource{ + ID: resID, NetworkID: netID, AccountID: "test-account", Enabled: true, + Address: "multi-svc.netbird.cloud", + }) + account.NetworkRouters = append(account.NetworkRouters, + &routerTypes.NetworkRouter{ID: "router-a", NetworkID: netID, Peer: "peer-5", Enabled: true, AccountID: "test-account", Metric: 100}, + &routerTypes.NetworkRouter{ID: "router-b", NetworkID: netID, Peer: "peer-15", Enabled: true, AccountID: "test-account", Metric: 200}, + ) + account.Policies = append(account.Policies, &types.Policy{ + ID: "policy-multi-router-res", Name: "Multi Router Resource", Enabled: true, AccountID: "test-account", + Rules: []*types.PolicyRule{{ + ID: "rule-multi-router-res", Name: "Allow Multi Router", Enabled: true, + Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolALL, Bidirectional: true, + Sources: []string{"group-0"}, DestinationResource: types.Resource{ID: resID}, + }}, + }) + + // peer-0 is in group-0 (source), should see both router peers + nm := componentsNetworkMap(account, "peer-0", validatedPeers) + require.NotNil(t, nm) + + peerIDs := make(map[string]bool, len(nm.Peers)) + for _, p := range nm.Peers { + peerIDs[p.ID] = true + } + assert.True(t, peerIDs["peer-5"], "source peer should see router-a (peer-5)") + assert.True(t, peerIDs["peer-15"], "source peer should see router-b (peer-15)") +} + +// ────────────────────────────────────────────────────────────────────────────── +// 16. PEER-AS-NAMESERVER EXCLUSION +// ────────────────────────────────────────────────────────────────────────────── + +// TestComponents_PeerIsNameserverExcludedFromNSGroup verifies that a peer serving +// as a nameserver does not receive its own NS group in DNS config. +func TestComponents_PeerIsNameserverExcludedFromNSGroup(t *testing.T) { + account, validatedPeers := scalableTestAccount(20, 2) + + // peer-0 has IP 100.64.0.0 — make it a nameserver + nsIP := account.Peers["peer-0"].IP + account.NameServerGroups["ns-self"] = &nbdns.NameServerGroup{ + ID: "ns-self", Name: "Self NS", Enabled: true, Groups: []string{"group-all"}, + NameServers: []nbdns.NameServer{{IP: netip.AddrFrom4([4]byte{nsIP[0], nsIP[1], nsIP[2], nsIP[3]}), NSType: nbdns.UDPNameServerType, Port: 53}}, + } + + nm := componentsNetworkMap(account, "peer-0", validatedPeers) + require.NotNil(t, nm) + + hasSelfNS := false + for _, ns := range nm.DNSConfig.NameServerGroups { + if ns.ID == "ns-self" { + hasSelfNS = true + } + } + assert.False(t, hasSelfNS, "peer serving as nameserver should NOT receive its own NS group") + + // peer-10 is NOT the nameserver, should receive the NS group + nm10 := componentsNetworkMap(account, "peer-10", validatedPeers) + require.NotNil(t, nm10) + hasNSForPeer10 := false + for _, ns := range nm10.DNSConfig.NameServerGroups { + if ns.ID == "ns-self" { + hasNSForPeer10 = true + } + } + assert.True(t, hasNSForPeer10, "non-nameserver peer should receive the NS group") +} + +// ────────────────────────────────────────────────────────────────────────────── +// 17. DOMAIN NETWORK RESOURCES +// ────────────────────────────────────────────────────────────────────────────── + +// TestComponents_DomainNetworkResource verifies that domain-based network resources +// produce routes with the correct domain configuration. +func TestComponents_DomainNetworkResource(t *testing.T) { + account, validatedPeers := scalableTestAccountWithoutDefaultPolicy(20, 2) + + netID := "net-domain" + resID := "res-domain" + account.Networks = append(account.Networks, &networkTypes.Network{ID: netID, Name: "Domain Network", AccountID: "test-account"}) + account.NetworkResources = append(account.NetworkResources, &resourceTypes.NetworkResource{ + ID: resID, NetworkID: netID, AccountID: "test-account", Enabled: true, + Address: "api.example.com", Type: "domain", + }) + account.NetworkRouters = append(account.NetworkRouters, &routerTypes.NetworkRouter{ + ID: "router-domain", NetworkID: netID, Peer: "peer-5", Enabled: true, AccountID: "test-account", + }) + account.Policies = append(account.Policies, &types.Policy{ + ID: "policy-domain-res", Name: "Domain Resource Policy", Enabled: true, AccountID: "test-account", + Rules: []*types.PolicyRule{{ + ID: "rule-domain-res", Name: "Allow Domain", Enabled: true, + Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolALL, Bidirectional: true, + Sources: []string{"group-0"}, DestinationResource: types.Resource{ID: resID}, + }}, + }) + + // peer-0 is source, should get route to the domain resource via peer-5 + nm := componentsNetworkMap(account, "peer-0", validatedPeers) + require.NotNil(t, nm) + + peerIDs := make(map[string]bool, len(nm.Peers)) + for _, p := range nm.Peers { + peerIDs[p.ID] = true + } + assert.True(t, peerIDs["peer-5"], "source peer should see domain resource router peer") +} + +// ────────────────────────────────────────────────────────────────────────────── +// 18. DISABLED RULE WITHIN ENABLED POLICY +// ────────────────────────────────────────────────────────────────────────────── + +// TestComponents_DisabledRuleInEnabledPolicy verifies that a disabled rule within +// an enabled policy does not generate firewall rules. +func TestComponents_DisabledRuleInEnabledPolicy(t *testing.T) { + account, validatedPeers := scalableTestAccount(20, 2) + + account.Policies = append(account.Policies, &types.Policy{ + ID: "policy-mixed-rules", Name: "Mixed Rules", Enabled: true, AccountID: "test-account", + Rules: []*types.PolicyRule{ + { + ID: "rule-enabled", Name: "Enabled Rule", Enabled: true, + Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolTCP, + Bidirectional: true, Ports: []string{"3000"}, + Sources: []string{"group-0"}, Destinations: []string{"group-1"}, + }, + { + ID: "rule-disabled", Name: "Disabled Rule", Enabled: false, + Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolTCP, + Bidirectional: true, Ports: []string{"3001"}, + Sources: []string{"group-0"}, Destinations: []string{"group-1"}, + }, + }, + }) + + nm := componentsNetworkMap(account, "peer-0", validatedPeers) + require.NotNil(t, nm) + + has3000, has3001 := false, false + for _, rule := range nm.FirewallRules { + if rule.Port == "3000" { + has3000 = true + } + if rule.Port == "3001" { + has3001 = true + } + } + assert.True(t, has3000, "enabled rule should generate firewall rule for port 3000") + assert.False(t, has3001, "disabled rule should NOT generate firewall rule for port 3001") +} diff --git a/relay/server/handshake.go b/relay/server/handshake.go index 8c3ee1899..067888406 100644 --- a/relay/server/handshake.go +++ b/relay/server/handshake.go @@ -1,11 +1,13 @@ package server import ( + "context" "fmt" - "net" + "time" log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/relay/server/listener" "github.com/netbirdio/netbird/shared/relay/messages" //nolint:staticcheck "github.com/netbirdio/netbird/shared/relay/messages/address" @@ -13,6 +15,12 @@ import ( authmsg "github.com/netbirdio/netbird/shared/relay/messages/auth" ) +const ( + // handshakeTimeout bounds how long a connection may remain in the + // pre-authentication handshake phase before being closed. + handshakeTimeout = 10 * time.Second +) + type Validator interface { Validate(any) error // Deprecated: Use Validate instead. @@ -58,7 +66,7 @@ func marshalResponseHelloMsg(instanceURL string) ([]byte, error) { } type handshake struct { - conn net.Conn + conn listener.Conn validator Validator preparedMsg *preparedMsg @@ -66,9 +74,9 @@ type handshake struct { peerID *messages.PeerID } -func (h *handshake) handshakeReceive() (*messages.PeerID, error) { +func (h *handshake) handshakeReceive(ctx context.Context) (*messages.PeerID, error) { buf := make([]byte, messages.MaxHandshakeSize) - n, err := h.conn.Read(buf) + n, err := h.conn.Read(ctx, buf) if err != nil { return nil, fmt.Errorf("read from %s: %w", h.conn.RemoteAddr(), err) } @@ -103,7 +111,7 @@ func (h *handshake) handshakeReceive() (*messages.PeerID, error) { return peerID, nil } -func (h *handshake) handshakeResponse() error { +func (h *handshake) handshakeResponse(ctx context.Context) error { var responseMsg []byte if h.handshakeMethodAuth { responseMsg = h.preparedMsg.responseAuthMsg @@ -111,7 +119,7 @@ func (h *handshake) handshakeResponse() error { responseMsg = h.preparedMsg.responseHelloMsg } - if _, err := h.conn.Write(responseMsg); err != nil { + if _, err := h.conn.Write(ctx, responseMsg); err != nil { return fmt.Errorf("handshake response write to %s (%s): %w", h.peerID, h.conn.RemoteAddr(), err) } diff --git a/relay/server/listener/conn.go b/relay/server/listener/conn.go new file mode 100644 index 000000000..ef0869594 --- /dev/null +++ b/relay/server/listener/conn.go @@ -0,0 +1,14 @@ +package listener + +import ( + "context" + "net" +) + +// Conn is the relay connection contract implemented by WS and QUIC transports. +type Conn interface { + Read(ctx context.Context, b []byte) (n int, err error) + Write(ctx context.Context, b []byte) (n int, err error) + RemoteAddr() net.Addr + Close() error +} diff --git a/relay/server/listener/listener.go b/relay/server/listener/listener.go deleted file mode 100644 index 0a79182f4..000000000 --- a/relay/server/listener/listener.go +++ /dev/null @@ -1,14 +0,0 @@ -package listener - -import ( - "context" - "net" - - "github.com/netbirdio/netbird/relay/protocol" -) - -type Listener interface { - Listen(func(conn net.Conn)) error - Shutdown(ctx context.Context) error - Protocol() protocol.Protocol -} diff --git a/relay/server/listener/quic/conn.go b/relay/server/listener/quic/conn.go index 6e2201bf7..d8dafcd1f 100644 --- a/relay/server/listener/quic/conn.go +++ b/relay/server/listener/quic/conn.go @@ -3,33 +3,26 @@ package quic import ( "context" "errors" - "fmt" "net" "sync" - "time" "github.com/quic-go/quic-go" ) type Conn struct { - session *quic.Conn - closed bool - closedMu sync.Mutex - ctx context.Context - ctxCancel context.CancelFunc + session *quic.Conn + closed bool + closedMu sync.Mutex } func NewConn(session *quic.Conn) *Conn { - ctx, cancel := context.WithCancel(context.Background()) return &Conn{ - session: session, - ctx: ctx, - ctxCancel: cancel, + session: session, } } -func (c *Conn) Read(b []byte) (n int, err error) { - dgram, err := c.session.ReceiveDatagram(c.ctx) +func (c *Conn) Read(ctx context.Context, b []byte) (n int, err error) { + dgram, err := c.session.ReceiveDatagram(ctx) if err != nil { return 0, c.remoteCloseErrHandling(err) } @@ -38,33 +31,17 @@ func (c *Conn) Read(b []byte) (n int, err error) { return n, nil } -func (c *Conn) Write(b []byte) (int, error) { +func (c *Conn) Write(_ context.Context, b []byte) (int, error) { if err := c.session.SendDatagram(b); err != nil { return 0, c.remoteCloseErrHandling(err) } return len(b), nil } -func (c *Conn) LocalAddr() net.Addr { - return c.session.LocalAddr() -} - func (c *Conn) RemoteAddr() net.Addr { return c.session.RemoteAddr() } -func (c *Conn) SetReadDeadline(t time.Time) error { - return nil -} - -func (c *Conn) SetWriteDeadline(t time.Time) error { - return fmt.Errorf("SetWriteDeadline is not implemented") -} - -func (c *Conn) SetDeadline(t time.Time) error { - return fmt.Errorf("SetDeadline is not implemented") -} - func (c *Conn) Close() error { c.closedMu.Lock() if c.closed { @@ -74,8 +51,6 @@ func (c *Conn) Close() error { c.closed = true c.closedMu.Unlock() - c.ctxCancel() // Cancel the context - sessionErr := c.session.CloseWithError(0, "normal closure") return sessionErr } diff --git a/relay/server/listener/quic/listener.go b/relay/server/listener/quic/listener.go index 797223e74..68f0e03c0 100644 --- a/relay/server/listener/quic/listener.go +++ b/relay/server/listener/quic/listener.go @@ -5,12 +5,12 @@ import ( "crypto/tls" "errors" "fmt" - "net" "github.com/quic-go/quic-go" log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/relay/protocol" + relaylistener "github.com/netbirdio/netbird/relay/server/listener" nbRelay "github.com/netbirdio/netbird/shared/relay" ) @@ -25,7 +25,7 @@ type Listener struct { listener *quic.Listener } -func (l *Listener) Listen(acceptFn func(conn net.Conn)) error { +func (l *Listener) Listen(acceptFn func(conn relaylistener.Conn)) error { quicCfg := &quic.Config{ EnableDatagrams: true, InitialPacketSize: nbRelay.QUICInitialPacketSize, diff --git a/relay/server/listener/ws/conn.go b/relay/server/listener/ws/conn.go index d5bce56f7..c22b5719d 100644 --- a/relay/server/listener/ws/conn.go +++ b/relay/server/listener/ws/conn.go @@ -18,25 +18,21 @@ const ( type Conn struct { *websocket.Conn - lAddr *net.TCPAddr rAddr *net.TCPAddr closed bool closedMu sync.Mutex - ctx context.Context } -func NewConn(wsConn *websocket.Conn, lAddr, rAddr *net.TCPAddr) *Conn { +func NewConn(wsConn *websocket.Conn, rAddr *net.TCPAddr) *Conn { return &Conn{ Conn: wsConn, - lAddr: lAddr, rAddr: rAddr, - ctx: context.Background(), } } -func (c *Conn) Read(b []byte) (n int, err error) { - t, r, err := c.Reader(c.ctx) +func (c *Conn) Read(ctx context.Context, b []byte) (n int, err error) { + t, r, err := c.Reader(ctx) if err != nil { return 0, c.ioErrHandling(err) } @@ -56,34 +52,18 @@ func (c *Conn) Read(b []byte) (n int, err error) { // Write writes a binary message with the given payload. // It does not block until fill the internal buffer. // If the buffer filled up, wait until the buffer is drained or timeout. -func (c *Conn) Write(b []byte) (int, error) { - ctx, ctxCancel := context.WithTimeout(c.ctx, writeTimeout) +func (c *Conn) Write(ctx context.Context, b []byte) (int, error) { + ctx, ctxCancel := context.WithTimeout(ctx, writeTimeout) defer ctxCancel() err := c.Conn.Write(ctx, websocket.MessageBinary, b) return len(b), err } -func (c *Conn) LocalAddr() net.Addr { - return c.lAddr -} - func (c *Conn) RemoteAddr() net.Addr { return c.rAddr } -func (c *Conn) SetReadDeadline(t time.Time) error { - return fmt.Errorf("SetReadDeadline is not implemented") -} - -func (c *Conn) SetWriteDeadline(t time.Time) error { - return fmt.Errorf("SetWriteDeadline is not implemented") -} - -func (c *Conn) SetDeadline(t time.Time) error { - return fmt.Errorf("SetDeadline is not implemented") -} - func (c *Conn) Close() error { c.closedMu.Lock() c.closed = true diff --git a/relay/server/listener/ws/listener.go b/relay/server/listener/ws/listener.go index 12219e29b..ba175f901 100644 --- a/relay/server/listener/ws/listener.go +++ b/relay/server/listener/ws/listener.go @@ -7,11 +7,13 @@ import ( "fmt" "net" "net/http" + "time" "github.com/coder/websocket" log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/relay/protocol" + relaylistener "github.com/netbirdio/netbird/relay/server/listener" "github.com/netbirdio/netbird/shared/relay" ) @@ -27,18 +29,19 @@ type Listener struct { TLSConfig *tls.Config server *http.Server - acceptFn func(conn net.Conn) + acceptFn func(conn relaylistener.Conn) } -func (l *Listener) Listen(acceptFn func(conn net.Conn)) error { +func (l *Listener) Listen(acceptFn func(conn relaylistener.Conn)) error { l.acceptFn = acceptFn mux := http.NewServeMux() mux.HandleFunc(URLPath, l.onAccept) l.server = &http.Server{ - Addr: l.Address, - Handler: mux, - TLSConfig: l.TLSConfig, + Addr: l.Address, + Handler: mux, + TLSConfig: l.TLSConfig, + ReadHeaderTimeout: 5 * time.Second, } log.Infof("WS server listening address: %s", l.Address) @@ -93,18 +96,9 @@ func (l *Listener) onAccept(w http.ResponseWriter, r *http.Request) { return } - lAddr, err := net.ResolveTCPAddr("tcp", l.server.Addr) - if err != nil { - err = wsConn.Close(websocket.StatusInternalError, "internal error") - if err != nil { - log.Errorf("failed to close ws connection: %s", err) - } - return - } - log.Infof("WS client connected from: %s", rAddr) - conn := NewConn(wsConn, lAddr, rAddr) + conn := NewConn(wsConn, rAddr) l.acceptFn(conn) } diff --git a/relay/server/peer.go b/relay/server/peer.go index c5ff41857..8376cdfa7 100644 --- a/relay/server/peer.go +++ b/relay/server/peer.go @@ -10,6 +10,7 @@ import ( log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/relay/metrics" + "github.com/netbirdio/netbird/relay/server/listener" "github.com/netbirdio/netbird/relay/server/store" "github.com/netbirdio/netbird/shared/relay/healthcheck" "github.com/netbirdio/netbird/shared/relay/messages" @@ -26,11 +27,14 @@ type Peer struct { metrics *metrics.Metrics log *log.Entry id messages.PeerID - conn net.Conn + conn listener.Conn connMu sync.RWMutex store *store.Store notifier *store.PeerNotifier + ctx context.Context + ctxCancel context.CancelFunc + peersListener *store.Listener // between the online peer collection step and the notification sending should not be sent offline notifications from another thread @@ -38,14 +42,17 @@ type Peer struct { } // NewPeer creates a new Peer instance and prepare custom logging -func NewPeer(metrics *metrics.Metrics, id messages.PeerID, conn net.Conn, store *store.Store, notifier *store.PeerNotifier) *Peer { +func NewPeer(metrics *metrics.Metrics, id messages.PeerID, conn listener.Conn, store *store.Store, notifier *store.PeerNotifier) *Peer { + ctx, cancel := context.WithCancel(context.Background()) p := &Peer{ - metrics: metrics, - log: log.WithField("peer_id", id.String()), - id: id, - conn: conn, - store: store, - notifier: notifier, + metrics: metrics, + log: log.WithField("peer_id", id.String()), + id: id, + conn: conn, + store: store, + notifier: notifier, + ctx: ctx, + ctxCancel: cancel, } return p @@ -57,6 +64,7 @@ func NewPeer(metrics *metrics.Metrics, id messages.PeerID, conn net.Conn, store func (p *Peer) Work() { p.peersListener = p.notifier.NewListener(p.sendPeersOnline, p.sendPeersWentOffline) defer func() { + p.ctxCancel() p.notifier.RemoveListener(p.peersListener) if err := p.conn.Close(); err != nil && !errors.Is(err, net.ErrClosed) { @@ -64,8 +72,7 @@ func (p *Peer) Work() { } }() - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + ctx := p.ctx hc := healthcheck.NewSender(p.log) go hc.StartHealthCheck(ctx) @@ -73,7 +80,7 @@ func (p *Peer) Work() { buf := make([]byte, bufferSize) for { - n, err := p.conn.Read(buf) + n, err := p.conn.Read(ctx, buf) if err != nil { if !errors.Is(err, net.ErrClosed) { p.log.Errorf("failed to read message: %s", err) @@ -131,10 +138,10 @@ func (p *Peer) handleMsgType(ctx context.Context, msgType messages.MsgType, hc * } // Write writes data to the connection -func (p *Peer) Write(b []byte) (int, error) { +func (p *Peer) Write(ctx context.Context, b []byte) (int, error) { p.connMu.RLock() defer p.connMu.RUnlock() - return p.conn.Write(b) + return p.conn.Write(ctx, b) } // CloseGracefully closes the connection with the peer gracefully. Send a close message to the client and close the @@ -147,6 +154,7 @@ func (p *Peer) CloseGracefully(ctx context.Context) { p.log.Errorf("failed to send close message to peer: %s", p.String()) } + p.ctxCancel() if err := p.conn.Close(); err != nil { p.log.Errorf(errCloseConn, err) } @@ -156,6 +164,7 @@ func (p *Peer) Close() { p.connMu.Lock() defer p.connMu.Unlock() + p.ctxCancel() if err := p.conn.Close(); err != nil { p.log.Errorf(errCloseConn, err) } @@ -170,26 +179,15 @@ func (p *Peer) writeWithTimeout(ctx context.Context, buf []byte) error { ctx, cancel := context.WithTimeout(ctx, 3*time.Second) defer cancel() - writeDone := make(chan struct{}) - var err error - go func() { - _, err = p.conn.Write(buf) - close(writeDone) - }() - - select { - case <-ctx.Done(): - return ctx.Err() - case <-writeDone: - return err - } + _, err := p.conn.Write(ctx, buf) + return err } func (p *Peer) handleHealthcheckEvents(ctx context.Context, hc *healthcheck.Sender) { for { select { case <-hc.HealthCheck: - _, err := p.Write(messages.MarshalHealthcheck()) + _, err := p.Write(ctx, messages.MarshalHealthcheck()) if err != nil { p.log.Errorf("failed to send healthcheck message: %s", err) return @@ -228,12 +226,12 @@ func (p *Peer) handleTransportMsg(msg []byte) { return } - n, err := dp.Write(msg) + n, err := dp.Write(dp.ctx, msg) if err != nil { p.log.Errorf("failed to write transport message to: %s", dp.String()) return } - p.metrics.TransferBytesSent.Add(context.Background(), int64(n)) + p.metrics.TransferBytesSent.Add(p.ctx, int64(n)) } func (p *Peer) handleSubscribePeerState(msg []byte) { @@ -276,7 +274,7 @@ func (p *Peer) sendPeersOnline(peers []messages.PeerID) { } for n, msg := range msgs { - if _, err := p.Write(msg); err != nil { + if _, err := p.Write(p.ctx, msg); err != nil { p.log.Errorf("failed to write %d. peers offline message: %s", n, err) } } @@ -293,7 +291,7 @@ func (p *Peer) sendPeersWentOffline(peers []messages.PeerID) { } for n, msg := range msgs { - if _, err := p.Write(msg); err != nil { + if _, err := p.Write(p.ctx, msg); err != nil { p.log.Errorf("failed to write %d. peers offline message: %s", n, err) } } diff --git a/relay/server/relay.go b/relay/server/relay.go index bb355f58f..56add8bea 100644 --- a/relay/server/relay.go +++ b/relay/server/relay.go @@ -3,7 +3,6 @@ package server import ( "context" "fmt" - "net" "net/url" "sync" "time" @@ -13,11 +12,20 @@ import ( "go.opentelemetry.io/otel/metric" "github.com/netbirdio/netbird/relay/healthcheck/peerid" + "github.com/netbirdio/netbird/relay/protocol" + "github.com/netbirdio/netbird/relay/server/listener" + //nolint:staticcheck "github.com/netbirdio/netbird/relay/metrics" "github.com/netbirdio/netbird/relay/server/store" ) +type Listener interface { + Listen(func(conn listener.Conn)) error + Shutdown(ctx context.Context) error + Protocol() protocol.Protocol +} + type Config struct { Meter metric.Meter ExposedAddress string @@ -109,7 +117,7 @@ func NewRelay(config Config) (*Relay, error) { } // Accept start to handle a new peer connection -func (r *Relay) Accept(conn net.Conn) { +func (r *Relay) Accept(conn listener.Conn) { acceptTime := time.Now() r.closeMu.RLock() defer r.closeMu.RUnlock() @@ -117,12 +125,15 @@ func (r *Relay) Accept(conn net.Conn) { return } + hsCtx, hsCancel := context.WithTimeout(context.Background(), handshakeTimeout) + defer hsCancel() + h := handshake{ conn: conn, validator: r.validator, preparedMsg: r.preparedMsg, } - peerID, err := h.handshakeReceive() + peerID, err := h.handshakeReceive(hsCtx) if err != nil { if peerid.IsHealthCheck(peerID) { log.Debugf("health check connection from %s", conn.RemoteAddr()) @@ -154,7 +165,7 @@ func (r *Relay) Accept(conn net.Conn) { r.metrics.PeerDisconnected(peer.String()) }() - if err := h.handshakeResponse(); err != nil { + if err := h.handshakeResponse(hsCtx); err != nil { log.Errorf("failed to send handshake response, close peer: %s", err) peer.Close() } diff --git a/relay/server/server.go b/relay/server/server.go index a0f7eb73c..340da55b8 100644 --- a/relay/server/server.go +++ b/relay/server/server.go @@ -3,7 +3,6 @@ package server import ( "context" "crypto/tls" - "net" "net/url" "sync" @@ -31,7 +30,7 @@ type ListenerConfig struct { // In a new HTTP connection, the server will accept the connection and pass it to the Relay server via the Accept method. type Server struct { relay *Relay - listeners []listener.Listener + listeners []Listener listenerMux sync.Mutex } @@ -56,7 +55,7 @@ func NewServer(config Config) (*Server, error) { } return &Server{ relay: relay, - listeners: make([]listener.Listener, 0, 2), + listeners: make([]Listener, 0, 2), }, nil } @@ -86,7 +85,7 @@ func (r *Server) Listen(cfg ListenerConfig) error { wg := sync.WaitGroup{} for _, l := range r.listeners { wg.Add(1) - go func(listener listener.Listener) { + go func(listener Listener) { defer wg.Done() errChan <- listener.Listen(r.relay.Accept) }(l) @@ -139,6 +138,6 @@ func (r *Server) InstanceURL() url.URL { // RelayAccept returns the relay's Accept function for handling incoming connections. // This allows external HTTP handlers to route connections to the relay without // starting the relay's own listeners. -func (r *Server) RelayAccept() func(conn net.Conn) { +func (r *Server) RelayAccept() func(conn listener.Conn) { return r.relay.Accept }