mirror of
https://github.com/netbirdio/netbird.git
synced 2026-06-05 15:39:56 +00:00
Compare commits
10 Commits
fix/ios-de
...
relay-tran
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e2b4f2c0bf | ||
|
|
d50241b75d | ||
|
|
3bad4c939b | ||
|
|
ac70eb2c8e | ||
|
|
36f346c0fe | ||
|
|
fbd97d6da5 | ||
|
|
3435267a23 | ||
|
|
eac6d501c3 | ||
|
|
deeae30612 | ||
|
|
f3cdf163e1 |
9
.github/workflows/golang-test-darwin.yml
vendored
9
.github/workflows/golang-test-darwin.yml
vendored
@@ -45,4 +45,11 @@ jobs:
|
||||
run: git --no-pager diff --exit-code
|
||||
|
||||
- name: Test
|
||||
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=devcert -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined)
|
||||
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -coverprofile=coverage.txt -tags=devcert -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined)
|
||||
|
||||
- name: Upload coverage reports to Codecov
|
||||
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 #v6.0.1
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
slug: netbirdio/netbird
|
||||
flags: unit,client
|
||||
|
||||
61
.github/workflows/golang-test-linux.yml
vendored
61
.github/workflows/golang-test-linux.yml
vendored
@@ -158,7 +158,16 @@ jobs:
|
||||
run: git --no-pager diff --exit-code
|
||||
|
||||
- name: Test
|
||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined)
|
||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -coverprofile=coverage.txt -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined)
|
||||
|
||||
- name: Upload coverage reports to Codecov
|
||||
if: matrix.arch == 'amd64'
|
||||
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 #v6.0.1
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
slug: netbirdio/netbird
|
||||
flags: unit,client
|
||||
|
||||
|
||||
test_client_on_docker:
|
||||
name: "Client (Docker) / Unit"
|
||||
@@ -276,9 +285,17 @@ jobs:
|
||||
run: |
|
||||
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
||||
go test ${{ matrix.raceFlag }} \
|
||||
-exec 'sudo' \
|
||||
-exec 'sudo' -coverprofile=coverage.txt \
|
||||
-timeout 10m -p 1 ./relay/... ./shared/relay/...
|
||||
|
||||
- name: Upload coverage reports to Codecov
|
||||
if: matrix.arch == 'amd64'
|
||||
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 #v6.0.1
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
slug: netbirdio/netbird
|
||||
flags: unit,relay
|
||||
|
||||
test_proxy:
|
||||
name: "Proxy / Unit"
|
||||
needs: [build-cache]
|
||||
@@ -326,7 +343,15 @@ jobs:
|
||||
- name: Test
|
||||
run: |
|
||||
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
||||
go test -timeout 10m -p 1 ./proxy/...
|
||||
go test -timeout 10m -p 1 -coverprofile=coverage.txt ./proxy/...
|
||||
|
||||
- name: Upload coverage reports to Codecov
|
||||
if: matrix.arch == 'amd64'
|
||||
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 #v6.0.1
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
slug: netbirdio/netbird
|
||||
flags: unit,proxy
|
||||
|
||||
test_signal:
|
||||
name: "Signal / Unit"
|
||||
@@ -377,9 +402,17 @@ jobs:
|
||||
run: |
|
||||
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
||||
go test \
|
||||
-exec 'sudo' \
|
||||
-exec 'sudo' -coverprofile=coverage.txt \
|
||||
-timeout 10m ./signal/... ./shared/signal/...
|
||||
|
||||
- name: Upload coverage reports to Codecov
|
||||
if: matrix.arch == 'amd64'
|
||||
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 #v6.0.1
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
slug: netbirdio/netbird
|
||||
flags: unit,signal
|
||||
|
||||
test_management:
|
||||
name: "Management / Unit"
|
||||
needs: [build-cache]
|
||||
@@ -445,10 +478,18 @@ jobs:
|
||||
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
||||
NETBIRD_STORE_ENGINE=${{ matrix.store }} \
|
||||
CI=true \
|
||||
go test -tags=devcert \
|
||||
go test -tags=devcert -coverprofile=coverage.txt \
|
||||
-exec "sudo --preserve-env=CI,NETBIRD_STORE_ENGINE" \
|
||||
-timeout 20m ./management/... ./shared/management/...
|
||||
|
||||
- name: Upload coverage reports to Codecov
|
||||
if: matrix.arch == 'amd64'
|
||||
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 #v6.0.1
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
slug: netbirdio/netbird
|
||||
flags: unit,management
|
||||
|
||||
benchmark:
|
||||
name: "Management / Benchmark"
|
||||
needs: [build-cache]
|
||||
@@ -687,6 +728,14 @@ jobs:
|
||||
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
||||
NETBIRD_STORE_ENGINE=${{ matrix.store }} \
|
||||
CI=true \
|
||||
go test -tags=integration \
|
||||
go test -tags=integration -coverprofile=coverage.txt \
|
||||
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \
|
||||
-timeout 20m ./management/server/http/...
|
||||
|
||||
- name: Upload coverage reports to Codecov
|
||||
if: matrix.arch == 'amd64'
|
||||
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 #v6.0.1
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
slug: netbirdio/netbird
|
||||
flags: integration,management
|
||||
|
||||
4
.github/workflows/wasm-build-validation.yml
vendored
4
.github/workflows/wasm-build-validation.yml
vendored
@@ -65,7 +65,7 @@ jobs:
|
||||
|
||||
echo "Size: ${SIZE} bytes (${SIZE_MB} MB)"
|
||||
|
||||
if [ ${SIZE} -gt 58720256 ]; then
|
||||
echo "Wasm binary size (${SIZE_MB}MB) exceeds 56MB limit!"
|
||||
if [ ${SIZE} -gt 62914560 ]; then
|
||||
echo "Wasm binary size (${SIZE_MB}MB) exceeds 60MB limit!"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
@@ -1015,14 +1015,17 @@ func (d *Status) GetRelayStates() []relay.ProbeResult {
|
||||
return d.relayStates
|
||||
}
|
||||
|
||||
// extend the list of stun, turn servers with relay address
|
||||
// extend the list of stun, turn servers with the relay server connections
|
||||
relayStates := slices.Clone(d.relayStates)
|
||||
|
||||
// if the server connection is not established then we will use the general address
|
||||
// in case of connection we will use the instance specific address
|
||||
instanceAddr, _, err := d.relayMgr.RelayInstanceAddress()
|
||||
if err != nil {
|
||||
// TODO add their status
|
||||
states := d.relayMgr.RelayStates()
|
||||
if len(states) == 0 {
|
||||
// no relay connection tracked yet; surface configured servers as
|
||||
// unavailable with the real reconnect error when known
|
||||
err := relayClient.ErrRelayClientNotConnected
|
||||
if connErr := d.relayMgr.RelayConnectError(); connErr != nil {
|
||||
err = connErr
|
||||
}
|
||||
for _, r := range d.relayMgr.ServerURLs() {
|
||||
relayStates = append(relayStates, relay.ProbeResult{
|
||||
URI: r,
|
||||
@@ -1032,10 +1035,14 @@ func (d *Status) GetRelayStates() []relay.ProbeResult {
|
||||
return relayStates
|
||||
}
|
||||
|
||||
relayState := relay.ProbeResult{
|
||||
URI: instanceAddr,
|
||||
for _, rs := range states {
|
||||
relayStates = append(relayStates, relay.ProbeResult{
|
||||
URI: rs.URL,
|
||||
Err: rs.Err,
|
||||
Transport: rs.Transport,
|
||||
})
|
||||
}
|
||||
return append(relayStates, relayState)
|
||||
return relayStates
|
||||
}
|
||||
|
||||
func (d *Status) ForwardingRules() []firewall.ForwardRule {
|
||||
@@ -1395,6 +1402,7 @@ func (fs FullStatus) ToProto() *proto.FullStatus {
|
||||
pbRelayState := &proto.RelayState{
|
||||
URI: relayState.URI,
|
||||
Available: relayState.Err == nil,
|
||||
Transport: relayState.Transport,
|
||||
}
|
||||
if err := relayState.Err; err != nil {
|
||||
pbRelayState.Error = err.Error()
|
||||
|
||||
@@ -32,6 +32,9 @@ type ProbeResult struct {
|
||||
URI string
|
||||
Err error
|
||||
Addr string
|
||||
// Transport is the negotiated relay transport, empty
|
||||
// for stun/turn probes or when not connected.
|
||||
Transport string
|
||||
}
|
||||
|
||||
type StunTurnProbe struct {
|
||||
|
||||
@@ -1832,6 +1832,7 @@ type RelayState struct {
|
||||
URI string `protobuf:"bytes,1,opt,name=URI,proto3" json:"URI,omitempty"`
|
||||
Available bool `protobuf:"varint,2,opt,name=available,proto3" json:"available,omitempty"`
|
||||
Error string `protobuf:"bytes,3,opt,name=error,proto3" json:"error,omitempty"`
|
||||
Transport string `protobuf:"bytes,4,opt,name=transport,proto3" json:"transport,omitempty"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
@@ -1887,6 +1888,13 @@ func (x *RelayState) GetError() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *RelayState) GetTransport() string {
|
||||
if x != nil {
|
||||
return x.Transport
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
type NSGroupState struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
Servers []string `protobuf:"bytes,1,rep,name=servers,proto3" json:"servers,omitempty"`
|
||||
@@ -6406,12 +6414,13 @@ const file_daemon_proto_rawDesc = "" +
|
||||
"\x0fManagementState\x12\x10\n" +
|
||||
"\x03URL\x18\x01 \x01(\tR\x03URL\x12\x1c\n" +
|
||||
"\tconnected\x18\x02 \x01(\bR\tconnected\x12\x14\n" +
|
||||
"\x05error\x18\x03 \x01(\tR\x05error\"R\n" +
|
||||
"\x05error\x18\x03 \x01(\tR\x05error\"p\n" +
|
||||
"\n" +
|
||||
"RelayState\x12\x10\n" +
|
||||
"\x03URI\x18\x01 \x01(\tR\x03URI\x12\x1c\n" +
|
||||
"\tavailable\x18\x02 \x01(\bR\tavailable\x12\x14\n" +
|
||||
"\x05error\x18\x03 \x01(\tR\x05error\"r\n" +
|
||||
"\x05error\x18\x03 \x01(\tR\x05error\x12\x1c\n" +
|
||||
"\ttransport\x18\x04 \x01(\tR\ttransport\"r\n" +
|
||||
"\fNSGroupState\x12\x18\n" +
|
||||
"\aservers\x18\x01 \x03(\tR\aservers\x12\x18\n" +
|
||||
"\adomains\x18\x02 \x03(\tR\adomains\x12\x18\n" +
|
||||
|
||||
@@ -370,6 +370,7 @@ message RelayState {
|
||||
string URI = 1;
|
||||
bool available = 2;
|
||||
string error = 3;
|
||||
string transport = 4;
|
||||
}
|
||||
|
||||
message NSGroupState {
|
||||
|
||||
@@ -98,6 +98,7 @@ type RelayStateOutputDetail struct {
|
||||
URI string `json:"uri" yaml:"uri"`
|
||||
Available bool `json:"available" yaml:"available"`
|
||||
Error string `json:"error" yaml:"error"`
|
||||
Transport string `json:"transport,omitempty" yaml:"transport,omitempty"`
|
||||
}
|
||||
|
||||
type RelayStateOutput struct {
|
||||
@@ -217,7 +218,8 @@ func mapRelays(relays []*proto.RelayState) RelayStateOutput {
|
||||
RelayStateOutputDetail{
|
||||
URI: relay.URI,
|
||||
Available: available,
|
||||
Error: relay.GetError(),
|
||||
Error: relayErrorString(relay.GetError()),
|
||||
Transport: relay.GetTransport(),
|
||||
},
|
||||
)
|
||||
|
||||
@@ -233,6 +235,12 @@ func mapRelays(relays []*proto.RelayState) RelayStateOutput {
|
||||
}
|
||||
}
|
||||
|
||||
// relayErrorString flattens a newline-joined aggregated relay error onto a
|
||||
// single line for status output.
|
||||
func relayErrorString(s string) string {
|
||||
return strings.ReplaceAll(s, "\n", "; ")
|
||||
}
|
||||
|
||||
func mapNSGroups(servers []*proto.NSGroupState) []NsServerGroupStateOutput {
|
||||
mappedNSGroups := make([]NsServerGroupStateOutput, 0, len(servers))
|
||||
for _, pbNsGroupServer := range servers {
|
||||
@@ -439,6 +447,8 @@ func (o *OutputOverview) GeneralSummary(showURL bool, showRelays bool, showNameS
|
||||
available = "Unavailable"
|
||||
reason = fmt.Sprintf(", reason: %s", relay.Error)
|
||||
}
|
||||
} else if relay.Transport != "" {
|
||||
available = fmt.Sprintf("%s via %s", available, relay.Transport)
|
||||
}
|
||||
|
||||
relaysString += fmt.Sprintf("\n [%s] is %s%s", relay.URI, available, reason)
|
||||
|
||||
@@ -641,3 +641,13 @@ func TestTimeAgo(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMapRelaysTransport(t *testing.T) {
|
||||
out := mapRelays([]*proto.RelayState{
|
||||
{URI: "rels://relay.example:443", Available: true, Transport: "quic"},
|
||||
{URI: "rels://relay2.example:443", Available: true, Transport: "ws"},
|
||||
})
|
||||
require.Len(t, out.Details, 2)
|
||||
assert.Equal(t, "quic", out.Details[0].Transport)
|
||||
assert.Equal(t, "ws", out.Details[1].Transport)
|
||||
}
|
||||
|
||||
@@ -311,11 +311,12 @@ initialize_default_values() {
|
||||
NETBIRD_STUN_PORT=3478
|
||||
|
||||
# Docker images
|
||||
DASHBOARD_IMAGE="netbirdio/dashboard:latest"
|
||||
DASHBOARD_IMAGE=${DASHBOARD_IMAGE:-"netbirdio/dashboard:latest"}
|
||||
# Combined server replaces separate signal, relay, and management containers
|
||||
NETBIRD_SERVER_IMAGE="netbirdio/netbird-server:latest"
|
||||
NETBIRD_PROXY_IMAGE="netbirdio/reverse-proxy:latest"
|
||||
|
||||
NETBIRD_SERVER_IMAGE=${NETBIRD_SERVER_IMAGE:-"netbirdio/netbird-server:latest"}
|
||||
NETBIRD_PROXY_IMAGE=${NETBIRD_PROXY_IMAGE:-"netbirdio/reverse-proxy:latest"}
|
||||
TRAEFIK_IMAGE=${TRAEFIK_IMAGE:-"traefik:v3.6"}
|
||||
CROWDSEC_IMAGE=${CROWDSEC_IMAGE:-"crowdsecurity/crowdsec:v1.7.7"}
|
||||
# Reverse proxy configuration
|
||||
REVERSE_PROXY_TYPE="0"
|
||||
TRAEFIK_EXTERNAL_NETWORK=""
|
||||
@@ -656,7 +657,7 @@ render_docker_compose_traefik_builtin() {
|
||||
if [[ "$ENABLE_CROWDSEC" == "true" ]]; then
|
||||
crowdsec_service="
|
||||
crowdsec:
|
||||
image: crowdsecurity/crowdsec:v1.7.7
|
||||
image: $CROWDSEC_IMAGE
|
||||
container_name: netbird-crowdsec
|
||||
restart: unless-stopped
|
||||
networks: [netbird]
|
||||
@@ -687,7 +688,7 @@ render_docker_compose_traefik_builtin() {
|
||||
services:
|
||||
# Traefik reverse proxy (automatic TLS via Let's Encrypt)
|
||||
traefik:
|
||||
image: traefik:v3.6
|
||||
image: $TRAEFIK_IMAGE
|
||||
container_name: netbird-traefik
|
||||
restart: unless-stopped
|
||||
networks:
|
||||
@@ -771,7 +772,7 @@ $traefik_dynamic_volume
|
||||
labels:
|
||||
- traefik.enable=true
|
||||
# gRPC router (needs h2c backend for HTTP/2 cleartext)
|
||||
- traefik.http.routers.netbird-grpc.rule=Host(\`$NETBIRD_DOMAIN\`) && (PathPrefix(\`/signalexchange.SignalExchange/\`) || PathPrefix(\`/management.ManagementService/\`))
|
||||
- traefik.http.routers.netbird-grpc.rule=Host(\`$NETBIRD_DOMAIN\`) && (PathPrefix(\`/signalexchange.SignalExchange/\`) || PathPrefix(\`/management.ManagementService/\`) || PathPrefix(\`/management.ProxyService/\`))
|
||||
- traefik.http.routers.netbird-grpc.entrypoints=websecure
|
||||
- traefik.http.routers.netbird-grpc.tls=true
|
||||
- traefik.http.routers.netbird-grpc.tls.certresolver=letsencrypt
|
||||
|
||||
@@ -122,7 +122,7 @@ func (s *BaseServer) Start(ctx context.Context) error {
|
||||
s.errCh = make(chan error, 4)
|
||||
|
||||
if s.autoResolveDomains {
|
||||
s.resolveDomains(srvCtx)
|
||||
s.ResolveDomains(srvCtx)
|
||||
}
|
||||
|
||||
s.PeersManager()
|
||||
@@ -398,10 +398,10 @@ func (s *BaseServer) serveGRPCWithHTTP(ctx context.Context, listener net.Listene
|
||||
}()
|
||||
}
|
||||
|
||||
// resolveDomains determines dnsDomain and mgmtSingleAccModeDomain based on store state.
|
||||
// ResolveDomains determines dnsDomain and mgmtSingleAccModeDomain based on store state.
|
||||
// Fresh installs use the default self-hosted domain, while existing installs reuse the
|
||||
// persisted account domain to keep addressing stable across config changes.
|
||||
func (s *BaseServer) resolveDomains(ctx context.Context) {
|
||||
func (s *BaseServer) ResolveDomains(ctx context.Context) {
|
||||
st := s.Store()
|
||||
|
||||
setDefault := func(logMsg string, args ...any) {
|
||||
|
||||
@@ -22,7 +22,7 @@ func TestResolveDomains_FreshInstallUsesDefault(t *testing.T) {
|
||||
srv := NewServer(&Config{NbConfig: &nbconfig.Config{}})
|
||||
Inject[store.Store](srv, mockStore)
|
||||
|
||||
srv.resolveDomains(context.Background())
|
||||
srv.ResolveDomains(context.Background())
|
||||
|
||||
require.Equal(t, DefaultSelfHostedDomain, srv.dnsDomain)
|
||||
require.Equal(t, DefaultSelfHostedDomain, srv.mgmtSingleAccModeDomain)
|
||||
@@ -40,7 +40,7 @@ func TestResolveDomains_ExistingInstallUsesPersistedDomain(t *testing.T) {
|
||||
srv := NewServer(&Config{NbConfig: &nbconfig.Config{}})
|
||||
Inject[store.Store](srv, mockStore)
|
||||
|
||||
srv.resolveDomains(context.Background())
|
||||
srv.ResolveDomains(context.Background())
|
||||
|
||||
require.Equal(t, "vpn.mycompany.com", srv.dnsDomain)
|
||||
require.Equal(t, "vpn.mycompany.com", srv.mgmtSingleAccModeDomain)
|
||||
@@ -56,7 +56,7 @@ func TestResolveDomains_StoreErrorFallsBackToDefault(t *testing.T) {
|
||||
srv := NewServer(&Config{NbConfig: &nbconfig.Config{}})
|
||||
Inject[store.Store](srv, mockStore)
|
||||
|
||||
srv.resolveDomains(context.Background())
|
||||
srv.ResolveDomains(context.Background())
|
||||
|
||||
require.Equal(t, DefaultSelfHostedDomain, srv.dnsDomain)
|
||||
require.Equal(t, DefaultSelfHostedDomain, srv.mgmtSingleAccModeDomain)
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
"go.opentelemetry.io/otel/metric"
|
||||
)
|
||||
|
||||
@@ -119,8 +120,8 @@ func NewMetrics(ctx context.Context, meter metric.Meter) (*Metrics, error) {
|
||||
}
|
||||
|
||||
// PeerConnected increments the number of connected peers and increments number of idle connections
|
||||
func (m *Metrics) PeerConnected(id string) {
|
||||
m.peers.Add(m.ctx, 1)
|
||||
func (m *Metrics) PeerConnected(id, transport string) {
|
||||
m.peers.Add(m.ctx, 1, metric.WithAttributes(attribute.String("transport", transport)))
|
||||
m.mutexActivity.Lock()
|
||||
defer m.mutexActivity.Unlock()
|
||||
|
||||
@@ -138,8 +139,8 @@ func (m *Metrics) RecordPeerStoreTime(duration time.Duration) {
|
||||
}
|
||||
|
||||
// PeerDisconnected decrements the number of connected peers and decrements number of idle or active connections
|
||||
func (m *Metrics) PeerDisconnected(id string) {
|
||||
m.peers.Add(m.ctx, -1)
|
||||
func (m *Metrics) PeerDisconnected(id, transport string) {
|
||||
m.peers.Add(m.ctx, -1, metric.WithAttributes(attribute.String("transport", transport)))
|
||||
m.mutexActivity.Lock()
|
||||
defer m.mutexActivity.Unlock()
|
||||
|
||||
|
||||
@@ -11,4 +11,6 @@ type Conn interface {
|
||||
Write(ctx context.Context, b []byte) (n int, err error)
|
||||
RemoteAddr() net.Addr
|
||||
Close() error
|
||||
// Protocol returns the transport name.
|
||||
Protocol() string
|
||||
}
|
||||
|
||||
@@ -42,6 +42,11 @@ func (c *Conn) RemoteAddr() net.Addr {
|
||||
return c.session.RemoteAddr()
|
||||
}
|
||||
|
||||
// Protocol returns the transport name for this connection.
|
||||
func (c *Conn) Protocol() string {
|
||||
return "quic"
|
||||
}
|
||||
|
||||
func (c *Conn) Close() error {
|
||||
c.closedMu.Lock()
|
||||
if c.closed {
|
||||
|
||||
@@ -64,6 +64,11 @@ func (c *Conn) RemoteAddr() net.Addr {
|
||||
return c.rAddr
|
||||
}
|
||||
|
||||
// Protocol returns the transport name for this connection.
|
||||
func (c *Conn) Protocol() string {
|
||||
return "ws"
|
||||
}
|
||||
|
||||
func (c *Conn) Close() error {
|
||||
c.closedMu.Lock()
|
||||
c.closed = true
|
||||
|
||||
@@ -154,15 +154,16 @@ func (r *Relay) Accept(conn listener.Conn) {
|
||||
}
|
||||
r.notifier.PeerCameOnline(peer.ID())
|
||||
|
||||
transport := conn.Protocol()
|
||||
r.metrics.RecordPeerStoreTime(time.Since(storeTime))
|
||||
r.metrics.PeerConnected(peer.String())
|
||||
r.metrics.PeerConnected(peer.String(), transport)
|
||||
go func() {
|
||||
peer.Work()
|
||||
if deleted := r.store.DeletePeer(peer); deleted {
|
||||
r.notifier.PeerWentOffline(peer.ID())
|
||||
}
|
||||
peer.log.Debugf("relay connection closed")
|
||||
r.metrics.PeerDisconnected(peer.String())
|
||||
r.metrics.PeerDisconnected(peer.String(), transport)
|
||||
}()
|
||||
|
||||
if err := h.handshakeResponse(hsCtx); err != nil {
|
||||
|
||||
@@ -9,12 +9,14 @@ import (
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
auth "github.com/netbirdio/netbird/shared/relay/auth/hmac"
|
||||
"github.com/netbirdio/netbird/shared/relay/client/dialer"
|
||||
netErr "github.com/netbirdio/netbird/shared/relay/client/dialer/net"
|
||||
"github.com/netbirdio/netbird/shared/relay/healthcheck"
|
||||
"github.com/netbirdio/netbird/shared/relay/messages"
|
||||
)
|
||||
@@ -143,6 +145,11 @@ func (cc *connContainer) close() {
|
||||
}
|
||||
}
|
||||
|
||||
// transportConn is implemented by relay connections that know their transport.
|
||||
type transportConn interface {
|
||||
Protocol() string
|
||||
}
|
||||
|
||||
// Client is a client for the relay server. It is responsible for establishing a connection to the relay server and
|
||||
// managing connections to other peers. All exported functions are safe to call concurrently. After close the connection,
|
||||
// the client can be reused by calling Connect again. When the client is closed, all connections are closed too.
|
||||
@@ -172,6 +179,31 @@ type Client struct {
|
||||
stateSubscription *PeersStateSubscription
|
||||
|
||||
mtu uint16
|
||||
|
||||
// transportFallback, when set, records datagram-too-large failures so a
|
||||
// datagram-sized transport is avoided on subsequent connects. Shared via
|
||||
// the manager.
|
||||
transportFallback *transportFallback
|
||||
// datagramFallbackTriggered guards a single fallback per connection so a
|
||||
// burst of oversized datagrams triggers one reconnect, not many.
|
||||
datagramFallbackTriggered atomic.Bool
|
||||
|
||||
// transport is the negotiated relay transport of the
|
||||
// current connection, guarded by mu.
|
||||
transport string
|
||||
}
|
||||
|
||||
// Transport returns the negotiated relay transport of the current connection,
|
||||
// or an empty string when not connected.
|
||||
func (c *Client) Transport() string {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
return c.transport
|
||||
}
|
||||
|
||||
// SetTransportFallback wires the shared datagram-transport fallback tracker.
|
||||
func (c *Client) SetTransportFallback(tf *transportFallback) {
|
||||
c.transportFallback = tf
|
||||
}
|
||||
|
||||
// NewClient creates a new client for the relay server. The client is not connected to the server until the Connect
|
||||
@@ -361,12 +393,13 @@ func (c *Client) Close() error {
|
||||
}
|
||||
|
||||
func (c *Client) connect(ctx context.Context) (*RelayAddr, error) {
|
||||
dialers := c.getDialers()
|
||||
mode := transportModeFromEnv()
|
||||
dialers := c.getDialers(mode)
|
||||
|
||||
var conn net.Conn
|
||||
if c.serverIP.IsValid() {
|
||||
var err error
|
||||
conn, err = c.dialRaceDirect(ctx, dialers)
|
||||
conn, err = c.dialRaceDirect(ctx, mode, dialers)
|
||||
if err != nil {
|
||||
c.log.Infof("dial via server IP %s failed, falling back to FQDN: %v", c.serverIP, err)
|
||||
conn = nil
|
||||
@@ -375,6 +408,9 @@ func (c *Client) connect(ctx context.Context) (*RelayAddr, error) {
|
||||
|
||||
if conn == nil {
|
||||
rd := dialer.NewRaceDial(c.log, dialer.DefaultConnectionTimeout, c.connectionURL, dialers...)
|
||||
if mode.sequential() {
|
||||
rd.WithSequential()
|
||||
}
|
||||
var err error
|
||||
conn, err = rd.Dial(ctx)
|
||||
if err != nil {
|
||||
@@ -382,6 +418,10 @@ func (c *Client) connect(ctx context.Context) (*RelayAddr, error) {
|
||||
}
|
||||
}
|
||||
c.relayConn = conn
|
||||
c.datagramFallbackTriggered.Store(false)
|
||||
if tc, ok := conn.(transportConn); ok {
|
||||
c.transport = tc.Protocol()
|
||||
}
|
||||
|
||||
instanceURL, err := c.handShake(ctx)
|
||||
if err != nil {
|
||||
@@ -396,7 +436,7 @@ func (c *Client) connect(ctx context.Context) (*RelayAddr, error) {
|
||||
}
|
||||
|
||||
// dialRaceDirect dials c.serverIP, preserving the original FQDN as the TLS ServerName for SNI.
|
||||
func (c *Client) dialRaceDirect(ctx context.Context, dialers []dialer.DialeFn) (net.Conn, error) {
|
||||
func (c *Client) dialRaceDirect(ctx context.Context, mode TransportMode, dialers []dialer.DialeFn) (net.Conn, error) {
|
||||
directURL, serverName, err := substituteHost(c.connectionURL, c.serverIP)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("substitute host: %w", err)
|
||||
@@ -406,6 +446,9 @@ func (c *Client) dialRaceDirect(ctx context.Context, dialers []dialer.DialeFn) (
|
||||
|
||||
rd := dialer.NewRaceDial(c.log, dialer.DefaultConnectionTimeout, directURL, dialers...).
|
||||
WithServerName(serverName)
|
||||
if mode.sequential() {
|
||||
rd.WithSequential()
|
||||
}
|
||||
return rd.Dial(ctx)
|
||||
}
|
||||
|
||||
@@ -631,13 +674,53 @@ func (c *Client) writeTo(containerRef *connContainer, dstID messages.PeerID, pay
|
||||
}
|
||||
|
||||
// the write always return with 0 length because the underling does not support the size feedback.
|
||||
_, err = c.relayConn.Write(msg)
|
||||
conn := c.relayConn
|
||||
_, err = conn.Write(msg)
|
||||
if err != nil {
|
||||
c.log.Errorf("failed to write transport message: %s", err)
|
||||
if errors.Is(err, netErr.ErrDatagramTooLarge) {
|
||||
c.onDatagramTooLarge(conn, err)
|
||||
} else {
|
||||
c.log.Errorf("failed to write transport message: %s", err)
|
||||
}
|
||||
}
|
||||
return len(payload), err
|
||||
}
|
||||
|
||||
// onDatagramTooLarge reacts to a datagram rejected as too large for the path.
|
||||
// When a non-datagram transport is available, it records a fallback for this
|
||||
// server and closes the connection so the reconnect avoids datagram-sized
|
||||
// transports. A single fallback is triggered per connection regardless of how
|
||||
// many oversized datagrams arrive. cause carries the datagram size and budget.
|
||||
func (c *Client) onDatagramTooLarge(conn net.Conn, cause error) {
|
||||
// Handle one oversized datagram per connection; a burst triggers a single
|
||||
// fallback (and a single log line), not many.
|
||||
if !c.datagramFallbackTriggered.CompareAndSwap(false, true) {
|
||||
return
|
||||
}
|
||||
|
||||
// If the selected mode offers no non-datagram transport (e.g. pinned to a
|
||||
// datagram-sized transport), reconnecting would just re-fail, so leave the
|
||||
// connection up rather than loop.
|
||||
if len(nonDatagramSized(c.baseDialers(transportModeFromEnv()))) == 0 {
|
||||
c.log.Warnf("%s, but no non-datagram transport is available, not falling back", cause)
|
||||
return
|
||||
}
|
||||
|
||||
// Without the shared tracker a reconnect would just select the same
|
||||
// transport again and re-fail, so leave the connection up rather than loop.
|
||||
if c.transportFallback == nil {
|
||||
c.log.Debugf("%s, but no transport fallback configured, leaving connection up", cause)
|
||||
return
|
||||
}
|
||||
|
||||
window := c.transportFallback.recordFailure(c.connectionURL)
|
||||
c.log.Warnf("%s, avoiding datagram-sized transport for %s", cause, window)
|
||||
|
||||
if err := conn.Close(); err != nil {
|
||||
c.log.Debugf("close relay connection for transport fallback: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) listenForStopEvents(ctx context.Context, hc *healthcheck.Receiver, conn net.Conn, internalStopFlag *internalStopFlag) {
|
||||
for {
|
||||
select {
|
||||
@@ -729,6 +812,7 @@ func (c *Client) close(gracefullyExit bool) error {
|
||||
return nil
|
||||
}
|
||||
c.serviceIsRunning = false
|
||||
c.transport = ""
|
||||
|
||||
c.muInstanceURL.Lock()
|
||||
c.instanceURL = nil
|
||||
|
||||
18
shared/relay/client/dialer/capability.go
Normal file
18
shared/relay/client/dialer/capability.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package dialer
|
||||
|
||||
// DatagramSized is implemented by dialers whose connections carry each write in
|
||||
// a single datagram, so a write can be rejected when it exceeds the path's
|
||||
// datagram budget (e.g. QUIC). Transports without this capability (e.g.
|
||||
// WebSocket over TCP) impose no per-write size limit, so the relay client can
|
||||
// fall back to them when a datagram-sized transport rejects a write as too
|
||||
// large. The capability is advertised per dialer rather than hardcoded, so a
|
||||
// new transport only needs to declare whether it is datagram-sized.
|
||||
type DatagramSized interface {
|
||||
DatagramSized()
|
||||
}
|
||||
|
||||
// IsDatagramSized reports whether d produces datagram-sized connections.
|
||||
func IsDatagramSized(d DialeFn) bool {
|
||||
_, ok := d.(DatagramSized)
|
||||
return ok
|
||||
}
|
||||
@@ -4,4 +4,9 @@ import "errors"
|
||||
|
||||
var (
|
||||
ErrClosedByServer = errors.New("closed by server")
|
||||
|
||||
// ErrDatagramTooLarge is returned when a transport message exceeds the
|
||||
// QUIC datagram size the path to the relay can carry. The relay client
|
||||
// treats it as a signal to fall back to a non-datagram transport.
|
||||
ErrDatagramTooLarge = errors.New("datagram frame too large")
|
||||
)
|
||||
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
netErr "github.com/netbirdio/netbird/shared/relay/client/dialer/net"
|
||||
)
|
||||
@@ -52,15 +51,17 @@ func (c *Conn) Read(b []byte) (n int, err error) {
|
||||
}
|
||||
|
||||
func (c *Conn) Write(b []byte) (int, error) {
|
||||
err := c.session.SendDatagram(b)
|
||||
if err != nil {
|
||||
err = c.remoteCloseErrHandling(err)
|
||||
log.Errorf("failed to write to QUIC stream: %v", err)
|
||||
return 0, err
|
||||
if err := c.session.SendDatagram(b); err != nil {
|
||||
return 0, c.writeErrHandling(err, len(b))
|
||||
}
|
||||
return len(b), nil
|
||||
}
|
||||
|
||||
// Protocol returns the transport name for this connection.
|
||||
func (c *Conn) Protocol() string {
|
||||
return Network
|
||||
}
|
||||
|
||||
func (c *Conn) RemoteAddr() net.Addr {
|
||||
return c.session.RemoteAddr()
|
||||
}
|
||||
@@ -95,3 +96,15 @@ func (c *Conn) remoteCloseErrHandling(err error) error {
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// writeErrHandling normalizes SendDatagram errors. A datagram that exceeds the
|
||||
// path's QUIC packet budget is mapped to ErrDatagramTooLarge (annotated with the
|
||||
// datagram size and path budget) so the relay client can fall back to a
|
||||
// non-datagram transport.
|
||||
func (c *Conn) writeErrHandling(err error, size int) error {
|
||||
var tooLarge *quic.DatagramTooLargeError
|
||||
if errors.As(err, &tooLarge) {
|
||||
return fmt.Errorf("%w: %d byte datagram over path budget %d", netErr.ErrDatagramTooLarge, size, tooLarge.MaxDatagramPayloadSize)
|
||||
}
|
||||
return c.remoteCloseErrHandling(err)
|
||||
}
|
||||
|
||||
@@ -2,13 +2,13 @@ package quic
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go"
|
||||
"github.com/quic-go/quic-go/logging"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
@@ -23,6 +23,12 @@ func (d Dialer) Protocol() string {
|
||||
return Network
|
||||
}
|
||||
|
||||
// DatagramSized marks QUIC as a datagram-sized transport: relay traffic is
|
||||
// carried in QUIC DATAGRAM frames, which must fit a single packet.
|
||||
func (d Dialer) DatagramSized() {
|
||||
// Intentional marker method; presence is the capability signal.
|
||||
}
|
||||
|
||||
func (d Dialer) Dial(ctx context.Context, address, serverName string) (net.Conn, error) {
|
||||
quicURL, err := prepareURL(address)
|
||||
if err != nil {
|
||||
@@ -47,26 +53,21 @@ func (d Dialer) Dial(ctx context.Context, address, serverName string) (net.Conn,
|
||||
MaxIdleTimeout: 4 * time.Minute,
|
||||
EnableDatagrams: true,
|
||||
InitialPacketSize: nbRelay.QUICInitialPacketSize,
|
||||
Tracer: connectionTracer(quicURL),
|
||||
}
|
||||
|
||||
udpConn, err := nbnet.ListenUDP("udp", &net.UDPAddr{Port: 0})
|
||||
if err != nil {
|
||||
log.Errorf("failed to listen on UDP: %s", err)
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("listen udp: %w", err)
|
||||
}
|
||||
|
||||
udpAddr, err := net.ResolveUDPAddr("udp", quicURL)
|
||||
if err != nil {
|
||||
log.Errorf("failed to resolve UDP address: %s", err)
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("resolve %s: %w", quicURL, err)
|
||||
}
|
||||
|
||||
session, err := quic.Dial(ctx, udpConn, udpAddr, tlsClientConfig, quicConfig)
|
||||
if err != nil {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return nil, err
|
||||
}
|
||||
log.Errorf("failed to dial to Relay server via QUIC '%s': %s", quicURL, err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -74,6 +75,28 @@ func (d Dialer) Dial(ctx context.Context, address, serverName string) (net.Conn,
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// connectionTracer returns a QUIC tracer that logs the DPLPMTUD result and the
|
||||
// reason a relay connection closed, so the path MTU settled on and teardown
|
||||
// cause are visible in logs. Lines carry the relay address as a structured
|
||||
// field, matching the rest of the relay client logging.
|
||||
func connectionTracer(addr string) func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer {
|
||||
relayLog := log.WithField("relay", addr)
|
||||
return func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer {
|
||||
return &logging.ConnectionTracer{
|
||||
UpdatedMTU: func(mtu logging.ByteCount, done bool) {
|
||||
if done {
|
||||
relayLog.Infof("QUIC path MTU settled at %d", mtu)
|
||||
return
|
||||
}
|
||||
relayLog.Debugf("QUIC path MTU probing at %d", mtu)
|
||||
},
|
||||
ClosedConnection: func(err error) {
|
||||
relayLog.Debugf("QUIC connection closed: %v", err)
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func prepareURL(address string) (string, error) {
|
||||
var host string
|
||||
var defaultPort string
|
||||
|
||||
@@ -3,6 +3,7 @@ package dialer
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
@@ -32,6 +33,7 @@ type RaceDial struct {
|
||||
serverName string
|
||||
dialerFns []DialeFn
|
||||
connectionTimeout time.Duration
|
||||
sequential bool
|
||||
}
|
||||
|
||||
func NewRaceDial(log *log.Entry, connectionTimeout time.Duration, serverURL string, dialerFns ...DialeFn) *RaceDial {
|
||||
@@ -53,9 +55,24 @@ func (r *RaceDial) WithServerName(serverName string) *RaceDial {
|
||||
return r
|
||||
}
|
||||
|
||||
// WithSequential makes Dial try the dialers in order, falling back to the next
|
||||
// only when one fails to connect, instead of racing them concurrently.
|
||||
//
|
||||
// Mutates the receiver and is not safe for concurrent reconfiguration; a
|
||||
// RaceDial is intended to be constructed per dial and discarded.
|
||||
func (r *RaceDial) WithSequential() *RaceDial {
|
||||
r.sequential = true
|
||||
return r
|
||||
}
|
||||
|
||||
func (r *RaceDial) Dial(ctx context.Context) (net.Conn, error) {
|
||||
if r.sequential {
|
||||
return r.dialSequential(ctx)
|
||||
}
|
||||
|
||||
connChan := make(chan dialResult, len(r.dialerFns))
|
||||
winnerConn := make(chan net.Conn, 1)
|
||||
errChan := make(chan error, 1)
|
||||
abortCtx, abort := context.WithCancel(ctx)
|
||||
defer abort()
|
||||
|
||||
@@ -63,15 +80,41 @@ func (r *RaceDial) Dial(ctx context.Context) (net.Conn, error) {
|
||||
go r.dial(dfn, abortCtx, connChan)
|
||||
}
|
||||
|
||||
go r.processResults(connChan, winnerConn, abort)
|
||||
go r.processResults(connChan, winnerConn, errChan, abort)
|
||||
|
||||
conn, ok := <-winnerConn
|
||||
if !ok {
|
||||
return nil, errors.New("failed to dial to Relay server on any protocol")
|
||||
return nil, <-errChan
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// dialSequential tries each dialer in order, returning the first connection and
|
||||
// falling back to the next on failure.
|
||||
func (r *RaceDial) dialSequential(ctx context.Context) (net.Conn, error) {
|
||||
var errs []error
|
||||
for _, dfn := range r.dialerFns {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
attemptCtx, cancel := context.WithTimeout(ctx, r.connectionTimeout)
|
||||
r.log.Infof("dialing Relay server via %s", dfn.Protocol())
|
||||
conn, err := dfn.Dial(attemptCtx, r.serverURL, r.serverName)
|
||||
cancel()
|
||||
if err != nil {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return nil, err
|
||||
}
|
||||
r.log.Errorf("failed to dial via %s: %s", dfn.Protocol(), err)
|
||||
errs = append(errs, fmt.Errorf("%s: %w", dfn.Protocol(), err))
|
||||
continue
|
||||
}
|
||||
r.log.Infof("successfully dialed via: %s", dfn.Protocol())
|
||||
return conn, nil
|
||||
}
|
||||
return nil, dialErr(errs)
|
||||
}
|
||||
|
||||
func (r *RaceDial) dial(dfn DialeFn, abortCtx context.Context, connChan chan dialResult) {
|
||||
ctx, cancel := context.WithTimeout(abortCtx, r.connectionTimeout)
|
||||
defer cancel()
|
||||
@@ -81,8 +124,9 @@ func (r *RaceDial) dial(dfn DialeFn, abortCtx context.Context, connChan chan dia
|
||||
connChan <- dialResult{Conn: conn, Protocol: dfn.Protocol(), Err: err}
|
||||
}
|
||||
|
||||
func (r *RaceDial) processResults(connChan chan dialResult, winnerConn chan net.Conn, abort context.CancelFunc) {
|
||||
func (r *RaceDial) processResults(connChan chan dialResult, winnerConn chan net.Conn, errChan chan error, abort context.CancelFunc) {
|
||||
var hasWinner bool
|
||||
errsByProtocol := make(map[string]error)
|
||||
for i := 0; i < len(r.dialerFns); i++ {
|
||||
dr := <-connChan
|
||||
if dr.Err != nil {
|
||||
@@ -90,6 +134,7 @@ func (r *RaceDial) processResults(connChan chan dialResult, winnerConn chan net.
|
||||
r.log.Infof("connection attempt aborted via: %s", dr.Protocol)
|
||||
} else {
|
||||
r.log.Errorf("failed to dial via %s: %s", dr.Protocol, dr.Err)
|
||||
errsByProtocol[dr.Protocol] = fmt.Errorf("%s: %w", dr.Protocol, dr.Err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
@@ -107,5 +152,29 @@ func (r *RaceDial) processResults(connChan chan dialResult, winnerConn chan net.
|
||||
hasWinner = true
|
||||
winnerConn <- dr.Conn
|
||||
}
|
||||
if !hasWinner {
|
||||
errChan <- dialErr(r.orderedErrs(errsByProtocol))
|
||||
}
|
||||
close(winnerConn)
|
||||
}
|
||||
|
||||
// orderedErrs returns the per-protocol errors in dialer order, so the combined
|
||||
// error is stable regardless of which attempt failed first.
|
||||
func (r *RaceDial) orderedErrs(byProtocol map[string]error) []error {
|
||||
errs := make([]error, 0, len(byProtocol))
|
||||
for _, dfn := range r.dialerFns {
|
||||
if err, ok := byProtocol[dfn.Protocol()]; ok {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
}
|
||||
return errs
|
||||
}
|
||||
|
||||
// dialErr combines per-dialer failures, preserving the underlying reasons
|
||||
// (e.g. "connection refused") rather than a generic message.
|
||||
func dialErr(errs []error) error {
|
||||
if len(errs) == 0 {
|
||||
return errors.New("no relay transport available")
|
||||
}
|
||||
return errors.Join(errs...)
|
||||
}
|
||||
|
||||
@@ -250,3 +250,66 @@ func TestRaceDialFirstSuccessfulDialerWins(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRaceDialSequentialFallback(t *testing.T) {
|
||||
logger := logrus.NewEntry(logrus.New())
|
||||
serverURL := "test.server.com"
|
||||
|
||||
var firstDialed, secondDialed bool
|
||||
preferred := &MockDialer{
|
||||
protocolStr: "quic",
|
||||
dialFunc: func(ctx context.Context, address string) (net.Conn, error) {
|
||||
firstDialed = true
|
||||
return nil, errors.New("quic unreachable")
|
||||
},
|
||||
}
|
||||
fallbackConn := &MockConn{remoteAddr: &MockAddr{network: "ws"}}
|
||||
fallback := &MockDialer{
|
||||
protocolStr: "ws",
|
||||
dialFunc: func(ctx context.Context, address string) (net.Conn, error) {
|
||||
secondDialed = true
|
||||
return fallbackConn, nil
|
||||
},
|
||||
}
|
||||
|
||||
rd := NewRaceDial(logger, DefaultConnectionTimeout, serverURL, preferred, fallback).WithSequential()
|
||||
conn, err := rd.Dial(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("expected fallback to succeed, got %v", err)
|
||||
}
|
||||
if conn != fallbackConn {
|
||||
t.Errorf("expected fallback connection, got %v", conn)
|
||||
}
|
||||
if !firstDialed || !secondDialed {
|
||||
t.Errorf("expected both dialers attempted in order, first=%v second=%v", firstDialed, secondDialed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRaceDialSequentialPreferredWins(t *testing.T) {
|
||||
logger := logrus.NewEntry(logrus.New())
|
||||
serverURL := "test.server.com"
|
||||
|
||||
preferredConn := &MockConn{remoteAddr: &MockAddr{network: "quic"}}
|
||||
preferred := &MockDialer{
|
||||
protocolStr: "quic",
|
||||
dialFunc: func(ctx context.Context, address string) (net.Conn, error) {
|
||||
return preferredConn, nil
|
||||
},
|
||||
}
|
||||
fallback := &MockDialer{
|
||||
protocolStr: "ws",
|
||||
dialFunc: func(ctx context.Context, address string) (net.Conn, error) {
|
||||
t.Errorf("fallback dialer must not be tried when preferred succeeds")
|
||||
return nil, errors.New("should not happen")
|
||||
},
|
||||
}
|
||||
|
||||
rd := NewRaceDial(logger, DefaultConnectionTimeout, serverURL, preferred, fallback).WithSequential()
|
||||
conn, err := rd.Dial(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("expected preferred to succeed, got %v", err)
|
||||
}
|
||||
if conn != preferredConn {
|
||||
t.Errorf("expected preferred connection, got %v", conn)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -33,6 +33,11 @@ func NewConn(wsConn *websocket.Conn, serverAddress string, underlying net.Conn)
|
||||
}
|
||||
}
|
||||
|
||||
// Protocol returns the transport name for this connection.
|
||||
func (c *Conn) Protocol() string {
|
||||
return Network
|
||||
}
|
||||
|
||||
func (c *Conn) Read(b []byte) (n int, err error) {
|
||||
t, ioReader, err := c.Conn.Reader(c.ctx)
|
||||
if err != nil {
|
||||
|
||||
@@ -22,7 +22,7 @@ type Dialer struct {
|
||||
}
|
||||
|
||||
func (d Dialer) Protocol() string {
|
||||
return "WS"
|
||||
return Network
|
||||
}
|
||||
|
||||
func (d Dialer) Dial(ctx context.Context, address, serverName string) (net.Conn, error) {
|
||||
@@ -39,7 +39,12 @@ func (d Dialer) Dial(ctx context.Context, address, serverName string) (net.Conn,
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return nil, err
|
||||
}
|
||||
log.Errorf("failed to dial to Relay server '%s': %s", wsURL, err)
|
||||
// websocket.Dial wraps the cause in verbose layers; surface the
|
||||
// underlying network error when present.
|
||||
var opErr *net.OpError
|
||||
if errors.As(err, &opErr) {
|
||||
return nil, opErr
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
|
||||
@@ -9,11 +9,42 @@ import (
|
||||
"github.com/netbirdio/netbird/shared/relay/client/dialer/ws"
|
||||
)
|
||||
|
||||
// getDialers returns the list of dialers to use for connecting to the relay server.
|
||||
func (c *Client) getDialers() []dialer.DialeFn {
|
||||
if c.mtu > 0 && c.mtu > iface.DefaultMTU {
|
||||
c.log.Infof("MTU %d exceeds default (%d), forcing WebSocket transport to avoid DATAGRAM frame size issues", c.mtu, iface.DefaultMTU)
|
||||
return []dialer.DialeFn{ws.Dialer{}}
|
||||
// getDialers returns the ordered dialers for connecting to the relay server. It
|
||||
// applies the datagram fallback generically: if this server recently rejected a
|
||||
// datagram-sized transport, those dialers are dropped, leaving the rest.
|
||||
func (c *Client) getDialers(mode TransportMode) []dialer.DialeFn {
|
||||
dialers := c.baseDialers(mode)
|
||||
|
||||
if c.transportFallback != nil && c.transportFallback.avoidDatagramSized(c.connectionURL) {
|
||||
if filtered := nonDatagramSized(dialers); len(filtered) > 0 {
|
||||
c.log.Infof("relay recently rejected a datagram-sized transport, avoiding it")
|
||||
return filtered
|
||||
}
|
||||
}
|
||||
return []dialer.DialeFn{quic.Dialer{}, ws.Dialer{}}
|
||||
return dialers
|
||||
}
|
||||
|
||||
// baseDialers returns the ordered dialers for the mode, before any datagram
|
||||
// fallback filtering. For racing modes (auto) the order is irrelevant; for
|
||||
// prefer modes the first entry is tried before falling back to the second.
|
||||
func (c *Client) baseDialers(mode TransportMode) []dialer.DialeFn {
|
||||
switch mode {
|
||||
case TransportModeWS:
|
||||
c.log.Infof("%s=ws, using WebSocket transport", EnvRelayTransport)
|
||||
return []dialer.DialeFn{ws.Dialer{}}
|
||||
case TransportModeQUIC:
|
||||
c.log.Infof("%s=quic, using QUIC transport", EnvRelayTransport)
|
||||
return []dialer.DialeFn{quic.Dialer{}}
|
||||
}
|
||||
|
||||
all := []dialer.DialeFn{quic.Dialer{}, ws.Dialer{}}
|
||||
if mode == TransportModePreferWS {
|
||||
all = []dialer.DialeFn{ws.Dialer{}, quic.Dialer{}}
|
||||
}
|
||||
|
||||
if c.mtu > 0 && c.mtu > iface.DefaultMTU {
|
||||
c.log.Infof("MTU %d exceeds default (%d), avoiding datagram-sized transports", c.mtu, iface.DefaultMTU)
|
||||
return nonDatagramSized(all)
|
||||
}
|
||||
return all
|
||||
}
|
||||
|
||||
101
shared/relay/client/dialers_generic_test.go
Normal file
101
shared/relay/client/dialers_generic_test.go
Normal file
@@ -0,0 +1,101 @@
|
||||
//go:build !js
|
||||
|
||||
package client
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/shared/relay/client/dialer"
|
||||
netErr "github.com/netbirdio/netbird/shared/relay/client/dialer/net"
|
||||
"github.com/netbirdio/netbird/shared/relay/client/dialer/quic"
|
||||
"github.com/netbirdio/netbird/shared/relay/client/dialer/ws"
|
||||
)
|
||||
|
||||
// TestDatagramSizedCapability locks the capability the generic fallback relies
|
||||
// on: QUIC is datagram-sized, WebSocket is not.
|
||||
func TestDatagramSizedCapability(t *testing.T) {
|
||||
assert.True(t, dialer.IsDatagramSized(quic.Dialer{}), "QUIC must advertise datagram-sized")
|
||||
assert.False(t, dialer.IsDatagramSized(ws.Dialer{}), "WebSocket must not advertise datagram-sized")
|
||||
}
|
||||
|
||||
func protocols(dialers []dialer.DialeFn) []string {
|
||||
out := make([]string, len(dialers))
|
||||
for i, d := range dialers {
|
||||
out[i] = d.Protocol()
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func TestGetDialers(t *testing.T) {
|
||||
const url = "rels://relay.example:443"
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
mode string
|
||||
mtu uint16
|
||||
preferWS bool
|
||||
want []string
|
||||
}{
|
||||
{name: "auto races quic and ws", mode: "auto", mtu: iface.DefaultMTU, want: []string{"quic", "ws"}},
|
||||
{name: "ws pinned", mode: "ws", mtu: iface.DefaultMTU, want: []string{"ws"}},
|
||||
{name: "quic pinned", mode: "quic", mtu: iface.DefaultMTU, want: []string{"quic"}},
|
||||
{name: "prefer-quic orders quic first", mode: "prefer-quic", mtu: iface.DefaultMTU, want: []string{"quic", "ws"}},
|
||||
{name: "prefer-ws orders ws first", mode: "prefer-ws", mtu: iface.DefaultMTU, want: []string{"ws", "quic"}},
|
||||
{name: "mtu above default forces ws", mode: "auto", mtu: iface.DefaultMTU + 100, want: []string{"ws"}},
|
||||
{name: "sticky fallback forces ws in auto", mode: "auto", mtu: iface.DefaultMTU, preferWS: true, want: []string{"ws"}},
|
||||
{name: "sticky fallback forces ws in prefer-quic", mode: "prefer-quic", mtu: iface.DefaultMTU, preferWS: true, want: []string{"ws"}},
|
||||
{name: "quic pin overrides sticky fallback", mode: "quic", mtu: iface.DefaultMTU, preferWS: true, want: []string{"quic"}},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Setenv(EnvRelayTransport, tc.mode)
|
||||
if tc.mode == "" {
|
||||
os.Unsetenv(EnvRelayTransport)
|
||||
}
|
||||
|
||||
tf := newTransportFallback()
|
||||
if tc.preferWS {
|
||||
tf.recordFailure(url)
|
||||
}
|
||||
|
||||
c := &Client{
|
||||
log: log.WithField("test", t.Name()),
|
||||
connectionURL: url,
|
||||
mtu: tc.mtu,
|
||||
transportFallback: tf,
|
||||
}
|
||||
|
||||
assert.Equal(t, tc.want, protocols(c.getDialers(transportModeFromEnv())))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestStickyFallbackAfterDatagramTooLarge verifies the full chain: an oversized
|
||||
// datagram records a fallback that makes the next dial pick WebSocket, the way a
|
||||
// reconnect would after the connection is closed.
|
||||
func TestStickyFallbackAfterDatagramTooLarge(t *testing.T) {
|
||||
const url = "rels://relay.example:443"
|
||||
t.Setenv(EnvRelayTransport, string(TransportModeAuto))
|
||||
|
||||
c := &Client{
|
||||
log: log.WithField("test", t.Name()),
|
||||
connectionURL: url,
|
||||
mtu: iface.DefaultMTU,
|
||||
transportFallback: newTransportFallback(),
|
||||
}
|
||||
|
||||
// First dial races both transports.
|
||||
assert.Equal(t, []string{"quic", "ws"}, protocols(c.getDialers(transportModeFromEnv())))
|
||||
|
||||
// An oversized datagram records the fallback for this server.
|
||||
c.onDatagramTooLarge(&closeTrackingConn{}, netErr.ErrDatagramTooLarge)
|
||||
|
||||
// The reconnect now sticks to WebSocket.
|
||||
assert.Equal(t, []string{"ws"}, protocols(c.getDialers(transportModeFromEnv())))
|
||||
}
|
||||
@@ -7,7 +7,11 @@ import (
|
||||
"github.com/netbirdio/netbird/shared/relay/client/dialer/ws"
|
||||
)
|
||||
|
||||
func (c *Client) getDialers() []dialer.DialeFn {
|
||||
func (c *Client) getDialers(_ TransportMode) []dialer.DialeFn {
|
||||
// JS/WASM build only uses WebSocket transport
|
||||
return []dialer.DialeFn{ws.Dialer{}}
|
||||
}
|
||||
|
||||
func (c *Client) baseDialers(_ TransportMode) []dialer.DialeFn {
|
||||
return []dialer.DialeFn{ws.Dialer{}}
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
@@ -20,6 +21,10 @@ type Guard struct {
|
||||
// maxBackoffInterval caps the exponential backoff between reconnect
|
||||
// attempts.
|
||||
maxBackoffInterval time.Duration
|
||||
|
||||
// lastErr is the error from the most recent failed reconnect attempt,
|
||||
// surfaced as the home relay status while disconnected.
|
||||
lastErr atomic.Pointer[error]
|
||||
}
|
||||
|
||||
// NewGuard creates a new guard for the relay client. A non-positive
|
||||
@@ -37,6 +42,19 @@ func NewGuard(sp *ServerPicker, maxBackoffInterval time.Duration) *Guard {
|
||||
return g
|
||||
}
|
||||
|
||||
// LastError returns the error from the most recent failed reconnect attempt, or
|
||||
// nil if reconnection last succeeded.
|
||||
func (g *Guard) LastError() error {
|
||||
if p := g.lastErr.Load(); p != nil {
|
||||
return *p
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *Guard) setLastError(err error) {
|
||||
g.lastErr.Store(&err)
|
||||
}
|
||||
|
||||
// StartReconnectTrys is called when the relay client is disconnected from the relay server.
|
||||
// It attempts to reconnect to the relay server. The function first tries a quick reconnect
|
||||
// to the same server that was used before, if the server URL is still valid. If the quick
|
||||
@@ -63,6 +81,7 @@ func (g *Guard) StartReconnectTrys(ctx context.Context, relayClient *Client) {
|
||||
case <-ticker.C:
|
||||
if err := g.retry(ctx); err != nil {
|
||||
log.Errorf("failed to pick new Relay server: %s", err)
|
||||
g.setLastError(err)
|
||||
continue
|
||||
}
|
||||
return
|
||||
@@ -89,6 +108,7 @@ func (g *Guard) tryToQuickReconnect(parentCtx context.Context, rc *Client) bool
|
||||
|
||||
if err := rc.Connect(parentCtx); err != nil {
|
||||
log.Errorf("failed to reconnect to relay server: %s", err)
|
||||
g.setLastError(err)
|
||||
return false
|
||||
}
|
||||
return true
|
||||
@@ -100,6 +120,7 @@ func (g *Guard) retry(ctx context.Context) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
g.setLastError(nil)
|
||||
|
||||
// prevent to work with a deprecated Relay client instance
|
||||
g.drainRelayClientChan()
|
||||
@@ -125,6 +146,7 @@ func (g *Guard) isServerURLStillValid(rc *Client) bool {
|
||||
}
|
||||
|
||||
func (g *Guard) notifyReconnected() {
|
||||
g.setLastError(nil)
|
||||
select {
|
||||
case g.OnReconnected <- struct{}{}:
|
||||
default:
|
||||
|
||||
@@ -79,23 +79,30 @@ type Manager struct {
|
||||
|
||||
cleanupInterval time.Duration
|
||||
keepUnusedServerTime time.Duration
|
||||
|
||||
// transportFallback is shared across home and foreign relay clients so a
|
||||
// datagram-too-large failure makes that server avoid datagram-sized transports across reconnects.
|
||||
transportFallback *transportFallback
|
||||
}
|
||||
|
||||
// NewManager creates a new manager instance.
|
||||
// The serverURL address can be empty. In this case, the manager will not serve.
|
||||
func NewManager(ctx context.Context, serverURLs []string, peerID string, mtu uint16, opts ...ManagerOption) *Manager {
|
||||
tokenStore := &relayAuth.TokenStore{}
|
||||
tf := newTransportFallback()
|
||||
|
||||
m := &Manager{
|
||||
ctx: ctx,
|
||||
peerID: peerID,
|
||||
tokenStore: tokenStore,
|
||||
mtu: mtu,
|
||||
ctx: ctx,
|
||||
peerID: peerID,
|
||||
tokenStore: tokenStore,
|
||||
mtu: mtu,
|
||||
transportFallback: tf,
|
||||
serverPicker: &ServerPicker{
|
||||
TokenStore: tokenStore,
|
||||
PeerID: peerID,
|
||||
MTU: mtu,
|
||||
ConnectionTimeout: defaultConnectionTimeout,
|
||||
TransportFallback: tf,
|
||||
},
|
||||
relayClients: make(map[string]*RelayTrack),
|
||||
onDisconnectedListeners: make(map[string]*list.List),
|
||||
@@ -123,6 +130,9 @@ func (m *Manager) Serve() error {
|
||||
|
||||
client, err := m.serverPicker.PickServer(m.ctx)
|
||||
if err != nil {
|
||||
// record the initial failure so status shows the real reason before
|
||||
// the guard's first retry tick
|
||||
m.reconnectGuard.setLastError(err)
|
||||
go m.reconnectGuard.StartReconnectTrys(m.ctx, nil)
|
||||
} else {
|
||||
m.storeClient(client)
|
||||
@@ -235,6 +245,67 @@ func (m *Manager) ServerURLs() []string {
|
||||
return m.serverPicker.ServerURLs.Load().([]string)
|
||||
}
|
||||
|
||||
// RelayConnectError returns the error from the most recent failed home relay
|
||||
// reconnect attempt, or nil if the relay last connected successfully.
|
||||
func (m *Manager) RelayConnectError() error {
|
||||
return m.reconnectGuard.LastError()
|
||||
}
|
||||
|
||||
// RelayConnState is the connection state of a single relay server.
|
||||
type RelayConnState struct {
|
||||
// URL is the server's instance address when connected, otherwise the
|
||||
// configured server URL.
|
||||
URL string
|
||||
// Transport is the negotiated transport, empty if not connected.
|
||||
Transport string
|
||||
// Err is set when the relay is not connected.
|
||||
Err error
|
||||
}
|
||||
|
||||
// RelayStates returns the connection state of the home relay and every foreign
|
||||
// relay the manager currently tracks.
|
||||
func (m *Manager) RelayStates() []RelayConnState {
|
||||
var states []RelayConnState
|
||||
|
||||
m.relayClientMu.RLock()
|
||||
home := m.relayClient
|
||||
m.relayClientMu.RUnlock()
|
||||
if home != nil {
|
||||
st := relayConnState(home)
|
||||
// The home relay reconnects through the guard, so the real failure
|
||||
// reason lives there rather than on the (stale) client.
|
||||
if st.Err != nil {
|
||||
if gErr := m.reconnectGuard.LastError(); gErr != nil {
|
||||
st.Err = gErr
|
||||
}
|
||||
}
|
||||
states = append(states, st)
|
||||
}
|
||||
|
||||
// Snapshot the tracks, then query each outside the map lock: a track can be
|
||||
// held by an in-progress Connect, and blocking on it must not stall other
|
||||
// relay operations.
|
||||
m.relayClientsMutex.RLock()
|
||||
tracks := make([]*RelayTrack, 0, len(m.relayClients))
|
||||
for _, rt := range m.relayClients {
|
||||
tracks = append(tracks, rt)
|
||||
}
|
||||
m.relayClientsMutex.RUnlock()
|
||||
|
||||
// Only connected foreign relays carry state; a failed connect is evicted
|
||||
// immediately (openConnVia), so there is no error state to surface.
|
||||
for _, rt := range tracks {
|
||||
rt.RLock()
|
||||
rc := rt.relayClient
|
||||
rt.RUnlock()
|
||||
if rc != nil {
|
||||
states = append(states, relayConnState(rc))
|
||||
}
|
||||
}
|
||||
|
||||
return states
|
||||
}
|
||||
|
||||
// HasRelayAddress returns true if the manager is serving. With this method can check if the peer can communicate with
|
||||
// Relay service.
|
||||
func (m *Manager) HasRelayAddress() bool {
|
||||
@@ -287,6 +358,7 @@ func (m *Manager) openConnVia(ctx context.Context, serverAddress, peerKey string
|
||||
m.relayClientsMutex.Unlock()
|
||||
|
||||
relayClient := NewClientWithServerIP(serverAddress, serverIP, m.tokenStore, m.peerID, m.mtu)
|
||||
relayClient.SetTransportFallback(m.transportFallback)
|
||||
err := relayClient.Connect(m.ctx)
|
||||
if err != nil {
|
||||
rt.err = err
|
||||
@@ -452,3 +524,11 @@ func (m *Manager) notifyOnDisconnectListeners(serverAddress string) {
|
||||
}
|
||||
delete(m.onDisconnectedListeners, serverAddress)
|
||||
}
|
||||
|
||||
func relayConnState(c *Client) RelayConnState {
|
||||
addr, err := c.ServerInstanceURL()
|
||||
if err != nil {
|
||||
return RelayConnState{URL: c.connectionURL, Err: err}
|
||||
}
|
||||
return RelayConnState{URL: addr, Transport: c.Transport()}
|
||||
}
|
||||
|
||||
@@ -29,6 +29,7 @@ type ServerPicker struct {
|
||||
PeerID string
|
||||
MTU uint16
|
||||
ConnectionTimeout time.Duration
|
||||
TransportFallback *transportFallback
|
||||
}
|
||||
|
||||
func (sp *ServerPicker) PickServer(parentCtx context.Context) (*Client, error) {
|
||||
@@ -39,6 +40,7 @@ func (sp *ServerPicker) PickServer(parentCtx context.Context) (*Client, error) {
|
||||
|
||||
connResultChan := make(chan connResult, totalServers)
|
||||
successChan := make(chan connResult, 1)
|
||||
errChan := make(chan error, 1)
|
||||
concurrentLimiter := make(chan struct{}, maxConcurrentServers)
|
||||
|
||||
log.Debugf("pick server from list: %v", sp.ServerURLs.Load().([]string))
|
||||
@@ -53,23 +55,24 @@ func (sp *ServerPicker) PickServer(parentCtx context.Context) (*Client, error) {
|
||||
}(url)
|
||||
}
|
||||
|
||||
go sp.processConnResults(connResultChan, successChan)
|
||||
go sp.processConnResults(connResultChan, successChan, errChan)
|
||||
|
||||
select {
|
||||
case cr, ok := <-successChan:
|
||||
if !ok {
|
||||
return nil, errors.New("failed to connect to any relay server: all attempts failed")
|
||||
return nil, <-errChan
|
||||
}
|
||||
log.Infof("chosen home Relay server: %s", cr.Url)
|
||||
return cr.RelayClient, nil
|
||||
case <-ctx.Done():
|
||||
return nil, fmt.Errorf("failed to connect to any relay server: %w", ctx.Err())
|
||||
return nil, fmt.Errorf("connect to relay server: %w", ctx.Err())
|
||||
}
|
||||
}
|
||||
|
||||
func (sp *ServerPicker) startConnection(ctx context.Context, resultChan chan connResult, url string) {
|
||||
log.Infof("try to connecting to relay server: %s", url)
|
||||
relayClient := NewClient(url, sp.TokenStore, sp.PeerID, sp.MTU)
|
||||
relayClient.SetTransportFallback(sp.TransportFallback)
|
||||
err := relayClient.Connect(ctx)
|
||||
resultChan <- connResult{
|
||||
RelayClient: relayClient,
|
||||
@@ -78,12 +81,14 @@ func (sp *ServerPicker) startConnection(ctx context.Context, resultChan chan con
|
||||
}
|
||||
}
|
||||
|
||||
func (sp *ServerPicker) processConnResults(resultChan chan connResult, successChan chan connResult) {
|
||||
func (sp *ServerPicker) processConnResults(resultChan chan connResult, successChan chan connResult, errChan chan error) {
|
||||
var hasSuccess bool
|
||||
var errs []error
|
||||
for numOfResults := 0; numOfResults < cap(resultChan); numOfResults++ {
|
||||
cr := <-resultChan
|
||||
if cr.Err != nil {
|
||||
log.Tracef("failed to connect to Relay server: %s: %v", cr.Url, cr.Err)
|
||||
errs = append(errs, cr.Err)
|
||||
continue
|
||||
}
|
||||
log.Infof("connected to Relay server: %s", cr.Url)
|
||||
@@ -99,5 +104,16 @@ func (sp *ServerPicker) processConnResults(resultChan chan connResult, successCh
|
||||
hasSuccess = true
|
||||
successChan <- cr
|
||||
}
|
||||
if !hasSuccess {
|
||||
errChan <- pickErr(errs)
|
||||
}
|
||||
close(successChan)
|
||||
}
|
||||
|
||||
// pickErr combines per-server connection failures into a single error.
|
||||
func pickErr(errs []error) error {
|
||||
if len(errs) == 0 {
|
||||
return errors.New("no relay server available")
|
||||
}
|
||||
return errors.Join(errs...)
|
||||
}
|
||||
|
||||
129
shared/relay/client/transport.go
Normal file
129
shared/relay/client/transport.go
Normal file
@@ -0,0 +1,129 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/relay/client/dialer"
|
||||
)
|
||||
|
||||
// EnvRelayTransport pins the relay transport. Valid values: "auto" (default,
|
||||
// race QUIC and WebSocket), "quic" (QUIC only), "ws" (WebSocket only),
|
||||
// "prefer-quic" / "prefer-ws" (try the preferred transport first, fall back to
|
||||
// the other only if it fails to connect; no race). The prefer modes trade a
|
||||
// slower connect when the preferred transport is blackholed for deterministic
|
||||
// transport selection.
|
||||
const EnvRelayTransport = "NB_RELAY_TRANSPORT"
|
||||
|
||||
const (
|
||||
// transportFallbackBase is the initial window a relay server avoids
|
||||
// datagram-sized transports after a datagram is rejected as too large.
|
||||
transportFallbackBase = 10 * time.Minute
|
||||
// transportFallbackMax caps the pinned window when failures repeat.
|
||||
transportFallbackMax = 60 * time.Minute
|
||||
)
|
||||
|
||||
// TransportMode selects which relay dialers are used.
|
||||
type TransportMode string
|
||||
|
||||
const (
|
||||
TransportModeAuto TransportMode = "auto"
|
||||
TransportModeQUIC TransportMode = "quic"
|
||||
TransportModeWS TransportMode = "ws"
|
||||
TransportModePreferQUIC TransportMode = "prefer-quic"
|
||||
TransportModePreferWS TransportMode = "prefer-ws"
|
||||
)
|
||||
|
||||
// transportModeFromEnv reads EnvRelayTransport, defaulting to auto for an empty
|
||||
// or unrecognized value.
|
||||
func transportModeFromEnv() TransportMode {
|
||||
switch TransportMode(strings.ToLower(strings.TrimSpace(os.Getenv(EnvRelayTransport)))) {
|
||||
case "", TransportModeAuto:
|
||||
return TransportModeAuto
|
||||
case TransportModeQUIC:
|
||||
return TransportModeQUIC
|
||||
case TransportModeWS:
|
||||
return TransportModeWS
|
||||
case TransportModePreferQUIC:
|
||||
return TransportModePreferQUIC
|
||||
case TransportModePreferWS:
|
||||
return TransportModePreferWS
|
||||
default:
|
||||
log.Warnf("invalid %s value %q, using %q", EnvRelayTransport, os.Getenv(EnvRelayTransport), TransportModeAuto)
|
||||
return TransportModeAuto
|
||||
}
|
||||
}
|
||||
|
||||
// sequential reports whether the mode tries dialers in order with fallback
|
||||
// instead of racing them concurrently.
|
||||
func (m TransportMode) sequential() bool {
|
||||
return m == TransportModePreferQUIC || m == TransportModePreferWS
|
||||
}
|
||||
|
||||
// transportFallback tracks relay servers that have rejected a datagram-sized
|
||||
// transport (a write too large for the path) and should temporarily avoid such
|
||||
// transports. It is shared across the relay manager so the preference survives
|
||||
// client recreation (foreign relay clients are evicted and rebuilt on
|
||||
// disconnect). Entries are keyed by server URL and expire after a window that
|
||||
// grows on repeated failures.
|
||||
type transportFallback struct {
|
||||
mu sync.Mutex
|
||||
entries map[string]*fallbackEntry
|
||||
}
|
||||
|
||||
type fallbackEntry struct {
|
||||
until time.Time
|
||||
duration time.Duration
|
||||
}
|
||||
|
||||
func newTransportFallback() *transportFallback {
|
||||
return &transportFallback{entries: make(map[string]*fallbackEntry)}
|
||||
}
|
||||
|
||||
// avoidDatagramSized reports whether serverURL is currently within a window
|
||||
// where datagram-sized transports should be avoided.
|
||||
func (f *transportFallback) avoidDatagramSized(serverURL string) bool {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
e := f.entries[serverURL]
|
||||
return e != nil && time.Now().Before(e.until)
|
||||
}
|
||||
|
||||
// recordFailure makes serverURL avoid datagram-sized transports for a window:
|
||||
// transportFallbackBase on the first failure, doubling up to transportFallbackMax
|
||||
// when a datagram transport fails again after a previous window expired. It
|
||||
// returns the active window duration.
|
||||
func (f *transportFallback) recordFailure(serverURL string) time.Duration {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
e := f.entries[serverURL]
|
||||
switch {
|
||||
case e == nil:
|
||||
e = &fallbackEntry{duration: transportFallbackBase}
|
||||
f.entries[serverURL] = e
|
||||
case now.Before(e.until):
|
||||
return time.Until(e.until)
|
||||
default:
|
||||
e.duration = min(e.duration*2, transportFallbackMax)
|
||||
}
|
||||
e.until = now.Add(e.duration)
|
||||
return e.duration
|
||||
}
|
||||
|
||||
// nonDatagramSized returns the dialers from in that are not datagram-sized,
|
||||
// preserving order.
|
||||
func nonDatagramSized(in []dialer.DialeFn) []dialer.DialeFn {
|
||||
out := make([]dialer.DialeFn, 0, len(in))
|
||||
for _, d := range in {
|
||||
if !dialer.IsDatagramSized(d) {
|
||||
out = append(out, d)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
140
shared/relay/client/transport_test.go
Normal file
140
shared/relay/client/transport_test.go
Normal file
@@ -0,0 +1,140 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"net"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
netErr "github.com/netbirdio/netbird/shared/relay/client/dialer/net"
|
||||
)
|
||||
|
||||
// closeTrackingConn records whether Close was called; only Close is exercised.
|
||||
type closeTrackingConn struct {
|
||||
net.Conn
|
||||
closed bool
|
||||
}
|
||||
|
||||
func (c *closeTrackingConn) Close() error {
|
||||
c.closed = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestTransportModeFromEnv(t *testing.T) {
|
||||
tests := []struct {
|
||||
value string
|
||||
want TransportMode
|
||||
}{
|
||||
{"", TransportModeAuto},
|
||||
{"auto", TransportModeAuto},
|
||||
{"quic", TransportModeQUIC},
|
||||
{"QUIC", TransportModeQUIC},
|
||||
{"ws", TransportModeWS},
|
||||
{" Ws ", TransportModeWS},
|
||||
{"prefer-quic", TransportModePreferQUIC},
|
||||
{"prefer-ws", TransportModePreferWS},
|
||||
{"garbage", TransportModeAuto},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.value, func(t *testing.T) {
|
||||
t.Setenv(EnvRelayTransport, tc.value)
|
||||
if tc.value == "" {
|
||||
os.Unsetenv(EnvRelayTransport)
|
||||
}
|
||||
assert.Equal(t, tc.want, transportModeFromEnv())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTransportFallbackRecordAndExpiry(t *testing.T) {
|
||||
const url = "rels://relay.example:443"
|
||||
f := newTransportFallback()
|
||||
|
||||
assert.False(t, f.avoidDatagramSized(url), "no fallback recorded yet")
|
||||
|
||||
d := f.recordFailure(url)
|
||||
assert.Equal(t, transportFallbackBase, d, "first failure pins for the base window")
|
||||
assert.True(t, f.avoidDatagramSized(url), "datagram-sized transport avoided within the window")
|
||||
|
||||
// A second failure while still inside the window must not grow the window.
|
||||
d = f.recordFailure(url)
|
||||
assert.LessOrEqual(t, d, transportFallbackBase, "still within the active window")
|
||||
require.NotNil(t, f.entries[url])
|
||||
assert.Equal(t, transportFallbackBase, f.entries[url].duration, "duration unchanged inside window")
|
||||
|
||||
// Expire the window: datagram-sized transport allowed again.
|
||||
f.entries[url].until = time.Now().Add(-time.Second)
|
||||
assert.False(t, f.avoidDatagramSized(url), "window expired, datagram-sized transport allowed")
|
||||
}
|
||||
|
||||
func TestTransportFallbackGrowsOnRepeat(t *testing.T) {
|
||||
const url = "rels://relay.example:443"
|
||||
f := newTransportFallback()
|
||||
|
||||
want := transportFallbackBase
|
||||
for i := range 6 {
|
||||
d := f.recordFailure(url)
|
||||
assert.Equal(t, want, d, "window after %d expiries", i)
|
||||
|
||||
// expire the window so the next failure is treated as a repeat
|
||||
f.entries[url].until = time.Now().Add(-time.Second)
|
||||
|
||||
want = min(want*2, transportFallbackMax)
|
||||
}
|
||||
|
||||
assert.Equal(t, transportFallbackMax, f.entries[url].duration, "window caps at the max")
|
||||
}
|
||||
|
||||
func TestOnDatagramTooLargeAuto(t *testing.T) {
|
||||
const url = "rels://relay.example:443"
|
||||
t.Setenv(EnvRelayTransport, string(TransportModeAuto))
|
||||
|
||||
tf := newTransportFallback()
|
||||
c := &Client{
|
||||
log: log.WithField("test", t.Name()),
|
||||
connectionURL: url,
|
||||
transportFallback: tf,
|
||||
}
|
||||
conn := &closeTrackingConn{}
|
||||
|
||||
c.onDatagramTooLarge(conn, netErr.ErrDatagramTooLarge)
|
||||
|
||||
assert.True(t, conn.closed, "connection closed to force reconnect")
|
||||
assert.True(t, tf.avoidDatagramSized(url), "fallback recorded for the server")
|
||||
|
||||
// A second oversized datagram on the same connection must not re-close.
|
||||
conn.closed = false
|
||||
c.onDatagramTooLarge(conn, netErr.ErrDatagramTooLarge)
|
||||
assert.False(t, conn.closed, "single fallback per connection")
|
||||
}
|
||||
|
||||
func TestOnDatagramTooLargeQUICPinned(t *testing.T) {
|
||||
const url = "rels://relay.example:443"
|
||||
t.Setenv(EnvRelayTransport, string(TransportModeQUIC))
|
||||
|
||||
tf := newTransportFallback()
|
||||
c := &Client{
|
||||
log: log.WithField("test", t.Name()),
|
||||
connectionURL: url,
|
||||
transportFallback: tf,
|
||||
}
|
||||
conn := &closeTrackingConn{}
|
||||
|
||||
c.onDatagramTooLarge(conn, netErr.ErrDatagramTooLarge)
|
||||
|
||||
assert.False(t, conn.closed, "QUIC pin keeps the connection, no fallback redial")
|
||||
assert.False(t, tf.avoidDatagramSized(url), "QUIC pin records no fallback")
|
||||
}
|
||||
|
||||
func TestTransportFallbackPerServer(t *testing.T) {
|
||||
f := newTransportFallback()
|
||||
f.recordFailure("rels://a.example:443")
|
||||
|
||||
assert.True(t, f.avoidDatagramSized("rels://a.example:443"))
|
||||
assert.False(t, f.avoidDatagramSized("rels://b.example:443"), "fallback is scoped to one server")
|
||||
}
|
||||
Reference in New Issue
Block a user