mirror of
https://github.com/netbirdio/netbird.git
synced 2026-06-09 09:29:57 +00:00
Compare commits
1 Commits
main
...
ipv6-exit-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8eb50c7704 |
@@ -3,7 +3,6 @@ package iptables
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"maps"
|
||||
"net"
|
||||
"slices"
|
||||
|
||||
@@ -422,17 +421,12 @@ func (m *aclManager) updateState() {
|
||||
currentState.Lock()
|
||||
defer currentState.Unlock()
|
||||
|
||||
// Clone the maps so the persisted state holds a private snapshot. The
|
||||
// live maps keep being mutated by subsequent rule operations while the
|
||||
// state manager marshals the state from its periodic-save goroutine.
|
||||
// Sharing them by reference races the two and aborts the process with a
|
||||
// concurrent map iteration and write.
|
||||
if m.v6 {
|
||||
currentState.ACLEntries6 = maps.Clone(m.entries)
|
||||
currentState.ACLIPsetStore6 = m.ipsetStore.clone()
|
||||
currentState.ACLEntries6 = m.entries
|
||||
currentState.ACLIPsetStore6 = m.ipsetStore
|
||||
} else {
|
||||
currentState.ACLEntries = maps.Clone(m.entries)
|
||||
currentState.ACLIPsetStore = m.ipsetStore.clone()
|
||||
currentState.ACLEntries = m.entries
|
||||
currentState.ACLIPsetStore = m.ipsetStore
|
||||
}
|
||||
|
||||
if err := m.stateManager.UpdateState(currentState); err != nil {
|
||||
|
||||
@@ -4,7 +4,6 @@ package iptables
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"maps"
|
||||
"net/netip"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -750,17 +749,11 @@ func (r *router) updateState() {
|
||||
currentState.Lock()
|
||||
defer currentState.Unlock()
|
||||
|
||||
// Clone the rule map so the persisted state holds a private snapshot. The
|
||||
// live map keeps being mutated by subsequent rule operations while the
|
||||
// state manager marshals the state from its periodic-save goroutine.
|
||||
// Sharing it by reference races the two and aborts the process with a
|
||||
// concurrent map iteration and write. The ipset counter guards itself
|
||||
// during marshaling, so it can be shared directly.
|
||||
if r.v6 {
|
||||
currentState.RouteRules6 = maps.Clone(r.rules)
|
||||
currentState.RouteRules6 = r.rules
|
||||
currentState.RouteIPsetCounter6 = r.ipsetCounter
|
||||
} else {
|
||||
currentState.RouteRules = maps.Clone(r.rules)
|
||||
currentState.RouteRules = r.rules
|
||||
currentState.RouteIPsetCounter = r.ipsetCounter
|
||||
}
|
||||
|
||||
|
||||
@@ -1,9 +1,6 @@
|
||||
package iptables
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"maps"
|
||||
)
|
||||
import "encoding/json"
|
||||
|
||||
type ipList struct {
|
||||
ips map[string]struct{}
|
||||
@@ -22,14 +19,6 @@ func (s *ipList) addIP(ip string) {
|
||||
s.ips[ip] = struct{}{}
|
||||
}
|
||||
|
||||
// clone returns a deep copy of the ipList with its own ips map.
|
||||
func (s *ipList) clone() *ipList {
|
||||
if s == nil {
|
||||
return nil
|
||||
}
|
||||
return &ipList{ips: maps.Clone(s.ips)}
|
||||
}
|
||||
|
||||
// MarshalJSON implements json.Marshaler
|
||||
func (s *ipList) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
@@ -66,19 +55,6 @@ func newIpsetStore() *ipsetStore {
|
||||
}
|
||||
}
|
||||
|
||||
// clone returns a deep copy of the ipsetStore with its own ipsets map and
|
||||
// independent ipList entries.
|
||||
func (s *ipsetStore) clone() *ipsetStore {
|
||||
if s == nil {
|
||||
return nil
|
||||
}
|
||||
cloned := &ipsetStore{ipsets: make(map[string]*ipList, len(s.ipsets))}
|
||||
for name, list := range s.ipsets {
|
||||
cloned.ipsets[name] = list.clone()
|
||||
}
|
||||
return cloned
|
||||
}
|
||||
|
||||
func (s *ipsetStore) ipset(ipsetName string) (*ipList, bool) {
|
||||
r, ok := s.ipsets[ipsetName]
|
||||
return r, ok
|
||||
|
||||
@@ -700,13 +700,6 @@ func resolveURLsToIPs(urls []string) []net.IP {
|
||||
|
||||
// updateRouteSelectorFromManagement updates the route selector based on the isSelected status from the management server
|
||||
func (m *DefaultManager) updateRouteSelectorFromManagement(clientRoutes route.HAMap) {
|
||||
// An explicit user "deselect all" must not be overridden by management auto-apply.
|
||||
// Auto-applying an exit node here would call SelectRoutes, which clears the
|
||||
// deselect-all flag and re-enables every route the user turned off.
|
||||
if m.routeSelector.IsDeselectAll() {
|
||||
return
|
||||
}
|
||||
|
||||
exitNodeInfo := m.collectExitNodeInfo(clientRoutes)
|
||||
if len(exitNodeInfo.allIDs) == 0 {
|
||||
return
|
||||
|
||||
@@ -1,71 +0,0 @@
|
||||
package routemanager
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/routeselector"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
func exitNodeRoutes(netID route.NetID, skipAutoApply bool) route.HAMap {
|
||||
haID := route.HAUniqueID(string(netID) + "|0.0.0.0/0")
|
||||
return route.HAMap{
|
||||
haID: []*route.Route{
|
||||
{
|
||||
ID: "r-" + route.ID(netID),
|
||||
NetID: netID,
|
||||
Network: netip.MustParsePrefix("0.0.0.0/0"),
|
||||
NetworkType: route.IPv4Network,
|
||||
Enabled: true,
|
||||
SkipAutoApply: skipAutoApply,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateRouteSelectorFromManagement(t *testing.T) {
|
||||
t.Run("management auto-apply selects exit node without user selection", func(t *testing.T) {
|
||||
m := &DefaultManager{routeSelector: routeselector.NewRouteSelector()}
|
||||
routes := exitNodeRoutes("exit1", false)
|
||||
|
||||
m.updateRouteSelectorFromManagement(routes)
|
||||
|
||||
require.True(t, m.routeSelector.IsSelected("exit1"), "auto-apply exit node should be selected")
|
||||
require.Len(t, m.routeSelector.FilterSelectedExitNodes(routes), 1, "selected exit node should pass the filter")
|
||||
})
|
||||
|
||||
t.Run("management SkipAutoApply leaves exit node deselected", func(t *testing.T) {
|
||||
m := &DefaultManager{routeSelector: routeselector.NewRouteSelector()}
|
||||
routes := exitNodeRoutes("exit1", true)
|
||||
|
||||
m.updateRouteSelectorFromManagement(routes)
|
||||
|
||||
require.False(t, m.routeSelector.IsSelected("exit1"), "SkipAutoApply exit node should not be selected")
|
||||
require.Empty(t, m.routeSelector.FilterSelectedExitNodes(routes), "deselected exit node should be filtered out")
|
||||
})
|
||||
|
||||
t.Run("user selection is not overridden by management", func(t *testing.T) {
|
||||
m := &DefaultManager{routeSelector: routeselector.NewRouteSelector()}
|
||||
require.NoError(t, m.routeSelector.SelectRoutes([]route.NetID{"exit1"}, true, []route.NetID{"exit1"}))
|
||||
routes := exitNodeRoutes("exit1", true)
|
||||
|
||||
m.updateRouteSelectorFromManagement(routes)
|
||||
|
||||
require.True(t, m.routeSelector.IsSelected("exit1"), "explicit user selection must survive a management sync that wants to skip auto-apply")
|
||||
require.Len(t, m.routeSelector.FilterSelectedExitNodes(routes), 1, "user-selected exit node should pass the filter")
|
||||
})
|
||||
|
||||
t.Run("deselect-all is preserved across a management sync", func(t *testing.T) {
|
||||
m := &DefaultManager{routeSelector: routeselector.NewRouteSelector()}
|
||||
m.routeSelector.DeselectAllRoutes()
|
||||
routes := exitNodeRoutes("exit1", false)
|
||||
|
||||
m.updateRouteSelectorFromManagement(routes)
|
||||
|
||||
require.True(t, m.routeSelector.IsDeselectAll(), "an explicit deselect-all must not be cleared by management auto-apply")
|
||||
require.Empty(t, m.routeSelector.FilterSelectedExitNodes(routes), "no routes should be selected while deselect-all is set")
|
||||
})
|
||||
}
|
||||
@@ -116,14 +116,6 @@ func (rs *RouteSelector) DeselectAllRoutes() {
|
||||
clear(rs.selectedRoutes)
|
||||
}
|
||||
|
||||
// IsDeselectAll reports whether the user has explicitly deselected all routes.
|
||||
func (rs *RouteSelector) IsDeselectAll() bool {
|
||||
rs.mu.RLock()
|
||||
defer rs.mu.RUnlock()
|
||||
|
||||
return rs.deselectAll
|
||||
}
|
||||
|
||||
// IsSelected checks if a specific route is selected.
|
||||
func (rs *RouteSelector) IsSelected(routeID route.NetID) bool {
|
||||
rs.mu.RLock()
|
||||
|
||||
@@ -19,46 +19,6 @@ readonly MSG_SEPARATOR="=========================================="
|
||||
# Utility Functions
|
||||
############################################
|
||||
|
||||
check_docker_sock_perms() {
|
||||
local sock="${DOCKER_HOST:-unix:///var/run/docker.sock}"
|
||||
sock="${sock#unix://}"
|
||||
|
||||
if [[ ! -S "$sock" ]]; then
|
||||
return 0
|
||||
fi
|
||||
|
||||
if [[ ! -r "$sock" ]] || [[ ! -w "$sock" ]]; then
|
||||
local group
|
||||
if [[ "${OSTYPE}" == "darwin"* ]]; then
|
||||
group="$(stat -f '%Sg' "$sock")"
|
||||
else
|
||||
group="$(stat -c '%G' "$sock")"
|
||||
fi
|
||||
|
||||
echo "Cannot access Docker socket: $sock" > /dev/stderr
|
||||
echo "" > /dev/stderr
|
||||
echo "Socket permissions:" > /dev/stderr
|
||||
ls -l "$sock" > /dev/stderr
|
||||
echo "" > /dev/stderr
|
||||
|
||||
if [[ "$group" == "docker" ]]; then
|
||||
echo "Your user may need to be added to the '$group' group:" > /dev/stderr
|
||||
echo " sudo usermod -aG $group \"$USER\"" > /dev/stderr
|
||||
echo "Then log out and back in, or run this for the current shell:" > /dev/stderr
|
||||
echo " newgrp $group" > /dev/stderr
|
||||
echo "Note: newgrp is temporary; usermod is the permanent group change." > /dev/stderr
|
||||
else
|
||||
echo "The Docker socket is owned by the '$group' group, which is not the standard 'docker' group." > /dev/stderr
|
||||
echo "For safety, this script will not suggest adding your user to '$group'." > /dev/stderr
|
||||
echo "Instead, either run this script with appropriate privileges (for example, via sudo) or follow Docker's post-install steps to configure access via the 'docker' group:" > /dev/stderr
|
||||
echo " https://docs.docker.com/engine/install/linux-postinstall/" > /dev/stderr
|
||||
fi
|
||||
|
||||
exit 1
|
||||
fi
|
||||
return 0
|
||||
}
|
||||
|
||||
check_docker_compose() {
|
||||
if command -v docker-compose &> /dev/null
|
||||
then
|
||||
@@ -621,15 +581,12 @@ start_services_and_show_instructions() {
|
||||
}
|
||||
|
||||
init_environment() {
|
||||
# Check if docker compose is installed using check_docker_compose function
|
||||
DOCKER_COMPOSE_COMMAND=$(check_docker_compose)
|
||||
check_docker_sock_perms
|
||||
|
||||
initialize_default_values
|
||||
configure_domain
|
||||
configure_reverse_proxy
|
||||
|
||||
check_jq
|
||||
DOCKER_COMPOSE_COMMAND=$(check_docker_compose)
|
||||
|
||||
check_existing_installation
|
||||
generate_configuration_files
|
||||
|
||||
@@ -557,7 +557,6 @@ func (c *NetworkMapComponents) getRoutingPeerRoutes(peerID string) (enabledRoute
|
||||
return enabledRoutes, disabledRoutes
|
||||
}
|
||||
|
||||
|
||||
func (c *NetworkMapComponents) filterRoutesByGroups(routes []*route.Route, groupListMap LookupMap) []*route.Route {
|
||||
var filteredRoutes []*route.Route
|
||||
for _, r := range routes {
|
||||
@@ -628,9 +627,14 @@ func (c *NetworkMapComponents) getDefaultPermit(r *route.Route, includeIPv6 bool
|
||||
|
||||
rules := []*RouteFirewallRule{&rule}
|
||||
|
||||
if includeIPv6 && r.IsDynamic() {
|
||||
isDefaultV4 := r.Network.Addr().Is4() && r.Network.Bits() == 0
|
||||
if includeIPv6 && (r.IsDynamic() || isDefaultV4) {
|
||||
ruleV6 := rule
|
||||
ruleV6.SourceRanges = []string{"::/0"}
|
||||
if isDefaultV4 {
|
||||
ruleV6.Destination = "::/0"
|
||||
ruleV6.RouteID = r.ID + "-v6-default"
|
||||
}
|
||||
rules = append(rules, &ruleV6)
|
||||
}
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -1029,6 +1030,48 @@ func TestComponents_RouteDefaultPermit(t *testing.T) {
|
||||
assert.True(t, hasDefaultPermit, "route without ACG should have default permit rule with 0.0.0.0/0 source")
|
||||
}
|
||||
|
||||
// TestComponents_ExitNodeDefaultPermitIPv6 verifies that a default exit node route
|
||||
// (0.0.0.0/0) without AccessControlGroups also emits an IPv6 default permit rule
|
||||
// (::/0 source and destination) for peers that support IPv6, mirroring the route
|
||||
// the client installs. Without it, IPv6 traffic is routed to the exit node but
|
||||
// dropped at the forward chain.
|
||||
func TestComponents_ExitNodeDefaultPermitIPv6(t *testing.T) {
|
||||
account, validatedPeers := scalableTestAccount(20, 2)
|
||||
|
||||
routingPeerID := "peer-5"
|
||||
routingPeer := account.Peers[routingPeerID]
|
||||
routingPeer.IPv6 = netip.MustParseAddr("fd00::5")
|
||||
routingPeer.Meta.Capabilities = append(routingPeer.Meta.Capabilities, nbpeer.PeerCapabilityIPv6Overlay)
|
||||
|
||||
account.Routes["route-exit"] = &route.Route{
|
||||
ID: "route-exit", Network: netip.MustParsePrefix("0.0.0.0/0"),
|
||||
PeerID: routingPeerID, Peer: routingPeer.Key,
|
||||
Enabled: true, Groups: []string{"group-all"}, PeerGroups: []string{"group-0"},
|
||||
AccessControlGroups: []string{},
|
||||
AccountID: "test-account",
|
||||
}
|
||||
|
||||
nm := componentsNetworkMap(account, routingPeerID, validatedPeers)
|
||||
require.NotNil(t, nm)
|
||||
|
||||
hasV4 := false
|
||||
hasV6 := false
|
||||
for _, rfr := range nm.RoutesFirewallRules {
|
||||
switch rfr.Destination {
|
||||
case "0.0.0.0/0":
|
||||
if slices.Contains(rfr.SourceRanges, "0.0.0.0/0") {
|
||||
hasV4 = true
|
||||
}
|
||||
case "::/0":
|
||||
if slices.Contains(rfr.SourceRanges, "::/0") {
|
||||
hasV6 = true
|
||||
}
|
||||
}
|
||||
}
|
||||
assert.True(t, hasV4, "exit node route should have an IPv4 default permit rule (0.0.0.0/0)")
|
||||
assert.True(t, hasV6, "exit node route should have an IPv6 default permit rule (::/0)")
|
||||
}
|
||||
|
||||
// ──────────────────────────────────────────────────────────────────────────────
|
||||
// 15. MULTIPLE ROUTERS PER NETWORK
|
||||
// ──────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
@@ -417,30 +417,15 @@ if type uname >/dev/null 2>&1; then
|
||||
# Check the availability of a compatible package manager
|
||||
if check_use_bin_variable; then
|
||||
PACKAGE_MANAGER="bin"
|
||||
elif [ -e /run/ostree-booted ]; then
|
||||
if [ -x "$(command -v rpm-ostree)" ]; then
|
||||
PACKAGE_MANAGER="rpm-ostree"
|
||||
echo "The installation will be performed using rpm-ostree package manager"
|
||||
elif [ -x "$(command -v bootc)" ]; then
|
||||
echo "Detected bootc system without rpm-ostree." >&2
|
||||
echo "NetBird cannot be installed via package manager on this system." >&2
|
||||
echo "Options:" >&2
|
||||
echo " 1. Install via Distrobox (instructions in the installation docs)" >&2
|
||||
echo " 2. Rebuild your base image with rpm-ostree included" >&2
|
||||
echo " 3. Bake NetBird into your Containerfile" >&2
|
||||
exit 1
|
||||
else
|
||||
echo "Detected ostree-booted system without rpm-ostree or bootc." >&2
|
||||
echo "NetBird cannot be installed automatically on this atomic system." >&2
|
||||
echo "Please install NetBird by rebuilding your base image or use a supported package manager." >&2
|
||||
exit 1
|
||||
fi
|
||||
elif [ -x "$(command -v apt-get)" ]; then
|
||||
PACKAGE_MANAGER="apt"
|
||||
echo "The installation will be performed using apt package manager"
|
||||
elif [ -x "$(command -v dnf)" ]; then
|
||||
PACKAGE_MANAGER="dnf"
|
||||
echo "The installation will be performed using dnf package manager"
|
||||
elif [ -x "$(command -v rpm-ostree)" ]; then
|
||||
PACKAGE_MANAGER="rpm-ostree"
|
||||
echo "The installation will be performed using rpm-ostree package manager"
|
||||
elif [ -x "$(command -v yum)" ]; then
|
||||
PACKAGE_MANAGER="yum"
|
||||
echo "The installation will be performed using yum package manager"
|
||||
|
||||
@@ -9,14 +9,12 @@ import (
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
auth "github.com/netbirdio/netbird/shared/relay/auth/hmac"
|
||||
"github.com/netbirdio/netbird/shared/relay/client/dialer"
|
||||
netErr "github.com/netbirdio/netbird/shared/relay/client/dialer/net"
|
||||
"github.com/netbirdio/netbird/shared/relay/healthcheck"
|
||||
"github.com/netbirdio/netbird/shared/relay/messages"
|
||||
)
|
||||
@@ -174,19 +172,6 @@ type Client struct {
|
||||
stateSubscription *PeersStateSubscription
|
||||
|
||||
mtu uint16
|
||||
|
||||
// transportFallback, when set, records datagram-too-large failures so a
|
||||
// datagram-sized transport is avoided on subsequent connects. Shared via
|
||||
// the manager.
|
||||
transportFallback *transportFallback
|
||||
// datagramFallbackTriggered guards a single fallback per connection so a
|
||||
// burst of oversized datagrams triggers one reconnect, not many.
|
||||
datagramFallbackTriggered atomic.Bool
|
||||
}
|
||||
|
||||
// SetTransportFallback wires the shared datagram-transport fallback tracker.
|
||||
func (c *Client) SetTransportFallback(tf *transportFallback) {
|
||||
c.transportFallback = tf
|
||||
}
|
||||
|
||||
// NewClient creates a new client for the relay server. The client is not connected to the server until the Connect
|
||||
@@ -376,13 +361,12 @@ func (c *Client) Close() error {
|
||||
}
|
||||
|
||||
func (c *Client) connect(ctx context.Context) (*RelayAddr, error) {
|
||||
mode := transportModeFromEnv()
|
||||
dialers := c.getDialers(mode)
|
||||
dialers := c.getDialers()
|
||||
|
||||
var conn net.Conn
|
||||
if c.serverIP.IsValid() {
|
||||
var err error
|
||||
conn, err = c.dialRaceDirect(ctx, mode, dialers)
|
||||
conn, err = c.dialRaceDirect(ctx, dialers)
|
||||
if err != nil {
|
||||
c.log.Infof("dial via server IP %s failed, falling back to FQDN: %v", c.serverIP, err)
|
||||
conn = nil
|
||||
@@ -391,9 +375,6 @@ func (c *Client) connect(ctx context.Context) (*RelayAddr, error) {
|
||||
|
||||
if conn == nil {
|
||||
rd := dialer.NewRaceDial(c.log, dialer.DefaultConnectionTimeout, c.connectionURL, dialers...)
|
||||
if mode.sequential() {
|
||||
rd.WithSequential()
|
||||
}
|
||||
var err error
|
||||
conn, err = rd.Dial(ctx)
|
||||
if err != nil {
|
||||
@@ -401,7 +382,6 @@ func (c *Client) connect(ctx context.Context) (*RelayAddr, error) {
|
||||
}
|
||||
}
|
||||
c.relayConn = conn
|
||||
c.datagramFallbackTriggered.Store(false)
|
||||
|
||||
instanceURL, err := c.handShake(ctx)
|
||||
if err != nil {
|
||||
@@ -416,7 +396,7 @@ func (c *Client) connect(ctx context.Context) (*RelayAddr, error) {
|
||||
}
|
||||
|
||||
// dialRaceDirect dials c.serverIP, preserving the original FQDN as the TLS ServerName for SNI.
|
||||
func (c *Client) dialRaceDirect(ctx context.Context, mode TransportMode, dialers []dialer.DialeFn) (net.Conn, error) {
|
||||
func (c *Client) dialRaceDirect(ctx context.Context, dialers []dialer.DialeFn) (net.Conn, error) {
|
||||
directURL, serverName, err := substituteHost(c.connectionURL, c.serverIP)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("substitute host: %w", err)
|
||||
@@ -426,9 +406,6 @@ func (c *Client) dialRaceDirect(ctx context.Context, mode TransportMode, dialers
|
||||
|
||||
rd := dialer.NewRaceDial(c.log, dialer.DefaultConnectionTimeout, directURL, dialers...).
|
||||
WithServerName(serverName)
|
||||
if mode.sequential() {
|
||||
rd.WithSequential()
|
||||
}
|
||||
return rd.Dial(ctx)
|
||||
}
|
||||
|
||||
@@ -654,53 +631,13 @@ func (c *Client) writeTo(containerRef *connContainer, dstID messages.PeerID, pay
|
||||
}
|
||||
|
||||
// the write always return with 0 length because the underling does not support the size feedback.
|
||||
conn := c.relayConn
|
||||
_, err = conn.Write(msg)
|
||||
_, err = c.relayConn.Write(msg)
|
||||
if err != nil {
|
||||
if errors.Is(err, netErr.ErrDatagramTooLarge) {
|
||||
c.onDatagramTooLarge(conn, err)
|
||||
} else {
|
||||
c.log.Errorf("failed to write transport message: %s", err)
|
||||
}
|
||||
c.log.Errorf("failed to write transport message: %s", err)
|
||||
}
|
||||
return len(payload), err
|
||||
}
|
||||
|
||||
// onDatagramTooLarge reacts to a datagram rejected as too large for the path.
|
||||
// When a non-datagram transport is available, it records a fallback for this
|
||||
// server and closes the connection so the reconnect avoids datagram-sized
|
||||
// transports. A single fallback is triggered per connection regardless of how
|
||||
// many oversized datagrams arrive. cause carries the datagram size and budget.
|
||||
func (c *Client) onDatagramTooLarge(conn net.Conn, cause error) {
|
||||
// Handle one oversized datagram per connection; a burst triggers a single
|
||||
// fallback (and a single log line), not many.
|
||||
if !c.datagramFallbackTriggered.CompareAndSwap(false, true) {
|
||||
return
|
||||
}
|
||||
|
||||
// If the selected mode offers no non-datagram transport (e.g. pinned to a
|
||||
// datagram-sized transport), reconnecting would just re-fail, so leave the
|
||||
// connection up rather than loop.
|
||||
if len(nonDatagramSized(c.baseDialers(transportModeFromEnv()))) == 0 {
|
||||
c.log.Warnf("%s, but no non-datagram transport is available, not falling back", cause)
|
||||
return
|
||||
}
|
||||
|
||||
// Without the shared tracker a reconnect would just select the same
|
||||
// transport again and re-fail, so leave the connection up rather than loop.
|
||||
if c.transportFallback == nil {
|
||||
c.log.Debugf("%s, but no transport fallback configured, leaving connection up", cause)
|
||||
return
|
||||
}
|
||||
|
||||
window := c.transportFallback.recordFailure(c.connectionURL)
|
||||
c.log.Warnf("%s, avoiding datagram-sized transport for %s", cause, window)
|
||||
|
||||
if err := conn.Close(); err != nil {
|
||||
c.log.Debugf("close relay connection for transport fallback: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) listenForStopEvents(ctx context.Context, hc *healthcheck.Receiver, conn net.Conn, internalStopFlag *internalStopFlag) {
|
||||
for {
|
||||
select {
|
||||
|
||||
@@ -1,18 +0,0 @@
|
||||
package dialer
|
||||
|
||||
// DatagramSized is implemented by dialers whose connections carry each write in
|
||||
// a single datagram, so a write can be rejected when it exceeds the path's
|
||||
// datagram budget (e.g. QUIC). Transports without this capability (e.g.
|
||||
// WebSocket over TCP) impose no per-write size limit, so the relay client can
|
||||
// fall back to them when a datagram-sized transport rejects a write as too
|
||||
// large. The capability is advertised per dialer rather than hardcoded, so a
|
||||
// new transport only needs to declare whether it is datagram-sized.
|
||||
type DatagramSized interface {
|
||||
DatagramSized()
|
||||
}
|
||||
|
||||
// IsDatagramSized reports whether d produces datagram-sized connections.
|
||||
func IsDatagramSized(d DialeFn) bool {
|
||||
_, ok := d.(DatagramSized)
|
||||
return ok
|
||||
}
|
||||
@@ -4,9 +4,4 @@ import "errors"
|
||||
|
||||
var (
|
||||
ErrClosedByServer = errors.New("closed by server")
|
||||
|
||||
// ErrDatagramTooLarge is returned when a transport message exceeds the
|
||||
// QUIC datagram size the path to the relay can carry. The relay client
|
||||
// treats it as a signal to fall back to a non-datagram transport.
|
||||
ErrDatagramTooLarge = errors.New("datagram frame too large")
|
||||
)
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
netErr "github.com/netbirdio/netbird/shared/relay/client/dialer/net"
|
||||
)
|
||||
@@ -51,8 +52,11 @@ func (c *Conn) Read(b []byte) (n int, err error) {
|
||||
}
|
||||
|
||||
func (c *Conn) Write(b []byte) (int, error) {
|
||||
if err := c.session.SendDatagram(b); err != nil {
|
||||
return 0, c.writeErrHandling(err, len(b))
|
||||
err := c.session.SendDatagram(b)
|
||||
if err != nil {
|
||||
err = c.remoteCloseErrHandling(err)
|
||||
log.Errorf("failed to write to QUIC stream: %v", err)
|
||||
return 0, err
|
||||
}
|
||||
return len(b), nil
|
||||
}
|
||||
@@ -91,15 +95,3 @@ func (c *Conn) remoteCloseErrHandling(err error) error {
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// writeErrHandling normalizes SendDatagram errors. A datagram that exceeds the
|
||||
// path's QUIC packet budget is mapped to ErrDatagramTooLarge (annotated with the
|
||||
// datagram size and path budget) so the relay client can fall back to a
|
||||
// non-datagram transport.
|
||||
func (c *Conn) writeErrHandling(err error, size int) error {
|
||||
var tooLarge *quic.DatagramTooLargeError
|
||||
if errors.As(err, &tooLarge) {
|
||||
return fmt.Errorf("%w: %d byte datagram over path budget %d", netErr.ErrDatagramTooLarge, size, tooLarge.MaxDatagramPayloadSize)
|
||||
}
|
||||
return c.remoteCloseErrHandling(err)
|
||||
}
|
||||
|
||||
@@ -9,7 +9,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go"
|
||||
"github.com/quic-go/quic-go/logging"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
@@ -24,12 +23,6 @@ func (d Dialer) Protocol() string {
|
||||
return Network
|
||||
}
|
||||
|
||||
// DatagramSized marks QUIC as a datagram-sized transport: relay traffic is
|
||||
// carried in QUIC DATAGRAM frames, which must fit a single packet.
|
||||
func (d Dialer) DatagramSized() {
|
||||
// Intentional marker method; presence is the capability signal.
|
||||
}
|
||||
|
||||
func (d Dialer) Dial(ctx context.Context, address, serverName string) (net.Conn, error) {
|
||||
quicURL, err := prepareURL(address)
|
||||
if err != nil {
|
||||
@@ -54,7 +47,6 @@ func (d Dialer) Dial(ctx context.Context, address, serverName string) (net.Conn,
|
||||
MaxIdleTimeout: 4 * time.Minute,
|
||||
EnableDatagrams: true,
|
||||
InitialPacketSize: nbRelay.QUICInitialPacketSize,
|
||||
Tracer: connectionTracer(quicURL),
|
||||
}
|
||||
|
||||
udpConn, err := nbnet.ListenUDP("udp", &net.UDPAddr{Port: 0})
|
||||
@@ -82,28 +74,6 @@ func (d Dialer) Dial(ctx context.Context, address, serverName string) (net.Conn,
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// connectionTracer returns a QUIC tracer that logs the DPLPMTUD result and the
|
||||
// reason a relay connection closed, so the path MTU settled on and teardown
|
||||
// cause are visible in logs. Lines carry the relay address as a structured
|
||||
// field, matching the rest of the relay client logging.
|
||||
func connectionTracer(addr string) func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer {
|
||||
relayLog := log.WithField("relay", addr)
|
||||
return func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer {
|
||||
return &logging.ConnectionTracer{
|
||||
UpdatedMTU: func(mtu logging.ByteCount, done bool) {
|
||||
if done {
|
||||
relayLog.Infof("QUIC path MTU settled at %d", mtu)
|
||||
return
|
||||
}
|
||||
relayLog.Debugf("QUIC path MTU probing at %d", mtu)
|
||||
},
|
||||
ClosedConnection: func(err error) {
|
||||
relayLog.Debugf("QUIC connection closed: %v", err)
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func prepareURL(address string) (string, error) {
|
||||
var host string
|
||||
var defaultPort string
|
||||
|
||||
@@ -32,7 +32,6 @@ type RaceDial struct {
|
||||
serverName string
|
||||
dialerFns []DialeFn
|
||||
connectionTimeout time.Duration
|
||||
sequential bool
|
||||
}
|
||||
|
||||
func NewRaceDial(log *log.Entry, connectionTimeout time.Duration, serverURL string, dialerFns ...DialeFn) *RaceDial {
|
||||
@@ -54,21 +53,7 @@ func (r *RaceDial) WithServerName(serverName string) *RaceDial {
|
||||
return r
|
||||
}
|
||||
|
||||
// WithSequential makes Dial try the dialers in order, falling back to the next
|
||||
// only when one fails to connect, instead of racing them concurrently.
|
||||
//
|
||||
// Mutates the receiver and is not safe for concurrent reconfiguration; a
|
||||
// RaceDial is intended to be constructed per dial and discarded.
|
||||
func (r *RaceDial) WithSequential() *RaceDial {
|
||||
r.sequential = true
|
||||
return r
|
||||
}
|
||||
|
||||
func (r *RaceDial) Dial(ctx context.Context) (net.Conn, error) {
|
||||
if r.sequential {
|
||||
return r.dialSequential(ctx)
|
||||
}
|
||||
|
||||
connChan := make(chan dialResult, len(r.dialerFns))
|
||||
winnerConn := make(chan net.Conn, 1)
|
||||
abortCtx, abort := context.WithCancel(ctx)
|
||||
@@ -87,30 +72,6 @@ func (r *RaceDial) Dial(ctx context.Context) (net.Conn, error) {
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// dialSequential tries each dialer in order, returning the first connection and
|
||||
// falling back to the next on failure.
|
||||
func (r *RaceDial) dialSequential(ctx context.Context) (net.Conn, error) {
|
||||
for _, dfn := range r.dialerFns {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
attemptCtx, cancel := context.WithTimeout(ctx, r.connectionTimeout)
|
||||
r.log.Infof("dialing Relay server via %s", dfn.Protocol())
|
||||
conn, err := dfn.Dial(attemptCtx, r.serverURL, r.serverName)
|
||||
cancel()
|
||||
if err != nil {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return nil, err
|
||||
}
|
||||
r.log.Errorf("failed to dial via %s: %s", dfn.Protocol(), err)
|
||||
continue
|
||||
}
|
||||
r.log.Infof("successfully dialed via: %s", dfn.Protocol())
|
||||
return conn, nil
|
||||
}
|
||||
return nil, errors.New("failed to dial to Relay server on any protocol")
|
||||
}
|
||||
|
||||
func (r *RaceDial) dial(dfn DialeFn, abortCtx context.Context, connChan chan dialResult) {
|
||||
ctx, cancel := context.WithTimeout(abortCtx, r.connectionTimeout)
|
||||
defer cancel()
|
||||
|
||||
@@ -250,66 +250,3 @@ func TestRaceDialFirstSuccessfulDialerWins(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRaceDialSequentialFallback(t *testing.T) {
|
||||
logger := logrus.NewEntry(logrus.New())
|
||||
serverURL := "test.server.com"
|
||||
|
||||
var firstDialed, secondDialed bool
|
||||
preferred := &MockDialer{
|
||||
protocolStr: "quic",
|
||||
dialFunc: func(ctx context.Context, address string) (net.Conn, error) {
|
||||
firstDialed = true
|
||||
return nil, errors.New("quic unreachable")
|
||||
},
|
||||
}
|
||||
fallbackConn := &MockConn{remoteAddr: &MockAddr{network: "ws"}}
|
||||
fallback := &MockDialer{
|
||||
protocolStr: "ws",
|
||||
dialFunc: func(ctx context.Context, address string) (net.Conn, error) {
|
||||
secondDialed = true
|
||||
return fallbackConn, nil
|
||||
},
|
||||
}
|
||||
|
||||
rd := NewRaceDial(logger, DefaultConnectionTimeout, serverURL, preferred, fallback).WithSequential()
|
||||
conn, err := rd.Dial(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("expected fallback to succeed, got %v", err)
|
||||
}
|
||||
if conn != fallbackConn {
|
||||
t.Errorf("expected fallback connection, got %v", conn)
|
||||
}
|
||||
if !firstDialed || !secondDialed {
|
||||
t.Errorf("expected both dialers attempted in order, first=%v second=%v", firstDialed, secondDialed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRaceDialSequentialPreferredWins(t *testing.T) {
|
||||
logger := logrus.NewEntry(logrus.New())
|
||||
serverURL := "test.server.com"
|
||||
|
||||
preferredConn := &MockConn{remoteAddr: &MockAddr{network: "quic"}}
|
||||
preferred := &MockDialer{
|
||||
protocolStr: "quic",
|
||||
dialFunc: func(ctx context.Context, address string) (net.Conn, error) {
|
||||
return preferredConn, nil
|
||||
},
|
||||
}
|
||||
fallback := &MockDialer{
|
||||
protocolStr: "ws",
|
||||
dialFunc: func(ctx context.Context, address string) (net.Conn, error) {
|
||||
t.Errorf("fallback dialer must not be tried when preferred succeeds")
|
||||
return nil, errors.New("should not happen")
|
||||
},
|
||||
}
|
||||
|
||||
rd := NewRaceDial(logger, DefaultConnectionTimeout, serverURL, preferred, fallback).WithSequential()
|
||||
conn, err := rd.Dial(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("expected preferred to succeed, got %v", err)
|
||||
}
|
||||
if conn != preferredConn {
|
||||
t.Errorf("expected preferred connection, got %v", conn)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,42 +9,11 @@ import (
|
||||
"github.com/netbirdio/netbird/shared/relay/client/dialer/ws"
|
||||
)
|
||||
|
||||
// getDialers returns the ordered dialers for connecting to the relay server. It
|
||||
// applies the datagram fallback generically: if this server recently rejected a
|
||||
// datagram-sized transport, those dialers are dropped, leaving the rest.
|
||||
func (c *Client) getDialers(mode TransportMode) []dialer.DialeFn {
|
||||
dialers := c.baseDialers(mode)
|
||||
|
||||
if c.transportFallback != nil && c.transportFallback.avoidDatagramSized(c.connectionURL) {
|
||||
if filtered := nonDatagramSized(dialers); len(filtered) > 0 {
|
||||
c.log.Infof("relay recently rejected a datagram-sized transport, avoiding it")
|
||||
return filtered
|
||||
}
|
||||
}
|
||||
return dialers
|
||||
}
|
||||
|
||||
// baseDialers returns the ordered dialers for the mode, before any datagram
|
||||
// fallback filtering. For racing modes (auto) the order is irrelevant; for
|
||||
// prefer modes the first entry is tried before falling back to the second.
|
||||
func (c *Client) baseDialers(mode TransportMode) []dialer.DialeFn {
|
||||
switch mode {
|
||||
case TransportModeWS:
|
||||
c.log.Infof("%s=ws, using WebSocket transport", EnvRelayTransport)
|
||||
return []dialer.DialeFn{ws.Dialer{}}
|
||||
case TransportModeQUIC:
|
||||
c.log.Infof("%s=quic, using QUIC transport", EnvRelayTransport)
|
||||
return []dialer.DialeFn{quic.Dialer{}}
|
||||
}
|
||||
|
||||
all := []dialer.DialeFn{quic.Dialer{}, ws.Dialer{}}
|
||||
if mode == TransportModePreferWS {
|
||||
all = []dialer.DialeFn{ws.Dialer{}, quic.Dialer{}}
|
||||
}
|
||||
|
||||
// getDialers returns the list of dialers to use for connecting to the relay server.
|
||||
func (c *Client) getDialers() []dialer.DialeFn {
|
||||
if c.mtu > 0 && c.mtu > iface.DefaultMTU {
|
||||
c.log.Infof("MTU %d exceeds default (%d), avoiding datagram-sized transports", c.mtu, iface.DefaultMTU)
|
||||
return nonDatagramSized(all)
|
||||
c.log.Infof("MTU %d exceeds default (%d), forcing WebSocket transport to avoid DATAGRAM frame size issues", c.mtu, iface.DefaultMTU)
|
||||
return []dialer.DialeFn{ws.Dialer{}}
|
||||
}
|
||||
return all
|
||||
return []dialer.DialeFn{quic.Dialer{}, ws.Dialer{}}
|
||||
}
|
||||
|
||||
@@ -1,101 +0,0 @@
|
||||
//go:build !js
|
||||
|
||||
package client
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/shared/relay/client/dialer"
|
||||
netErr "github.com/netbirdio/netbird/shared/relay/client/dialer/net"
|
||||
"github.com/netbirdio/netbird/shared/relay/client/dialer/quic"
|
||||
"github.com/netbirdio/netbird/shared/relay/client/dialer/ws"
|
||||
)
|
||||
|
||||
// TestDatagramSizedCapability locks the capability the generic fallback relies
|
||||
// on: QUIC is datagram-sized, WebSocket is not.
|
||||
func TestDatagramSizedCapability(t *testing.T) {
|
||||
assert.True(t, dialer.IsDatagramSized(quic.Dialer{}), "QUIC must advertise datagram-sized")
|
||||
assert.False(t, dialer.IsDatagramSized(ws.Dialer{}), "WebSocket must not advertise datagram-sized")
|
||||
}
|
||||
|
||||
func protocols(dialers []dialer.DialeFn) []string {
|
||||
out := make([]string, len(dialers))
|
||||
for i, d := range dialers {
|
||||
out[i] = d.Protocol()
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func TestGetDialers(t *testing.T) {
|
||||
const url = "rels://relay.example:443"
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
mode string
|
||||
mtu uint16
|
||||
preferWS bool
|
||||
want []string
|
||||
}{
|
||||
{name: "auto races quic and ws", mode: "auto", mtu: iface.DefaultMTU, want: []string{"quic", "WS"}},
|
||||
{name: "ws pinned", mode: "ws", mtu: iface.DefaultMTU, want: []string{"WS"}},
|
||||
{name: "quic pinned", mode: "quic", mtu: iface.DefaultMTU, want: []string{"quic"}},
|
||||
{name: "prefer-quic orders quic first", mode: "prefer-quic", mtu: iface.DefaultMTU, want: []string{"quic", "WS"}},
|
||||
{name: "prefer-ws orders ws first", mode: "prefer-ws", mtu: iface.DefaultMTU, want: []string{"WS", "quic"}},
|
||||
{name: "mtu above default forces ws", mode: "auto", mtu: iface.DefaultMTU + 100, want: []string{"WS"}},
|
||||
{name: "sticky fallback forces ws in auto", mode: "auto", mtu: iface.DefaultMTU, preferWS: true, want: []string{"WS"}},
|
||||
{name: "sticky fallback forces ws in prefer-quic", mode: "prefer-quic", mtu: iface.DefaultMTU, preferWS: true, want: []string{"WS"}},
|
||||
{name: "quic pin overrides sticky fallback", mode: "quic", mtu: iface.DefaultMTU, preferWS: true, want: []string{"quic"}},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Setenv(EnvRelayTransport, tc.mode)
|
||||
if tc.mode == "" {
|
||||
os.Unsetenv(EnvRelayTransport)
|
||||
}
|
||||
|
||||
tf := newTransportFallback()
|
||||
if tc.preferWS {
|
||||
tf.recordFailure(url)
|
||||
}
|
||||
|
||||
c := &Client{
|
||||
log: log.WithField("test", t.Name()),
|
||||
connectionURL: url,
|
||||
mtu: tc.mtu,
|
||||
transportFallback: tf,
|
||||
}
|
||||
|
||||
assert.Equal(t, tc.want, protocols(c.getDialers(transportModeFromEnv())))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestStickyFallbackAfterDatagramTooLarge verifies the full chain: an oversized
|
||||
// datagram records a fallback that makes the next dial pick WebSocket, the way a
|
||||
// reconnect would after the connection is closed.
|
||||
func TestStickyFallbackAfterDatagramTooLarge(t *testing.T) {
|
||||
const url = "rels://relay.example:443"
|
||||
t.Setenv(EnvRelayTransport, string(TransportModeAuto))
|
||||
|
||||
c := &Client{
|
||||
log: log.WithField("test", t.Name()),
|
||||
connectionURL: url,
|
||||
mtu: iface.DefaultMTU,
|
||||
transportFallback: newTransportFallback(),
|
||||
}
|
||||
|
||||
// First dial races both transports.
|
||||
assert.Equal(t, []string{"quic", "WS"}, protocols(c.getDialers(transportModeFromEnv())))
|
||||
|
||||
// An oversized datagram records the fallback for this server.
|
||||
c.onDatagramTooLarge(&closeTrackingConn{}, netErr.ErrDatagramTooLarge)
|
||||
|
||||
// The reconnect now sticks to WebSocket.
|
||||
assert.Equal(t, []string{"WS"}, protocols(c.getDialers(transportModeFromEnv())))
|
||||
}
|
||||
@@ -7,11 +7,7 @@ import (
|
||||
"github.com/netbirdio/netbird/shared/relay/client/dialer/ws"
|
||||
)
|
||||
|
||||
func (c *Client) getDialers(_ TransportMode) []dialer.DialeFn {
|
||||
func (c *Client) getDialers() []dialer.DialeFn {
|
||||
// JS/WASM build only uses WebSocket transport
|
||||
return []dialer.DialeFn{ws.Dialer{}}
|
||||
}
|
||||
|
||||
func (c *Client) baseDialers(_ TransportMode) []dialer.DialeFn {
|
||||
return []dialer.DialeFn{ws.Dialer{}}
|
||||
}
|
||||
|
||||
@@ -79,30 +79,23 @@ type Manager struct {
|
||||
|
||||
cleanupInterval time.Duration
|
||||
keepUnusedServerTime time.Duration
|
||||
|
||||
// transportFallback is shared across home and foreign relay clients so a
|
||||
// datagram-too-large failure makes that server avoid datagram-sized transports across reconnects.
|
||||
transportFallback *transportFallback
|
||||
}
|
||||
|
||||
// NewManager creates a new manager instance.
|
||||
// The serverURL address can be empty. In this case, the manager will not serve.
|
||||
func NewManager(ctx context.Context, serverURLs []string, peerID string, mtu uint16, opts ...ManagerOption) *Manager {
|
||||
tokenStore := &relayAuth.TokenStore{}
|
||||
tf := newTransportFallback()
|
||||
|
||||
m := &Manager{
|
||||
ctx: ctx,
|
||||
peerID: peerID,
|
||||
tokenStore: tokenStore,
|
||||
mtu: mtu,
|
||||
transportFallback: tf,
|
||||
ctx: ctx,
|
||||
peerID: peerID,
|
||||
tokenStore: tokenStore,
|
||||
mtu: mtu,
|
||||
serverPicker: &ServerPicker{
|
||||
TokenStore: tokenStore,
|
||||
PeerID: peerID,
|
||||
MTU: mtu,
|
||||
ConnectionTimeout: defaultConnectionTimeout,
|
||||
TransportFallback: tf,
|
||||
},
|
||||
relayClients: make(map[string]*RelayTrack),
|
||||
onDisconnectedListeners: make(map[string]*list.List),
|
||||
@@ -294,7 +287,6 @@ func (m *Manager) openConnVia(ctx context.Context, serverAddress, peerKey string
|
||||
m.relayClientsMutex.Unlock()
|
||||
|
||||
relayClient := NewClientWithServerIP(serverAddress, serverIP, m.tokenStore, m.peerID, m.mtu)
|
||||
relayClient.SetTransportFallback(m.transportFallback)
|
||||
err := relayClient.Connect(m.ctx)
|
||||
if err != nil {
|
||||
rt.err = err
|
||||
|
||||
@@ -29,7 +29,6 @@ type ServerPicker struct {
|
||||
PeerID string
|
||||
MTU uint16
|
||||
ConnectionTimeout time.Duration
|
||||
TransportFallback *transportFallback
|
||||
}
|
||||
|
||||
func (sp *ServerPicker) PickServer(parentCtx context.Context) (*Client, error) {
|
||||
@@ -71,7 +70,6 @@ func (sp *ServerPicker) PickServer(parentCtx context.Context) (*Client, error) {
|
||||
func (sp *ServerPicker) startConnection(ctx context.Context, resultChan chan connResult, url string) {
|
||||
log.Infof("try to connecting to relay server: %s", url)
|
||||
relayClient := NewClient(url, sp.TokenStore, sp.PeerID, sp.MTU)
|
||||
relayClient.SetTransportFallback(sp.TransportFallback)
|
||||
err := relayClient.Connect(ctx)
|
||||
resultChan <- connResult{
|
||||
RelayClient: relayClient,
|
||||
|
||||
@@ -1,129 +0,0 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/relay/client/dialer"
|
||||
)
|
||||
|
||||
// EnvRelayTransport pins the relay transport. Valid values: "auto" (default,
|
||||
// race QUIC and WebSocket), "quic" (QUIC only), "ws" (WebSocket only),
|
||||
// "prefer-quic" / "prefer-ws" (try the preferred transport first, fall back to
|
||||
// the other only if it fails to connect; no race). The prefer modes trade a
|
||||
// slower connect when the preferred transport is blackholed for deterministic
|
||||
// transport selection.
|
||||
const EnvRelayTransport = "NB_RELAY_TRANSPORT"
|
||||
|
||||
const (
|
||||
// transportFallbackBase is the initial window a relay server avoids
|
||||
// datagram-sized transports after a datagram is rejected as too large.
|
||||
transportFallbackBase = 10 * time.Minute
|
||||
// transportFallbackMax caps the pinned window when failures repeat.
|
||||
transportFallbackMax = 60 * time.Minute
|
||||
)
|
||||
|
||||
// TransportMode selects which relay dialers are used.
|
||||
type TransportMode string
|
||||
|
||||
const (
|
||||
TransportModeAuto TransportMode = "auto"
|
||||
TransportModeQUIC TransportMode = "quic"
|
||||
TransportModeWS TransportMode = "ws"
|
||||
TransportModePreferQUIC TransportMode = "prefer-quic"
|
||||
TransportModePreferWS TransportMode = "prefer-ws"
|
||||
)
|
||||
|
||||
// transportModeFromEnv reads EnvRelayTransport, defaulting to auto for an empty
|
||||
// or unrecognized value.
|
||||
func transportModeFromEnv() TransportMode {
|
||||
switch TransportMode(strings.ToLower(strings.TrimSpace(os.Getenv(EnvRelayTransport)))) {
|
||||
case "", TransportModeAuto:
|
||||
return TransportModeAuto
|
||||
case TransportModeQUIC:
|
||||
return TransportModeQUIC
|
||||
case TransportModeWS:
|
||||
return TransportModeWS
|
||||
case TransportModePreferQUIC:
|
||||
return TransportModePreferQUIC
|
||||
case TransportModePreferWS:
|
||||
return TransportModePreferWS
|
||||
default:
|
||||
log.Warnf("invalid %s value %q, using %q", EnvRelayTransport, os.Getenv(EnvRelayTransport), TransportModeAuto)
|
||||
return TransportModeAuto
|
||||
}
|
||||
}
|
||||
|
||||
// sequential reports whether the mode tries dialers in order with fallback
|
||||
// instead of racing them concurrently.
|
||||
func (m TransportMode) sequential() bool {
|
||||
return m == TransportModePreferQUIC || m == TransportModePreferWS
|
||||
}
|
||||
|
||||
// transportFallback tracks relay servers that have rejected a datagram-sized
|
||||
// transport (a write too large for the path) and should temporarily avoid such
|
||||
// transports. It is shared across the relay manager so the preference survives
|
||||
// client recreation (foreign relay clients are evicted and rebuilt on
|
||||
// disconnect). Entries are keyed by server URL and expire after a window that
|
||||
// grows on repeated failures.
|
||||
type transportFallback struct {
|
||||
mu sync.Mutex
|
||||
entries map[string]*fallbackEntry
|
||||
}
|
||||
|
||||
type fallbackEntry struct {
|
||||
until time.Time
|
||||
duration time.Duration
|
||||
}
|
||||
|
||||
func newTransportFallback() *transportFallback {
|
||||
return &transportFallback{entries: make(map[string]*fallbackEntry)}
|
||||
}
|
||||
|
||||
// avoidDatagramSized reports whether serverURL is currently within a window
|
||||
// where datagram-sized transports should be avoided.
|
||||
func (f *transportFallback) avoidDatagramSized(serverURL string) bool {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
e := f.entries[serverURL]
|
||||
return e != nil && time.Now().Before(e.until)
|
||||
}
|
||||
|
||||
// recordFailure makes serverURL avoid datagram-sized transports for a window:
|
||||
// transportFallbackBase on the first failure, doubling up to transportFallbackMax
|
||||
// when a datagram transport fails again after a previous window expired. It
|
||||
// returns the active window duration.
|
||||
func (f *transportFallback) recordFailure(serverURL string) time.Duration {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
e := f.entries[serverURL]
|
||||
switch {
|
||||
case e == nil:
|
||||
e = &fallbackEntry{duration: transportFallbackBase}
|
||||
f.entries[serverURL] = e
|
||||
case now.Before(e.until):
|
||||
return time.Until(e.until)
|
||||
default:
|
||||
e.duration = min(e.duration*2, transportFallbackMax)
|
||||
}
|
||||
e.until = now.Add(e.duration)
|
||||
return e.duration
|
||||
}
|
||||
|
||||
// nonDatagramSized returns the dialers from in that are not datagram-sized,
|
||||
// preserving order.
|
||||
func nonDatagramSized(in []dialer.DialeFn) []dialer.DialeFn {
|
||||
out := make([]dialer.DialeFn, 0, len(in))
|
||||
for _, d := range in {
|
||||
if !dialer.IsDatagramSized(d) {
|
||||
out = append(out, d)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
@@ -1,140 +0,0 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"net"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
netErr "github.com/netbirdio/netbird/shared/relay/client/dialer/net"
|
||||
)
|
||||
|
||||
// closeTrackingConn records whether Close was called; only Close is exercised.
|
||||
type closeTrackingConn struct {
|
||||
net.Conn
|
||||
closed bool
|
||||
}
|
||||
|
||||
func (c *closeTrackingConn) Close() error {
|
||||
c.closed = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestTransportModeFromEnv(t *testing.T) {
|
||||
tests := []struct {
|
||||
value string
|
||||
want TransportMode
|
||||
}{
|
||||
{"", TransportModeAuto},
|
||||
{"auto", TransportModeAuto},
|
||||
{"quic", TransportModeQUIC},
|
||||
{"QUIC", TransportModeQUIC},
|
||||
{"ws", TransportModeWS},
|
||||
{" Ws ", TransportModeWS},
|
||||
{"prefer-quic", TransportModePreferQUIC},
|
||||
{"prefer-ws", TransportModePreferWS},
|
||||
{"garbage", TransportModeAuto},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.value, func(t *testing.T) {
|
||||
t.Setenv(EnvRelayTransport, tc.value)
|
||||
if tc.value == "" {
|
||||
os.Unsetenv(EnvRelayTransport)
|
||||
}
|
||||
assert.Equal(t, tc.want, transportModeFromEnv())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTransportFallbackRecordAndExpiry(t *testing.T) {
|
||||
const url = "rels://relay.example:443"
|
||||
f := newTransportFallback()
|
||||
|
||||
assert.False(t, f.avoidDatagramSized(url), "no fallback recorded yet")
|
||||
|
||||
d := f.recordFailure(url)
|
||||
assert.Equal(t, transportFallbackBase, d, "first failure pins for the base window")
|
||||
assert.True(t, f.avoidDatagramSized(url), "datagram-sized transport avoided within the window")
|
||||
|
||||
// A second failure while still inside the window must not grow the window.
|
||||
d = f.recordFailure(url)
|
||||
assert.LessOrEqual(t, d, transportFallbackBase, "still within the active window")
|
||||
require.NotNil(t, f.entries[url])
|
||||
assert.Equal(t, transportFallbackBase, f.entries[url].duration, "duration unchanged inside window")
|
||||
|
||||
// Expire the window: datagram-sized transport allowed again.
|
||||
f.entries[url].until = time.Now().Add(-time.Second)
|
||||
assert.False(t, f.avoidDatagramSized(url), "window expired, datagram-sized transport allowed")
|
||||
}
|
||||
|
||||
func TestTransportFallbackGrowsOnRepeat(t *testing.T) {
|
||||
const url = "rels://relay.example:443"
|
||||
f := newTransportFallback()
|
||||
|
||||
want := transportFallbackBase
|
||||
for i := range 6 {
|
||||
d := f.recordFailure(url)
|
||||
assert.Equal(t, want, d, "window after %d expiries", i)
|
||||
|
||||
// expire the window so the next failure is treated as a repeat
|
||||
f.entries[url].until = time.Now().Add(-time.Second)
|
||||
|
||||
want = min(want*2, transportFallbackMax)
|
||||
}
|
||||
|
||||
assert.Equal(t, transportFallbackMax, f.entries[url].duration, "window caps at the max")
|
||||
}
|
||||
|
||||
func TestOnDatagramTooLargeAuto(t *testing.T) {
|
||||
const url = "rels://relay.example:443"
|
||||
t.Setenv(EnvRelayTransport, string(TransportModeAuto))
|
||||
|
||||
tf := newTransportFallback()
|
||||
c := &Client{
|
||||
log: log.WithField("test", t.Name()),
|
||||
connectionURL: url,
|
||||
transportFallback: tf,
|
||||
}
|
||||
conn := &closeTrackingConn{}
|
||||
|
||||
c.onDatagramTooLarge(conn, netErr.ErrDatagramTooLarge)
|
||||
|
||||
assert.True(t, conn.closed, "connection closed to force reconnect")
|
||||
assert.True(t, tf.avoidDatagramSized(url), "fallback recorded for the server")
|
||||
|
||||
// A second oversized datagram on the same connection must not re-close.
|
||||
conn.closed = false
|
||||
c.onDatagramTooLarge(conn, netErr.ErrDatagramTooLarge)
|
||||
assert.False(t, conn.closed, "single fallback per connection")
|
||||
}
|
||||
|
||||
func TestOnDatagramTooLargeQUICPinned(t *testing.T) {
|
||||
const url = "rels://relay.example:443"
|
||||
t.Setenv(EnvRelayTransport, string(TransportModeQUIC))
|
||||
|
||||
tf := newTransportFallback()
|
||||
c := &Client{
|
||||
log: log.WithField("test", t.Name()),
|
||||
connectionURL: url,
|
||||
transportFallback: tf,
|
||||
}
|
||||
conn := &closeTrackingConn{}
|
||||
|
||||
c.onDatagramTooLarge(conn, netErr.ErrDatagramTooLarge)
|
||||
|
||||
assert.False(t, conn.closed, "QUIC pin keeps the connection, no fallback redial")
|
||||
assert.False(t, tf.avoidDatagramSized(url), "QUIC pin records no fallback")
|
||||
}
|
||||
|
||||
func TestTransportFallbackPerServer(t *testing.T) {
|
||||
f := newTransportFallback()
|
||||
f.recordFailure("rels://a.example:443")
|
||||
|
||||
assert.True(t, f.avoidDatagramSized("rels://a.example:443"))
|
||||
assert.False(t, f.avoidDatagramSized("rels://b.example:443"), "fallback is scoped to one server")
|
||||
}
|
||||
Reference in New Issue
Block a user