From 2256d1f04176c25c648c68463ca60f128766928e Mon Sep 17 00:00:00 2001 From: Owen Date: Mon, 1 Dec 2025 17:44:33 -0500 Subject: [PATCH] Holepunch tester working? --- bind/shared_bind.go | 141 +++++++++++++++-- clients/clients.go | 1 - holepunch/holepunch.go | 4 +- holepunch/tester.go | 340 +++++++++++++++++++++++++++++++++++++++++ 4 files changed, 471 insertions(+), 15 deletions(-) create mode 100644 holepunch/tester.go diff --git a/bind/shared_bind.go b/bind/shared_bind.go index 52f9fcc..230990b 100644 --- a/bind/shared_bind.go +++ b/bind/shared_bind.go @@ -3,6 +3,7 @@ package bind import ( + "bytes" "fmt" "net" "net/netip" @@ -15,6 +16,30 @@ import ( wgConn "golang.zx2c4.com/wireguard/conn" ) +// Magic packet constants for connection testing +// These packets are intercepted by SharedBind and responded to directly, +// without being passed to the WireGuard device. +var ( + // MagicTestRequest is the prefix for a test request packet + // Format: PANGOLIN_TEST_REQ + 8 bytes of random data (for echo) + MagicTestRequest = []byte("PANGOLIN_TEST_REQ") + + // MagicTestResponse is the prefix for a test response packet + // Format: PANGOLIN_TEST_RSP + 8 bytes echoed from request + MagicTestResponse = []byte("PANGOLIN_TEST_RSP") +) + +const ( + // MagicPacketDataLen is the length of random data included in test packets + MagicPacketDataLen = 8 + + // MagicTestRequestLen is the total length of a test request packet + MagicTestRequestLen = 17 + MagicPacketDataLen // len("PANGOLIN_TEST_REQ") + 8 + + // MagicTestResponseLen is the total length of a test response packet + MagicTestResponseLen = 17 + MagicPacketDataLen // len("PANGOLIN_TEST_RSP") + 8 +) + // PacketSource identifies where a packet came from type PacketSource uint8 @@ -115,8 +140,14 @@ type SharedBind struct { // Shutdown signal for receive goroutines closeChan chan struct{} + + // Callback for magic test responses (used for holepunch testing) + magicResponseCallback atomic.Pointer[func(addr netip.AddrPort, echoData []byte)] } +// MagicResponseCallback is the function signature for magic packet response callbacks +type MagicResponseCallback func(addr netip.AddrPort, echoData []byte) + // New creates a new SharedBind from an existing UDP connection. // The SharedBind takes ownership of the connection and will close it // when all references are released. @@ -273,6 +304,21 @@ func (b *SharedBind) IsClosed() bool { return b.closed.Load() } +// SetMagicResponseCallback sets a callback function that will be called when +// a magic test response packet is received. This is used for holepunch testing. +// Pass nil to clear the callback. +func (b *SharedBind) SetMagicResponseCallback(callback MagicResponseCallback) { + if callback == nil { + b.magicResponseCallback.Store(nil) + } else { + // Convert to the function type the atomic.Pointer expects + fn := func(addr netip.AddrPort, echoData []byte) { + callback(addr, echoData) + } + b.magicResponseCallback.Store(&fn) + } +} + // WriteToUDP writes data to a specific UDP address. // This is thread-safe and can be used by hole punch senders. func (b *SharedBind) WriteToUDP(data []byte, addr *net.UDPAddr) (int, error) { @@ -397,37 +443,108 @@ func (b *SharedBind) receiveIPv4Batch(pc *ipv4.PacketConn, bufs [][]byte, sizes return 0, err } + // Process messages and filter out magic packets + writeIdx := 0 for i := 0; i < numMsgs; i++ { - sizes[i] = b.ipv4Msgs[i].N - if sizes[i] == 0 { + if b.ipv4Msgs[i].N == 0 { continue } + // Check for magic packet + if b.ipv4Msgs[i].Addr != nil { + if udpAddr, ok := b.ipv4Msgs[i].Addr.(*net.UDPAddr); ok { + data := bufs[i][:b.ipv4Msgs[i].N] + if b.handleMagicPacket(data, udpAddr) { + // Magic packet handled, skip this message + continue + } + } + } + + // Not a magic packet, include in output + if writeIdx != i { + // Need to copy data to the correct position + copy(bufs[writeIdx], bufs[i][:b.ipv4Msgs[i].N]) + } + sizes[writeIdx] = b.ipv4Msgs[i].N + if b.ipv4Msgs[i].Addr != nil { if udpAddr, ok := b.ipv4Msgs[i].Addr.(*net.UDPAddr); ok { addrPort := udpAddr.AddrPort() - eps[i] = &wgConn.StdNetEndpoint{AddrPort: addrPort} + eps[writeIdx] = &wgConn.StdNetEndpoint{AddrPort: addrPort} } } + writeIdx++ } - return numMsgs, nil + return writeIdx, nil } // receiveIPv4Simple uses simple ReadFromUDP for non-Linux platforms func (b *SharedBind) receiveIPv4Simple(conn *net.UDPConn, bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (int, error) { - n, addr, err := conn.ReadFromUDP(bufs[0]) - if err != nil { - return 0, err + for { + n, addr, err := conn.ReadFromUDP(bufs[0]) + if err != nil { + return 0, err + } + + // Check for magic test packet and handle it directly + if b.handleMagicPacket(bufs[0][:n], addr) { + // Magic packet was handled, read another packet + continue + } + + sizes[0] = n + if addr != nil { + addrPort := addr.AddrPort() + eps[0] = &wgConn.StdNetEndpoint{AddrPort: addrPort} + } + + return 1, nil + } +} + +// handleMagicPacket checks if the packet is a magic test packet and responds if so. +// Returns true if the packet was a magic packet and was handled (should not be passed to WireGuard). +func (b *SharedBind) handleMagicPacket(data []byte, addr *net.UDPAddr) bool { + // Check if this is a test request packet + if len(data) >= MagicTestRequestLen && bytes.HasPrefix(data, MagicTestRequest) { + // Extract the random data portion to echo back + echoData := data[len(MagicTestRequest) : len(MagicTestRequest)+MagicPacketDataLen] + + // Build response packet + response := make([]byte, MagicTestResponseLen) + copy(response, MagicTestResponse) + copy(response[len(MagicTestResponse):], echoData) + + // Send response back to sender + b.mu.RLock() + conn := b.udpConn + b.mu.RUnlock() + + if conn != nil { + _, _ = conn.WriteToUDP(response, addr) + } + + return true } - sizes[0] = n - if addr != nil { - addrPort := addr.AddrPort() - eps[0] = &wgConn.StdNetEndpoint{AddrPort: addrPort} + // Check if this is a test response packet + if len(data) >= MagicTestResponseLen && bytes.HasPrefix(data, MagicTestResponse) { + // Extract the echoed data + echoData := data[len(MagicTestResponse) : len(MagicTestResponse)+MagicPacketDataLen] + + // Call the callback if set + callbackPtr := b.magicResponseCallback.Load() + if callbackPtr != nil { + callback := *callbackPtr + callback(addr.AddrPort(), echoData) + } + + return true } - return 1, nil + return false } // Send implements the WireGuard Bind interface. diff --git a/clients/clients.go b/clients/clients.go index cd1fbab..c78e576 100644 --- a/clients/clients.go +++ b/clients/clients.go @@ -148,7 +148,6 @@ func NewWireGuardService(interfaceName string, mtu int, host string, newtId stri mtu: mtu, client: wsClient, key: key, - keyFilePath: generateAndSaveKeyTo, newtId: newtId, host: host, lastReadings: make(map[string]PeerReading), diff --git a/holepunch/holepunch.go b/holepunch/holepunch.go index 41d3846..81ddcea 100644 --- a/holepunch/holepunch.go +++ b/holepunch/holepunch.go @@ -12,7 +12,7 @@ import ( "github.com/fosrl/newt/util" "golang.org/x/crypto/chacha20poly1305" "golang.org/x/crypto/curve25519" - "golang.org/x/exp/rand" + mrand "golang.org/x/exp/rand" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) @@ -559,7 +559,7 @@ func encryptPayload(payload []byte, serverPublicKey string) (interface{}, error) // Generate a random nonce nonce := make([]byte, aead.NonceSize()) - if _, err := rand.Read(nonce); err != nil { + if _, err := mrand.Read(nonce); err != nil { return nil, fmt.Errorf("failed to generate nonce: %v", err) } diff --git a/holepunch/tester.go b/holepunch/tester.go new file mode 100644 index 0000000..27852c9 --- /dev/null +++ b/holepunch/tester.go @@ -0,0 +1,340 @@ +package holepunch + +import ( + "crypto/rand" + "fmt" + "net" + "net/netip" + "sync" + "time" + + "github.com/fosrl/newt/bind" + "github.com/fosrl/newt/logger" + "github.com/fosrl/newt/util" +) + +// TestResult represents the result of a connection test +type TestResult struct { + // Success indicates whether the test was successful + Success bool + // RTT is the round-trip time of the test packet + RTT time.Duration + // Endpoint is the endpoint that was tested + Endpoint string + // Error contains any error that occurred during the test + Error error +} + +// TestConnectionOptions configures the connection test +type TestConnectionOptions struct { + // Timeout is how long to wait for a response (default: 5 seconds) + Timeout time.Duration + // Retries is the number of times to retry on failure (default: 0) + Retries int +} + +// DefaultTestOptions returns the default test options +func DefaultTestOptions() TestConnectionOptions { + return TestConnectionOptions{ + Timeout: 5 * time.Second, + Retries: 0, + } +} + +// HolepunchTester monitors holepunch connectivity using magic packets +type HolepunchTester struct { + sharedBind *bind.SharedBind + mu sync.RWMutex + running bool + stopChan chan struct{} + + // Pending requests waiting for responses (key: echo data as string) + pendingRequests sync.Map // map[string]*pendingRequest + + // Callback when connection status changes + callback HolepunchStatusCallback +} + +// HolepunchStatus represents the status of a holepunch connection +type HolepunchStatus struct { + Endpoint string + Connected bool + RTT time.Duration +} + +// HolepunchStatusCallback is called when holepunch status changes +type HolepunchStatusCallback func(status HolepunchStatus) + +// pendingRequest tracks a pending test request +type pendingRequest struct { + endpoint string + sentAt time.Time + replyChan chan time.Duration +} + +// NewHolepunchTester creates a new holepunch tester using the given SharedBind +func NewHolepunchTester(sharedBind *bind.SharedBind) *HolepunchTester { + return &HolepunchTester{ + sharedBind: sharedBind, + } +} + +// SetCallback sets the callback for connection status changes +func (t *HolepunchTester) SetCallback(callback HolepunchStatusCallback) { + t.mu.Lock() + defer t.mu.Unlock() + t.callback = callback +} + +// Start begins listening for magic packet responses +func (t *HolepunchTester) Start() error { + t.mu.Lock() + defer t.mu.Unlock() + + if t.running { + return fmt.Errorf("tester already running") + } + + if t.sharedBind == nil { + return fmt.Errorf("sharedBind is nil") + } + + t.running = true + t.stopChan = make(chan struct{}) + + // Register our callback with the SharedBind to receive magic responses + t.sharedBind.SetMagicResponseCallback(t.handleResponse) + + logger.Debug("HolepunchTester started") + return nil +} + +// Stop stops the tester +func (t *HolepunchTester) Stop() { + t.mu.Lock() + defer t.mu.Unlock() + + if !t.running { + return + } + + t.running = false + close(t.stopChan) + + // Clear the callback + if t.sharedBind != nil { + t.sharedBind.SetMagicResponseCallback(nil) + } + + // Cancel all pending requests + t.pendingRequests.Range(func(key, value interface{}) bool { + if req, ok := value.(*pendingRequest); ok { + close(req.replyChan) + } + t.pendingRequests.Delete(key) + return true + }) + + logger.Debug("HolepunchTester stopped") +} + +// handleResponse is called by SharedBind when a magic response is received +func (t *HolepunchTester) handleResponse(addr netip.AddrPort, echoData []byte) { + key := string(echoData) + + value, ok := t.pendingRequests.LoadAndDelete(key) + if !ok { + // No matching request found + return + } + + req := value.(*pendingRequest) + rtt := time.Since(req.sentAt) + + // Send RTT to the waiting goroutine (non-blocking) + select { + case req.replyChan <- rtt: + default: + } +} + +// TestEndpoint sends a magic test packet to the endpoint and waits for a response. +// This uses the SharedBind so packets come from the same source port as WireGuard. +func (t *HolepunchTester) TestEndpoint(endpoint string, timeout time.Duration) TestResult { + result := TestResult{ + Endpoint: endpoint, + } + + t.mu.RLock() + running := t.running + sharedBind := t.sharedBind + t.mu.RUnlock() + + if !running { + result.Error = fmt.Errorf("tester not running") + return result + } + + if sharedBind == nil || sharedBind.IsClosed() { + result.Error = fmt.Errorf("sharedBind is nil or closed") + return result + } + + // Resolve the endpoint + host, err := util.ResolveDomain(endpoint) + if err != nil { + host = endpoint + } + + _, _, err = net.SplitHostPort(host) + if err != nil { + host = net.JoinHostPort(host, "21820") + } + + remoteAddr, err := net.ResolveUDPAddr("udp", host) + if err != nil { + result.Error = fmt.Errorf("failed to resolve UDP address %s: %w", host, err) + return result + } + + // Generate random data for the test packet + randomData := make([]byte, bind.MagicPacketDataLen) + if _, err := rand.Read(randomData); err != nil { + result.Error = fmt.Errorf("failed to generate random data: %w", err) + return result + } + + // Create a pending request + req := &pendingRequest{ + endpoint: endpoint, + sentAt: time.Now(), + replyChan: make(chan time.Duration, 1), + } + + key := string(randomData) + t.pendingRequests.Store(key, req) + + // Build the test request packet + request := make([]byte, bind.MagicTestRequestLen) + copy(request, bind.MagicTestRequest) + copy(request[len(bind.MagicTestRequest):], randomData) + + // Send the test packet + _, err = sharedBind.WriteToUDP(request, remoteAddr) + if err != nil { + t.pendingRequests.Delete(key) + result.Error = fmt.Errorf("failed to send test packet: %w", err) + return result + } + + // Wait for response with timeout + select { + case rtt, ok := <-req.replyChan: + if ok { + result.Success = true + result.RTT = rtt + } else { + result.Error = fmt.Errorf("request cancelled") + } + case <-time.After(timeout): + t.pendingRequests.Delete(key) + result.Error = fmt.Errorf("timeout waiting for response") + } + + return result +} + +// TestConnectionWithBind sends a magic test packet using an existing SharedBind. +// This is useful when you want to test the connection through the same socket +// that WireGuard is using, which tests the actual hole-punched path. +func TestConnectionWithBind(sharedBind *bind.SharedBind, endpoint string, opts *TestConnectionOptions) TestResult { + if opts == nil { + defaultOpts := DefaultTestOptions() + opts = &defaultOpts + } + + result := TestResult{ + Endpoint: endpoint, + } + + if sharedBind == nil { + result.Error = fmt.Errorf("sharedBind is nil") + return result + } + + if sharedBind.IsClosed() { + result.Error = fmt.Errorf("sharedBind is closed") + return result + } + + // Resolve the endpoint + host, err := util.ResolveDomain(endpoint) + if err != nil { + host = endpoint + } + + _, _, err = net.SplitHostPort(host) + if err != nil { + host = net.JoinHostPort(host, "21820") + } + + remoteAddr, err := net.ResolveUDPAddr("udp", host) + if err != nil { + result.Error = fmt.Errorf("failed to resolve UDP address %s: %w", host, err) + return result + } + + // Generate random data for the test packet + randomData := make([]byte, bind.MagicPacketDataLen) + if _, err := rand.Read(randomData); err != nil { + result.Error = fmt.Errorf("failed to generate random data: %w", err) + return result + } + + // Build the test request packet + request := make([]byte, bind.MagicTestRequestLen) + copy(request, bind.MagicTestRequest) + copy(request[len(bind.MagicTestRequest):], randomData) + + // Get the underlying UDP connection to set read deadline and read response + udpConn := sharedBind.GetUDPConn() + if udpConn == nil { + result.Error = fmt.Errorf("could not get UDP connection from SharedBind") + return result + } + + attempts := opts.Retries + 1 + for attempt := 0; attempt < attempts; attempt++ { + if attempt > 0 { + logger.Debug("Retrying connection test to %s (attempt %d/%d)", endpoint, attempt+1, attempts) + } + + // Note: We can't easily set a read deadline on the shared connection + // without affecting WireGuard, so we use a goroutine with timeout instead + startTime := time.Now() + + // Send the test packet through the shared bind + _, err = sharedBind.WriteToUDP(request, remoteAddr) + if err != nil { + result.Error = fmt.Errorf("failed to send test packet: %w", err) + if attempt < attempts-1 { + continue + } + return result + } + + // For shared bind test, we send the packet but can't easily wait for + // response without interfering with WireGuard's receive loop. + // The response will be handled by SharedBind automatically. + // We consider the test successful if the send succeeded. + // For a full round-trip test, use TestConnection() with a separate socket. + + result.RTT = time.Since(startTime) + result.Success = true + result.Error = nil + logger.Debug("Test packet sent to %s via SharedBind", endpoint) + return result + } + + return result +}