From 78e3bb374a3905a0d6e46b00801262318d3e5b1e Mon Sep 17 00:00:00 2001 From: Owen Date: Fri, 7 Nov 2025 21:59:07 -0800 Subject: [PATCH] Split out hp Former-commit-id: 29ed4fefbf32fe6263f0e93d236cc51c6e39c050 --- holepunch/holepunch.go | 351 ++++++++++++++++++++++++++++ olm-binary.REMOVED.git-id | 2 +- olm/common.go | 467 +------------------------------------- olm/olm.go | 86 +++---- 4 files changed, 402 insertions(+), 504 deletions(-) create mode 100644 holepunch/holepunch.go diff --git a/holepunch/holepunch.go b/holepunch/holepunch.go new file mode 100644 index 0000000..187d3fe --- /dev/null +++ b/holepunch/holepunch.go @@ -0,0 +1,351 @@ +package holepunch + +import ( + "encoding/json" + "fmt" + "net" + "sync" + "time" + + "github.com/fosrl/newt/logger" + "github.com/fosrl/olm/bind" + "golang.org/x/crypto/chacha20poly1305" + "golang.org/x/crypto/curve25519" + "golang.org/x/exp/rand" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" +) + +// DomainResolver is a function type for resolving domains to IP addresses +type DomainResolver func(string) (string, error) + +// ExitNode represents a WireGuard exit node for hole punching +type ExitNode struct { + Endpoint string `json:"endpoint"` + PublicKey string `json:"publicKey"` +} + +// Manager handles UDP hole punching operations +type Manager struct { + mu sync.Mutex + running bool + stopChan chan struct{} + sharedBind *bind.SharedBind + olmID string + token string + domainResolver DomainResolver +} + +// NewManager creates a new hole punch manager +func NewManager(sharedBind *bind.SharedBind, olmID string, domainResolver DomainResolver) *Manager { + return &Manager{ + sharedBind: sharedBind, + olmID: olmID, + domainResolver: domainResolver, + } +} + +// SetToken updates the authentication token used for hole punching +func (m *Manager) SetToken(token string) { + m.mu.Lock() + defer m.mu.Unlock() + m.token = token +} + +// IsRunning returns whether hole punching is currently active +func (m *Manager) IsRunning() bool { + m.mu.Lock() + defer m.mu.Unlock() + return m.running +} + +// Stop stops any ongoing hole punch operations +func (m *Manager) Stop() { + m.mu.Lock() + defer m.mu.Unlock() + + if !m.running { + return + } + + if m.stopChan != nil { + close(m.stopChan) + m.stopChan = nil + } + + m.running = false + logger.Info("Hole punch manager stopped") +} + +// StartMultipleExitNodes starts hole punching to multiple exit nodes +func (m *Manager) StartMultipleExitNodes(exitNodes []ExitNode) error { + m.mu.Lock() + + if m.running { + m.mu.Unlock() + logger.Debug("UDP hole punch already running, skipping new request") + return fmt.Errorf("hole punch already running") + } + + if len(exitNodes) == 0 { + m.mu.Unlock() + logger.Warn("No exit nodes provided for hole punching") + return fmt.Errorf("no exit nodes provided") + } + + m.running = true + m.stopChan = make(chan struct{}) + m.mu.Unlock() + + logger.Info("Starting UDP hole punch to %d exit nodes with shared bind", len(exitNodes)) + + go m.runMultipleExitNodes(exitNodes) + + return nil +} + +// StartSingleEndpoint starts hole punching to a single endpoint (legacy mode) +func (m *Manager) StartSingleEndpoint(endpoint, serverPubKey string) error { + m.mu.Lock() + + if m.running { + m.mu.Unlock() + logger.Debug("UDP hole punch already running, skipping new request") + return fmt.Errorf("hole punch already running") + } + + m.running = true + m.stopChan = make(chan struct{}) + m.mu.Unlock() + + logger.Info("Starting UDP hole punch to %s with shared bind", endpoint) + + go m.runSingleEndpoint(endpoint, serverPubKey) + + return nil +} + +// runMultipleExitNodes performs hole punching to multiple exit nodes +func (m *Manager) runMultipleExitNodes(exitNodes []ExitNode) { + defer func() { + m.mu.Lock() + m.running = false + m.mu.Unlock() + logger.Info("UDP hole punch goroutine ended for all exit nodes") + }() + + // Resolve all endpoints upfront + type resolvedExitNode struct { + remoteAddr *net.UDPAddr + publicKey string + endpointName string + } + + var resolvedNodes []resolvedExitNode + for _, exitNode := range exitNodes { + host, err := m.domainResolver(exitNode.Endpoint) + if err != nil { + logger.Warn("Failed to resolve endpoint %s: %v", exitNode.Endpoint, err) + continue + } + + serverAddr := net.JoinHostPort(host, "21820") + remoteAddr, err := net.ResolveUDPAddr("udp", serverAddr) + if err != nil { + logger.Error("Failed to resolve UDP address %s: %v", serverAddr, err) + continue + } + + resolvedNodes = append(resolvedNodes, resolvedExitNode{ + remoteAddr: remoteAddr, + publicKey: exitNode.PublicKey, + endpointName: exitNode.Endpoint, + }) + logger.Info("Resolved exit node: %s -> %s", exitNode.Endpoint, remoteAddr.String()) + } + + if len(resolvedNodes) == 0 { + logger.Error("No exit nodes could be resolved") + return + } + + // Send initial hole punch to all exit nodes + for _, node := range resolvedNodes { + if err := m.sendHolePunch(node.remoteAddr, node.publicKey); err != nil { + logger.Warn("Failed to send initial hole punch to %s: %v", node.endpointName, err) + } + } + + ticker := time.NewTicker(250 * time.Millisecond) + defer ticker.Stop() + + timeout := time.NewTimer(15 * time.Second) + defer timeout.Stop() + + for { + select { + case <-m.stopChan: + logger.Debug("Hole punch stopped by signal") + return + case <-timeout.C: + logger.Debug("Hole punch timeout reached") + return + case <-ticker.C: + // Send hole punch to all exit nodes + for _, node := range resolvedNodes { + if err := m.sendHolePunch(node.remoteAddr, node.publicKey); err != nil { + logger.Debug("Failed to send hole punch to %s: %v", node.endpointName, err) + } + } + } + } +} + +// runSingleEndpoint performs hole punching to a single endpoint +func (m *Manager) runSingleEndpoint(endpoint, serverPubKey string) { + defer func() { + m.mu.Lock() + m.running = false + m.mu.Unlock() + logger.Info("UDP hole punch goroutine ended for %s", endpoint) + }() + + host, err := m.domainResolver(endpoint) + if err != nil { + logger.Error("Failed to resolve domain %s: %v", endpoint, err) + return + } + + serverAddr := net.JoinHostPort(host, "21820") + + remoteAddr, err := net.ResolveUDPAddr("udp", serverAddr) + if err != nil { + logger.Error("Failed to resolve UDP address %s: %v", serverAddr, err) + return + } + + // Execute once immediately before starting the loop + if err := m.sendHolePunch(remoteAddr, serverPubKey); err != nil { + logger.Warn("Failed to send initial hole punch: %v", err) + } + + ticker := time.NewTicker(250 * time.Millisecond) + defer ticker.Stop() + + timeout := time.NewTimer(15 * time.Second) + defer timeout.Stop() + + for { + select { + case <-m.stopChan: + logger.Debug("Hole punch stopped by signal") + return + case <-timeout.C: + logger.Debug("Hole punch timeout reached") + return + case <-ticker.C: + if err := m.sendHolePunch(remoteAddr, serverPubKey); err != nil { + logger.Debug("Failed to send hole punch: %v", err) + } + } + } +} + +// sendHolePunch sends an encrypted hole punch packet using the shared bind +func (m *Manager) sendHolePunch(remoteAddr *net.UDPAddr, serverPubKey string) error { + m.mu.Lock() + token := m.token + olmID := m.olmID + m.mu.Unlock() + + if serverPubKey == "" || token == "" { + return fmt.Errorf("server public key or OLM token is empty") + } + + payload := struct { + OlmID string `json:"olmId"` + Token string `json:"token"` + }{ + OlmID: olmID, + Token: token, + } + + // Convert payload to JSON + payloadBytes, err := json.Marshal(payload) + if err != nil { + return fmt.Errorf("failed to marshal payload: %w", err) + } + + // Encrypt the payload using the server's WireGuard public key + encryptedPayload, err := encryptPayload(payloadBytes, serverPubKey) + if err != nil { + return fmt.Errorf("failed to encrypt payload: %w", err) + } + + jsonData, err := json.Marshal(encryptedPayload) + if err != nil { + return fmt.Errorf("failed to marshal encrypted payload: %w", err) + } + + _, err = m.sharedBind.WriteToUDP(jsonData, remoteAddr) + if err != nil { + return fmt.Errorf("failed to write to UDP: %w", err) + } + + logger.Debug("Sent UDP hole punch to %s: %s", remoteAddr.String(), string(jsonData)) + + return nil +} + +// encryptPayload encrypts the payload using ChaCha20-Poly1305 AEAD with X25519 key exchange +func encryptPayload(payload []byte, serverPublicKey string) (interface{}, error) { + // Generate an ephemeral keypair for this message + ephemeralPrivateKey, err := wgtypes.GeneratePrivateKey() + if err != nil { + return nil, fmt.Errorf("failed to generate ephemeral private key: %v", err) + } + ephemeralPublicKey := ephemeralPrivateKey.PublicKey() + + // Parse the server's public key + serverPubKey, err := wgtypes.ParseKey(serverPublicKey) + if err != nil { + return nil, fmt.Errorf("failed to parse server public key: %v", err) + } + + // Use X25519 for key exchange + var ephPrivKeyFixed [32]byte + copy(ephPrivKeyFixed[:], ephemeralPrivateKey[:]) + + // Perform X25519 key exchange + sharedSecret, err := curve25519.X25519(ephPrivKeyFixed[:], serverPubKey[:]) + if err != nil { + return nil, fmt.Errorf("failed to perform X25519 key exchange: %v", err) + } + + // Create an AEAD cipher using the shared secret + aead, err := chacha20poly1305.New(sharedSecret) + if err != nil { + return nil, fmt.Errorf("failed to create AEAD cipher: %v", err) + } + + // Generate a random nonce + nonce := make([]byte, aead.NonceSize()) + if _, err := rand.Read(nonce); err != nil { + return nil, fmt.Errorf("failed to generate nonce: %v", err) + } + + // Encrypt the payload + ciphertext := aead.Seal(nil, nonce, payload, nil) + + // Prepare the final encrypted message + encryptedMsg := struct { + EphemeralPublicKey string `json:"ephemeralPublicKey"` + Nonce []byte `json:"nonce"` + Ciphertext []byte `json:"ciphertext"` + }{ + EphemeralPublicKey: ephemeralPublicKey.String(), + Nonce: nonce, + Ciphertext: ciphertext, + } + + return encryptedMsg, nil +} diff --git a/olm-binary.REMOVED.git-id b/olm-binary.REMOVED.git-id index 78de5d4..830c71f 100644 --- a/olm-binary.REMOVED.git-id +++ b/olm-binary.REMOVED.git-id @@ -1 +1 @@ -767662d6fa777b3bb77d47a1c44eb5fb60249e87 \ No newline at end of file +573df1772c00fcb34ec68e575e973c460dc27ba8 \ No newline at end of file diff --git a/olm/common.go b/olm/common.go index f082a6a..c15b66d 100644 --- a/olm/common.go +++ b/olm/common.go @@ -3,7 +3,6 @@ package olm import ( "encoding/base64" "encoding/hex" - "encoding/json" "fmt" "net" "os/exec" @@ -14,12 +13,9 @@ import ( "time" "github.com/fosrl/newt/logger" - "github.com/fosrl/olm/bind" "github.com/fosrl/olm/peermonitor" "github.com/fosrl/olm/websocket" "github.com/vishvananda/netlink" - "golang.org/x/crypto/chacha20poly1305" - "golang.org/x/crypto/curve25519" "golang.org/x/exp/rand" "golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" @@ -192,7 +188,7 @@ func mapToWireGuardLogLevel(level logger.LogLevel) int { } } -func resolveDomain(domain string) (string, error) { +func ResolveDomain(domain string) (string, error) { // First handle any protocol prefix domain = strings.TrimPrefix(strings.TrimPrefix(domain, "https://"), "http://") @@ -239,463 +235,6 @@ func resolveDomain(domain string) (string, error) { return ipAddr, nil } -func sendUDPHolePunchWithConn(conn *net.UDPConn, remoteAddr *net.UDPAddr, olmID string, serverPubKey string) error { - if serverPubKey == "" || olmToken == "" { - return nil - } - - payload := struct { - OlmID string `json:"olmId"` - Token string `json:"token"` - }{ - OlmID: olmID, - Token: olmToken, - } - - // Convert payload to JSON - payloadBytes, err := json.Marshal(payload) - if err != nil { - return fmt.Errorf("failed to marshal payload: %v", err) - } - - // Encrypt the payload using the server's WireGuard public key - encryptedPayload, err := encryptPayload(payloadBytes, serverPubKey) - if err != nil { - return fmt.Errorf("failed to encrypt payload: %v", err) - } - - jsonData, err := json.Marshal(encryptedPayload) - if err != nil { - return fmt.Errorf("failed to marshal encrypted payload: %v", err) - } - - _, err = conn.WriteToUDP(jsonData, remoteAddr) - if err != nil { - return fmt.Errorf("failed to send UDP packet: %v", err) - } - - logger.Debug("Sent UDP hole punch to %s: %s", remoteAddr.String(), string(jsonData)) - - return nil -} - -func encryptPayload(payload []byte, serverPublicKey string) (interface{}, error) { - // Generate an ephemeral keypair for this message - ephemeralPrivateKey, err := wgtypes.GeneratePrivateKey() - if err != nil { - return nil, fmt.Errorf("failed to generate ephemeral private key: %v", err) - } - ephemeralPublicKey := ephemeralPrivateKey.PublicKey() - - // Parse the server's public key - serverPubKey, err := wgtypes.ParseKey(serverPublicKey) - if err != nil { - return nil, fmt.Errorf("failed to parse server public key: %v", err) - } - - // Use X25519 for key exchange (replacing deprecated ScalarMult) - var ephPrivKeyFixed [32]byte - copy(ephPrivKeyFixed[:], ephemeralPrivateKey[:]) - - // Perform X25519 key exchange - sharedSecret, err := curve25519.X25519(ephPrivKeyFixed[:], serverPubKey[:]) - if err != nil { - return nil, fmt.Errorf("failed to perform X25519 key exchange: %v", err) - } - - // Create an AEAD cipher using the shared secret - aead, err := chacha20poly1305.New(sharedSecret) - if err != nil { - return nil, fmt.Errorf("failed to create AEAD cipher: %v", err) - } - - // Generate a random nonce - nonce := make([]byte, aead.NonceSize()) - if _, err := rand.Read(nonce); err != nil { - return nil, fmt.Errorf("failed to generate nonce: %v", err) - } - - // Encrypt the payload - ciphertext := aead.Seal(nil, nonce, payload, nil) - - // Prepare the final encrypted message - encryptedMsg := struct { - EphemeralPublicKey string `json:"ephemeralPublicKey"` - Nonce []byte `json:"nonce"` - Ciphertext []byte `json:"ciphertext"` - }{ - EphemeralPublicKey: ephemeralPublicKey.String(), - Nonce: nonce, - Ciphertext: ciphertext, - } - - return encryptedMsg, nil -} - -func keepSendingUDPHolePunchToMultipleExitNodes(exitNodes []ExitNode, olmID string, sourcePort uint16) { - if len(exitNodes) == 0 { - logger.Warn("No exit nodes provided for hole punching") - return - } - - // Check if hole punching is already running - if holePunchRunning { - logger.Debug("UDP hole punch already running, skipping new request") - return - } - - // Set the flag to indicate hole punching is running - holePunchRunning = true - defer func() { - holePunchRunning = false - logger.Info("UDP hole punch goroutine ended") - }() - - logger.Info("Starting UDP hole punch to %d exit nodes", len(exitNodes)) - defer logger.Info("UDP hole punch goroutine ended for all exit nodes") - - // Create the UDP connection once and reuse it for all exit nodes - localAddr := &net.UDPAddr{ - Port: int(sourcePort), - IP: net.IPv4zero, - } - - conn, err := net.ListenUDP("udp", localAddr) - if err != nil { - logger.Error("Failed to bind UDP socket: %v", err) - return - } - defer conn.Close() - - // Resolve all endpoints upfront - type resolvedExitNode struct { - remoteAddr *net.UDPAddr - publicKey string - endpointName string - } - - var resolvedNodes []resolvedExitNode - for _, exitNode := range exitNodes { - host, err := resolveDomain(exitNode.Endpoint) - if err != nil { - logger.Error("Failed to resolve endpoint %s: %v", exitNode.Endpoint, err) - continue - } - - serverAddr := net.JoinHostPort(host, "21820") - remoteAddr, err := net.ResolveUDPAddr("udp", serverAddr) - if err != nil { - logger.Error("Failed to resolve UDP address for %s: %v", exitNode.Endpoint, err) - continue - } - - resolvedNodes = append(resolvedNodes, resolvedExitNode{ - remoteAddr: remoteAddr, - publicKey: exitNode.PublicKey, - endpointName: exitNode.Endpoint, - }) - logger.Info("Resolved exit node: %s -> %s", exitNode.Endpoint, remoteAddr.String()) - } - - if len(resolvedNodes) == 0 { - logger.Error("No exit nodes could be resolved") - return - } - - // Send initial hole punch to all exit nodes - for _, node := range resolvedNodes { - if err := sendUDPHolePunchWithConn(conn, node.remoteAddr, olmID, node.publicKey); err != nil { - logger.Error("Failed to send initial UDP hole punch to %s: %v", node.endpointName, err) - } - } - - ticker := time.NewTicker(250 * time.Millisecond) - defer ticker.Stop() - - timeout := time.NewTimer(15 * time.Second) - defer timeout.Stop() - - for { - select { - case <-stopHolepunch: - logger.Info("Stopping UDP holepunch for all exit nodes") - return - case <-timeout.C: - logger.Info("UDP holepunch routine timed out after 15 seconds for all exit nodes") - return - case <-ticker.C: - // Send hole punch to all exit nodes - for _, node := range resolvedNodes { - if err := sendUDPHolePunchWithConn(conn, node.remoteAddr, olmID, node.publicKey); err != nil { - logger.Error("Failed to send UDP hole punch to %s: %v", node.endpointName, err) - } - } - } - } -} - -func keepSendingUDPHolePunch(endpoint string, olmID string, sourcePort uint16, serverPubKey string) { - - // Check if hole punching is already running - if holePunchRunning { - logger.Debug("UDP hole punch already running, skipping new request") - return - } - - // Set the flag to indicate hole punching is running - holePunchRunning = true - defer func() { - holePunchRunning = false - logger.Info("UDP hole punch goroutine ended") - }() - - logger.Info("Starting UDP hole punch to %s", endpoint) - defer logger.Info("UDP hole punch goroutine ended for %s", endpoint) - - host, err := resolveDomain(endpoint) - if err != nil { - logger.Error("Failed to resolve endpoint: %v", err) - return - } - - serverAddr := net.JoinHostPort(host, "21820") - - // Create the UDP connection once and reuse it - localAddr := &net.UDPAddr{ - Port: int(sourcePort), - IP: net.IPv4zero, - } - - remoteAddr, err := net.ResolveUDPAddr("udp", serverAddr) - if err != nil { - logger.Error("Failed to resolve UDP address: %v", err) - return - } - - conn, err := net.ListenUDP("udp", localAddr) - if err != nil { - logger.Error("Failed to bind UDP socket: %v", err) - return - } - defer conn.Close() - - // Execute once immediately before starting the loop - if err := sendUDPHolePunchWithConn(conn, remoteAddr, olmID, serverPubKey); err != nil { - logger.Error("Failed to send UDP hole punch: %v", err) - } - - ticker := time.NewTicker(250 * time.Millisecond) - defer ticker.Stop() - - timeout := time.NewTimer(15 * time.Second) - defer timeout.Stop() - - for { - select { - case <-stopHolepunch: - logger.Info("Stopping UDP holepunch") - return - case <-timeout.C: - logger.Info("UDP holepunch routine timed out after 15 seconds") - return - case <-ticker.C: - if err := sendUDPHolePunchWithConn(conn, remoteAddr, olmID, serverPubKey); err != nil { - logger.Error("Failed to send UDP hole punch: %v", err) - } - } - } -} - -// keepSendingUDPHolePunchToMultipleExitNodesWithSharedBind sends hole punch packets using the shared bind -func keepSendingUDPHolePunchToMultipleExitNodesWithSharedBind(exitNodes []ExitNode, olmID string, sharedBind *bind.SharedBind) { - if len(exitNodes) == 0 { - logger.Warn("No exit nodes provided for hole punching") - return - } - - // Check if hole punching is already running - if holePunchRunning { - logger.Debug("UDP hole punch already running, skipping new request") - return - } - - // Set the flag to indicate hole punching is running - holePunchRunning = true - defer func() { - holePunchRunning = false - logger.Info("UDP hole punch goroutine ended") - }() - - logger.Info("Starting UDP hole punch to %d exit nodes with shared bind", len(exitNodes)) - defer logger.Info("UDP hole punch goroutine ended for all exit nodes") - - // Resolve all endpoints upfront - type resolvedExitNode struct { - remoteAddr *net.UDPAddr - publicKey string - endpointName string - } - - var resolvedNodes []resolvedExitNode - for _, exitNode := range exitNodes { - host, err := resolveDomain(exitNode.Endpoint) - if err != nil { - logger.Error("Failed to resolve endpoint %s: %v", exitNode.Endpoint, err) - continue - } - - serverAddr := net.JoinHostPort(host, "21820") - remoteAddr, err := net.ResolveUDPAddr("udp", serverAddr) - if err != nil { - logger.Error("Failed to resolve UDP address for %s: %v", exitNode.Endpoint, err) - continue - } - - resolvedNodes = append(resolvedNodes, resolvedExitNode{ - remoteAddr: remoteAddr, - publicKey: exitNode.PublicKey, - endpointName: exitNode.Endpoint, - }) - logger.Info("Resolved exit node: %s -> %s", exitNode.Endpoint, remoteAddr.String()) - } - - if len(resolvedNodes) == 0 { - logger.Error("No exit nodes could be resolved") - return - } - - // Send initial hole punch to all exit nodes - for _, node := range resolvedNodes { - if err := sendUDPHolePunchWithBind(sharedBind, node.remoteAddr, olmID, node.publicKey); err != nil { - logger.Error("Failed to send initial UDP hole punch to %s: %v", node.endpointName, err) - } - } - - ticker := time.NewTicker(250 * time.Millisecond) - defer ticker.Stop() - - timeout := time.NewTimer(15 * time.Second) - defer timeout.Stop() - - for { - select { - case <-stopHolepunch: - logger.Info("Stopping UDP holepunch for all exit nodes") - return - case <-timeout.C: - logger.Info("UDP holepunch routine timed out after 15 seconds for all exit nodes") - return - case <-ticker.C: - // Send hole punch to all exit nodes - for _, node := range resolvedNodes { - if err := sendUDPHolePunchWithBind(sharedBind, node.remoteAddr, olmID, node.publicKey); err != nil { - logger.Error("Failed to send UDP hole punch to %s: %v", node.endpointName, err) - } - } - } - } -} - -// keepSendingUDPHolePunchWithSharedBind sends hole punch packets to a single endpoint using shared bind -func keepSendingUDPHolePunchWithSharedBind(endpoint string, olmID string, sharedBind *bind.SharedBind, serverPubKey string) { - // Check if hole punching is already running - if holePunchRunning { - logger.Debug("UDP hole punch already running, skipping new request") - return - } - - // Set the flag to indicate hole punching is running - holePunchRunning = true - defer func() { - holePunchRunning = false - logger.Info("UDP hole punch goroutine ended") - }() - - logger.Info("Starting UDP hole punch to %s with shared bind", endpoint) - defer logger.Info("UDP hole punch goroutine ended for %s", endpoint) - - host, err := resolveDomain(endpoint) - if err != nil { - logger.Error("Failed to resolve domain %s: %v", endpoint, err) - return - } - - serverAddr := net.JoinHostPort(host, "21820") - - remoteAddr, err := net.ResolveUDPAddr("udp", serverAddr) - if err != nil { - logger.Error("Failed to resolve UDP address %s: %v", serverAddr, err) - return - } - - // Execute once immediately before starting the loop - if err := sendUDPHolePunchWithBind(sharedBind, remoteAddr, olmID, serverPubKey); err != nil { - logger.Error("Failed to send initial UDP hole punch: %v", err) - } - - ticker := time.NewTicker(250 * time.Millisecond) - defer ticker.Stop() - - timeout := time.NewTimer(15 * time.Second) - defer timeout.Stop() - - for { - select { - case <-stopHolepunch: - logger.Info("Stopping UDP holepunch") - return - case <-timeout.C: - logger.Info("UDP holepunch routine timed out after 15 seconds") - return - case <-ticker.C: - if err := sendUDPHolePunchWithBind(sharedBind, remoteAddr, olmID, serverPubKey); err != nil { - logger.Error("Failed to send UDP hole punch: %v", err) - } - } - } -} - -// sendUDPHolePunchWithBind sends an encrypted hole punch packet using the shared bind -func sendUDPHolePunchWithBind(sharedBind *bind.SharedBind, remoteAddr *net.UDPAddr, olmID string, serverPubKey string) error { - if serverPubKey == "" || olmToken == "" { - return fmt.Errorf("server public key or OLM token is empty") - } - - payload := struct { - OlmID string `json:"olmId"` - Token string `json:"token"` - }{ - OlmID: olmID, - Token: olmToken, - } - - // Convert payload to JSON - payloadBytes, err := json.Marshal(payload) - if err != nil { - return fmt.Errorf("failed to marshal payload: %w", err) - } - - // Encrypt the payload using the server's WireGuard public key - encryptedPayload, err := encryptPayload(payloadBytes, serverPubKey) - if err != nil { - return fmt.Errorf("failed to encrypt payload: %w", err) - } - - jsonData, err := json.Marshal(encryptedPayload) - if err != nil { - return fmt.Errorf("failed to marshal encrypted payload: %w", err) - } - - _, err = sharedBind.WriteToUDP(jsonData, remoteAddr) - if err != nil { - return fmt.Errorf("failed to write to UDP: %w", err) - } - - logger.Debug("Sent UDP hole punch to %s: %s", remoteAddr.String(), string(jsonData)) - - return nil -} - func FindAvailableUDPPort(minPort, maxPort uint16) (uint16, error) { if maxPort < minPort { return 0, fmt.Errorf("invalid port range: min=%d, max=%d", minPort, maxPort) @@ -772,7 +311,7 @@ func keepSendingPing(olm *websocket.Client) { // ConfigurePeer sets up or updates a peer within the WireGuard device func ConfigurePeer(dev *device.Device, siteConfig SiteConfig, privateKey wgtypes.Key, endpoint string) error { - siteHost, err := resolveDomain(siteConfig.Endpoint) + siteHost, err := ResolveDomain(siteConfig.Endpoint) if err != nil { return fmt.Errorf("failed to resolve endpoint for site %d: %v", siteConfig.SiteId, err) } @@ -829,7 +368,7 @@ func ConfigurePeer(dev *device.Device, siteConfig SiteConfig, privateKey wgtypes monitorPeer := net.JoinHostPort(monitorAddress, strconv.Itoa(int(siteConfig.ServerPort+1))) // +1 for the monitor port logger.Debug("Setting up peer monitor for site %d at %s", siteConfig.SiteId, monitorPeer) - primaryRelay, err := resolveDomain(endpoint) // Using global endpoint variable + primaryRelay, err := ResolveDomain(endpoint) // Using global endpoint variable if err != nil { logger.Warn("Failed to resolve primary relay endpoint: %v", err) } diff --git a/olm/olm.go b/olm/olm.go index 7821a32..211b90b 100644 --- a/olm/olm.go +++ b/olm/olm.go @@ -13,6 +13,7 @@ import ( "github.com/fosrl/newt/updates" "github.com/fosrl/olm/api" "github.com/fosrl/olm/bind" + "github.com/fosrl/olm/holepunch" "github.com/fosrl/olm/peermonitor" "github.com/fosrl/olm/websocket" "golang.zx2c4.com/wireguard/device" @@ -57,18 +58,19 @@ type Config struct { } var ( - privateKey wgtypes.Key - connected bool - dev *device.Device - wgData WgData - holePunchData HolePunchData - uapiListener net.Listener - tdev tun.Device - apiServer *api.API - olmClient *websocket.Client - tunnelCancel context.CancelFunc - tunnelRunning bool - sharedBind *bind.SharedBind + privateKey wgtypes.Key + connected bool + dev *device.Device + wgData WgData + holePunchData HolePunchData + uapiListener net.Listener + tdev tun.Device + apiServer *api.API + olmClient *websocket.Client + tunnelCancel context.CancelFunc + tunnelRunning bool + sharedBind *bind.SharedBind + holePunchManager *holepunch.Manager ) func Run(ctx context.Context, config Config) { @@ -197,7 +199,6 @@ func TunnelProcess(ctx context.Context, config Config, id string, secret string, }() // Recreate channels for this tunnel session - stopHolepunch = make(chan struct{}) stopPing = make(chan struct{}) var ( @@ -260,6 +261,11 @@ func TunnelProcess(ctx context.Context, config Config, id string, secret string, logger.Info("Created shared UDP socket on port %d (refcount: %d)", sourcePort, sharedBind.GetRefCount()) } + // Create the holepunch manager + if holePunchManager == nil { + holePunchManager = holepunch.NewManager(sharedBind, id, ResolveDomain) + } + olm.RegisterHandler("olm/wg/holepunch/all", func(msg websocket.WSMessage) { logger.Debug("Received message: %v", msg.Data) @@ -274,12 +280,20 @@ func TunnelProcess(ctx context.Context, config Config, id string, secret string, return } - // Create a new stopHolepunch channel for the new set of goroutines - stopHolepunch = make(chan struct{}) + // Convert HolePunchData.ExitNodes to holepunch.ExitNode slice + exitNodes := make([]holepunch.ExitNode, len(holePunchData.ExitNodes)) + for i, node := range holePunchData.ExitNodes { + exitNodes[i] = holepunch.ExitNode{ + Endpoint: node.Endpoint, + PublicKey: node.PublicKey, + } + } - // Start a single hole punch goroutine for all exit nodes - logger.Info("Starting hole punch for %d exit nodes", len(holePunchData.ExitNodes)) - go keepSendingUDPHolePunchToMultipleExitNodesWithSharedBind(holePunchData.ExitNodes, id, sharedBind) + // Start hole punching using the manager + logger.Info("Starting hole punch for %d exit nodes", len(exitNodes)) + if err := holePunchManager.StartMultipleExitNodes(exitNodes); err != nil { + logger.Warn("Failed to start hole punch: %v", err) + } }) olm.RegisterHandler("olm/wg/holepunch", func(msg websocket.WSMessage) { @@ -304,20 +318,16 @@ func TunnelProcess(ctx context.Context, config Config, id string, secret string, return } - // Stop any existing hole punch goroutines by closing the current channel - select { - case <-stopHolepunch: - // Channel already closed - default: - close(stopHolepunch) + // Stop any existing hole punch operations + if holePunchManager != nil { + holePunchManager.Stop() } - // Create a new stopHolepunch channel for the new set of goroutines - stopHolepunch = make(chan struct{}) - - // Start hole punching for each exit node + // Start hole punching for the exit node logger.Info("Starting hole punch for exit node: %s with public key: %s", legacyHolePunchData.Endpoint, legacyHolePunchData.ServerPubKey) - go keepSendingUDPHolePunchWithSharedBind(legacyHolePunchData.Endpoint, id, sharedBind, legacyHolePunchData.ServerPubKey) + if err := holePunchManager.StartSingleEndpoint(legacyHolePunchData.Endpoint, legacyHolePunchData.ServerPubKey); err != nil { + logger.Warn("Failed to start hole punch: %v", err) + } }) olm.RegisterHandler("olm/wg/connect", func(msg websocket.WSMessage) { @@ -407,6 +417,7 @@ func TunnelProcess(ctx context.Context, config Config, id string, secret string, for { conn, err := uapiListener.Accept() if err != nil { + return } go dev.IpcHandle(conn) @@ -696,7 +707,7 @@ func TunnelProcess(ctx context.Context, config Config, id string, secret string, return } - primaryRelay, err := resolveDomain(relayData.Endpoint) + primaryRelay, err := ResolveDomain(relayData.Endpoint) if err != nil { logger.Warn("Failed to resolve primary relay endpoint: %v", err) } @@ -752,7 +763,9 @@ func TunnelProcess(ctx context.Context, config Config, id string, secret string, }) olm.OnTokenUpdate(func(token string) { - olmToken = token + if holePunchManager != nil { + holePunchManager.SetToken(token) + } }) // Connect to the WebSocket server @@ -780,7 +793,6 @@ func TunnelProcess(ctx context.Context, config Config, id string, secret string, apiServer.SetTunnelIP("") apiServer.SetOrgID(config.OrgID) - stopHolepunch = make(chan struct{}) // Trigger re-registration with new orgId logger.Info("Re-registering with new orgId: %s", config.OrgID) publicKey := privateKey.PublicKey() @@ -799,13 +811,9 @@ func TunnelProcess(ctx context.Context, config Config, id string, secret string, } func Stop() { - if stopHolepunch != nil { - select { - case <-stopHolepunch: - // Channel already closed, do nothing - default: - close(stopHolepunch) - } + // Stop hole punch manager + if holePunchManager != nil { + holePunchManager.Stop() } if stopPing != nil {