mirror of
https://github.com/fosrl/olm.git
synced 2026-02-08 05:56:41 +00:00
Split out hp
This commit is contained in:
351
holepunch/holepunch.go
Normal file
351
holepunch/holepunch.go
Normal file
@@ -0,0 +1,351 @@
|
||||
package holepunch
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/fosrl/newt/logger"
|
||||
"github.com/fosrl/olm/bind"
|
||||
"golang.org/x/crypto/chacha20poly1305"
|
||||
"golang.org/x/crypto/curve25519"
|
||||
"golang.org/x/exp/rand"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
)
|
||||
|
||||
// DomainResolver is a function type for resolving domains to IP addresses
|
||||
type DomainResolver func(string) (string, error)
|
||||
|
||||
// ExitNode represents a WireGuard exit node for hole punching
|
||||
type ExitNode struct {
|
||||
Endpoint string `json:"endpoint"`
|
||||
PublicKey string `json:"publicKey"`
|
||||
}
|
||||
|
||||
// Manager handles UDP hole punching operations
|
||||
type Manager struct {
|
||||
mu sync.Mutex
|
||||
running bool
|
||||
stopChan chan struct{}
|
||||
sharedBind *bind.SharedBind
|
||||
olmID string
|
||||
token string
|
||||
domainResolver DomainResolver
|
||||
}
|
||||
|
||||
// NewManager creates a new hole punch manager
|
||||
func NewManager(sharedBind *bind.SharedBind, olmID string, domainResolver DomainResolver) *Manager {
|
||||
return &Manager{
|
||||
sharedBind: sharedBind,
|
||||
olmID: olmID,
|
||||
domainResolver: domainResolver,
|
||||
}
|
||||
}
|
||||
|
||||
// SetToken updates the authentication token used for hole punching
|
||||
func (m *Manager) SetToken(token string) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.token = token
|
||||
}
|
||||
|
||||
// IsRunning returns whether hole punching is currently active
|
||||
func (m *Manager) IsRunning() bool {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
return m.running
|
||||
}
|
||||
|
||||
// Stop stops any ongoing hole punch operations
|
||||
func (m *Manager) Stop() {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if !m.running {
|
||||
return
|
||||
}
|
||||
|
||||
if m.stopChan != nil {
|
||||
close(m.stopChan)
|
||||
m.stopChan = nil
|
||||
}
|
||||
|
||||
m.running = false
|
||||
logger.Info("Hole punch manager stopped")
|
||||
}
|
||||
|
||||
// StartMultipleExitNodes starts hole punching to multiple exit nodes
|
||||
func (m *Manager) StartMultipleExitNodes(exitNodes []ExitNode) error {
|
||||
m.mu.Lock()
|
||||
|
||||
if m.running {
|
||||
m.mu.Unlock()
|
||||
logger.Debug("UDP hole punch already running, skipping new request")
|
||||
return fmt.Errorf("hole punch already running")
|
||||
}
|
||||
|
||||
if len(exitNodes) == 0 {
|
||||
m.mu.Unlock()
|
||||
logger.Warn("No exit nodes provided for hole punching")
|
||||
return fmt.Errorf("no exit nodes provided")
|
||||
}
|
||||
|
||||
m.running = true
|
||||
m.stopChan = make(chan struct{})
|
||||
m.mu.Unlock()
|
||||
|
||||
logger.Info("Starting UDP hole punch to %d exit nodes with shared bind", len(exitNodes))
|
||||
|
||||
go m.runMultipleExitNodes(exitNodes)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// StartSingleEndpoint starts hole punching to a single endpoint (legacy mode)
|
||||
func (m *Manager) StartSingleEndpoint(endpoint, serverPubKey string) error {
|
||||
m.mu.Lock()
|
||||
|
||||
if m.running {
|
||||
m.mu.Unlock()
|
||||
logger.Debug("UDP hole punch already running, skipping new request")
|
||||
return fmt.Errorf("hole punch already running")
|
||||
}
|
||||
|
||||
m.running = true
|
||||
m.stopChan = make(chan struct{})
|
||||
m.mu.Unlock()
|
||||
|
||||
logger.Info("Starting UDP hole punch to %s with shared bind", endpoint)
|
||||
|
||||
go m.runSingleEndpoint(endpoint, serverPubKey)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// runMultipleExitNodes performs hole punching to multiple exit nodes
|
||||
func (m *Manager) runMultipleExitNodes(exitNodes []ExitNode) {
|
||||
defer func() {
|
||||
m.mu.Lock()
|
||||
m.running = false
|
||||
m.mu.Unlock()
|
||||
logger.Info("UDP hole punch goroutine ended for all exit nodes")
|
||||
}()
|
||||
|
||||
// Resolve all endpoints upfront
|
||||
type resolvedExitNode struct {
|
||||
remoteAddr *net.UDPAddr
|
||||
publicKey string
|
||||
endpointName string
|
||||
}
|
||||
|
||||
var resolvedNodes []resolvedExitNode
|
||||
for _, exitNode := range exitNodes {
|
||||
host, err := m.domainResolver(exitNode.Endpoint)
|
||||
if err != nil {
|
||||
logger.Warn("Failed to resolve endpoint %s: %v", exitNode.Endpoint, err)
|
||||
continue
|
||||
}
|
||||
|
||||
serverAddr := net.JoinHostPort(host, "21820")
|
||||
remoteAddr, err := net.ResolveUDPAddr("udp", serverAddr)
|
||||
if err != nil {
|
||||
logger.Error("Failed to resolve UDP address %s: %v", serverAddr, err)
|
||||
continue
|
||||
}
|
||||
|
||||
resolvedNodes = append(resolvedNodes, resolvedExitNode{
|
||||
remoteAddr: remoteAddr,
|
||||
publicKey: exitNode.PublicKey,
|
||||
endpointName: exitNode.Endpoint,
|
||||
})
|
||||
logger.Info("Resolved exit node: %s -> %s", exitNode.Endpoint, remoteAddr.String())
|
||||
}
|
||||
|
||||
if len(resolvedNodes) == 0 {
|
||||
logger.Error("No exit nodes could be resolved")
|
||||
return
|
||||
}
|
||||
|
||||
// Send initial hole punch to all exit nodes
|
||||
for _, node := range resolvedNodes {
|
||||
if err := m.sendHolePunch(node.remoteAddr, node.publicKey); err != nil {
|
||||
logger.Warn("Failed to send initial hole punch to %s: %v", node.endpointName, err)
|
||||
}
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(250 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
timeout := time.NewTimer(15 * time.Second)
|
||||
defer timeout.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-m.stopChan:
|
||||
logger.Debug("Hole punch stopped by signal")
|
||||
return
|
||||
case <-timeout.C:
|
||||
logger.Debug("Hole punch timeout reached")
|
||||
return
|
||||
case <-ticker.C:
|
||||
// Send hole punch to all exit nodes
|
||||
for _, node := range resolvedNodes {
|
||||
if err := m.sendHolePunch(node.remoteAddr, node.publicKey); err != nil {
|
||||
logger.Debug("Failed to send hole punch to %s: %v", node.endpointName, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// runSingleEndpoint performs hole punching to a single endpoint
|
||||
func (m *Manager) runSingleEndpoint(endpoint, serverPubKey string) {
|
||||
defer func() {
|
||||
m.mu.Lock()
|
||||
m.running = false
|
||||
m.mu.Unlock()
|
||||
logger.Info("UDP hole punch goroutine ended for %s", endpoint)
|
||||
}()
|
||||
|
||||
host, err := m.domainResolver(endpoint)
|
||||
if err != nil {
|
||||
logger.Error("Failed to resolve domain %s: %v", endpoint, err)
|
||||
return
|
||||
}
|
||||
|
||||
serverAddr := net.JoinHostPort(host, "21820")
|
||||
|
||||
remoteAddr, err := net.ResolveUDPAddr("udp", serverAddr)
|
||||
if err != nil {
|
||||
logger.Error("Failed to resolve UDP address %s: %v", serverAddr, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Execute once immediately before starting the loop
|
||||
if err := m.sendHolePunch(remoteAddr, serverPubKey); err != nil {
|
||||
logger.Warn("Failed to send initial hole punch: %v", err)
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(250 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
timeout := time.NewTimer(15 * time.Second)
|
||||
defer timeout.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-m.stopChan:
|
||||
logger.Debug("Hole punch stopped by signal")
|
||||
return
|
||||
case <-timeout.C:
|
||||
logger.Debug("Hole punch timeout reached")
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := m.sendHolePunch(remoteAddr, serverPubKey); err != nil {
|
||||
logger.Debug("Failed to send hole punch: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// sendHolePunch sends an encrypted hole punch packet using the shared bind
|
||||
func (m *Manager) sendHolePunch(remoteAddr *net.UDPAddr, serverPubKey string) error {
|
||||
m.mu.Lock()
|
||||
token := m.token
|
||||
olmID := m.olmID
|
||||
m.mu.Unlock()
|
||||
|
||||
if serverPubKey == "" || token == "" {
|
||||
return fmt.Errorf("server public key or OLM token is empty")
|
||||
}
|
||||
|
||||
payload := struct {
|
||||
OlmID string `json:"olmId"`
|
||||
Token string `json:"token"`
|
||||
}{
|
||||
OlmID: olmID,
|
||||
Token: token,
|
||||
}
|
||||
|
||||
// Convert payload to JSON
|
||||
payloadBytes, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal payload: %w", err)
|
||||
}
|
||||
|
||||
// Encrypt the payload using the server's WireGuard public key
|
||||
encryptedPayload, err := encryptPayload(payloadBytes, serverPubKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encrypt payload: %w", err)
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(encryptedPayload)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal encrypted payload: %w", err)
|
||||
}
|
||||
|
||||
_, err = m.sharedBind.WriteToUDP(jsonData, remoteAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to write to UDP: %w", err)
|
||||
}
|
||||
|
||||
logger.Debug("Sent UDP hole punch to %s: %s", remoteAddr.String(), string(jsonData))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// encryptPayload encrypts the payload using ChaCha20-Poly1305 AEAD with X25519 key exchange
|
||||
func encryptPayload(payload []byte, serverPublicKey string) (interface{}, error) {
|
||||
// Generate an ephemeral keypair for this message
|
||||
ephemeralPrivateKey, err := wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate ephemeral private key: %v", err)
|
||||
}
|
||||
ephemeralPublicKey := ephemeralPrivateKey.PublicKey()
|
||||
|
||||
// Parse the server's public key
|
||||
serverPubKey, err := wgtypes.ParseKey(serverPublicKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse server public key: %v", err)
|
||||
}
|
||||
|
||||
// Use X25519 for key exchange
|
||||
var ephPrivKeyFixed [32]byte
|
||||
copy(ephPrivKeyFixed[:], ephemeralPrivateKey[:])
|
||||
|
||||
// Perform X25519 key exchange
|
||||
sharedSecret, err := curve25519.X25519(ephPrivKeyFixed[:], serverPubKey[:])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to perform X25519 key exchange: %v", err)
|
||||
}
|
||||
|
||||
// Create an AEAD cipher using the shared secret
|
||||
aead, err := chacha20poly1305.New(sharedSecret)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create AEAD cipher: %v", err)
|
||||
}
|
||||
|
||||
// Generate a random nonce
|
||||
nonce := make([]byte, aead.NonceSize())
|
||||
if _, err := rand.Read(nonce); err != nil {
|
||||
return nil, fmt.Errorf("failed to generate nonce: %v", err)
|
||||
}
|
||||
|
||||
// Encrypt the payload
|
||||
ciphertext := aead.Seal(nil, nonce, payload, nil)
|
||||
|
||||
// Prepare the final encrypted message
|
||||
encryptedMsg := struct {
|
||||
EphemeralPublicKey string `json:"ephemeralPublicKey"`
|
||||
Nonce []byte `json:"nonce"`
|
||||
Ciphertext []byte `json:"ciphertext"`
|
||||
}{
|
||||
EphemeralPublicKey: ephemeralPublicKey.String(),
|
||||
Nonce: nonce,
|
||||
Ciphertext: ciphertext,
|
||||
}
|
||||
|
||||
return encryptedMsg, nil
|
||||
}
|
||||
BIN
olm-binary
BIN
olm-binary
Binary file not shown.
467
olm/common.go
467
olm/common.go
@@ -3,7 +3,6 @@ package olm
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"os/exec"
|
||||
@@ -14,12 +13,9 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/fosrl/newt/logger"
|
||||
"github.com/fosrl/olm/bind"
|
||||
"github.com/fosrl/olm/peermonitor"
|
||||
"github.com/fosrl/olm/websocket"
|
||||
"github.com/vishvananda/netlink"
|
||||
"golang.org/x/crypto/chacha20poly1305"
|
||||
"golang.org/x/crypto/curve25519"
|
||||
"golang.org/x/exp/rand"
|
||||
"golang.zx2c4.com/wireguard/device"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
@@ -192,7 +188,7 @@ func mapToWireGuardLogLevel(level logger.LogLevel) int {
|
||||
}
|
||||
}
|
||||
|
||||
func resolveDomain(domain string) (string, error) {
|
||||
func ResolveDomain(domain string) (string, error) {
|
||||
// First handle any protocol prefix
|
||||
domain = strings.TrimPrefix(strings.TrimPrefix(domain, "https://"), "http://")
|
||||
|
||||
@@ -239,463 +235,6 @@ func resolveDomain(domain string) (string, error) {
|
||||
return ipAddr, nil
|
||||
}
|
||||
|
||||
func sendUDPHolePunchWithConn(conn *net.UDPConn, remoteAddr *net.UDPAddr, olmID string, serverPubKey string) error {
|
||||
if serverPubKey == "" || olmToken == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
payload := struct {
|
||||
OlmID string `json:"olmId"`
|
||||
Token string `json:"token"`
|
||||
}{
|
||||
OlmID: olmID,
|
||||
Token: olmToken,
|
||||
}
|
||||
|
||||
// Convert payload to JSON
|
||||
payloadBytes, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal payload: %v", err)
|
||||
}
|
||||
|
||||
// Encrypt the payload using the server's WireGuard public key
|
||||
encryptedPayload, err := encryptPayload(payloadBytes, serverPubKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encrypt payload: %v", err)
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(encryptedPayload)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal encrypted payload: %v", err)
|
||||
}
|
||||
|
||||
_, err = conn.WriteToUDP(jsonData, remoteAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to send UDP packet: %v", err)
|
||||
}
|
||||
|
||||
logger.Debug("Sent UDP hole punch to %s: %s", remoteAddr.String(), string(jsonData))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func encryptPayload(payload []byte, serverPublicKey string) (interface{}, error) {
|
||||
// Generate an ephemeral keypair for this message
|
||||
ephemeralPrivateKey, err := wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate ephemeral private key: %v", err)
|
||||
}
|
||||
ephemeralPublicKey := ephemeralPrivateKey.PublicKey()
|
||||
|
||||
// Parse the server's public key
|
||||
serverPubKey, err := wgtypes.ParseKey(serverPublicKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse server public key: %v", err)
|
||||
}
|
||||
|
||||
// Use X25519 for key exchange (replacing deprecated ScalarMult)
|
||||
var ephPrivKeyFixed [32]byte
|
||||
copy(ephPrivKeyFixed[:], ephemeralPrivateKey[:])
|
||||
|
||||
// Perform X25519 key exchange
|
||||
sharedSecret, err := curve25519.X25519(ephPrivKeyFixed[:], serverPubKey[:])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to perform X25519 key exchange: %v", err)
|
||||
}
|
||||
|
||||
// Create an AEAD cipher using the shared secret
|
||||
aead, err := chacha20poly1305.New(sharedSecret)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create AEAD cipher: %v", err)
|
||||
}
|
||||
|
||||
// Generate a random nonce
|
||||
nonce := make([]byte, aead.NonceSize())
|
||||
if _, err := rand.Read(nonce); err != nil {
|
||||
return nil, fmt.Errorf("failed to generate nonce: %v", err)
|
||||
}
|
||||
|
||||
// Encrypt the payload
|
||||
ciphertext := aead.Seal(nil, nonce, payload, nil)
|
||||
|
||||
// Prepare the final encrypted message
|
||||
encryptedMsg := struct {
|
||||
EphemeralPublicKey string `json:"ephemeralPublicKey"`
|
||||
Nonce []byte `json:"nonce"`
|
||||
Ciphertext []byte `json:"ciphertext"`
|
||||
}{
|
||||
EphemeralPublicKey: ephemeralPublicKey.String(),
|
||||
Nonce: nonce,
|
||||
Ciphertext: ciphertext,
|
||||
}
|
||||
|
||||
return encryptedMsg, nil
|
||||
}
|
||||
|
||||
func keepSendingUDPHolePunchToMultipleExitNodes(exitNodes []ExitNode, olmID string, sourcePort uint16) {
|
||||
if len(exitNodes) == 0 {
|
||||
logger.Warn("No exit nodes provided for hole punching")
|
||||
return
|
||||
}
|
||||
|
||||
// Check if hole punching is already running
|
||||
if holePunchRunning {
|
||||
logger.Debug("UDP hole punch already running, skipping new request")
|
||||
return
|
||||
}
|
||||
|
||||
// Set the flag to indicate hole punching is running
|
||||
holePunchRunning = true
|
||||
defer func() {
|
||||
holePunchRunning = false
|
||||
logger.Info("UDP hole punch goroutine ended")
|
||||
}()
|
||||
|
||||
logger.Info("Starting UDP hole punch to %d exit nodes", len(exitNodes))
|
||||
defer logger.Info("UDP hole punch goroutine ended for all exit nodes")
|
||||
|
||||
// Create the UDP connection once and reuse it for all exit nodes
|
||||
localAddr := &net.UDPAddr{
|
||||
Port: int(sourcePort),
|
||||
IP: net.IPv4zero,
|
||||
}
|
||||
|
||||
conn, err := net.ListenUDP("udp", localAddr)
|
||||
if err != nil {
|
||||
logger.Error("Failed to bind UDP socket: %v", err)
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
// Resolve all endpoints upfront
|
||||
type resolvedExitNode struct {
|
||||
remoteAddr *net.UDPAddr
|
||||
publicKey string
|
||||
endpointName string
|
||||
}
|
||||
|
||||
var resolvedNodes []resolvedExitNode
|
||||
for _, exitNode := range exitNodes {
|
||||
host, err := resolveDomain(exitNode.Endpoint)
|
||||
if err != nil {
|
||||
logger.Error("Failed to resolve endpoint %s: %v", exitNode.Endpoint, err)
|
||||
continue
|
||||
}
|
||||
|
||||
serverAddr := net.JoinHostPort(host, "21820")
|
||||
remoteAddr, err := net.ResolveUDPAddr("udp", serverAddr)
|
||||
if err != nil {
|
||||
logger.Error("Failed to resolve UDP address for %s: %v", exitNode.Endpoint, err)
|
||||
continue
|
||||
}
|
||||
|
||||
resolvedNodes = append(resolvedNodes, resolvedExitNode{
|
||||
remoteAddr: remoteAddr,
|
||||
publicKey: exitNode.PublicKey,
|
||||
endpointName: exitNode.Endpoint,
|
||||
})
|
||||
logger.Info("Resolved exit node: %s -> %s", exitNode.Endpoint, remoteAddr.String())
|
||||
}
|
||||
|
||||
if len(resolvedNodes) == 0 {
|
||||
logger.Error("No exit nodes could be resolved")
|
||||
return
|
||||
}
|
||||
|
||||
// Send initial hole punch to all exit nodes
|
||||
for _, node := range resolvedNodes {
|
||||
if err := sendUDPHolePunchWithConn(conn, node.remoteAddr, olmID, node.publicKey); err != nil {
|
||||
logger.Error("Failed to send initial UDP hole punch to %s: %v", node.endpointName, err)
|
||||
}
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(250 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
timeout := time.NewTimer(15 * time.Second)
|
||||
defer timeout.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-stopHolepunch:
|
||||
logger.Info("Stopping UDP holepunch for all exit nodes")
|
||||
return
|
||||
case <-timeout.C:
|
||||
logger.Info("UDP holepunch routine timed out after 15 seconds for all exit nodes")
|
||||
return
|
||||
case <-ticker.C:
|
||||
// Send hole punch to all exit nodes
|
||||
for _, node := range resolvedNodes {
|
||||
if err := sendUDPHolePunchWithConn(conn, node.remoteAddr, olmID, node.publicKey); err != nil {
|
||||
logger.Error("Failed to send UDP hole punch to %s: %v", node.endpointName, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func keepSendingUDPHolePunch(endpoint string, olmID string, sourcePort uint16, serverPubKey string) {
|
||||
|
||||
// Check if hole punching is already running
|
||||
if holePunchRunning {
|
||||
logger.Debug("UDP hole punch already running, skipping new request")
|
||||
return
|
||||
}
|
||||
|
||||
// Set the flag to indicate hole punching is running
|
||||
holePunchRunning = true
|
||||
defer func() {
|
||||
holePunchRunning = false
|
||||
logger.Info("UDP hole punch goroutine ended")
|
||||
}()
|
||||
|
||||
logger.Info("Starting UDP hole punch to %s", endpoint)
|
||||
defer logger.Info("UDP hole punch goroutine ended for %s", endpoint)
|
||||
|
||||
host, err := resolveDomain(endpoint)
|
||||
if err != nil {
|
||||
logger.Error("Failed to resolve endpoint: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
serverAddr := net.JoinHostPort(host, "21820")
|
||||
|
||||
// Create the UDP connection once and reuse it
|
||||
localAddr := &net.UDPAddr{
|
||||
Port: int(sourcePort),
|
||||
IP: net.IPv4zero,
|
||||
}
|
||||
|
||||
remoteAddr, err := net.ResolveUDPAddr("udp", serverAddr)
|
||||
if err != nil {
|
||||
logger.Error("Failed to resolve UDP address: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
conn, err := net.ListenUDP("udp", localAddr)
|
||||
if err != nil {
|
||||
logger.Error("Failed to bind UDP socket: %v", err)
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
// Execute once immediately before starting the loop
|
||||
if err := sendUDPHolePunchWithConn(conn, remoteAddr, olmID, serverPubKey); err != nil {
|
||||
logger.Error("Failed to send UDP hole punch: %v", err)
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(250 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
timeout := time.NewTimer(15 * time.Second)
|
||||
defer timeout.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-stopHolepunch:
|
||||
logger.Info("Stopping UDP holepunch")
|
||||
return
|
||||
case <-timeout.C:
|
||||
logger.Info("UDP holepunch routine timed out after 15 seconds")
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := sendUDPHolePunchWithConn(conn, remoteAddr, olmID, serverPubKey); err != nil {
|
||||
logger.Error("Failed to send UDP hole punch: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// keepSendingUDPHolePunchToMultipleExitNodesWithSharedBind sends hole punch packets using the shared bind
|
||||
func keepSendingUDPHolePunchToMultipleExitNodesWithSharedBind(exitNodes []ExitNode, olmID string, sharedBind *bind.SharedBind) {
|
||||
if len(exitNodes) == 0 {
|
||||
logger.Warn("No exit nodes provided for hole punching")
|
||||
return
|
||||
}
|
||||
|
||||
// Check if hole punching is already running
|
||||
if holePunchRunning {
|
||||
logger.Debug("UDP hole punch already running, skipping new request")
|
||||
return
|
||||
}
|
||||
|
||||
// Set the flag to indicate hole punching is running
|
||||
holePunchRunning = true
|
||||
defer func() {
|
||||
holePunchRunning = false
|
||||
logger.Info("UDP hole punch goroutine ended")
|
||||
}()
|
||||
|
||||
logger.Info("Starting UDP hole punch to %d exit nodes with shared bind", len(exitNodes))
|
||||
defer logger.Info("UDP hole punch goroutine ended for all exit nodes")
|
||||
|
||||
// Resolve all endpoints upfront
|
||||
type resolvedExitNode struct {
|
||||
remoteAddr *net.UDPAddr
|
||||
publicKey string
|
||||
endpointName string
|
||||
}
|
||||
|
||||
var resolvedNodes []resolvedExitNode
|
||||
for _, exitNode := range exitNodes {
|
||||
host, err := resolveDomain(exitNode.Endpoint)
|
||||
if err != nil {
|
||||
logger.Error("Failed to resolve endpoint %s: %v", exitNode.Endpoint, err)
|
||||
continue
|
||||
}
|
||||
|
||||
serverAddr := net.JoinHostPort(host, "21820")
|
||||
remoteAddr, err := net.ResolveUDPAddr("udp", serverAddr)
|
||||
if err != nil {
|
||||
logger.Error("Failed to resolve UDP address for %s: %v", exitNode.Endpoint, err)
|
||||
continue
|
||||
}
|
||||
|
||||
resolvedNodes = append(resolvedNodes, resolvedExitNode{
|
||||
remoteAddr: remoteAddr,
|
||||
publicKey: exitNode.PublicKey,
|
||||
endpointName: exitNode.Endpoint,
|
||||
})
|
||||
logger.Info("Resolved exit node: %s -> %s", exitNode.Endpoint, remoteAddr.String())
|
||||
}
|
||||
|
||||
if len(resolvedNodes) == 0 {
|
||||
logger.Error("No exit nodes could be resolved")
|
||||
return
|
||||
}
|
||||
|
||||
// Send initial hole punch to all exit nodes
|
||||
for _, node := range resolvedNodes {
|
||||
if err := sendUDPHolePunchWithBind(sharedBind, node.remoteAddr, olmID, node.publicKey); err != nil {
|
||||
logger.Error("Failed to send initial UDP hole punch to %s: %v", node.endpointName, err)
|
||||
}
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(250 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
timeout := time.NewTimer(15 * time.Second)
|
||||
defer timeout.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-stopHolepunch:
|
||||
logger.Info("Stopping UDP holepunch for all exit nodes")
|
||||
return
|
||||
case <-timeout.C:
|
||||
logger.Info("UDP holepunch routine timed out after 15 seconds for all exit nodes")
|
||||
return
|
||||
case <-ticker.C:
|
||||
// Send hole punch to all exit nodes
|
||||
for _, node := range resolvedNodes {
|
||||
if err := sendUDPHolePunchWithBind(sharedBind, node.remoteAddr, olmID, node.publicKey); err != nil {
|
||||
logger.Error("Failed to send UDP hole punch to %s: %v", node.endpointName, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// keepSendingUDPHolePunchWithSharedBind sends hole punch packets to a single endpoint using shared bind
|
||||
func keepSendingUDPHolePunchWithSharedBind(endpoint string, olmID string, sharedBind *bind.SharedBind, serverPubKey string) {
|
||||
// Check if hole punching is already running
|
||||
if holePunchRunning {
|
||||
logger.Debug("UDP hole punch already running, skipping new request")
|
||||
return
|
||||
}
|
||||
|
||||
// Set the flag to indicate hole punching is running
|
||||
holePunchRunning = true
|
||||
defer func() {
|
||||
holePunchRunning = false
|
||||
logger.Info("UDP hole punch goroutine ended")
|
||||
}()
|
||||
|
||||
logger.Info("Starting UDP hole punch to %s with shared bind", endpoint)
|
||||
defer logger.Info("UDP hole punch goroutine ended for %s", endpoint)
|
||||
|
||||
host, err := resolveDomain(endpoint)
|
||||
if err != nil {
|
||||
logger.Error("Failed to resolve domain %s: %v", endpoint, err)
|
||||
return
|
||||
}
|
||||
|
||||
serverAddr := net.JoinHostPort(host, "21820")
|
||||
|
||||
remoteAddr, err := net.ResolveUDPAddr("udp", serverAddr)
|
||||
if err != nil {
|
||||
logger.Error("Failed to resolve UDP address %s: %v", serverAddr, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Execute once immediately before starting the loop
|
||||
if err := sendUDPHolePunchWithBind(sharedBind, remoteAddr, olmID, serverPubKey); err != nil {
|
||||
logger.Error("Failed to send initial UDP hole punch: %v", err)
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(250 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
timeout := time.NewTimer(15 * time.Second)
|
||||
defer timeout.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-stopHolepunch:
|
||||
logger.Info("Stopping UDP holepunch")
|
||||
return
|
||||
case <-timeout.C:
|
||||
logger.Info("UDP holepunch routine timed out after 15 seconds")
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := sendUDPHolePunchWithBind(sharedBind, remoteAddr, olmID, serverPubKey); err != nil {
|
||||
logger.Error("Failed to send UDP hole punch: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// sendUDPHolePunchWithBind sends an encrypted hole punch packet using the shared bind
|
||||
func sendUDPHolePunchWithBind(sharedBind *bind.SharedBind, remoteAddr *net.UDPAddr, olmID string, serverPubKey string) error {
|
||||
if serverPubKey == "" || olmToken == "" {
|
||||
return fmt.Errorf("server public key or OLM token is empty")
|
||||
}
|
||||
|
||||
payload := struct {
|
||||
OlmID string `json:"olmId"`
|
||||
Token string `json:"token"`
|
||||
}{
|
||||
OlmID: olmID,
|
||||
Token: olmToken,
|
||||
}
|
||||
|
||||
// Convert payload to JSON
|
||||
payloadBytes, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal payload: %w", err)
|
||||
}
|
||||
|
||||
// Encrypt the payload using the server's WireGuard public key
|
||||
encryptedPayload, err := encryptPayload(payloadBytes, serverPubKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encrypt payload: %w", err)
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(encryptedPayload)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal encrypted payload: %w", err)
|
||||
}
|
||||
|
||||
_, err = sharedBind.WriteToUDP(jsonData, remoteAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to write to UDP: %w", err)
|
||||
}
|
||||
|
||||
logger.Debug("Sent UDP hole punch to %s: %s", remoteAddr.String(), string(jsonData))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func FindAvailableUDPPort(minPort, maxPort uint16) (uint16, error) {
|
||||
if maxPort < minPort {
|
||||
return 0, fmt.Errorf("invalid port range: min=%d, max=%d", minPort, maxPort)
|
||||
@@ -772,7 +311,7 @@ func keepSendingPing(olm *websocket.Client) {
|
||||
|
||||
// ConfigurePeer sets up or updates a peer within the WireGuard device
|
||||
func ConfigurePeer(dev *device.Device, siteConfig SiteConfig, privateKey wgtypes.Key, endpoint string) error {
|
||||
siteHost, err := resolveDomain(siteConfig.Endpoint)
|
||||
siteHost, err := ResolveDomain(siteConfig.Endpoint)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to resolve endpoint for site %d: %v", siteConfig.SiteId, err)
|
||||
}
|
||||
@@ -829,7 +368,7 @@ func ConfigurePeer(dev *device.Device, siteConfig SiteConfig, privateKey wgtypes
|
||||
monitorPeer := net.JoinHostPort(monitorAddress, strconv.Itoa(int(siteConfig.ServerPort+1))) // +1 for the monitor port
|
||||
logger.Debug("Setting up peer monitor for site %d at %s", siteConfig.SiteId, monitorPeer)
|
||||
|
||||
primaryRelay, err := resolveDomain(endpoint) // Using global endpoint variable
|
||||
primaryRelay, err := ResolveDomain(endpoint) // Using global endpoint variable
|
||||
if err != nil {
|
||||
logger.Warn("Failed to resolve primary relay endpoint: %v", err)
|
||||
}
|
||||
|
||||
86
olm/olm.go
86
olm/olm.go
@@ -13,6 +13,7 @@ import (
|
||||
"github.com/fosrl/newt/updates"
|
||||
"github.com/fosrl/olm/api"
|
||||
"github.com/fosrl/olm/bind"
|
||||
"github.com/fosrl/olm/holepunch"
|
||||
"github.com/fosrl/olm/peermonitor"
|
||||
"github.com/fosrl/olm/websocket"
|
||||
"golang.zx2c4.com/wireguard/device"
|
||||
@@ -57,18 +58,19 @@ type Config struct {
|
||||
}
|
||||
|
||||
var (
|
||||
privateKey wgtypes.Key
|
||||
connected bool
|
||||
dev *device.Device
|
||||
wgData WgData
|
||||
holePunchData HolePunchData
|
||||
uapiListener net.Listener
|
||||
tdev tun.Device
|
||||
apiServer *api.API
|
||||
olmClient *websocket.Client
|
||||
tunnelCancel context.CancelFunc
|
||||
tunnelRunning bool
|
||||
sharedBind *bind.SharedBind
|
||||
privateKey wgtypes.Key
|
||||
connected bool
|
||||
dev *device.Device
|
||||
wgData WgData
|
||||
holePunchData HolePunchData
|
||||
uapiListener net.Listener
|
||||
tdev tun.Device
|
||||
apiServer *api.API
|
||||
olmClient *websocket.Client
|
||||
tunnelCancel context.CancelFunc
|
||||
tunnelRunning bool
|
||||
sharedBind *bind.SharedBind
|
||||
holePunchManager *holepunch.Manager
|
||||
)
|
||||
|
||||
func Run(ctx context.Context, config Config) {
|
||||
@@ -197,7 +199,6 @@ func TunnelProcess(ctx context.Context, config Config, id string, secret string,
|
||||
}()
|
||||
|
||||
// Recreate channels for this tunnel session
|
||||
stopHolepunch = make(chan struct{})
|
||||
stopPing = make(chan struct{})
|
||||
|
||||
var (
|
||||
@@ -260,6 +261,11 @@ func TunnelProcess(ctx context.Context, config Config, id string, secret string,
|
||||
logger.Info("Created shared UDP socket on port %d (refcount: %d)", sourcePort, sharedBind.GetRefCount())
|
||||
}
|
||||
|
||||
// Create the holepunch manager
|
||||
if holePunchManager == nil {
|
||||
holePunchManager = holepunch.NewManager(sharedBind, id, ResolveDomain)
|
||||
}
|
||||
|
||||
olm.RegisterHandler("olm/wg/holepunch/all", func(msg websocket.WSMessage) {
|
||||
logger.Debug("Received message: %v", msg.Data)
|
||||
|
||||
@@ -274,12 +280,20 @@ func TunnelProcess(ctx context.Context, config Config, id string, secret string,
|
||||
return
|
||||
}
|
||||
|
||||
// Create a new stopHolepunch channel for the new set of goroutines
|
||||
stopHolepunch = make(chan struct{})
|
||||
// Convert HolePunchData.ExitNodes to holepunch.ExitNode slice
|
||||
exitNodes := make([]holepunch.ExitNode, len(holePunchData.ExitNodes))
|
||||
for i, node := range holePunchData.ExitNodes {
|
||||
exitNodes[i] = holepunch.ExitNode{
|
||||
Endpoint: node.Endpoint,
|
||||
PublicKey: node.PublicKey,
|
||||
}
|
||||
}
|
||||
|
||||
// Start a single hole punch goroutine for all exit nodes
|
||||
logger.Info("Starting hole punch for %d exit nodes", len(holePunchData.ExitNodes))
|
||||
go keepSendingUDPHolePunchToMultipleExitNodesWithSharedBind(holePunchData.ExitNodes, id, sharedBind)
|
||||
// Start hole punching using the manager
|
||||
logger.Info("Starting hole punch for %d exit nodes", len(exitNodes))
|
||||
if err := holePunchManager.StartMultipleExitNodes(exitNodes); err != nil {
|
||||
logger.Warn("Failed to start hole punch: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
olm.RegisterHandler("olm/wg/holepunch", func(msg websocket.WSMessage) {
|
||||
@@ -304,20 +318,16 @@ func TunnelProcess(ctx context.Context, config Config, id string, secret string,
|
||||
return
|
||||
}
|
||||
|
||||
// Stop any existing hole punch goroutines by closing the current channel
|
||||
select {
|
||||
case <-stopHolepunch:
|
||||
// Channel already closed
|
||||
default:
|
||||
close(stopHolepunch)
|
||||
// Stop any existing hole punch operations
|
||||
if holePunchManager != nil {
|
||||
holePunchManager.Stop()
|
||||
}
|
||||
|
||||
// Create a new stopHolepunch channel for the new set of goroutines
|
||||
stopHolepunch = make(chan struct{})
|
||||
|
||||
// Start hole punching for each exit node
|
||||
// Start hole punching for the exit node
|
||||
logger.Info("Starting hole punch for exit node: %s with public key: %s", legacyHolePunchData.Endpoint, legacyHolePunchData.ServerPubKey)
|
||||
go keepSendingUDPHolePunchWithSharedBind(legacyHolePunchData.Endpoint, id, sharedBind, legacyHolePunchData.ServerPubKey)
|
||||
if err := holePunchManager.StartSingleEndpoint(legacyHolePunchData.Endpoint, legacyHolePunchData.ServerPubKey); err != nil {
|
||||
logger.Warn("Failed to start hole punch: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
olm.RegisterHandler("olm/wg/connect", func(msg websocket.WSMessage) {
|
||||
@@ -407,6 +417,7 @@ func TunnelProcess(ctx context.Context, config Config, id string, secret string,
|
||||
for {
|
||||
conn, err := uapiListener.Accept()
|
||||
if err != nil {
|
||||
|
||||
return
|
||||
}
|
||||
go dev.IpcHandle(conn)
|
||||
@@ -696,7 +707,7 @@ func TunnelProcess(ctx context.Context, config Config, id string, secret string,
|
||||
return
|
||||
}
|
||||
|
||||
primaryRelay, err := resolveDomain(relayData.Endpoint)
|
||||
primaryRelay, err := ResolveDomain(relayData.Endpoint)
|
||||
if err != nil {
|
||||
logger.Warn("Failed to resolve primary relay endpoint: %v", err)
|
||||
}
|
||||
@@ -752,7 +763,9 @@ func TunnelProcess(ctx context.Context, config Config, id string, secret string,
|
||||
})
|
||||
|
||||
olm.OnTokenUpdate(func(token string) {
|
||||
olmToken = token
|
||||
if holePunchManager != nil {
|
||||
holePunchManager.SetToken(token)
|
||||
}
|
||||
})
|
||||
|
||||
// Connect to the WebSocket server
|
||||
@@ -780,7 +793,6 @@ func TunnelProcess(ctx context.Context, config Config, id string, secret string,
|
||||
apiServer.SetTunnelIP("")
|
||||
apiServer.SetOrgID(config.OrgID)
|
||||
|
||||
stopHolepunch = make(chan struct{})
|
||||
// Trigger re-registration with new orgId
|
||||
logger.Info("Re-registering with new orgId: %s", config.OrgID)
|
||||
publicKey := privateKey.PublicKey()
|
||||
@@ -799,13 +811,9 @@ func TunnelProcess(ctx context.Context, config Config, id string, secret string,
|
||||
}
|
||||
|
||||
func Stop() {
|
||||
if stopHolepunch != nil {
|
||||
select {
|
||||
case <-stopHolepunch:
|
||||
// Channel already closed, do nothing
|
||||
default:
|
||||
close(stopHolepunch)
|
||||
}
|
||||
// Stop hole punch manager
|
||||
if holePunchManager != nil {
|
||||
holePunchManager.Stop()
|
||||
}
|
||||
|
||||
if stopPing != nil {
|
||||
|
||||
Reference in New Issue
Block a user