mirror of
https://github.com/fosrl/newt.git
synced 2026-03-06 10:46:40 +00:00
Holepunch tester working?
This commit is contained in:
@@ -12,7 +12,7 @@ import (
|
||||
"github.com/fosrl/newt/util"
|
||||
"golang.org/x/crypto/chacha20poly1305"
|
||||
"golang.org/x/crypto/curve25519"
|
||||
"golang.org/x/exp/rand"
|
||||
mrand "golang.org/x/exp/rand"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
)
|
||||
|
||||
@@ -559,7 +559,7 @@ func encryptPayload(payload []byte, serverPublicKey string) (interface{}, error)
|
||||
|
||||
// Generate a random nonce
|
||||
nonce := make([]byte, aead.NonceSize())
|
||||
if _, err := rand.Read(nonce); err != nil {
|
||||
if _, err := mrand.Read(nonce); err != nil {
|
||||
return nil, fmt.Errorf("failed to generate nonce: %v", err)
|
||||
}
|
||||
|
||||
|
||||
340
holepunch/tester.go
Normal file
340
holepunch/tester.go
Normal file
@@ -0,0 +1,340 @@
|
||||
package holepunch
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/fosrl/newt/bind"
|
||||
"github.com/fosrl/newt/logger"
|
||||
"github.com/fosrl/newt/util"
|
||||
)
|
||||
|
||||
// TestResult represents the result of a connection test
|
||||
type TestResult struct {
|
||||
// Success indicates whether the test was successful
|
||||
Success bool
|
||||
// RTT is the round-trip time of the test packet
|
||||
RTT time.Duration
|
||||
// Endpoint is the endpoint that was tested
|
||||
Endpoint string
|
||||
// Error contains any error that occurred during the test
|
||||
Error error
|
||||
}
|
||||
|
||||
// TestConnectionOptions configures the connection test
|
||||
type TestConnectionOptions struct {
|
||||
// Timeout is how long to wait for a response (default: 5 seconds)
|
||||
Timeout time.Duration
|
||||
// Retries is the number of times to retry on failure (default: 0)
|
||||
Retries int
|
||||
}
|
||||
|
||||
// DefaultTestOptions returns the default test options
|
||||
func DefaultTestOptions() TestConnectionOptions {
|
||||
return TestConnectionOptions{
|
||||
Timeout: 5 * time.Second,
|
||||
Retries: 0,
|
||||
}
|
||||
}
|
||||
|
||||
// HolepunchTester monitors holepunch connectivity using magic packets
|
||||
type HolepunchTester struct {
|
||||
sharedBind *bind.SharedBind
|
||||
mu sync.RWMutex
|
||||
running bool
|
||||
stopChan chan struct{}
|
||||
|
||||
// Pending requests waiting for responses (key: echo data as string)
|
||||
pendingRequests sync.Map // map[string]*pendingRequest
|
||||
|
||||
// Callback when connection status changes
|
||||
callback HolepunchStatusCallback
|
||||
}
|
||||
|
||||
// HolepunchStatus represents the status of a holepunch connection
|
||||
type HolepunchStatus struct {
|
||||
Endpoint string
|
||||
Connected bool
|
||||
RTT time.Duration
|
||||
}
|
||||
|
||||
// HolepunchStatusCallback is called when holepunch status changes
|
||||
type HolepunchStatusCallback func(status HolepunchStatus)
|
||||
|
||||
// pendingRequest tracks a pending test request
|
||||
type pendingRequest struct {
|
||||
endpoint string
|
||||
sentAt time.Time
|
||||
replyChan chan time.Duration
|
||||
}
|
||||
|
||||
// NewHolepunchTester creates a new holepunch tester using the given SharedBind
|
||||
func NewHolepunchTester(sharedBind *bind.SharedBind) *HolepunchTester {
|
||||
return &HolepunchTester{
|
||||
sharedBind: sharedBind,
|
||||
}
|
||||
}
|
||||
|
||||
// SetCallback sets the callback for connection status changes
|
||||
func (t *HolepunchTester) SetCallback(callback HolepunchStatusCallback) {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
t.callback = callback
|
||||
}
|
||||
|
||||
// Start begins listening for magic packet responses
|
||||
func (t *HolepunchTester) Start() error {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
|
||||
if t.running {
|
||||
return fmt.Errorf("tester already running")
|
||||
}
|
||||
|
||||
if t.sharedBind == nil {
|
||||
return fmt.Errorf("sharedBind is nil")
|
||||
}
|
||||
|
||||
t.running = true
|
||||
t.stopChan = make(chan struct{})
|
||||
|
||||
// Register our callback with the SharedBind to receive magic responses
|
||||
t.sharedBind.SetMagicResponseCallback(t.handleResponse)
|
||||
|
||||
logger.Debug("HolepunchTester started")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop stops the tester
|
||||
func (t *HolepunchTester) Stop() {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
|
||||
if !t.running {
|
||||
return
|
||||
}
|
||||
|
||||
t.running = false
|
||||
close(t.stopChan)
|
||||
|
||||
// Clear the callback
|
||||
if t.sharedBind != nil {
|
||||
t.sharedBind.SetMagicResponseCallback(nil)
|
||||
}
|
||||
|
||||
// Cancel all pending requests
|
||||
t.pendingRequests.Range(func(key, value interface{}) bool {
|
||||
if req, ok := value.(*pendingRequest); ok {
|
||||
close(req.replyChan)
|
||||
}
|
||||
t.pendingRequests.Delete(key)
|
||||
return true
|
||||
})
|
||||
|
||||
logger.Debug("HolepunchTester stopped")
|
||||
}
|
||||
|
||||
// handleResponse is called by SharedBind when a magic response is received
|
||||
func (t *HolepunchTester) handleResponse(addr netip.AddrPort, echoData []byte) {
|
||||
key := string(echoData)
|
||||
|
||||
value, ok := t.pendingRequests.LoadAndDelete(key)
|
||||
if !ok {
|
||||
// No matching request found
|
||||
return
|
||||
}
|
||||
|
||||
req := value.(*pendingRequest)
|
||||
rtt := time.Since(req.sentAt)
|
||||
|
||||
// Send RTT to the waiting goroutine (non-blocking)
|
||||
select {
|
||||
case req.replyChan <- rtt:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
// TestEndpoint sends a magic test packet to the endpoint and waits for a response.
|
||||
// This uses the SharedBind so packets come from the same source port as WireGuard.
|
||||
func (t *HolepunchTester) TestEndpoint(endpoint string, timeout time.Duration) TestResult {
|
||||
result := TestResult{
|
||||
Endpoint: endpoint,
|
||||
}
|
||||
|
||||
t.mu.RLock()
|
||||
running := t.running
|
||||
sharedBind := t.sharedBind
|
||||
t.mu.RUnlock()
|
||||
|
||||
if !running {
|
||||
result.Error = fmt.Errorf("tester not running")
|
||||
return result
|
||||
}
|
||||
|
||||
if sharedBind == nil || sharedBind.IsClosed() {
|
||||
result.Error = fmt.Errorf("sharedBind is nil or closed")
|
||||
return result
|
||||
}
|
||||
|
||||
// Resolve the endpoint
|
||||
host, err := util.ResolveDomain(endpoint)
|
||||
if err != nil {
|
||||
host = endpoint
|
||||
}
|
||||
|
||||
_, _, err = net.SplitHostPort(host)
|
||||
if err != nil {
|
||||
host = net.JoinHostPort(host, "21820")
|
||||
}
|
||||
|
||||
remoteAddr, err := net.ResolveUDPAddr("udp", host)
|
||||
if err != nil {
|
||||
result.Error = fmt.Errorf("failed to resolve UDP address %s: %w", host, err)
|
||||
return result
|
||||
}
|
||||
|
||||
// Generate random data for the test packet
|
||||
randomData := make([]byte, bind.MagicPacketDataLen)
|
||||
if _, err := rand.Read(randomData); err != nil {
|
||||
result.Error = fmt.Errorf("failed to generate random data: %w", err)
|
||||
return result
|
||||
}
|
||||
|
||||
// Create a pending request
|
||||
req := &pendingRequest{
|
||||
endpoint: endpoint,
|
||||
sentAt: time.Now(),
|
||||
replyChan: make(chan time.Duration, 1),
|
||||
}
|
||||
|
||||
key := string(randomData)
|
||||
t.pendingRequests.Store(key, req)
|
||||
|
||||
// Build the test request packet
|
||||
request := make([]byte, bind.MagicTestRequestLen)
|
||||
copy(request, bind.MagicTestRequest)
|
||||
copy(request[len(bind.MagicTestRequest):], randomData)
|
||||
|
||||
// Send the test packet
|
||||
_, err = sharedBind.WriteToUDP(request, remoteAddr)
|
||||
if err != nil {
|
||||
t.pendingRequests.Delete(key)
|
||||
result.Error = fmt.Errorf("failed to send test packet: %w", err)
|
||||
return result
|
||||
}
|
||||
|
||||
// Wait for response with timeout
|
||||
select {
|
||||
case rtt, ok := <-req.replyChan:
|
||||
if ok {
|
||||
result.Success = true
|
||||
result.RTT = rtt
|
||||
} else {
|
||||
result.Error = fmt.Errorf("request cancelled")
|
||||
}
|
||||
case <-time.After(timeout):
|
||||
t.pendingRequests.Delete(key)
|
||||
result.Error = fmt.Errorf("timeout waiting for response")
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// TestConnectionWithBind sends a magic test packet using an existing SharedBind.
|
||||
// This is useful when you want to test the connection through the same socket
|
||||
// that WireGuard is using, which tests the actual hole-punched path.
|
||||
func TestConnectionWithBind(sharedBind *bind.SharedBind, endpoint string, opts *TestConnectionOptions) TestResult {
|
||||
if opts == nil {
|
||||
defaultOpts := DefaultTestOptions()
|
||||
opts = &defaultOpts
|
||||
}
|
||||
|
||||
result := TestResult{
|
||||
Endpoint: endpoint,
|
||||
}
|
||||
|
||||
if sharedBind == nil {
|
||||
result.Error = fmt.Errorf("sharedBind is nil")
|
||||
return result
|
||||
}
|
||||
|
||||
if sharedBind.IsClosed() {
|
||||
result.Error = fmt.Errorf("sharedBind is closed")
|
||||
return result
|
||||
}
|
||||
|
||||
// Resolve the endpoint
|
||||
host, err := util.ResolveDomain(endpoint)
|
||||
if err != nil {
|
||||
host = endpoint
|
||||
}
|
||||
|
||||
_, _, err = net.SplitHostPort(host)
|
||||
if err != nil {
|
||||
host = net.JoinHostPort(host, "21820")
|
||||
}
|
||||
|
||||
remoteAddr, err := net.ResolveUDPAddr("udp", host)
|
||||
if err != nil {
|
||||
result.Error = fmt.Errorf("failed to resolve UDP address %s: %w", host, err)
|
||||
return result
|
||||
}
|
||||
|
||||
// Generate random data for the test packet
|
||||
randomData := make([]byte, bind.MagicPacketDataLen)
|
||||
if _, err := rand.Read(randomData); err != nil {
|
||||
result.Error = fmt.Errorf("failed to generate random data: %w", err)
|
||||
return result
|
||||
}
|
||||
|
||||
// Build the test request packet
|
||||
request := make([]byte, bind.MagicTestRequestLen)
|
||||
copy(request, bind.MagicTestRequest)
|
||||
copy(request[len(bind.MagicTestRequest):], randomData)
|
||||
|
||||
// Get the underlying UDP connection to set read deadline and read response
|
||||
udpConn := sharedBind.GetUDPConn()
|
||||
if udpConn == nil {
|
||||
result.Error = fmt.Errorf("could not get UDP connection from SharedBind")
|
||||
return result
|
||||
}
|
||||
|
||||
attempts := opts.Retries + 1
|
||||
for attempt := 0; attempt < attempts; attempt++ {
|
||||
if attempt > 0 {
|
||||
logger.Debug("Retrying connection test to %s (attempt %d/%d)", endpoint, attempt+1, attempts)
|
||||
}
|
||||
|
||||
// Note: We can't easily set a read deadline on the shared connection
|
||||
// without affecting WireGuard, so we use a goroutine with timeout instead
|
||||
startTime := time.Now()
|
||||
|
||||
// Send the test packet through the shared bind
|
||||
_, err = sharedBind.WriteToUDP(request, remoteAddr)
|
||||
if err != nil {
|
||||
result.Error = fmt.Errorf("failed to send test packet: %w", err)
|
||||
if attempt < attempts-1 {
|
||||
continue
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// For shared bind test, we send the packet but can't easily wait for
|
||||
// response without interfering with WireGuard's receive loop.
|
||||
// The response will be handled by SharedBind automatically.
|
||||
// We consider the test successful if the send succeeded.
|
||||
// For a full round-trip test, use TestConnection() with a separate socket.
|
||||
|
||||
result.RTT = time.Since(startTime)
|
||||
result.Success = true
|
||||
result.Error = nil
|
||||
logger.Debug("Test packet sent to %s via SharedBind", endpoint)
|
||||
return result
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
Reference in New Issue
Block a user