Files
newt/holepunch/holepunch.go
2025-12-16 18:33:05 -05:00

520 lines
14 KiB
Go

package holepunch
import (
"encoding/json"
"fmt"
"net"
"strconv"
"sync"
"time"
"github.com/fosrl/newt/bind"
"github.com/fosrl/newt/logger"
"github.com/fosrl/newt/util"
"golang.org/x/crypto/chacha20poly1305"
"golang.org/x/crypto/curve25519"
mrand "golang.org/x/exp/rand"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
// ExitNode represents a WireGuard exit node for hole punching
type ExitNode struct {
Endpoint string `json:"endpoint"`
RelayPort uint16 `json:"relayPort"`
PublicKey string `json:"publicKey"`
}
// Manager handles UDP hole punching operations
type Manager struct {
mu sync.Mutex
running bool
stopChan chan struct{}
sharedBind *bind.SharedBind
ID string
token string
publicKey string
clientType string
exitNodes map[string]ExitNode // key is endpoint
updateChan chan struct{} // signals the goroutine to refresh exit nodes
sendHolepunchInterval time.Duration
}
const sendHolepunchIntervalMax = 60 * time.Second
const sendHolepunchIntervalMin = 1 * time.Second
// NewManager creates a new hole punch manager
func NewManager(sharedBind *bind.SharedBind, ID string, clientType string, publicKey string) *Manager {
return &Manager{
sharedBind: sharedBind,
ID: ID,
clientType: clientType,
publicKey: publicKey,
exitNodes: make(map[string]ExitNode),
sendHolepunchInterval: sendHolepunchIntervalMin,
}
}
// SetToken updates the authentication token used for hole punching
func (m *Manager) SetToken(token string) {
m.mu.Lock()
defer m.mu.Unlock()
m.token = token
}
// IsRunning returns whether hole punching is currently active
func (m *Manager) IsRunning() bool {
m.mu.Lock()
defer m.mu.Unlock()
return m.running
}
// Stop stops any ongoing hole punch operations
func (m *Manager) Stop() {
m.mu.Lock()
defer m.mu.Unlock()
if !m.running {
return
}
if m.stopChan != nil {
close(m.stopChan)
m.stopChan = nil
}
if m.updateChan != nil {
close(m.updateChan)
m.updateChan = nil
}
m.running = false
logger.Info("Hole punch manager stopped")
}
// AddExitNode adds a new exit node to the rotation if it doesn't already exist
func (m *Manager) AddExitNode(exitNode ExitNode) bool {
m.mu.Lock()
defer m.mu.Unlock()
if _, exists := m.exitNodes[exitNode.Endpoint]; exists {
logger.Debug("Exit node %s already exists in rotation", exitNode.Endpoint)
return false
}
m.exitNodes[exitNode.Endpoint] = exitNode
logger.Info("Added exit node %s to hole punch rotation", exitNode.Endpoint)
// Signal the goroutine to refresh if running
if m.running && m.updateChan != nil {
select {
case m.updateChan <- struct{}{}:
default:
// Channel full or closed, skip
}
}
return true
}
// RemoveExitNode removes an exit node from the rotation
func (m *Manager) RemoveExitNode(endpoint string) bool {
m.mu.Lock()
defer m.mu.Unlock()
if _, exists := m.exitNodes[endpoint]; !exists {
logger.Debug("Exit node %s not found in rotation", endpoint)
return false
}
delete(m.exitNodes, endpoint)
logger.Info("Removed exit node %s from hole punch rotation", endpoint)
// Signal the goroutine to refresh if running
if m.running && m.updateChan != nil {
select {
case m.updateChan <- struct{}{}:
default:
// Channel full or closed, skip
}
}
return true
}
// GetExitNodes returns a copy of the current exit nodes
func (m *Manager) GetExitNodes() []ExitNode {
m.mu.Lock()
defer m.mu.Unlock()
nodes := make([]ExitNode, 0, len(m.exitNodes))
for _, node := range m.exitNodes {
nodes = append(nodes, node)
}
return nodes
}
// ResetInterval resets the hole punch interval back to the minimum value,
// allowing it to climb back up through exponential backoff.
// This is useful when network conditions change or connectivity is restored.
func (m *Manager) ResetInterval() {
m.mu.Lock()
defer m.mu.Unlock()
if m.sendHolepunchInterval != sendHolepunchIntervalMin {
m.sendHolepunchInterval = sendHolepunchIntervalMin
logger.Info("Reset hole punch interval to minimum (%v)", sendHolepunchIntervalMin)
}
// Signal the goroutine to apply the new interval if running
if m.running && m.updateChan != nil {
select {
case m.updateChan <- struct{}{}:
default:
// Channel full or closed, skip
}
}
}
// TriggerHolePunch sends an immediate hole punch packet to all configured exit nodes
// This is useful for triggering hole punching on demand without waiting for the interval
func (m *Manager) TriggerHolePunch() error {
m.mu.Lock()
if len(m.exitNodes) == 0 {
m.mu.Unlock()
return fmt.Errorf("no exit nodes configured")
}
// Get a copy of exit nodes to work with
currentExitNodes := make([]ExitNode, 0, len(m.exitNodes))
for _, node := range m.exitNodes {
currentExitNodes = append(currentExitNodes, node)
}
m.mu.Unlock()
logger.Info("Triggering on-demand hole punch to %d exit nodes", len(currentExitNodes))
// Send hole punch to all exit nodes
successCount := 0
for _, exitNode := range currentExitNodes {
host, err := util.ResolveDomain(exitNode.Endpoint)
if err != nil {
logger.Warn("Failed to resolve endpoint %s: %v", exitNode.Endpoint, err)
continue
}
serverAddr := net.JoinHostPort(host, strconv.Itoa(int(exitNode.RelayPort)))
remoteAddr, err := net.ResolveUDPAddr("udp", serverAddr)
if err != nil {
logger.Error("Failed to resolve UDP address %s: %v", serverAddr, err)
continue
}
if err := m.sendHolePunch(remoteAddr, exitNode.PublicKey); err != nil {
logger.Warn("Failed to send on-demand hole punch to %s: %v", exitNode.Endpoint, err)
continue
}
logger.Debug("Sent on-demand hole punch to %s", exitNode.Endpoint)
successCount++
}
if successCount == 0 {
return fmt.Errorf("failed to send hole punch to any exit node")
}
logger.Info("Successfully sent on-demand hole punch to %d/%d exit nodes", successCount, len(currentExitNodes))
return nil
}
// StartMultipleExitNodes starts hole punching to multiple exit nodes
func (m *Manager) StartMultipleExitNodes(exitNodes []ExitNode) error {
m.mu.Lock()
if m.running {
m.mu.Unlock()
logger.Debug("UDP hole punch already running, skipping new request")
return fmt.Errorf("hole punch already running")
}
// Populate exit nodes map
m.exitNodes = make(map[string]ExitNode)
for _, node := range exitNodes {
m.exitNodes[node.Endpoint] = node
}
m.running = true
m.stopChan = make(chan struct{})
m.updateChan = make(chan struct{}, 1)
m.mu.Unlock()
logger.Info("Starting UDP hole punch to %d exit nodes with shared bind", len(exitNodes))
go m.runMultipleExitNodes()
return nil
}
// Start starts hole punching with the current set of exit nodes
func (m *Manager) Start() error {
m.mu.Lock()
if m.running {
m.mu.Unlock()
logger.Debug("UDP hole punch already running")
return fmt.Errorf("hole punch already running")
}
m.running = true
m.stopChan = make(chan struct{})
m.updateChan = make(chan struct{}, 1)
nodeCount := len(m.exitNodes)
m.mu.Unlock()
if nodeCount == 0 {
logger.Info("Starting UDP hole punch manager (waiting for exit nodes to be added)")
} else {
logger.Info("Starting UDP hole punch with %d exit nodes", nodeCount)
}
go m.runMultipleExitNodes()
return nil
}
// runMultipleExitNodes performs hole punching to multiple exit nodes
func (m *Manager) runMultipleExitNodes() {
defer func() {
m.mu.Lock()
m.running = false
m.mu.Unlock()
logger.Info("UDP hole punch goroutine ended for all exit nodes")
}()
// Resolve all endpoints upfront
type resolvedExitNode struct {
remoteAddr *net.UDPAddr
publicKey string
endpointName string
}
resolveNodes := func() []resolvedExitNode {
m.mu.Lock()
currentExitNodes := make([]ExitNode, 0, len(m.exitNodes))
for _, node := range m.exitNodes {
currentExitNodes = append(currentExitNodes, node)
}
m.mu.Unlock()
var resolvedNodes []resolvedExitNode
for _, exitNode := range currentExitNodes {
host, err := util.ResolveDomain(exitNode.Endpoint)
if err != nil {
logger.Warn("Failed to resolve endpoint %s: %v", exitNode.Endpoint, err)
continue
}
serverAddr := net.JoinHostPort(host, strconv.Itoa(int(exitNode.RelayPort)))
remoteAddr, err := net.ResolveUDPAddr("udp", serverAddr)
if err != nil {
logger.Error("Failed to resolve UDP address %s: %v", serverAddr, err)
continue
}
resolvedNodes = append(resolvedNodes, resolvedExitNode{
remoteAddr: remoteAddr,
publicKey: exitNode.PublicKey,
endpointName: exitNode.Endpoint,
})
logger.Info("Resolved exit node: %s -> %s", exitNode.Endpoint, remoteAddr.String())
}
return resolvedNodes
}
resolvedNodes := resolveNodes()
if len(resolvedNodes) == 0 {
logger.Info("No exit nodes available yet, waiting for nodes to be added")
} else {
// Send initial hole punch to all exit nodes
for _, node := range resolvedNodes {
if err := m.sendHolePunch(node.remoteAddr, node.publicKey); err != nil {
logger.Warn("Failed to send initial hole punch to %s: %v", node.endpointName, err)
}
}
}
// Start with minimum interval
m.mu.Lock()
m.sendHolepunchInterval = sendHolepunchIntervalMin
m.mu.Unlock()
ticker := time.NewTicker(m.sendHolepunchInterval)
defer ticker.Stop()
for {
select {
case <-m.stopChan:
logger.Debug("Hole punch stopped by signal")
return
case <-m.updateChan:
// Re-resolve exit nodes when update is signaled
logger.Info("Refreshing exit nodes for hole punching")
resolvedNodes = resolveNodes()
if len(resolvedNodes) == 0 {
logger.Warn("No exit nodes available after refresh")
} else {
logger.Info("Updated resolved nodes count: %d", len(resolvedNodes))
}
// Reset interval to minimum on update
m.mu.Lock()
m.sendHolepunchInterval = sendHolepunchIntervalMin
m.mu.Unlock()
ticker.Reset(m.sendHolepunchInterval)
// Send immediate hole punch to newly resolved nodes
for _, node := range resolvedNodes {
if err := m.sendHolePunch(node.remoteAddr, node.publicKey); err != nil {
logger.Debug("Failed to send hole punch to %s: %v", node.endpointName, err)
}
}
case <-ticker.C:
// Send hole punch to all exit nodes (if any are available)
if len(resolvedNodes) > 0 {
for _, node := range resolvedNodes {
if err := m.sendHolePunch(node.remoteAddr, node.publicKey); err != nil {
logger.Debug("Failed to send hole punch to %s: %v", node.endpointName, err)
}
}
// Exponential backoff: double the interval up to max
m.mu.Lock()
newInterval := m.sendHolepunchInterval * 2
if newInterval > sendHolepunchIntervalMax {
newInterval = sendHolepunchIntervalMax
}
if newInterval != m.sendHolepunchInterval {
m.sendHolepunchInterval = newInterval
ticker.Reset(m.sendHolepunchInterval)
logger.Debug("Increased hole punch interval to %v", m.sendHolepunchInterval)
}
m.mu.Unlock()
}
}
}
}
// sendHolePunch sends an encrypted hole punch packet using the shared bind
func (m *Manager) sendHolePunch(remoteAddr *net.UDPAddr, serverPubKey string) error {
m.mu.Lock()
token := m.token
ID := m.ID
m.mu.Unlock()
if serverPubKey == "" || token == "" {
return fmt.Errorf("server public key or OLM token is empty")
}
var payload interface{}
if m.clientType == "newt" {
payload = struct {
ID string `json:"newtId"`
Token string `json:"token"`
PublicKey string `json:"publicKey"`
}{
ID: ID,
Token: token,
PublicKey: m.publicKey,
}
} else {
payload = struct {
ID string `json:"olmId"`
Token string `json:"token"`
PublicKey string `json:"publicKey"`
}{
ID: ID,
Token: token,
PublicKey: m.publicKey,
}
}
// Convert payload to JSON
payloadBytes, err := json.Marshal(payload)
if err != nil {
return fmt.Errorf("failed to marshal payload: %w", err)
}
// Encrypt the payload using the server's WireGuard public key
encryptedPayload, err := encryptPayload(payloadBytes, serverPubKey)
if err != nil {
return fmt.Errorf("failed to encrypt payload: %w", err)
}
jsonData, err := json.Marshal(encryptedPayload)
if err != nil {
return fmt.Errorf("failed to marshal encrypted payload: %w", err)
}
_, err = m.sharedBind.WriteToUDP(jsonData, remoteAddr)
if err != nil {
return fmt.Errorf("failed to write to UDP: %w", err)
}
logger.Debug("Sent UDP hole punch to %s: %s", remoteAddr.String(), string(jsonData))
return nil
}
// encryptPayload encrypts the payload using ChaCha20-Poly1305 AEAD with X25519 key exchange
func encryptPayload(payload []byte, serverPublicKey string) (interface{}, error) {
// Generate an ephemeral keypair for this message
ephemeralPrivateKey, err := wgtypes.GeneratePrivateKey()
if err != nil {
return nil, fmt.Errorf("failed to generate ephemeral private key: %v", err)
}
ephemeralPublicKey := ephemeralPrivateKey.PublicKey()
// Parse the server's public key
serverPubKey, err := wgtypes.ParseKey(serverPublicKey)
if err != nil {
return nil, fmt.Errorf("failed to parse server public key: %v", err)
}
// Use X25519 for key exchange
var ephPrivKeyFixed [32]byte
copy(ephPrivKeyFixed[:], ephemeralPrivateKey[:])
// Perform X25519 key exchange
sharedSecret, err := curve25519.X25519(ephPrivKeyFixed[:], serverPubKey[:])
if err != nil {
return nil, fmt.Errorf("failed to perform X25519 key exchange: %v", err)
}
// Create an AEAD cipher using the shared secret
aead, err := chacha20poly1305.New(sharedSecret)
if err != nil {
return nil, fmt.Errorf("failed to create AEAD cipher: %v", err)
}
// Generate a random nonce
nonce := make([]byte, aead.NonceSize())
if _, err := mrand.Read(nonce); err != nil {
return nil, fmt.Errorf("failed to generate nonce: %v", err)
}
// Encrypt the payload
ciphertext := aead.Seal(nil, nonce, payload, nil)
// Prepare the final encrypted message
encryptedMsg := struct {
EphemeralPublicKey string `json:"ephemeralPublicKey"`
Nonce []byte `json:"nonce"`
Ciphertext []byte `json:"ciphertext"`
}{
EphemeralPublicKey: ephemeralPublicKey.String(),
Nonce: nonce,
Ciphertext: ciphertext,
}
return encryptedMsg, nil
}