Merge remote-tracking branch 'origin/main' into proto-ipv6-overlay

This commit is contained in:
Viktor Liu
2026-04-09 11:58:06 +02:00
38 changed files with 2829 additions and 243 deletions

View File

@@ -199,9 +199,11 @@ func runForDuration(cmd *cobra.Command, args []string) error {
cmd.Println("Log level set to trace.") cmd.Println("Log level set to trace.")
} }
needsRestoreUp := false
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil { if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
cmd.PrintErrf("Failed to bring service down: %v\n", status.Convert(err).Message()) cmd.PrintErrf("Failed to bring service down: %v\n", status.Convert(err).Message())
} else { } else {
needsRestoreUp = !stateWasDown
cmd.Println("netbird down") cmd.Println("netbird down")
} }
@@ -217,6 +219,7 @@ func runForDuration(cmd *cobra.Command, args []string) error {
if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil { if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil {
cmd.PrintErrf("Failed to bring service up: %v\n", status.Convert(err).Message()) cmd.PrintErrf("Failed to bring service up: %v\n", status.Convert(err).Message())
} else { } else {
needsRestoreUp = false
cmd.Println("netbird up") cmd.Println("netbird up")
} }
@@ -264,6 +267,14 @@ func runForDuration(cmd *cobra.Command, args []string) error {
return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message()) return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message())
} }
if needsRestoreUp {
if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil {
cmd.PrintErrf("Failed to restore service up state: %v\n", status.Convert(err).Message())
} else {
cmd.Println("netbird up (restored)")
}
}
if stateWasDown { if stateWasDown {
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil { if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
cmd.PrintErrf("Failed to restore service down state: %v\n", status.Convert(err).Message()) cmd.PrintErrf("Failed to restore service down state: %v\n", status.Convert(err).Message())

View File

@@ -114,6 +114,7 @@ func (c *ConnectClient) RunOniOS(
fileDescriptor int32, fileDescriptor int32,
networkChangeListener listener.NetworkChangeListener, networkChangeListener listener.NetworkChangeListener,
dnsManager dns.IosDnsManager, dnsManager dns.IosDnsManager,
dnsAddresses []netip.AddrPort,
stateFilePath string, stateFilePath string,
) error { ) error {
// Set GC percent to 5% to reduce memory usage as iOS only allows 50MB of memory for the extension. // Set GC percent to 5% to reduce memory usage as iOS only allows 50MB of memory for the extension.
@@ -123,6 +124,7 @@ func (c *ConnectClient) RunOniOS(
FileDescriptor: fileDescriptor, FileDescriptor: fileDescriptor,
NetworkChangeListener: networkChangeListener, NetworkChangeListener: networkChangeListener,
DnsManager: dnsManager, DnsManager: dnsManager,
HostDNSAddresses: dnsAddresses,
StateFilePath: stateFilePath, StateFilePath: stateFilePath,
} }
return c.run(mobileDependency, nil, "") return c.run(mobileDependency, nil, "")

View File

@@ -25,6 +25,7 @@ import (
"google.golang.org/protobuf/encoding/protojson" "google.golang.org/protobuf/encoding/protojson"
"github.com/netbirdio/netbird/client/anonymize" "github.com/netbirdio/netbird/client/anonymize"
"github.com/netbirdio/netbird/client/configs"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/profilemanager" "github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/internal/updater/installer" "github.com/netbirdio/netbird/client/internal/updater/installer"
@@ -52,6 +53,7 @@ resolved_domains.txt: Anonymized resolved domain IP addresses from the status re
config.txt: Anonymized configuration information of the NetBird client. config.txt: Anonymized configuration information of the NetBird client.
network_map.json: Anonymized sync response containing peer configurations, routes, DNS settings, and firewall rules. network_map.json: Anonymized sync response containing peer configurations, routes, DNS settings, and firewall rules.
state.json: Anonymized client state dump containing netbird states for the active profile. state.json: Anonymized client state dump containing netbird states for the active profile.
service_params.json: Sanitized service install parameters (service.json). Sensitive environment variable values are masked. Only present when service.json exists.
metrics.txt: Buffered client metrics in InfluxDB line protocol format. Only present when metrics collection is enabled. Peer identifiers are anonymized. metrics.txt: Buffered client metrics in InfluxDB line protocol format. Only present when metrics collection is enabled. Peer identifiers are anonymized.
mutex.prof: Mutex profiling information. mutex.prof: Mutex profiling information.
goroutine.prof: Goroutine profiling information. goroutine.prof: Goroutine profiling information.
@@ -359,6 +361,10 @@ func (g *BundleGenerator) createArchive() error {
log.Errorf("failed to add corrupted state files to debug bundle: %v", err) log.Errorf("failed to add corrupted state files to debug bundle: %v", err)
} }
if err := g.addServiceParams(); err != nil {
log.Errorf("failed to add service params to debug bundle: %v", err)
}
if err := g.addMetrics(); err != nil { if err := g.addMetrics(); err != nil {
log.Errorf("failed to add metrics to debug bundle: %v", err) log.Errorf("failed to add metrics to debug bundle: %v", err)
} }
@@ -488,6 +494,90 @@ func (g *BundleGenerator) addConfig() error {
return nil return nil
} }
const (
serviceParamsFile = "service.json"
serviceParamsBundle = "service_params.json"
maskedValue = "***"
envVarPrefix = "NB_"
jsonKeyManagementURL = "management_url"
jsonKeyServiceEnv = "service_env_vars"
)
var sensitiveEnvSubstrings = []string{"key", "token", "secret", "password", "credential"}
// addServiceParams reads the service.json file and adds a sanitized version to the bundle.
// Non-NB_ env vars and vars with sensitive names are masked. Other NB_ values are anonymized.
func (g *BundleGenerator) addServiceParams() error {
path := filepath.Join(configs.StateDir, serviceParamsFile)
data, err := os.ReadFile(path)
if err != nil {
if os.IsNotExist(err) {
return nil
}
return fmt.Errorf("read service params: %w", err)
}
var params map[string]any
if err := json.Unmarshal(data, &params); err != nil {
return fmt.Errorf("parse service params: %w", err)
}
if g.anonymize {
if mgmtURL, ok := params[jsonKeyManagementURL].(string); ok && mgmtURL != "" {
params[jsonKeyManagementURL] = g.anonymizer.AnonymizeURI(mgmtURL)
}
}
g.sanitizeServiceEnvVars(params)
sanitizedData, err := json.MarshalIndent(params, "", " ")
if err != nil {
return fmt.Errorf("marshal sanitized service params: %w", err)
}
if err := g.addFileToZip(bytes.NewReader(sanitizedData), serviceParamsBundle); err != nil {
return fmt.Errorf("add service params to zip: %w", err)
}
return nil
}
// sanitizeServiceEnvVars masks or anonymizes env var values in service params.
// Non-NB_ vars and vars with sensitive names (key, token, etc.) are fully masked.
// Other NB_ var values are passed through the anonymizer when anonymization is enabled.
func (g *BundleGenerator) sanitizeServiceEnvVars(params map[string]any) {
envVars, ok := params[jsonKeyServiceEnv].(map[string]any)
if !ok {
return
}
sanitized := make(map[string]any, len(envVars))
for k, v := range envVars {
val, _ := v.(string)
switch {
case !strings.HasPrefix(k, envVarPrefix) || isSensitiveEnvVar(k):
sanitized[k] = maskedValue
case g.anonymize:
sanitized[k] = g.anonymizer.AnonymizeString(val)
default:
sanitized[k] = val
}
}
params[jsonKeyServiceEnv] = sanitized
}
// isSensitiveEnvVar returns true for env var names that may contain secrets.
func isSensitiveEnvVar(key string) bool {
lower := strings.ToLower(key)
for _, s := range sensitiveEnvSubstrings {
if strings.Contains(lower, s) {
return true
}
}
return false
}
func (g *BundleGenerator) addCommonConfigFields(configContent *strings.Builder) { func (g *BundleGenerator) addCommonConfigFields(configContent *strings.Builder) {
configContent.WriteString("NetBird Client Configuration:\n\n") configContent.WriteString("NetBird Client Configuration:\n\n")

View File

@@ -1,8 +1,12 @@
package debug package debug
import ( import (
"archive/zip"
"bytes"
"encoding/json" "encoding/json"
"net" "net"
"os"
"path/filepath"
"strings" "strings"
"testing" "testing"
@@ -10,6 +14,7 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/anonymize" "github.com/netbirdio/netbird/client/anonymize"
"github.com/netbirdio/netbird/client/configs"
mgmProto "github.com/netbirdio/netbird/shared/management/proto" mgmProto "github.com/netbirdio/netbird/shared/management/proto"
) )
@@ -420,6 +425,226 @@ func TestAnonymizeNetworkMap(t *testing.T) {
} }
} }
func TestIsSensitiveEnvVar(t *testing.T) {
tests := []struct {
key string
sensitive bool
}{
{"NB_SETUP_KEY", true},
{"NB_API_TOKEN", true},
{"NB_CLIENT_SECRET", true},
{"NB_PASSWORD", true},
{"NB_CREDENTIAL", true},
{"NB_LOG_LEVEL", false},
{"NB_MANAGEMENT_URL", false},
{"NB_HOSTNAME", false},
{"HOME", false},
{"PATH", false},
}
for _, tt := range tests {
t.Run(tt.key, func(t *testing.T) {
assert.Equal(t, tt.sensitive, isSensitiveEnvVar(tt.key))
})
}
}
func TestSanitizeServiceEnvVars(t *testing.T) {
tests := []struct {
name string
anonymize bool
input map[string]any
check func(t *testing.T, params map[string]any)
}{
{
name: "no env vars key",
anonymize: false,
input: map[string]any{"management_url": "https://mgmt.example.com"},
check: func(t *testing.T, params map[string]any) {
t.Helper()
assert.Equal(t, "https://mgmt.example.com", params["management_url"], "non-env fields should be untouched")
_, ok := params[jsonKeyServiceEnv]
assert.False(t, ok, "service_env_vars should not be added")
},
},
{
name: "non-NB vars are masked",
anonymize: false,
input: map[string]any{
jsonKeyServiceEnv: map[string]any{
"HOME": "/root",
"PATH": "/usr/bin",
"NB_LOG_LEVEL": "debug",
},
},
check: func(t *testing.T, params map[string]any) {
t.Helper()
env := params[jsonKeyServiceEnv].(map[string]any)
assert.Equal(t, maskedValue, env["HOME"], "non-NB_ var should be masked")
assert.Equal(t, maskedValue, env["PATH"], "non-NB_ var should be masked")
assert.Equal(t, "debug", env["NB_LOG_LEVEL"], "safe NB_ var should pass through")
},
},
{
name: "sensitive NB vars are masked",
anonymize: false,
input: map[string]any{
jsonKeyServiceEnv: map[string]any{
"NB_SETUP_KEY": "abc123",
"NB_API_TOKEN": "tok_xyz",
"NB_LOG_LEVEL": "info",
},
},
check: func(t *testing.T, params map[string]any) {
t.Helper()
env := params[jsonKeyServiceEnv].(map[string]any)
assert.Equal(t, maskedValue, env["NB_SETUP_KEY"], "sensitive NB_ var should be masked")
assert.Equal(t, maskedValue, env["NB_API_TOKEN"], "sensitive NB_ var should be masked")
assert.Equal(t, "info", env["NB_LOG_LEVEL"], "safe NB_ var should pass through")
},
},
{
name: "safe NB vars anonymized when anonymize is true",
anonymize: true,
input: map[string]any{
jsonKeyServiceEnv: map[string]any{
"NB_MANAGEMENT_URL": "https://mgmt.example.com:443",
"NB_LOG_LEVEL": "debug",
"NB_SETUP_KEY": "secret",
"SOME_OTHER": "val",
},
},
check: func(t *testing.T, params map[string]any) {
t.Helper()
env := params[jsonKeyServiceEnv].(map[string]any)
// Safe NB_ values should be anonymized (not the original, not masked)
mgmtVal := env["NB_MANAGEMENT_URL"].(string)
assert.NotEqual(t, "https://mgmt.example.com:443", mgmtVal, "should be anonymized")
assert.NotEqual(t, maskedValue, mgmtVal, "should not be masked")
logVal := env["NB_LOG_LEVEL"].(string)
assert.NotEqual(t, maskedValue, logVal, "safe NB_ var should not be masked")
// Sensitive and non-NB_ still masked
assert.Equal(t, maskedValue, env["NB_SETUP_KEY"])
assert.Equal(t, maskedValue, env["SOME_OTHER"])
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
anonymizer := anonymize.NewAnonymizer(anonymize.DefaultAddresses())
g := &BundleGenerator{
anonymize: tt.anonymize,
anonymizer: anonymizer,
}
g.sanitizeServiceEnvVars(tt.input)
tt.check(t, tt.input)
})
}
}
func TestAddServiceParams(t *testing.T) {
t.Run("missing service.json returns nil", func(t *testing.T) {
g := &BundleGenerator{
anonymizer: anonymize.NewAnonymizer(anonymize.DefaultAddresses()),
}
origStateDir := configs.StateDir
configs.StateDir = t.TempDir()
t.Cleanup(func() { configs.StateDir = origStateDir })
err := g.addServiceParams()
assert.NoError(t, err)
})
t.Run("management_url anonymized when anonymize is true", func(t *testing.T) {
dir := t.TempDir()
origStateDir := configs.StateDir
configs.StateDir = dir
t.Cleanup(func() { configs.StateDir = origStateDir })
input := map[string]any{
jsonKeyManagementURL: "https://api.example.com:443",
jsonKeyServiceEnv: map[string]any{
"NB_LOG_LEVEL": "trace",
},
}
data, err := json.Marshal(input)
require.NoError(t, err)
require.NoError(t, os.WriteFile(filepath.Join(dir, serviceParamsFile), data, 0600))
var buf bytes.Buffer
zw := zip.NewWriter(&buf)
g := &BundleGenerator{
anonymize: true,
anonymizer: anonymize.NewAnonymizer(anonymize.DefaultAddresses()),
archive: zw,
}
require.NoError(t, g.addServiceParams())
require.NoError(t, zw.Close())
zr, err := zip.NewReader(bytes.NewReader(buf.Bytes()), int64(buf.Len()))
require.NoError(t, err)
require.Len(t, zr.File, 1)
assert.Equal(t, serviceParamsBundle, zr.File[0].Name)
rc, err := zr.File[0].Open()
require.NoError(t, err)
defer rc.Close()
var result map[string]any
require.NoError(t, json.NewDecoder(rc).Decode(&result))
mgmt := result[jsonKeyManagementURL].(string)
assert.NotEqual(t, "https://api.example.com:443", mgmt, "management_url should be anonymized")
assert.NotEmpty(t, mgmt)
env := result[jsonKeyServiceEnv].(map[string]any)
assert.NotEqual(t, maskedValue, env["NB_LOG_LEVEL"], "safe NB_ var should not be masked")
})
t.Run("management_url preserved when anonymize is false", func(t *testing.T) {
dir := t.TempDir()
origStateDir := configs.StateDir
configs.StateDir = dir
t.Cleanup(func() { configs.StateDir = origStateDir })
input := map[string]any{
jsonKeyManagementURL: "https://api.example.com:443",
}
data, err := json.Marshal(input)
require.NoError(t, err)
require.NoError(t, os.WriteFile(filepath.Join(dir, serviceParamsFile), data, 0600))
var buf bytes.Buffer
zw := zip.NewWriter(&buf)
g := &BundleGenerator{
anonymize: false,
anonymizer: anonymize.NewAnonymizer(anonymize.DefaultAddresses()),
archive: zw,
}
require.NoError(t, g.addServiceParams())
require.NoError(t, zw.Close())
zr, err := zip.NewReader(bytes.NewReader(buf.Bytes()), int64(buf.Len()))
require.NoError(t, err)
rc, err := zr.File[0].Open()
require.NoError(t, err)
defer rc.Close()
var result map[string]any
require.NoError(t, json.NewDecoder(rc).Decode(&result))
assert.Equal(t, "https://api.example.com:443", result[jsonKeyManagementURL], "management_url should be preserved")
})
}
// Helper function to check if IP is in CGNAT range // Helper function to check if IP is in CGNAT range
func isInCGNATRange(ip net.IP) bool { func isInCGNATRange(ip net.IP) bool {
cgnat := net.IPNet{ cgnat := net.IPNet{

View File

@@ -1263,9 +1263,9 @@ func TestLocalResolver_AuthoritativeFlag(t *testing.T) {
}) })
} }
// TestLocalResolver_Stop tests cleanup on Stop // TestLocalResolver_Stop tests cleanup on GracefullyStop
func TestLocalResolver_Stop(t *testing.T) { func TestLocalResolver_Stop(t *testing.T) {
t.Run("Stop clears all state", func(t *testing.T) { t.Run("GracefullyStop clears all state", func(t *testing.T) {
resolver := NewResolver() resolver := NewResolver()
resolver.Update([]nbdns.CustomZone{{ resolver.Update([]nbdns.CustomZone{{
Domain: "example.com.", Domain: "example.com.",
@@ -1285,7 +1285,7 @@ func TestLocalResolver_Stop(t *testing.T) {
assert.False(t, resolver.isInManagedZone("host.example.com.")) assert.False(t, resolver.isInManagedZone("host.example.com."))
}) })
t.Run("Stop is safe to call multiple times", func(t *testing.T) { t.Run("GracefullyStop is safe to call multiple times", func(t *testing.T) {
resolver := NewResolver() resolver := NewResolver()
resolver.Update([]nbdns.CustomZone{{ resolver.Update([]nbdns.CustomZone{{
Domain: "example.com.", Domain: "example.com.",
@@ -1299,7 +1299,7 @@ func TestLocalResolver_Stop(t *testing.T) {
resolver.Stop() resolver.Stop()
}) })
t.Run("Stop cancels in-flight external resolution", func(t *testing.T) { t.Run("GracefullyStop cancels in-flight external resolution", func(t *testing.T) {
resolver := NewResolver() resolver := NewResolver()
lookupStarted := make(chan struct{}) lookupStarted := make(chan struct{})

View File

@@ -187,11 +187,16 @@ func NewDefaultServerIos(
ctx context.Context, ctx context.Context,
wgInterface WGIface, wgInterface WGIface,
iosDnsManager IosDnsManager, iosDnsManager IosDnsManager,
hostsDnsList []netip.AddrPort,
statusRecorder *peer.Status, statusRecorder *peer.Status,
disableSys bool, disableSys bool,
) *DefaultServer { ) *DefaultServer {
log.Debugf("iOS host dns address list is: %v", hostsDnsList)
ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder, nil, disableSys) ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder, nil, disableSys)
ds.iosDnsManager = iosDnsManager ds.iosDnsManager = iosDnsManager
ds.hostsDNSHolder.set(hostsDnsList)
ds.permanent = true
ds.addHostRootZone()
return ds return ds
} }

View File

@@ -47,6 +47,7 @@ import (
"github.com/netbirdio/netbird/client/internal/peer/guard" "github.com/netbirdio/netbird/client/internal/peer/guard"
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice" icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
"github.com/netbirdio/netbird/client/internal/peerstore" "github.com/netbirdio/netbird/client/internal/peerstore"
"github.com/netbirdio/netbird/client/internal/portforward"
"github.com/netbirdio/netbird/client/internal/profilemanager" "github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/internal/relay" "github.com/netbirdio/netbird/client/internal/relay"
"github.com/netbirdio/netbird/client/internal/rosenpass" "github.com/netbirdio/netbird/client/internal/rosenpass"
@@ -214,9 +215,10 @@ type Engine struct {
// checks are the client-applied posture checks that need to be evaluated on the client // checks are the client-applied posture checks that need to be evaluated on the client
checks []*mgmProto.Checks checks []*mgmProto.Checks
relayManager *relayClient.Manager relayManager *relayClient.Manager
stateManager *statemanager.Manager stateManager *statemanager.Manager
srWatcher *guard.SRWatcher portForwardManager *portforward.Manager
srWatcher *guard.SRWatcher
// Sync response persistence (protected by syncRespMux) // Sync response persistence (protected by syncRespMux)
syncRespMux sync.RWMutex syncRespMux sync.RWMutex
@@ -263,26 +265,27 @@ func NewEngine(
mobileDep MobileDependency, mobileDep MobileDependency,
) *Engine { ) *Engine {
engine := &Engine{ engine := &Engine{
clientCtx: clientCtx, clientCtx: clientCtx,
clientCancel: clientCancel, clientCancel: clientCancel,
signal: services.SignalClient, signal: services.SignalClient,
signaler: peer.NewSignaler(services.SignalClient, config.WgPrivateKey), signaler: peer.NewSignaler(services.SignalClient, config.WgPrivateKey),
mgmClient: services.MgmClient, mgmClient: services.MgmClient,
relayManager: services.RelayManager, relayManager: services.RelayManager,
peerStore: peerstore.NewConnStore(), peerStore: peerstore.NewConnStore(),
syncMsgMux: &sync.Mutex{}, syncMsgMux: &sync.Mutex{},
config: config, config: config,
mobileDep: mobileDep, mobileDep: mobileDep,
STUNs: []*stun.URI{}, STUNs: []*stun.URI{},
TURNs: []*stun.URI{}, TURNs: []*stun.URI{},
networkSerial: 0, networkSerial: 0,
statusRecorder: services.StatusRecorder, statusRecorder: services.StatusRecorder,
stateManager: services.StateManager, stateManager: services.StateManager,
checks: services.Checks, portForwardManager: portforward.NewManager(),
probeStunTurn: relay.NewStunTurnProbe(relay.DefaultCacheTTL), checks: services.Checks,
jobExecutor: jobexec.NewExecutor(), probeStunTurn: relay.NewStunTurnProbe(relay.DefaultCacheTTL),
clientMetrics: services.ClientMetrics, jobExecutor: jobexec.NewExecutor(),
updateManager: services.UpdateManager, clientMetrics: services.ClientMetrics,
updateManager: services.UpdateManager,
} }
log.Infof("I am: %s", config.WgPrivateKey.PublicKey().String()) log.Infof("I am: %s", config.WgPrivateKey.PublicKey().String())
@@ -541,6 +544,13 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
// conntrack entries from being created before the rules are in place // conntrack entries from being created before the rules are in place
e.setupWGProxyNoTrack() e.setupWGProxyNoTrack()
// Start after interface is up since port may have been resolved from 0 or changed if occupied
e.shutdownWg.Add(1)
go func() {
defer e.shutdownWg.Done()
e.portForwardManager.Start(e.ctx, uint16(e.config.WgPort))
}()
// Set the WireGuard interface for rosenpass after interface is up // Set the WireGuard interface for rosenpass after interface is up
if e.rpManager != nil { if e.rpManager != nil {
e.rpManager.SetInterface(e.wgInterface) e.rpManager.SetInterface(e.wgInterface)
@@ -1627,12 +1637,13 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix, agentV
} }
serviceDependencies := peer.ServiceDependencies{ serviceDependencies := peer.ServiceDependencies{
StatusRecorder: e.statusRecorder, StatusRecorder: e.statusRecorder,
Signaler: e.signaler, Signaler: e.signaler,
IFaceDiscover: e.mobileDep.IFaceDiscover, IFaceDiscover: e.mobileDep.IFaceDiscover,
RelayManager: e.relayManager, RelayManager: e.relayManager,
SrWatcher: e.srWatcher, SrWatcher: e.srWatcher,
MetricsRecorder: e.clientMetrics, PortForwardManager: e.portForwardManager,
MetricsRecorder: e.clientMetrics,
} }
peerConn, err := peer.NewConn(config, serviceDependencies) peerConn, err := peer.NewConn(config, serviceDependencies)
if err != nil { if err != nil {
@@ -1789,6 +1800,12 @@ func (e *Engine) close() {
if e.rpManager != nil { if e.rpManager != nil {
_ = e.rpManager.Close() _ = e.rpManager.Close()
} }
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := e.portForwardManager.GracefullyStop(ctx); err != nil {
log.Warnf("failed to gracefully stop port forwarding manager: %s", err)
}
} }
func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, bool, error) { func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, bool, error) {
@@ -1894,7 +1911,7 @@ func (e *Engine) newDnsServer(dnsConfig *nbdns.Config) (dns.Server, error) {
return dnsServer, nil return dnsServer, nil
case "ios": case "ios":
dnsServer := dns.NewDefaultServerIos(e.ctx, e.wgInterface, e.mobileDep.DnsManager, e.statusRecorder, e.config.DisableDNS) dnsServer := dns.NewDefaultServerIos(e.ctx, e.wgInterface, e.mobileDep.DnsManager, e.mobileDep.HostDNSAddresses, e.statusRecorder, e.config.DisableDNS)
return dnsServer, nil return dnsServer, nil
default: default:

View File

@@ -22,6 +22,7 @@ import (
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice" icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
"github.com/netbirdio/netbird/client/internal/peer/id" "github.com/netbirdio/netbird/client/internal/peer/id"
"github.com/netbirdio/netbird/client/internal/peer/worker" "github.com/netbirdio/netbird/client/internal/peer/worker"
"github.com/netbirdio/netbird/client/internal/portforward"
"github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
relayClient "github.com/netbirdio/netbird/shared/relay/client" relayClient "github.com/netbirdio/netbird/shared/relay/client"
@@ -45,6 +46,7 @@ type ServiceDependencies struct {
RelayManager *relayClient.Manager RelayManager *relayClient.Manager
SrWatcher *guard.SRWatcher SrWatcher *guard.SRWatcher
PeerConnDispatcher *dispatcher.ConnectionDispatcher PeerConnDispatcher *dispatcher.ConnectionDispatcher
PortForwardManager *portforward.Manager
MetricsRecorder MetricsRecorder MetricsRecorder MetricsRecorder
} }
@@ -87,16 +89,17 @@ type ConnConfig struct {
} }
type Conn struct { type Conn struct {
Log *log.Entry Log *log.Entry
mu sync.Mutex mu sync.Mutex
ctx context.Context ctx context.Context
ctxCancel context.CancelFunc ctxCancel context.CancelFunc
config ConnConfig config ConnConfig
statusRecorder *Status statusRecorder *Status
signaler *Signaler signaler *Signaler
iFaceDiscover stdnet.ExternalIFaceDiscover iFaceDiscover stdnet.ExternalIFaceDiscover
relayManager *relayClient.Manager relayManager *relayClient.Manager
srWatcher *guard.SRWatcher srWatcher *guard.SRWatcher
portForwardManager *portforward.Manager
onConnected func(remoteWireGuardKey string, remoteRosenpassPubKey []byte, wireGuardIP string, remoteRosenpassAddr string) onConnected func(remoteWireGuardKey string, remoteRosenpassPubKey []byte, wireGuardIP string, remoteRosenpassAddr string)
onDisconnected func(remotePeer string) onDisconnected func(remotePeer string)
@@ -145,19 +148,20 @@ func NewConn(config ConnConfig, services ServiceDependencies) (*Conn, error) {
dumpState := newStateDump(config.Key, connLog, services.StatusRecorder) dumpState := newStateDump(config.Key, connLog, services.StatusRecorder)
var conn = &Conn{ var conn = &Conn{
Log: connLog, Log: connLog,
config: config, config: config,
statusRecorder: services.StatusRecorder, statusRecorder: services.StatusRecorder,
signaler: services.Signaler, signaler: services.Signaler,
iFaceDiscover: services.IFaceDiscover, iFaceDiscover: services.IFaceDiscover,
relayManager: services.RelayManager, relayManager: services.RelayManager,
srWatcher: services.SrWatcher, srWatcher: services.SrWatcher,
statusRelay: worker.NewAtomicStatus(), portForwardManager: services.PortForwardManager,
statusICE: worker.NewAtomicStatus(), statusRelay: worker.NewAtomicStatus(),
dumpState: dumpState, statusICE: worker.NewAtomicStatus(),
endpointUpdater: NewEndpointUpdater(connLog, config.WgConfig, isController(config)), dumpState: dumpState,
wgWatcher: NewWGWatcher(connLog, config.WgConfig.WgInterface, config.Key, dumpState), endpointUpdater: NewEndpointUpdater(connLog, config.WgConfig, isController(config)),
metricsRecorder: services.MetricsRecorder, wgWatcher: NewWGWatcher(connLog, config.WgConfig.WgInterface, config.Key, dumpState),
metricsRecorder: services.MetricsRecorder,
} }
return conn, nil return conn, nil

View File

@@ -16,6 +16,7 @@ import (
"github.com/netbirdio/netbird/client/iface/udpmux" "github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/internal/peer/conntype" "github.com/netbirdio/netbird/client/internal/peer/conntype"
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice" icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
"github.com/netbirdio/netbird/client/internal/portforward"
"github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
) )
@@ -61,6 +62,9 @@ type WorkerICE struct {
// we record the last known state of the ICE agent to avoid duplicate on disconnected events // we record the last known state of the ICE agent to avoid duplicate on disconnected events
lastKnownState ice.ConnectionState lastKnownState ice.ConnectionState
// portForwardAttempted tracks if we've already tried port forwarding this session
portForwardAttempted bool
} }
func NewWorkerICE(ctx context.Context, log *log.Entry, config ConnConfig, conn *Conn, signaler *Signaler, ifaceDiscover stdnet.ExternalIFaceDiscover, statusRecorder *Status, hasRelayOnLocally bool) (*WorkerICE, error) { func NewWorkerICE(ctx context.Context, log *log.Entry, config ConnConfig, conn *Conn, signaler *Signaler, ifaceDiscover stdnet.ExternalIFaceDiscover, statusRecorder *Status, hasRelayOnLocally bool) (*WorkerICE, error) {
@@ -214,6 +218,8 @@ func (w *WorkerICE) Close() {
} }
func (w *WorkerICE) reCreateAgent(dialerCancel context.CancelFunc, candidates []ice.CandidateType) (*icemaker.ThreadSafeAgent, error) { func (w *WorkerICE) reCreateAgent(dialerCancel context.CancelFunc, candidates []ice.CandidateType) (*icemaker.ThreadSafeAgent, error) {
w.portForwardAttempted = false
agent, err := icemaker.NewAgent(w.ctx, w.iFaceDiscover, w.config.ICEConfig, candidates, w.localUfrag, w.localPwd) agent, err := icemaker.NewAgent(w.ctx, w.iFaceDiscover, w.config.ICEConfig, candidates, w.localUfrag, w.localPwd)
if err != nil { if err != nil {
return nil, fmt.Errorf("create agent: %w", err) return nil, fmt.Errorf("create agent: %w", err)
@@ -370,6 +376,93 @@ func (w *WorkerICE) onICECandidate(candidate ice.Candidate) {
w.log.Errorf("failed signaling candidate to the remote peer %s %s", w.config.Key, err) w.log.Errorf("failed signaling candidate to the remote peer %s %s", w.config.Key, err)
} }
}() }()
if candidate.Type() == ice.CandidateTypeServerReflexive {
w.injectPortForwardedCandidate(candidate)
}
}
// injectPortForwardedCandidate signals an additional candidate using the pre-created port mapping.
func (w *WorkerICE) injectPortForwardedCandidate(srflxCandidate ice.Candidate) {
pfManager := w.conn.portForwardManager
if pfManager == nil {
return
}
mapping := pfManager.GetMapping()
if mapping == nil {
return
}
w.muxAgent.Lock()
if w.portForwardAttempted {
w.muxAgent.Unlock()
return
}
w.portForwardAttempted = true
w.muxAgent.Unlock()
forwardedCandidate, err := w.createForwardedCandidate(srflxCandidate, mapping)
if err != nil {
w.log.Warnf("create forwarded candidate: %v", err)
return
}
w.log.Debugf("injecting port-forwarded candidate: %s (mapping: %d -> %d via %s, priority: %d)",
forwardedCandidate.String(), mapping.InternalPort, mapping.ExternalPort, mapping.NATType, forwardedCandidate.Priority())
go func() {
if err := w.signaler.SignalICECandidate(forwardedCandidate, w.config.Key); err != nil {
w.log.Errorf("signal port-forwarded candidate: %v", err)
}
}()
}
// createForwardedCandidate creates a new server reflexive candidate with the forwarded port.
// It uses the NAT gateway's external IP with the forwarded port.
func (w *WorkerICE) createForwardedCandidate(srflxCandidate ice.Candidate, mapping *portforward.Mapping) (ice.Candidate, error) {
var externalIP string
if mapping.ExternalIP != nil && !mapping.ExternalIP.IsUnspecified() {
externalIP = mapping.ExternalIP.String()
} else {
// Fallback to STUN-discovered address if NAT didn't provide external IP
externalIP = srflxCandidate.Address()
}
// Per RFC 8445, the related address for srflx is the base (host candidate address).
// If the original srflx has unspecified related address, use its own address as base.
relAddr := srflxCandidate.RelatedAddress().Address
if relAddr == "" || relAddr == "0.0.0.0" || relAddr == "::" {
relAddr = srflxCandidate.Address()
}
// Arbitrary +1000 boost on top of RFC 8445 priority to favor port-forwarded candidates
// over regular srflx during ICE connectivity checks.
priority := srflxCandidate.Priority() + 1000
candidate, err := ice.NewCandidateServerReflexive(&ice.CandidateServerReflexiveConfig{
Network: srflxCandidate.NetworkType().String(),
Address: externalIP,
Port: int(mapping.ExternalPort),
Component: srflxCandidate.Component(),
Priority: priority,
RelAddr: relAddr,
RelPort: int(mapping.InternalPort),
})
if err != nil {
return nil, fmt.Errorf("create candidate: %w", err)
}
for _, e := range srflxCandidate.Extensions() {
if e.Key == ice.ExtensionKeyCandidateID {
e.Value = srflxCandidate.ID()
}
if err := candidate.AddExtension(e); err != nil {
return nil, fmt.Errorf("add extension: %w", err)
}
}
return candidate, nil
} }
func (w *WorkerICE) onICESelectedCandidatePair(agent *icemaker.ThreadSafeAgent, c1, c2 ice.Candidate) { func (w *WorkerICE) onICESelectedCandidatePair(agent *icemaker.ThreadSafeAgent, c1, c2 ice.Candidate) {
@@ -411,10 +504,10 @@ func (w *WorkerICE) logSuccessfulPaths(agent *icemaker.ThreadSafeAgent) {
if !lok || !rok { if !lok || !rok {
continue continue
} }
w.log.Debugf("successful ICE path %s: [%s %s %s] <-> [%s %s %s] rtt=%.3fms", w.log.Debugf("successful ICE path %s: [%s %s %s:%d] <-> [%s %s %s:%d] rtt=%.3fms",
sessionID, sessionID,
local.NetworkType(), local.Type(), local.Address(), local.NetworkType(), local.Type(), local.Address(), local.Port(),
remote.NetworkType(), remote.Type(), remote.Address(), remote.NetworkType(), remote.Type(), remote.Address(), remote.Port(),
stat.CurrentRoundTripTime*1000) stat.CurrentRoundTripTime*1000)
} }
} }

View File

@@ -0,0 +1,26 @@
package portforward
import (
"os"
"strconv"
log "github.com/sirupsen/logrus"
)
const (
envDisableNATMapper = "NB_DISABLE_NAT_MAPPER"
)
func isDisabledByEnv() bool {
val := os.Getenv(envDisableNATMapper)
if val == "" {
return false
}
disabled, err := strconv.ParseBool(val)
if err != nil {
log.Warnf("failed to parse %s: %v", envDisableNATMapper, err)
return false
}
return disabled
}

View File

@@ -0,0 +1,280 @@
//go:build !js
package portforward
import (
"context"
"fmt"
"net"
"regexp"
"sync"
"time"
"github.com/libp2p/go-nat"
log "github.com/sirupsen/logrus"
)
const (
defaultMappingTTL = 2 * time.Hour
discoveryTimeout = 10 * time.Second
mappingDescription = "NetBird"
)
// upnpErrPermanentLeaseOnly matches UPnP error 725 in SOAP fault XML,
// allowing for whitespace/newlines between tags from different router firmware.
var upnpErrPermanentLeaseOnly = regexp.MustCompile(`<errorCode>\s*725\s*</errorCode>`)
// Mapping represents an active NAT port mapping.
type Mapping struct {
Protocol string
InternalPort uint16
ExternalPort uint16
ExternalIP net.IP
NATType string
// TTL is the lease duration. Zero means a permanent lease that never expires.
TTL time.Duration
}
// TODO: persist mapping state for crash recovery cleanup of permanent leases.
// Currently not done because State.Cleanup requires NAT gateway re-discovery,
// which blocks startup for ~10s when no gateway is present (affects all clients).
type Manager struct {
cancel context.CancelFunc
mapping *Mapping
mappingLock sync.Mutex
wgPort uint16
done chan struct{}
stopCtx chan context.Context
// protect exported functions
mu sync.Mutex
}
// NewManager creates a new port forwarding manager.
func NewManager() *Manager {
return &Manager{
stopCtx: make(chan context.Context, 1),
}
}
func (m *Manager) Start(ctx context.Context, wgPort uint16) {
m.mu.Lock()
if m.cancel != nil {
m.mu.Unlock()
return
}
if isDisabledByEnv() {
log.Infof("NAT port mapper disabled via %s", envDisableNATMapper)
m.mu.Unlock()
return
}
if wgPort == 0 {
log.Warnf("invalid WireGuard port 0; NAT mapping disabled")
m.mu.Unlock()
return
}
m.wgPort = wgPort
m.done = make(chan struct{})
defer close(m.done)
ctx, m.cancel = context.WithCancel(ctx)
m.mu.Unlock()
gateway, mapping, err := m.setup(ctx)
if err != nil {
log.Infof("port forwarding setup: %v", err)
return
}
m.mappingLock.Lock()
m.mapping = mapping
m.mappingLock.Unlock()
m.renewLoop(ctx, gateway, mapping.TTL)
select {
case cleanupCtx := <-m.stopCtx:
// block the Start while cleaned up gracefully
m.cleanup(cleanupCtx, gateway)
default:
// return Start immediately and cleanup in background
cleanupCtx, cleanupCancel := context.WithTimeout(context.Background(), 10*time.Second)
go func() {
defer cleanupCancel()
m.cleanup(cleanupCtx, gateway)
}()
}
}
// GetMapping returns the current mapping if ready, nil otherwise
func (m *Manager) GetMapping() *Mapping {
m.mappingLock.Lock()
defer m.mappingLock.Unlock()
if m.mapping == nil {
return nil
}
mapping := *m.mapping
return &mapping
}
// GracefullyStop cancels the manager and attempts to delete the port mapping.
// After GracefullyStop returns, the manager cannot be restarted.
func (m *Manager) GracefullyStop(ctx context.Context) error {
m.mu.Lock()
defer m.mu.Unlock()
if m.cancel == nil {
return nil
}
// Send cleanup context before cancelling, so Start picks it up after renewLoop exits.
m.startTearDown(ctx)
m.cancel()
m.cancel = nil
select {
case <-ctx.Done():
return ctx.Err()
case <-m.done:
return nil
}
}
func (m *Manager) setup(ctx context.Context) (nat.NAT, *Mapping, error) {
discoverCtx, discoverCancel := context.WithTimeout(ctx, discoveryTimeout)
defer discoverCancel()
gateway, err := nat.DiscoverGateway(discoverCtx)
if err != nil {
return nil, nil, fmt.Errorf("discover gateway: %w", err)
}
log.Infof("discovered NAT gateway: %s", gateway.Type())
mapping, err := m.createMapping(ctx, gateway)
if err != nil {
return nil, nil, fmt.Errorf("create port mapping: %w", err)
}
return gateway, mapping, nil
}
func (m *Manager) createMapping(ctx context.Context, gateway nat.NAT) (*Mapping, error) {
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
ttl := defaultMappingTTL
externalPort, err := gateway.AddPortMapping(ctx, "udp", int(m.wgPort), mappingDescription, ttl)
if err != nil {
if !isPermanentLeaseRequired(err) {
return nil, err
}
log.Infof("gateway only supports permanent leases, retrying with indefinite duration")
ttl = 0
externalPort, err = gateway.AddPortMapping(ctx, "udp", int(m.wgPort), mappingDescription, ttl)
if err != nil {
return nil, err
}
}
externalIP, err := gateway.GetExternalAddress()
if err != nil {
log.Debugf("failed to get external address: %v", err)
// todo return with err?
}
mapping := &Mapping{
Protocol: "udp",
InternalPort: m.wgPort,
ExternalPort: uint16(externalPort),
ExternalIP: externalIP,
NATType: gateway.Type(),
TTL: ttl,
}
log.Infof("created port mapping: %d -> %d via %s (external IP: %s)",
m.wgPort, externalPort, gateway.Type(), externalIP)
return mapping, nil
}
func (m *Manager) renewLoop(ctx context.Context, gateway nat.NAT, ttl time.Duration) {
if ttl == 0 {
// Permanent mappings don't expire, just wait for cancellation.
<-ctx.Done()
return
}
ticker := time.NewTicker(ttl / 2)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if err := m.renewMapping(ctx, gateway); err != nil {
log.Warnf("failed to renew port mapping: %v", err)
continue
}
}
}
}
func (m *Manager) renewMapping(ctx context.Context, gateway nat.NAT) error {
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
externalPort, err := gateway.AddPortMapping(ctx, m.mapping.Protocol, int(m.mapping.InternalPort), mappingDescription, m.mapping.TTL)
if err != nil {
return fmt.Errorf("add port mapping: %w", err)
}
if uint16(externalPort) != m.mapping.ExternalPort {
log.Warnf("external port changed on renewal: %d -> %d (candidate may be stale)", m.mapping.ExternalPort, externalPort)
m.mappingLock.Lock()
m.mapping.ExternalPort = uint16(externalPort)
m.mappingLock.Unlock()
}
log.Debugf("renewed port mapping: %d -> %d", m.mapping.InternalPort, m.mapping.ExternalPort)
return nil
}
func (m *Manager) cleanup(ctx context.Context, gateway nat.NAT) {
m.mappingLock.Lock()
mapping := m.mapping
m.mapping = nil
m.mappingLock.Unlock()
if mapping == nil {
return
}
if err := gateway.DeletePortMapping(ctx, mapping.Protocol, int(mapping.InternalPort)); err != nil {
log.Warnf("delete port mapping on stop: %v", err)
return
}
log.Infof("deleted port mapping for port %d", mapping.InternalPort)
}
func (m *Manager) startTearDown(ctx context.Context) {
select {
case m.stopCtx <- ctx:
default:
}
}
// isPermanentLeaseRequired checks if a UPnP error indicates the gateway only supports permanent leases (error 725).
func isPermanentLeaseRequired(err error) bool {
return err != nil && upnpErrPermanentLeaseOnly.MatchString(err.Error())
}

View File

@@ -0,0 +1,39 @@
package portforward
import (
"context"
"net"
"time"
)
// Mapping represents an active NAT port mapping.
type Mapping struct {
Protocol string
InternalPort uint16
ExternalPort uint16
ExternalIP net.IP
NATType string
// TTL is the lease duration. Zero means a permanent lease that never expires.
TTL time.Duration
}
// Manager is a stub for js/wasm builds where NAT-PMP/UPnP is not supported.
type Manager struct{}
// NewManager returns a stub manager for js/wasm builds.
func NewManager() *Manager {
return &Manager{}
}
// Start is a no-op on js/wasm: NAT-PMP/UPnP is not available in browser environments.
func (m *Manager) Start(context.Context, uint16) {
// no NAT traversal in wasm
}
// GracefullyStop is a no-op on js/wasm.
func (m *Manager) GracefullyStop(context.Context) error { return nil }
// GetMapping always returns nil on js/wasm.
func (m *Manager) GetMapping() *Mapping {
return nil
}

View File

@@ -0,0 +1,201 @@
//go:build !js
package portforward
import (
"context"
"fmt"
"net"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
type mockNAT struct {
natType string
deviceAddr net.IP
externalAddr net.IP
internalAddr net.IP
mappings map[int]int
addMappingErr error
deleteMappingErr error
onlyPermanentLeases bool
lastTimeout time.Duration
}
func newMockNAT() *mockNAT {
return &mockNAT{
natType: "Mock-NAT",
deviceAddr: net.ParseIP("192.168.1.1"),
externalAddr: net.ParseIP("203.0.113.50"),
internalAddr: net.ParseIP("192.168.1.100"),
mappings: make(map[int]int),
}
}
func (m *mockNAT) Type() string {
return m.natType
}
func (m *mockNAT) GetDeviceAddress() (net.IP, error) {
return m.deviceAddr, nil
}
func (m *mockNAT) GetExternalAddress() (net.IP, error) {
return m.externalAddr, nil
}
func (m *mockNAT) GetInternalAddress() (net.IP, error) {
return m.internalAddr, nil
}
func (m *mockNAT) AddPortMapping(ctx context.Context, protocol string, internalPort int, description string, timeout time.Duration) (int, error) {
if m.addMappingErr != nil {
return 0, m.addMappingErr
}
if m.onlyPermanentLeases && timeout != 0 {
return 0, fmt.Errorf("SOAP fault. Code: | Explanation: | Detail: <UPnPError xmlns=\"urn:schemas-upnp-org:control-1-0\"><errorCode>725</errorCode><errorDescription>OnlyPermanentLeasesSupported</errorDescription></UPnPError>")
}
externalPort := internalPort
m.mappings[internalPort] = externalPort
m.lastTimeout = timeout
return externalPort, nil
}
func (m *mockNAT) DeletePortMapping(ctx context.Context, protocol string, internalPort int) error {
if m.deleteMappingErr != nil {
return m.deleteMappingErr
}
delete(m.mappings, internalPort)
return nil
}
func TestManager_CreateMapping(t *testing.T) {
m := NewManager()
m.wgPort = 51820
gateway := newMockNAT()
mapping, err := m.createMapping(context.Background(), gateway)
require.NoError(t, err)
require.NotNil(t, mapping)
assert.Equal(t, "udp", mapping.Protocol)
assert.Equal(t, uint16(51820), mapping.InternalPort)
assert.Equal(t, uint16(51820), mapping.ExternalPort)
assert.Equal(t, "Mock-NAT", mapping.NATType)
assert.Equal(t, net.ParseIP("203.0.113.50").To4(), mapping.ExternalIP.To4())
assert.Equal(t, defaultMappingTTL, mapping.TTL)
}
func TestManager_GetMapping_ReturnsNilWhenNotReady(t *testing.T) {
m := NewManager()
assert.Nil(t, m.GetMapping())
}
func TestManager_GetMapping_ReturnsCopy(t *testing.T) {
m := NewManager()
m.mapping = &Mapping{
Protocol: "udp",
InternalPort: 51820,
ExternalPort: 51820,
}
mapping := m.GetMapping()
require.NotNil(t, mapping)
assert.Equal(t, uint16(51820), mapping.InternalPort)
// Mutating the returned copy should not affect the manager's mapping.
mapping.ExternalPort = 9999
assert.Equal(t, uint16(51820), m.GetMapping().ExternalPort)
}
func TestManager_Cleanup_DeletesMapping(t *testing.T) {
m := NewManager()
m.mapping = &Mapping{
Protocol: "udp",
InternalPort: 51820,
ExternalPort: 51820,
}
gateway := newMockNAT()
// Seed the mock so we can verify deletion.
gateway.mappings[51820] = 51820
m.cleanup(context.Background(), gateway)
_, exists := gateway.mappings[51820]
assert.False(t, exists, "mapping should be deleted from gateway")
assert.Nil(t, m.GetMapping(), "in-memory mapping should be cleared")
}
func TestManager_Cleanup_NilMapping(t *testing.T) {
m := NewManager()
gateway := newMockNAT()
// Should not panic or call gateway.
m.cleanup(context.Background(), gateway)
}
func TestManager_CreateMapping_PermanentLeaseFallback(t *testing.T) {
m := NewManager()
m.wgPort = 51820
gateway := newMockNAT()
gateway.onlyPermanentLeases = true
mapping, err := m.createMapping(context.Background(), gateway)
require.NoError(t, err)
require.NotNil(t, mapping)
assert.Equal(t, uint16(51820), mapping.InternalPort)
assert.Equal(t, time.Duration(0), mapping.TTL, "should return zero TTL for permanent lease")
assert.Equal(t, time.Duration(0), gateway.lastTimeout, "should have retried with zero duration")
}
func TestIsPermanentLeaseRequired(t *testing.T) {
tests := []struct {
name string
err error
expected bool
}{
{
name: "nil error",
err: nil,
expected: false,
},
{
name: "UPnP error 725",
err: fmt.Errorf("SOAP fault. Code: | Detail: <UPnPError><errorCode>725</errorCode><errorDescription>OnlyPermanentLeasesSupported</errorDescription></UPnPError>"),
expected: true,
},
{
name: "wrapped error with 725",
err: fmt.Errorf("add port mapping: %w", fmt.Errorf("Detail: <errorCode>725</errorCode>")),
expected: true,
},
{
name: "error 725 with newlines in XML",
err: fmt.Errorf("<errorCode>\n 725\n</errorCode>"),
expected: true,
},
{
name: "bare 725 without XML tag",
err: fmt.Errorf("error code 725"),
expected: false,
},
{
name: "unrelated error",
err: fmt.Errorf("connection refused"),
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert.Equal(t, tt.expected, isPermanentLeaseRequired(tt.err))
})
}
}

View File

@@ -53,7 +53,6 @@ func (n *Notifier) OnNewPrefixes(prefixes []netip.Prefix) {
n.currentPrefixes = newNets n.currentPrefixes = newNets
n.notify() n.notify()
} }
func (n *Notifier) notify() { func (n *Notifier) notify() {
n.listenerMux.Lock() n.listenerMux.Lock()
defer n.listenerMux.Unlock() defer n.listenerMux.Unlock()

View File

@@ -162,7 +162,11 @@ func (c *Client) Run(fd int32, interfaceName string, envList *EnvList) error {
cfg.WgIface = interfaceName cfg.WgIface = interfaceName
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder) c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
return c.connectClient.RunOniOS(fd, c.networkChangeListener, c.dnsManager, c.stateFile) hostDNS := []netip.AddrPort{
netip.MustParseAddrPort("9.9.9.9:53"),
netip.MustParseAddrPort("149.112.112.112:53"),
}
return c.connectClient.RunOniOS(fd, c.networkChangeListener, c.dnsManager, hostDNS, c.stateFile)
} }
// Stop the internal client and free the resources // Stop the internal client and free the resources

View File

@@ -9,6 +9,7 @@ import (
"github.com/netbirdio/netbird/client/ssh/config" "github.com/netbirdio/netbird/client/ssh/config"
) )
// registerStates registers all states that need crash recovery cleanup.
func registerStates(mgr *statemanager.Manager) { func registerStates(mgr *statemanager.Manager) {
mgr.RegisterState(&dns.ShutdownState{}) mgr.RegisterState(&dns.ShutdownState{})
mgr.RegisterState(&systemops.ShutdownState{}) mgr.RegisterState(&systemops.ShutdownState{})

View File

@@ -11,6 +11,7 @@ import (
"github.com/netbirdio/netbird/client/ssh/config" "github.com/netbirdio/netbird/client/ssh/config"
) )
// registerStates registers all states that need crash recovery cleanup.
func registerStates(mgr *statemanager.Manager) { func registerStates(mgr *statemanager.Manager) {
mgr.RegisterState(&dns.ShutdownState{}) mgr.RegisterState(&dns.ShutdownState{})
mgr.RegisterState(&systemops.ShutdownState{}) mgr.RegisterState(&systemops.ShutdownState{})

View File

@@ -141,7 +141,7 @@ func (p *SSHProxy) runProxySSHServer(jwtToken string) error {
func (p *SSHProxy) handleSSHSession(session ssh.Session) { func (p *SSHProxy) handleSSHSession(session ssh.Session) {
ptyReq, winCh, isPty := session.Pty() ptyReq, winCh, isPty := session.Pty()
hasCommand := len(session.Command()) > 0 hasCommand := session.RawCommand() != ""
sshClient, err := p.getOrCreateBackendClient(session.Context(), session.User()) sshClient, err := p.getOrCreateBackendClient(session.Context(), session.User())
if err != nil { if err != nil {
@@ -180,7 +180,7 @@ func (p *SSHProxy) handleSSHSession(session ssh.Session) {
} }
if hasCommand { if hasCommand {
if err := serverSession.Run(strings.Join(session.Command(), " ")); err != nil { if err := serverSession.Run(session.RawCommand()); err != nil {
log.Debugf("run command: %v", err) log.Debugf("run command: %v", err)
p.handleProxyExitCode(session, err) p.handleProxyExitCode(session, err)
} }

View File

@@ -1,6 +1,7 @@
package proxy package proxy
import ( import (
"bytes"
"context" "context"
"crypto/rand" "crypto/rand"
"crypto/rsa" "crypto/rsa"
@@ -245,6 +246,191 @@ func TestSSHProxy_Connect(t *testing.T) {
cancel() cancel()
} }
// TestSSHProxy_CommandQuoting verifies that the proxy preserves shell quoting
// when forwarding commands to the backend. This is critical for tools like
// Ansible that send commands such as:
//
// /bin/sh -c '( umask 77 && mkdir -p ... ) && sleep 0'
//
// The single quotes must be preserved so the backend shell receives the
// subshell expression as a single argument to -c.
func TestSSHProxy_CommandQuoting(t *testing.T) {
if testing.Short() {
t.Skip("Skipping integration test in short mode")
}
sshClient, cleanup := setupProxySSHClient(t)
defer cleanup()
// These commands simulate what the SSH protocol delivers as exec payloads.
// When a user types: ssh host '/bin/sh -c "( echo hello )"'
// the local shell strips the outer single quotes, and the SSH exec request
// contains the raw string: /bin/sh -c "( echo hello )"
//
// The proxy must forward this string verbatim. Using session.Command()
// (shlex.Split + strings.Join) strips the inner double quotes, breaking
// the command on the backend.
tests := []struct {
name string
command string
expect string
}{
{
name: "subshell_in_double_quotes",
command: `/bin/sh -c "( echo from-subshell ) && echo outer"`,
expect: "from-subshell\nouter\n",
},
{
name: "printf_with_special_chars",
command: `/bin/sh -c "printf '%s\n' 'hello world'"`,
expect: "hello world\n",
},
{
name: "nested_command_substitution",
command: `/bin/sh -c "echo $(echo nested)"`,
expect: "nested\n",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
session, err := sshClient.NewSession()
require.NoError(t, err)
defer func() { _ = session.Close() }()
var stderrBuf bytes.Buffer
session.Stderr = &stderrBuf
outputCh := make(chan []byte, 1)
errCh := make(chan error, 1)
go func() {
output, err := session.Output(tc.command)
outputCh <- output
errCh <- err
}()
select {
case output := <-outputCh:
err := <-errCh
if stderrBuf.Len() > 0 {
t.Logf("stderr: %s", stderrBuf.String())
}
require.NoError(t, err, "command should succeed: %s", tc.command)
assert.Equal(t, tc.expect, string(output), "output mismatch for: %s", tc.command)
case <-time.After(5 * time.Second):
t.Fatalf("command timed out: %s", tc.command)
}
})
}
}
// setupProxySSHClient creates a full proxy test environment and returns
// an SSH client connected through the proxy to a backend NetBird SSH server.
func setupProxySSHClient(t *testing.T) (*cryptossh.Client, func()) {
t.Helper()
const (
issuer = "https://test-issuer.example.com"
audience = "test-audience"
)
jwksServer, privateKey, jwksURL := setupJWKSServer(t)
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
hostPubKey, err := nbssh.GeneratePublicKey(hostKey)
require.NoError(t, err)
serverConfig := &server.Config{
HostKeyPEM: hostKey,
JWT: &server.JWTConfig{
Issuer: issuer,
Audiences: []string{audience},
KeysLocation: jwksURL,
},
}
sshServer := server.New(serverConfig)
sshServer.SetAllowRootLogin(true)
testUsername := testutil.GetTestUsername(t)
testJWTUser := "test-username"
testUserHash, err := sshuserhash.HashUserID(testJWTUser)
require.NoError(t, err)
authConfig := &sshauth.Config{
UserIDClaim: sshauth.DefaultUserIDClaim,
AuthorizedUsers: []sshuserhash.UserIDHash{testUserHash},
MachineUsers: map[string][]uint32{
testUsername: {0},
},
}
sshServer.UpdateSSHAuth(authConfig)
sshServerAddr := server.StartTestServer(t, sshServer)
mockDaemon := startMockDaemon(t)
host, portStr, err := net.SplitHostPort(sshServerAddr)
require.NoError(t, err)
port, err := strconv.Atoi(portStr)
require.NoError(t, err)
mockDaemon.setHostKey(host, hostPubKey)
validToken := generateValidJWT(t, privateKey, issuer, audience, testJWTUser)
mockDaemon.setJWTToken(validToken)
proxyInstance, err := New(mockDaemon.addr, host, port, io.Discard, nil)
require.NoError(t, err)
origStdin := os.Stdin
origStdout := os.Stdout
stdinReader, stdinWriter, err := os.Pipe()
require.NoError(t, err)
stdoutReader, stdoutWriter, err := os.Pipe()
require.NoError(t, err)
os.Stdin = stdinReader
os.Stdout = stdoutWriter
clientConn, proxyConn := net.Pipe()
go func() { _, _ = io.Copy(stdinWriter, proxyConn) }()
go func() { _, _ = io.Copy(proxyConn, stdoutReader) }()
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
go func() {
_ = proxyInstance.Connect(ctx)
}()
sshConfig := &cryptossh.ClientConfig{
User: testutil.GetTestUsername(t),
Auth: []cryptossh.AuthMethod{},
HostKeyCallback: cryptossh.InsecureIgnoreHostKey(),
Timeout: 5 * time.Second,
}
sshClientConn, chans, reqs, err := cryptossh.NewClientConn(clientConn, "test", sshConfig)
require.NoError(t, err)
client := cryptossh.NewClient(sshClientConn, chans, reqs)
cleanupFn := func() {
_ = client.Close()
_ = clientConn.Close()
cancel()
os.Stdin = origStdin
os.Stdout = origStdout
_ = sshServer.Stop()
mockDaemon.stop()
jwksServer.Close()
}
return client, cleanupFn
}
type mockDaemonServer struct { type mockDaemonServer struct {
proto.UnimplementedDaemonServiceServer proto.UnimplementedDaemonServiceServer
hostKeys map[string][]byte hostKeys map[string][]byte

View File

@@ -60,7 +60,7 @@ func (s *Server) sessionHandler(session ssh.Session) {
} }
ptyReq, winCh, isPty := session.Pty() ptyReq, winCh, isPty := session.Pty()
hasCommand := len(session.Command()) > 0 hasCommand := session.RawCommand() != ""
if isPty && !hasCommand { if isPty && !hasCommand {
// ssh <host> - PTY interactive session (login) // ssh <host> - PTY interactive session (login)

View File

@@ -43,18 +43,24 @@ func GetInfo(ctx context.Context) *Info {
systemHostname, _ := os.Hostname() systemHostname, _ := os.Hostname()
addrs, err := networkAddresses()
if err != nil {
log.Warnf("failed to discover network addresses: %s", err)
}
return &Info{ return &Info{
GoOS: runtime.GOOS, GoOS: runtime.GOOS,
Kernel: osInfo[0], Kernel: osInfo[0],
Platform: runtime.GOARCH, Platform: runtime.GOARCH,
OS: osName, OS: osName,
OSVersion: osVersion, OSVersion: osVersion,
Hostname: extractDeviceName(ctx, systemHostname), Hostname: extractDeviceName(ctx, systemHostname),
CPUs: runtime.NumCPU(), CPUs: runtime.NumCPU(),
NetbirdVersion: version.NetbirdVersion(), NetbirdVersion: version.NetbirdVersion(),
UIVersion: extractUserAgent(ctx), UIVersion: extractUserAgent(ctx),
KernelVersion: osInfo[1], KernelVersion: osInfo[1],
Environment: env, NetworkAddresses: addrs,
Environment: env,
} }
} }

View File

@@ -24,9 +24,10 @@ import (
// Initial state for the debug collection // Initial state for the debug collection
type debugInitialState struct { type debugInitialState struct {
wasDown bool wasDown bool
logLevel proto.LogLevel needsRestoreUp bool
isLevelTrace bool logLevel proto.LogLevel
isLevelTrace bool
} }
// Debug collection parameters // Debug collection parameters
@@ -371,46 +372,51 @@ func (s *serviceClient) configureServiceForDebug(
conn proto.DaemonServiceClient, conn proto.DaemonServiceClient,
state *debugInitialState, state *debugInitialState,
enablePersistence bool, enablePersistence bool,
) error { ) {
if state.wasDown { if state.wasDown {
if _, err := conn.Up(s.ctx, &proto.UpRequest{}); err != nil { if _, err := conn.Up(s.ctx, &proto.UpRequest{}); err != nil {
return fmt.Errorf("bring service up: %v", err) log.Warnf("failed to bring service up: %v", err)
} else {
log.Info("Service brought up for debug")
time.Sleep(time.Second * 10)
} }
log.Info("Service brought up for debug")
time.Sleep(time.Second * 10)
} }
if !state.isLevelTrace { if !state.isLevelTrace {
if _, err := conn.SetLogLevel(s.ctx, &proto.SetLogLevelRequest{Level: proto.LogLevel_TRACE}); err != nil { if _, err := conn.SetLogLevel(s.ctx, &proto.SetLogLevelRequest{Level: proto.LogLevel_TRACE}); err != nil {
return fmt.Errorf("set log level to TRACE: %v", err) log.Warnf("failed to set log level to TRACE: %v", err)
} else {
log.Info("Log level set to TRACE for debug")
} }
log.Info("Log level set to TRACE for debug")
} }
if _, err := conn.Down(s.ctx, &proto.DownRequest{}); err != nil { if _, err := conn.Down(s.ctx, &proto.DownRequest{}); err != nil {
return fmt.Errorf("bring service down: %v", err) log.Warnf("failed to bring service down: %v", err)
} else {
state.needsRestoreUp = !state.wasDown
time.Sleep(time.Second)
} }
time.Sleep(time.Second)
if enablePersistence { if enablePersistence {
if _, err := conn.SetSyncResponsePersistence(s.ctx, &proto.SetSyncResponsePersistenceRequest{ if _, err := conn.SetSyncResponsePersistence(s.ctx, &proto.SetSyncResponsePersistenceRequest{
Enabled: true, Enabled: true,
}); err != nil { }); err != nil {
return fmt.Errorf("enable sync response persistence: %v", err) log.Warnf("failed to enable sync response persistence: %v", err)
} else {
log.Info("Sync response persistence enabled for debug")
} }
log.Info("Sync response persistence enabled for debug")
} }
if _, err := conn.Up(s.ctx, &proto.UpRequest{}); err != nil { if _, err := conn.Up(s.ctx, &proto.UpRequest{}); err != nil {
return fmt.Errorf("bring service back up: %v", err) log.Warnf("failed to bring service back up: %v", err)
} else {
state.needsRestoreUp = false
time.Sleep(time.Second * 3)
} }
time.Sleep(time.Second * 3)
if _, err := conn.StartCPUProfile(s.ctx, &proto.StartCPUProfileRequest{}); err != nil { if _, err := conn.StartCPUProfile(s.ctx, &proto.StartCPUProfileRequest{}); err != nil {
log.Warnf("failed to start CPU profiling: %v", err) log.Warnf("failed to start CPU profiling: %v", err)
} }
return nil
} }
func (s *serviceClient) collectDebugData( func (s *serviceClient) collectDebugData(
@@ -424,9 +430,7 @@ func (s *serviceClient) collectDebugData(
var wg sync.WaitGroup var wg sync.WaitGroup
startProgressTracker(ctx, &wg, params.duration, progress) startProgressTracker(ctx, &wg, params.duration, progress)
if err := s.configureServiceForDebug(conn, state, params.enablePersistence); err != nil { s.configureServiceForDebug(conn, state, params.enablePersistence)
return err
}
wg.Wait() wg.Wait()
progress.progressBar.Hide() progress.progressBar.Hide()
@@ -482,9 +486,17 @@ func (s *serviceClient) createDebugBundleFromCollection(
// Restore service to original state // Restore service to original state
func (s *serviceClient) restoreServiceState(conn proto.DaemonServiceClient, state *debugInitialState) { func (s *serviceClient) restoreServiceState(conn proto.DaemonServiceClient, state *debugInitialState) {
if state.needsRestoreUp {
if _, err := conn.Up(s.ctx, &proto.UpRequest{}); err != nil {
log.Warnf("failed to restore up state: %v", err)
} else {
log.Info("Service state restored to up")
}
}
if state.wasDown { if state.wasDown {
if _, err := conn.Down(s.ctx, &proto.DownRequest{}); err != nil { if _, err := conn.Down(s.ctx, &proto.DownRequest{}); err != nil {
log.Errorf("Failed to restore down state: %v", err) log.Warnf("failed to restore down state: %v", err)
} else { } else {
log.Info("Service state restored to down") log.Info("Service state restored to down")
} }
@@ -492,7 +504,7 @@ func (s *serviceClient) restoreServiceState(conn proto.DaemonServiceClient, stat
if !state.isLevelTrace { if !state.isLevelTrace {
if _, err := conn.SetLogLevel(s.ctx, &proto.SetLogLevelRequest{Level: state.logLevel}); err != nil { if _, err := conn.SetLogLevel(s.ctx, &proto.SetLogLevelRequest{Level: state.logLevel}); err != nil {
log.Errorf("Failed to restore log level: %v", err) log.Warnf("failed to restore log level: %v", err)
} else { } else {
log.Info("Log level restored to original setting") log.Info("Log level restored to original setting")
} }

View File

@@ -29,6 +29,7 @@ import (
"github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/relay/healthcheck" "github.com/netbirdio/netbird/relay/healthcheck"
relayServer "github.com/netbirdio/netbird/relay/server" relayServer "github.com/netbirdio/netbird/relay/server"
"github.com/netbirdio/netbird/relay/server/listener"
"github.com/netbirdio/netbird/relay/server/listener/ws" "github.com/netbirdio/netbird/relay/server/listener/ws"
sharedMetrics "github.com/netbirdio/netbird/shared/metrics" sharedMetrics "github.com/netbirdio/netbird/shared/metrics"
"github.com/netbirdio/netbird/shared/relay/auth" "github.com/netbirdio/netbird/shared/relay/auth"
@@ -523,7 +524,7 @@ func createManagementServer(cfg *CombinedConfig, mgmtConfig *nbconfig.Config) (*
func createCombinedHandler(grpcServer *grpc.Server, httpHandler http.Handler, relaySrv *relayServer.Server, meter metric.Meter, cfg *CombinedConfig) http.Handler { func createCombinedHandler(grpcServer *grpc.Server, httpHandler http.Handler, relaySrv *relayServer.Server, meter metric.Meter, cfg *CombinedConfig) http.Handler {
wsProxy := wsproxyserver.New(grpcServer, wsproxyserver.WithOTelMeter(meter)) wsProxy := wsproxyserver.New(grpcServer, wsproxyserver.WithOTelMeter(meter))
var relayAcceptFn func(conn net.Conn) var relayAcceptFn func(conn listener.Conn)
if relaySrv != nil { if relaySrv != nil {
relayAcceptFn = relaySrv.RelayAccept() relayAcceptFn = relaySrv.RelayAccept()
} }
@@ -563,7 +564,7 @@ func createCombinedHandler(grpcServer *grpc.Server, httpHandler http.Handler, re
} }
// handleRelayWebSocket handles incoming WebSocket connections for the relay service // handleRelayWebSocket handles incoming WebSocket connections for the relay service
func handleRelayWebSocket(w http.ResponseWriter, r *http.Request, acceptFn func(conn net.Conn), cfg *CombinedConfig) { func handleRelayWebSocket(w http.ResponseWriter, r *http.Request, acceptFn func(conn listener.Conn), cfg *CombinedConfig) {
acceptOptions := &websocket.AcceptOptions{ acceptOptions := &websocket.AcceptOptions{
OriginPatterns: []string{"*"}, OriginPatterns: []string{"*"},
} }
@@ -585,15 +586,9 @@ func handleRelayWebSocket(w http.ResponseWriter, r *http.Request, acceptFn func(
return return
} }
lAddr, err := net.ResolveTCPAddr("tcp", cfg.Server.ListenAddress)
if err != nil {
_ = wsConn.Close(websocket.StatusInternalError, "internal error")
return
}
log.Debugf("Relay WS client connected from: %s", rAddr) log.Debugf("Relay WS client connected from: %s", rAddr)
conn := ws.NewConn(wsConn, lAddr, rAddr) conn := ws.NewConn(wsConn, rAddr)
acceptFn(conn) acceptFn(conn)
} }

4
go.mod
View File

@@ -63,6 +63,7 @@ require (
github.com/hashicorp/go-version v1.6.0 github.com/hashicorp/go-version v1.6.0
github.com/jackc/pgx/v5 v5.5.5 github.com/jackc/pgx/v5 v5.5.5
github.com/libdns/route53 v1.5.0 github.com/libdns/route53 v1.5.0
github.com/libp2p/go-nat v0.2.0
github.com/libp2p/go-netroute v0.2.1 github.com/libp2p/go-netroute v0.2.1
github.com/lrh3321/ipset-go v0.0.0-20250619021614-54a0a98ace81 github.com/lrh3321/ipset-go v0.0.0-20250619021614-54a0a98ace81
github.com/mdlayher/socket v0.5.1 github.com/mdlayher/socket v0.5.1
@@ -200,10 +201,12 @@ require (
github.com/hashicorp/errwrap v1.1.0 // indirect github.com/hashicorp/errwrap v1.1.0 // indirect
github.com/hashicorp/go-uuid v1.0.3 // indirect github.com/hashicorp/go-uuid v1.0.3 // indirect
github.com/huandu/xstrings v1.5.0 // indirect github.com/huandu/xstrings v1.5.0 // indirect
github.com/huin/goupnp v1.2.0 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect
github.com/jackc/puddle/v2 v2.2.1 // indirect github.com/jackc/puddle/v2 v2.2.1 // indirect
github.com/jackpal/go-nat-pmp v1.0.2 // indirect
github.com/jeandeaual/go-locale v0.0.0-20250612000132-0ef82f21eade // indirect github.com/jeandeaual/go-locale v0.0.0-20250612000132-0ef82f21eade // indirect
github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect github.com/jinzhu/now v1.1.5 // indirect
@@ -213,6 +216,7 @@ require (
github.com/kelseyhightower/envconfig v1.4.0 // indirect github.com/kelseyhightower/envconfig v1.4.0 // indirect
github.com/klauspost/compress v1.18.0 // indirect github.com/klauspost/compress v1.18.0 // indirect
github.com/klauspost/cpuid/v2 v2.2.7 // indirect github.com/klauspost/cpuid/v2 v2.2.7 // indirect
github.com/koron/go-ssdp v0.0.4 // indirect
github.com/kr/fs v0.1.0 // indirect github.com/kr/fs v0.1.0 // indirect
github.com/lib/pq v1.10.9 // indirect github.com/lib/pq v1.10.9 // indirect
github.com/libdns/libdns v0.2.2 // indirect github.com/libdns/libdns v0.2.2 // indirect

8
go.sum
View File

@@ -281,6 +281,8 @@ github.com/hashicorp/go-version v1.6.0/go.mod h1:fltr4n8CU8Ke44wwGCBoEymUuxUHl09
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
github.com/huandu/xstrings v1.5.0 h1:2ag3IFq9ZDANvthTwTiqSSZLjDc+BedvHPAp5tJy2TI= github.com/huandu/xstrings v1.5.0 h1:2ag3IFq9ZDANvthTwTiqSSZLjDc+BedvHPAp5tJy2TI=
github.com/huandu/xstrings v1.5.0/go.mod h1:y5/lhBue+AyNmUVz9RLU9xbLR0o4KIIExikq4ovT0aE= github.com/huandu/xstrings v1.5.0/go.mod h1:y5/lhBue+AyNmUVz9RLU9xbLR0o4KIIExikq4ovT0aE=
github.com/huin/goupnp v1.2.0 h1:uOKW26NG1hsSSbXIZ1IR7XP9Gjd1U8pnLaCMgntmkmY=
github.com/huin/goupnp v1.2.0/go.mod h1:gnGPsThkYa7bFi/KWmEysQRf48l2dvR5bxr2OFckNX8=
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
@@ -291,6 +293,8 @@ github.com/jackc/pgx/v5 v5.5.5 h1:amBjrZVmksIdNjxGW/IiIMzxMKZFelXbUoPNb+8sjQw=
github.com/jackc/pgx/v5 v5.5.5/go.mod h1:ez9gk+OAat140fv9ErkZDYFWmXLfV+++K0uAOiwgm1A= github.com/jackc/pgx/v5 v5.5.5/go.mod h1:ez9gk+OAat140fv9ErkZDYFWmXLfV+++K0uAOiwgm1A=
github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk= github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk=
github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
github.com/jackpal/go-nat-pmp v1.0.2 h1:KzKSgb7qkJvOUTqYl9/Hg/me3pWgBmERKrTGD7BdWus=
github.com/jackpal/go-nat-pmp v1.0.2/go.mod h1:QPH045xvCAeXUZOxsnwmrtiCoxIr9eob+4orBN1SBKc=
github.com/jcmturner/aescts/v2 v2.0.0 h1:9YKLH6ey7H4eDBXW8khjYslgyqG2xZikXP0EQFKrle8= github.com/jcmturner/aescts/v2 v2.0.0 h1:9YKLH6ey7H4eDBXW8khjYslgyqG2xZikXP0EQFKrle8=
github.com/jcmturner/aescts/v2 v2.0.0/go.mod h1:AiaICIRyfYg35RUkr8yESTqvSy7csK90qZ5xfvvsoNs= github.com/jcmturner/aescts/v2 v2.0.0/go.mod h1:AiaICIRyfYg35RUkr8yESTqvSy7csK90qZ5xfvvsoNs=
github.com/jcmturner/dnsutils/v2 v2.0.0 h1:lltnkeZGL0wILNvrNiVCR6Ro5PGU/SeBvVO/8c/iPbo= github.com/jcmturner/dnsutils/v2 v2.0.0 h1:lltnkeZGL0wILNvrNiVCR6Ro5PGU/SeBvVO/8c/iPbo=
@@ -328,6 +332,8 @@ github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYW
github.com/klauspost/cpuid/v2 v2.0.12/go.mod h1:g2LTdtYhdyuGPqyWyv7qRAmj1WBqxuObKfj5c0PQa7c= github.com/klauspost/cpuid/v2 v2.0.12/go.mod h1:g2LTdtYhdyuGPqyWyv7qRAmj1WBqxuObKfj5c0PQa7c=
github.com/klauspost/cpuid/v2 v2.2.7 h1:ZWSB3igEs+d0qvnxR/ZBzXVmxkgt8DdzP6m9pfuVLDM= github.com/klauspost/cpuid/v2 v2.2.7 h1:ZWSB3igEs+d0qvnxR/ZBzXVmxkgt8DdzP6m9pfuVLDM=
github.com/klauspost/cpuid/v2 v2.2.7/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws= github.com/klauspost/cpuid/v2 v2.2.7/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws=
github.com/koron/go-ssdp v0.0.4 h1:1IDwrghSKYM7yLf7XCzbByg2sJ/JcNOZRXS2jczTwz0=
github.com/koron/go-ssdp v0.0.4/go.mod h1:oDXq+E5IL5q0U8uSBcoAXzTzInwy5lEgC91HoKtbmZk=
github.com/kr/fs v0.1.0 h1:Jskdu9ieNAYnjxsi0LbQp1ulIKZV1LAFgK1tWhpZgl8= github.com/kr/fs v0.1.0 h1:Jskdu9ieNAYnjxsi0LbQp1ulIKZV1LAFgK1tWhpZgl8=
github.com/kr/fs v0.1.0/go.mod h1:FFnZGqtBN9Gxj7eW1uZ42v5BccTP0vu6NEaFoC2HwRg= github.com/kr/fs v0.1.0/go.mod h1:FFnZGqtBN9Gxj7eW1uZ42v5BccTP0vu6NEaFoC2HwRg=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
@@ -346,6 +352,8 @@ github.com/libdns/libdns v0.2.2 h1:O6ws7bAfRPaBsgAYt8MDe2HcNBGC29hkZ9MX2eUSX3s=
github.com/libdns/libdns v0.2.2/go.mod h1:4Bj9+5CQiNMVGf87wjX4CY3HQJypUHRuLvlsfsZqLWQ= github.com/libdns/libdns v0.2.2/go.mod h1:4Bj9+5CQiNMVGf87wjX4CY3HQJypUHRuLvlsfsZqLWQ=
github.com/libdns/route53 v1.5.0 h1:2SKdpPFl/qgWsXQvsLNJJAoX7rSxlk7zgoL4jnWdXVA= github.com/libdns/route53 v1.5.0 h1:2SKdpPFl/qgWsXQvsLNJJAoX7rSxlk7zgoL4jnWdXVA=
github.com/libdns/route53 v1.5.0/go.mod h1:joT4hKmaTNKHEwb7GmZ65eoDz1whTu7KKYPS8ZqIh6Q= github.com/libdns/route53 v1.5.0/go.mod h1:joT4hKmaTNKHEwb7GmZ65eoDz1whTu7KKYPS8ZqIh6Q=
github.com/libp2p/go-nat v0.2.0 h1:Tyz+bUFAYqGyJ/ppPPymMGbIgNRH+WqC5QrT5fKrrGk=
github.com/libp2p/go-nat v0.2.0/go.mod h1:3MJr+GRpRkyT65EpVPBstXLvOlAPzUVlG6Pwg9ohLJk=
github.com/lrh3321/ipset-go v0.0.0-20250619021614-54a0a98ace81 h1:J56rFEfUTFT9j9CiRXhi1r8lUJ4W5idG3CiaBZGojNU= github.com/lrh3321/ipset-go v0.0.0-20250619021614-54a0a98ace81 h1:J56rFEfUTFT9j9CiRXhi1r8lUJ4W5idG3CiaBZGojNU=
github.com/lrh3321/ipset-go v0.0.0-20250619021614-54a0a98ace81/go.mod h1:RD8ML/YdXctQ7qbcizZkw5mZ6l8Ogrl1dodBzVJduwI= github.com/lrh3321/ipset-go v0.0.0-20250619021614-54a0a98ace81/go.mod h1:RD8ML/YdXctQ7qbcizZkw5mZ6l8Ogrl1dodBzVJduwI=
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I= github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I=

View File

@@ -2115,6 +2115,7 @@ func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*rpserv
var createdAt, certIssuedAt sql.NullTime var createdAt, certIssuedAt sql.NullTime
var status, proxyCluster, sessionPrivateKey, sessionPublicKey sql.NullString var status, proxyCluster, sessionPrivateKey, sessionPublicKey sql.NullString
var mode, source, sourcePeer sql.NullString var mode, source, sourcePeer sql.NullString
var terminated sql.NullBool
err := row.Scan( err := row.Scan(
&s.ID, &s.ID,
&s.AccountID, &s.AccountID,
@@ -2135,7 +2136,7 @@ func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*rpserv
&s.PortAutoAssigned, &s.PortAutoAssigned,
&source, &source,
&sourcePeer, &sourcePeer,
&s.Terminated, &terminated,
) )
if err != nil { if err != nil {
return nil, err return nil, err
@@ -2176,7 +2177,9 @@ func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*rpserv
if sourcePeer.Valid { if sourcePeer.Valid {
s.SourcePeer = sourcePeer.String s.SourcePeer = sourcePeer.String
} }
if terminated.Valid {
s.Terminated = terminated.Bool
}
s.Targets = []*rpservice.Target{} s.Targets = []*rpservice.Target{}
return &s, nil return &s, nil
}) })

View File

@@ -0,0 +1,217 @@
package types_test
import (
"context"
"fmt"
"os"
"testing"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/types"
)
type benchmarkScale struct {
name string
peers int
groups int
}
var defaultScales = []benchmarkScale{
{"100peers_5groups", 100, 5},
{"500peers_20groups", 500, 20},
{"1000peers_50groups", 1000, 50},
{"5000peers_100groups", 5000, 100},
{"10000peers_200groups", 10000, 200},
{"20000peers_200groups", 20000, 200},
{"30000peers_300groups", 30000, 300},
}
func skipCIBenchmark(b *testing.B) {
if os.Getenv("CI") == "true" {
b.Skip("Skipping benchmark in CI")
}
}
// ──────────────────────────────────────────────────────────────────────────────
// Single Peer Network Map Generation
// ──────────────────────────────────────────────────────────────────────────────
// BenchmarkNetworkMapGeneration_Components benchmarks the components-based approach for a single peer.
func BenchmarkNetworkMapGeneration_Components(b *testing.B) {
skipCIBenchmark(b)
for _, scale := range defaultScales {
b.Run(scale.name, func(b *testing.B) {
account, validatedPeers := scalableTestAccount(scale.peers, scale.groups)
ctx := context.Background()
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
groupIDToUserIDs := account.GetActiveGroupUsers()
b.ReportAllocs()
b.ResetTimer()
for range b.N {
_ = account.GetPeerNetworkMapFromComponents(ctx, "peer-0", nbdns.CustomZone{}, nil, validatedPeers, resourcePolicies, routers, nil, groupIDToUserIDs)
}
})
}
}
// ──────────────────────────────────────────────────────────────────────────────
// All Peers (UpdateAccountPeers hot path)
// ──────────────────────────────────────────────────────────────────────────────
// BenchmarkNetworkMapGeneration_AllPeers benchmarks generating network maps for ALL peers.
func BenchmarkNetworkMapGeneration_AllPeers(b *testing.B) {
skipCIBenchmark(b)
scales := []benchmarkScale{
{"100peers_5groups", 100, 5},
{"500peers_20groups", 500, 20},
{"1000peers_50groups", 1000, 50},
{"5000peers_100groups", 5000, 100},
}
for _, scale := range scales {
account, validatedPeers := scalableTestAccount(scale.peers, scale.groups)
ctx := context.Background()
peerIDs := make([]string, 0, len(account.Peers))
for peerID := range account.Peers {
peerIDs = append(peerIDs, peerID)
}
b.Run("components/"+scale.name, func(b *testing.B) {
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
groupIDToUserIDs := account.GetActiveGroupUsers()
b.ReportAllocs()
b.ResetTimer()
for range b.N {
for _, peerID := range peerIDs {
_ = account.GetPeerNetworkMapFromComponents(ctx, peerID, nbdns.CustomZone{}, nil, validatedPeers, resourcePolicies, routers, nil, groupIDToUserIDs)
}
}
})
}
}
// ──────────────────────────────────────────────────────────────────────────────
// Sub-operations
// ──────────────────────────────────────────────────────────────────────────────
// BenchmarkNetworkMapGeneration_ComponentsCreation benchmarks components extraction.
func BenchmarkNetworkMapGeneration_ComponentsCreation(b *testing.B) {
skipCIBenchmark(b)
for _, scale := range defaultScales {
b.Run(scale.name, func(b *testing.B) {
account, validatedPeers := scalableTestAccount(scale.peers, scale.groups)
ctx := context.Background()
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
groupIDToUserIDs := account.GetActiveGroupUsers()
b.ReportAllocs()
b.ResetTimer()
for range b.N {
_ = account.GetPeerNetworkMapComponents(ctx, "peer-0", nbdns.CustomZone{}, nil, validatedPeers, resourcePolicies, routers, groupIDToUserIDs)
}
})
}
}
// BenchmarkNetworkMapGeneration_ComponentsCalculation benchmarks calculation from pre-built components.
func BenchmarkNetworkMapGeneration_ComponentsCalculation(b *testing.B) {
skipCIBenchmark(b)
for _, scale := range defaultScales {
b.Run(scale.name, func(b *testing.B) {
account, validatedPeers := scalableTestAccount(scale.peers, scale.groups)
ctx := context.Background()
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
groupIDToUserIDs := account.GetActiveGroupUsers()
components := account.GetPeerNetworkMapComponents(ctx, "peer-0", nbdns.CustomZone{}, nil, validatedPeers, resourcePolicies, routers, groupIDToUserIDs)
b.ReportAllocs()
b.ResetTimer()
for range b.N {
_ = types.CalculateNetworkMapFromComponents(ctx, components)
}
})
}
}
// BenchmarkNetworkMapGeneration_PrecomputeMaps benchmarks precomputed map costs.
func BenchmarkNetworkMapGeneration_PrecomputeMaps(b *testing.B) {
skipCIBenchmark(b)
for _, scale := range defaultScales {
b.Run("ResourcePoliciesMap/"+scale.name, func(b *testing.B) {
account, _ := scalableTestAccount(scale.peers, scale.groups)
b.ReportAllocs()
b.ResetTimer()
for range b.N {
_ = account.GetResourcePoliciesMap()
}
})
b.Run("ResourceRoutersMap/"+scale.name, func(b *testing.B) {
account, _ := scalableTestAccount(scale.peers, scale.groups)
b.ReportAllocs()
b.ResetTimer()
for range b.N {
_ = account.GetResourceRoutersMap()
}
})
b.Run("ActiveGroupUsers/"+scale.name, func(b *testing.B) {
account, _ := scalableTestAccount(scale.peers, scale.groups)
b.ReportAllocs()
b.ResetTimer()
for range b.N {
_ = account.GetActiveGroupUsers()
}
})
}
}
// ──────────────────────────────────────────────────────────────────────────────
// Scaling Analysis
// ──────────────────────────────────────────────────────────────────────────────
// BenchmarkNetworkMapGeneration_GroupScaling tests group count impact on performance.
func BenchmarkNetworkMapGeneration_GroupScaling(b *testing.B) {
skipCIBenchmark(b)
groupCounts := []int{1, 5, 20, 50, 100, 200, 500}
for _, numGroups := range groupCounts {
b.Run(fmt.Sprintf("components_%dgroups", numGroups), func(b *testing.B) {
account, validatedPeers := scalableTestAccount(1000, numGroups)
ctx := context.Background()
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
groupIDToUserIDs := account.GetActiveGroupUsers()
b.ReportAllocs()
b.ResetTimer()
for range b.N {
_ = account.GetPeerNetworkMapFromComponents(ctx, "peer-0", nbdns.CustomZone{}, nil, validatedPeers, resourcePolicies, routers, nil, groupIDToUserIDs)
}
})
}
}
// BenchmarkNetworkMapGeneration_PeerScaling tests peer count impact on performance.
func BenchmarkNetworkMapGeneration_PeerScaling(b *testing.B) {
skipCIBenchmark(b)
peerCounts := []int{50, 100, 500, 1000, 2000, 5000, 10000, 20000, 30000}
for _, numPeers := range peerCounts {
numGroups := numPeers / 20
if numGroups < 1 {
numGroups = 1
}
b.Run(fmt.Sprintf("components_%dpeers", numPeers), func(b *testing.B) {
account, validatedPeers := scalableTestAccount(numPeers, numGroups)
ctx := context.Background()
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
groupIDToUserIDs := account.GetActiveGroupUsers()
b.ReportAllocs()
b.ResetTimer()
for range b.N {
_ = account.GetPeerNetworkMapFromComponents(ctx, "peer-0", nbdns.CustomZone{}, nil, validatedPeers, resourcePolicies, routers, nil, groupIDToUserIDs)
}
})
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,11 +1,13 @@
package server package server
import ( import (
"context"
"fmt" "fmt"
"net" "time"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/relay/server/listener"
"github.com/netbirdio/netbird/shared/relay/messages" "github.com/netbirdio/netbird/shared/relay/messages"
//nolint:staticcheck //nolint:staticcheck
"github.com/netbirdio/netbird/shared/relay/messages/address" "github.com/netbirdio/netbird/shared/relay/messages/address"
@@ -13,6 +15,12 @@ import (
authmsg "github.com/netbirdio/netbird/shared/relay/messages/auth" authmsg "github.com/netbirdio/netbird/shared/relay/messages/auth"
) )
const (
// handshakeTimeout bounds how long a connection may remain in the
// pre-authentication handshake phase before being closed.
handshakeTimeout = 10 * time.Second
)
type Validator interface { type Validator interface {
Validate(any) error Validate(any) error
// Deprecated: Use Validate instead. // Deprecated: Use Validate instead.
@@ -58,7 +66,7 @@ func marshalResponseHelloMsg(instanceURL string) ([]byte, error) {
} }
type handshake struct { type handshake struct {
conn net.Conn conn listener.Conn
validator Validator validator Validator
preparedMsg *preparedMsg preparedMsg *preparedMsg
@@ -66,9 +74,9 @@ type handshake struct {
peerID *messages.PeerID peerID *messages.PeerID
} }
func (h *handshake) handshakeReceive() (*messages.PeerID, error) { func (h *handshake) handshakeReceive(ctx context.Context) (*messages.PeerID, error) {
buf := make([]byte, messages.MaxHandshakeSize) buf := make([]byte, messages.MaxHandshakeSize)
n, err := h.conn.Read(buf) n, err := h.conn.Read(ctx, buf)
if err != nil { if err != nil {
return nil, fmt.Errorf("read from %s: %w", h.conn.RemoteAddr(), err) return nil, fmt.Errorf("read from %s: %w", h.conn.RemoteAddr(), err)
} }
@@ -103,7 +111,7 @@ func (h *handshake) handshakeReceive() (*messages.PeerID, error) {
return peerID, nil return peerID, nil
} }
func (h *handshake) handshakeResponse() error { func (h *handshake) handshakeResponse(ctx context.Context) error {
var responseMsg []byte var responseMsg []byte
if h.handshakeMethodAuth { if h.handshakeMethodAuth {
responseMsg = h.preparedMsg.responseAuthMsg responseMsg = h.preparedMsg.responseAuthMsg
@@ -111,7 +119,7 @@ func (h *handshake) handshakeResponse() error {
responseMsg = h.preparedMsg.responseHelloMsg responseMsg = h.preparedMsg.responseHelloMsg
} }
if _, err := h.conn.Write(responseMsg); err != nil { if _, err := h.conn.Write(ctx, responseMsg); err != nil {
return fmt.Errorf("handshake response write to %s (%s): %w", h.peerID, h.conn.RemoteAddr(), err) return fmt.Errorf("handshake response write to %s (%s): %w", h.peerID, h.conn.RemoteAddr(), err)
} }

View File

@@ -0,0 +1,14 @@
package listener
import (
"context"
"net"
)
// Conn is the relay connection contract implemented by WS and QUIC transports.
type Conn interface {
Read(ctx context.Context, b []byte) (n int, err error)
Write(ctx context.Context, b []byte) (n int, err error)
RemoteAddr() net.Addr
Close() error
}

View File

@@ -1,14 +0,0 @@
package listener
import (
"context"
"net"
"github.com/netbirdio/netbird/relay/protocol"
)
type Listener interface {
Listen(func(conn net.Conn)) error
Shutdown(ctx context.Context) error
Protocol() protocol.Protocol
}

View File

@@ -3,33 +3,26 @@ package quic
import ( import (
"context" "context"
"errors" "errors"
"fmt"
"net" "net"
"sync" "sync"
"time"
"github.com/quic-go/quic-go" "github.com/quic-go/quic-go"
) )
type Conn struct { type Conn struct {
session *quic.Conn session *quic.Conn
closed bool closed bool
closedMu sync.Mutex closedMu sync.Mutex
ctx context.Context
ctxCancel context.CancelFunc
} }
func NewConn(session *quic.Conn) *Conn { func NewConn(session *quic.Conn) *Conn {
ctx, cancel := context.WithCancel(context.Background())
return &Conn{ return &Conn{
session: session, session: session,
ctx: ctx,
ctxCancel: cancel,
} }
} }
func (c *Conn) Read(b []byte) (n int, err error) { func (c *Conn) Read(ctx context.Context, b []byte) (n int, err error) {
dgram, err := c.session.ReceiveDatagram(c.ctx) dgram, err := c.session.ReceiveDatagram(ctx)
if err != nil { if err != nil {
return 0, c.remoteCloseErrHandling(err) return 0, c.remoteCloseErrHandling(err)
} }
@@ -38,33 +31,17 @@ func (c *Conn) Read(b []byte) (n int, err error) {
return n, nil return n, nil
} }
func (c *Conn) Write(b []byte) (int, error) { func (c *Conn) Write(_ context.Context, b []byte) (int, error) {
if err := c.session.SendDatagram(b); err != nil { if err := c.session.SendDatagram(b); err != nil {
return 0, c.remoteCloseErrHandling(err) return 0, c.remoteCloseErrHandling(err)
} }
return len(b), nil return len(b), nil
} }
func (c *Conn) LocalAddr() net.Addr {
return c.session.LocalAddr()
}
func (c *Conn) RemoteAddr() net.Addr { func (c *Conn) RemoteAddr() net.Addr {
return c.session.RemoteAddr() return c.session.RemoteAddr()
} }
func (c *Conn) SetReadDeadline(t time.Time) error {
return nil
}
func (c *Conn) SetWriteDeadline(t time.Time) error {
return fmt.Errorf("SetWriteDeadline is not implemented")
}
func (c *Conn) SetDeadline(t time.Time) error {
return fmt.Errorf("SetDeadline is not implemented")
}
func (c *Conn) Close() error { func (c *Conn) Close() error {
c.closedMu.Lock() c.closedMu.Lock()
if c.closed { if c.closed {
@@ -74,8 +51,6 @@ func (c *Conn) Close() error {
c.closed = true c.closed = true
c.closedMu.Unlock() c.closedMu.Unlock()
c.ctxCancel() // Cancel the context
sessionErr := c.session.CloseWithError(0, "normal closure") sessionErr := c.session.CloseWithError(0, "normal closure")
return sessionErr return sessionErr
} }

View File

@@ -5,12 +5,12 @@ import (
"crypto/tls" "crypto/tls"
"errors" "errors"
"fmt" "fmt"
"net"
"github.com/quic-go/quic-go" "github.com/quic-go/quic-go"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/relay/protocol" "github.com/netbirdio/netbird/relay/protocol"
relaylistener "github.com/netbirdio/netbird/relay/server/listener"
nbRelay "github.com/netbirdio/netbird/shared/relay" nbRelay "github.com/netbirdio/netbird/shared/relay"
) )
@@ -25,7 +25,7 @@ type Listener struct {
listener *quic.Listener listener *quic.Listener
} }
func (l *Listener) Listen(acceptFn func(conn net.Conn)) error { func (l *Listener) Listen(acceptFn func(conn relaylistener.Conn)) error {
quicCfg := &quic.Config{ quicCfg := &quic.Config{
EnableDatagrams: true, EnableDatagrams: true,
InitialPacketSize: nbRelay.QUICInitialPacketSize, InitialPacketSize: nbRelay.QUICInitialPacketSize,

View File

@@ -18,25 +18,21 @@ const (
type Conn struct { type Conn struct {
*websocket.Conn *websocket.Conn
lAddr *net.TCPAddr
rAddr *net.TCPAddr rAddr *net.TCPAddr
closed bool closed bool
closedMu sync.Mutex closedMu sync.Mutex
ctx context.Context
} }
func NewConn(wsConn *websocket.Conn, lAddr, rAddr *net.TCPAddr) *Conn { func NewConn(wsConn *websocket.Conn, rAddr *net.TCPAddr) *Conn {
return &Conn{ return &Conn{
Conn: wsConn, Conn: wsConn,
lAddr: lAddr,
rAddr: rAddr, rAddr: rAddr,
ctx: context.Background(),
} }
} }
func (c *Conn) Read(b []byte) (n int, err error) { func (c *Conn) Read(ctx context.Context, b []byte) (n int, err error) {
t, r, err := c.Reader(c.ctx) t, r, err := c.Reader(ctx)
if err != nil { if err != nil {
return 0, c.ioErrHandling(err) return 0, c.ioErrHandling(err)
} }
@@ -56,34 +52,18 @@ func (c *Conn) Read(b []byte) (n int, err error) {
// Write writes a binary message with the given payload. // Write writes a binary message with the given payload.
// It does not block until fill the internal buffer. // It does not block until fill the internal buffer.
// If the buffer filled up, wait until the buffer is drained or timeout. // If the buffer filled up, wait until the buffer is drained or timeout.
func (c *Conn) Write(b []byte) (int, error) { func (c *Conn) Write(ctx context.Context, b []byte) (int, error) {
ctx, ctxCancel := context.WithTimeout(c.ctx, writeTimeout) ctx, ctxCancel := context.WithTimeout(ctx, writeTimeout)
defer ctxCancel() defer ctxCancel()
err := c.Conn.Write(ctx, websocket.MessageBinary, b) err := c.Conn.Write(ctx, websocket.MessageBinary, b)
return len(b), err return len(b), err
} }
func (c *Conn) LocalAddr() net.Addr {
return c.lAddr
}
func (c *Conn) RemoteAddr() net.Addr { func (c *Conn) RemoteAddr() net.Addr {
return c.rAddr return c.rAddr
} }
func (c *Conn) SetReadDeadline(t time.Time) error {
return fmt.Errorf("SetReadDeadline is not implemented")
}
func (c *Conn) SetWriteDeadline(t time.Time) error {
return fmt.Errorf("SetWriteDeadline is not implemented")
}
func (c *Conn) SetDeadline(t time.Time) error {
return fmt.Errorf("SetDeadline is not implemented")
}
func (c *Conn) Close() error { func (c *Conn) Close() error {
c.closedMu.Lock() c.closedMu.Lock()
c.closed = true c.closed = true

View File

@@ -7,11 +7,13 @@ import (
"fmt" "fmt"
"net" "net"
"net/http" "net/http"
"time"
"github.com/coder/websocket" "github.com/coder/websocket"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/relay/protocol" "github.com/netbirdio/netbird/relay/protocol"
relaylistener "github.com/netbirdio/netbird/relay/server/listener"
"github.com/netbirdio/netbird/shared/relay" "github.com/netbirdio/netbird/shared/relay"
) )
@@ -27,18 +29,19 @@ type Listener struct {
TLSConfig *tls.Config TLSConfig *tls.Config
server *http.Server server *http.Server
acceptFn func(conn net.Conn) acceptFn func(conn relaylistener.Conn)
} }
func (l *Listener) Listen(acceptFn func(conn net.Conn)) error { func (l *Listener) Listen(acceptFn func(conn relaylistener.Conn)) error {
l.acceptFn = acceptFn l.acceptFn = acceptFn
mux := http.NewServeMux() mux := http.NewServeMux()
mux.HandleFunc(URLPath, l.onAccept) mux.HandleFunc(URLPath, l.onAccept)
l.server = &http.Server{ l.server = &http.Server{
Addr: l.Address, Addr: l.Address,
Handler: mux, Handler: mux,
TLSConfig: l.TLSConfig, TLSConfig: l.TLSConfig,
ReadHeaderTimeout: 5 * time.Second,
} }
log.Infof("WS server listening address: %s", l.Address) log.Infof("WS server listening address: %s", l.Address)
@@ -93,18 +96,9 @@ func (l *Listener) onAccept(w http.ResponseWriter, r *http.Request) {
return return
} }
lAddr, err := net.ResolveTCPAddr("tcp", l.server.Addr)
if err != nil {
err = wsConn.Close(websocket.StatusInternalError, "internal error")
if err != nil {
log.Errorf("failed to close ws connection: %s", err)
}
return
}
log.Infof("WS client connected from: %s", rAddr) log.Infof("WS client connected from: %s", rAddr)
conn := NewConn(wsConn, lAddr, rAddr) conn := NewConn(wsConn, rAddr)
l.acceptFn(conn) l.acceptFn(conn)
} }

View File

@@ -10,6 +10,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/relay/metrics" "github.com/netbirdio/netbird/relay/metrics"
"github.com/netbirdio/netbird/relay/server/listener"
"github.com/netbirdio/netbird/relay/server/store" "github.com/netbirdio/netbird/relay/server/store"
"github.com/netbirdio/netbird/shared/relay/healthcheck" "github.com/netbirdio/netbird/shared/relay/healthcheck"
"github.com/netbirdio/netbird/shared/relay/messages" "github.com/netbirdio/netbird/shared/relay/messages"
@@ -26,11 +27,14 @@ type Peer struct {
metrics *metrics.Metrics metrics *metrics.Metrics
log *log.Entry log *log.Entry
id messages.PeerID id messages.PeerID
conn net.Conn conn listener.Conn
connMu sync.RWMutex connMu sync.RWMutex
store *store.Store store *store.Store
notifier *store.PeerNotifier notifier *store.PeerNotifier
ctx context.Context
ctxCancel context.CancelFunc
peersListener *store.Listener peersListener *store.Listener
// between the online peer collection step and the notification sending should not be sent offline notifications from another thread // between the online peer collection step and the notification sending should not be sent offline notifications from another thread
@@ -38,14 +42,17 @@ type Peer struct {
} }
// NewPeer creates a new Peer instance and prepare custom logging // NewPeer creates a new Peer instance and prepare custom logging
func NewPeer(metrics *metrics.Metrics, id messages.PeerID, conn net.Conn, store *store.Store, notifier *store.PeerNotifier) *Peer { func NewPeer(metrics *metrics.Metrics, id messages.PeerID, conn listener.Conn, store *store.Store, notifier *store.PeerNotifier) *Peer {
ctx, cancel := context.WithCancel(context.Background())
p := &Peer{ p := &Peer{
metrics: metrics, metrics: metrics,
log: log.WithField("peer_id", id.String()), log: log.WithField("peer_id", id.String()),
id: id, id: id,
conn: conn, conn: conn,
store: store, store: store,
notifier: notifier, notifier: notifier,
ctx: ctx,
ctxCancel: cancel,
} }
return p return p
@@ -57,6 +64,7 @@ func NewPeer(metrics *metrics.Metrics, id messages.PeerID, conn net.Conn, store
func (p *Peer) Work() { func (p *Peer) Work() {
p.peersListener = p.notifier.NewListener(p.sendPeersOnline, p.sendPeersWentOffline) p.peersListener = p.notifier.NewListener(p.sendPeersOnline, p.sendPeersWentOffline)
defer func() { defer func() {
p.ctxCancel()
p.notifier.RemoveListener(p.peersListener) p.notifier.RemoveListener(p.peersListener)
if err := p.conn.Close(); err != nil && !errors.Is(err, net.ErrClosed) { if err := p.conn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
@@ -64,8 +72,7 @@ func (p *Peer) Work() {
} }
}() }()
ctx, cancel := context.WithCancel(context.Background()) ctx := p.ctx
defer cancel()
hc := healthcheck.NewSender(p.log) hc := healthcheck.NewSender(p.log)
go hc.StartHealthCheck(ctx) go hc.StartHealthCheck(ctx)
@@ -73,7 +80,7 @@ func (p *Peer) Work() {
buf := make([]byte, bufferSize) buf := make([]byte, bufferSize)
for { for {
n, err := p.conn.Read(buf) n, err := p.conn.Read(ctx, buf)
if err != nil { if err != nil {
if !errors.Is(err, net.ErrClosed) { if !errors.Is(err, net.ErrClosed) {
p.log.Errorf("failed to read message: %s", err) p.log.Errorf("failed to read message: %s", err)
@@ -131,10 +138,10 @@ func (p *Peer) handleMsgType(ctx context.Context, msgType messages.MsgType, hc *
} }
// Write writes data to the connection // Write writes data to the connection
func (p *Peer) Write(b []byte) (int, error) { func (p *Peer) Write(ctx context.Context, b []byte) (int, error) {
p.connMu.RLock() p.connMu.RLock()
defer p.connMu.RUnlock() defer p.connMu.RUnlock()
return p.conn.Write(b) return p.conn.Write(ctx, b)
} }
// CloseGracefully closes the connection with the peer gracefully. Send a close message to the client and close the // CloseGracefully closes the connection with the peer gracefully. Send a close message to the client and close the
@@ -147,6 +154,7 @@ func (p *Peer) CloseGracefully(ctx context.Context) {
p.log.Errorf("failed to send close message to peer: %s", p.String()) p.log.Errorf("failed to send close message to peer: %s", p.String())
} }
p.ctxCancel()
if err := p.conn.Close(); err != nil { if err := p.conn.Close(); err != nil {
p.log.Errorf(errCloseConn, err) p.log.Errorf(errCloseConn, err)
} }
@@ -156,6 +164,7 @@ func (p *Peer) Close() {
p.connMu.Lock() p.connMu.Lock()
defer p.connMu.Unlock() defer p.connMu.Unlock()
p.ctxCancel()
if err := p.conn.Close(); err != nil { if err := p.conn.Close(); err != nil {
p.log.Errorf(errCloseConn, err) p.log.Errorf(errCloseConn, err)
} }
@@ -170,26 +179,15 @@ func (p *Peer) writeWithTimeout(ctx context.Context, buf []byte) error {
ctx, cancel := context.WithTimeout(ctx, 3*time.Second) ctx, cancel := context.WithTimeout(ctx, 3*time.Second)
defer cancel() defer cancel()
writeDone := make(chan struct{}) _, err := p.conn.Write(ctx, buf)
var err error return err
go func() {
_, err = p.conn.Write(buf)
close(writeDone)
}()
select {
case <-ctx.Done():
return ctx.Err()
case <-writeDone:
return err
}
} }
func (p *Peer) handleHealthcheckEvents(ctx context.Context, hc *healthcheck.Sender) { func (p *Peer) handleHealthcheckEvents(ctx context.Context, hc *healthcheck.Sender) {
for { for {
select { select {
case <-hc.HealthCheck: case <-hc.HealthCheck:
_, err := p.Write(messages.MarshalHealthcheck()) _, err := p.Write(ctx, messages.MarshalHealthcheck())
if err != nil { if err != nil {
p.log.Errorf("failed to send healthcheck message: %s", err) p.log.Errorf("failed to send healthcheck message: %s", err)
return return
@@ -228,12 +226,12 @@ func (p *Peer) handleTransportMsg(msg []byte) {
return return
} }
n, err := dp.Write(msg) n, err := dp.Write(dp.ctx, msg)
if err != nil { if err != nil {
p.log.Errorf("failed to write transport message to: %s", dp.String()) p.log.Errorf("failed to write transport message to: %s", dp.String())
return return
} }
p.metrics.TransferBytesSent.Add(context.Background(), int64(n)) p.metrics.TransferBytesSent.Add(p.ctx, int64(n))
} }
func (p *Peer) handleSubscribePeerState(msg []byte) { func (p *Peer) handleSubscribePeerState(msg []byte) {
@@ -276,7 +274,7 @@ func (p *Peer) sendPeersOnline(peers []messages.PeerID) {
} }
for n, msg := range msgs { for n, msg := range msgs {
if _, err := p.Write(msg); err != nil { if _, err := p.Write(p.ctx, msg); err != nil {
p.log.Errorf("failed to write %d. peers offline message: %s", n, err) p.log.Errorf("failed to write %d. peers offline message: %s", n, err)
} }
} }
@@ -293,7 +291,7 @@ func (p *Peer) sendPeersWentOffline(peers []messages.PeerID) {
} }
for n, msg := range msgs { for n, msg := range msgs {
if _, err := p.Write(msg); err != nil { if _, err := p.Write(p.ctx, msg); err != nil {
p.log.Errorf("failed to write %d. peers offline message: %s", n, err) p.log.Errorf("failed to write %d. peers offline message: %s", n, err)
} }
} }

View File

@@ -3,7 +3,6 @@ package server
import ( import (
"context" "context"
"fmt" "fmt"
"net"
"net/url" "net/url"
"sync" "sync"
"time" "time"
@@ -13,11 +12,20 @@ import (
"go.opentelemetry.io/otel/metric" "go.opentelemetry.io/otel/metric"
"github.com/netbirdio/netbird/relay/healthcheck/peerid" "github.com/netbirdio/netbird/relay/healthcheck/peerid"
"github.com/netbirdio/netbird/relay/protocol"
"github.com/netbirdio/netbird/relay/server/listener"
//nolint:staticcheck //nolint:staticcheck
"github.com/netbirdio/netbird/relay/metrics" "github.com/netbirdio/netbird/relay/metrics"
"github.com/netbirdio/netbird/relay/server/store" "github.com/netbirdio/netbird/relay/server/store"
) )
type Listener interface {
Listen(func(conn listener.Conn)) error
Shutdown(ctx context.Context) error
Protocol() protocol.Protocol
}
type Config struct { type Config struct {
Meter metric.Meter Meter metric.Meter
ExposedAddress string ExposedAddress string
@@ -109,7 +117,7 @@ func NewRelay(config Config) (*Relay, error) {
} }
// Accept start to handle a new peer connection // Accept start to handle a new peer connection
func (r *Relay) Accept(conn net.Conn) { func (r *Relay) Accept(conn listener.Conn) {
acceptTime := time.Now() acceptTime := time.Now()
r.closeMu.RLock() r.closeMu.RLock()
defer r.closeMu.RUnlock() defer r.closeMu.RUnlock()
@@ -117,12 +125,15 @@ func (r *Relay) Accept(conn net.Conn) {
return return
} }
hsCtx, hsCancel := context.WithTimeout(context.Background(), handshakeTimeout)
defer hsCancel()
h := handshake{ h := handshake{
conn: conn, conn: conn,
validator: r.validator, validator: r.validator,
preparedMsg: r.preparedMsg, preparedMsg: r.preparedMsg,
} }
peerID, err := h.handshakeReceive() peerID, err := h.handshakeReceive(hsCtx)
if err != nil { if err != nil {
if peerid.IsHealthCheck(peerID) { if peerid.IsHealthCheck(peerID) {
log.Debugf("health check connection from %s", conn.RemoteAddr()) log.Debugf("health check connection from %s", conn.RemoteAddr())
@@ -154,7 +165,7 @@ func (r *Relay) Accept(conn net.Conn) {
r.metrics.PeerDisconnected(peer.String()) r.metrics.PeerDisconnected(peer.String())
}() }()
if err := h.handshakeResponse(); err != nil { if err := h.handshakeResponse(hsCtx); err != nil {
log.Errorf("failed to send handshake response, close peer: %s", err) log.Errorf("failed to send handshake response, close peer: %s", err)
peer.Close() peer.Close()
} }

View File

@@ -3,7 +3,6 @@ package server
import ( import (
"context" "context"
"crypto/tls" "crypto/tls"
"net"
"net/url" "net/url"
"sync" "sync"
@@ -31,7 +30,7 @@ type ListenerConfig struct {
// In a new HTTP connection, the server will accept the connection and pass it to the Relay server via the Accept method. // In a new HTTP connection, the server will accept the connection and pass it to the Relay server via the Accept method.
type Server struct { type Server struct {
relay *Relay relay *Relay
listeners []listener.Listener listeners []Listener
listenerMux sync.Mutex listenerMux sync.Mutex
} }
@@ -56,7 +55,7 @@ func NewServer(config Config) (*Server, error) {
} }
return &Server{ return &Server{
relay: relay, relay: relay,
listeners: make([]listener.Listener, 0, 2), listeners: make([]Listener, 0, 2),
}, nil }, nil
} }
@@ -86,7 +85,7 @@ func (r *Server) Listen(cfg ListenerConfig) error {
wg := sync.WaitGroup{} wg := sync.WaitGroup{}
for _, l := range r.listeners { for _, l := range r.listeners {
wg.Add(1) wg.Add(1)
go func(listener listener.Listener) { go func(listener Listener) {
defer wg.Done() defer wg.Done()
errChan <- listener.Listen(r.relay.Accept) errChan <- listener.Listen(r.relay.Accept)
}(l) }(l)
@@ -139,6 +138,6 @@ func (r *Server) InstanceURL() url.URL {
// RelayAccept returns the relay's Accept function for handling incoming connections. // RelayAccept returns the relay's Accept function for handling incoming connections.
// This allows external HTTP handlers to route connections to the relay without // This allows external HTTP handlers to route connections to the relay without
// starting the relay's own listeners. // starting the relay's own listeners.
func (r *Server) RelayAccept() func(conn net.Conn) { func (r *Server) RelayAccept() func(conn listener.Conn) {
return r.relay.Accept return r.relay.Accept
} }