Split out hp

Former-commit-id: 29ed4fefbf
This commit is contained in:
Owen
2025-11-07 21:59:07 -08:00
parent a61c7ca1ee
commit 78e3bb374a
4 changed files with 402 additions and 504 deletions

351
holepunch/holepunch.go Normal file
View File

@@ -0,0 +1,351 @@
package holepunch
import (
"encoding/json"
"fmt"
"net"
"sync"
"time"
"github.com/fosrl/newt/logger"
"github.com/fosrl/olm/bind"
"golang.org/x/crypto/chacha20poly1305"
"golang.org/x/crypto/curve25519"
"golang.org/x/exp/rand"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
// DomainResolver is a function type for resolving domains to IP addresses
type DomainResolver func(string) (string, error)
// ExitNode represents a WireGuard exit node for hole punching
type ExitNode struct {
Endpoint string `json:"endpoint"`
PublicKey string `json:"publicKey"`
}
// Manager handles UDP hole punching operations
type Manager struct {
mu sync.Mutex
running bool
stopChan chan struct{}
sharedBind *bind.SharedBind
olmID string
token string
domainResolver DomainResolver
}
// NewManager creates a new hole punch manager
func NewManager(sharedBind *bind.SharedBind, olmID string, domainResolver DomainResolver) *Manager {
return &Manager{
sharedBind: sharedBind,
olmID: olmID,
domainResolver: domainResolver,
}
}
// SetToken updates the authentication token used for hole punching
func (m *Manager) SetToken(token string) {
m.mu.Lock()
defer m.mu.Unlock()
m.token = token
}
// IsRunning returns whether hole punching is currently active
func (m *Manager) IsRunning() bool {
m.mu.Lock()
defer m.mu.Unlock()
return m.running
}
// Stop stops any ongoing hole punch operations
func (m *Manager) Stop() {
m.mu.Lock()
defer m.mu.Unlock()
if !m.running {
return
}
if m.stopChan != nil {
close(m.stopChan)
m.stopChan = nil
}
m.running = false
logger.Info("Hole punch manager stopped")
}
// StartMultipleExitNodes starts hole punching to multiple exit nodes
func (m *Manager) StartMultipleExitNodes(exitNodes []ExitNode) error {
m.mu.Lock()
if m.running {
m.mu.Unlock()
logger.Debug("UDP hole punch already running, skipping new request")
return fmt.Errorf("hole punch already running")
}
if len(exitNodes) == 0 {
m.mu.Unlock()
logger.Warn("No exit nodes provided for hole punching")
return fmt.Errorf("no exit nodes provided")
}
m.running = true
m.stopChan = make(chan struct{})
m.mu.Unlock()
logger.Info("Starting UDP hole punch to %d exit nodes with shared bind", len(exitNodes))
go m.runMultipleExitNodes(exitNodes)
return nil
}
// StartSingleEndpoint starts hole punching to a single endpoint (legacy mode)
func (m *Manager) StartSingleEndpoint(endpoint, serverPubKey string) error {
m.mu.Lock()
if m.running {
m.mu.Unlock()
logger.Debug("UDP hole punch already running, skipping new request")
return fmt.Errorf("hole punch already running")
}
m.running = true
m.stopChan = make(chan struct{})
m.mu.Unlock()
logger.Info("Starting UDP hole punch to %s with shared bind", endpoint)
go m.runSingleEndpoint(endpoint, serverPubKey)
return nil
}
// runMultipleExitNodes performs hole punching to multiple exit nodes
func (m *Manager) runMultipleExitNodes(exitNodes []ExitNode) {
defer func() {
m.mu.Lock()
m.running = false
m.mu.Unlock()
logger.Info("UDP hole punch goroutine ended for all exit nodes")
}()
// Resolve all endpoints upfront
type resolvedExitNode struct {
remoteAddr *net.UDPAddr
publicKey string
endpointName string
}
var resolvedNodes []resolvedExitNode
for _, exitNode := range exitNodes {
host, err := m.domainResolver(exitNode.Endpoint)
if err != nil {
logger.Warn("Failed to resolve endpoint %s: %v", exitNode.Endpoint, err)
continue
}
serverAddr := net.JoinHostPort(host, "21820")
remoteAddr, err := net.ResolveUDPAddr("udp", serverAddr)
if err != nil {
logger.Error("Failed to resolve UDP address %s: %v", serverAddr, err)
continue
}
resolvedNodes = append(resolvedNodes, resolvedExitNode{
remoteAddr: remoteAddr,
publicKey: exitNode.PublicKey,
endpointName: exitNode.Endpoint,
})
logger.Info("Resolved exit node: %s -> %s", exitNode.Endpoint, remoteAddr.String())
}
if len(resolvedNodes) == 0 {
logger.Error("No exit nodes could be resolved")
return
}
// Send initial hole punch to all exit nodes
for _, node := range resolvedNodes {
if err := m.sendHolePunch(node.remoteAddr, node.publicKey); err != nil {
logger.Warn("Failed to send initial hole punch to %s: %v", node.endpointName, err)
}
}
ticker := time.NewTicker(250 * time.Millisecond)
defer ticker.Stop()
timeout := time.NewTimer(15 * time.Second)
defer timeout.Stop()
for {
select {
case <-m.stopChan:
logger.Debug("Hole punch stopped by signal")
return
case <-timeout.C:
logger.Debug("Hole punch timeout reached")
return
case <-ticker.C:
// Send hole punch to all exit nodes
for _, node := range resolvedNodes {
if err := m.sendHolePunch(node.remoteAddr, node.publicKey); err != nil {
logger.Debug("Failed to send hole punch to %s: %v", node.endpointName, err)
}
}
}
}
}
// runSingleEndpoint performs hole punching to a single endpoint
func (m *Manager) runSingleEndpoint(endpoint, serverPubKey string) {
defer func() {
m.mu.Lock()
m.running = false
m.mu.Unlock()
logger.Info("UDP hole punch goroutine ended for %s", endpoint)
}()
host, err := m.domainResolver(endpoint)
if err != nil {
logger.Error("Failed to resolve domain %s: %v", endpoint, err)
return
}
serverAddr := net.JoinHostPort(host, "21820")
remoteAddr, err := net.ResolveUDPAddr("udp", serverAddr)
if err != nil {
logger.Error("Failed to resolve UDP address %s: %v", serverAddr, err)
return
}
// Execute once immediately before starting the loop
if err := m.sendHolePunch(remoteAddr, serverPubKey); err != nil {
logger.Warn("Failed to send initial hole punch: %v", err)
}
ticker := time.NewTicker(250 * time.Millisecond)
defer ticker.Stop()
timeout := time.NewTimer(15 * time.Second)
defer timeout.Stop()
for {
select {
case <-m.stopChan:
logger.Debug("Hole punch stopped by signal")
return
case <-timeout.C:
logger.Debug("Hole punch timeout reached")
return
case <-ticker.C:
if err := m.sendHolePunch(remoteAddr, serverPubKey); err != nil {
logger.Debug("Failed to send hole punch: %v", err)
}
}
}
}
// sendHolePunch sends an encrypted hole punch packet using the shared bind
func (m *Manager) sendHolePunch(remoteAddr *net.UDPAddr, serverPubKey string) error {
m.mu.Lock()
token := m.token
olmID := m.olmID
m.mu.Unlock()
if serverPubKey == "" || token == "" {
return fmt.Errorf("server public key or OLM token is empty")
}
payload := struct {
OlmID string `json:"olmId"`
Token string `json:"token"`
}{
OlmID: olmID,
Token: token,
}
// Convert payload to JSON
payloadBytes, err := json.Marshal(payload)
if err != nil {
return fmt.Errorf("failed to marshal payload: %w", err)
}
// Encrypt the payload using the server's WireGuard public key
encryptedPayload, err := encryptPayload(payloadBytes, serverPubKey)
if err != nil {
return fmt.Errorf("failed to encrypt payload: %w", err)
}
jsonData, err := json.Marshal(encryptedPayload)
if err != nil {
return fmt.Errorf("failed to marshal encrypted payload: %w", err)
}
_, err = m.sharedBind.WriteToUDP(jsonData, remoteAddr)
if err != nil {
return fmt.Errorf("failed to write to UDP: %w", err)
}
logger.Debug("Sent UDP hole punch to %s: %s", remoteAddr.String(), string(jsonData))
return nil
}
// encryptPayload encrypts the payload using ChaCha20-Poly1305 AEAD with X25519 key exchange
func encryptPayload(payload []byte, serverPublicKey string) (interface{}, error) {
// Generate an ephemeral keypair for this message
ephemeralPrivateKey, err := wgtypes.GeneratePrivateKey()
if err != nil {
return nil, fmt.Errorf("failed to generate ephemeral private key: %v", err)
}
ephemeralPublicKey := ephemeralPrivateKey.PublicKey()
// Parse the server's public key
serverPubKey, err := wgtypes.ParseKey(serverPublicKey)
if err != nil {
return nil, fmt.Errorf("failed to parse server public key: %v", err)
}
// Use X25519 for key exchange
var ephPrivKeyFixed [32]byte
copy(ephPrivKeyFixed[:], ephemeralPrivateKey[:])
// Perform X25519 key exchange
sharedSecret, err := curve25519.X25519(ephPrivKeyFixed[:], serverPubKey[:])
if err != nil {
return nil, fmt.Errorf("failed to perform X25519 key exchange: %v", err)
}
// Create an AEAD cipher using the shared secret
aead, err := chacha20poly1305.New(sharedSecret)
if err != nil {
return nil, fmt.Errorf("failed to create AEAD cipher: %v", err)
}
// Generate a random nonce
nonce := make([]byte, aead.NonceSize())
if _, err := rand.Read(nonce); err != nil {
return nil, fmt.Errorf("failed to generate nonce: %v", err)
}
// Encrypt the payload
ciphertext := aead.Seal(nil, nonce, payload, nil)
// Prepare the final encrypted message
encryptedMsg := struct {
EphemeralPublicKey string `json:"ephemeralPublicKey"`
Nonce []byte `json:"nonce"`
Ciphertext []byte `json:"ciphertext"`
}{
EphemeralPublicKey: ephemeralPublicKey.String(),
Nonce: nonce,
Ciphertext: ciphertext,
}
return encryptedMsg, nil
}

View File

@@ -1 +1 @@
767662d6fa777b3bb77d47a1c44eb5fb60249e87
573df1772c00fcb34ec68e575e973c460dc27ba8

View File

@@ -3,7 +3,6 @@ package olm
import (
"encoding/base64"
"encoding/hex"
"encoding/json"
"fmt"
"net"
"os/exec"
@@ -14,12 +13,9 @@ import (
"time"
"github.com/fosrl/newt/logger"
"github.com/fosrl/olm/bind"
"github.com/fosrl/olm/peermonitor"
"github.com/fosrl/olm/websocket"
"github.com/vishvananda/netlink"
"golang.org/x/crypto/chacha20poly1305"
"golang.org/x/crypto/curve25519"
"golang.org/x/exp/rand"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
@@ -192,7 +188,7 @@ func mapToWireGuardLogLevel(level logger.LogLevel) int {
}
}
func resolveDomain(domain string) (string, error) {
func ResolveDomain(domain string) (string, error) {
// First handle any protocol prefix
domain = strings.TrimPrefix(strings.TrimPrefix(domain, "https://"), "http://")
@@ -239,463 +235,6 @@ func resolveDomain(domain string) (string, error) {
return ipAddr, nil
}
func sendUDPHolePunchWithConn(conn *net.UDPConn, remoteAddr *net.UDPAddr, olmID string, serverPubKey string) error {
if serverPubKey == "" || olmToken == "" {
return nil
}
payload := struct {
OlmID string `json:"olmId"`
Token string `json:"token"`
}{
OlmID: olmID,
Token: olmToken,
}
// Convert payload to JSON
payloadBytes, err := json.Marshal(payload)
if err != nil {
return fmt.Errorf("failed to marshal payload: %v", err)
}
// Encrypt the payload using the server's WireGuard public key
encryptedPayload, err := encryptPayload(payloadBytes, serverPubKey)
if err != nil {
return fmt.Errorf("failed to encrypt payload: %v", err)
}
jsonData, err := json.Marshal(encryptedPayload)
if err != nil {
return fmt.Errorf("failed to marshal encrypted payload: %v", err)
}
_, err = conn.WriteToUDP(jsonData, remoteAddr)
if err != nil {
return fmt.Errorf("failed to send UDP packet: %v", err)
}
logger.Debug("Sent UDP hole punch to %s: %s", remoteAddr.String(), string(jsonData))
return nil
}
func encryptPayload(payload []byte, serverPublicKey string) (interface{}, error) {
// Generate an ephemeral keypair for this message
ephemeralPrivateKey, err := wgtypes.GeneratePrivateKey()
if err != nil {
return nil, fmt.Errorf("failed to generate ephemeral private key: %v", err)
}
ephemeralPublicKey := ephemeralPrivateKey.PublicKey()
// Parse the server's public key
serverPubKey, err := wgtypes.ParseKey(serverPublicKey)
if err != nil {
return nil, fmt.Errorf("failed to parse server public key: %v", err)
}
// Use X25519 for key exchange (replacing deprecated ScalarMult)
var ephPrivKeyFixed [32]byte
copy(ephPrivKeyFixed[:], ephemeralPrivateKey[:])
// Perform X25519 key exchange
sharedSecret, err := curve25519.X25519(ephPrivKeyFixed[:], serverPubKey[:])
if err != nil {
return nil, fmt.Errorf("failed to perform X25519 key exchange: %v", err)
}
// Create an AEAD cipher using the shared secret
aead, err := chacha20poly1305.New(sharedSecret)
if err != nil {
return nil, fmt.Errorf("failed to create AEAD cipher: %v", err)
}
// Generate a random nonce
nonce := make([]byte, aead.NonceSize())
if _, err := rand.Read(nonce); err != nil {
return nil, fmt.Errorf("failed to generate nonce: %v", err)
}
// Encrypt the payload
ciphertext := aead.Seal(nil, nonce, payload, nil)
// Prepare the final encrypted message
encryptedMsg := struct {
EphemeralPublicKey string `json:"ephemeralPublicKey"`
Nonce []byte `json:"nonce"`
Ciphertext []byte `json:"ciphertext"`
}{
EphemeralPublicKey: ephemeralPublicKey.String(),
Nonce: nonce,
Ciphertext: ciphertext,
}
return encryptedMsg, nil
}
func keepSendingUDPHolePunchToMultipleExitNodes(exitNodes []ExitNode, olmID string, sourcePort uint16) {
if len(exitNodes) == 0 {
logger.Warn("No exit nodes provided for hole punching")
return
}
// Check if hole punching is already running
if holePunchRunning {
logger.Debug("UDP hole punch already running, skipping new request")
return
}
// Set the flag to indicate hole punching is running
holePunchRunning = true
defer func() {
holePunchRunning = false
logger.Info("UDP hole punch goroutine ended")
}()
logger.Info("Starting UDP hole punch to %d exit nodes", len(exitNodes))
defer logger.Info("UDP hole punch goroutine ended for all exit nodes")
// Create the UDP connection once and reuse it for all exit nodes
localAddr := &net.UDPAddr{
Port: int(sourcePort),
IP: net.IPv4zero,
}
conn, err := net.ListenUDP("udp", localAddr)
if err != nil {
logger.Error("Failed to bind UDP socket: %v", err)
return
}
defer conn.Close()
// Resolve all endpoints upfront
type resolvedExitNode struct {
remoteAddr *net.UDPAddr
publicKey string
endpointName string
}
var resolvedNodes []resolvedExitNode
for _, exitNode := range exitNodes {
host, err := resolveDomain(exitNode.Endpoint)
if err != nil {
logger.Error("Failed to resolve endpoint %s: %v", exitNode.Endpoint, err)
continue
}
serverAddr := net.JoinHostPort(host, "21820")
remoteAddr, err := net.ResolveUDPAddr("udp", serverAddr)
if err != nil {
logger.Error("Failed to resolve UDP address for %s: %v", exitNode.Endpoint, err)
continue
}
resolvedNodes = append(resolvedNodes, resolvedExitNode{
remoteAddr: remoteAddr,
publicKey: exitNode.PublicKey,
endpointName: exitNode.Endpoint,
})
logger.Info("Resolved exit node: %s -> %s", exitNode.Endpoint, remoteAddr.String())
}
if len(resolvedNodes) == 0 {
logger.Error("No exit nodes could be resolved")
return
}
// Send initial hole punch to all exit nodes
for _, node := range resolvedNodes {
if err := sendUDPHolePunchWithConn(conn, node.remoteAddr, olmID, node.publicKey); err != nil {
logger.Error("Failed to send initial UDP hole punch to %s: %v", node.endpointName, err)
}
}
ticker := time.NewTicker(250 * time.Millisecond)
defer ticker.Stop()
timeout := time.NewTimer(15 * time.Second)
defer timeout.Stop()
for {
select {
case <-stopHolepunch:
logger.Info("Stopping UDP holepunch for all exit nodes")
return
case <-timeout.C:
logger.Info("UDP holepunch routine timed out after 15 seconds for all exit nodes")
return
case <-ticker.C:
// Send hole punch to all exit nodes
for _, node := range resolvedNodes {
if err := sendUDPHolePunchWithConn(conn, node.remoteAddr, olmID, node.publicKey); err != nil {
logger.Error("Failed to send UDP hole punch to %s: %v", node.endpointName, err)
}
}
}
}
}
func keepSendingUDPHolePunch(endpoint string, olmID string, sourcePort uint16, serverPubKey string) {
// Check if hole punching is already running
if holePunchRunning {
logger.Debug("UDP hole punch already running, skipping new request")
return
}
// Set the flag to indicate hole punching is running
holePunchRunning = true
defer func() {
holePunchRunning = false
logger.Info("UDP hole punch goroutine ended")
}()
logger.Info("Starting UDP hole punch to %s", endpoint)
defer logger.Info("UDP hole punch goroutine ended for %s", endpoint)
host, err := resolveDomain(endpoint)
if err != nil {
logger.Error("Failed to resolve endpoint: %v", err)
return
}
serverAddr := net.JoinHostPort(host, "21820")
// Create the UDP connection once and reuse it
localAddr := &net.UDPAddr{
Port: int(sourcePort),
IP: net.IPv4zero,
}
remoteAddr, err := net.ResolveUDPAddr("udp", serverAddr)
if err != nil {
logger.Error("Failed to resolve UDP address: %v", err)
return
}
conn, err := net.ListenUDP("udp", localAddr)
if err != nil {
logger.Error("Failed to bind UDP socket: %v", err)
return
}
defer conn.Close()
// Execute once immediately before starting the loop
if err := sendUDPHolePunchWithConn(conn, remoteAddr, olmID, serverPubKey); err != nil {
logger.Error("Failed to send UDP hole punch: %v", err)
}
ticker := time.NewTicker(250 * time.Millisecond)
defer ticker.Stop()
timeout := time.NewTimer(15 * time.Second)
defer timeout.Stop()
for {
select {
case <-stopHolepunch:
logger.Info("Stopping UDP holepunch")
return
case <-timeout.C:
logger.Info("UDP holepunch routine timed out after 15 seconds")
return
case <-ticker.C:
if err := sendUDPHolePunchWithConn(conn, remoteAddr, olmID, serverPubKey); err != nil {
logger.Error("Failed to send UDP hole punch: %v", err)
}
}
}
}
// keepSendingUDPHolePunchToMultipleExitNodesWithSharedBind sends hole punch packets using the shared bind
func keepSendingUDPHolePunchToMultipleExitNodesWithSharedBind(exitNodes []ExitNode, olmID string, sharedBind *bind.SharedBind) {
if len(exitNodes) == 0 {
logger.Warn("No exit nodes provided for hole punching")
return
}
// Check if hole punching is already running
if holePunchRunning {
logger.Debug("UDP hole punch already running, skipping new request")
return
}
// Set the flag to indicate hole punching is running
holePunchRunning = true
defer func() {
holePunchRunning = false
logger.Info("UDP hole punch goroutine ended")
}()
logger.Info("Starting UDP hole punch to %d exit nodes with shared bind", len(exitNodes))
defer logger.Info("UDP hole punch goroutine ended for all exit nodes")
// Resolve all endpoints upfront
type resolvedExitNode struct {
remoteAddr *net.UDPAddr
publicKey string
endpointName string
}
var resolvedNodes []resolvedExitNode
for _, exitNode := range exitNodes {
host, err := resolveDomain(exitNode.Endpoint)
if err != nil {
logger.Error("Failed to resolve endpoint %s: %v", exitNode.Endpoint, err)
continue
}
serverAddr := net.JoinHostPort(host, "21820")
remoteAddr, err := net.ResolveUDPAddr("udp", serverAddr)
if err != nil {
logger.Error("Failed to resolve UDP address for %s: %v", exitNode.Endpoint, err)
continue
}
resolvedNodes = append(resolvedNodes, resolvedExitNode{
remoteAddr: remoteAddr,
publicKey: exitNode.PublicKey,
endpointName: exitNode.Endpoint,
})
logger.Info("Resolved exit node: %s -> %s", exitNode.Endpoint, remoteAddr.String())
}
if len(resolvedNodes) == 0 {
logger.Error("No exit nodes could be resolved")
return
}
// Send initial hole punch to all exit nodes
for _, node := range resolvedNodes {
if err := sendUDPHolePunchWithBind(sharedBind, node.remoteAddr, olmID, node.publicKey); err != nil {
logger.Error("Failed to send initial UDP hole punch to %s: %v", node.endpointName, err)
}
}
ticker := time.NewTicker(250 * time.Millisecond)
defer ticker.Stop()
timeout := time.NewTimer(15 * time.Second)
defer timeout.Stop()
for {
select {
case <-stopHolepunch:
logger.Info("Stopping UDP holepunch for all exit nodes")
return
case <-timeout.C:
logger.Info("UDP holepunch routine timed out after 15 seconds for all exit nodes")
return
case <-ticker.C:
// Send hole punch to all exit nodes
for _, node := range resolvedNodes {
if err := sendUDPHolePunchWithBind(sharedBind, node.remoteAddr, olmID, node.publicKey); err != nil {
logger.Error("Failed to send UDP hole punch to %s: %v", node.endpointName, err)
}
}
}
}
}
// keepSendingUDPHolePunchWithSharedBind sends hole punch packets to a single endpoint using shared bind
func keepSendingUDPHolePunchWithSharedBind(endpoint string, olmID string, sharedBind *bind.SharedBind, serverPubKey string) {
// Check if hole punching is already running
if holePunchRunning {
logger.Debug("UDP hole punch already running, skipping new request")
return
}
// Set the flag to indicate hole punching is running
holePunchRunning = true
defer func() {
holePunchRunning = false
logger.Info("UDP hole punch goroutine ended")
}()
logger.Info("Starting UDP hole punch to %s with shared bind", endpoint)
defer logger.Info("UDP hole punch goroutine ended for %s", endpoint)
host, err := resolveDomain(endpoint)
if err != nil {
logger.Error("Failed to resolve domain %s: %v", endpoint, err)
return
}
serverAddr := net.JoinHostPort(host, "21820")
remoteAddr, err := net.ResolveUDPAddr("udp", serverAddr)
if err != nil {
logger.Error("Failed to resolve UDP address %s: %v", serverAddr, err)
return
}
// Execute once immediately before starting the loop
if err := sendUDPHolePunchWithBind(sharedBind, remoteAddr, olmID, serverPubKey); err != nil {
logger.Error("Failed to send initial UDP hole punch: %v", err)
}
ticker := time.NewTicker(250 * time.Millisecond)
defer ticker.Stop()
timeout := time.NewTimer(15 * time.Second)
defer timeout.Stop()
for {
select {
case <-stopHolepunch:
logger.Info("Stopping UDP holepunch")
return
case <-timeout.C:
logger.Info("UDP holepunch routine timed out after 15 seconds")
return
case <-ticker.C:
if err := sendUDPHolePunchWithBind(sharedBind, remoteAddr, olmID, serverPubKey); err != nil {
logger.Error("Failed to send UDP hole punch: %v", err)
}
}
}
}
// sendUDPHolePunchWithBind sends an encrypted hole punch packet using the shared bind
func sendUDPHolePunchWithBind(sharedBind *bind.SharedBind, remoteAddr *net.UDPAddr, olmID string, serverPubKey string) error {
if serverPubKey == "" || olmToken == "" {
return fmt.Errorf("server public key or OLM token is empty")
}
payload := struct {
OlmID string `json:"olmId"`
Token string `json:"token"`
}{
OlmID: olmID,
Token: olmToken,
}
// Convert payload to JSON
payloadBytes, err := json.Marshal(payload)
if err != nil {
return fmt.Errorf("failed to marshal payload: %w", err)
}
// Encrypt the payload using the server's WireGuard public key
encryptedPayload, err := encryptPayload(payloadBytes, serverPubKey)
if err != nil {
return fmt.Errorf("failed to encrypt payload: %w", err)
}
jsonData, err := json.Marshal(encryptedPayload)
if err != nil {
return fmt.Errorf("failed to marshal encrypted payload: %w", err)
}
_, err = sharedBind.WriteToUDP(jsonData, remoteAddr)
if err != nil {
return fmt.Errorf("failed to write to UDP: %w", err)
}
logger.Debug("Sent UDP hole punch to %s: %s", remoteAddr.String(), string(jsonData))
return nil
}
func FindAvailableUDPPort(minPort, maxPort uint16) (uint16, error) {
if maxPort < minPort {
return 0, fmt.Errorf("invalid port range: min=%d, max=%d", minPort, maxPort)
@@ -772,7 +311,7 @@ func keepSendingPing(olm *websocket.Client) {
// ConfigurePeer sets up or updates a peer within the WireGuard device
func ConfigurePeer(dev *device.Device, siteConfig SiteConfig, privateKey wgtypes.Key, endpoint string) error {
siteHost, err := resolveDomain(siteConfig.Endpoint)
siteHost, err := ResolveDomain(siteConfig.Endpoint)
if err != nil {
return fmt.Errorf("failed to resolve endpoint for site %d: %v", siteConfig.SiteId, err)
}
@@ -829,7 +368,7 @@ func ConfigurePeer(dev *device.Device, siteConfig SiteConfig, privateKey wgtypes
monitorPeer := net.JoinHostPort(monitorAddress, strconv.Itoa(int(siteConfig.ServerPort+1))) // +1 for the monitor port
logger.Debug("Setting up peer monitor for site %d at %s", siteConfig.SiteId, monitorPeer)
primaryRelay, err := resolveDomain(endpoint) // Using global endpoint variable
primaryRelay, err := ResolveDomain(endpoint) // Using global endpoint variable
if err != nil {
logger.Warn("Failed to resolve primary relay endpoint: %v", err)
}

View File

@@ -13,6 +13,7 @@ import (
"github.com/fosrl/newt/updates"
"github.com/fosrl/olm/api"
"github.com/fosrl/olm/bind"
"github.com/fosrl/olm/holepunch"
"github.com/fosrl/olm/peermonitor"
"github.com/fosrl/olm/websocket"
"golang.zx2c4.com/wireguard/device"
@@ -69,6 +70,7 @@ var (
tunnelCancel context.CancelFunc
tunnelRunning bool
sharedBind *bind.SharedBind
holePunchManager *holepunch.Manager
)
func Run(ctx context.Context, config Config) {
@@ -197,7 +199,6 @@ func TunnelProcess(ctx context.Context, config Config, id string, secret string,
}()
// Recreate channels for this tunnel session
stopHolepunch = make(chan struct{})
stopPing = make(chan struct{})
var (
@@ -260,6 +261,11 @@ func TunnelProcess(ctx context.Context, config Config, id string, secret string,
logger.Info("Created shared UDP socket on port %d (refcount: %d)", sourcePort, sharedBind.GetRefCount())
}
// Create the holepunch manager
if holePunchManager == nil {
holePunchManager = holepunch.NewManager(sharedBind, id, ResolveDomain)
}
olm.RegisterHandler("olm/wg/holepunch/all", func(msg websocket.WSMessage) {
logger.Debug("Received message: %v", msg.Data)
@@ -274,12 +280,20 @@ func TunnelProcess(ctx context.Context, config Config, id string, secret string,
return
}
// Create a new stopHolepunch channel for the new set of goroutines
stopHolepunch = make(chan struct{})
// Convert HolePunchData.ExitNodes to holepunch.ExitNode slice
exitNodes := make([]holepunch.ExitNode, len(holePunchData.ExitNodes))
for i, node := range holePunchData.ExitNodes {
exitNodes[i] = holepunch.ExitNode{
Endpoint: node.Endpoint,
PublicKey: node.PublicKey,
}
}
// Start a single hole punch goroutine for all exit nodes
logger.Info("Starting hole punch for %d exit nodes", len(holePunchData.ExitNodes))
go keepSendingUDPHolePunchToMultipleExitNodesWithSharedBind(holePunchData.ExitNodes, id, sharedBind)
// Start hole punching using the manager
logger.Info("Starting hole punch for %d exit nodes", len(exitNodes))
if err := holePunchManager.StartMultipleExitNodes(exitNodes); err != nil {
logger.Warn("Failed to start hole punch: %v", err)
}
})
olm.RegisterHandler("olm/wg/holepunch", func(msg websocket.WSMessage) {
@@ -304,20 +318,16 @@ func TunnelProcess(ctx context.Context, config Config, id string, secret string,
return
}
// Stop any existing hole punch goroutines by closing the current channel
select {
case <-stopHolepunch:
// Channel already closed
default:
close(stopHolepunch)
// Stop any existing hole punch operations
if holePunchManager != nil {
holePunchManager.Stop()
}
// Create a new stopHolepunch channel for the new set of goroutines
stopHolepunch = make(chan struct{})
// Start hole punching for each exit node
// Start hole punching for the exit node
logger.Info("Starting hole punch for exit node: %s with public key: %s", legacyHolePunchData.Endpoint, legacyHolePunchData.ServerPubKey)
go keepSendingUDPHolePunchWithSharedBind(legacyHolePunchData.Endpoint, id, sharedBind, legacyHolePunchData.ServerPubKey)
if err := holePunchManager.StartSingleEndpoint(legacyHolePunchData.Endpoint, legacyHolePunchData.ServerPubKey); err != nil {
logger.Warn("Failed to start hole punch: %v", err)
}
})
olm.RegisterHandler("olm/wg/connect", func(msg websocket.WSMessage) {
@@ -407,6 +417,7 @@ func TunnelProcess(ctx context.Context, config Config, id string, secret string,
for {
conn, err := uapiListener.Accept()
if err != nil {
return
}
go dev.IpcHandle(conn)
@@ -696,7 +707,7 @@ func TunnelProcess(ctx context.Context, config Config, id string, secret string,
return
}
primaryRelay, err := resolveDomain(relayData.Endpoint)
primaryRelay, err := ResolveDomain(relayData.Endpoint)
if err != nil {
logger.Warn("Failed to resolve primary relay endpoint: %v", err)
}
@@ -752,7 +763,9 @@ func TunnelProcess(ctx context.Context, config Config, id string, secret string,
})
olm.OnTokenUpdate(func(token string) {
olmToken = token
if holePunchManager != nil {
holePunchManager.SetToken(token)
}
})
// Connect to the WebSocket server
@@ -780,7 +793,6 @@ func TunnelProcess(ctx context.Context, config Config, id string, secret string,
apiServer.SetTunnelIP("")
apiServer.SetOrgID(config.OrgID)
stopHolepunch = make(chan struct{})
// Trigger re-registration with new orgId
logger.Info("Re-registering with new orgId: %s", config.OrgID)
publicKey := privateKey.PublicKey()
@@ -799,13 +811,9 @@ func TunnelProcess(ctx context.Context, config Config, id string, secret string,
}
func Stop() {
if stopHolepunch != nil {
select {
case <-stopHolepunch:
// Channel already closed, do nothing
default:
close(stopHolepunch)
}
// Stop hole punch manager
if holePunchManager != nil {
holePunchManager.Stop()
}
if stopPing != nil {