[client] Recreate agent when receive new session id (#4564)

When an ICE agent connection was in progress, new offers were being ignored. This was incorrect logic because the remote agent could be restarted at any time.
In this change, whenever a new session ID is received, the ongoing handshake is closed and a new one is started.
This commit is contained in:
Zoltan Papp
2025-10-08 17:14:24 +02:00
committed by GitHub
parent 768332820e
commit 9021bb512b
9 changed files with 107 additions and 59 deletions

View File

@@ -171,9 +171,9 @@ func (conn *Conn) Open(engineCtx context.Context) error {
conn.handshaker = NewHandshaker(conn.Log, conn.config, conn.signaler, conn.workerICE, conn.workerRelay) conn.handshaker = NewHandshaker(conn.Log, conn.config, conn.signaler, conn.workerICE, conn.workerRelay)
conn.handshaker.AddOnNewOfferListener(conn.workerRelay.OnNewOffer) conn.handshaker.AddRelayListener(conn.workerRelay.OnNewOffer)
if !isForceRelayed() { if !isForceRelayed() {
conn.handshaker.AddOnNewOfferListener(conn.workerICE.OnNewOffer) conn.handshaker.AddICEListener(conn.workerICE.OnNewOffer)
} }
conn.guard = guard.NewGuard(conn.Log, conn.isConnectedOnAllWay, conn.config.Timeout, conn.srWatcher) conn.guard = guard.NewGuard(conn.Log, conn.isConnectedOnAllWay, conn.config.Timeout, conn.srWatcher)

View File

@@ -79,10 +79,10 @@ func TestConn_OnRemoteOffer(t *testing.T) {
return return
} }
onNewOffeChan := make(chan struct{}) onNewOfferChan := make(chan struct{})
conn.handshaker.AddOnNewOfferListener(func(remoteOfferAnswer *OfferAnswer) { conn.handshaker.AddRelayListener(func(remoteOfferAnswer *OfferAnswer) {
onNewOffeChan <- struct{}{} onNewOfferChan <- struct{}{}
}) })
conn.OnRemoteOffer(OfferAnswer{ conn.OnRemoteOffer(OfferAnswer{
@@ -98,7 +98,7 @@ func TestConn_OnRemoteOffer(t *testing.T) {
defer cancel() defer cancel()
select { select {
case <-onNewOffeChan: case <-onNewOfferChan:
// success // success
case <-ctx.Done(): case <-ctx.Done():
t.Error("expected to receive a new offer notification, but timed out") t.Error("expected to receive a new offer notification, but timed out")
@@ -118,10 +118,10 @@ func TestConn_OnRemoteAnswer(t *testing.T) {
return return
} }
onNewOffeChan := make(chan struct{}) onNewOfferChan := make(chan struct{})
conn.handshaker.AddOnNewOfferListener(func(remoteOfferAnswer *OfferAnswer) { conn.handshaker.AddRelayListener(func(remoteOfferAnswer *OfferAnswer) {
onNewOffeChan <- struct{}{} onNewOfferChan <- struct{}{}
}) })
conn.OnRemoteAnswer(OfferAnswer{ conn.OnRemoteAnswer(OfferAnswer{
@@ -136,7 +136,7 @@ func TestConn_OnRemoteAnswer(t *testing.T) {
defer cancel() defer cancel()
select { select {
case <-onNewOffeChan: case <-onNewOfferChan:
// success // success
case <-ctx.Done(): case <-ctx.Done():
t.Error("expected to receive a new offer notification, but timed out") t.Error("expected to receive a new offer notification, but timed out")

View File

@@ -0,0 +1,20 @@
package guard
import (
"os"
"strconv"
"time"
)
const (
envICEMonitorPeriod = "NB_ICE_MONITOR_PERIOD"
)
func GetICEMonitorPeriod() time.Duration {
if envVal := os.Getenv(envICEMonitorPeriod); envVal != "" {
if seconds, err := strconv.Atoi(envVal); err == nil && seconds > 0 {
return time.Duration(seconds) * time.Second
}
}
return defaultCandidatesMonitorPeriod
}

View File

@@ -16,8 +16,8 @@ import (
) )
const ( const (
candidatesMonitorPeriod = 5 * time.Minute defaultCandidatesMonitorPeriod = 5 * time.Minute
candidateGatheringTimeout = 5 * time.Second candidateGatheringTimeout = 5 * time.Second
) )
type ICEMonitor struct { type ICEMonitor struct {
@@ -25,16 +25,19 @@ type ICEMonitor struct {
iFaceDiscover stdnet.ExternalIFaceDiscover iFaceDiscover stdnet.ExternalIFaceDiscover
iceConfig icemaker.Config iceConfig icemaker.Config
tickerPeriod time.Duration
currentCandidatesAddress []string currentCandidatesAddress []string
candidatesMu sync.Mutex candidatesMu sync.Mutex
} }
func NewICEMonitor(iFaceDiscover stdnet.ExternalIFaceDiscover, config icemaker.Config) *ICEMonitor { func NewICEMonitor(iFaceDiscover stdnet.ExternalIFaceDiscover, config icemaker.Config, period time.Duration) *ICEMonitor {
log.Debugf("prepare ICE monitor with period: %s", period)
cm := &ICEMonitor{ cm := &ICEMonitor{
ReconnectCh: make(chan struct{}, 1), ReconnectCh: make(chan struct{}, 1),
iFaceDiscover: iFaceDiscover, iFaceDiscover: iFaceDiscover,
iceConfig: config, iceConfig: config,
tickerPeriod: period,
} }
return cm return cm
} }
@@ -46,7 +49,12 @@ func (cm *ICEMonitor) Start(ctx context.Context, onChanged func()) {
return return
} }
ticker := time.NewTicker(candidatesMonitorPeriod) // Initial check to populate the candidates for later comparison
if _, err := cm.handleCandidateTick(ctx, ufrag, pwd); err != nil {
log.Warnf("Failed to check initial ICE candidates: %v", err)
}
ticker := time.NewTicker(cm.tickerPeriod)
defer ticker.Stop() defer ticker.Stop()
for { for {

View File

@@ -51,7 +51,7 @@ func (w *SRWatcher) Start() {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
w.cancelIceMonitor = cancel w.cancelIceMonitor = cancel
iceMonitor := NewICEMonitor(w.iFaceDiscover, w.iceConfig) iceMonitor := NewICEMonitor(w.iFaceDiscover, w.iceConfig, GetICEMonitorPeriod())
go iceMonitor.Start(ctx, w.onICEChanged) go iceMonitor.Start(ctx, w.onICEChanged)
w.signalClient.SetOnReconnectedListener(w.onReconnected) w.signalClient.SetOnReconnectedListener(w.onReconnected)
w.relayManager.SetOnReconnectedListener(w.onReconnected) w.relayManager.SetOnReconnectedListener(w.onReconnected)

View File

@@ -44,13 +44,19 @@ type OfferAnswer struct {
} }
type Handshaker struct { type Handshaker struct {
mu sync.Mutex mu sync.Mutex
log *log.Entry log *log.Entry
config ConnConfig config ConnConfig
signaler *Signaler signaler *Signaler
ice *WorkerICE ice *WorkerICE
relay *WorkerRelay relay *WorkerRelay
onNewOfferListeners []*OfferListener // relayListener is not blocking because the listener is using a goroutine to process the messages
// and it will only keep the latest message if multiple offers are received in a short time
// this is to avoid blocking the handshaker if the listener is doing some heavy processing
// and also to avoid processing old offers if multiple offers are received in a short time
// the listener will always process the latest offer
relayListener *AsyncOfferListener
iceListener func(remoteOfferAnswer *OfferAnswer)
// remoteOffersCh is a channel used to wait for remote credentials to proceed with the connection // remoteOffersCh is a channel used to wait for remote credentials to proceed with the connection
remoteOffersCh chan OfferAnswer remoteOffersCh chan OfferAnswer
@@ -70,28 +76,39 @@ func NewHandshaker(log *log.Entry, config ConnConfig, signaler *Signaler, ice *W
} }
} }
func (h *Handshaker) AddOnNewOfferListener(offer func(remoteOfferAnswer *OfferAnswer)) { func (h *Handshaker) AddRelayListener(offer func(remoteOfferAnswer *OfferAnswer)) {
l := NewOfferListener(offer) h.relayListener = NewAsyncOfferListener(offer)
h.onNewOfferListeners = append(h.onNewOfferListeners, l) }
func (h *Handshaker) AddICEListener(offer func(remoteOfferAnswer *OfferAnswer)) {
h.iceListener = offer
} }
func (h *Handshaker) Listen(ctx context.Context) { func (h *Handshaker) Listen(ctx context.Context) {
for { for {
select { select {
case remoteOfferAnswer := <-h.remoteOffersCh: case remoteOfferAnswer := <-h.remoteOffersCh:
// received confirmation from the remote peer -> ready to proceed h.log.Infof("received offer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString())
if h.relayListener != nil {
h.relayListener.Notify(&remoteOfferAnswer)
}
if h.iceListener != nil {
h.iceListener(&remoteOfferAnswer)
}
if err := h.sendAnswer(); err != nil { if err := h.sendAnswer(); err != nil {
h.log.Errorf("failed to send remote offer confirmation: %s", err) h.log.Errorf("failed to send remote offer confirmation: %s", err)
continue continue
} }
for _, listener := range h.onNewOfferListeners {
listener.Notify(&remoteOfferAnswer)
}
h.log.Infof("received offer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString())
case remoteOfferAnswer := <-h.remoteAnswerCh: case remoteOfferAnswer := <-h.remoteAnswerCh:
h.log.Infof("received answer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString()) h.log.Infof("received answer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString())
for _, listener := range h.onNewOfferListeners { if h.relayListener != nil {
listener.Notify(&remoteOfferAnswer) h.relayListener.Notify(&remoteOfferAnswer)
}
if h.iceListener != nil {
h.iceListener(&remoteOfferAnswer)
} }
case <-ctx.Done(): case <-ctx.Done():
h.log.Infof("stop listening for remote offers and answers") h.log.Infof("stop listening for remote offers and answers")

View File

@@ -13,20 +13,20 @@ func (oa *OfferAnswer) SessionIDString() string {
return oa.SessionID.String() return oa.SessionID.String()
} }
type OfferListener struct { type AsyncOfferListener struct {
fn callbackFunc fn callbackFunc
running bool running bool
latest *OfferAnswer latest *OfferAnswer
mu sync.Mutex mu sync.Mutex
} }
func NewOfferListener(fn callbackFunc) *OfferListener { func NewAsyncOfferListener(fn callbackFunc) *AsyncOfferListener {
return &OfferListener{ return &AsyncOfferListener{
fn: fn, fn: fn,
} }
} }
func (o *OfferListener) Notify(remoteOfferAnswer *OfferAnswer) { func (o *AsyncOfferListener) Notify(remoteOfferAnswer *OfferAnswer) {
o.mu.Lock() o.mu.Lock()
defer o.mu.Unlock() defer o.mu.Unlock()

View File

@@ -14,7 +14,7 @@ func Test_newOfferListener(t *testing.T) {
runChan <- struct{}{} runChan <- struct{}{}
} }
hl := NewOfferListener(longRunningFn) hl := NewAsyncOfferListener(longRunningFn)
hl.Notify(dummyOfferAnswer) hl.Notify(dummyOfferAnswer)
hl.Notify(dummyOfferAnswer) hl.Notify(dummyOfferAnswer)

View File

@@ -92,23 +92,16 @@ func NewWorkerICE(ctx context.Context, log *log.Entry, config ConnConfig, conn *
func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) { func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
w.log.Debugf("OnNewOffer for ICE, serial: %s", remoteOfferAnswer.SessionIDString()) w.log.Debugf("OnNewOffer for ICE, serial: %s", remoteOfferAnswer.SessionIDString())
w.muxAgent.Lock() w.muxAgent.Lock()
defer w.muxAgent.Unlock()
if w.agentConnecting { if w.agent != nil || w.agentConnecting {
w.log.Debugf("agent connection is in progress, skipping the offer")
w.muxAgent.Unlock()
return
}
if w.agent != nil {
// backward compatibility with old clients that do not send session ID // backward compatibility with old clients that do not send session ID
if remoteOfferAnswer.SessionID == nil { if remoteOfferAnswer.SessionID == nil {
w.log.Debugf("agent already exists, skipping the offer") w.log.Debugf("agent already exists, skipping the offer")
w.muxAgent.Unlock()
return return
} }
if w.remoteSessionID == *remoteOfferAnswer.SessionID { if w.remoteSessionID == *remoteOfferAnswer.SessionID {
w.log.Debugf("agent already exists and session ID matches, skipping the offer: %s", remoteOfferAnswer.SessionIDString()) w.log.Debugf("agent already exists and session ID matches, skipping the offer: %s", remoteOfferAnswer.SessionIDString())
w.muxAgent.Unlock()
return return
} }
w.log.Debugf("agent already exists, recreate the connection") w.log.Debugf("agent already exists, recreate the connection")
@@ -116,6 +109,12 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
if err := w.agent.Close(); err != nil { if err := w.agent.Close(); err != nil {
w.log.Warnf("failed to close ICE agent: %s", err) w.log.Warnf("failed to close ICE agent: %s", err)
} }
sessionID, err := NewICESessionID()
if err != nil {
w.log.Errorf("failed to create new session ID: %s", err)
}
w.sessionID = sessionID
w.agent = nil w.agent = nil
} }
@@ -126,18 +125,23 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
preferredCandidateTypes = icemaker.CandidateTypes() preferredCandidateTypes = icemaker.CandidateTypes()
} }
w.log.Debugf("recreate ICE agent") if remoteOfferAnswer.SessionID != nil {
w.log.Debugf("recreate ICE agent: %s / %s", w.sessionID, *remoteOfferAnswer.SessionID)
}
dialerCtx, dialerCancel := context.WithCancel(w.ctx) dialerCtx, dialerCancel := context.WithCancel(w.ctx)
agent, err := w.reCreateAgent(dialerCancel, preferredCandidateTypes) agent, err := w.reCreateAgent(dialerCancel, preferredCandidateTypes)
if err != nil { if err != nil {
w.log.Errorf("failed to recreate ICE Agent: %s", err) w.log.Errorf("failed to recreate ICE Agent: %s", err)
w.muxAgent.Unlock()
return return
} }
w.agent = agent w.agent = agent
w.agentDialerCancel = dialerCancel w.agentDialerCancel = dialerCancel
w.agentConnecting = true w.agentConnecting = true
w.muxAgent.Unlock() if remoteOfferAnswer.SessionID != nil {
w.remoteSessionID = *remoteOfferAnswer.SessionID
} else {
w.remoteSessionID = ""
}
go w.connect(dialerCtx, agent, remoteOfferAnswer) go w.connect(dialerCtx, agent, remoteOfferAnswer)
} }
@@ -293,9 +297,6 @@ func (w *WorkerICE) connect(ctx context.Context, agent *icemaker.ThreadSafeAgent
w.muxAgent.Lock() w.muxAgent.Lock()
w.agentConnecting = false w.agentConnecting = false
w.lastSuccess = time.Now() w.lastSuccess = time.Now()
if remoteOfferAnswer.SessionID != nil {
w.remoteSessionID = *remoteOfferAnswer.SessionID
}
w.muxAgent.Unlock() w.muxAgent.Unlock()
// todo: the potential problem is a race between the onConnectionStateChange // todo: the potential problem is a race between the onConnectionStateChange
@@ -309,16 +310,17 @@ func (w *WorkerICE) closeAgent(agent *icemaker.ThreadSafeAgent, cancel context.C
} }
w.muxAgent.Lock() w.muxAgent.Lock()
// todo review does it make sense to generate new session ID all the time when w.agent==agent
sessionID, err := NewICESessionID()
if err != nil {
w.log.Errorf("failed to create new session ID: %s", err)
}
w.sessionID = sessionID
if w.agent == agent { if w.agent == agent {
// consider to remove from here and move to the OnNewOffer
sessionID, err := NewICESessionID()
if err != nil {
w.log.Errorf("failed to create new session ID: %s", err)
}
w.sessionID = sessionID
w.agent = nil w.agent = nil
w.agentConnecting = false w.agentConnecting = false
w.remoteSessionID = ""
} }
w.muxAgent.Unlock() w.muxAgent.Unlock()
} }
@@ -395,11 +397,12 @@ func (w *WorkerICE) onConnectionStateChange(agent *icemaker.ThreadSafeAgent, dia
// ice.ConnectionStateClosed happens when we recreate the agent. For the P2P to TURN switch important to // ice.ConnectionStateClosed happens when we recreate the agent. For the P2P to TURN switch important to
// notify the conn.onICEStateDisconnected changes to update the current used priority // notify the conn.onICEStateDisconnected changes to update the current used priority
w.closeAgent(agent, dialerCancel)
if w.lastKnownState == ice.ConnectionStateConnected { if w.lastKnownState == ice.ConnectionStateConnected {
w.lastKnownState = ice.ConnectionStateDisconnected w.lastKnownState = ice.ConnectionStateDisconnected
w.conn.onICEStateDisconnected() w.conn.onICEStateDisconnected()
} }
w.closeAgent(agent, dialerCancel)
default: default:
return return
} }