Compare commits

...

7 Commits

35 changed files with 1505 additions and 97 deletions

View File

@@ -3,6 +3,7 @@ package iptables
import (
"errors"
"fmt"
"maps"
"net"
"slices"
@@ -421,12 +422,17 @@ func (m *aclManager) updateState() {
currentState.Lock()
defer currentState.Unlock()
// Clone the maps so the persisted state holds a private snapshot. The
// live maps keep being mutated by subsequent rule operations while the
// state manager marshals the state from its periodic-save goroutine.
// Sharing them by reference races the two and aborts the process with a
// concurrent map iteration and write.
if m.v6 {
currentState.ACLEntries6 = m.entries
currentState.ACLIPsetStore6 = m.ipsetStore
currentState.ACLEntries6 = maps.Clone(m.entries)
currentState.ACLIPsetStore6 = m.ipsetStore.clone()
} else {
currentState.ACLEntries = m.entries
currentState.ACLIPsetStore = m.ipsetStore
currentState.ACLEntries = maps.Clone(m.entries)
currentState.ACLIPsetStore = m.ipsetStore.clone()
}
if err := m.stateManager.UpdateState(currentState); err != nil {

View File

@@ -4,6 +4,7 @@ package iptables
import (
"fmt"
"maps"
"net/netip"
"strconv"
"strings"
@@ -749,11 +750,17 @@ func (r *router) updateState() {
currentState.Lock()
defer currentState.Unlock()
// Clone the rule map so the persisted state holds a private snapshot. The
// live map keeps being mutated by subsequent rule operations while the
// state manager marshals the state from its periodic-save goroutine.
// Sharing it by reference races the two and aborts the process with a
// concurrent map iteration and write. The ipset counter guards itself
// during marshaling, so it can be shared directly.
if r.v6 {
currentState.RouteRules6 = r.rules
currentState.RouteRules6 = maps.Clone(r.rules)
currentState.RouteIPsetCounter6 = r.ipsetCounter
} else {
currentState.RouteRules = r.rules
currentState.RouteRules = maps.Clone(r.rules)
currentState.RouteIPsetCounter = r.ipsetCounter
}

View File

@@ -1,6 +1,9 @@
package iptables
import "encoding/json"
import (
"encoding/json"
"maps"
)
type ipList struct {
ips map[string]struct{}
@@ -19,6 +22,14 @@ func (s *ipList) addIP(ip string) {
s.ips[ip] = struct{}{}
}
// clone returns a deep copy of the ipList with its own ips map.
func (s *ipList) clone() *ipList {
if s == nil {
return nil
}
return &ipList{ips: maps.Clone(s.ips)}
}
// MarshalJSON implements json.Marshaler
func (s *ipList) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
@@ -55,6 +66,19 @@ func newIpsetStore() *ipsetStore {
}
}
// clone returns a deep copy of the ipsetStore with its own ipsets map and
// independent ipList entries.
func (s *ipsetStore) clone() *ipsetStore {
if s == nil {
return nil
}
cloned := &ipsetStore{ipsets: make(map[string]*ipList, len(s.ipsets))}
for name, list := range s.ipsets {
cloned.ipsets[name] = list.clone()
}
return cloned
}
func (s *ipsetStore) ipset(ipsetName string) (*ipList, bool) {
r, ok := s.ipsets[ipsetName]
return r, ok

View File

@@ -777,13 +777,24 @@ func (s *DefaultServer) applyHostConfig() {
// context is released rather than leaked until GC.
func (s *DefaultServer) registerFallback() {
originalNameservers := s.hostManager.getOriginalNameservers()
if len(originalNameservers) == 0 {
serverIP := s.service.RuntimeIP()
var servers []netip.AddrPort
for _, ns := range originalNameservers {
if ns == serverIP {
log.Debugf("skipping original nameserver %s as it is the same as the server IP %s", ns, serverIP)
continue
}
servers = append(servers, netip.AddrPortFrom(ns, DefaultPort))
}
if len(servers) == 0 {
log.Debugf("no fallback upstreams to register; clearing PriorityFallback handler")
s.clearFallback()
return
}
log.Infof("registering original nameservers %v as upstream handlers with priority %d", originalNameservers, PriorityFallback)
log.Infof("registering original nameservers %v as upstream handlers with priority %d", servers, PriorityFallback)
handler, err := newUpstreamResolver(
s.ctx,
@@ -797,11 +808,6 @@ func (s *DefaultServer) registerFallback() {
return
}
handler.selectedRoutes = s.selectedRoutes
var servers []netip.AddrPort
for _, ns := range originalNameservers {
servers = append(servers, netip.AddrPortFrom(ns, DefaultPort))
}
handler.addRace(servers)
prev := s.fallbackHandler

View File

@@ -700,6 +700,13 @@ func resolveURLsToIPs(urls []string) []net.IP {
// updateRouteSelectorFromManagement updates the route selector based on the isSelected status from the management server
func (m *DefaultManager) updateRouteSelectorFromManagement(clientRoutes route.HAMap) {
// An explicit user "deselect all" must not be overridden by management auto-apply.
// Auto-applying an exit node here would call SelectRoutes, which clears the
// deselect-all flag and re-enables every route the user turned off.
if m.routeSelector.IsDeselectAll() {
return
}
exitNodeInfo := m.collectExitNodeInfo(clientRoutes)
if len(exitNodeInfo.allIDs) == 0 {
return

View File

@@ -0,0 +1,71 @@
package routemanager
import (
"net/netip"
"testing"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/internal/routeselector"
"github.com/netbirdio/netbird/route"
)
func exitNodeRoutes(netID route.NetID, skipAutoApply bool) route.HAMap {
haID := route.HAUniqueID(string(netID) + "|0.0.0.0/0")
return route.HAMap{
haID: []*route.Route{
{
ID: "r-" + route.ID(netID),
NetID: netID,
Network: netip.MustParsePrefix("0.0.0.0/0"),
NetworkType: route.IPv4Network,
Enabled: true,
SkipAutoApply: skipAutoApply,
},
},
}
}
func TestUpdateRouteSelectorFromManagement(t *testing.T) {
t.Run("management auto-apply selects exit node without user selection", func(t *testing.T) {
m := &DefaultManager{routeSelector: routeselector.NewRouteSelector()}
routes := exitNodeRoutes("exit1", false)
m.updateRouteSelectorFromManagement(routes)
require.True(t, m.routeSelector.IsSelected("exit1"), "auto-apply exit node should be selected")
require.Len(t, m.routeSelector.FilterSelectedExitNodes(routes), 1, "selected exit node should pass the filter")
})
t.Run("management SkipAutoApply leaves exit node deselected", func(t *testing.T) {
m := &DefaultManager{routeSelector: routeselector.NewRouteSelector()}
routes := exitNodeRoutes("exit1", true)
m.updateRouteSelectorFromManagement(routes)
require.False(t, m.routeSelector.IsSelected("exit1"), "SkipAutoApply exit node should not be selected")
require.Empty(t, m.routeSelector.FilterSelectedExitNodes(routes), "deselected exit node should be filtered out")
})
t.Run("user selection is not overridden by management", func(t *testing.T) {
m := &DefaultManager{routeSelector: routeselector.NewRouteSelector()}
require.NoError(t, m.routeSelector.SelectRoutes([]route.NetID{"exit1"}, true, []route.NetID{"exit1"}))
routes := exitNodeRoutes("exit1", true)
m.updateRouteSelectorFromManagement(routes)
require.True(t, m.routeSelector.IsSelected("exit1"), "explicit user selection must survive a management sync that wants to skip auto-apply")
require.Len(t, m.routeSelector.FilterSelectedExitNodes(routes), 1, "user-selected exit node should pass the filter")
})
t.Run("deselect-all is preserved across a management sync", func(t *testing.T) {
m := &DefaultManager{routeSelector: routeselector.NewRouteSelector()}
m.routeSelector.DeselectAllRoutes()
routes := exitNodeRoutes("exit1", false)
m.updateRouteSelectorFromManagement(routes)
require.True(t, m.routeSelector.IsDeselectAll(), "an explicit deselect-all must not be cleared by management auto-apply")
require.Empty(t, m.routeSelector.FilterSelectedExitNodes(routes), "no routes should be selected while deselect-all is set")
})
}

View File

@@ -116,6 +116,14 @@ func (rs *RouteSelector) DeselectAllRoutes() {
clear(rs.selectedRoutes)
}
// IsDeselectAll reports whether the user has explicitly deselected all routes.
func (rs *RouteSelector) IsDeselectAll() bool {
rs.mu.RLock()
defer rs.mu.RUnlock()
return rs.deselectAll
}
// IsSelected checks if a specific route is selected.
func (rs *RouteSelector) IsSelected(routeID route.NetID) bool {
rs.mu.RLock()

View File

@@ -99,6 +99,9 @@ func addFields(entry *logrus.Entry) {
if ctxAccountID, ok := entry.Context.Value(context.AccountIDKey).(string); ok {
entry.Data[context.AccountIDKey] = ctxAccountID
}
if ctxUserAgent, ok := entry.Context.Value(context.UserAgentKey).(string); ok {
entry.Data[context.UserAgentKey] = ctxUserAgent
}
if ctxInitiatorID, ok := entry.Context.Value(context.UserIDKey).(string); ok {
entry.Data[context.UserIDKey] = ctxInitiatorID
}

View File

@@ -666,8 +666,10 @@ func (s *ProxyServiceServer) sender(conn *proxyConnection, errChan chan<- error)
case resp := <-conn.sendChan:
if err := conn.sendResponse(resp); err != nil {
errChan <- err
log.WithContext(conn.ctx).Tracef("Failed to send response to proxy %s: %v", conn.proxyID, err)
return
}
log.WithContext(conn.ctx).Tracef("Send response to proxy %s", conn.proxyID)
case <-conn.ctx.Done():
return
}

View File

@@ -12,6 +12,7 @@ const (
RoleKey = nbcontext.RoleKey
UserIDKey = nbcontext.UserIDKey
PeerIDKey = nbcontext.PeerIDKey
UserAgentKey = nbcontext.UserAgentKey
)
// RoleFromContext returns the role stored in ctx, or empty string and false if absent.

View File

@@ -21,6 +21,8 @@ const (
httpRequestCounterPrefix = "management.http.request.counter"
httpResponseCounterPrefix = "management.http.response.counter"
httpRequestDurationPrefix = "management.http.request.duration.ms"
RequestIDHeader = "X-Request-Id"
)
// WrappedResponseWriter is a wrapper for http.ResponseWriter that allows the
@@ -172,6 +174,10 @@ func (m *HTTPMiddleware) Handler(h http.Handler) http.Handler {
reqID := xid.New().String()
//nolint
ctx = context.WithValue(ctx, nbContext.RequestIDKey, reqID)
//nolint
ctx = context.WithValue(ctx, nbContext.UserAgentKey, r.UserAgent())
rw.Header().Set(RequestIDHeader, reqID)
log.WithContext(ctx).Tracef("HTTP request %v: %v %v", reqID, r.Method, r.URL)

View File

@@ -557,7 +557,6 @@ func (c *NetworkMapComponents) getRoutingPeerRoutes(peerID string) (enabledRoute
return enabledRoutes, disabledRoutes
}
func (c *NetworkMapComponents) filterRoutesByGroups(routes []*route.Route, groupListMap LookupMap) []*route.Route {
var filteredRoutes []*route.Route
for _, r := range routes {
@@ -628,9 +627,14 @@ func (c *NetworkMapComponents) getDefaultPermit(r *route.Route, includeIPv6 bool
rules := []*RouteFirewallRule{&rule}
if includeIPv6 && r.IsDynamic() {
isDefaultV4 := r.Network.Addr().Is4() && r.Network.Bits() == 0
if includeIPv6 && (r.IsDynamic() || isDefaultV4) {
ruleV6 := rule
ruleV6.SourceRanges = []string{"::/0"}
if isDefaultV4 {
ruleV6.Destination = "::/0"
ruleV6.RouteID = r.ID + "-v6-default"
}
rules = append(rules, &ruleV6)
}

View File

@@ -5,6 +5,7 @@ import (
"fmt"
"net"
"net/netip"
"slices"
"testing"
"time"
@@ -1029,6 +1030,48 @@ func TestComponents_RouteDefaultPermit(t *testing.T) {
assert.True(t, hasDefaultPermit, "route without ACG should have default permit rule with 0.0.0.0/0 source")
}
// TestComponents_ExitNodeDefaultPermitIPv6 verifies that a default exit node route
// (0.0.0.0/0) without AccessControlGroups also emits an IPv6 default permit rule
// (::/0 source and destination) for peers that support IPv6, mirroring the route
// the client installs. Without it, IPv6 traffic is routed to the exit node but
// dropped at the forward chain.
func TestComponents_ExitNodeDefaultPermitIPv6(t *testing.T) {
account, validatedPeers := scalableTestAccount(20, 2)
routingPeerID := "peer-5"
routingPeer := account.Peers[routingPeerID]
routingPeer.IPv6 = netip.MustParseAddr("fd00::5")
routingPeer.Meta.Capabilities = append(routingPeer.Meta.Capabilities, nbpeer.PeerCapabilityIPv6Overlay)
account.Routes["route-exit"] = &route.Route{
ID: "route-exit", Network: netip.MustParsePrefix("0.0.0.0/0"),
PeerID: routingPeerID, Peer: routingPeer.Key,
Enabled: true, Groups: []string{"group-all"}, PeerGroups: []string{"group-0"},
AccessControlGroups: []string{},
AccountID: "test-account",
}
nm := componentsNetworkMap(account, routingPeerID, validatedPeers)
require.NotNil(t, nm)
hasV4 := false
hasV6 := false
for _, rfr := range nm.RoutesFirewallRules {
switch rfr.Destination {
case "0.0.0.0/0":
if slices.Contains(rfr.SourceRanges, "0.0.0.0/0") {
hasV4 = true
}
case "::/0":
if slices.Contains(rfr.SourceRanges, "::/0") {
hasV6 = true
}
}
}
assert.True(t, hasV4, "exit node route should have an IPv4 default permit rule (0.0.0.0/0)")
assert.True(t, hasV6, "exit node route should have an IPv6 default permit rule (::/0)")
}
// ──────────────────────────────────────────────────────────────────────────────
// 15. MULTIPLE ROUTERS PER NETWORK
// ──────────────────────────────────────────────────────────────────────────────

View File

@@ -249,6 +249,7 @@ func runServer(cmd *cobra.Command, args []string) error {
Private: private,
MaxDialTimeout: maxDialTimeout,
MaxSessionIdleTimeout: maxSessionIdleTimeout,
MappingBatchWatchdog: envDurationOrDefault("NB_PROXY_MAPPING_BATCH_WATCHDOG", 0),
GeoDataDir: geoDataDir,
CrowdSecAPIURL: crowdsecAPIURL,
CrowdSecAPIKey: crowdsecAPIKey,

View File

@@ -28,6 +28,10 @@ import (
const deviceNamePrefix = "ingress-proxy-"
const clientStopTimeout = 30 * time.Second
const createProxyPeerTimeout = 30 * time.Second
// backendKey identifies a backend by its host:port from the target URL.
type backendKey string
@@ -162,6 +166,7 @@ type NetBird struct {
clientsMux sync.RWMutex
clients map[types.AccountID]*clientEntry
lifecycleMu sync.Map
initLogOnce sync.Once
statusNotifier statusNotifier
// readyHandler runs after the embedded client for an account reports
@@ -177,6 +182,10 @@ type NetBird struct {
// (i.e. when a new client was actually created, not when an existing one
// was reused). The duration covers keygen + gRPC CreateProxyPeer + embed.New.
OnAddPeer func(d time.Duration, err error)
// startClient runs the post-create client startup. Nil uses runClientStartup;
// tests override it to avoid a real embed client.Start.
startClient func(accountID types.AccountID, client *embed.Client)
}
// ClientDebugInfo contains debug information about a client.
@@ -200,31 +209,20 @@ type skipTLSVerifyContextKey struct{}
func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, key ServiceKey, authToken string, serviceID types.ServiceID) error {
si := serviceInfo{serviceID: serviceID}
n.clientsMux.Lock()
if n.registerExistingClient(accountID, key, si) {
return nil
}
entry, exists := n.clients[accountID]
if exists {
entry.services[key] = si
started := entry.started
n.clientsMux.Unlock()
n.logger.WithFields(log.Fields{
"account_id": accountID,
"service_key": key,
}).Debug("registered service with existing client")
if started && n.statusNotifier != nil {
// Use a background context, not the caller's: the management
// connection notification must land even if the request /
// stream that triggered this registration is cancelled.
// Mirrors the async runClientStartup path.
if err := n.statusNotifier.NotifyStatus(context.Background(), accountID, serviceID, true); err != nil {
n.logger.WithFields(log.Fields{
"account_id": accountID,
"service_key": key,
}).WithError(err).Warn("failed to notify status for existing client")
}
lifecycle := n.accountLifecycle(accountID)
lifecycle.Lock()
transferred := false
defer func() {
if !transferred {
lifecycle.Unlock()
}
}()
if n.registerExistingClient(accountID, key, si) {
return nil
}
@@ -234,10 +232,10 @@ func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, key Se
n.OnAddPeer(time.Since(createStart), err)
}
if err != nil {
n.clientsMux.Unlock()
return err
}
n.clientsMux.Lock()
n.clients[accountID] = entry
n.clientsMux.Unlock()
@@ -246,17 +244,64 @@ func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, key Se
"service_key": key,
}).Info("created new client for account")
// Attempt to start the client in the background; if this fails we will
// retry on the first request via RoundTrip. runClientStartup uses its
// own background context so the caller's request-scoped ctx can't
// cancel the inbound bring-up.
go n.runClientStartup(accountID, entry.client)
transferred = true
go func() {
defer lifecycle.Unlock()
n.startClientStartup(accountID, entry.client)
}()
return nil
}
func (n *NetBird) startClientStartup(accountID types.AccountID, client *embed.Client) {
if n.startClient != nil {
n.startClient(accountID, client)
return
}
n.runClientStartup(accountID, client)
}
// registerExistingClient registers the service against an already-present
// client for the account and returns true when it did. It notifies management
// of the new service when the client is already started.
func (n *NetBird) registerExistingClient(accountID types.AccountID, key ServiceKey, si serviceInfo) bool {
n.clientsMux.Lock()
entry, exists := n.clients[accountID]
if !exists {
n.clientsMux.Unlock()
return false
}
entry.services[key] = si
started := entry.started
n.clientsMux.Unlock()
n.logger.WithFields(log.Fields{
"account_id": accountID,
"service_key": key,
}).Debug("registered service with existing client")
if started && n.statusNotifier != nil {
if err := n.statusNotifier.NotifyStatus(context.Background(), accountID, si.serviceID, true); err != nil {
n.logger.WithFields(log.Fields{
"account_id": accountID,
"service_key": key,
}).WithError(err).Warn("failed to notify status for existing client")
}
}
return true
}
// accountLifecycle returns the per-account lifecycle mutex, serialising client
// creation against teardown so a slow client.Stop cannot race a new
// client.Start for the same account, without blocking clientsMux.
func (n *NetBird) accountLifecycle(accountID types.AccountID) *sync.Mutex {
mu, _ := n.lifecycleMu.LoadOrStore(accountID, &sync.Mutex{})
return mu.(*sync.Mutex)
}
// createClientEntry generates a WireGuard keypair, authenticates with management,
// and creates an embedded NetBird client. Must be called with clientsMux held.
// and creates an embedded NetBird client. Must be called with the account's
// lifecycle mutex held.
func (n *NetBird) createClientEntry(ctx context.Context, accountID types.AccountID, key ServiceKey, authToken string, si serviceInfo) (*clientEntry, error) {
serviceID := si.serviceID
n.logger.WithFields(log.Fields{
@@ -276,7 +321,9 @@ func (n *NetBird) createClientEntry(ctx context.Context, accountID types.Account
"public_key": publicKey.String(),
}).Debug("authenticating new proxy peer with management")
resp, err := n.mgmtClient.CreateProxyPeer(ctx, &proto.CreateProxyPeerRequest{
createCtx, cancel := context.WithTimeout(ctx, createProxyPeerTimeout)
defer cancel()
resp, err := n.mgmtClient.CreateProxyPeer(createCtx, &proto.CreateProxyPeerRequest{
ServiceId: string(serviceID),
AccountId: string(accountID),
Token: authToken,
@@ -444,6 +491,15 @@ func (n *NetBird) notifyClientReady(accountID types.AccountID, client *embed.Cli
// RemovePeer unregisters a service from an account. The client is only stopped
// when no services are using it anymore.
func (n *NetBird) RemovePeer(ctx context.Context, accountID types.AccountID, key ServiceKey) error {
lifecycle := n.accountLifecycle(accountID)
lifecycle.Lock()
transferred := false
defer func() {
if !transferred {
lifecycle.Unlock()
}
}()
n.clientsMux.Lock()
entry, exists := n.clients[accountID]
@@ -466,17 +522,8 @@ func (n *NetBird) RemovePeer(ctx context.Context, accountID types.AccountID, key
delete(entry.services, key)
stopClient := len(entry.services) == 0
var client *embed.Client
var transport, insecureTransport *http.Transport
var inbound any
var stopHandler func(types.AccountID, any)
if stopClient {
n.logger.WithField("account_id", accountID).Info("stopping client, no more services")
client = entry.client
transport = entry.transport
insecureTransport = entry.insecureTransport
inbound = entry.inbound
stopHandler = n.stopHandler
delete(n.clients, accountID)
} else {
n.logger.WithFields(log.Fields{
@@ -490,19 +537,40 @@ func (n *NetBird) RemovePeer(ctx context.Context, accountID types.AccountID, key
n.notifyDisconnect(ctx, accountID, key, si.serviceID)
if stopClient {
if inbound != nil && stopHandler != nil {
stopHandler(accountID, inbound)
}
transport.CloseIdleConnections()
insecureTransport.CloseIdleConnections()
if err := client.Stop(ctx); err != nil {
n.logger.WithField("account_id", accountID).WithError(err).Warn("failed to stop netbird client")
}
transferred = true
go n.stopClientLocked(accountID, lifecycle, entry)
}
return nil
}
// stopClientLocked releases a client's resources off the caller's goroutine so a
// slow client.Stop cannot wedge the mapping receive loop (which calls RemovePeer
// synchronously). It unlocks lifecycle when done so a new client.Start for the
// same account waits for this teardown.
func (n *NetBird) stopClientLocked(accountID types.AccountID, lifecycle *sync.Mutex, entry *clientEntry) {
defer lifecycle.Unlock()
if entry.inbound != nil && n.stopHandler != nil {
n.stopHandler(accountID, entry.inbound)
}
if entry.transport != nil {
entry.transport.CloseIdleConnections()
}
if entry.insecureTransport != nil {
entry.insecureTransport.CloseIdleConnections()
}
if entry.client == nil {
return
}
ctx, cancel := context.WithTimeout(context.Background(), clientStopTimeout)
defer cancel()
if err := entry.client.Stop(ctx); err != nil {
n.logger.WithField("account_id", accountID).WithError(err).Warn("failed to stop netbird client")
}
}
func (n *NetBird) notifyDisconnect(ctx context.Context, accountID types.AccountID, key ServiceKey, serviceID types.ServiceID) {
if n.statusNotifier == nil {
return

View File

@@ -6,6 +6,7 @@ import (
"net/netip"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -22,6 +23,18 @@ func (m *mockMgmtClient) CreateProxyPeer(_ context.Context, _ *proto.CreateProxy
return &proto.CreateProxyPeerResponse{Success: true}, nil
}
// signalMgmtClient closes entered the first time CreateProxyPeer is called, so
// tests can detect AddPeer reaching client creation.
type signalMgmtClient struct {
entered chan struct{}
once sync.Once
}
func (m *signalMgmtClient) CreateProxyPeer(_ context.Context, _ *proto.CreateProxyPeerRequest, _ ...grpc.CallOption) (*proto.CreateProxyPeerResponse, error) {
m.once.Do(func() { close(m.entered) })
return &proto.CreateProxyPeerResponse{Success: true}, nil
}
type mockStatusNotifier struct {
mu sync.Mutex
statuses []statusCall
@@ -52,11 +65,15 @@ func (m *mockStatusNotifier) calls() []statusCall {
// mockNetBird creates a NetBird instance for testing without actually connecting.
// It uses an invalid management URL to prevent real connections.
func mockNetBird() *NetBird {
return NewNetBird(context.Background(), "test-proxy", "invalid.test", ClientConfig{
nb := NewNetBird(context.Background(), "test-proxy", "invalid.test", ClientConfig{
MgmtAddr: "http://invalid.test:9999",
WGPort: 0,
PreSharedKey: "",
}, nil, nil, &mockMgmtClient{})
// Skip the real embed client.Start, which would hang against the unreachable
// mgmt URL and (now that the lifecycle lock spans startup) serialise removes.
nb.startClient = func(types.AccountID, *embed.Client) {}
return nb
}
func TestNetBird_AddPeer_CreatesClientForNewAccount(t *testing.T) {
@@ -288,6 +305,7 @@ func TestNetBird_AddPeer_ExistingStartedClient_NotifiesStatus(t *testing.T) {
WGPort: 0,
PreSharedKey: "",
}, nil, notifier, &mockMgmtClient{})
nb.startClient = func(types.AccountID, *embed.Client) {}
accountID := types.AccountID("account-1")
// Add first service — creates a new client entry.
@@ -372,6 +390,117 @@ func TestNetBird_RemovePeer_NotifiesDisconnection(t *testing.T) {
assert.False(t, calls[0].connected)
}
// TestNetBird_RemovePeer_TeardownIsAsync proves the fix for the receive-loop
// stall: RemovePeer must return promptly even when the client teardown blocks,
// because teardown runs off the caller's goroutine. The receive loop calls
// RemovePeer synchronously, so a blocking teardown inline would wedge it.
func TestNetBird_RemovePeer_TeardownIsAsync(t *testing.T) {
nb := NewNetBird(context.Background(), "test-proxy", "invalid.test", ClientConfig{
MgmtAddr: "http://invalid.test:9999",
}, nil, &mockStatusNotifier{}, &mockMgmtClient{})
accountID := types.AccountID("acct-async-teardown")
key := DomainServiceKey("svc.example")
teardownEntered := make(chan struct{})
releaseTeardown := make(chan struct{})
nb.SetClientLifecycle(nil, func(types.AccountID, any) {
close(teardownEntered)
<-releaseTeardown
})
nb.clientsMux.Lock()
nb.clients[accountID] = &clientEntry{
services: map[ServiceKey]serviceInfo{key: {serviceID: types.ServiceID("svc-1")}},
started: true,
inbound: struct{}{},
}
nb.clientsMux.Unlock()
done := make(chan error, 1)
go func() { done <- nb.RemovePeer(context.Background(), accountID, key) }()
select {
case err := <-done:
require.NoError(t, err)
case <-time.After(2 * time.Second):
t.Fatal("RemovePeer did not return while teardown was blocked — teardown is not async")
}
select {
case <-teardownEntered:
case <-time.After(2 * time.Second):
t.Fatal("teardown never ran")
}
close(releaseTeardown)
}
// TestNetBird_AddPeer_WaitsForTeardown proves the lifecycle lock serialises a
// new client bringup behind an in-flight teardown for the same account, so a
// slow client.Stop can never race a new client.Start for that account.
//
// It targets the handoff race specifically: AddPeer is launched immediately
// after RemovePeer returns, WITHOUT waiting for the teardown goroutine to start.
// This only passes if RemovePeer acquires the lifecycle lock synchronously
// (before returning) and hands it to the teardown goroutine — if the goroutine
// acquired the lock itself, AddPeer could win the lock in this window and start
// a replacement client while the old teardown is still pending.
func TestNetBird_AddPeer_WaitsForTeardown(t *testing.T) {
nb := NewNetBird(context.Background(), "test-proxy", "invalid.test", ClientConfig{
MgmtAddr: "http://invalid.test:9999",
}, nil, &mockStatusNotifier{}, &mockMgmtClient{})
nb.startClient = func(types.AccountID, *embed.Client) {}
accountID := types.AccountID("acct-serialize")
key := DomainServiceKey("svc.example")
addEntered := make(chan struct{})
releaseTeardown := make(chan struct{})
nb.SetClientLifecycle(nil, func(types.AccountID, any) {
// Block teardown until released. If AddPeer ever reaches createClientEntry
// (signalled via the mgmt client below) while we hold the lock, the lock
// failed to serialise and the test fails before we release.
<-releaseTeardown
})
nb.clientsMux.Lock()
nb.clients[accountID] = &clientEntry{
services: map[ServiceKey]serviceInfo{key: {serviceID: types.ServiceID("svc-1")}},
started: true,
inbound: struct{}{},
}
nb.clientsMux.Unlock()
// createClientEntry calls CreateProxyPeer; closing addEntered there tells us
// AddPeer got past the lifecycle lock and into client creation.
nb.mgmtClient = &signalMgmtClient{entered: addEntered}
require.NoError(t, nb.RemovePeer(context.Background(), accountID, key))
// Launch AddPeer with NO synchronisation against the teardown goroutine.
addReturned := make(chan struct{})
go func() {
_ = nb.AddPeer(context.Background(), accountID, DomainServiceKey("svc2.example"), "key-2", types.ServiceID("svc-2"))
close(addReturned)
}()
select {
case <-addEntered:
t.Fatal("AddPeer entered client creation while teardown held the lifecycle lock — handoff race not closed")
case <-addReturned:
t.Fatal("AddPeer completed while teardown held the lifecycle lock — not serialised")
case <-time.After(300 * time.Millisecond):
}
close(releaseTeardown)
select {
case <-addReturned:
case <-time.After(2 * time.Second):
t.Fatal("AddPeer never completed after teardown released the lifecycle lock")
}
}
// TestNotifyClientReady_UsesBackgroundCtx pins the contract that the
// post-Start hooks (readyHandler + statusNotifier.NotifyStatus) run on
// a fresh context.Background() rather than inheriting the AddPeer

View File

@@ -114,6 +114,10 @@ type Config struct {
MaxDialTimeout time.Duration
// MaxSessionIdleTimeout caps the per-service session idle timeout.
MaxSessionIdleTimeout time.Duration
// MappingBatchWatchdog bounds how long a single mapping batch may spend
// being applied before the receive loop reconnects to resync. Zero falls
// back to the internal default.
MappingBatchWatchdog time.Duration
// GeoDataDir is the directory containing GeoLite2 MMDB files.
GeoDataDir string
@@ -164,6 +168,7 @@ func New(ctx context.Context, cfg Config) *Server {
Private: cfg.Private,
MaxDialTimeout: cfg.MaxDialTimeout,
MaxSessionIdleTimeout: cfg.MaxSessionIdleTimeout,
MappingBatchWatchdog: cfg.MappingBatchWatchdog,
GeoDataDir: cfg.GeoDataDir,
CrowdSecAPIURL: cfg.CrowdSecAPIURL,
CrowdSecAPIKey: cfg.CrowdSecAPIKey,

282
proxy/mapping_stall_test.go Normal file
View File

@@ -0,0 +1,282 @@
package proxy
import (
"context"
"sync"
"sync/atomic"
"testing"
"time"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"
"github.com/netbirdio/netbird/proxy/internal/roundtrip"
"github.com/netbirdio/netbird/proxy/internal/types"
"github.com/netbirdio/netbird/shared/management/proto"
)
// blockingMgmtClient implements roundtrip's managementClient interface.
// CreateProxyPeer parks until release is closed, signalling entry on entered.
// This reproduces the confirmed real-world stall: createClientEntry calls
// CreateProxyPeer synchronously while holding clientsMux, and the proxy's
// receive loop calls that path synchronously inside processMappings.
type blockingMgmtClient struct {
entered chan struct{}
once sync.Once
}
func (b *blockingMgmtClient) CreateProxyPeer(ctx context.Context, _ *proto.CreateProxyPeerRequest, _ ...grpc.CallOption) (*proto.CreateProxyPeerResponse, error) {
b.once.Do(func() { close(b.entered) })
// Park until the caller's context is cancelled. In production this ctx is
// the gRPC mapping-stream context with no per-call timeout, so a slow or
// unresponsive CreateProxyPeer parks the receive loop here indefinitely.
<-ctx.Done()
return nil, ctx.Err()
}
// gatedMappingStream is a mock GetMappingUpdate client stream that hands out a
// pre-seeded list of messages, then records how many times Recv advanced. It
// lets the test observe whether the single-threaded receive loop ever gets
// past the first (blocking) batch to pull the second message.
type gatedMappingStream struct {
grpc.ClientStream
messages []*proto.GetMappingUpdateResponse
idx int32
}
func (g *gatedMappingStream) Recv() (*proto.GetMappingUpdateResponse, error) {
i := int(atomic.LoadInt32(&g.idx))
if i >= len(g.messages) {
// Block instead of returning EOF so the loop doesn't exit; we only
// care whether the loop ever reaches this second Recv at all.
select {}
}
msg := g.messages[i]
atomic.AddInt32(&g.idx, 1)
return msg, nil
}
func (g *gatedMappingStream) deliveredCount() int32 { return atomic.LoadInt32(&g.idx) }
func (g *gatedMappingStream) Header() (metadata.MD, error) { return nil, nil } //nolint:nilnil
func (g *gatedMappingStream) Trailer() metadata.MD { return nil }
func (g *gatedMappingStream) CloseSend() error { return nil }
func (g *gatedMappingStream) Context() context.Context { return context.Background() }
func (g *gatedMappingStream) SendMsg(any) error { return nil }
func (g *gatedMappingStream) RecvMsg(any) error { return nil }
// noopNotifier satisfies roundtrip's statusNotifier interface.
type noopNotifier struct{}
func (noopNotifier) NotifyStatus(context.Context, types.AccountID, types.ServiceID, bool) error {
return nil
}
// noopProxyClient is a proto.ProxyServiceClient that no-ops the one method the
// teardown unwind reaches (SendStatusUpdate, via notifyError when the parked
// AddPeer is cancelled). The embedded nil interface satisfies the rest at
// compile time; none of those methods are called by this test.
type noopProxyClient struct {
proto.ProxyServiceClient
}
func (noopProxyClient) SendStatusUpdate(context.Context, *proto.SendStatusUpdateRequest, ...grpc.CallOption) (*proto.SendStatusUpdateResponse, error) {
return &proto.SendStatusUpdateResponse{}, nil
}
// TestMappingStream_StallsWhenApplyBlocks proves the deadlock: the proxy's
// mapping receive loop processes batches strictly serially, so when applying
// one batch blocks (here: createClientEntry parked on a synchronous
// CreateProxyPeer call, exactly as observed in production), the loop never
// advances to Recv the next batch. Management can keep sending updates onto
// the stream with no error and no channel overflow, yet the proxy applies
// nothing further — it is stuck.
func TestMappingStream_StallsWhenApplyBlocks(t *testing.T) {
logger := log.New()
logger.SetLevel(log.PanicLevel)
mgmt := &blockingMgmtClient{
entered: make(chan struct{}),
}
nb := roundtrip.NewNetBird(
context.Background(),
"proxy-test",
"proxy.example.com",
roundtrip.ClientConfig{},
logger,
noopNotifier{},
mgmt,
)
s := &Server{
Logger: logger,
netbird: nb,
mgmtClient: noopProxyClient{},
routerReady: closedChan(),
lastMappings: make(map[types.ServiceID]*proto.ProxyMapping),
}
// First batch: a CREATED mapping for a brand-new account. addMapping ->
// netbird.AddPeer -> createClientEntry -> CreateProxyPeer, which blocks.
// Empty Path keeps setupHTTPMapping a no-op (it returns early), so the
// ONLY blocking point is the synchronous CreateProxyPeer in AddPeer —
// no routers/auth need wiring. The second batch exists only to detect
// whether the loop ever advances past the blocked first batch.
stream := &gatedMappingStream{
messages: []*proto.GetMappingUpdateResponse{
{
Mapping: []*proto.ProxyMapping{
{
Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED,
Id: "svc-1",
AccountId: "acct-1",
AuthToken: "token-1",
},
},
},
{
Mapping: []*proto.ProxyMapping{
{
Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED,
Id: "svc-2",
AccountId: "acct-2",
AuthToken: "token-2",
},
},
},
},
}
ctx, cancel := context.WithCancel(context.Background())
// Unblock the parked apply on teardown via ctx (CreateProxyPeer returns
// ctx.Err()), so the wedged loop goroutine unwinds before embed.New —
// avoiding any dependency on collaborators this test deliberately leaves
// nil. The deadlock is fully proven before this fires.
t.Cleanup(cancel)
loopDone := make(chan struct{})
syncDone := false
go func() {
defer close(loopDone)
_ = s.handleMappingStream(ctx, stream, &syncDone, time.Time{})
}()
// The loop must reach the blocking apply for the first batch.
select {
case <-mgmt.entered:
case <-time.After(2 * time.Second):
t.Fatal("receive loop never reached CreateProxyPeer for the first batch")
}
// THE DEADLOCK: while the first batch is parked in CreateProxyPeer, the
// single-threaded loop cannot advance. The second batch is never pulled,
// even though it is already available on the stream. Give it ample time.
// deliveredCount is atomic; syncDone is intentionally not read here because
// the loop goroutine owns it (reading it from the test would race).
time.Sleep(500 * time.Millisecond)
assert.Equal(t, int32(1), stream.deliveredCount(),
"loop must NOT consume the second batch while the first is blocked in apply — proxy is stuck")
select {
case <-loopDone:
t.Fatal("receive loop returned while it should be wedged in apply")
default:
// Still wedged, as expected.
}
}
// TestMappingStream_StallsWhenRemoveBlocks proves the deadlock for the REMOVE
// path observed in production: a mapping remove tears down the account's last
// embedded client via netbird.RemovePeer -> client.Stop -> Engine.Stop, whose
// jobExecutorWG.Wait() is unbounded. Because the receive loop is single-
// threaded, a blocked remove wedges the loop: no further mapping updates of any
// kind (create/modify/remove) are applied, while management keeps sending them
// successfully (no send error, no channel-full). Matches the reported symptom:
// the last log line is a remove that stops a client, then silence.
func TestMappingStream_StallsWhenRemoveBlocks(t *testing.T) {
logger := log.New()
logger.SetLevel(log.PanicLevel)
enteredRemove := make(chan struct{})
blockRemove := make(chan struct{})
var once sync.Once
s := &Server{
Logger: logger,
mgmtClient: noopProxyClient{},
routerReady: closedChan(),
lastMappings: make(map[types.ServiceID]*proto.ProxyMapping),
// Stand in for netbird.RemovePeer -> client.Stop hanging on
// Engine.Stop's unbounded jobExecutorWG.Wait(). Only the first remove
// blocks; later removes return immediately so the recovery assertion
// can observe the loop advancing.
removePeer: func(ctx context.Context, _ types.AccountID, _ roundtrip.ServiceKey) error {
first := false
once.Do(func() {
first = true
close(enteredRemove)
})
if !first {
return nil
}
select {
case <-blockRemove:
case <-ctx.Done():
}
return nil
},
}
// Batch 1 removes a service (blocks in teardown). Batch 2 is a later update
// that must never be applied while the remove is wedged.
stream := &gatedMappingStream{
messages: []*proto.GetMappingUpdateResponse{
{
Mapping: []*proto.ProxyMapping{
{Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED, Id: "svc-1", AccountId: "acct-1"},
},
},
{
Mapping: []*proto.ProxyMapping{
{Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED, Id: "svc-2", AccountId: "acct-1"},
},
},
},
}
loopDone := make(chan struct{})
syncDone := false
go func() {
defer close(loopDone)
_ = s.handleMappingStream(context.Background(), stream, &syncDone, time.Time{})
}()
select {
case <-enteredRemove:
case <-time.After(2 * time.Second):
t.Fatal("receive loop never reached the blocking remove for the first batch")
}
// THE DEADLOCK: the loop is parked in the blocked remove and cannot advance.
// syncDone is owned by the loop goroutine, so it is not read here.
time.Sleep(500 * time.Millisecond)
assert.Equal(t, int32(1), stream.deliveredCount(),
"loop must NOT consume the second batch while the first remove is blocked — proxy is stuck")
select {
case <-loopDone:
t.Fatal("receive loop returned while it should be wedged on the remove")
default:
}
// Unblock and confirm the wedge was solely the blocked remove: the loop
// then advances and consumes the next batch.
close(blockRemove)
assert.Eventually(t, func() bool {
return stream.deliveredCount() >= 2
}, 2*time.Second, 5*time.Millisecond,
"once the remove unblocks, the loop must advance and consume the next batch")
}

View File

@@ -118,6 +118,9 @@ type Server struct {
// The mapping worker waits on this before processing updates.
routerReady chan struct{}
// removePeer defaults to netbird.RemovePeer; overridable in tests.
removePeer func(ctx context.Context, accountID types.AccountID, key roundtrip.ServiceKey) error
// inbound, when non-nil, manages per-account inbound listeners. Set by
// initPrivateInbound only when Private is true so the standalone
// proxy keeps its zero-overhead default path.
@@ -227,6 +230,10 @@ type Server struct {
// Zero means no cap (the proxy honors whatever management sends).
// Set via NB_PROXY_MAX_SESSION_IDLE_TIMEOUT for shared deployments.
MaxSessionIdleTimeout time.Duration
// MappingBatchWatchdog bounds how long a single mapping batch may spend
// in processMappings before the receive loop reconnects to resync.
// Zero uses defaultMappingBatchWatchdog.
MappingBatchWatchdog time.Duration
}
// clampIdleTimeout returns d capped to MaxSessionIdleTimeout when configured.
@@ -1172,24 +1179,30 @@ func (s *Server) newManagementMappingWorker(ctx context.Context, client proto.Pr
s.healthChecker.SetManagementConnected(false)
}
connected := false
onConnected := func() { connected = true }
var streamErr error
if syncSupported {
streamErr = s.trySyncMappings(ctx, client, &initialSyncDone)
streamErr = s.trySyncMappings(ctx, client, &initialSyncDone, onConnected)
if isSyncUnimplemented(streamErr) {
syncSupported = false
s.Logger.Info("management does not support SyncMappings, falling back to GetMappingUpdate")
streamErr = s.tryGetMappingUpdate(ctx, client, &initialSyncDone)
streamErr = s.tryGetMappingUpdate(ctx, client, &initialSyncDone, onConnected)
}
} else {
streamErr = s.tryGetMappingUpdate(ctx, client, &initialSyncDone)
streamErr = s.tryGetMappingUpdate(ctx, client, &initialSyncDone, onConnected)
}
if s.healthChecker != nil {
s.healthChecker.SetManagementConnected(false)
}
// Stream established — reset backoff so the next failure retries quickly.
bo.Reset()
// Reset backoff only when a stream actually connected, so immediate
// connect failures still back off instead of spinning.
if connected {
bo.Reset()
}
if streamErr == nil {
return fmt.Errorf("stream closed by server")
@@ -1221,7 +1234,7 @@ func (s *Server) proxyCapabilities() *proto.ProxyCapabilities {
}
}
func (s *Server) tryGetMappingUpdate(ctx context.Context, client proto.ProxyServiceClient, initialSyncDone *bool) error {
func (s *Server) tryGetMappingUpdate(ctx context.Context, client proto.ProxyServiceClient, initialSyncDone *bool, onConnected func()) error {
connectTime := time.Now()
mappingClient, err := client.GetMappingUpdate(ctx, &proto.GetMappingUpdateRequest{
ProxyId: s.ID,
@@ -1234,6 +1247,7 @@ func (s *Server) tryGetMappingUpdate(ctx context.Context, client proto.ProxyServ
return fmt.Errorf("create mapping stream: %w", err)
}
onConnected()
if s.healthChecker != nil {
s.healthChecker.SetManagementConnected(true)
}
@@ -1242,7 +1256,7 @@ func (s *Server) tryGetMappingUpdate(ctx context.Context, client proto.ProxyServ
return s.handleMappingStream(ctx, mappingClient, initialSyncDone, connectTime)
}
func (s *Server) trySyncMappings(ctx context.Context, client proto.ProxyServiceClient, initialSyncDone *bool) error {
func (s *Server) trySyncMappings(ctx context.Context, client proto.ProxyServiceClient, initialSyncDone *bool, onConnected func()) error {
connectTime := time.Now()
stream, err := client.SyncMappings(ctx)
if err != nil {
@@ -1263,6 +1277,7 @@ func (s *Server) trySyncMappings(ctx context.Context, client proto.ProxyServiceC
return fmt.Errorf("send sync init: %w", err)
}
onConnected()
if s.healthChecker != nil {
s.healthChecker.SetManagementConnected(true)
}
@@ -1307,7 +1322,9 @@ func (s *Server) handleSyncMappingsStream(ctx context.Context, stream proto.Prox
batchStart := time.Now()
s.Logger.Debug("Received mapping update, starting processing")
s.processMappings(ctx, msg.GetMapping())
if err := s.processMappingsGuarded(ctx, msg.GetMapping()); err != nil {
return err
}
s.Logger.Debug("Processing mapping update completed")
tracker.recordBatch(ctx, s, msg.GetMapping(), msg.GetInitialSyncComplete(), batchStart)
@@ -1391,7 +1408,9 @@ func (s *Server) handleMappingStream(ctx context.Context, mappingClient proto.Pr
batchStart := time.Now()
s.Logger.Debug("Received mapping update, starting processing")
s.processMappings(ctx, msg.GetMapping())
if err := s.processMappingsGuarded(ctx, msg.GetMapping()); err != nil {
return err
}
s.Logger.Debug("Processing mapping update completed")
tracker.recordBatch(ctx, s, msg.GetMapping(), msg.GetInitialSyncComplete(), batchStart)
}
@@ -1456,6 +1475,44 @@ func redactMappingForLog(m *proto.ProxyMapping) *proto.ProxyMapping {
return c
}
const defaultMappingBatchWatchdog = 2 * time.Minute
// mappingBatchWatchdog returns the configured batch watchdog or the default.
func (s *Server) mappingBatchWatchdog() time.Duration {
if s.MappingBatchWatchdog > 0 {
return s.MappingBatchWatchdog
}
return defaultMappingBatchWatchdog
}
// processMappingsGuarded applies a batch under a watchdog, returning an error
// if processing exceeds the watchdog so the caller reconnects and resyncs
// instead of wedging silently.
func (s *Server) processMappingsGuarded(ctx context.Context, mappings []*proto.ProxyMapping) error {
batchCtx, cancel := context.WithCancel(ctx)
defer cancel()
done := make(chan struct{})
go func() {
defer close(done)
s.processMappings(batchCtx, mappings)
}()
watchdog := s.mappingBatchWatchdog()
timer := time.NewTimer(watchdog)
defer timer.Stop()
select {
case <-done:
return nil
case <-ctx.Done():
return ctx.Err()
case <-timer.C:
s.Logger.Errorf("processing mapping batch exceeded %s, cancelling and reconnecting to resync", watchdog)
return fmt.Errorf("mapping batch processing stalled after %s", watchdog)
}
}
func (s *Server) processMappings(ctx context.Context, mappings []*proto.ProxyMapping) {
debug := s.Logger != nil && s.Logger.IsLevelEnabled(log.DebugLevel)
for _, mapping := range mappings {
@@ -1951,7 +2008,11 @@ func (s *Server) updateMapping(ctx context.Context, mapping *proto.ProxyMapping)
func (s *Server) removeMapping(ctx context.Context, mapping *proto.ProxyMapping) {
accountID := types.AccountID(mapping.GetAccountId())
svcKey := s.serviceKeyForMapping(mapping)
if err := s.netbird.RemovePeer(ctx, accountID, svcKey); err != nil {
removePeer := s.removePeer
if removePeer == nil {
removePeer = s.netbird.RemovePeer
}
if err := removePeer(ctx, accountID, svcKey); err != nil {
s.Logger.WithFields(log.Fields{
"account_id": accountID,
"service_id": mapping.GetId(),

View File

@@ -6,4 +6,5 @@ const (
RoleKey = "role"
UserIDKey = "userID"
PeerIDKey = "peerID"
UserAgentKey = "userAgent"
)

View File

@@ -5107,31 +5107,63 @@ components:
responses:
not_found:
description: Resource not found
headers:
X-Request-Id:
$ref: '#/components/headers/X-Request-Id'
content: { }
validation_failed_simple:
description: Validation failed
headers:
X-Request-Id:
$ref: '#/components/headers/X-Request-Id'
content: { }
bad_request:
description: Bad Request
headers:
X-Request-Id:
$ref: '#/components/headers/X-Request-Id'
content: { }
internal_error:
description: Internal Server Error
headers:
X-Request-Id:
$ref: '#/components/headers/X-Request-Id'
content: { }
validation_failed:
description: Validation failed
headers:
X-Request-Id:
$ref: '#/components/headers/X-Request-Id'
content: { }
forbidden:
description: Forbidden
headers:
X-Request-Id:
$ref: '#/components/headers/X-Request-Id'
content: { }
requires_authentication:
description: Requires authentication
headers:
X-Request-Id:
$ref: '#/components/headers/X-Request-Id'
content: { }
conflict:
description: Conflict
headers:
X-Request-Id:
$ref: '#/components/headers/X-Request-Id'
content:
application/json:
schema:
$ref: '#/components/schemas/ErrorResponse'
headers:
X-Request-Id:
description: |
Unique identifier assigned to the request by the server and set on every
response. Useful for correlating client requests with server-side logs.
schema:
type: string
example: cot7r4n3l3vh3qj4qveg
securitySchemes:
BearerAuth:
type: http

View File

@@ -9,12 +9,14 @@ import (
"net/url"
"strings"
"sync"
"sync/atomic"
"time"
log "github.com/sirupsen/logrus"
auth "github.com/netbirdio/netbird/shared/relay/auth/hmac"
"github.com/netbirdio/netbird/shared/relay/client/dialer"
netErr "github.com/netbirdio/netbird/shared/relay/client/dialer/net"
"github.com/netbirdio/netbird/shared/relay/healthcheck"
"github.com/netbirdio/netbird/shared/relay/messages"
)
@@ -172,6 +174,19 @@ type Client struct {
stateSubscription *PeersStateSubscription
mtu uint16
// transportFallback, when set, records datagram-too-large failures so a
// datagram-sized transport is avoided on subsequent connects. Shared via
// the manager.
transportFallback *transportFallback
// datagramFallbackTriggered guards a single fallback per connection so a
// burst of oversized datagrams triggers one reconnect, not many.
datagramFallbackTriggered atomic.Bool
}
// SetTransportFallback wires the shared datagram-transport fallback tracker.
func (c *Client) SetTransportFallback(tf *transportFallback) {
c.transportFallback = tf
}
// NewClient creates a new client for the relay server. The client is not connected to the server until the Connect
@@ -361,12 +376,13 @@ func (c *Client) Close() error {
}
func (c *Client) connect(ctx context.Context) (*RelayAddr, error) {
dialers := c.getDialers()
mode := transportModeFromEnv()
dialers := c.getDialers(mode)
var conn net.Conn
if c.serverIP.IsValid() {
var err error
conn, err = c.dialRaceDirect(ctx, dialers)
conn, err = c.dialRaceDirect(ctx, mode, dialers)
if err != nil {
c.log.Infof("dial via server IP %s failed, falling back to FQDN: %v", c.serverIP, err)
conn = nil
@@ -375,6 +391,9 @@ func (c *Client) connect(ctx context.Context) (*RelayAddr, error) {
if conn == nil {
rd := dialer.NewRaceDial(c.log, dialer.DefaultConnectionTimeout, c.connectionURL, dialers...)
if mode.sequential() {
rd.WithSequential()
}
var err error
conn, err = rd.Dial(ctx)
if err != nil {
@@ -382,6 +401,7 @@ func (c *Client) connect(ctx context.Context) (*RelayAddr, error) {
}
}
c.relayConn = conn
c.datagramFallbackTriggered.Store(false)
instanceURL, err := c.handShake(ctx)
if err != nil {
@@ -396,7 +416,7 @@ func (c *Client) connect(ctx context.Context) (*RelayAddr, error) {
}
// dialRaceDirect dials c.serverIP, preserving the original FQDN as the TLS ServerName for SNI.
func (c *Client) dialRaceDirect(ctx context.Context, dialers []dialer.DialeFn) (net.Conn, error) {
func (c *Client) dialRaceDirect(ctx context.Context, mode TransportMode, dialers []dialer.DialeFn) (net.Conn, error) {
directURL, serverName, err := substituteHost(c.connectionURL, c.serverIP)
if err != nil {
return nil, fmt.Errorf("substitute host: %w", err)
@@ -406,6 +426,9 @@ func (c *Client) dialRaceDirect(ctx context.Context, dialers []dialer.DialeFn) (
rd := dialer.NewRaceDial(c.log, dialer.DefaultConnectionTimeout, directURL, dialers...).
WithServerName(serverName)
if mode.sequential() {
rd.WithSequential()
}
return rd.Dial(ctx)
}
@@ -631,13 +654,53 @@ func (c *Client) writeTo(containerRef *connContainer, dstID messages.PeerID, pay
}
// the write always return with 0 length because the underling does not support the size feedback.
_, err = c.relayConn.Write(msg)
conn := c.relayConn
_, err = conn.Write(msg)
if err != nil {
c.log.Errorf("failed to write transport message: %s", err)
if errors.Is(err, netErr.ErrDatagramTooLarge) {
c.onDatagramTooLarge(conn, err)
} else {
c.log.Errorf("failed to write transport message: %s", err)
}
}
return len(payload), err
}
// onDatagramTooLarge reacts to a datagram rejected as too large for the path.
// When a non-datagram transport is available, it records a fallback for this
// server and closes the connection so the reconnect avoids datagram-sized
// transports. A single fallback is triggered per connection regardless of how
// many oversized datagrams arrive. cause carries the datagram size and budget.
func (c *Client) onDatagramTooLarge(conn net.Conn, cause error) {
// Handle one oversized datagram per connection; a burst triggers a single
// fallback (and a single log line), not many.
if !c.datagramFallbackTriggered.CompareAndSwap(false, true) {
return
}
// If the selected mode offers no non-datagram transport (e.g. pinned to a
// datagram-sized transport), reconnecting would just re-fail, so leave the
// connection up rather than loop.
if len(nonDatagramSized(c.baseDialers(transportModeFromEnv()))) == 0 {
c.log.Warnf("%s, but no non-datagram transport is available, not falling back", cause)
return
}
// Without the shared tracker a reconnect would just select the same
// transport again and re-fail, so leave the connection up rather than loop.
if c.transportFallback == nil {
c.log.Debugf("%s, but no transport fallback configured, leaving connection up", cause)
return
}
window := c.transportFallback.recordFailure(c.connectionURL)
c.log.Warnf("%s, avoiding datagram-sized transport for %s", cause, window)
if err := conn.Close(); err != nil {
c.log.Debugf("close relay connection for transport fallback: %s", err)
}
}
func (c *Client) listenForStopEvents(ctx context.Context, hc *healthcheck.Receiver, conn net.Conn, internalStopFlag *internalStopFlag) {
for {
select {

View File

@@ -0,0 +1,18 @@
package dialer
// DatagramSized is implemented by dialers whose connections carry each write in
// a single datagram, so a write can be rejected when it exceeds the path's
// datagram budget (e.g. QUIC). Transports without this capability (e.g.
// WebSocket over TCP) impose no per-write size limit, so the relay client can
// fall back to them when a datagram-sized transport rejects a write as too
// large. The capability is advertised per dialer rather than hardcoded, so a
// new transport only needs to declare whether it is datagram-sized.
type DatagramSized interface {
DatagramSized()
}
// IsDatagramSized reports whether d produces datagram-sized connections.
func IsDatagramSized(d DialeFn) bool {
_, ok := d.(DatagramSized)
return ok
}

View File

@@ -4,4 +4,9 @@ import "errors"
var (
ErrClosedByServer = errors.New("closed by server")
// ErrDatagramTooLarge is returned when a transport message exceeds the
// QUIC datagram size the path to the relay can carry. The relay client
// treats it as a signal to fall back to a non-datagram transport.
ErrDatagramTooLarge = errors.New("datagram frame too large")
)

View File

@@ -8,7 +8,6 @@ import (
"time"
"github.com/quic-go/quic-go"
log "github.com/sirupsen/logrus"
netErr "github.com/netbirdio/netbird/shared/relay/client/dialer/net"
)
@@ -52,11 +51,8 @@ func (c *Conn) Read(b []byte) (n int, err error) {
}
func (c *Conn) Write(b []byte) (int, error) {
err := c.session.SendDatagram(b)
if err != nil {
err = c.remoteCloseErrHandling(err)
log.Errorf("failed to write to QUIC stream: %v", err)
return 0, err
if err := c.session.SendDatagram(b); err != nil {
return 0, c.writeErrHandling(err, len(b))
}
return len(b), nil
}
@@ -95,3 +91,15 @@ func (c *Conn) remoteCloseErrHandling(err error) error {
}
return err
}
// writeErrHandling normalizes SendDatagram errors. A datagram that exceeds the
// path's QUIC packet budget is mapped to ErrDatagramTooLarge (annotated with the
// datagram size and path budget) so the relay client can fall back to a
// non-datagram transport.
func (c *Conn) writeErrHandling(err error, size int) error {
var tooLarge *quic.DatagramTooLargeError
if errors.As(err, &tooLarge) {
return fmt.Errorf("%w: %d byte datagram over path budget %d", netErr.ErrDatagramTooLarge, size, tooLarge.MaxDatagramPayloadSize)
}
return c.remoteCloseErrHandling(err)
}

View File

@@ -9,6 +9,7 @@ import (
"time"
"github.com/quic-go/quic-go"
"github.com/quic-go/quic-go/logging"
log "github.com/sirupsen/logrus"
nbnet "github.com/netbirdio/netbird/client/net"
@@ -23,6 +24,12 @@ func (d Dialer) Protocol() string {
return Network
}
// DatagramSized marks QUIC as a datagram-sized transport: relay traffic is
// carried in QUIC DATAGRAM frames, which must fit a single packet.
func (d Dialer) DatagramSized() {
// Intentional marker method; presence is the capability signal.
}
func (d Dialer) Dial(ctx context.Context, address, serverName string) (net.Conn, error) {
quicURL, err := prepareURL(address)
if err != nil {
@@ -47,6 +54,7 @@ func (d Dialer) Dial(ctx context.Context, address, serverName string) (net.Conn,
MaxIdleTimeout: 4 * time.Minute,
EnableDatagrams: true,
InitialPacketSize: nbRelay.QUICInitialPacketSize,
Tracer: connectionTracer(quicURL),
}
udpConn, err := nbnet.ListenUDP("udp", &net.UDPAddr{Port: 0})
@@ -74,6 +82,28 @@ func (d Dialer) Dial(ctx context.Context, address, serverName string) (net.Conn,
return conn, nil
}
// connectionTracer returns a QUIC tracer that logs the DPLPMTUD result and the
// reason a relay connection closed, so the path MTU settled on and teardown
// cause are visible in logs. Lines carry the relay address as a structured
// field, matching the rest of the relay client logging.
func connectionTracer(addr string) func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer {
relayLog := log.WithField("relay", addr)
return func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer {
return &logging.ConnectionTracer{
UpdatedMTU: func(mtu logging.ByteCount, done bool) {
if done {
relayLog.Infof("QUIC path MTU settled at %d", mtu)
return
}
relayLog.Debugf("QUIC path MTU probing at %d", mtu)
},
ClosedConnection: func(err error) {
relayLog.Debugf("QUIC connection closed: %v", err)
},
}
}
}
func prepareURL(address string) (string, error) {
var host string
var defaultPort string

View File

@@ -32,6 +32,7 @@ type RaceDial struct {
serverName string
dialerFns []DialeFn
connectionTimeout time.Duration
sequential bool
}
func NewRaceDial(log *log.Entry, connectionTimeout time.Duration, serverURL string, dialerFns ...DialeFn) *RaceDial {
@@ -53,7 +54,21 @@ func (r *RaceDial) WithServerName(serverName string) *RaceDial {
return r
}
// WithSequential makes Dial try the dialers in order, falling back to the next
// only when one fails to connect, instead of racing them concurrently.
//
// Mutates the receiver and is not safe for concurrent reconfiguration; a
// RaceDial is intended to be constructed per dial and discarded.
func (r *RaceDial) WithSequential() *RaceDial {
r.sequential = true
return r
}
func (r *RaceDial) Dial(ctx context.Context) (net.Conn, error) {
if r.sequential {
return r.dialSequential(ctx)
}
connChan := make(chan dialResult, len(r.dialerFns))
winnerConn := make(chan net.Conn, 1)
abortCtx, abort := context.WithCancel(ctx)
@@ -72,6 +87,30 @@ func (r *RaceDial) Dial(ctx context.Context) (net.Conn, error) {
return conn, nil
}
// dialSequential tries each dialer in order, returning the first connection and
// falling back to the next on failure.
func (r *RaceDial) dialSequential(ctx context.Context) (net.Conn, error) {
for _, dfn := range r.dialerFns {
if err := ctx.Err(); err != nil {
return nil, err
}
attemptCtx, cancel := context.WithTimeout(ctx, r.connectionTimeout)
r.log.Infof("dialing Relay server via %s", dfn.Protocol())
conn, err := dfn.Dial(attemptCtx, r.serverURL, r.serverName)
cancel()
if err != nil {
if errors.Is(err, context.Canceled) {
return nil, err
}
r.log.Errorf("failed to dial via %s: %s", dfn.Protocol(), err)
continue
}
r.log.Infof("successfully dialed via: %s", dfn.Protocol())
return conn, nil
}
return nil, errors.New("failed to dial to Relay server on any protocol")
}
func (r *RaceDial) dial(dfn DialeFn, abortCtx context.Context, connChan chan dialResult) {
ctx, cancel := context.WithTimeout(abortCtx, r.connectionTimeout)
defer cancel()

View File

@@ -250,3 +250,66 @@ func TestRaceDialFirstSuccessfulDialerWins(t *testing.T) {
}
}
}
func TestRaceDialSequentialFallback(t *testing.T) {
logger := logrus.NewEntry(logrus.New())
serverURL := "test.server.com"
var firstDialed, secondDialed bool
preferred := &MockDialer{
protocolStr: "quic",
dialFunc: func(ctx context.Context, address string) (net.Conn, error) {
firstDialed = true
return nil, errors.New("quic unreachable")
},
}
fallbackConn := &MockConn{remoteAddr: &MockAddr{network: "ws"}}
fallback := &MockDialer{
protocolStr: "ws",
dialFunc: func(ctx context.Context, address string) (net.Conn, error) {
secondDialed = true
return fallbackConn, nil
},
}
rd := NewRaceDial(logger, DefaultConnectionTimeout, serverURL, preferred, fallback).WithSequential()
conn, err := rd.Dial(context.Background())
if err != nil {
t.Fatalf("expected fallback to succeed, got %v", err)
}
if conn != fallbackConn {
t.Errorf("expected fallback connection, got %v", conn)
}
if !firstDialed || !secondDialed {
t.Errorf("expected both dialers attempted in order, first=%v second=%v", firstDialed, secondDialed)
}
}
func TestRaceDialSequentialPreferredWins(t *testing.T) {
logger := logrus.NewEntry(logrus.New())
serverURL := "test.server.com"
preferredConn := &MockConn{remoteAddr: &MockAddr{network: "quic"}}
preferred := &MockDialer{
protocolStr: "quic",
dialFunc: func(ctx context.Context, address string) (net.Conn, error) {
return preferredConn, nil
},
}
fallback := &MockDialer{
protocolStr: "ws",
dialFunc: func(ctx context.Context, address string) (net.Conn, error) {
t.Errorf("fallback dialer must not be tried when preferred succeeds")
return nil, errors.New("should not happen")
},
}
rd := NewRaceDial(logger, DefaultConnectionTimeout, serverURL, preferred, fallback).WithSequential()
conn, err := rd.Dial(context.Background())
if err != nil {
t.Fatalf("expected preferred to succeed, got %v", err)
}
if conn != preferredConn {
t.Errorf("expected preferred connection, got %v", conn)
}
}

View File

@@ -9,11 +9,42 @@ import (
"github.com/netbirdio/netbird/shared/relay/client/dialer/ws"
)
// getDialers returns the list of dialers to use for connecting to the relay server.
func (c *Client) getDialers() []dialer.DialeFn {
if c.mtu > 0 && c.mtu > iface.DefaultMTU {
c.log.Infof("MTU %d exceeds default (%d), forcing WebSocket transport to avoid DATAGRAM frame size issues", c.mtu, iface.DefaultMTU)
return []dialer.DialeFn{ws.Dialer{}}
// getDialers returns the ordered dialers for connecting to the relay server. It
// applies the datagram fallback generically: if this server recently rejected a
// datagram-sized transport, those dialers are dropped, leaving the rest.
func (c *Client) getDialers(mode TransportMode) []dialer.DialeFn {
dialers := c.baseDialers(mode)
if c.transportFallback != nil && c.transportFallback.avoidDatagramSized(c.connectionURL) {
if filtered := nonDatagramSized(dialers); len(filtered) > 0 {
c.log.Infof("relay recently rejected a datagram-sized transport, avoiding it")
return filtered
}
}
return []dialer.DialeFn{quic.Dialer{}, ws.Dialer{}}
return dialers
}
// baseDialers returns the ordered dialers for the mode, before any datagram
// fallback filtering. For racing modes (auto) the order is irrelevant; for
// prefer modes the first entry is tried before falling back to the second.
func (c *Client) baseDialers(mode TransportMode) []dialer.DialeFn {
switch mode {
case TransportModeWS:
c.log.Infof("%s=ws, using WebSocket transport", EnvRelayTransport)
return []dialer.DialeFn{ws.Dialer{}}
case TransportModeQUIC:
c.log.Infof("%s=quic, using QUIC transport", EnvRelayTransport)
return []dialer.DialeFn{quic.Dialer{}}
}
all := []dialer.DialeFn{quic.Dialer{}, ws.Dialer{}}
if mode == TransportModePreferWS {
all = []dialer.DialeFn{ws.Dialer{}, quic.Dialer{}}
}
if c.mtu > 0 && c.mtu > iface.DefaultMTU {
c.log.Infof("MTU %d exceeds default (%d), avoiding datagram-sized transports", c.mtu, iface.DefaultMTU)
return nonDatagramSized(all)
}
return all
}

View File

@@ -0,0 +1,101 @@
//go:build !js
package client
import (
"os"
"testing"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/shared/relay/client/dialer"
netErr "github.com/netbirdio/netbird/shared/relay/client/dialer/net"
"github.com/netbirdio/netbird/shared/relay/client/dialer/quic"
"github.com/netbirdio/netbird/shared/relay/client/dialer/ws"
)
// TestDatagramSizedCapability locks the capability the generic fallback relies
// on: QUIC is datagram-sized, WebSocket is not.
func TestDatagramSizedCapability(t *testing.T) {
assert.True(t, dialer.IsDatagramSized(quic.Dialer{}), "QUIC must advertise datagram-sized")
assert.False(t, dialer.IsDatagramSized(ws.Dialer{}), "WebSocket must not advertise datagram-sized")
}
func protocols(dialers []dialer.DialeFn) []string {
out := make([]string, len(dialers))
for i, d := range dialers {
out[i] = d.Protocol()
}
return out
}
func TestGetDialers(t *testing.T) {
const url = "rels://relay.example:443"
tests := []struct {
name string
mode string
mtu uint16
preferWS bool
want []string
}{
{name: "auto races quic and ws", mode: "auto", mtu: iface.DefaultMTU, want: []string{"quic", "WS"}},
{name: "ws pinned", mode: "ws", mtu: iface.DefaultMTU, want: []string{"WS"}},
{name: "quic pinned", mode: "quic", mtu: iface.DefaultMTU, want: []string{"quic"}},
{name: "prefer-quic orders quic first", mode: "prefer-quic", mtu: iface.DefaultMTU, want: []string{"quic", "WS"}},
{name: "prefer-ws orders ws first", mode: "prefer-ws", mtu: iface.DefaultMTU, want: []string{"WS", "quic"}},
{name: "mtu above default forces ws", mode: "auto", mtu: iface.DefaultMTU + 100, want: []string{"WS"}},
{name: "sticky fallback forces ws in auto", mode: "auto", mtu: iface.DefaultMTU, preferWS: true, want: []string{"WS"}},
{name: "sticky fallback forces ws in prefer-quic", mode: "prefer-quic", mtu: iface.DefaultMTU, preferWS: true, want: []string{"WS"}},
{name: "quic pin overrides sticky fallback", mode: "quic", mtu: iface.DefaultMTU, preferWS: true, want: []string{"quic"}},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Setenv(EnvRelayTransport, tc.mode)
if tc.mode == "" {
os.Unsetenv(EnvRelayTransport)
}
tf := newTransportFallback()
if tc.preferWS {
tf.recordFailure(url)
}
c := &Client{
log: log.WithField("test", t.Name()),
connectionURL: url,
mtu: tc.mtu,
transportFallback: tf,
}
assert.Equal(t, tc.want, protocols(c.getDialers(transportModeFromEnv())))
})
}
}
// TestStickyFallbackAfterDatagramTooLarge verifies the full chain: an oversized
// datagram records a fallback that makes the next dial pick WebSocket, the way a
// reconnect would after the connection is closed.
func TestStickyFallbackAfterDatagramTooLarge(t *testing.T) {
const url = "rels://relay.example:443"
t.Setenv(EnvRelayTransport, string(TransportModeAuto))
c := &Client{
log: log.WithField("test", t.Name()),
connectionURL: url,
mtu: iface.DefaultMTU,
transportFallback: newTransportFallback(),
}
// First dial races both transports.
assert.Equal(t, []string{"quic", "WS"}, protocols(c.getDialers(transportModeFromEnv())))
// An oversized datagram records the fallback for this server.
c.onDatagramTooLarge(&closeTrackingConn{}, netErr.ErrDatagramTooLarge)
// The reconnect now sticks to WebSocket.
assert.Equal(t, []string{"WS"}, protocols(c.getDialers(transportModeFromEnv())))
}

View File

@@ -7,7 +7,11 @@ import (
"github.com/netbirdio/netbird/shared/relay/client/dialer/ws"
)
func (c *Client) getDialers() []dialer.DialeFn {
func (c *Client) getDialers(_ TransportMode) []dialer.DialeFn {
// JS/WASM build only uses WebSocket transport
return []dialer.DialeFn{ws.Dialer{}}
}
func (c *Client) baseDialers(_ TransportMode) []dialer.DialeFn {
return []dialer.DialeFn{ws.Dialer{}}
}

View File

@@ -79,23 +79,30 @@ type Manager struct {
cleanupInterval time.Duration
keepUnusedServerTime time.Duration
// transportFallback is shared across home and foreign relay clients so a
// datagram-too-large failure makes that server avoid datagram-sized transports across reconnects.
transportFallback *transportFallback
}
// NewManager creates a new manager instance.
// The serverURL address can be empty. In this case, the manager will not serve.
func NewManager(ctx context.Context, serverURLs []string, peerID string, mtu uint16, opts ...ManagerOption) *Manager {
tokenStore := &relayAuth.TokenStore{}
tf := newTransportFallback()
m := &Manager{
ctx: ctx,
peerID: peerID,
tokenStore: tokenStore,
mtu: mtu,
ctx: ctx,
peerID: peerID,
tokenStore: tokenStore,
mtu: mtu,
transportFallback: tf,
serverPicker: &ServerPicker{
TokenStore: tokenStore,
PeerID: peerID,
MTU: mtu,
ConnectionTimeout: defaultConnectionTimeout,
TransportFallback: tf,
},
relayClients: make(map[string]*RelayTrack),
onDisconnectedListeners: make(map[string]*list.List),
@@ -287,6 +294,7 @@ func (m *Manager) openConnVia(ctx context.Context, serverAddress, peerKey string
m.relayClientsMutex.Unlock()
relayClient := NewClientWithServerIP(serverAddress, serverIP, m.tokenStore, m.peerID, m.mtu)
relayClient.SetTransportFallback(m.transportFallback)
err := relayClient.Connect(m.ctx)
if err != nil {
rt.err = err

View File

@@ -29,6 +29,7 @@ type ServerPicker struct {
PeerID string
MTU uint16
ConnectionTimeout time.Duration
TransportFallback *transportFallback
}
func (sp *ServerPicker) PickServer(parentCtx context.Context) (*Client, error) {
@@ -70,6 +71,7 @@ func (sp *ServerPicker) PickServer(parentCtx context.Context) (*Client, error) {
func (sp *ServerPicker) startConnection(ctx context.Context, resultChan chan connResult, url string) {
log.Infof("try to connecting to relay server: %s", url)
relayClient := NewClient(url, sp.TokenStore, sp.PeerID, sp.MTU)
relayClient.SetTransportFallback(sp.TransportFallback)
err := relayClient.Connect(ctx)
resultChan <- connResult{
RelayClient: relayClient,

View File

@@ -0,0 +1,129 @@
package client
import (
"os"
"strings"
"sync"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/shared/relay/client/dialer"
)
// EnvRelayTransport pins the relay transport. Valid values: "auto" (default,
// race QUIC and WebSocket), "quic" (QUIC only), "ws" (WebSocket only),
// "prefer-quic" / "prefer-ws" (try the preferred transport first, fall back to
// the other only if it fails to connect; no race). The prefer modes trade a
// slower connect when the preferred transport is blackholed for deterministic
// transport selection.
const EnvRelayTransport = "NB_RELAY_TRANSPORT"
const (
// transportFallbackBase is the initial window a relay server avoids
// datagram-sized transports after a datagram is rejected as too large.
transportFallbackBase = 10 * time.Minute
// transportFallbackMax caps the pinned window when failures repeat.
transportFallbackMax = 60 * time.Minute
)
// TransportMode selects which relay dialers are used.
type TransportMode string
const (
TransportModeAuto TransportMode = "auto"
TransportModeQUIC TransportMode = "quic"
TransportModeWS TransportMode = "ws"
TransportModePreferQUIC TransportMode = "prefer-quic"
TransportModePreferWS TransportMode = "prefer-ws"
)
// transportModeFromEnv reads EnvRelayTransport, defaulting to auto for an empty
// or unrecognized value.
func transportModeFromEnv() TransportMode {
switch TransportMode(strings.ToLower(strings.TrimSpace(os.Getenv(EnvRelayTransport)))) {
case "", TransportModeAuto:
return TransportModeAuto
case TransportModeQUIC:
return TransportModeQUIC
case TransportModeWS:
return TransportModeWS
case TransportModePreferQUIC:
return TransportModePreferQUIC
case TransportModePreferWS:
return TransportModePreferWS
default:
log.Warnf("invalid %s value %q, using %q", EnvRelayTransport, os.Getenv(EnvRelayTransport), TransportModeAuto)
return TransportModeAuto
}
}
// sequential reports whether the mode tries dialers in order with fallback
// instead of racing them concurrently.
func (m TransportMode) sequential() bool {
return m == TransportModePreferQUIC || m == TransportModePreferWS
}
// transportFallback tracks relay servers that have rejected a datagram-sized
// transport (a write too large for the path) and should temporarily avoid such
// transports. It is shared across the relay manager so the preference survives
// client recreation (foreign relay clients are evicted and rebuilt on
// disconnect). Entries are keyed by server URL and expire after a window that
// grows on repeated failures.
type transportFallback struct {
mu sync.Mutex
entries map[string]*fallbackEntry
}
type fallbackEntry struct {
until time.Time
duration time.Duration
}
func newTransportFallback() *transportFallback {
return &transportFallback{entries: make(map[string]*fallbackEntry)}
}
// avoidDatagramSized reports whether serverURL is currently within a window
// where datagram-sized transports should be avoided.
func (f *transportFallback) avoidDatagramSized(serverURL string) bool {
f.mu.Lock()
defer f.mu.Unlock()
e := f.entries[serverURL]
return e != nil && time.Now().Before(e.until)
}
// recordFailure makes serverURL avoid datagram-sized transports for a window:
// transportFallbackBase on the first failure, doubling up to transportFallbackMax
// when a datagram transport fails again after a previous window expired. It
// returns the active window duration.
func (f *transportFallback) recordFailure(serverURL string) time.Duration {
f.mu.Lock()
defer f.mu.Unlock()
now := time.Now()
e := f.entries[serverURL]
switch {
case e == nil:
e = &fallbackEntry{duration: transportFallbackBase}
f.entries[serverURL] = e
case now.Before(e.until):
return time.Until(e.until)
default:
e.duration = min(e.duration*2, transportFallbackMax)
}
e.until = now.Add(e.duration)
return e.duration
}
// nonDatagramSized returns the dialers from in that are not datagram-sized,
// preserving order.
func nonDatagramSized(in []dialer.DialeFn) []dialer.DialeFn {
out := make([]dialer.DialeFn, 0, len(in))
for _, d := range in {
if !dialer.IsDatagramSized(d) {
out = append(out, d)
}
}
return out
}

View File

@@ -0,0 +1,140 @@
package client
import (
"net"
"os"
"testing"
"time"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
netErr "github.com/netbirdio/netbird/shared/relay/client/dialer/net"
)
// closeTrackingConn records whether Close was called; only Close is exercised.
type closeTrackingConn struct {
net.Conn
closed bool
}
func (c *closeTrackingConn) Close() error {
c.closed = true
return nil
}
func TestTransportModeFromEnv(t *testing.T) {
tests := []struct {
value string
want TransportMode
}{
{"", TransportModeAuto},
{"auto", TransportModeAuto},
{"quic", TransportModeQUIC},
{"QUIC", TransportModeQUIC},
{"ws", TransportModeWS},
{" Ws ", TransportModeWS},
{"prefer-quic", TransportModePreferQUIC},
{"prefer-ws", TransportModePreferWS},
{"garbage", TransportModeAuto},
}
for _, tc := range tests {
t.Run(tc.value, func(t *testing.T) {
t.Setenv(EnvRelayTransport, tc.value)
if tc.value == "" {
os.Unsetenv(EnvRelayTransport)
}
assert.Equal(t, tc.want, transportModeFromEnv())
})
}
}
func TestTransportFallbackRecordAndExpiry(t *testing.T) {
const url = "rels://relay.example:443"
f := newTransportFallback()
assert.False(t, f.avoidDatagramSized(url), "no fallback recorded yet")
d := f.recordFailure(url)
assert.Equal(t, transportFallbackBase, d, "first failure pins for the base window")
assert.True(t, f.avoidDatagramSized(url), "datagram-sized transport avoided within the window")
// A second failure while still inside the window must not grow the window.
d = f.recordFailure(url)
assert.LessOrEqual(t, d, transportFallbackBase, "still within the active window")
require.NotNil(t, f.entries[url])
assert.Equal(t, transportFallbackBase, f.entries[url].duration, "duration unchanged inside window")
// Expire the window: datagram-sized transport allowed again.
f.entries[url].until = time.Now().Add(-time.Second)
assert.False(t, f.avoidDatagramSized(url), "window expired, datagram-sized transport allowed")
}
func TestTransportFallbackGrowsOnRepeat(t *testing.T) {
const url = "rels://relay.example:443"
f := newTransportFallback()
want := transportFallbackBase
for i := range 6 {
d := f.recordFailure(url)
assert.Equal(t, want, d, "window after %d expiries", i)
// expire the window so the next failure is treated as a repeat
f.entries[url].until = time.Now().Add(-time.Second)
want = min(want*2, transportFallbackMax)
}
assert.Equal(t, transportFallbackMax, f.entries[url].duration, "window caps at the max")
}
func TestOnDatagramTooLargeAuto(t *testing.T) {
const url = "rels://relay.example:443"
t.Setenv(EnvRelayTransport, string(TransportModeAuto))
tf := newTransportFallback()
c := &Client{
log: log.WithField("test", t.Name()),
connectionURL: url,
transportFallback: tf,
}
conn := &closeTrackingConn{}
c.onDatagramTooLarge(conn, netErr.ErrDatagramTooLarge)
assert.True(t, conn.closed, "connection closed to force reconnect")
assert.True(t, tf.avoidDatagramSized(url), "fallback recorded for the server")
// A second oversized datagram on the same connection must not re-close.
conn.closed = false
c.onDatagramTooLarge(conn, netErr.ErrDatagramTooLarge)
assert.False(t, conn.closed, "single fallback per connection")
}
func TestOnDatagramTooLargeQUICPinned(t *testing.T) {
const url = "rels://relay.example:443"
t.Setenv(EnvRelayTransport, string(TransportModeQUIC))
tf := newTransportFallback()
c := &Client{
log: log.WithField("test", t.Name()),
connectionURL: url,
transportFallback: tf,
}
conn := &closeTrackingConn{}
c.onDatagramTooLarge(conn, netErr.ErrDatagramTooLarge)
assert.False(t, conn.closed, "QUIC pin keeps the connection, no fallback redial")
assert.False(t, tf.avoidDatagramSized(url), "QUIC pin records no fallback")
}
func TestTransportFallbackPerServer(t *testing.T) {
f := newTransportFallback()
f.recordFailure("rels://a.example:443")
assert.True(t, f.avoidDatagramSized("rels://a.example:443"))
assert.False(t, f.avoidDatagramSized("rels://b.example:443"), "fallback is scoped to one server")
}