Compare commits

...

10 Commits

Author SHA1 Message Date
Viktor Liu
e2b4f2c0bf Expose relay transport and connection errors in status and metrics 2026-06-05 13:06:25 +02:00
Viktor Liu
d50241b75d Move nonDatagramSized below the transportFallback methods 2026-06-05 08:53:43 +02:00
Viktor Liu
3bad4c939b Run relay datagram fallback once per connection and annotate the datagram-sized marker 2026-06-04 14:29:19 +02:00
Viktor Liu
ac70eb2c8e Make relay datagram fallback transport-agnostic via a dialer capability 2026-06-04 14:15:36 +02:00
Viktor Liu
36f346c0fe Log skipped relay fallback and bail sequential dial on cancelled context 2026-06-04 13:18:23 +02:00
Viktor Liu
fbd97d6da5 Raise wasm build size limit to 60MB 2026-06-04 13:14:15 +02:00
Viktor Liu
3435267a23 Add relay transport selection and fall back to WebSocket on oversized QUIC datagrams 2026-06-04 12:59:48 +02:00
Maycon Santos
eac6d501c3 [infrastructure] allow docker image overrides for getting started (#6335)
* [infrastructure] allow docker image overrides for getting started

Make dashboard and server image configurations overrideable via environment variables

* [infrastructure] update Traefik gRPC rule to include ProxyService PathPrefix

* make Traefik and CrowdSec images configurable via environment variables
2026-06-04 11:24:47 +02:00
Maycon Santos
deeae30612 [misc] Add Codecov integration and coverage reporting across workflows (#6333) 2026-06-03 19:08:45 +02:00
Bethuel Mmbaga
f3cdf163e1 [management] Export ResolveDomain (#6334) 2026-06-03 19:53:57 +03:00
34 changed files with 1000 additions and 80 deletions

View File

@@ -45,4 +45,11 @@ jobs:
run: git --no-pager diff --exit-code
- name: Test
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=devcert -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined)
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -coverprofile=coverage.txt -tags=devcert -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined)
- name: Upload coverage reports to Codecov
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 #v6.0.1
with:
token: ${{ secrets.CODECOV_TOKEN }}
slug: netbirdio/netbird
flags: unit,client

View File

@@ -158,7 +158,16 @@ jobs:
run: git --no-pager diff --exit-code
- name: Test
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined)
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -coverprofile=coverage.txt -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined)
- name: Upload coverage reports to Codecov
if: matrix.arch == 'amd64'
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 #v6.0.1
with:
token: ${{ secrets.CODECOV_TOKEN }}
slug: netbirdio/netbird
flags: unit,client
test_client_on_docker:
name: "Client (Docker) / Unit"
@@ -276,9 +285,17 @@ jobs:
run: |
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
go test ${{ matrix.raceFlag }} \
-exec 'sudo' \
-exec 'sudo' -coverprofile=coverage.txt \
-timeout 10m -p 1 ./relay/... ./shared/relay/...
- name: Upload coverage reports to Codecov
if: matrix.arch == 'amd64'
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 #v6.0.1
with:
token: ${{ secrets.CODECOV_TOKEN }}
slug: netbirdio/netbird
flags: unit,relay
test_proxy:
name: "Proxy / Unit"
needs: [build-cache]
@@ -326,7 +343,15 @@ jobs:
- name: Test
run: |
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
go test -timeout 10m -p 1 ./proxy/...
go test -timeout 10m -p 1 -coverprofile=coverage.txt ./proxy/...
- name: Upload coverage reports to Codecov
if: matrix.arch == 'amd64'
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 #v6.0.1
with:
token: ${{ secrets.CODECOV_TOKEN }}
slug: netbirdio/netbird
flags: unit,proxy
test_signal:
name: "Signal / Unit"
@@ -377,9 +402,17 @@ jobs:
run: |
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
go test \
-exec 'sudo' \
-exec 'sudo' -coverprofile=coverage.txt \
-timeout 10m ./signal/... ./shared/signal/...
- name: Upload coverage reports to Codecov
if: matrix.arch == 'amd64'
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 #v6.0.1
with:
token: ${{ secrets.CODECOV_TOKEN }}
slug: netbirdio/netbird
flags: unit,signal
test_management:
name: "Management / Unit"
needs: [build-cache]
@@ -445,10 +478,18 @@ jobs:
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
NETBIRD_STORE_ENGINE=${{ matrix.store }} \
CI=true \
go test -tags=devcert \
go test -tags=devcert -coverprofile=coverage.txt \
-exec "sudo --preserve-env=CI,NETBIRD_STORE_ENGINE" \
-timeout 20m ./management/... ./shared/management/...
- name: Upload coverage reports to Codecov
if: matrix.arch == 'amd64'
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 #v6.0.1
with:
token: ${{ secrets.CODECOV_TOKEN }}
slug: netbirdio/netbird
flags: unit,management
benchmark:
name: "Management / Benchmark"
needs: [build-cache]
@@ -687,6 +728,14 @@ jobs:
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
NETBIRD_STORE_ENGINE=${{ matrix.store }} \
CI=true \
go test -tags=integration \
go test -tags=integration -coverprofile=coverage.txt \
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \
-timeout 20m ./management/server/http/...
- name: Upload coverage reports to Codecov
if: matrix.arch == 'amd64'
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 #v6.0.1
with:
token: ${{ secrets.CODECOV_TOKEN }}
slug: netbirdio/netbird
flags: integration,management

View File

@@ -65,7 +65,7 @@ jobs:
echo "Size: ${SIZE} bytes (${SIZE_MB} MB)"
if [ ${SIZE} -gt 58720256 ]; then
echo "Wasm binary size (${SIZE_MB}MB) exceeds 56MB limit!"
if [ ${SIZE} -gt 62914560 ]; then
echo "Wasm binary size (${SIZE_MB}MB) exceeds 60MB limit!"
exit 1
fi

View File

@@ -1015,14 +1015,17 @@ func (d *Status) GetRelayStates() []relay.ProbeResult {
return d.relayStates
}
// extend the list of stun, turn servers with relay address
// extend the list of stun, turn servers with the relay server connections
relayStates := slices.Clone(d.relayStates)
// if the server connection is not established then we will use the general address
// in case of connection we will use the instance specific address
instanceAddr, _, err := d.relayMgr.RelayInstanceAddress()
if err != nil {
// TODO add their status
states := d.relayMgr.RelayStates()
if len(states) == 0 {
// no relay connection tracked yet; surface configured servers as
// unavailable with the real reconnect error when known
err := relayClient.ErrRelayClientNotConnected
if connErr := d.relayMgr.RelayConnectError(); connErr != nil {
err = connErr
}
for _, r := range d.relayMgr.ServerURLs() {
relayStates = append(relayStates, relay.ProbeResult{
URI: r,
@@ -1032,10 +1035,14 @@ func (d *Status) GetRelayStates() []relay.ProbeResult {
return relayStates
}
relayState := relay.ProbeResult{
URI: instanceAddr,
for _, rs := range states {
relayStates = append(relayStates, relay.ProbeResult{
URI: rs.URL,
Err: rs.Err,
Transport: rs.Transport,
})
}
return append(relayStates, relayState)
return relayStates
}
func (d *Status) ForwardingRules() []firewall.ForwardRule {
@@ -1395,6 +1402,7 @@ func (fs FullStatus) ToProto() *proto.FullStatus {
pbRelayState := &proto.RelayState{
URI: relayState.URI,
Available: relayState.Err == nil,
Transport: relayState.Transport,
}
if err := relayState.Err; err != nil {
pbRelayState.Error = err.Error()

View File

@@ -32,6 +32,9 @@ type ProbeResult struct {
URI string
Err error
Addr string
// Transport is the negotiated relay transport, empty
// for stun/turn probes or when not connected.
Transport string
}
type StunTurnProbe struct {

View File

@@ -1832,6 +1832,7 @@ type RelayState struct {
URI string `protobuf:"bytes,1,opt,name=URI,proto3" json:"URI,omitempty"`
Available bool `protobuf:"varint,2,opt,name=available,proto3" json:"available,omitempty"`
Error string `protobuf:"bytes,3,opt,name=error,proto3" json:"error,omitempty"`
Transport string `protobuf:"bytes,4,opt,name=transport,proto3" json:"transport,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
@@ -1887,6 +1888,13 @@ func (x *RelayState) GetError() string {
return ""
}
func (x *RelayState) GetTransport() string {
if x != nil {
return x.Transport
}
return ""
}
type NSGroupState struct {
state protoimpl.MessageState `protogen:"open.v1"`
Servers []string `protobuf:"bytes,1,rep,name=servers,proto3" json:"servers,omitempty"`
@@ -6406,12 +6414,13 @@ const file_daemon_proto_rawDesc = "" +
"\x0fManagementState\x12\x10\n" +
"\x03URL\x18\x01 \x01(\tR\x03URL\x12\x1c\n" +
"\tconnected\x18\x02 \x01(\bR\tconnected\x12\x14\n" +
"\x05error\x18\x03 \x01(\tR\x05error\"R\n" +
"\x05error\x18\x03 \x01(\tR\x05error\"p\n" +
"\n" +
"RelayState\x12\x10\n" +
"\x03URI\x18\x01 \x01(\tR\x03URI\x12\x1c\n" +
"\tavailable\x18\x02 \x01(\bR\tavailable\x12\x14\n" +
"\x05error\x18\x03 \x01(\tR\x05error\"r\n" +
"\x05error\x18\x03 \x01(\tR\x05error\x12\x1c\n" +
"\ttransport\x18\x04 \x01(\tR\ttransport\"r\n" +
"\fNSGroupState\x12\x18\n" +
"\aservers\x18\x01 \x03(\tR\aservers\x12\x18\n" +
"\adomains\x18\x02 \x03(\tR\adomains\x12\x18\n" +

View File

@@ -370,6 +370,7 @@ message RelayState {
string URI = 1;
bool available = 2;
string error = 3;
string transport = 4;
}
message NSGroupState {

View File

@@ -98,6 +98,7 @@ type RelayStateOutputDetail struct {
URI string `json:"uri" yaml:"uri"`
Available bool `json:"available" yaml:"available"`
Error string `json:"error" yaml:"error"`
Transport string `json:"transport,omitempty" yaml:"transport,omitempty"`
}
type RelayStateOutput struct {
@@ -217,7 +218,8 @@ func mapRelays(relays []*proto.RelayState) RelayStateOutput {
RelayStateOutputDetail{
URI: relay.URI,
Available: available,
Error: relay.GetError(),
Error: relayErrorString(relay.GetError()),
Transport: relay.GetTransport(),
},
)
@@ -233,6 +235,12 @@ func mapRelays(relays []*proto.RelayState) RelayStateOutput {
}
}
// relayErrorString flattens a newline-joined aggregated relay error onto a
// single line for status output.
func relayErrorString(s string) string {
return strings.ReplaceAll(s, "\n", "; ")
}
func mapNSGroups(servers []*proto.NSGroupState) []NsServerGroupStateOutput {
mappedNSGroups := make([]NsServerGroupStateOutput, 0, len(servers))
for _, pbNsGroupServer := range servers {
@@ -439,6 +447,8 @@ func (o *OutputOverview) GeneralSummary(showURL bool, showRelays bool, showNameS
available = "Unavailable"
reason = fmt.Sprintf(", reason: %s", relay.Error)
}
} else if relay.Transport != "" {
available = fmt.Sprintf("%s via %s", available, relay.Transport)
}
relaysString += fmt.Sprintf("\n [%s] is %s%s", relay.URI, available, reason)

View File

@@ -641,3 +641,13 @@ func TestTimeAgo(t *testing.T) {
})
}
}
func TestMapRelaysTransport(t *testing.T) {
out := mapRelays([]*proto.RelayState{
{URI: "rels://relay.example:443", Available: true, Transport: "quic"},
{URI: "rels://relay2.example:443", Available: true, Transport: "ws"},
})
require.Len(t, out.Details, 2)
assert.Equal(t, "quic", out.Details[0].Transport)
assert.Equal(t, "ws", out.Details[1].Transport)
}

View File

@@ -311,11 +311,12 @@ initialize_default_values() {
NETBIRD_STUN_PORT=3478
# Docker images
DASHBOARD_IMAGE="netbirdio/dashboard:latest"
DASHBOARD_IMAGE=${DASHBOARD_IMAGE:-"netbirdio/dashboard:latest"}
# Combined server replaces separate signal, relay, and management containers
NETBIRD_SERVER_IMAGE="netbirdio/netbird-server:latest"
NETBIRD_PROXY_IMAGE="netbirdio/reverse-proxy:latest"
NETBIRD_SERVER_IMAGE=${NETBIRD_SERVER_IMAGE:-"netbirdio/netbird-server:latest"}
NETBIRD_PROXY_IMAGE=${NETBIRD_PROXY_IMAGE:-"netbirdio/reverse-proxy:latest"}
TRAEFIK_IMAGE=${TRAEFIK_IMAGE:-"traefik:v3.6"}
CROWDSEC_IMAGE=${CROWDSEC_IMAGE:-"crowdsecurity/crowdsec:v1.7.7"}
# Reverse proxy configuration
REVERSE_PROXY_TYPE="0"
TRAEFIK_EXTERNAL_NETWORK=""
@@ -656,7 +657,7 @@ render_docker_compose_traefik_builtin() {
if [[ "$ENABLE_CROWDSEC" == "true" ]]; then
crowdsec_service="
crowdsec:
image: crowdsecurity/crowdsec:v1.7.7
image: $CROWDSEC_IMAGE
container_name: netbird-crowdsec
restart: unless-stopped
networks: [netbird]
@@ -687,7 +688,7 @@ render_docker_compose_traefik_builtin() {
services:
# Traefik reverse proxy (automatic TLS via Let's Encrypt)
traefik:
image: traefik:v3.6
image: $TRAEFIK_IMAGE
container_name: netbird-traefik
restart: unless-stopped
networks:
@@ -771,7 +772,7 @@ $traefik_dynamic_volume
labels:
- traefik.enable=true
# gRPC router (needs h2c backend for HTTP/2 cleartext)
- traefik.http.routers.netbird-grpc.rule=Host(\`$NETBIRD_DOMAIN\`) && (PathPrefix(\`/signalexchange.SignalExchange/\`) || PathPrefix(\`/management.ManagementService/\`))
- traefik.http.routers.netbird-grpc.rule=Host(\`$NETBIRD_DOMAIN\`) && (PathPrefix(\`/signalexchange.SignalExchange/\`) || PathPrefix(\`/management.ManagementService/\`) || PathPrefix(\`/management.ProxyService/\`))
- traefik.http.routers.netbird-grpc.entrypoints=websecure
- traefik.http.routers.netbird-grpc.tls=true
- traefik.http.routers.netbird-grpc.tls.certresolver=letsencrypt

View File

@@ -122,7 +122,7 @@ func (s *BaseServer) Start(ctx context.Context) error {
s.errCh = make(chan error, 4)
if s.autoResolveDomains {
s.resolveDomains(srvCtx)
s.ResolveDomains(srvCtx)
}
s.PeersManager()
@@ -398,10 +398,10 @@ func (s *BaseServer) serveGRPCWithHTTP(ctx context.Context, listener net.Listene
}()
}
// resolveDomains determines dnsDomain and mgmtSingleAccModeDomain based on store state.
// ResolveDomains determines dnsDomain and mgmtSingleAccModeDomain based on store state.
// Fresh installs use the default self-hosted domain, while existing installs reuse the
// persisted account domain to keep addressing stable across config changes.
func (s *BaseServer) resolveDomains(ctx context.Context) {
func (s *BaseServer) ResolveDomains(ctx context.Context) {
st := s.Store()
setDefault := func(logMsg string, args ...any) {

View File

@@ -22,7 +22,7 @@ func TestResolveDomains_FreshInstallUsesDefault(t *testing.T) {
srv := NewServer(&Config{NbConfig: &nbconfig.Config{}})
Inject[store.Store](srv, mockStore)
srv.resolveDomains(context.Background())
srv.ResolveDomains(context.Background())
require.Equal(t, DefaultSelfHostedDomain, srv.dnsDomain)
require.Equal(t, DefaultSelfHostedDomain, srv.mgmtSingleAccModeDomain)
@@ -40,7 +40,7 @@ func TestResolveDomains_ExistingInstallUsesPersistedDomain(t *testing.T) {
srv := NewServer(&Config{NbConfig: &nbconfig.Config{}})
Inject[store.Store](srv, mockStore)
srv.resolveDomains(context.Background())
srv.ResolveDomains(context.Background())
require.Equal(t, "vpn.mycompany.com", srv.dnsDomain)
require.Equal(t, "vpn.mycompany.com", srv.mgmtSingleAccModeDomain)
@@ -56,7 +56,7 @@ func TestResolveDomains_StoreErrorFallsBackToDefault(t *testing.T) {
srv := NewServer(&Config{NbConfig: &nbconfig.Config{}})
Inject[store.Store](srv, mockStore)
srv.resolveDomains(context.Background())
srv.ResolveDomains(context.Background())
require.Equal(t, DefaultSelfHostedDomain, srv.dnsDomain)
require.Equal(t, DefaultSelfHostedDomain, srv.mgmtSingleAccModeDomain)

View File

@@ -6,6 +6,7 @@ import (
"time"
log "github.com/sirupsen/logrus"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/metric"
)
@@ -119,8 +120,8 @@ func NewMetrics(ctx context.Context, meter metric.Meter) (*Metrics, error) {
}
// PeerConnected increments the number of connected peers and increments number of idle connections
func (m *Metrics) PeerConnected(id string) {
m.peers.Add(m.ctx, 1)
func (m *Metrics) PeerConnected(id, transport string) {
m.peers.Add(m.ctx, 1, metric.WithAttributes(attribute.String("transport", transport)))
m.mutexActivity.Lock()
defer m.mutexActivity.Unlock()
@@ -138,8 +139,8 @@ func (m *Metrics) RecordPeerStoreTime(duration time.Duration) {
}
// PeerDisconnected decrements the number of connected peers and decrements number of idle or active connections
func (m *Metrics) PeerDisconnected(id string) {
m.peers.Add(m.ctx, -1)
func (m *Metrics) PeerDisconnected(id, transport string) {
m.peers.Add(m.ctx, -1, metric.WithAttributes(attribute.String("transport", transport)))
m.mutexActivity.Lock()
defer m.mutexActivity.Unlock()

View File

@@ -11,4 +11,6 @@ type Conn interface {
Write(ctx context.Context, b []byte) (n int, err error)
RemoteAddr() net.Addr
Close() error
// Protocol returns the transport name.
Protocol() string
}

View File

@@ -42,6 +42,11 @@ func (c *Conn) RemoteAddr() net.Addr {
return c.session.RemoteAddr()
}
// Protocol returns the transport name for this connection.
func (c *Conn) Protocol() string {
return "quic"
}
func (c *Conn) Close() error {
c.closedMu.Lock()
if c.closed {

View File

@@ -64,6 +64,11 @@ func (c *Conn) RemoteAddr() net.Addr {
return c.rAddr
}
// Protocol returns the transport name for this connection.
func (c *Conn) Protocol() string {
return "ws"
}
func (c *Conn) Close() error {
c.closedMu.Lock()
c.closed = true

View File

@@ -154,15 +154,16 @@ func (r *Relay) Accept(conn listener.Conn) {
}
r.notifier.PeerCameOnline(peer.ID())
transport := conn.Protocol()
r.metrics.RecordPeerStoreTime(time.Since(storeTime))
r.metrics.PeerConnected(peer.String())
r.metrics.PeerConnected(peer.String(), transport)
go func() {
peer.Work()
if deleted := r.store.DeletePeer(peer); deleted {
r.notifier.PeerWentOffline(peer.ID())
}
peer.log.Debugf("relay connection closed")
r.metrics.PeerDisconnected(peer.String())
r.metrics.PeerDisconnected(peer.String(), transport)
}()
if err := h.handshakeResponse(hsCtx); err != nil {

View File

@@ -9,12 +9,14 @@ import (
"net/url"
"strings"
"sync"
"sync/atomic"
"time"
log "github.com/sirupsen/logrus"
auth "github.com/netbirdio/netbird/shared/relay/auth/hmac"
"github.com/netbirdio/netbird/shared/relay/client/dialer"
netErr "github.com/netbirdio/netbird/shared/relay/client/dialer/net"
"github.com/netbirdio/netbird/shared/relay/healthcheck"
"github.com/netbirdio/netbird/shared/relay/messages"
)
@@ -143,6 +145,11 @@ func (cc *connContainer) close() {
}
}
// transportConn is implemented by relay connections that know their transport.
type transportConn interface {
Protocol() string
}
// Client is a client for the relay server. It is responsible for establishing a connection to the relay server and
// managing connections to other peers. All exported functions are safe to call concurrently. After close the connection,
// the client can be reused by calling Connect again. When the client is closed, all connections are closed too.
@@ -172,6 +179,31 @@ type Client struct {
stateSubscription *PeersStateSubscription
mtu uint16
// transportFallback, when set, records datagram-too-large failures so a
// datagram-sized transport is avoided on subsequent connects. Shared via
// the manager.
transportFallback *transportFallback
// datagramFallbackTriggered guards a single fallback per connection so a
// burst of oversized datagrams triggers one reconnect, not many.
datagramFallbackTriggered atomic.Bool
// transport is the negotiated relay transport of the
// current connection, guarded by mu.
transport string
}
// Transport returns the negotiated relay transport of the current connection,
// or an empty string when not connected.
func (c *Client) Transport() string {
c.mu.Lock()
defer c.mu.Unlock()
return c.transport
}
// SetTransportFallback wires the shared datagram-transport fallback tracker.
func (c *Client) SetTransportFallback(tf *transportFallback) {
c.transportFallback = tf
}
// NewClient creates a new client for the relay server. The client is not connected to the server until the Connect
@@ -361,12 +393,13 @@ func (c *Client) Close() error {
}
func (c *Client) connect(ctx context.Context) (*RelayAddr, error) {
dialers := c.getDialers()
mode := transportModeFromEnv()
dialers := c.getDialers(mode)
var conn net.Conn
if c.serverIP.IsValid() {
var err error
conn, err = c.dialRaceDirect(ctx, dialers)
conn, err = c.dialRaceDirect(ctx, mode, dialers)
if err != nil {
c.log.Infof("dial via server IP %s failed, falling back to FQDN: %v", c.serverIP, err)
conn = nil
@@ -375,6 +408,9 @@ func (c *Client) connect(ctx context.Context) (*RelayAddr, error) {
if conn == nil {
rd := dialer.NewRaceDial(c.log, dialer.DefaultConnectionTimeout, c.connectionURL, dialers...)
if mode.sequential() {
rd.WithSequential()
}
var err error
conn, err = rd.Dial(ctx)
if err != nil {
@@ -382,6 +418,10 @@ func (c *Client) connect(ctx context.Context) (*RelayAddr, error) {
}
}
c.relayConn = conn
c.datagramFallbackTriggered.Store(false)
if tc, ok := conn.(transportConn); ok {
c.transport = tc.Protocol()
}
instanceURL, err := c.handShake(ctx)
if err != nil {
@@ -396,7 +436,7 @@ func (c *Client) connect(ctx context.Context) (*RelayAddr, error) {
}
// dialRaceDirect dials c.serverIP, preserving the original FQDN as the TLS ServerName for SNI.
func (c *Client) dialRaceDirect(ctx context.Context, dialers []dialer.DialeFn) (net.Conn, error) {
func (c *Client) dialRaceDirect(ctx context.Context, mode TransportMode, dialers []dialer.DialeFn) (net.Conn, error) {
directURL, serverName, err := substituteHost(c.connectionURL, c.serverIP)
if err != nil {
return nil, fmt.Errorf("substitute host: %w", err)
@@ -406,6 +446,9 @@ func (c *Client) dialRaceDirect(ctx context.Context, dialers []dialer.DialeFn) (
rd := dialer.NewRaceDial(c.log, dialer.DefaultConnectionTimeout, directURL, dialers...).
WithServerName(serverName)
if mode.sequential() {
rd.WithSequential()
}
return rd.Dial(ctx)
}
@@ -631,13 +674,53 @@ func (c *Client) writeTo(containerRef *connContainer, dstID messages.PeerID, pay
}
// the write always return with 0 length because the underling does not support the size feedback.
_, err = c.relayConn.Write(msg)
conn := c.relayConn
_, err = conn.Write(msg)
if err != nil {
c.log.Errorf("failed to write transport message: %s", err)
if errors.Is(err, netErr.ErrDatagramTooLarge) {
c.onDatagramTooLarge(conn, err)
} else {
c.log.Errorf("failed to write transport message: %s", err)
}
}
return len(payload), err
}
// onDatagramTooLarge reacts to a datagram rejected as too large for the path.
// When a non-datagram transport is available, it records a fallback for this
// server and closes the connection so the reconnect avoids datagram-sized
// transports. A single fallback is triggered per connection regardless of how
// many oversized datagrams arrive. cause carries the datagram size and budget.
func (c *Client) onDatagramTooLarge(conn net.Conn, cause error) {
// Handle one oversized datagram per connection; a burst triggers a single
// fallback (and a single log line), not many.
if !c.datagramFallbackTriggered.CompareAndSwap(false, true) {
return
}
// If the selected mode offers no non-datagram transport (e.g. pinned to a
// datagram-sized transport), reconnecting would just re-fail, so leave the
// connection up rather than loop.
if len(nonDatagramSized(c.baseDialers(transportModeFromEnv()))) == 0 {
c.log.Warnf("%s, but no non-datagram transport is available, not falling back", cause)
return
}
// Without the shared tracker a reconnect would just select the same
// transport again and re-fail, so leave the connection up rather than loop.
if c.transportFallback == nil {
c.log.Debugf("%s, but no transport fallback configured, leaving connection up", cause)
return
}
window := c.transportFallback.recordFailure(c.connectionURL)
c.log.Warnf("%s, avoiding datagram-sized transport for %s", cause, window)
if err := conn.Close(); err != nil {
c.log.Debugf("close relay connection for transport fallback: %s", err)
}
}
func (c *Client) listenForStopEvents(ctx context.Context, hc *healthcheck.Receiver, conn net.Conn, internalStopFlag *internalStopFlag) {
for {
select {
@@ -729,6 +812,7 @@ func (c *Client) close(gracefullyExit bool) error {
return nil
}
c.serviceIsRunning = false
c.transport = ""
c.muInstanceURL.Lock()
c.instanceURL = nil

View File

@@ -0,0 +1,18 @@
package dialer
// DatagramSized is implemented by dialers whose connections carry each write in
// a single datagram, so a write can be rejected when it exceeds the path's
// datagram budget (e.g. QUIC). Transports without this capability (e.g.
// WebSocket over TCP) impose no per-write size limit, so the relay client can
// fall back to them when a datagram-sized transport rejects a write as too
// large. The capability is advertised per dialer rather than hardcoded, so a
// new transport only needs to declare whether it is datagram-sized.
type DatagramSized interface {
DatagramSized()
}
// IsDatagramSized reports whether d produces datagram-sized connections.
func IsDatagramSized(d DialeFn) bool {
_, ok := d.(DatagramSized)
return ok
}

View File

@@ -4,4 +4,9 @@ import "errors"
var (
ErrClosedByServer = errors.New("closed by server")
// ErrDatagramTooLarge is returned when a transport message exceeds the
// QUIC datagram size the path to the relay can carry. The relay client
// treats it as a signal to fall back to a non-datagram transport.
ErrDatagramTooLarge = errors.New("datagram frame too large")
)

View File

@@ -8,7 +8,6 @@ import (
"time"
"github.com/quic-go/quic-go"
log "github.com/sirupsen/logrus"
netErr "github.com/netbirdio/netbird/shared/relay/client/dialer/net"
)
@@ -52,15 +51,17 @@ func (c *Conn) Read(b []byte) (n int, err error) {
}
func (c *Conn) Write(b []byte) (int, error) {
err := c.session.SendDatagram(b)
if err != nil {
err = c.remoteCloseErrHandling(err)
log.Errorf("failed to write to QUIC stream: %v", err)
return 0, err
if err := c.session.SendDatagram(b); err != nil {
return 0, c.writeErrHandling(err, len(b))
}
return len(b), nil
}
// Protocol returns the transport name for this connection.
func (c *Conn) Protocol() string {
return Network
}
func (c *Conn) RemoteAddr() net.Addr {
return c.session.RemoteAddr()
}
@@ -95,3 +96,15 @@ func (c *Conn) remoteCloseErrHandling(err error) error {
}
return err
}
// writeErrHandling normalizes SendDatagram errors. A datagram that exceeds the
// path's QUIC packet budget is mapped to ErrDatagramTooLarge (annotated with the
// datagram size and path budget) so the relay client can fall back to a
// non-datagram transport.
func (c *Conn) writeErrHandling(err error, size int) error {
var tooLarge *quic.DatagramTooLargeError
if errors.As(err, &tooLarge) {
return fmt.Errorf("%w: %d byte datagram over path budget %d", netErr.ErrDatagramTooLarge, size, tooLarge.MaxDatagramPayloadSize)
}
return c.remoteCloseErrHandling(err)
}

View File

@@ -2,13 +2,13 @@ package quic
import (
"context"
"errors"
"fmt"
"net"
"strings"
"time"
"github.com/quic-go/quic-go"
"github.com/quic-go/quic-go/logging"
log "github.com/sirupsen/logrus"
nbnet "github.com/netbirdio/netbird/client/net"
@@ -23,6 +23,12 @@ func (d Dialer) Protocol() string {
return Network
}
// DatagramSized marks QUIC as a datagram-sized transport: relay traffic is
// carried in QUIC DATAGRAM frames, which must fit a single packet.
func (d Dialer) DatagramSized() {
// Intentional marker method; presence is the capability signal.
}
func (d Dialer) Dial(ctx context.Context, address, serverName string) (net.Conn, error) {
quicURL, err := prepareURL(address)
if err != nil {
@@ -47,26 +53,21 @@ func (d Dialer) Dial(ctx context.Context, address, serverName string) (net.Conn,
MaxIdleTimeout: 4 * time.Minute,
EnableDatagrams: true,
InitialPacketSize: nbRelay.QUICInitialPacketSize,
Tracer: connectionTracer(quicURL),
}
udpConn, err := nbnet.ListenUDP("udp", &net.UDPAddr{Port: 0})
if err != nil {
log.Errorf("failed to listen on UDP: %s", err)
return nil, err
return nil, fmt.Errorf("listen udp: %w", err)
}
udpAddr, err := net.ResolveUDPAddr("udp", quicURL)
if err != nil {
log.Errorf("failed to resolve UDP address: %s", err)
return nil, err
return nil, fmt.Errorf("resolve %s: %w", quicURL, err)
}
session, err := quic.Dial(ctx, udpConn, udpAddr, tlsClientConfig, quicConfig)
if err != nil {
if errors.Is(err, context.Canceled) {
return nil, err
}
log.Errorf("failed to dial to Relay server via QUIC '%s': %s", quicURL, err)
return nil, err
}
@@ -74,6 +75,28 @@ func (d Dialer) Dial(ctx context.Context, address, serverName string) (net.Conn,
return conn, nil
}
// connectionTracer returns a QUIC tracer that logs the DPLPMTUD result and the
// reason a relay connection closed, so the path MTU settled on and teardown
// cause are visible in logs. Lines carry the relay address as a structured
// field, matching the rest of the relay client logging.
func connectionTracer(addr string) func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer {
relayLog := log.WithField("relay", addr)
return func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer {
return &logging.ConnectionTracer{
UpdatedMTU: func(mtu logging.ByteCount, done bool) {
if done {
relayLog.Infof("QUIC path MTU settled at %d", mtu)
return
}
relayLog.Debugf("QUIC path MTU probing at %d", mtu)
},
ClosedConnection: func(err error) {
relayLog.Debugf("QUIC connection closed: %v", err)
},
}
}
}
func prepareURL(address string) (string, error) {
var host string
var defaultPort string

View File

@@ -3,6 +3,7 @@ package dialer
import (
"context"
"errors"
"fmt"
"net"
"time"
@@ -32,6 +33,7 @@ type RaceDial struct {
serverName string
dialerFns []DialeFn
connectionTimeout time.Duration
sequential bool
}
func NewRaceDial(log *log.Entry, connectionTimeout time.Duration, serverURL string, dialerFns ...DialeFn) *RaceDial {
@@ -53,9 +55,24 @@ func (r *RaceDial) WithServerName(serverName string) *RaceDial {
return r
}
// WithSequential makes Dial try the dialers in order, falling back to the next
// only when one fails to connect, instead of racing them concurrently.
//
// Mutates the receiver and is not safe for concurrent reconfiguration; a
// RaceDial is intended to be constructed per dial and discarded.
func (r *RaceDial) WithSequential() *RaceDial {
r.sequential = true
return r
}
func (r *RaceDial) Dial(ctx context.Context) (net.Conn, error) {
if r.sequential {
return r.dialSequential(ctx)
}
connChan := make(chan dialResult, len(r.dialerFns))
winnerConn := make(chan net.Conn, 1)
errChan := make(chan error, 1)
abortCtx, abort := context.WithCancel(ctx)
defer abort()
@@ -63,15 +80,41 @@ func (r *RaceDial) Dial(ctx context.Context) (net.Conn, error) {
go r.dial(dfn, abortCtx, connChan)
}
go r.processResults(connChan, winnerConn, abort)
go r.processResults(connChan, winnerConn, errChan, abort)
conn, ok := <-winnerConn
if !ok {
return nil, errors.New("failed to dial to Relay server on any protocol")
return nil, <-errChan
}
return conn, nil
}
// dialSequential tries each dialer in order, returning the first connection and
// falling back to the next on failure.
func (r *RaceDial) dialSequential(ctx context.Context) (net.Conn, error) {
var errs []error
for _, dfn := range r.dialerFns {
if err := ctx.Err(); err != nil {
return nil, err
}
attemptCtx, cancel := context.WithTimeout(ctx, r.connectionTimeout)
r.log.Infof("dialing Relay server via %s", dfn.Protocol())
conn, err := dfn.Dial(attemptCtx, r.serverURL, r.serverName)
cancel()
if err != nil {
if errors.Is(err, context.Canceled) {
return nil, err
}
r.log.Errorf("failed to dial via %s: %s", dfn.Protocol(), err)
errs = append(errs, fmt.Errorf("%s: %w", dfn.Protocol(), err))
continue
}
r.log.Infof("successfully dialed via: %s", dfn.Protocol())
return conn, nil
}
return nil, dialErr(errs)
}
func (r *RaceDial) dial(dfn DialeFn, abortCtx context.Context, connChan chan dialResult) {
ctx, cancel := context.WithTimeout(abortCtx, r.connectionTimeout)
defer cancel()
@@ -81,8 +124,9 @@ func (r *RaceDial) dial(dfn DialeFn, abortCtx context.Context, connChan chan dia
connChan <- dialResult{Conn: conn, Protocol: dfn.Protocol(), Err: err}
}
func (r *RaceDial) processResults(connChan chan dialResult, winnerConn chan net.Conn, abort context.CancelFunc) {
func (r *RaceDial) processResults(connChan chan dialResult, winnerConn chan net.Conn, errChan chan error, abort context.CancelFunc) {
var hasWinner bool
errsByProtocol := make(map[string]error)
for i := 0; i < len(r.dialerFns); i++ {
dr := <-connChan
if dr.Err != nil {
@@ -90,6 +134,7 @@ func (r *RaceDial) processResults(connChan chan dialResult, winnerConn chan net.
r.log.Infof("connection attempt aborted via: %s", dr.Protocol)
} else {
r.log.Errorf("failed to dial via %s: %s", dr.Protocol, dr.Err)
errsByProtocol[dr.Protocol] = fmt.Errorf("%s: %w", dr.Protocol, dr.Err)
}
continue
}
@@ -107,5 +152,29 @@ func (r *RaceDial) processResults(connChan chan dialResult, winnerConn chan net.
hasWinner = true
winnerConn <- dr.Conn
}
if !hasWinner {
errChan <- dialErr(r.orderedErrs(errsByProtocol))
}
close(winnerConn)
}
// orderedErrs returns the per-protocol errors in dialer order, so the combined
// error is stable regardless of which attempt failed first.
func (r *RaceDial) orderedErrs(byProtocol map[string]error) []error {
errs := make([]error, 0, len(byProtocol))
for _, dfn := range r.dialerFns {
if err, ok := byProtocol[dfn.Protocol()]; ok {
errs = append(errs, err)
}
}
return errs
}
// dialErr combines per-dialer failures, preserving the underlying reasons
// (e.g. "connection refused") rather than a generic message.
func dialErr(errs []error) error {
if len(errs) == 0 {
return errors.New("no relay transport available")
}
return errors.Join(errs...)
}

View File

@@ -250,3 +250,66 @@ func TestRaceDialFirstSuccessfulDialerWins(t *testing.T) {
}
}
}
func TestRaceDialSequentialFallback(t *testing.T) {
logger := logrus.NewEntry(logrus.New())
serverURL := "test.server.com"
var firstDialed, secondDialed bool
preferred := &MockDialer{
protocolStr: "quic",
dialFunc: func(ctx context.Context, address string) (net.Conn, error) {
firstDialed = true
return nil, errors.New("quic unreachable")
},
}
fallbackConn := &MockConn{remoteAddr: &MockAddr{network: "ws"}}
fallback := &MockDialer{
protocolStr: "ws",
dialFunc: func(ctx context.Context, address string) (net.Conn, error) {
secondDialed = true
return fallbackConn, nil
},
}
rd := NewRaceDial(logger, DefaultConnectionTimeout, serverURL, preferred, fallback).WithSequential()
conn, err := rd.Dial(context.Background())
if err != nil {
t.Fatalf("expected fallback to succeed, got %v", err)
}
if conn != fallbackConn {
t.Errorf("expected fallback connection, got %v", conn)
}
if !firstDialed || !secondDialed {
t.Errorf("expected both dialers attempted in order, first=%v second=%v", firstDialed, secondDialed)
}
}
func TestRaceDialSequentialPreferredWins(t *testing.T) {
logger := logrus.NewEntry(logrus.New())
serverURL := "test.server.com"
preferredConn := &MockConn{remoteAddr: &MockAddr{network: "quic"}}
preferred := &MockDialer{
protocolStr: "quic",
dialFunc: func(ctx context.Context, address string) (net.Conn, error) {
return preferredConn, nil
},
}
fallback := &MockDialer{
protocolStr: "ws",
dialFunc: func(ctx context.Context, address string) (net.Conn, error) {
t.Errorf("fallback dialer must not be tried when preferred succeeds")
return nil, errors.New("should not happen")
},
}
rd := NewRaceDial(logger, DefaultConnectionTimeout, serverURL, preferred, fallback).WithSequential()
conn, err := rd.Dial(context.Background())
if err != nil {
t.Fatalf("expected preferred to succeed, got %v", err)
}
if conn != preferredConn {
t.Errorf("expected preferred connection, got %v", conn)
}
}

View File

@@ -33,6 +33,11 @@ func NewConn(wsConn *websocket.Conn, serverAddress string, underlying net.Conn)
}
}
// Protocol returns the transport name for this connection.
func (c *Conn) Protocol() string {
return Network
}
func (c *Conn) Read(b []byte) (n int, err error) {
t, ioReader, err := c.Conn.Reader(c.ctx)
if err != nil {

View File

@@ -22,7 +22,7 @@ type Dialer struct {
}
func (d Dialer) Protocol() string {
return "WS"
return Network
}
func (d Dialer) Dial(ctx context.Context, address, serverName string) (net.Conn, error) {
@@ -39,7 +39,12 @@ func (d Dialer) Dial(ctx context.Context, address, serverName string) (net.Conn,
if errors.Is(err, context.Canceled) {
return nil, err
}
log.Errorf("failed to dial to Relay server '%s': %s", wsURL, err)
// websocket.Dial wraps the cause in verbose layers; surface the
// underlying network error when present.
var opErr *net.OpError
if errors.As(err, &opErr) {
return nil, opErr
}
return nil, err
}
if resp.Body != nil {

View File

@@ -9,11 +9,42 @@ import (
"github.com/netbirdio/netbird/shared/relay/client/dialer/ws"
)
// getDialers returns the list of dialers to use for connecting to the relay server.
func (c *Client) getDialers() []dialer.DialeFn {
if c.mtu > 0 && c.mtu > iface.DefaultMTU {
c.log.Infof("MTU %d exceeds default (%d), forcing WebSocket transport to avoid DATAGRAM frame size issues", c.mtu, iface.DefaultMTU)
return []dialer.DialeFn{ws.Dialer{}}
// getDialers returns the ordered dialers for connecting to the relay server. It
// applies the datagram fallback generically: if this server recently rejected a
// datagram-sized transport, those dialers are dropped, leaving the rest.
func (c *Client) getDialers(mode TransportMode) []dialer.DialeFn {
dialers := c.baseDialers(mode)
if c.transportFallback != nil && c.transportFallback.avoidDatagramSized(c.connectionURL) {
if filtered := nonDatagramSized(dialers); len(filtered) > 0 {
c.log.Infof("relay recently rejected a datagram-sized transport, avoiding it")
return filtered
}
}
return []dialer.DialeFn{quic.Dialer{}, ws.Dialer{}}
return dialers
}
// baseDialers returns the ordered dialers for the mode, before any datagram
// fallback filtering. For racing modes (auto) the order is irrelevant; for
// prefer modes the first entry is tried before falling back to the second.
func (c *Client) baseDialers(mode TransportMode) []dialer.DialeFn {
switch mode {
case TransportModeWS:
c.log.Infof("%s=ws, using WebSocket transport", EnvRelayTransport)
return []dialer.DialeFn{ws.Dialer{}}
case TransportModeQUIC:
c.log.Infof("%s=quic, using QUIC transport", EnvRelayTransport)
return []dialer.DialeFn{quic.Dialer{}}
}
all := []dialer.DialeFn{quic.Dialer{}, ws.Dialer{}}
if mode == TransportModePreferWS {
all = []dialer.DialeFn{ws.Dialer{}, quic.Dialer{}}
}
if c.mtu > 0 && c.mtu > iface.DefaultMTU {
c.log.Infof("MTU %d exceeds default (%d), avoiding datagram-sized transports", c.mtu, iface.DefaultMTU)
return nonDatagramSized(all)
}
return all
}

View File

@@ -0,0 +1,101 @@
//go:build !js
package client
import (
"os"
"testing"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/shared/relay/client/dialer"
netErr "github.com/netbirdio/netbird/shared/relay/client/dialer/net"
"github.com/netbirdio/netbird/shared/relay/client/dialer/quic"
"github.com/netbirdio/netbird/shared/relay/client/dialer/ws"
)
// TestDatagramSizedCapability locks the capability the generic fallback relies
// on: QUIC is datagram-sized, WebSocket is not.
func TestDatagramSizedCapability(t *testing.T) {
assert.True(t, dialer.IsDatagramSized(quic.Dialer{}), "QUIC must advertise datagram-sized")
assert.False(t, dialer.IsDatagramSized(ws.Dialer{}), "WebSocket must not advertise datagram-sized")
}
func protocols(dialers []dialer.DialeFn) []string {
out := make([]string, len(dialers))
for i, d := range dialers {
out[i] = d.Protocol()
}
return out
}
func TestGetDialers(t *testing.T) {
const url = "rels://relay.example:443"
tests := []struct {
name string
mode string
mtu uint16
preferWS bool
want []string
}{
{name: "auto races quic and ws", mode: "auto", mtu: iface.DefaultMTU, want: []string{"quic", "ws"}},
{name: "ws pinned", mode: "ws", mtu: iface.DefaultMTU, want: []string{"ws"}},
{name: "quic pinned", mode: "quic", mtu: iface.DefaultMTU, want: []string{"quic"}},
{name: "prefer-quic orders quic first", mode: "prefer-quic", mtu: iface.DefaultMTU, want: []string{"quic", "ws"}},
{name: "prefer-ws orders ws first", mode: "prefer-ws", mtu: iface.DefaultMTU, want: []string{"ws", "quic"}},
{name: "mtu above default forces ws", mode: "auto", mtu: iface.DefaultMTU + 100, want: []string{"ws"}},
{name: "sticky fallback forces ws in auto", mode: "auto", mtu: iface.DefaultMTU, preferWS: true, want: []string{"ws"}},
{name: "sticky fallback forces ws in prefer-quic", mode: "prefer-quic", mtu: iface.DefaultMTU, preferWS: true, want: []string{"ws"}},
{name: "quic pin overrides sticky fallback", mode: "quic", mtu: iface.DefaultMTU, preferWS: true, want: []string{"quic"}},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Setenv(EnvRelayTransport, tc.mode)
if tc.mode == "" {
os.Unsetenv(EnvRelayTransport)
}
tf := newTransportFallback()
if tc.preferWS {
tf.recordFailure(url)
}
c := &Client{
log: log.WithField("test", t.Name()),
connectionURL: url,
mtu: tc.mtu,
transportFallback: tf,
}
assert.Equal(t, tc.want, protocols(c.getDialers(transportModeFromEnv())))
})
}
}
// TestStickyFallbackAfterDatagramTooLarge verifies the full chain: an oversized
// datagram records a fallback that makes the next dial pick WebSocket, the way a
// reconnect would after the connection is closed.
func TestStickyFallbackAfterDatagramTooLarge(t *testing.T) {
const url = "rels://relay.example:443"
t.Setenv(EnvRelayTransport, string(TransportModeAuto))
c := &Client{
log: log.WithField("test", t.Name()),
connectionURL: url,
mtu: iface.DefaultMTU,
transportFallback: newTransportFallback(),
}
// First dial races both transports.
assert.Equal(t, []string{"quic", "ws"}, protocols(c.getDialers(transportModeFromEnv())))
// An oversized datagram records the fallback for this server.
c.onDatagramTooLarge(&closeTrackingConn{}, netErr.ErrDatagramTooLarge)
// The reconnect now sticks to WebSocket.
assert.Equal(t, []string{"ws"}, protocols(c.getDialers(transportModeFromEnv())))
}

View File

@@ -7,7 +7,11 @@ import (
"github.com/netbirdio/netbird/shared/relay/client/dialer/ws"
)
func (c *Client) getDialers() []dialer.DialeFn {
func (c *Client) getDialers(_ TransportMode) []dialer.DialeFn {
// JS/WASM build only uses WebSocket transport
return []dialer.DialeFn{ws.Dialer{}}
}
func (c *Client) baseDialers(_ TransportMode) []dialer.DialeFn {
return []dialer.DialeFn{ws.Dialer{}}
}

View File

@@ -2,6 +2,7 @@ package client
import (
"context"
"sync/atomic"
"time"
"github.com/cenkalti/backoff/v4"
@@ -20,6 +21,10 @@ type Guard struct {
// maxBackoffInterval caps the exponential backoff between reconnect
// attempts.
maxBackoffInterval time.Duration
// lastErr is the error from the most recent failed reconnect attempt,
// surfaced as the home relay status while disconnected.
lastErr atomic.Pointer[error]
}
// NewGuard creates a new guard for the relay client. A non-positive
@@ -37,6 +42,19 @@ func NewGuard(sp *ServerPicker, maxBackoffInterval time.Duration) *Guard {
return g
}
// LastError returns the error from the most recent failed reconnect attempt, or
// nil if reconnection last succeeded.
func (g *Guard) LastError() error {
if p := g.lastErr.Load(); p != nil {
return *p
}
return nil
}
func (g *Guard) setLastError(err error) {
g.lastErr.Store(&err)
}
// StartReconnectTrys is called when the relay client is disconnected from the relay server.
// It attempts to reconnect to the relay server. The function first tries a quick reconnect
// to the same server that was used before, if the server URL is still valid. If the quick
@@ -63,6 +81,7 @@ func (g *Guard) StartReconnectTrys(ctx context.Context, relayClient *Client) {
case <-ticker.C:
if err := g.retry(ctx); err != nil {
log.Errorf("failed to pick new Relay server: %s", err)
g.setLastError(err)
continue
}
return
@@ -89,6 +108,7 @@ func (g *Guard) tryToQuickReconnect(parentCtx context.Context, rc *Client) bool
if err := rc.Connect(parentCtx); err != nil {
log.Errorf("failed to reconnect to relay server: %s", err)
g.setLastError(err)
return false
}
return true
@@ -100,6 +120,7 @@ func (g *Guard) retry(ctx context.Context) error {
if err != nil {
return err
}
g.setLastError(nil)
// prevent to work with a deprecated Relay client instance
g.drainRelayClientChan()
@@ -125,6 +146,7 @@ func (g *Guard) isServerURLStillValid(rc *Client) bool {
}
func (g *Guard) notifyReconnected() {
g.setLastError(nil)
select {
case g.OnReconnected <- struct{}{}:
default:

View File

@@ -79,23 +79,30 @@ type Manager struct {
cleanupInterval time.Duration
keepUnusedServerTime time.Duration
// transportFallback is shared across home and foreign relay clients so a
// datagram-too-large failure makes that server avoid datagram-sized transports across reconnects.
transportFallback *transportFallback
}
// NewManager creates a new manager instance.
// The serverURL address can be empty. In this case, the manager will not serve.
func NewManager(ctx context.Context, serverURLs []string, peerID string, mtu uint16, opts ...ManagerOption) *Manager {
tokenStore := &relayAuth.TokenStore{}
tf := newTransportFallback()
m := &Manager{
ctx: ctx,
peerID: peerID,
tokenStore: tokenStore,
mtu: mtu,
ctx: ctx,
peerID: peerID,
tokenStore: tokenStore,
mtu: mtu,
transportFallback: tf,
serverPicker: &ServerPicker{
TokenStore: tokenStore,
PeerID: peerID,
MTU: mtu,
ConnectionTimeout: defaultConnectionTimeout,
TransportFallback: tf,
},
relayClients: make(map[string]*RelayTrack),
onDisconnectedListeners: make(map[string]*list.List),
@@ -123,6 +130,9 @@ func (m *Manager) Serve() error {
client, err := m.serverPicker.PickServer(m.ctx)
if err != nil {
// record the initial failure so status shows the real reason before
// the guard's first retry tick
m.reconnectGuard.setLastError(err)
go m.reconnectGuard.StartReconnectTrys(m.ctx, nil)
} else {
m.storeClient(client)
@@ -235,6 +245,67 @@ func (m *Manager) ServerURLs() []string {
return m.serverPicker.ServerURLs.Load().([]string)
}
// RelayConnectError returns the error from the most recent failed home relay
// reconnect attempt, or nil if the relay last connected successfully.
func (m *Manager) RelayConnectError() error {
return m.reconnectGuard.LastError()
}
// RelayConnState is the connection state of a single relay server.
type RelayConnState struct {
// URL is the server's instance address when connected, otherwise the
// configured server URL.
URL string
// Transport is the negotiated transport, empty if not connected.
Transport string
// Err is set when the relay is not connected.
Err error
}
// RelayStates returns the connection state of the home relay and every foreign
// relay the manager currently tracks.
func (m *Manager) RelayStates() []RelayConnState {
var states []RelayConnState
m.relayClientMu.RLock()
home := m.relayClient
m.relayClientMu.RUnlock()
if home != nil {
st := relayConnState(home)
// The home relay reconnects through the guard, so the real failure
// reason lives there rather than on the (stale) client.
if st.Err != nil {
if gErr := m.reconnectGuard.LastError(); gErr != nil {
st.Err = gErr
}
}
states = append(states, st)
}
// Snapshot the tracks, then query each outside the map lock: a track can be
// held by an in-progress Connect, and blocking on it must not stall other
// relay operations.
m.relayClientsMutex.RLock()
tracks := make([]*RelayTrack, 0, len(m.relayClients))
for _, rt := range m.relayClients {
tracks = append(tracks, rt)
}
m.relayClientsMutex.RUnlock()
// Only connected foreign relays carry state; a failed connect is evicted
// immediately (openConnVia), so there is no error state to surface.
for _, rt := range tracks {
rt.RLock()
rc := rt.relayClient
rt.RUnlock()
if rc != nil {
states = append(states, relayConnState(rc))
}
}
return states
}
// HasRelayAddress returns true if the manager is serving. With this method can check if the peer can communicate with
// Relay service.
func (m *Manager) HasRelayAddress() bool {
@@ -287,6 +358,7 @@ func (m *Manager) openConnVia(ctx context.Context, serverAddress, peerKey string
m.relayClientsMutex.Unlock()
relayClient := NewClientWithServerIP(serverAddress, serverIP, m.tokenStore, m.peerID, m.mtu)
relayClient.SetTransportFallback(m.transportFallback)
err := relayClient.Connect(m.ctx)
if err != nil {
rt.err = err
@@ -452,3 +524,11 @@ func (m *Manager) notifyOnDisconnectListeners(serverAddress string) {
}
delete(m.onDisconnectedListeners, serverAddress)
}
func relayConnState(c *Client) RelayConnState {
addr, err := c.ServerInstanceURL()
if err != nil {
return RelayConnState{URL: c.connectionURL, Err: err}
}
return RelayConnState{URL: addr, Transport: c.Transport()}
}

View File

@@ -29,6 +29,7 @@ type ServerPicker struct {
PeerID string
MTU uint16
ConnectionTimeout time.Duration
TransportFallback *transportFallback
}
func (sp *ServerPicker) PickServer(parentCtx context.Context) (*Client, error) {
@@ -39,6 +40,7 @@ func (sp *ServerPicker) PickServer(parentCtx context.Context) (*Client, error) {
connResultChan := make(chan connResult, totalServers)
successChan := make(chan connResult, 1)
errChan := make(chan error, 1)
concurrentLimiter := make(chan struct{}, maxConcurrentServers)
log.Debugf("pick server from list: %v", sp.ServerURLs.Load().([]string))
@@ -53,23 +55,24 @@ func (sp *ServerPicker) PickServer(parentCtx context.Context) (*Client, error) {
}(url)
}
go sp.processConnResults(connResultChan, successChan)
go sp.processConnResults(connResultChan, successChan, errChan)
select {
case cr, ok := <-successChan:
if !ok {
return nil, errors.New("failed to connect to any relay server: all attempts failed")
return nil, <-errChan
}
log.Infof("chosen home Relay server: %s", cr.Url)
return cr.RelayClient, nil
case <-ctx.Done():
return nil, fmt.Errorf("failed to connect to any relay server: %w", ctx.Err())
return nil, fmt.Errorf("connect to relay server: %w", ctx.Err())
}
}
func (sp *ServerPicker) startConnection(ctx context.Context, resultChan chan connResult, url string) {
log.Infof("try to connecting to relay server: %s", url)
relayClient := NewClient(url, sp.TokenStore, sp.PeerID, sp.MTU)
relayClient.SetTransportFallback(sp.TransportFallback)
err := relayClient.Connect(ctx)
resultChan <- connResult{
RelayClient: relayClient,
@@ -78,12 +81,14 @@ func (sp *ServerPicker) startConnection(ctx context.Context, resultChan chan con
}
}
func (sp *ServerPicker) processConnResults(resultChan chan connResult, successChan chan connResult) {
func (sp *ServerPicker) processConnResults(resultChan chan connResult, successChan chan connResult, errChan chan error) {
var hasSuccess bool
var errs []error
for numOfResults := 0; numOfResults < cap(resultChan); numOfResults++ {
cr := <-resultChan
if cr.Err != nil {
log.Tracef("failed to connect to Relay server: %s: %v", cr.Url, cr.Err)
errs = append(errs, cr.Err)
continue
}
log.Infof("connected to Relay server: %s", cr.Url)
@@ -99,5 +104,16 @@ func (sp *ServerPicker) processConnResults(resultChan chan connResult, successCh
hasSuccess = true
successChan <- cr
}
if !hasSuccess {
errChan <- pickErr(errs)
}
close(successChan)
}
// pickErr combines per-server connection failures into a single error.
func pickErr(errs []error) error {
if len(errs) == 0 {
return errors.New("no relay server available")
}
return errors.Join(errs...)
}

View File

@@ -0,0 +1,129 @@
package client
import (
"os"
"strings"
"sync"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/shared/relay/client/dialer"
)
// EnvRelayTransport pins the relay transport. Valid values: "auto" (default,
// race QUIC and WebSocket), "quic" (QUIC only), "ws" (WebSocket only),
// "prefer-quic" / "prefer-ws" (try the preferred transport first, fall back to
// the other only if it fails to connect; no race). The prefer modes trade a
// slower connect when the preferred transport is blackholed for deterministic
// transport selection.
const EnvRelayTransport = "NB_RELAY_TRANSPORT"
const (
// transportFallbackBase is the initial window a relay server avoids
// datagram-sized transports after a datagram is rejected as too large.
transportFallbackBase = 10 * time.Minute
// transportFallbackMax caps the pinned window when failures repeat.
transportFallbackMax = 60 * time.Minute
)
// TransportMode selects which relay dialers are used.
type TransportMode string
const (
TransportModeAuto TransportMode = "auto"
TransportModeQUIC TransportMode = "quic"
TransportModeWS TransportMode = "ws"
TransportModePreferQUIC TransportMode = "prefer-quic"
TransportModePreferWS TransportMode = "prefer-ws"
)
// transportModeFromEnv reads EnvRelayTransport, defaulting to auto for an empty
// or unrecognized value.
func transportModeFromEnv() TransportMode {
switch TransportMode(strings.ToLower(strings.TrimSpace(os.Getenv(EnvRelayTransport)))) {
case "", TransportModeAuto:
return TransportModeAuto
case TransportModeQUIC:
return TransportModeQUIC
case TransportModeWS:
return TransportModeWS
case TransportModePreferQUIC:
return TransportModePreferQUIC
case TransportModePreferWS:
return TransportModePreferWS
default:
log.Warnf("invalid %s value %q, using %q", EnvRelayTransport, os.Getenv(EnvRelayTransport), TransportModeAuto)
return TransportModeAuto
}
}
// sequential reports whether the mode tries dialers in order with fallback
// instead of racing them concurrently.
func (m TransportMode) sequential() bool {
return m == TransportModePreferQUIC || m == TransportModePreferWS
}
// transportFallback tracks relay servers that have rejected a datagram-sized
// transport (a write too large for the path) and should temporarily avoid such
// transports. It is shared across the relay manager so the preference survives
// client recreation (foreign relay clients are evicted and rebuilt on
// disconnect). Entries are keyed by server URL and expire after a window that
// grows on repeated failures.
type transportFallback struct {
mu sync.Mutex
entries map[string]*fallbackEntry
}
type fallbackEntry struct {
until time.Time
duration time.Duration
}
func newTransportFallback() *transportFallback {
return &transportFallback{entries: make(map[string]*fallbackEntry)}
}
// avoidDatagramSized reports whether serverURL is currently within a window
// where datagram-sized transports should be avoided.
func (f *transportFallback) avoidDatagramSized(serverURL string) bool {
f.mu.Lock()
defer f.mu.Unlock()
e := f.entries[serverURL]
return e != nil && time.Now().Before(e.until)
}
// recordFailure makes serverURL avoid datagram-sized transports for a window:
// transportFallbackBase on the first failure, doubling up to transportFallbackMax
// when a datagram transport fails again after a previous window expired. It
// returns the active window duration.
func (f *transportFallback) recordFailure(serverURL string) time.Duration {
f.mu.Lock()
defer f.mu.Unlock()
now := time.Now()
e := f.entries[serverURL]
switch {
case e == nil:
e = &fallbackEntry{duration: transportFallbackBase}
f.entries[serverURL] = e
case now.Before(e.until):
return time.Until(e.until)
default:
e.duration = min(e.duration*2, transportFallbackMax)
}
e.until = now.Add(e.duration)
return e.duration
}
// nonDatagramSized returns the dialers from in that are not datagram-sized,
// preserving order.
func nonDatagramSized(in []dialer.DialeFn) []dialer.DialeFn {
out := make([]dialer.DialeFn, 0, len(in))
for _, d := range in {
if !dialer.IsDatagramSized(d) {
out = append(out, d)
}
}
return out
}

View File

@@ -0,0 +1,140 @@
package client
import (
"net"
"os"
"testing"
"time"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
netErr "github.com/netbirdio/netbird/shared/relay/client/dialer/net"
)
// closeTrackingConn records whether Close was called; only Close is exercised.
type closeTrackingConn struct {
net.Conn
closed bool
}
func (c *closeTrackingConn) Close() error {
c.closed = true
return nil
}
func TestTransportModeFromEnv(t *testing.T) {
tests := []struct {
value string
want TransportMode
}{
{"", TransportModeAuto},
{"auto", TransportModeAuto},
{"quic", TransportModeQUIC},
{"QUIC", TransportModeQUIC},
{"ws", TransportModeWS},
{" Ws ", TransportModeWS},
{"prefer-quic", TransportModePreferQUIC},
{"prefer-ws", TransportModePreferWS},
{"garbage", TransportModeAuto},
}
for _, tc := range tests {
t.Run(tc.value, func(t *testing.T) {
t.Setenv(EnvRelayTransport, tc.value)
if tc.value == "" {
os.Unsetenv(EnvRelayTransport)
}
assert.Equal(t, tc.want, transportModeFromEnv())
})
}
}
func TestTransportFallbackRecordAndExpiry(t *testing.T) {
const url = "rels://relay.example:443"
f := newTransportFallback()
assert.False(t, f.avoidDatagramSized(url), "no fallback recorded yet")
d := f.recordFailure(url)
assert.Equal(t, transportFallbackBase, d, "first failure pins for the base window")
assert.True(t, f.avoidDatagramSized(url), "datagram-sized transport avoided within the window")
// A second failure while still inside the window must not grow the window.
d = f.recordFailure(url)
assert.LessOrEqual(t, d, transportFallbackBase, "still within the active window")
require.NotNil(t, f.entries[url])
assert.Equal(t, transportFallbackBase, f.entries[url].duration, "duration unchanged inside window")
// Expire the window: datagram-sized transport allowed again.
f.entries[url].until = time.Now().Add(-time.Second)
assert.False(t, f.avoidDatagramSized(url), "window expired, datagram-sized transport allowed")
}
func TestTransportFallbackGrowsOnRepeat(t *testing.T) {
const url = "rels://relay.example:443"
f := newTransportFallback()
want := transportFallbackBase
for i := range 6 {
d := f.recordFailure(url)
assert.Equal(t, want, d, "window after %d expiries", i)
// expire the window so the next failure is treated as a repeat
f.entries[url].until = time.Now().Add(-time.Second)
want = min(want*2, transportFallbackMax)
}
assert.Equal(t, transportFallbackMax, f.entries[url].duration, "window caps at the max")
}
func TestOnDatagramTooLargeAuto(t *testing.T) {
const url = "rels://relay.example:443"
t.Setenv(EnvRelayTransport, string(TransportModeAuto))
tf := newTransportFallback()
c := &Client{
log: log.WithField("test", t.Name()),
connectionURL: url,
transportFallback: tf,
}
conn := &closeTrackingConn{}
c.onDatagramTooLarge(conn, netErr.ErrDatagramTooLarge)
assert.True(t, conn.closed, "connection closed to force reconnect")
assert.True(t, tf.avoidDatagramSized(url), "fallback recorded for the server")
// A second oversized datagram on the same connection must not re-close.
conn.closed = false
c.onDatagramTooLarge(conn, netErr.ErrDatagramTooLarge)
assert.False(t, conn.closed, "single fallback per connection")
}
func TestOnDatagramTooLargeQUICPinned(t *testing.T) {
const url = "rels://relay.example:443"
t.Setenv(EnvRelayTransport, string(TransportModeQUIC))
tf := newTransportFallback()
c := &Client{
log: log.WithField("test", t.Name()),
connectionURL: url,
transportFallback: tf,
}
conn := &closeTrackingConn{}
c.onDatagramTooLarge(conn, netErr.ErrDatagramTooLarge)
assert.False(t, conn.closed, "QUIC pin keeps the connection, no fallback redial")
assert.False(t, tf.avoidDatagramSized(url), "QUIC pin records no fallback")
}
func TestTransportFallbackPerServer(t *testing.T) {
f := newTransportFallback()
f.recordFailure("rels://a.example:443")
assert.True(t, f.avoidDatagramSized("rels://a.example:443"))
assert.False(t, f.avoidDatagramSized("rels://b.example:443"), "fallback is scoped to one server")
}