mirror of
https://github.com/fosrl/olm.git
synced 2026-03-01 16:26:43 +00:00
Split out hp
This commit is contained in:
351
holepunch/holepunch.go
Normal file
351
holepunch/holepunch.go
Normal file
@@ -0,0 +1,351 @@
|
|||||||
|
package holepunch
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/fosrl/newt/logger"
|
||||||
|
"github.com/fosrl/olm/bind"
|
||||||
|
"golang.org/x/crypto/chacha20poly1305"
|
||||||
|
"golang.org/x/crypto/curve25519"
|
||||||
|
"golang.org/x/exp/rand"
|
||||||
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DomainResolver is a function type for resolving domains to IP addresses
|
||||||
|
type DomainResolver func(string) (string, error)
|
||||||
|
|
||||||
|
// ExitNode represents a WireGuard exit node for hole punching
|
||||||
|
type ExitNode struct {
|
||||||
|
Endpoint string `json:"endpoint"`
|
||||||
|
PublicKey string `json:"publicKey"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Manager handles UDP hole punching operations
|
||||||
|
type Manager struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
running bool
|
||||||
|
stopChan chan struct{}
|
||||||
|
sharedBind *bind.SharedBind
|
||||||
|
olmID string
|
||||||
|
token string
|
||||||
|
domainResolver DomainResolver
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewManager creates a new hole punch manager
|
||||||
|
func NewManager(sharedBind *bind.SharedBind, olmID string, domainResolver DomainResolver) *Manager {
|
||||||
|
return &Manager{
|
||||||
|
sharedBind: sharedBind,
|
||||||
|
olmID: olmID,
|
||||||
|
domainResolver: domainResolver,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetToken updates the authentication token used for hole punching
|
||||||
|
func (m *Manager) SetToken(token string) {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
m.token = token
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsRunning returns whether hole punching is currently active
|
||||||
|
func (m *Manager) IsRunning() bool {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
return m.running
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop stops any ongoing hole punch operations
|
||||||
|
func (m *Manager) Stop() {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
if !m.running {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.stopChan != nil {
|
||||||
|
close(m.stopChan)
|
||||||
|
m.stopChan = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
m.running = false
|
||||||
|
logger.Info("Hole punch manager stopped")
|
||||||
|
}
|
||||||
|
|
||||||
|
// StartMultipleExitNodes starts hole punching to multiple exit nodes
|
||||||
|
func (m *Manager) StartMultipleExitNodes(exitNodes []ExitNode) error {
|
||||||
|
m.mu.Lock()
|
||||||
|
|
||||||
|
if m.running {
|
||||||
|
m.mu.Unlock()
|
||||||
|
logger.Debug("UDP hole punch already running, skipping new request")
|
||||||
|
return fmt.Errorf("hole punch already running")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(exitNodes) == 0 {
|
||||||
|
m.mu.Unlock()
|
||||||
|
logger.Warn("No exit nodes provided for hole punching")
|
||||||
|
return fmt.Errorf("no exit nodes provided")
|
||||||
|
}
|
||||||
|
|
||||||
|
m.running = true
|
||||||
|
m.stopChan = make(chan struct{})
|
||||||
|
m.mu.Unlock()
|
||||||
|
|
||||||
|
logger.Info("Starting UDP hole punch to %d exit nodes with shared bind", len(exitNodes))
|
||||||
|
|
||||||
|
go m.runMultipleExitNodes(exitNodes)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// StartSingleEndpoint starts hole punching to a single endpoint (legacy mode)
|
||||||
|
func (m *Manager) StartSingleEndpoint(endpoint, serverPubKey string) error {
|
||||||
|
m.mu.Lock()
|
||||||
|
|
||||||
|
if m.running {
|
||||||
|
m.mu.Unlock()
|
||||||
|
logger.Debug("UDP hole punch already running, skipping new request")
|
||||||
|
return fmt.Errorf("hole punch already running")
|
||||||
|
}
|
||||||
|
|
||||||
|
m.running = true
|
||||||
|
m.stopChan = make(chan struct{})
|
||||||
|
m.mu.Unlock()
|
||||||
|
|
||||||
|
logger.Info("Starting UDP hole punch to %s with shared bind", endpoint)
|
||||||
|
|
||||||
|
go m.runSingleEndpoint(endpoint, serverPubKey)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// runMultipleExitNodes performs hole punching to multiple exit nodes
|
||||||
|
func (m *Manager) runMultipleExitNodes(exitNodes []ExitNode) {
|
||||||
|
defer func() {
|
||||||
|
m.mu.Lock()
|
||||||
|
m.running = false
|
||||||
|
m.mu.Unlock()
|
||||||
|
logger.Info("UDP hole punch goroutine ended for all exit nodes")
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Resolve all endpoints upfront
|
||||||
|
type resolvedExitNode struct {
|
||||||
|
remoteAddr *net.UDPAddr
|
||||||
|
publicKey string
|
||||||
|
endpointName string
|
||||||
|
}
|
||||||
|
|
||||||
|
var resolvedNodes []resolvedExitNode
|
||||||
|
for _, exitNode := range exitNodes {
|
||||||
|
host, err := m.domainResolver(exitNode.Endpoint)
|
||||||
|
if err != nil {
|
||||||
|
logger.Warn("Failed to resolve endpoint %s: %v", exitNode.Endpoint, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
serverAddr := net.JoinHostPort(host, "21820")
|
||||||
|
remoteAddr, err := net.ResolveUDPAddr("udp", serverAddr)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Failed to resolve UDP address %s: %v", serverAddr, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
resolvedNodes = append(resolvedNodes, resolvedExitNode{
|
||||||
|
remoteAddr: remoteAddr,
|
||||||
|
publicKey: exitNode.PublicKey,
|
||||||
|
endpointName: exitNode.Endpoint,
|
||||||
|
})
|
||||||
|
logger.Info("Resolved exit node: %s -> %s", exitNode.Endpoint, remoteAddr.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(resolvedNodes) == 0 {
|
||||||
|
logger.Error("No exit nodes could be resolved")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send initial hole punch to all exit nodes
|
||||||
|
for _, node := range resolvedNodes {
|
||||||
|
if err := m.sendHolePunch(node.remoteAddr, node.publicKey); err != nil {
|
||||||
|
logger.Warn("Failed to send initial hole punch to %s: %v", node.endpointName, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ticker := time.NewTicker(250 * time.Millisecond)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
timeout := time.NewTimer(15 * time.Second)
|
||||||
|
defer timeout.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-m.stopChan:
|
||||||
|
logger.Debug("Hole punch stopped by signal")
|
||||||
|
return
|
||||||
|
case <-timeout.C:
|
||||||
|
logger.Debug("Hole punch timeout reached")
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
// Send hole punch to all exit nodes
|
||||||
|
for _, node := range resolvedNodes {
|
||||||
|
if err := m.sendHolePunch(node.remoteAddr, node.publicKey); err != nil {
|
||||||
|
logger.Debug("Failed to send hole punch to %s: %v", node.endpointName, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// runSingleEndpoint performs hole punching to a single endpoint
|
||||||
|
func (m *Manager) runSingleEndpoint(endpoint, serverPubKey string) {
|
||||||
|
defer func() {
|
||||||
|
m.mu.Lock()
|
||||||
|
m.running = false
|
||||||
|
m.mu.Unlock()
|
||||||
|
logger.Info("UDP hole punch goroutine ended for %s", endpoint)
|
||||||
|
}()
|
||||||
|
|
||||||
|
host, err := m.domainResolver(endpoint)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Failed to resolve domain %s: %v", endpoint, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
serverAddr := net.JoinHostPort(host, "21820")
|
||||||
|
|
||||||
|
remoteAddr, err := net.ResolveUDPAddr("udp", serverAddr)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Failed to resolve UDP address %s: %v", serverAddr, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute once immediately before starting the loop
|
||||||
|
if err := m.sendHolePunch(remoteAddr, serverPubKey); err != nil {
|
||||||
|
logger.Warn("Failed to send initial hole punch: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ticker := time.NewTicker(250 * time.Millisecond)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
timeout := time.NewTimer(15 * time.Second)
|
||||||
|
defer timeout.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-m.stopChan:
|
||||||
|
logger.Debug("Hole punch stopped by signal")
|
||||||
|
return
|
||||||
|
case <-timeout.C:
|
||||||
|
logger.Debug("Hole punch timeout reached")
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
if err := m.sendHolePunch(remoteAddr, serverPubKey); err != nil {
|
||||||
|
logger.Debug("Failed to send hole punch: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// sendHolePunch sends an encrypted hole punch packet using the shared bind
|
||||||
|
func (m *Manager) sendHolePunch(remoteAddr *net.UDPAddr, serverPubKey string) error {
|
||||||
|
m.mu.Lock()
|
||||||
|
token := m.token
|
||||||
|
olmID := m.olmID
|
||||||
|
m.mu.Unlock()
|
||||||
|
|
||||||
|
if serverPubKey == "" || token == "" {
|
||||||
|
return fmt.Errorf("server public key or OLM token is empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
payload := struct {
|
||||||
|
OlmID string `json:"olmId"`
|
||||||
|
Token string `json:"token"`
|
||||||
|
}{
|
||||||
|
OlmID: olmID,
|
||||||
|
Token: token,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert payload to JSON
|
||||||
|
payloadBytes, err := json.Marshal(payload)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to marshal payload: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encrypt the payload using the server's WireGuard public key
|
||||||
|
encryptedPayload, err := encryptPayload(payloadBytes, serverPubKey)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to encrypt payload: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonData, err := json.Marshal(encryptedPayload)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to marshal encrypted payload: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = m.sharedBind.WriteToUDP(jsonData, remoteAddr)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to write to UDP: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Debug("Sent UDP hole punch to %s: %s", remoteAddr.String(), string(jsonData))
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// encryptPayload encrypts the payload using ChaCha20-Poly1305 AEAD with X25519 key exchange
|
||||||
|
func encryptPayload(payload []byte, serverPublicKey string) (interface{}, error) {
|
||||||
|
// Generate an ephemeral keypair for this message
|
||||||
|
ephemeralPrivateKey, err := wgtypes.GeneratePrivateKey()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to generate ephemeral private key: %v", err)
|
||||||
|
}
|
||||||
|
ephemeralPublicKey := ephemeralPrivateKey.PublicKey()
|
||||||
|
|
||||||
|
// Parse the server's public key
|
||||||
|
serverPubKey, err := wgtypes.ParseKey(serverPublicKey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse server public key: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use X25519 for key exchange
|
||||||
|
var ephPrivKeyFixed [32]byte
|
||||||
|
copy(ephPrivKeyFixed[:], ephemeralPrivateKey[:])
|
||||||
|
|
||||||
|
// Perform X25519 key exchange
|
||||||
|
sharedSecret, err := curve25519.X25519(ephPrivKeyFixed[:], serverPubKey[:])
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to perform X25519 key exchange: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create an AEAD cipher using the shared secret
|
||||||
|
aead, err := chacha20poly1305.New(sharedSecret)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create AEAD cipher: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate a random nonce
|
||||||
|
nonce := make([]byte, aead.NonceSize())
|
||||||
|
if _, err := rand.Read(nonce); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to generate nonce: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encrypt the payload
|
||||||
|
ciphertext := aead.Seal(nil, nonce, payload, nil)
|
||||||
|
|
||||||
|
// Prepare the final encrypted message
|
||||||
|
encryptedMsg := struct {
|
||||||
|
EphemeralPublicKey string `json:"ephemeralPublicKey"`
|
||||||
|
Nonce []byte `json:"nonce"`
|
||||||
|
Ciphertext []byte `json:"ciphertext"`
|
||||||
|
}{
|
||||||
|
EphemeralPublicKey: ephemeralPublicKey.String(),
|
||||||
|
Nonce: nonce,
|
||||||
|
Ciphertext: ciphertext,
|
||||||
|
}
|
||||||
|
|
||||||
|
return encryptedMsg, nil
|
||||||
|
}
|
||||||
BIN
olm-binary
BIN
olm-binary
Binary file not shown.
467
olm/common.go
467
olm/common.go
@@ -3,7 +3,6 @@ package olm
|
|||||||
import (
|
import (
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
@@ -14,12 +13,9 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/fosrl/newt/logger"
|
"github.com/fosrl/newt/logger"
|
||||||
"github.com/fosrl/olm/bind"
|
|
||||||
"github.com/fosrl/olm/peermonitor"
|
"github.com/fosrl/olm/peermonitor"
|
||||||
"github.com/fosrl/olm/websocket"
|
"github.com/fosrl/olm/websocket"
|
||||||
"github.com/vishvananda/netlink"
|
"github.com/vishvananda/netlink"
|
||||||
"golang.org/x/crypto/chacha20poly1305"
|
|
||||||
"golang.org/x/crypto/curve25519"
|
|
||||||
"golang.org/x/exp/rand"
|
"golang.org/x/exp/rand"
|
||||||
"golang.zx2c4.com/wireguard/device"
|
"golang.zx2c4.com/wireguard/device"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
@@ -192,7 +188,7 @@ func mapToWireGuardLogLevel(level logger.LogLevel) int {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func resolveDomain(domain string) (string, error) {
|
func ResolveDomain(domain string) (string, error) {
|
||||||
// First handle any protocol prefix
|
// First handle any protocol prefix
|
||||||
domain = strings.TrimPrefix(strings.TrimPrefix(domain, "https://"), "http://")
|
domain = strings.TrimPrefix(strings.TrimPrefix(domain, "https://"), "http://")
|
||||||
|
|
||||||
@@ -239,463 +235,6 @@ func resolveDomain(domain string) (string, error) {
|
|||||||
return ipAddr, nil
|
return ipAddr, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func sendUDPHolePunchWithConn(conn *net.UDPConn, remoteAddr *net.UDPAddr, olmID string, serverPubKey string) error {
|
|
||||||
if serverPubKey == "" || olmToken == "" {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
payload := struct {
|
|
||||||
OlmID string `json:"olmId"`
|
|
||||||
Token string `json:"token"`
|
|
||||||
}{
|
|
||||||
OlmID: olmID,
|
|
||||||
Token: olmToken,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Convert payload to JSON
|
|
||||||
payloadBytes, err := json.Marshal(payload)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to marshal payload: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Encrypt the payload using the server's WireGuard public key
|
|
||||||
encryptedPayload, err := encryptPayload(payloadBytes, serverPubKey)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to encrypt payload: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
jsonData, err := json.Marshal(encryptedPayload)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to marshal encrypted payload: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = conn.WriteToUDP(jsonData, remoteAddr)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to send UDP packet: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.Debug("Sent UDP hole punch to %s: %s", remoteAddr.String(), string(jsonData))
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func encryptPayload(payload []byte, serverPublicKey string) (interface{}, error) {
|
|
||||||
// Generate an ephemeral keypair for this message
|
|
||||||
ephemeralPrivateKey, err := wgtypes.GeneratePrivateKey()
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to generate ephemeral private key: %v", err)
|
|
||||||
}
|
|
||||||
ephemeralPublicKey := ephemeralPrivateKey.PublicKey()
|
|
||||||
|
|
||||||
// Parse the server's public key
|
|
||||||
serverPubKey, err := wgtypes.ParseKey(serverPublicKey)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to parse server public key: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Use X25519 for key exchange (replacing deprecated ScalarMult)
|
|
||||||
var ephPrivKeyFixed [32]byte
|
|
||||||
copy(ephPrivKeyFixed[:], ephemeralPrivateKey[:])
|
|
||||||
|
|
||||||
// Perform X25519 key exchange
|
|
||||||
sharedSecret, err := curve25519.X25519(ephPrivKeyFixed[:], serverPubKey[:])
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to perform X25519 key exchange: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create an AEAD cipher using the shared secret
|
|
||||||
aead, err := chacha20poly1305.New(sharedSecret)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to create AEAD cipher: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Generate a random nonce
|
|
||||||
nonce := make([]byte, aead.NonceSize())
|
|
||||||
if _, err := rand.Read(nonce); err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to generate nonce: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Encrypt the payload
|
|
||||||
ciphertext := aead.Seal(nil, nonce, payload, nil)
|
|
||||||
|
|
||||||
// Prepare the final encrypted message
|
|
||||||
encryptedMsg := struct {
|
|
||||||
EphemeralPublicKey string `json:"ephemeralPublicKey"`
|
|
||||||
Nonce []byte `json:"nonce"`
|
|
||||||
Ciphertext []byte `json:"ciphertext"`
|
|
||||||
}{
|
|
||||||
EphemeralPublicKey: ephemeralPublicKey.String(),
|
|
||||||
Nonce: nonce,
|
|
||||||
Ciphertext: ciphertext,
|
|
||||||
}
|
|
||||||
|
|
||||||
return encryptedMsg, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func keepSendingUDPHolePunchToMultipleExitNodes(exitNodes []ExitNode, olmID string, sourcePort uint16) {
|
|
||||||
if len(exitNodes) == 0 {
|
|
||||||
logger.Warn("No exit nodes provided for hole punching")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if hole punching is already running
|
|
||||||
if holePunchRunning {
|
|
||||||
logger.Debug("UDP hole punch already running, skipping new request")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set the flag to indicate hole punching is running
|
|
||||||
holePunchRunning = true
|
|
||||||
defer func() {
|
|
||||||
holePunchRunning = false
|
|
||||||
logger.Info("UDP hole punch goroutine ended")
|
|
||||||
}()
|
|
||||||
|
|
||||||
logger.Info("Starting UDP hole punch to %d exit nodes", len(exitNodes))
|
|
||||||
defer logger.Info("UDP hole punch goroutine ended for all exit nodes")
|
|
||||||
|
|
||||||
// Create the UDP connection once and reuse it for all exit nodes
|
|
||||||
localAddr := &net.UDPAddr{
|
|
||||||
Port: int(sourcePort),
|
|
||||||
IP: net.IPv4zero,
|
|
||||||
}
|
|
||||||
|
|
||||||
conn, err := net.ListenUDP("udp", localAddr)
|
|
||||||
if err != nil {
|
|
||||||
logger.Error("Failed to bind UDP socket: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer conn.Close()
|
|
||||||
|
|
||||||
// Resolve all endpoints upfront
|
|
||||||
type resolvedExitNode struct {
|
|
||||||
remoteAddr *net.UDPAddr
|
|
||||||
publicKey string
|
|
||||||
endpointName string
|
|
||||||
}
|
|
||||||
|
|
||||||
var resolvedNodes []resolvedExitNode
|
|
||||||
for _, exitNode := range exitNodes {
|
|
||||||
host, err := resolveDomain(exitNode.Endpoint)
|
|
||||||
if err != nil {
|
|
||||||
logger.Error("Failed to resolve endpoint %s: %v", exitNode.Endpoint, err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
serverAddr := net.JoinHostPort(host, "21820")
|
|
||||||
remoteAddr, err := net.ResolveUDPAddr("udp", serverAddr)
|
|
||||||
if err != nil {
|
|
||||||
logger.Error("Failed to resolve UDP address for %s: %v", exitNode.Endpoint, err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
resolvedNodes = append(resolvedNodes, resolvedExitNode{
|
|
||||||
remoteAddr: remoteAddr,
|
|
||||||
publicKey: exitNode.PublicKey,
|
|
||||||
endpointName: exitNode.Endpoint,
|
|
||||||
})
|
|
||||||
logger.Info("Resolved exit node: %s -> %s", exitNode.Endpoint, remoteAddr.String())
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(resolvedNodes) == 0 {
|
|
||||||
logger.Error("No exit nodes could be resolved")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Send initial hole punch to all exit nodes
|
|
||||||
for _, node := range resolvedNodes {
|
|
||||||
if err := sendUDPHolePunchWithConn(conn, node.remoteAddr, olmID, node.publicKey); err != nil {
|
|
||||||
logger.Error("Failed to send initial UDP hole punch to %s: %v", node.endpointName, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
ticker := time.NewTicker(250 * time.Millisecond)
|
|
||||||
defer ticker.Stop()
|
|
||||||
|
|
||||||
timeout := time.NewTimer(15 * time.Second)
|
|
||||||
defer timeout.Stop()
|
|
||||||
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-stopHolepunch:
|
|
||||||
logger.Info("Stopping UDP holepunch for all exit nodes")
|
|
||||||
return
|
|
||||||
case <-timeout.C:
|
|
||||||
logger.Info("UDP holepunch routine timed out after 15 seconds for all exit nodes")
|
|
||||||
return
|
|
||||||
case <-ticker.C:
|
|
||||||
// Send hole punch to all exit nodes
|
|
||||||
for _, node := range resolvedNodes {
|
|
||||||
if err := sendUDPHolePunchWithConn(conn, node.remoteAddr, olmID, node.publicKey); err != nil {
|
|
||||||
logger.Error("Failed to send UDP hole punch to %s: %v", node.endpointName, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func keepSendingUDPHolePunch(endpoint string, olmID string, sourcePort uint16, serverPubKey string) {
|
|
||||||
|
|
||||||
// Check if hole punching is already running
|
|
||||||
if holePunchRunning {
|
|
||||||
logger.Debug("UDP hole punch already running, skipping new request")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set the flag to indicate hole punching is running
|
|
||||||
holePunchRunning = true
|
|
||||||
defer func() {
|
|
||||||
holePunchRunning = false
|
|
||||||
logger.Info("UDP hole punch goroutine ended")
|
|
||||||
}()
|
|
||||||
|
|
||||||
logger.Info("Starting UDP hole punch to %s", endpoint)
|
|
||||||
defer logger.Info("UDP hole punch goroutine ended for %s", endpoint)
|
|
||||||
|
|
||||||
host, err := resolveDomain(endpoint)
|
|
||||||
if err != nil {
|
|
||||||
logger.Error("Failed to resolve endpoint: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
serverAddr := net.JoinHostPort(host, "21820")
|
|
||||||
|
|
||||||
// Create the UDP connection once and reuse it
|
|
||||||
localAddr := &net.UDPAddr{
|
|
||||||
Port: int(sourcePort),
|
|
||||||
IP: net.IPv4zero,
|
|
||||||
}
|
|
||||||
|
|
||||||
remoteAddr, err := net.ResolveUDPAddr("udp", serverAddr)
|
|
||||||
if err != nil {
|
|
||||||
logger.Error("Failed to resolve UDP address: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
conn, err := net.ListenUDP("udp", localAddr)
|
|
||||||
if err != nil {
|
|
||||||
logger.Error("Failed to bind UDP socket: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer conn.Close()
|
|
||||||
|
|
||||||
// Execute once immediately before starting the loop
|
|
||||||
if err := sendUDPHolePunchWithConn(conn, remoteAddr, olmID, serverPubKey); err != nil {
|
|
||||||
logger.Error("Failed to send UDP hole punch: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
ticker := time.NewTicker(250 * time.Millisecond)
|
|
||||||
defer ticker.Stop()
|
|
||||||
|
|
||||||
timeout := time.NewTimer(15 * time.Second)
|
|
||||||
defer timeout.Stop()
|
|
||||||
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-stopHolepunch:
|
|
||||||
logger.Info("Stopping UDP holepunch")
|
|
||||||
return
|
|
||||||
case <-timeout.C:
|
|
||||||
logger.Info("UDP holepunch routine timed out after 15 seconds")
|
|
||||||
return
|
|
||||||
case <-ticker.C:
|
|
||||||
if err := sendUDPHolePunchWithConn(conn, remoteAddr, olmID, serverPubKey); err != nil {
|
|
||||||
logger.Error("Failed to send UDP hole punch: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// keepSendingUDPHolePunchToMultipleExitNodesWithSharedBind sends hole punch packets using the shared bind
|
|
||||||
func keepSendingUDPHolePunchToMultipleExitNodesWithSharedBind(exitNodes []ExitNode, olmID string, sharedBind *bind.SharedBind) {
|
|
||||||
if len(exitNodes) == 0 {
|
|
||||||
logger.Warn("No exit nodes provided for hole punching")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if hole punching is already running
|
|
||||||
if holePunchRunning {
|
|
||||||
logger.Debug("UDP hole punch already running, skipping new request")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set the flag to indicate hole punching is running
|
|
||||||
holePunchRunning = true
|
|
||||||
defer func() {
|
|
||||||
holePunchRunning = false
|
|
||||||
logger.Info("UDP hole punch goroutine ended")
|
|
||||||
}()
|
|
||||||
|
|
||||||
logger.Info("Starting UDP hole punch to %d exit nodes with shared bind", len(exitNodes))
|
|
||||||
defer logger.Info("UDP hole punch goroutine ended for all exit nodes")
|
|
||||||
|
|
||||||
// Resolve all endpoints upfront
|
|
||||||
type resolvedExitNode struct {
|
|
||||||
remoteAddr *net.UDPAddr
|
|
||||||
publicKey string
|
|
||||||
endpointName string
|
|
||||||
}
|
|
||||||
|
|
||||||
var resolvedNodes []resolvedExitNode
|
|
||||||
for _, exitNode := range exitNodes {
|
|
||||||
host, err := resolveDomain(exitNode.Endpoint)
|
|
||||||
if err != nil {
|
|
||||||
logger.Error("Failed to resolve endpoint %s: %v", exitNode.Endpoint, err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
serverAddr := net.JoinHostPort(host, "21820")
|
|
||||||
remoteAddr, err := net.ResolveUDPAddr("udp", serverAddr)
|
|
||||||
if err != nil {
|
|
||||||
logger.Error("Failed to resolve UDP address for %s: %v", exitNode.Endpoint, err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
resolvedNodes = append(resolvedNodes, resolvedExitNode{
|
|
||||||
remoteAddr: remoteAddr,
|
|
||||||
publicKey: exitNode.PublicKey,
|
|
||||||
endpointName: exitNode.Endpoint,
|
|
||||||
})
|
|
||||||
logger.Info("Resolved exit node: %s -> %s", exitNode.Endpoint, remoteAddr.String())
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(resolvedNodes) == 0 {
|
|
||||||
logger.Error("No exit nodes could be resolved")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Send initial hole punch to all exit nodes
|
|
||||||
for _, node := range resolvedNodes {
|
|
||||||
if err := sendUDPHolePunchWithBind(sharedBind, node.remoteAddr, olmID, node.publicKey); err != nil {
|
|
||||||
logger.Error("Failed to send initial UDP hole punch to %s: %v", node.endpointName, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
ticker := time.NewTicker(250 * time.Millisecond)
|
|
||||||
defer ticker.Stop()
|
|
||||||
|
|
||||||
timeout := time.NewTimer(15 * time.Second)
|
|
||||||
defer timeout.Stop()
|
|
||||||
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-stopHolepunch:
|
|
||||||
logger.Info("Stopping UDP holepunch for all exit nodes")
|
|
||||||
return
|
|
||||||
case <-timeout.C:
|
|
||||||
logger.Info("UDP holepunch routine timed out after 15 seconds for all exit nodes")
|
|
||||||
return
|
|
||||||
case <-ticker.C:
|
|
||||||
// Send hole punch to all exit nodes
|
|
||||||
for _, node := range resolvedNodes {
|
|
||||||
if err := sendUDPHolePunchWithBind(sharedBind, node.remoteAddr, olmID, node.publicKey); err != nil {
|
|
||||||
logger.Error("Failed to send UDP hole punch to %s: %v", node.endpointName, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// keepSendingUDPHolePunchWithSharedBind sends hole punch packets to a single endpoint using shared bind
|
|
||||||
func keepSendingUDPHolePunchWithSharedBind(endpoint string, olmID string, sharedBind *bind.SharedBind, serverPubKey string) {
|
|
||||||
// Check if hole punching is already running
|
|
||||||
if holePunchRunning {
|
|
||||||
logger.Debug("UDP hole punch already running, skipping new request")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set the flag to indicate hole punching is running
|
|
||||||
holePunchRunning = true
|
|
||||||
defer func() {
|
|
||||||
holePunchRunning = false
|
|
||||||
logger.Info("UDP hole punch goroutine ended")
|
|
||||||
}()
|
|
||||||
|
|
||||||
logger.Info("Starting UDP hole punch to %s with shared bind", endpoint)
|
|
||||||
defer logger.Info("UDP hole punch goroutine ended for %s", endpoint)
|
|
||||||
|
|
||||||
host, err := resolveDomain(endpoint)
|
|
||||||
if err != nil {
|
|
||||||
logger.Error("Failed to resolve domain %s: %v", endpoint, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
serverAddr := net.JoinHostPort(host, "21820")
|
|
||||||
|
|
||||||
remoteAddr, err := net.ResolveUDPAddr("udp", serverAddr)
|
|
||||||
if err != nil {
|
|
||||||
logger.Error("Failed to resolve UDP address %s: %v", serverAddr, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Execute once immediately before starting the loop
|
|
||||||
if err := sendUDPHolePunchWithBind(sharedBind, remoteAddr, olmID, serverPubKey); err != nil {
|
|
||||||
logger.Error("Failed to send initial UDP hole punch: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
ticker := time.NewTicker(250 * time.Millisecond)
|
|
||||||
defer ticker.Stop()
|
|
||||||
|
|
||||||
timeout := time.NewTimer(15 * time.Second)
|
|
||||||
defer timeout.Stop()
|
|
||||||
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-stopHolepunch:
|
|
||||||
logger.Info("Stopping UDP holepunch")
|
|
||||||
return
|
|
||||||
case <-timeout.C:
|
|
||||||
logger.Info("UDP holepunch routine timed out after 15 seconds")
|
|
||||||
return
|
|
||||||
case <-ticker.C:
|
|
||||||
if err := sendUDPHolePunchWithBind(sharedBind, remoteAddr, olmID, serverPubKey); err != nil {
|
|
||||||
logger.Error("Failed to send UDP hole punch: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// sendUDPHolePunchWithBind sends an encrypted hole punch packet using the shared bind
|
|
||||||
func sendUDPHolePunchWithBind(sharedBind *bind.SharedBind, remoteAddr *net.UDPAddr, olmID string, serverPubKey string) error {
|
|
||||||
if serverPubKey == "" || olmToken == "" {
|
|
||||||
return fmt.Errorf("server public key or OLM token is empty")
|
|
||||||
}
|
|
||||||
|
|
||||||
payload := struct {
|
|
||||||
OlmID string `json:"olmId"`
|
|
||||||
Token string `json:"token"`
|
|
||||||
}{
|
|
||||||
OlmID: olmID,
|
|
||||||
Token: olmToken,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Convert payload to JSON
|
|
||||||
payloadBytes, err := json.Marshal(payload)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to marshal payload: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Encrypt the payload using the server's WireGuard public key
|
|
||||||
encryptedPayload, err := encryptPayload(payloadBytes, serverPubKey)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to encrypt payload: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
jsonData, err := json.Marshal(encryptedPayload)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to marshal encrypted payload: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = sharedBind.WriteToUDP(jsonData, remoteAddr)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to write to UDP: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.Debug("Sent UDP hole punch to %s: %s", remoteAddr.String(), string(jsonData))
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func FindAvailableUDPPort(minPort, maxPort uint16) (uint16, error) {
|
func FindAvailableUDPPort(minPort, maxPort uint16) (uint16, error) {
|
||||||
if maxPort < minPort {
|
if maxPort < minPort {
|
||||||
return 0, fmt.Errorf("invalid port range: min=%d, max=%d", minPort, maxPort)
|
return 0, fmt.Errorf("invalid port range: min=%d, max=%d", minPort, maxPort)
|
||||||
@@ -772,7 +311,7 @@ func keepSendingPing(olm *websocket.Client) {
|
|||||||
|
|
||||||
// ConfigurePeer sets up or updates a peer within the WireGuard device
|
// ConfigurePeer sets up or updates a peer within the WireGuard device
|
||||||
func ConfigurePeer(dev *device.Device, siteConfig SiteConfig, privateKey wgtypes.Key, endpoint string) error {
|
func ConfigurePeer(dev *device.Device, siteConfig SiteConfig, privateKey wgtypes.Key, endpoint string) error {
|
||||||
siteHost, err := resolveDomain(siteConfig.Endpoint)
|
siteHost, err := ResolveDomain(siteConfig.Endpoint)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to resolve endpoint for site %d: %v", siteConfig.SiteId, err)
|
return fmt.Errorf("failed to resolve endpoint for site %d: %v", siteConfig.SiteId, err)
|
||||||
}
|
}
|
||||||
@@ -829,7 +368,7 @@ func ConfigurePeer(dev *device.Device, siteConfig SiteConfig, privateKey wgtypes
|
|||||||
monitorPeer := net.JoinHostPort(monitorAddress, strconv.Itoa(int(siteConfig.ServerPort+1))) // +1 for the monitor port
|
monitorPeer := net.JoinHostPort(monitorAddress, strconv.Itoa(int(siteConfig.ServerPort+1))) // +1 for the monitor port
|
||||||
logger.Debug("Setting up peer monitor for site %d at %s", siteConfig.SiteId, monitorPeer)
|
logger.Debug("Setting up peer monitor for site %d at %s", siteConfig.SiteId, monitorPeer)
|
||||||
|
|
||||||
primaryRelay, err := resolveDomain(endpoint) // Using global endpoint variable
|
primaryRelay, err := ResolveDomain(endpoint) // Using global endpoint variable
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Warn("Failed to resolve primary relay endpoint: %v", err)
|
logger.Warn("Failed to resolve primary relay endpoint: %v", err)
|
||||||
}
|
}
|
||||||
|
|||||||
86
olm/olm.go
86
olm/olm.go
@@ -13,6 +13,7 @@ import (
|
|||||||
"github.com/fosrl/newt/updates"
|
"github.com/fosrl/newt/updates"
|
||||||
"github.com/fosrl/olm/api"
|
"github.com/fosrl/olm/api"
|
||||||
"github.com/fosrl/olm/bind"
|
"github.com/fosrl/olm/bind"
|
||||||
|
"github.com/fosrl/olm/holepunch"
|
||||||
"github.com/fosrl/olm/peermonitor"
|
"github.com/fosrl/olm/peermonitor"
|
||||||
"github.com/fosrl/olm/websocket"
|
"github.com/fosrl/olm/websocket"
|
||||||
"golang.zx2c4.com/wireguard/device"
|
"golang.zx2c4.com/wireguard/device"
|
||||||
@@ -57,18 +58,19 @@ type Config struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
privateKey wgtypes.Key
|
privateKey wgtypes.Key
|
||||||
connected bool
|
connected bool
|
||||||
dev *device.Device
|
dev *device.Device
|
||||||
wgData WgData
|
wgData WgData
|
||||||
holePunchData HolePunchData
|
holePunchData HolePunchData
|
||||||
uapiListener net.Listener
|
uapiListener net.Listener
|
||||||
tdev tun.Device
|
tdev tun.Device
|
||||||
apiServer *api.API
|
apiServer *api.API
|
||||||
olmClient *websocket.Client
|
olmClient *websocket.Client
|
||||||
tunnelCancel context.CancelFunc
|
tunnelCancel context.CancelFunc
|
||||||
tunnelRunning bool
|
tunnelRunning bool
|
||||||
sharedBind *bind.SharedBind
|
sharedBind *bind.SharedBind
|
||||||
|
holePunchManager *holepunch.Manager
|
||||||
)
|
)
|
||||||
|
|
||||||
func Run(ctx context.Context, config Config) {
|
func Run(ctx context.Context, config Config) {
|
||||||
@@ -197,7 +199,6 @@ func TunnelProcess(ctx context.Context, config Config, id string, secret string,
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
// Recreate channels for this tunnel session
|
// Recreate channels for this tunnel session
|
||||||
stopHolepunch = make(chan struct{})
|
|
||||||
stopPing = make(chan struct{})
|
stopPing = make(chan struct{})
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -260,6 +261,11 @@ func TunnelProcess(ctx context.Context, config Config, id string, secret string,
|
|||||||
logger.Info("Created shared UDP socket on port %d (refcount: %d)", sourcePort, sharedBind.GetRefCount())
|
logger.Info("Created shared UDP socket on port %d (refcount: %d)", sourcePort, sharedBind.GetRefCount())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Create the holepunch manager
|
||||||
|
if holePunchManager == nil {
|
||||||
|
holePunchManager = holepunch.NewManager(sharedBind, id, ResolveDomain)
|
||||||
|
}
|
||||||
|
|
||||||
olm.RegisterHandler("olm/wg/holepunch/all", func(msg websocket.WSMessage) {
|
olm.RegisterHandler("olm/wg/holepunch/all", func(msg websocket.WSMessage) {
|
||||||
logger.Debug("Received message: %v", msg.Data)
|
logger.Debug("Received message: %v", msg.Data)
|
||||||
|
|
||||||
@@ -274,12 +280,20 @@ func TunnelProcess(ctx context.Context, config Config, id string, secret string,
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create a new stopHolepunch channel for the new set of goroutines
|
// Convert HolePunchData.ExitNodes to holepunch.ExitNode slice
|
||||||
stopHolepunch = make(chan struct{})
|
exitNodes := make([]holepunch.ExitNode, len(holePunchData.ExitNodes))
|
||||||
|
for i, node := range holePunchData.ExitNodes {
|
||||||
|
exitNodes[i] = holepunch.ExitNode{
|
||||||
|
Endpoint: node.Endpoint,
|
||||||
|
PublicKey: node.PublicKey,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Start a single hole punch goroutine for all exit nodes
|
// Start hole punching using the manager
|
||||||
logger.Info("Starting hole punch for %d exit nodes", len(holePunchData.ExitNodes))
|
logger.Info("Starting hole punch for %d exit nodes", len(exitNodes))
|
||||||
go keepSendingUDPHolePunchToMultipleExitNodesWithSharedBind(holePunchData.ExitNodes, id, sharedBind)
|
if err := holePunchManager.StartMultipleExitNodes(exitNodes); err != nil {
|
||||||
|
logger.Warn("Failed to start hole punch: %v", err)
|
||||||
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
olm.RegisterHandler("olm/wg/holepunch", func(msg websocket.WSMessage) {
|
olm.RegisterHandler("olm/wg/holepunch", func(msg websocket.WSMessage) {
|
||||||
@@ -304,20 +318,16 @@ func TunnelProcess(ctx context.Context, config Config, id string, secret string,
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Stop any existing hole punch goroutines by closing the current channel
|
// Stop any existing hole punch operations
|
||||||
select {
|
if holePunchManager != nil {
|
||||||
case <-stopHolepunch:
|
holePunchManager.Stop()
|
||||||
// Channel already closed
|
|
||||||
default:
|
|
||||||
close(stopHolepunch)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create a new stopHolepunch channel for the new set of goroutines
|
// Start hole punching for the exit node
|
||||||
stopHolepunch = make(chan struct{})
|
|
||||||
|
|
||||||
// Start hole punching for each exit node
|
|
||||||
logger.Info("Starting hole punch for exit node: %s with public key: %s", legacyHolePunchData.Endpoint, legacyHolePunchData.ServerPubKey)
|
logger.Info("Starting hole punch for exit node: %s with public key: %s", legacyHolePunchData.Endpoint, legacyHolePunchData.ServerPubKey)
|
||||||
go keepSendingUDPHolePunchWithSharedBind(legacyHolePunchData.Endpoint, id, sharedBind, legacyHolePunchData.ServerPubKey)
|
if err := holePunchManager.StartSingleEndpoint(legacyHolePunchData.Endpoint, legacyHolePunchData.ServerPubKey); err != nil {
|
||||||
|
logger.Warn("Failed to start hole punch: %v", err)
|
||||||
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
olm.RegisterHandler("olm/wg/connect", func(msg websocket.WSMessage) {
|
olm.RegisterHandler("olm/wg/connect", func(msg websocket.WSMessage) {
|
||||||
@@ -407,6 +417,7 @@ func TunnelProcess(ctx context.Context, config Config, id string, secret string,
|
|||||||
for {
|
for {
|
||||||
conn, err := uapiListener.Accept()
|
conn, err := uapiListener.Accept()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
go dev.IpcHandle(conn)
|
go dev.IpcHandle(conn)
|
||||||
@@ -696,7 +707,7 @@ func TunnelProcess(ctx context.Context, config Config, id string, secret string,
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
primaryRelay, err := resolveDomain(relayData.Endpoint)
|
primaryRelay, err := ResolveDomain(relayData.Endpoint)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Warn("Failed to resolve primary relay endpoint: %v", err)
|
logger.Warn("Failed to resolve primary relay endpoint: %v", err)
|
||||||
}
|
}
|
||||||
@@ -752,7 +763,9 @@ func TunnelProcess(ctx context.Context, config Config, id string, secret string,
|
|||||||
})
|
})
|
||||||
|
|
||||||
olm.OnTokenUpdate(func(token string) {
|
olm.OnTokenUpdate(func(token string) {
|
||||||
olmToken = token
|
if holePunchManager != nil {
|
||||||
|
holePunchManager.SetToken(token)
|
||||||
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
// Connect to the WebSocket server
|
// Connect to the WebSocket server
|
||||||
@@ -780,7 +793,6 @@ func TunnelProcess(ctx context.Context, config Config, id string, secret string,
|
|||||||
apiServer.SetTunnelIP("")
|
apiServer.SetTunnelIP("")
|
||||||
apiServer.SetOrgID(config.OrgID)
|
apiServer.SetOrgID(config.OrgID)
|
||||||
|
|
||||||
stopHolepunch = make(chan struct{})
|
|
||||||
// Trigger re-registration with new orgId
|
// Trigger re-registration with new orgId
|
||||||
logger.Info("Re-registering with new orgId: %s", config.OrgID)
|
logger.Info("Re-registering with new orgId: %s", config.OrgID)
|
||||||
publicKey := privateKey.PublicKey()
|
publicKey := privateKey.PublicKey()
|
||||||
@@ -799,13 +811,9 @@ func TunnelProcess(ctx context.Context, config Config, id string, secret string,
|
|||||||
}
|
}
|
||||||
|
|
||||||
func Stop() {
|
func Stop() {
|
||||||
if stopHolepunch != nil {
|
// Stop hole punch manager
|
||||||
select {
|
if holePunchManager != nil {
|
||||||
case <-stopHolepunch:
|
holePunchManager.Stop()
|
||||||
// Channel already closed, do nothing
|
|
||||||
default:
|
|
||||||
close(stopHolepunch)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if stopPing != nil {
|
if stopPing != nil {
|
||||||
|
|||||||
Reference in New Issue
Block a user