diff --git a/client/internal/engine.go b/client/internal/engine.go index b49e02c6d..09d80a87d 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -570,7 +570,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL) e.connMgr.Start(e.ctx) e.srWatcher = guard.NewSRWatcher(e.signal, e.relayManager, e.mobileDep.IFaceDiscover, iceCfg) - e.srWatcher.Start() + e.srWatcher.Start(peer.IsForceRelayed()) e.receiveSignalEvents() e.receiveManagementEvents() diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index 8d1585b3f..1e416bfe7 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -185,17 +185,20 @@ func (conn *Conn) Open(engineCtx context.Context) error { conn.workerRelay = NewWorkerRelay(conn.ctx, conn.Log, isController(conn.config), conn.config, conn, conn.relayManager) - relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally() - workerICE, err := NewWorkerICE(conn.ctx, conn.Log, conn.config, conn, conn.signaler, conn.iFaceDiscover, conn.statusRecorder, relayIsSupportedLocally) - if err != nil { - return err + forceRelay := IsForceRelayed() + if !forceRelay { + relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally() + workerICE, err := NewWorkerICE(conn.ctx, conn.Log, conn.config, conn, conn.signaler, conn.iFaceDiscover, conn.statusRecorder, relayIsSupportedLocally) + if err != nil { + return err + } + conn.workerICE = workerICE } - conn.workerICE = workerICE conn.handshaker = NewHandshaker(conn.Log, conn.config, conn.signaler, conn.workerICE, conn.workerRelay, conn.metricsStages) conn.handshaker.AddRelayListener(conn.workerRelay.OnNewOffer) - if !isForceRelayed() { + if !forceRelay { conn.handshaker.AddICEListener(conn.workerICE.OnNewOffer) } @@ -251,7 +254,9 @@ func (conn *Conn) Close(signalToRemote bool) { conn.wgWatcherCancel() } conn.workerRelay.CloseConn() - conn.workerICE.Close() + if conn.workerICE != nil { + conn.workerICE.Close() + } if conn.wgProxyRelay != nil { err := conn.wgProxyRelay.CloseConn() @@ -294,7 +299,9 @@ func (conn *Conn) OnRemoteAnswer(answer OfferAnswer) { // OnRemoteCandidate Handles ICE connection Candidate provided by the remote peer. func (conn *Conn) OnRemoteCandidate(candidate ice.Candidate, haRoutes route.HAMap) { conn.dumpState.RemoteCandidate() - conn.workerICE.OnRemoteCandidate(candidate, haRoutes) + if conn.workerICE != nil { + conn.workerICE.OnRemoteCandidate(candidate, haRoutes) + } } // SetOnConnected sets a handler function to be triggered by Conn when a new connection to a remote peer established @@ -712,33 +719,35 @@ func (conn *Conn) evalStatus() ConnStatus { return StatusConnecting } -func (conn *Conn) isConnectedOnAllWay() (connected bool) { - // would be better to protect this with a mutex, but it could cause deadlock with Close function - +// isConnectedOnAllWay evaluates the overall connection status based on ICE and Relay transports. +// +// The result is a tri-state: +// - ConnStatusConnected: all available transports are up +// - ConnStatusPartiallyConnected: relay is up but ICE is still pending/reconnecting +// - ConnStatusDisconnected: no working transport +func (conn *Conn) isConnectedOnAllWay() (status guard.ConnStatus) { defer func() { - if !connected { + if status == guard.ConnStatusDisconnected { conn.logTraceConnState() } }() - // For JS platform: only relay connection is supported - if runtime.GOOS == "js" { - return conn.statusRelay.Get() == worker.StatusConnected + iceWorkerCreated := conn.workerICE != nil + + var iceInProgress bool + if iceWorkerCreated { + iceInProgress = conn.workerICE.InProgress() } - // For non-JS platforms: check ICE connection status - if conn.statusICE.Get() == worker.StatusDisconnected && !conn.workerICE.InProgress() { - return false - } - - // If relay is supported with peer, it must also be connected - if conn.workerRelay.IsRelayConnectionSupportedWithPeer() { - if conn.statusRelay.Get() == worker.StatusDisconnected { - return false - } - } - - return true + return evalConnStatus(connStatusInputs{ + forceRelay: IsForceRelayed(), + peerUsesRelay: conn.workerRelay.IsRelayConnectionSupportedWithPeer(), + relayConnected: conn.statusRelay.Get() == worker.StatusConnected, + remoteSupportsICE: conn.handshaker.RemoteICESupported(), + iceWorkerCreated: iceWorkerCreated, + iceStatusConnecting: conn.statusICE.Get() != worker.StatusDisconnected, + iceInProgress: iceInProgress, + }) } func (conn *Conn) enableWgWatcherIfNeeded(enabledTime time.Time) { @@ -926,3 +935,43 @@ func isController(config ConnConfig) bool { func isRosenpassEnabled(remoteRosenpassPubKey []byte) bool { return remoteRosenpassPubKey != nil } + +func evalConnStatus(in connStatusInputs) guard.ConnStatus { + // "Relay up and needed" — the peer uses relay and the transport is connected. + relayUsedAndUp := in.peerUsesRelay && in.relayConnected + + // Force-relay mode: ICE never runs. Relay is the only transport and must be up. + if in.forceRelay { + return boolToConnStatus(relayUsedAndUp) + } + + // Remote peer doesn't support ICE, or we haven't created the worker yet: + // relay is the only possible transport. + if !in.remoteSupportsICE || !in.iceWorkerCreated { + return boolToConnStatus(relayUsedAndUp) + } + + // ICE counts as "up" when the status is anything other than Disconnected, OR + // when a negotiation is currently in progress (so we don't spam offers while one is in flight). + iceUp := in.iceStatusConnecting || in.iceInProgress + + // Relay side is acceptable if the peer doesn't rely on relay, or relay is connected. + relayOK := !in.peerUsesRelay || in.relayConnected + + switch { + case iceUp && relayOK: + return guard.ConnStatusConnected + case relayUsedAndUp: + // Relay is up but ICE is down — partially connected. + return guard.ConnStatusPartiallyConnected + default: + return guard.ConnStatusDisconnected + } +} + +func boolToConnStatus(connected bool) guard.ConnStatus { + if connected { + return guard.ConnStatusConnected + } + return guard.ConnStatusDisconnected +} diff --git a/client/internal/peer/conn_status.go b/client/internal/peer/conn_status.go index 73acc5ef5..b43e245f3 100644 --- a/client/internal/peer/conn_status.go +++ b/client/internal/peer/conn_status.go @@ -13,6 +13,20 @@ const ( StatusConnected ) +// connStatusInputs is the primitive-valued snapshot of the state that drives the +// tri-state connection classification. Extracted so the decision logic can be unit-tested +// without constructing full Worker/Handshaker objects. +type connStatusInputs struct { + forceRelay bool // NB_FORCE_RELAY or JS/WASM + peerUsesRelay bool // remote peer advertises relay support AND local has relay + relayConnected bool // statusRelay reports Connected (independent of whether peer uses relay) + remoteSupportsICE bool // remote peer sent ICE credentials + iceWorkerCreated bool // local WorkerICE exists (false in force-relay mode) + iceStatusConnecting bool // statusICE is anything other than Disconnected + iceInProgress bool // a negotiation is currently in flight +} + + // ConnStatus describe the status of a peer's connection type ConnStatus int32 diff --git a/client/internal/peer/conn_status_eval_test.go b/client/internal/peer/conn_status_eval_test.go new file mode 100644 index 000000000..66393cafe --- /dev/null +++ b/client/internal/peer/conn_status_eval_test.go @@ -0,0 +1,201 @@ +package peer + +import ( + "testing" + + "github.com/netbirdio/netbird/client/internal/peer/guard" +) + +func TestEvalConnStatus_ForceRelay(t *testing.T) { + tests := []struct { + name string + in connStatusInputs + want guard.ConnStatus + }{ + { + name: "force relay, peer uses relay, relay up", + in: connStatusInputs{ + forceRelay: true, + peerUsesRelay: true, + relayConnected: true, + }, + want: guard.ConnStatusConnected, + }, + { + name: "force relay, peer uses relay, relay down", + in: connStatusInputs{ + forceRelay: true, + peerUsesRelay: true, + relayConnected: false, + }, + want: guard.ConnStatusDisconnected, + }, + { + name: "force relay, peer does NOT use relay - disconnected forever", + in: connStatusInputs{ + forceRelay: true, + peerUsesRelay: false, + relayConnected: true, + }, + want: guard.ConnStatusDisconnected, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + if got := evalConnStatus(tc.in); got != tc.want { + t.Fatalf("evalConnStatus = %v, want %v", got, tc.want) + } + }) + } +} + +func TestEvalConnStatus_ICEUnavailable(t *testing.T) { + tests := []struct { + name string + in connStatusInputs + want guard.ConnStatus + }{ + { + name: "remote does not support ICE, peer uses relay, relay up", + in: connStatusInputs{ + peerUsesRelay: true, + relayConnected: true, + remoteSupportsICE: false, + iceWorkerCreated: true, + }, + want: guard.ConnStatusConnected, + }, + { + name: "remote does not support ICE, peer uses relay, relay down", + in: connStatusInputs{ + peerUsesRelay: true, + relayConnected: false, + remoteSupportsICE: false, + iceWorkerCreated: true, + }, + want: guard.ConnStatusDisconnected, + }, + { + name: "ICE worker not yet created, relay up", + in: connStatusInputs{ + peerUsesRelay: true, + relayConnected: true, + remoteSupportsICE: true, + iceWorkerCreated: false, + }, + want: guard.ConnStatusConnected, + }, + { + name: "remote does not support ICE, peer does not use relay", + in: connStatusInputs{ + peerUsesRelay: false, + relayConnected: false, + remoteSupportsICE: false, + iceWorkerCreated: true, + }, + want: guard.ConnStatusDisconnected, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + if got := evalConnStatus(tc.in); got != tc.want { + t.Fatalf("evalConnStatus = %v, want %v", got, tc.want) + } + }) + } +} + +func TestEvalConnStatus_FullyAvailable(t *testing.T) { + base := connStatusInputs{ + remoteSupportsICE: true, + iceWorkerCreated: true, + } + + tests := []struct { + name string + mutator func(*connStatusInputs) + want guard.ConnStatus + }{ + { + name: "ICE connected, relay connected, peer uses relay", + mutator: func(in *connStatusInputs) { + in.peerUsesRelay = true + in.relayConnected = true + in.iceStatusConnecting = true + }, + want: guard.ConnStatusConnected, + }, + { + name: "ICE connected, peer does NOT use relay", + mutator: func(in *connStatusInputs) { + in.peerUsesRelay = false + in.relayConnected = false + in.iceStatusConnecting = true + }, + want: guard.ConnStatusConnected, + }, + { + name: "ICE InProgress only, peer does NOT use relay", + mutator: func(in *connStatusInputs) { + in.peerUsesRelay = false + in.iceStatusConnecting = false + in.iceInProgress = true + }, + want: guard.ConnStatusConnected, + }, + { + name: "ICE down, relay up, peer uses relay -> partial", + mutator: func(in *connStatusInputs) { + in.peerUsesRelay = true + in.relayConnected = true + in.iceStatusConnecting = false + in.iceInProgress = false + }, + want: guard.ConnStatusPartiallyConnected, + }, + { + name: "ICE down, peer does NOT use relay -> disconnected", + mutator: func(in *connStatusInputs) { + in.peerUsesRelay = false + in.relayConnected = false + in.iceStatusConnecting = false + in.iceInProgress = false + }, + want: guard.ConnStatusDisconnected, + }, + { + name: "ICE up, peer uses relay but relay down -> partial (relay required, ICE ignored)", + mutator: func(in *connStatusInputs) { + in.peerUsesRelay = true + in.relayConnected = false + in.iceStatusConnecting = true + }, + // relayOK = false (peer uses relay but it's down), iceUp = true + // first switch arm fails (relayOK false), relayUsedAndUp = false (relay down), + // falls into default: Disconnected. + want: guard.ConnStatusDisconnected, + }, + { + name: "ICE down, relay up but peer does not use relay -> disconnected", + mutator: func(in *connStatusInputs) { + in.peerUsesRelay = false + in.relayConnected = true // not actually used since peer doesn't rely on it + in.iceStatusConnecting = false + in.iceInProgress = false + }, + want: guard.ConnStatusDisconnected, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + in := base + tc.mutator(&in) + if got := evalConnStatus(in); got != tc.want { + t.Fatalf("evalConnStatus = %v, want %v (inputs: %+v)", got, tc.want, in) + } + }) + } +} diff --git a/client/internal/peer/env.go b/client/internal/peer/env.go index 7f500c410..b4ba9ad7b 100644 --- a/client/internal/peer/env.go +++ b/client/internal/peer/env.go @@ -10,7 +10,7 @@ const ( EnvKeyNBForceRelay = "NB_FORCE_RELAY" ) -func isForceRelayed() bool { +func IsForceRelayed() bool { if runtime.GOOS == "js" { return true } diff --git a/client/internal/peer/guard/guard.go b/client/internal/peer/guard/guard.go index d93403730..2e5efbcc5 100644 --- a/client/internal/peer/guard/guard.go +++ b/client/internal/peer/guard/guard.go @@ -8,7 +8,19 @@ import ( log "github.com/sirupsen/logrus" ) -type isConnectedFunc func() bool +// ConnStatus represents the connection state as seen by the guard. +type ConnStatus int + +const ( + // ConnStatusDisconnected means neither ICE nor Relay is connected. + ConnStatusDisconnected ConnStatus = iota + // ConnStatusPartiallyConnected means Relay is connected but ICE is not. + ConnStatusPartiallyConnected + // ConnStatusConnected means all required connections are established. + ConnStatusConnected +) + +type connStatusFunc func() ConnStatus // Guard is responsible for the reconnection logic. // It will trigger to send an offer to the peer then has connection issues. @@ -20,14 +32,14 @@ type isConnectedFunc func() bool // - ICE candidate changes type Guard struct { log *log.Entry - isConnectedOnAllWay isConnectedFunc + isConnectedOnAllWay connStatusFunc timeout time.Duration srWatcher *SRWatcher relayedConnDisconnected chan struct{} iCEConnDisconnected chan struct{} } -func NewGuard(log *log.Entry, isConnectedFn isConnectedFunc, timeout time.Duration, srWatcher *SRWatcher) *Guard { +func NewGuard(log *log.Entry, isConnectedFn connStatusFunc, timeout time.Duration, srWatcher *SRWatcher) *Guard { return &Guard{ log: log, isConnectedOnAllWay: isConnectedFn, @@ -57,8 +69,17 @@ func (g *Guard) SetICEConnDisconnected() { } } -// reconnectLoopWithRetry periodically check the connection status. -// Try to send offer while the P2P is not established or while the Relay is not connected if is it supported +// reconnectLoopWithRetry periodically checks the connection status and sends offers to re-establish connectivity. +// +// Behavior depends on the connection state reported by isConnectedOnAllWay: +// - Connected: no action, the peer is fully reachable. +// - Disconnected (neither ICE nor Relay): retries aggressively with exponential backoff (800ms doubling +// up to timeout), never gives up. This ensures rapid recovery when the peer has no connectivity at all. +// - PartiallyConnected (Relay up, ICE not): retries up to 3 times with exponential backoff, then switches +// to one attempt per hour. This limits signaling traffic when relay already provides connectivity. +// +// External events (relay/ICE disconnect, signal/relay reconnect, candidate changes) reset the retry +// counter and backoff ticker, giving ICE a fresh chance after network conditions change. func (g *Guard) reconnectLoopWithRetry(ctx context.Context, callback func()) { srReconnectedChan := g.srWatcher.NewListener() defer g.srWatcher.RemoveListener(srReconnectedChan) @@ -68,36 +89,47 @@ func (g *Guard) reconnectLoopWithRetry(ctx context.Context, callback func()) { tickerChannel := ticker.C + iceState := &iceRetryState{log: g.log} + defer iceState.reset() + for { select { - case t := <-tickerChannel: - if t.IsZero() { - g.log.Infof("retry timed out, stop periodic offer sending") - // after backoff timeout the ticker.C will be closed. We need to a dummy channel to avoid loop - tickerChannel = make(<-chan time.Time) - continue + case <-tickerChannel: + switch g.isConnectedOnAllWay() { + case ConnStatusConnected: + // all good, nothing to do + case ConnStatusDisconnected: + callback() + case ConnStatusPartiallyConnected: + if iceState.shouldRetry() { + callback() + } else { + iceState.enterHourlyMode() + ticker.Stop() + tickerChannel = iceState.hourlyC() + } } - if !g.isConnectedOnAllWay() { - callback() - } case <-g.relayedConnDisconnected: g.log.Debugf("Relay connection changed, reset reconnection ticker") ticker.Stop() - ticker = g.prepareExponentTicker(ctx) + ticker = g.newReconnectTicker(ctx) tickerChannel = ticker.C + iceState.reset() case <-g.iCEConnDisconnected: g.log.Debugf("ICE connection changed, reset reconnection ticker") ticker.Stop() - ticker = g.prepareExponentTicker(ctx) + ticker = g.newReconnectTicker(ctx) tickerChannel = ticker.C + iceState.reset() case <-srReconnectedChan: g.log.Debugf("has network changes, reset reconnection ticker") ticker.Stop() - ticker = g.prepareExponentTicker(ctx) + ticker = g.newReconnectTicker(ctx) tickerChannel = ticker.C + iceState.reset() case <-ctx.Done(): g.log.Debugf("context is done, stop reconnect loop") @@ -120,7 +152,7 @@ func (g *Guard) initialTicker(ctx context.Context) *backoff.Ticker { return backoff.NewTicker(bo) } -func (g *Guard) prepareExponentTicker(ctx context.Context) *backoff.Ticker { +func (g *Guard) newReconnectTicker(ctx context.Context) *backoff.Ticker { bo := backoff.WithContext(&backoff.ExponentialBackOff{ InitialInterval: 800 * time.Millisecond, RandomizationFactor: 0.1, diff --git a/client/internal/peer/guard/ice_retry_state.go b/client/internal/peer/guard/ice_retry_state.go new file mode 100644 index 000000000..01dc1bf2d --- /dev/null +++ b/client/internal/peer/guard/ice_retry_state.go @@ -0,0 +1,61 @@ +package guard + +import ( + "time" + + log "github.com/sirupsen/logrus" +) + +const ( + // maxICERetries is the maximum number of ICE offer attempts when relay is connected + maxICERetries = 3 + // iceRetryInterval is the periodic retry interval after ICE retries are exhausted + iceRetryInterval = 1 * time.Hour +) + +// iceRetryState tracks the limited ICE retry attempts when relay is already connected. +// After maxICERetries attempts it switches to a periodic hourly retry. +type iceRetryState struct { + log *log.Entry + retries int + hourly *time.Ticker +} + +func (s *iceRetryState) reset() { + s.retries = 0 + if s.hourly != nil { + s.hourly.Stop() + s.hourly = nil + } +} + +// shouldRetry reports whether the caller should send another ICE offer on this tick. +// Returns false when the per-cycle retry budget is exhausted and the caller must switch +// to the hourly ticker via enterHourlyMode + hourlyC. +func (s *iceRetryState) shouldRetry() bool { + if s.hourly != nil { + s.log.Debugf("hourly ICE retry attempt") + return true + } + + s.retries++ + if s.retries <= maxICERetries { + s.log.Debugf("ICE retry attempt %d/%d", s.retries, maxICERetries) + return true + } + + return false +} + +// enterHourlyMode starts the hourly retry ticker. Must be called after shouldRetry returns false. +func (s *iceRetryState) enterHourlyMode() { + s.log.Infof("ICE retries exhausted (%d/%d), switching to hourly retry", maxICERetries, maxICERetries) + s.hourly = time.NewTicker(iceRetryInterval) +} + +func (s *iceRetryState) hourlyC() <-chan time.Time { + if s.hourly == nil { + return nil + } + return s.hourly.C +} diff --git a/client/internal/peer/guard/ice_retry_state_test.go b/client/internal/peer/guard/ice_retry_state_test.go new file mode 100644 index 000000000..6a5b5a76f --- /dev/null +++ b/client/internal/peer/guard/ice_retry_state_test.go @@ -0,0 +1,103 @@ +package guard + +import ( + "testing" + + log "github.com/sirupsen/logrus" +) + +func newTestRetryState() *iceRetryState { + return &iceRetryState{log: log.NewEntry(log.StandardLogger())} +} + +func TestICERetryState_AllowsInitialBudget(t *testing.T) { + s := newTestRetryState() + + for i := 1; i <= maxICERetries; i++ { + if !s.shouldRetry() { + t.Fatalf("shouldRetry returned false on attempt %d, want true (budget = %d)", i, maxICERetries) + } + } +} + +func TestICERetryState_ExhaustsAfterBudget(t *testing.T) { + s := newTestRetryState() + + for i := 0; i < maxICERetries; i++ { + _ = s.shouldRetry() + } + + if s.shouldRetry() { + t.Fatalf("shouldRetry returned true after budget exhausted, want false") + } +} + +func TestICERetryState_HourlyCNilBeforeEnterHourlyMode(t *testing.T) { + s := newTestRetryState() + + if s.hourlyC() != nil { + t.Fatalf("hourlyC returned non-nil channel before enterHourlyMode") + } +} + +func TestICERetryState_EnterHourlyModeArmsTicker(t *testing.T) { + s := newTestRetryState() + for i := 0; i < maxICERetries+1; i++ { + _ = s.shouldRetry() + } + + s.enterHourlyMode() + defer s.reset() + + if s.hourlyC() == nil { + t.Fatalf("hourlyC returned nil after enterHourlyMode") + } +} + +func TestICERetryState_ShouldRetryTrueInHourlyMode(t *testing.T) { + s := newTestRetryState() + s.enterHourlyMode() + defer s.reset() + + if !s.shouldRetry() { + t.Fatalf("shouldRetry returned false in hourly mode, want true") + } + + // Subsequent calls also return true — we keep retrying on each hourly tick. + if !s.shouldRetry() { + t.Fatalf("second shouldRetry returned false in hourly mode, want true") + } +} + +func TestICERetryState_ResetRestoresBudget(t *testing.T) { + s := newTestRetryState() + for i := 0; i < maxICERetries+1; i++ { + _ = s.shouldRetry() + } + s.enterHourlyMode() + + s.reset() + + if s.hourlyC() != nil { + t.Fatalf("hourlyC returned non-nil channel after reset") + } + if s.retries != 0 { + t.Fatalf("retries = %d after reset, want 0", s.retries) + } + + for i := 1; i <= maxICERetries; i++ { + if !s.shouldRetry() { + t.Fatalf("shouldRetry returned false on attempt %d after reset, want true", i) + } + } +} + +func TestICERetryState_ResetIsIdempotent(t *testing.T) { + s := newTestRetryState() + s.reset() + s.reset() // second call must not panic or re-stop a nil ticker + + if s.hourlyC() != nil { + t.Fatalf("hourlyC non-nil after double reset") + } +} diff --git a/client/internal/peer/guard/sr_watcher.go b/client/internal/peer/guard/sr_watcher.go index 6f4f5ad4f..0befd7438 100644 --- a/client/internal/peer/guard/sr_watcher.go +++ b/client/internal/peer/guard/sr_watcher.go @@ -39,7 +39,7 @@ func NewSRWatcher(signalClient chNotifier, relayManager chNotifier, iFaceDiscove return srw } -func (w *SRWatcher) Start() { +func (w *SRWatcher) Start(disableICEMonitor bool) { w.mu.Lock() defer w.mu.Unlock() @@ -50,8 +50,10 @@ func (w *SRWatcher) Start() { ctx, cancel := context.WithCancel(context.Background()) w.cancelIceMonitor = cancel - iceMonitor := NewICEMonitor(w.iFaceDiscover, w.iceConfig, GetICEMonitorPeriod()) - go iceMonitor.Start(ctx, w.onICEChanged) + if !disableICEMonitor { + iceMonitor := NewICEMonitor(w.iFaceDiscover, w.iceConfig, GetICEMonitorPeriod()) + go iceMonitor.Start(ctx, w.onICEChanged) + } w.signalClient.SetOnReconnectedListener(w.onReconnected) w.relayManager.SetOnReconnectedListener(w.onReconnected) diff --git a/client/internal/peer/handshaker.go b/client/internal/peer/handshaker.go index 9b50cecd1..741dfce60 100644 --- a/client/internal/peer/handshaker.go +++ b/client/internal/peer/handshaker.go @@ -4,6 +4,7 @@ import ( "context" "errors" "sync" + "sync/atomic" log "github.com/sirupsen/logrus" @@ -43,6 +44,10 @@ type OfferAnswer struct { SessionID *ICESessionID } +func (o *OfferAnswer) hasICECredentials() bool { + return o.IceCredentials.UFrag != "" && o.IceCredentials.Pwd != "" +} + type Handshaker struct { mu sync.Mutex log *log.Entry @@ -59,6 +64,10 @@ type Handshaker struct { relayListener *AsyncOfferListener iceListener func(remoteOfferAnswer *OfferAnswer) + // remoteICESupported tracks whether the remote peer includes ICE credentials in its offers/answers. + // When false, the local side skips ICE listener dispatch and suppresses ICE credentials in responses. + remoteICESupported atomic.Bool + // remoteOffersCh is a channel used to wait for remote credentials to proceed with the connection remoteOffersCh chan OfferAnswer // remoteAnswerCh is a channel used to wait for remote credentials answer (confirmation of our offer) to proceed with the connection @@ -66,7 +75,7 @@ type Handshaker struct { } func NewHandshaker(log *log.Entry, config ConnConfig, signaler *Signaler, ice *WorkerICE, relay *WorkerRelay, metricsStages *MetricsStages) *Handshaker { - return &Handshaker{ + h := &Handshaker{ log: log, config: config, signaler: signaler, @@ -76,6 +85,13 @@ func NewHandshaker(log *log.Entry, config ConnConfig, signaler *Signaler, ice *W remoteOffersCh: make(chan OfferAnswer), remoteAnswerCh: make(chan OfferAnswer), } + // assume remote supports ICE until we learn otherwise from received offers + h.remoteICESupported.Store(ice != nil) + return h +} + +func (h *Handshaker) RemoteICESupported() bool { + return h.remoteICESupported.Load() } func (h *Handshaker) AddRelayListener(offer func(remoteOfferAnswer *OfferAnswer)) { @@ -90,18 +106,20 @@ func (h *Handshaker) Listen(ctx context.Context) { for { select { case remoteOfferAnswer := <-h.remoteOffersCh: - h.log.Infof("received offer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString()) + h.log.Infof("received offer, running version %s, remote WireGuard listen port %d, session id: %s, remote ICE supported: %t", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString(), remoteOfferAnswer.hasICECredentials()) // Record signaling received for reconnection attempts if h.metricsStages != nil { h.metricsStages.RecordSignalingReceived() } + h.updateRemoteICEState(&remoteOfferAnswer) + if h.relayListener != nil { h.relayListener.Notify(&remoteOfferAnswer) } - if h.iceListener != nil { + if h.iceListener != nil && h.RemoteICESupported() { h.iceListener(&remoteOfferAnswer) } @@ -110,18 +128,20 @@ func (h *Handshaker) Listen(ctx context.Context) { continue } case remoteOfferAnswer := <-h.remoteAnswerCh: - h.log.Infof("received answer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString()) + h.log.Infof("received answer, running version %s, remote WireGuard listen port %d, session id: %s, remote ICE supported: %t", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString(), remoteOfferAnswer.hasICECredentials()) // Record signaling received for reconnection attempts if h.metricsStages != nil { h.metricsStages.RecordSignalingReceived() } + h.updateRemoteICEState(&remoteOfferAnswer) + if h.relayListener != nil { h.relayListener.Notify(&remoteOfferAnswer) } - if h.iceListener != nil { + if h.iceListener != nil && h.RemoteICESupported() { h.iceListener(&remoteOfferAnswer) } case <-ctx.Done(): @@ -183,15 +203,18 @@ func (h *Handshaker) sendAnswer() error { } func (h *Handshaker) buildOfferAnswer() OfferAnswer { - uFrag, pwd := h.ice.GetLocalUserCredentials() - sid := h.ice.SessionID() answer := OfferAnswer{ - IceCredentials: IceCredentials{uFrag, pwd}, WgListenPort: h.config.LocalWgPort, Version: version.NetbirdVersion(), RosenpassPubKey: h.config.RosenpassConfig.PubKey, RosenpassAddr: h.config.RosenpassConfig.Addr, - SessionID: &sid, + } + + if h.ice != nil && h.RemoteICESupported() { + uFrag, pwd := h.ice.GetLocalUserCredentials() + sid := h.ice.SessionID() + answer.IceCredentials = IceCredentials{uFrag, pwd} + answer.SessionID = &sid } if addr, err := h.relay.RelayInstanceAddress(); err == nil { @@ -200,3 +223,18 @@ func (h *Handshaker) buildOfferAnswer() OfferAnswer { return answer } + +func (h *Handshaker) updateRemoteICEState(offer *OfferAnswer) { + hasICE := offer.hasICECredentials() + prev := h.remoteICESupported.Swap(hasICE) + if prev != hasICE { + if hasICE { + h.log.Infof("remote peer started sending ICE credentials") + } else { + h.log.Infof("remote peer stopped sending ICE credentials") + if h.ice != nil { + h.ice.Close() + } + } + } +} diff --git a/client/internal/peer/signaler.go b/client/internal/peer/signaler.go index b28906625..f6eb87cca 100644 --- a/client/internal/peer/signaler.go +++ b/client/internal/peer/signaler.go @@ -46,9 +46,13 @@ func (s *Signaler) Ready() bool { // SignalOfferAnswer signals either an offer or an answer to remote peer func (s *Signaler) signalOfferAnswer(offerAnswer OfferAnswer, remoteKey string, bodyType sProto.Body_Type) error { - sessionIDBytes, err := offerAnswer.SessionID.Bytes() - if err != nil { - log.Warnf("failed to get session ID bytes: %v", err) + var sessionIDBytes []byte + if offerAnswer.SessionID != nil { + var err error + sessionIDBytes, err = offerAnswer.SessionID.Bytes() + if err != nil { + log.Warnf("failed to get session ID bytes: %v", err) + } } msg, err := signal.MarshalCredential( s.wgPrivateKey,