From 01ec6a0ce0fbd6d51f50659e40fa7e2b3b030eca Mon Sep 17 00:00:00 2001 From: Owen Date: Mon, 1 Dec 2025 13:54:14 -0500 Subject: [PATCH] Handle holepunches better --- clients/clients.go | 2 +- holepunch/holepunch.go | 303 ++++++++++++++++++++++++++++++++++++----- 2 files changed, 268 insertions(+), 37 deletions(-) diff --git a/clients/clients.go b/clients/clients.go index 68fb780..4b4f2b5 100644 --- a/clients/clients.go +++ b/clients/clients.go @@ -184,7 +184,7 @@ func NewWireGuardService(interfaceName string, mtu int, generateAndSaveKeyTo str // Create the holepunch manager with ResolveDomain function // We'll need to pass a domain resolver function - service.holePunchManager = holepunch.NewManager(sharedBind, newtId, "newt") + service.holePunchManager = holepunch.NewManager(sharedBind, newtId, "newt", key.PublicKey().String()) // Register websocket handlers wsClient.RegisterHandler("newt/wg/receive-config", service.handleConfig) diff --git a/holepunch/holepunch.go b/holepunch/holepunch.go index 4c09906..41d3846 100644 --- a/holepunch/holepunch.go +++ b/holepunch/holepunch.go @@ -30,20 +30,29 @@ type Manager struct { sharedBind *bind.SharedBind ID string token string + publicKey string clientType string + exitNodes map[string]ExitNode // key is endpoint + updateChan chan struct{} // signals the goroutine to refresh exit nodes + + sendHolepunchInterval time.Duration } +const sendHolepunchIntervalMax = 60 * time.Second +const sendHolepunchIntervalMin = 1 * time.Second + // NewManager creates a new hole punch manager -func NewManager(sharedBind *bind.SharedBind, ID string, clientType string) *Manager { +func NewManager(sharedBind *bind.SharedBind, ID string, clientType string, publicKey string) *Manager { return &Manager{ - sharedBind: sharedBind, - ID: ID, - clientType: clientType, + sharedBind: sharedBind, + ID: ID, + clientType: clientType, + publicKey: publicKey, + exitNodes: make(map[string]ExitNode), + sendHolepunchInterval: sendHolepunchIntervalMin, } } -const sendHolepunchInterval = 15 * time.Second - // SetToken updates the authentication token used for hole punching func (m *Manager) SetToken(token string) { m.mu.Lock() @@ -72,10 +81,129 @@ func (m *Manager) Stop() { m.stopChan = nil } + if m.updateChan != nil { + close(m.updateChan) + m.updateChan = nil + } + m.running = false logger.Info("Hole punch manager stopped") } +// AddExitNode adds a new exit node to the rotation if it doesn't already exist +func (m *Manager) AddExitNode(exitNode ExitNode) bool { + m.mu.Lock() + defer m.mu.Unlock() + + if _, exists := m.exitNodes[exitNode.Endpoint]; exists { + logger.Debug("Exit node %s already exists in rotation", exitNode.Endpoint) + return false + } + + m.exitNodes[exitNode.Endpoint] = exitNode + logger.Info("Added exit node %s to hole punch rotation", exitNode.Endpoint) + + // Signal the goroutine to refresh if running + if m.running && m.updateChan != nil { + select { + case m.updateChan <- struct{}{}: + default: + // Channel full or closed, skip + } + } + + return true +} + +// RemoveExitNode removes an exit node from the rotation +func (m *Manager) RemoveExitNode(endpoint string) bool { + m.mu.Lock() + defer m.mu.Unlock() + + if _, exists := m.exitNodes[endpoint]; !exists { + logger.Debug("Exit node %s not found in rotation", endpoint) + return false + } + + delete(m.exitNodes, endpoint) + logger.Info("Removed exit node %s from hole punch rotation", endpoint) + + // Signal the goroutine to refresh if running + if m.running && m.updateChan != nil { + select { + case m.updateChan <- struct{}{}: + default: + // Channel full or closed, skip + } + } + + return true +} + +// GetExitNodes returns a copy of the current exit nodes +func (m *Manager) GetExitNodes() []ExitNode { + m.mu.Lock() + defer m.mu.Unlock() + + nodes := make([]ExitNode, 0, len(m.exitNodes)) + for _, node := range m.exitNodes { + nodes = append(nodes, node) + } + return nodes +} + +// TriggerHolePunch sends an immediate hole punch packet to all configured exit nodes +// This is useful for triggering hole punching on demand without waiting for the interval +func (m *Manager) TriggerHolePunch() error { + m.mu.Lock() + + if len(m.exitNodes) == 0 { + m.mu.Unlock() + return fmt.Errorf("no exit nodes configured") + } + + // Get a copy of exit nodes to work with + currentExitNodes := make([]ExitNode, 0, len(m.exitNodes)) + for _, node := range m.exitNodes { + currentExitNodes = append(currentExitNodes, node) + } + m.mu.Unlock() + + logger.Info("Triggering on-demand hole punch to %d exit nodes", len(currentExitNodes)) + + // Send hole punch to all exit nodes + successCount := 0 + for _, exitNode := range currentExitNodes { + host, err := util.ResolveDomain(exitNode.Endpoint) + if err != nil { + logger.Warn("Failed to resolve endpoint %s: %v", exitNode.Endpoint, err) + continue + } + + serverAddr := net.JoinHostPort(host, "21820") + remoteAddr, err := net.ResolveUDPAddr("udp", serverAddr) + if err != nil { + logger.Error("Failed to resolve UDP address %s: %v", serverAddr, err) + continue + } + + if err := m.sendHolePunch(remoteAddr, exitNode.PublicKey); err != nil { + logger.Warn("Failed to send on-demand hole punch to %s: %v", exitNode.Endpoint, err) + continue + } + + logger.Debug("Sent on-demand hole punch to %s", exitNode.Endpoint) + successCount++ + } + + if successCount == 0 { + return fmt.Errorf("failed to send hole punch to any exit node") + } + + logger.Info("Successfully sent on-demand hole punch to %d/%d exit nodes", successCount, len(currentExitNodes)) + return nil +} + // StartMultipleExitNodes starts hole punching to multiple exit nodes func (m *Manager) StartMultipleExitNodes(exitNodes []ExitNode) error { m.mu.Lock() @@ -92,13 +220,48 @@ func (m *Manager) StartMultipleExitNodes(exitNodes []ExitNode) error { return fmt.Errorf("no exit nodes provided") } + // Populate exit nodes map + m.exitNodes = make(map[string]ExitNode) + for _, node := range exitNodes { + m.exitNodes[node.Endpoint] = node + } + m.running = true m.stopChan = make(chan struct{}) + m.updateChan = make(chan struct{}, 1) m.mu.Unlock() logger.Info("Starting UDP hole punch to %d exit nodes with shared bind", len(exitNodes)) - go m.runMultipleExitNodes(exitNodes) + go m.runMultipleExitNodes() + + return nil +} + +// Start starts hole punching with the current set of exit nodes +func (m *Manager) Start() error { + m.mu.Lock() + + if m.running { + m.mu.Unlock() + logger.Debug("UDP hole punch already running") + return fmt.Errorf("hole punch already running") + } + + if len(m.exitNodes) == 0 { + m.mu.Unlock() + logger.Warn("No exit nodes configured for hole punching") + return fmt.Errorf("no exit nodes configured") + } + + m.running = true + m.stopChan = make(chan struct{}) + m.updateChan = make(chan struct{}, 1) + m.mu.Unlock() + + logger.Info("Starting UDP hole punch with %d exit nodes", len(m.exitNodes)) + + go m.runMultipleExitNodes() return nil } @@ -125,7 +288,7 @@ func (m *Manager) StartSingleEndpoint(endpoint, serverPubKey string) error { } // runMultipleExitNodes performs hole punching to multiple exit nodes -func (m *Manager) runMultipleExitNodes(exitNodes []ExitNode) { +func (m *Manager) runMultipleExitNodes() { defer func() { m.mu.Lock() m.running = false @@ -140,29 +303,41 @@ func (m *Manager) runMultipleExitNodes(exitNodes []ExitNode) { endpointName string } - var resolvedNodes []resolvedExitNode - for _, exitNode := range exitNodes { - host, err := util.ResolveDomain(exitNode.Endpoint) - if err != nil { - logger.Warn("Failed to resolve endpoint %s: %v", exitNode.Endpoint, err) - continue + resolveNodes := func() []resolvedExitNode { + m.mu.Lock() + currentExitNodes := make([]ExitNode, 0, len(m.exitNodes)) + for _, node := range m.exitNodes { + currentExitNodes = append(currentExitNodes, node) } + m.mu.Unlock() - serverAddr := net.JoinHostPort(host, "21820") - remoteAddr, err := net.ResolveUDPAddr("udp", serverAddr) - if err != nil { - logger.Error("Failed to resolve UDP address %s: %v", serverAddr, err) - continue + var resolvedNodes []resolvedExitNode + for _, exitNode := range currentExitNodes { + host, err := util.ResolveDomain(exitNode.Endpoint) + if err != nil { + logger.Warn("Failed to resolve endpoint %s: %v", exitNode.Endpoint, err) + continue + } + + serverAddr := net.JoinHostPort(host, "21820") + remoteAddr, err := net.ResolveUDPAddr("udp", serverAddr) + if err != nil { + logger.Error("Failed to resolve UDP address %s: %v", serverAddr, err) + continue + } + + resolvedNodes = append(resolvedNodes, resolvedExitNode{ + remoteAddr: remoteAddr, + publicKey: exitNode.PublicKey, + endpointName: exitNode.Endpoint, + }) + logger.Info("Resolved exit node: %s -> %s", exitNode.Endpoint, remoteAddr.String()) } - - resolvedNodes = append(resolvedNodes, resolvedExitNode{ - remoteAddr: remoteAddr, - publicKey: exitNode.PublicKey, - endpointName: exitNode.Endpoint, - }) - logger.Info("Resolved exit node: %s -> %s", exitNode.Endpoint, remoteAddr.String()) + return resolvedNodes } + resolvedNodes := resolveNodes() + if len(resolvedNodes) == 0 { logger.Error("No exit nodes could be resolved") return @@ -175,7 +350,12 @@ func (m *Manager) runMultipleExitNodes(exitNodes []ExitNode) { } } - ticker := time.NewTicker(sendHolepunchInterval) + // Start with minimum interval + m.mu.Lock() + m.sendHolepunchInterval = sendHolepunchIntervalMin + m.mu.Unlock() + + ticker := time.NewTicker(m.sendHolepunchInterval) defer ticker.Stop() for { @@ -183,6 +363,24 @@ func (m *Manager) runMultipleExitNodes(exitNodes []ExitNode) { case <-m.stopChan: logger.Debug("Hole punch stopped by signal") return + case <-m.updateChan: + // Re-resolve exit nodes when update is signaled + logger.Info("Refreshing exit nodes for hole punching") + resolvedNodes = resolveNodes() + if len(resolvedNodes) == 0 { + logger.Warn("No exit nodes available after refresh") + } + // Reset interval to minimum on update + m.mu.Lock() + m.sendHolepunchInterval = sendHolepunchIntervalMin + m.mu.Unlock() + ticker.Reset(m.sendHolepunchInterval) + // Send immediate hole punch to newly resolved nodes + for _, node := range resolvedNodes { + if err := m.sendHolePunch(node.remoteAddr, node.publicKey); err != nil { + logger.Debug("Failed to send hole punch to %s: %v", node.endpointName, err) + } + } case <-ticker.C: // Send hole punch to all exit nodes for _, node := range resolvedNodes { @@ -190,6 +388,18 @@ func (m *Manager) runMultipleExitNodes(exitNodes []ExitNode) { logger.Debug("Failed to send hole punch to %s: %v", node.endpointName, err) } } + // Exponential backoff: double the interval up to max + m.mu.Lock() + newInterval := m.sendHolepunchInterval * 2 + if newInterval > sendHolepunchIntervalMax { + newInterval = sendHolepunchIntervalMax + } + if newInterval != m.sendHolepunchInterval { + m.sendHolepunchInterval = newInterval + ticker.Reset(m.sendHolepunchInterval) + logger.Debug("Increased hole punch interval to %v", m.sendHolepunchInterval) + } + m.mu.Unlock() } } } @@ -222,7 +432,12 @@ func (m *Manager) runSingleEndpoint(endpoint, serverPubKey string) { logger.Warn("Failed to send initial hole punch: %v", err) } - ticker := time.NewTicker(sendHolepunchInterval) + // Start with minimum interval + m.mu.Lock() + m.sendHolepunchInterval = sendHolepunchIntervalMin + m.mu.Unlock() + + ticker := time.NewTicker(m.sendHolepunchInterval) defer ticker.Stop() for { @@ -234,6 +449,18 @@ func (m *Manager) runSingleEndpoint(endpoint, serverPubKey string) { if err := m.sendHolePunch(remoteAddr, serverPubKey); err != nil { logger.Debug("Failed to send hole punch: %v", err) } + // Exponential backoff: double the interval up to max + m.mu.Lock() + newInterval := m.sendHolepunchInterval * 2 + if newInterval > sendHolepunchIntervalMax { + newInterval = sendHolepunchIntervalMax + } + if newInterval != m.sendHolepunchInterval { + m.sendHolepunchInterval = newInterval + ticker.Reset(m.sendHolepunchInterval) + logger.Debug("Increased hole punch interval to %v", m.sendHolepunchInterval) + } + m.mu.Unlock() } } } @@ -252,19 +479,23 @@ func (m *Manager) sendHolePunch(remoteAddr *net.UDPAddr, serverPubKey string) er var payload interface{} if m.clientType == "newt" { payload = struct { - ID string `json:"newtId"` - Token string `json:"token"` + ID string `json:"newtId"` + Token string `json:"token"` + PublicKey string `json:"publicKey"` }{ - ID: ID, - Token: token, + ID: ID, + Token: token, + PublicKey: m.publicKey, } } else { payload = struct { - ID string `json:"olmId"` - Token string `json:"token"` + ID string `json:"olmId"` + Token string `json:"token"` + PublicKey string `json:"publicKey"` }{ - ID: ID, - Token: token, + ID: ID, + Token: token, + PublicKey: m.publicKey, } }