mirror of
https://github.com/fosrl/newt.git
synced 2026-02-08 05:56:40 +00:00
Handle rebind in the polling function
This commit is contained in:
@@ -10,6 +10,7 @@ import (
|
||||
"runtime"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/fosrl/newt/logger"
|
||||
"golang.org/x/net/ipv4"
|
||||
@@ -144,6 +145,10 @@ type SharedBind struct {
|
||||
|
||||
// Callback for magic test responses (used for holepunch testing)
|
||||
magicResponseCallback atomic.Pointer[func(addr netip.AddrPort, echoData []byte)]
|
||||
|
||||
// Rebinding state - used to keep receive goroutines alive during socket transition
|
||||
rebinding bool // true when socket is being replaced
|
||||
rebindingCond *sync.Cond // signaled when rebind completes
|
||||
}
|
||||
|
||||
// MagicResponseCallback is the function signature for magic packet response callbacks
|
||||
@@ -163,6 +168,9 @@ func New(udpConn *net.UDPConn) (*SharedBind, error) {
|
||||
closeChan: make(chan struct{}),
|
||||
}
|
||||
|
||||
// Initialize the rebinding condition variable
|
||||
bind.rebindingCond = sync.NewCond(&bind.mu)
|
||||
|
||||
// Initialize reference count to 1 (the creator holds the first reference)
|
||||
bind.refCount.Store(1)
|
||||
|
||||
@@ -324,6 +332,8 @@ func (b *SharedBind) GetPort() uint16 {
|
||||
// to the same port before calling Rebind.
|
||||
//
|
||||
// Returns the port that was being used, so the caller can attempt to rebind to it.
|
||||
// Sets the rebinding flag so receive goroutines will wait for the new socket
|
||||
// instead of exiting.
|
||||
func (b *SharedBind) CloseSocket() (uint16, error) {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
@@ -334,9 +344,13 @@ func (b *SharedBind) CloseSocket() (uint16, error) {
|
||||
|
||||
port := b.port
|
||||
|
||||
// Set rebinding flag BEFORE closing the socket so receive goroutines
|
||||
// know to wait instead of exit
|
||||
b.rebinding = true
|
||||
|
||||
// Close the old connection to release the port
|
||||
if b.udpConn != nil {
|
||||
logger.Debug("Closing UDP connection to release port %d", port)
|
||||
logger.Debug("Closing UDP connection to release port %d (rebinding)", port)
|
||||
b.udpConn.Close()
|
||||
b.udpConn = nil
|
||||
}
|
||||
@@ -398,9 +412,11 @@ func (b *SharedBind) Rebind(newConn *net.UDPConn) error {
|
||||
logger.Info("Rebound UDP socket to port %d", b.port)
|
||||
}
|
||||
|
||||
// Note: recvFuncs don't need to be recreated because they reference b.udpConn
|
||||
// and b.ipv4PC through the SharedBind struct, which we just updated.
|
||||
// The receive functions will automatically use the new connection on their next read.
|
||||
// Clear the rebinding flag and wake up any waiting receive goroutines
|
||||
b.rebinding = false
|
||||
b.rebindingCond.Broadcast()
|
||||
|
||||
logger.Debug("Rebind complete, signaled waiting receive goroutines")
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -487,24 +503,77 @@ func (b *SharedBind) Open(uport uint16) ([]wgConn.ReceiveFunc, uint16, error) {
|
||||
// makeReceiveSocket creates a receive function for physical UDP socket packets
|
||||
func (b *SharedBind) makeReceiveSocket() wgConn.ReceiveFunc {
|
||||
return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) {
|
||||
if b.closed.Load() {
|
||||
return 0, net.ErrClosed
|
||||
}
|
||||
for {
|
||||
if b.closed.Load() {
|
||||
return 0, net.ErrClosed
|
||||
}
|
||||
|
||||
b.mu.RLock()
|
||||
conn := b.udpConn
|
||||
pc := b.ipv4PC
|
||||
b.mu.RUnlock()
|
||||
b.mu.RLock()
|
||||
conn := b.udpConn
|
||||
pc := b.ipv4PC
|
||||
b.mu.RUnlock()
|
||||
|
||||
if conn == nil {
|
||||
return 0, net.ErrClosed
|
||||
}
|
||||
if conn == nil {
|
||||
// Socket is nil - check if we're rebinding or truly closed
|
||||
if b.closed.Load() {
|
||||
return 0, net.ErrClosed
|
||||
}
|
||||
|
||||
// Use batch reading on Linux for performance
|
||||
if pc != nil && (runtime.GOOS == "linux" || runtime.GOOS == "android") {
|
||||
return b.receiveIPv4Batch(pc, bufs, sizes, eps)
|
||||
// Wait for rebind to complete
|
||||
b.mu.Lock()
|
||||
for b.rebinding && !b.closed.Load() {
|
||||
logger.Debug("Receive goroutine waiting for socket rebind to complete")
|
||||
b.rebindingCond.Wait()
|
||||
}
|
||||
b.mu.Unlock()
|
||||
|
||||
// Check again after waking up
|
||||
if b.closed.Load() {
|
||||
return 0, net.ErrClosed
|
||||
}
|
||||
|
||||
// Loop back to retry with new socket
|
||||
continue
|
||||
}
|
||||
|
||||
// Use batch reading on Linux for performance
|
||||
var n int
|
||||
var err error
|
||||
if pc != nil && (runtime.GOOS == "linux" || runtime.GOOS == "android") {
|
||||
n, err = b.receiveIPv4Batch(pc, bufs, sizes, eps)
|
||||
} else {
|
||||
n, err = b.receiveIPv4Simple(conn, bufs, sizes, eps)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
// Check if this error is due to rebinding
|
||||
b.mu.RLock()
|
||||
rebinding := b.rebinding
|
||||
b.mu.RUnlock()
|
||||
|
||||
if rebinding {
|
||||
logger.Debug("Receive got error during rebind, waiting for new socket: %v", err)
|
||||
// Wait for rebind to complete and retry
|
||||
b.mu.Lock()
|
||||
for b.rebinding && !b.closed.Load() {
|
||||
b.rebindingCond.Wait()
|
||||
}
|
||||
b.mu.Unlock()
|
||||
|
||||
if b.closed.Load() {
|
||||
return 0, net.ErrClosed
|
||||
}
|
||||
|
||||
// Retry with new socket
|
||||
continue
|
||||
}
|
||||
|
||||
// Not rebinding, return the error
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return n, nil
|
||||
}
|
||||
return b.receiveIPv4Simple(conn, bufs, sizes, eps)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -587,9 +656,17 @@ func (b *SharedBind) receiveIPv4Batch(pc *ipv4.PacketConn, bufs [][]byte, sizes
|
||||
|
||||
// receiveIPv4Simple uses simple ReadFromUDP for non-Linux platforms
|
||||
func (b *SharedBind) receiveIPv4Simple(conn *net.UDPConn, bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (int, error) {
|
||||
// Set a read deadline so we can periodically check for rebind state
|
||||
// This prevents blocking forever on a socket that's about to be closed
|
||||
conn.SetReadDeadline(time.Now().Add(1 * time.Second))
|
||||
for {
|
||||
n, addr, err := conn.ReadFromUDP(bufs[0])
|
||||
if err != nil {
|
||||
// Check if this is a timeout - if so, just return the error
|
||||
// so the caller can check rebind state and retry
|
||||
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
|
||||
return 0, err
|
||||
}
|
||||
return 0, err
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user