From 3739c237c774f6d426e59ab49c5590c0ddeec085 Mon Sep 17 00:00:00 2001 From: Owen Date: Sat, 17 Jan 2026 17:59:24 -0800 Subject: [PATCH] Handle rebind in the polling function --- bind/shared_bind.go | 113 +++++++++++++++++++++++++++++++++++++------- 1 file changed, 95 insertions(+), 18 deletions(-) diff --git a/bind/shared_bind.go b/bind/shared_bind.go index 7f40d04..ae930bd 100644 --- a/bind/shared_bind.go +++ b/bind/shared_bind.go @@ -10,6 +10,7 @@ import ( "runtime" "sync" "sync/atomic" + "time" "github.com/fosrl/newt/logger" "golang.org/x/net/ipv4" @@ -144,6 +145,10 @@ type SharedBind struct { // Callback for magic test responses (used for holepunch testing) magicResponseCallback atomic.Pointer[func(addr netip.AddrPort, echoData []byte)] + + // Rebinding state - used to keep receive goroutines alive during socket transition + rebinding bool // true when socket is being replaced + rebindingCond *sync.Cond // signaled when rebind completes } // MagicResponseCallback is the function signature for magic packet response callbacks @@ -163,6 +168,9 @@ func New(udpConn *net.UDPConn) (*SharedBind, error) { closeChan: make(chan struct{}), } + // Initialize the rebinding condition variable + bind.rebindingCond = sync.NewCond(&bind.mu) + // Initialize reference count to 1 (the creator holds the first reference) bind.refCount.Store(1) @@ -324,6 +332,8 @@ func (b *SharedBind) GetPort() uint16 { // to the same port before calling Rebind. // // Returns the port that was being used, so the caller can attempt to rebind to it. +// Sets the rebinding flag so receive goroutines will wait for the new socket +// instead of exiting. func (b *SharedBind) CloseSocket() (uint16, error) { b.mu.Lock() defer b.mu.Unlock() @@ -334,9 +344,13 @@ func (b *SharedBind) CloseSocket() (uint16, error) { port := b.port + // Set rebinding flag BEFORE closing the socket so receive goroutines + // know to wait instead of exit + b.rebinding = true + // Close the old connection to release the port if b.udpConn != nil { - logger.Debug("Closing UDP connection to release port %d", port) + logger.Debug("Closing UDP connection to release port %d (rebinding)", port) b.udpConn.Close() b.udpConn = nil } @@ -398,9 +412,11 @@ func (b *SharedBind) Rebind(newConn *net.UDPConn) error { logger.Info("Rebound UDP socket to port %d", b.port) } - // Note: recvFuncs don't need to be recreated because they reference b.udpConn - // and b.ipv4PC through the SharedBind struct, which we just updated. - // The receive functions will automatically use the new connection on their next read. + // Clear the rebinding flag and wake up any waiting receive goroutines + b.rebinding = false + b.rebindingCond.Broadcast() + + logger.Debug("Rebind complete, signaled waiting receive goroutines") return nil } @@ -487,24 +503,77 @@ func (b *SharedBind) Open(uport uint16) ([]wgConn.ReceiveFunc, uint16, error) { // makeReceiveSocket creates a receive function for physical UDP socket packets func (b *SharedBind) makeReceiveSocket() wgConn.ReceiveFunc { return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) { - if b.closed.Load() { - return 0, net.ErrClosed - } + for { + if b.closed.Load() { + return 0, net.ErrClosed + } - b.mu.RLock() - conn := b.udpConn - pc := b.ipv4PC - b.mu.RUnlock() + b.mu.RLock() + conn := b.udpConn + pc := b.ipv4PC + b.mu.RUnlock() - if conn == nil { - return 0, net.ErrClosed - } + if conn == nil { + // Socket is nil - check if we're rebinding or truly closed + if b.closed.Load() { + return 0, net.ErrClosed + } - // Use batch reading on Linux for performance - if pc != nil && (runtime.GOOS == "linux" || runtime.GOOS == "android") { - return b.receiveIPv4Batch(pc, bufs, sizes, eps) + // Wait for rebind to complete + b.mu.Lock() + for b.rebinding && !b.closed.Load() { + logger.Debug("Receive goroutine waiting for socket rebind to complete") + b.rebindingCond.Wait() + } + b.mu.Unlock() + + // Check again after waking up + if b.closed.Load() { + return 0, net.ErrClosed + } + + // Loop back to retry with new socket + continue + } + + // Use batch reading on Linux for performance + var n int + var err error + if pc != nil && (runtime.GOOS == "linux" || runtime.GOOS == "android") { + n, err = b.receiveIPv4Batch(pc, bufs, sizes, eps) + } else { + n, err = b.receiveIPv4Simple(conn, bufs, sizes, eps) + } + + if err != nil { + // Check if this error is due to rebinding + b.mu.RLock() + rebinding := b.rebinding + b.mu.RUnlock() + + if rebinding { + logger.Debug("Receive got error during rebind, waiting for new socket: %v", err) + // Wait for rebind to complete and retry + b.mu.Lock() + for b.rebinding && !b.closed.Load() { + b.rebindingCond.Wait() + } + b.mu.Unlock() + + if b.closed.Load() { + return 0, net.ErrClosed + } + + // Retry with new socket + continue + } + + // Not rebinding, return the error + return 0, err + } + + return n, nil } - return b.receiveIPv4Simple(conn, bufs, sizes, eps) } } @@ -587,9 +656,17 @@ func (b *SharedBind) receiveIPv4Batch(pc *ipv4.PacketConn, bufs [][]byte, sizes // receiveIPv4Simple uses simple ReadFromUDP for non-Linux platforms func (b *SharedBind) receiveIPv4Simple(conn *net.UDPConn, bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (int, error) { + // Set a read deadline so we can periodically check for rebind state + // This prevents blocking forever on a socket that's about to be closed + conn.SetReadDeadline(time.Now().Add(1 * time.Second)) for { n, addr, err := conn.ReadFromUDP(bufs[0]) if err != nil { + // Check if this is a timeout - if so, just return the error + // so the caller can check rebind state and retry + if netErr, ok := err.(net.Error); ok && netErr.Timeout() { + return 0, err + } return 0, err }