mirror of
https://github.com/fosrl/newt.git
synced 2026-03-08 03:36:40 +00:00
Handle rebind in the polling function
This commit is contained in:
@@ -10,6 +10,7 @@ import (
|
|||||||
"runtime"
|
"runtime"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/fosrl/newt/logger"
|
"github.com/fosrl/newt/logger"
|
||||||
"golang.org/x/net/ipv4"
|
"golang.org/x/net/ipv4"
|
||||||
@@ -144,6 +145,10 @@ type SharedBind struct {
|
|||||||
|
|
||||||
// Callback for magic test responses (used for holepunch testing)
|
// Callback for magic test responses (used for holepunch testing)
|
||||||
magicResponseCallback atomic.Pointer[func(addr netip.AddrPort, echoData []byte)]
|
magicResponseCallback atomic.Pointer[func(addr netip.AddrPort, echoData []byte)]
|
||||||
|
|
||||||
|
// Rebinding state - used to keep receive goroutines alive during socket transition
|
||||||
|
rebinding bool // true when socket is being replaced
|
||||||
|
rebindingCond *sync.Cond // signaled when rebind completes
|
||||||
}
|
}
|
||||||
|
|
||||||
// MagicResponseCallback is the function signature for magic packet response callbacks
|
// MagicResponseCallback is the function signature for magic packet response callbacks
|
||||||
@@ -163,6 +168,9 @@ func New(udpConn *net.UDPConn) (*SharedBind, error) {
|
|||||||
closeChan: make(chan struct{}),
|
closeChan: make(chan struct{}),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Initialize the rebinding condition variable
|
||||||
|
bind.rebindingCond = sync.NewCond(&bind.mu)
|
||||||
|
|
||||||
// Initialize reference count to 1 (the creator holds the first reference)
|
// Initialize reference count to 1 (the creator holds the first reference)
|
||||||
bind.refCount.Store(1)
|
bind.refCount.Store(1)
|
||||||
|
|
||||||
@@ -324,6 +332,8 @@ func (b *SharedBind) GetPort() uint16 {
|
|||||||
// to the same port before calling Rebind.
|
// to the same port before calling Rebind.
|
||||||
//
|
//
|
||||||
// Returns the port that was being used, so the caller can attempt to rebind to it.
|
// Returns the port that was being used, so the caller can attempt to rebind to it.
|
||||||
|
// Sets the rebinding flag so receive goroutines will wait for the new socket
|
||||||
|
// instead of exiting.
|
||||||
func (b *SharedBind) CloseSocket() (uint16, error) {
|
func (b *SharedBind) CloseSocket() (uint16, error) {
|
||||||
b.mu.Lock()
|
b.mu.Lock()
|
||||||
defer b.mu.Unlock()
|
defer b.mu.Unlock()
|
||||||
@@ -334,9 +344,13 @@ func (b *SharedBind) CloseSocket() (uint16, error) {
|
|||||||
|
|
||||||
port := b.port
|
port := b.port
|
||||||
|
|
||||||
|
// Set rebinding flag BEFORE closing the socket so receive goroutines
|
||||||
|
// know to wait instead of exit
|
||||||
|
b.rebinding = true
|
||||||
|
|
||||||
// Close the old connection to release the port
|
// Close the old connection to release the port
|
||||||
if b.udpConn != nil {
|
if b.udpConn != nil {
|
||||||
logger.Debug("Closing UDP connection to release port %d", port)
|
logger.Debug("Closing UDP connection to release port %d (rebinding)", port)
|
||||||
b.udpConn.Close()
|
b.udpConn.Close()
|
||||||
b.udpConn = nil
|
b.udpConn = nil
|
||||||
}
|
}
|
||||||
@@ -398,9 +412,11 @@ func (b *SharedBind) Rebind(newConn *net.UDPConn) error {
|
|||||||
logger.Info("Rebound UDP socket to port %d", b.port)
|
logger.Info("Rebound UDP socket to port %d", b.port)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Note: recvFuncs don't need to be recreated because they reference b.udpConn
|
// Clear the rebinding flag and wake up any waiting receive goroutines
|
||||||
// and b.ipv4PC through the SharedBind struct, which we just updated.
|
b.rebinding = false
|
||||||
// The receive functions will automatically use the new connection on their next read.
|
b.rebindingCond.Broadcast()
|
||||||
|
|
||||||
|
logger.Debug("Rebind complete, signaled waiting receive goroutines")
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -487,24 +503,77 @@ func (b *SharedBind) Open(uport uint16) ([]wgConn.ReceiveFunc, uint16, error) {
|
|||||||
// makeReceiveSocket creates a receive function for physical UDP socket packets
|
// makeReceiveSocket creates a receive function for physical UDP socket packets
|
||||||
func (b *SharedBind) makeReceiveSocket() wgConn.ReceiveFunc {
|
func (b *SharedBind) makeReceiveSocket() wgConn.ReceiveFunc {
|
||||||
return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) {
|
return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) {
|
||||||
if b.closed.Load() {
|
for {
|
||||||
return 0, net.ErrClosed
|
if b.closed.Load() {
|
||||||
}
|
return 0, net.ErrClosed
|
||||||
|
}
|
||||||
|
|
||||||
b.mu.RLock()
|
b.mu.RLock()
|
||||||
conn := b.udpConn
|
conn := b.udpConn
|
||||||
pc := b.ipv4PC
|
pc := b.ipv4PC
|
||||||
b.mu.RUnlock()
|
b.mu.RUnlock()
|
||||||
|
|
||||||
if conn == nil {
|
if conn == nil {
|
||||||
return 0, net.ErrClosed
|
// Socket is nil - check if we're rebinding or truly closed
|
||||||
}
|
if b.closed.Load() {
|
||||||
|
return 0, net.ErrClosed
|
||||||
|
}
|
||||||
|
|
||||||
// Use batch reading on Linux for performance
|
// Wait for rebind to complete
|
||||||
if pc != nil && (runtime.GOOS == "linux" || runtime.GOOS == "android") {
|
b.mu.Lock()
|
||||||
return b.receiveIPv4Batch(pc, bufs, sizes, eps)
|
for b.rebinding && !b.closed.Load() {
|
||||||
|
logger.Debug("Receive goroutine waiting for socket rebind to complete")
|
||||||
|
b.rebindingCond.Wait()
|
||||||
|
}
|
||||||
|
b.mu.Unlock()
|
||||||
|
|
||||||
|
// Check again after waking up
|
||||||
|
if b.closed.Load() {
|
||||||
|
return 0, net.ErrClosed
|
||||||
|
}
|
||||||
|
|
||||||
|
// Loop back to retry with new socket
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use batch reading on Linux for performance
|
||||||
|
var n int
|
||||||
|
var err error
|
||||||
|
if pc != nil && (runtime.GOOS == "linux" || runtime.GOOS == "android") {
|
||||||
|
n, err = b.receiveIPv4Batch(pc, bufs, sizes, eps)
|
||||||
|
} else {
|
||||||
|
n, err = b.receiveIPv4Simple(conn, bufs, sizes, eps)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
// Check if this error is due to rebinding
|
||||||
|
b.mu.RLock()
|
||||||
|
rebinding := b.rebinding
|
||||||
|
b.mu.RUnlock()
|
||||||
|
|
||||||
|
if rebinding {
|
||||||
|
logger.Debug("Receive got error during rebind, waiting for new socket: %v", err)
|
||||||
|
// Wait for rebind to complete and retry
|
||||||
|
b.mu.Lock()
|
||||||
|
for b.rebinding && !b.closed.Load() {
|
||||||
|
b.rebindingCond.Wait()
|
||||||
|
}
|
||||||
|
b.mu.Unlock()
|
||||||
|
|
||||||
|
if b.closed.Load() {
|
||||||
|
return 0, net.ErrClosed
|
||||||
|
}
|
||||||
|
|
||||||
|
// Retry with new socket
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Not rebinding, return the error
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return n, nil
|
||||||
}
|
}
|
||||||
return b.receiveIPv4Simple(conn, bufs, sizes, eps)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -587,9 +656,17 @@ func (b *SharedBind) receiveIPv4Batch(pc *ipv4.PacketConn, bufs [][]byte, sizes
|
|||||||
|
|
||||||
// receiveIPv4Simple uses simple ReadFromUDP for non-Linux platforms
|
// receiveIPv4Simple uses simple ReadFromUDP for non-Linux platforms
|
||||||
func (b *SharedBind) receiveIPv4Simple(conn *net.UDPConn, bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (int, error) {
|
func (b *SharedBind) receiveIPv4Simple(conn *net.UDPConn, bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (int, error) {
|
||||||
|
// Set a read deadline so we can periodically check for rebind state
|
||||||
|
// This prevents blocking forever on a socket that's about to be closed
|
||||||
|
conn.SetReadDeadline(time.Now().Add(1 * time.Second))
|
||||||
for {
|
for {
|
||||||
n, addr, err := conn.ReadFromUDP(bufs[0])
|
n, addr, err := conn.ReadFromUDP(bufs[0])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
// Check if this is a timeout - if so, just return the error
|
||||||
|
// so the caller can check rebind state and retry
|
||||||
|
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user