Custom bind?

This commit is contained in:
Owen
2025-11-07 21:39:28 -08:00
parent aedebb5579
commit 6d8e298ebc
6 changed files with 1039 additions and 27 deletions

View File

@@ -14,13 +14,13 @@ import (
"time"
"github.com/fosrl/newt/logger"
"github.com/fosrl/olm/bind"
"github.com/fosrl/olm/peermonitor"
"github.com/fosrl/olm/websocket"
"github.com/vishvananda/netlink"
"golang.org/x/crypto/chacha20poly1305"
"golang.org/x/crypto/curve25519"
"golang.org/x/exp/rand"
"golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
@@ -82,11 +82,6 @@ const (
ENV_WG_PROCESS_FOREGROUND = "WG_PROCESS_FOREGROUND"
)
type fixedPortBind struct {
port uint16
conn.Bind
}
// PeerAction represents a request to add, update, or remove a peer
type PeerAction struct {
Action string `json:"action"` // "add", "update", or "remove"
@@ -124,11 +119,6 @@ type RelayPeerData struct {
PublicKey string `json:"publicKey"`
}
func (b *fixedPortBind) Open(port uint16) ([]conn.ReceiveFunc, uint16, error) {
// Ignore the requested port and use our fixed port
return b.Bind.Open(b.port)
}
// Helper function to format endpoints correctly
func formatEndpoint(endpoint string) string {
if endpoint == "" {
@@ -156,13 +146,6 @@ func formatEndpoint(endpoint string) string {
return endpoint
}
func NewFixedPortBind(port uint16) conn.Bind {
return &fixedPortBind{
port: port,
Bind: conn.NewDefaultBind(),
}
}
func fixKey(key string) string {
// Remove any whitespace
key = strings.TrimSpace(key)
@@ -523,6 +506,196 @@ func keepSendingUDPHolePunch(endpoint string, olmID string, sourcePort uint16, s
}
}
// keepSendingUDPHolePunchToMultipleExitNodesWithSharedBind sends hole punch packets using the shared bind
func keepSendingUDPHolePunchToMultipleExitNodesWithSharedBind(exitNodes []ExitNode, olmID string, sharedBind *bind.SharedBind) {
if len(exitNodes) == 0 {
logger.Warn("No exit nodes provided for hole punching")
return
}
// Check if hole punching is already running
if holePunchRunning {
logger.Debug("UDP hole punch already running, skipping new request")
return
}
// Set the flag to indicate hole punching is running
holePunchRunning = true
defer func() {
holePunchRunning = false
logger.Info("UDP hole punch goroutine ended")
}()
logger.Info("Starting UDP hole punch to %d exit nodes with shared bind", len(exitNodes))
defer logger.Info("UDP hole punch goroutine ended for all exit nodes")
// Resolve all endpoints upfront
type resolvedExitNode struct {
remoteAddr *net.UDPAddr
publicKey string
endpointName string
}
var resolvedNodes []resolvedExitNode
for _, exitNode := range exitNodes {
host, err := resolveDomain(exitNode.Endpoint)
if err != nil {
logger.Error("Failed to resolve endpoint %s: %v", exitNode.Endpoint, err)
continue
}
serverAddr := net.JoinHostPort(host, "21820")
remoteAddr, err := net.ResolveUDPAddr("udp", serverAddr)
if err != nil {
logger.Error("Failed to resolve UDP address for %s: %v", exitNode.Endpoint, err)
continue
}
resolvedNodes = append(resolvedNodes, resolvedExitNode{
remoteAddr: remoteAddr,
publicKey: exitNode.PublicKey,
endpointName: exitNode.Endpoint,
})
logger.Info("Resolved exit node: %s -> %s", exitNode.Endpoint, remoteAddr.String())
}
if len(resolvedNodes) == 0 {
logger.Error("No exit nodes could be resolved")
return
}
// Send initial hole punch to all exit nodes
for _, node := range resolvedNodes {
if err := sendUDPHolePunchWithBind(sharedBind, node.remoteAddr, olmID, node.publicKey); err != nil {
logger.Error("Failed to send initial UDP hole punch to %s: %v", node.endpointName, err)
}
}
ticker := time.NewTicker(250 * time.Millisecond)
defer ticker.Stop()
timeout := time.NewTimer(15 * time.Second)
defer timeout.Stop()
for {
select {
case <-stopHolepunch:
logger.Info("Stopping UDP holepunch for all exit nodes")
return
case <-timeout.C:
logger.Info("UDP holepunch routine timed out after 15 seconds for all exit nodes")
return
case <-ticker.C:
// Send hole punch to all exit nodes
for _, node := range resolvedNodes {
if err := sendUDPHolePunchWithBind(sharedBind, node.remoteAddr, olmID, node.publicKey); err != nil {
logger.Error("Failed to send UDP hole punch to %s: %v", node.endpointName, err)
}
}
}
}
}
// keepSendingUDPHolePunchWithSharedBind sends hole punch packets to a single endpoint using shared bind
func keepSendingUDPHolePunchWithSharedBind(endpoint string, olmID string, sharedBind *bind.SharedBind, serverPubKey string) {
// Check if hole punching is already running
if holePunchRunning {
logger.Debug("UDP hole punch already running, skipping new request")
return
}
// Set the flag to indicate hole punching is running
holePunchRunning = true
defer func() {
holePunchRunning = false
logger.Info("UDP hole punch goroutine ended")
}()
logger.Info("Starting UDP hole punch to %s with shared bind", endpoint)
defer logger.Info("UDP hole punch goroutine ended for %s", endpoint)
host, err := resolveDomain(endpoint)
if err != nil {
logger.Error("Failed to resolve domain %s: %v", endpoint, err)
return
}
serverAddr := net.JoinHostPort(host, "21820")
remoteAddr, err := net.ResolveUDPAddr("udp", serverAddr)
if err != nil {
logger.Error("Failed to resolve UDP address %s: %v", serverAddr, err)
return
}
// Execute once immediately before starting the loop
if err := sendUDPHolePunchWithBind(sharedBind, remoteAddr, olmID, serverPubKey); err != nil {
logger.Error("Failed to send initial UDP hole punch: %v", err)
}
ticker := time.NewTicker(250 * time.Millisecond)
defer ticker.Stop()
timeout := time.NewTimer(15 * time.Second)
defer timeout.Stop()
for {
select {
case <-stopHolepunch:
logger.Info("Stopping UDP holepunch")
return
case <-timeout.C:
logger.Info("UDP holepunch routine timed out after 15 seconds")
return
case <-ticker.C:
if err := sendUDPHolePunchWithBind(sharedBind, remoteAddr, olmID, serverPubKey); err != nil {
logger.Error("Failed to send UDP hole punch: %v", err)
}
}
}
}
// sendUDPHolePunchWithBind sends an encrypted hole punch packet using the shared bind
func sendUDPHolePunchWithBind(sharedBind *bind.SharedBind, remoteAddr *net.UDPAddr, olmID string, serverPubKey string) error {
if serverPubKey == "" || olmToken == "" {
return fmt.Errorf("server public key or OLM token is empty")
}
payload := struct {
OlmID string `json:"olmId"`
Token string `json:"token"`
}{
OlmID: olmID,
Token: olmToken,
}
// Convert payload to JSON
payloadBytes, err := json.Marshal(payload)
if err != nil {
return fmt.Errorf("failed to marshal payload: %w", err)
}
// Encrypt the payload using the server's WireGuard public key
encryptedPayload, err := encryptPayload(payloadBytes, serverPubKey)
if err != nil {
return fmt.Errorf("failed to encrypt payload: %w", err)
}
jsonData, err := json.Marshal(encryptedPayload)
if err != nil {
return fmt.Errorf("failed to marshal encrypted payload: %w", err)
}
_, err = sharedBind.WriteToUDP(jsonData, remoteAddr)
if err != nil {
return fmt.Errorf("failed to write to UDP: %w", err)
}
logger.Debug("Sent UDP hole punch to %s: %s", remoteAddr.String(), string(jsonData))
return nil
}
func FindAvailableUDPPort(minPort, maxPort uint16) (uint16, error) {
if maxPort < minPort {
return 0, fmt.Errorf("invalid port range: min=%d, max=%d", minPort, maxPort)

View File

@@ -12,6 +12,7 @@ import (
"github.com/fosrl/newt/logger"
"github.com/fosrl/newt/updates"
"github.com/fosrl/olm/api"
"github.com/fosrl/olm/bind"
"github.com/fosrl/olm/peermonitor"
"github.com/fosrl/olm/websocket"
"golang.zx2c4.com/wireguard/device"
@@ -67,6 +68,7 @@ var (
olmClient *websocket.Client
tunnelCancel context.CancelFunc
tunnelRunning bool
sharedBind *bind.SharedBind
)
func Run(ctx context.Context, config Config) {
@@ -226,10 +228,36 @@ func TunnelProcess(ctx context.Context, config Config, id string, secret string,
return
}
sourcePort, err := FindAvailableUDPPort(49152, 65535)
if err != nil {
logger.Error("Error finding available port: %v", err)
return
// Create shared UDP socket for both holepunch and WireGuard
if sharedBind == nil {
sourcePort, err := FindAvailableUDPPort(49152, 65535)
if err != nil {
logger.Error("Error finding available port: %v", err)
return
}
localAddr := &net.UDPAddr{
Port: int(sourcePort),
IP: net.IPv4zero,
}
udpConn, err := net.ListenUDP("udp", localAddr)
if err != nil {
logger.Error("Failed to create shared UDP socket: %v", err)
return
}
sharedBind, err = bind.New(udpConn)
if err != nil {
logger.Error("Failed to create shared bind: %v", err)
udpConn.Close()
return
}
// Add a reference for the hole punch senders (creator already has one reference for WireGuard)
sharedBind.AddRef()
logger.Info("Created shared UDP socket on port %d (refcount: %d)", sourcePort, sharedBind.GetRefCount())
}
olm.RegisterHandler("olm/wg/holepunch/all", func(msg websocket.WSMessage) {
@@ -251,7 +279,7 @@ func TunnelProcess(ctx context.Context, config Config, id string, secret string,
// Start a single hole punch goroutine for all exit nodes
logger.Info("Starting hole punch for %d exit nodes", len(holePunchData.ExitNodes))
go keepSendingUDPHolePunchToMultipleExitNodes(holePunchData.ExitNodes, id, sourcePort)
go keepSendingUDPHolePunchToMultipleExitNodesWithSharedBind(holePunchData.ExitNodes, id, sharedBind)
})
olm.RegisterHandler("olm/wg/holepunch", func(msg websocket.WSMessage) {
@@ -289,7 +317,7 @@ func TunnelProcess(ctx context.Context, config Config, id string, secret string,
// Start hole punching for each exit node
logger.Info("Starting hole punch for exit node: %s with public key: %s", legacyHolePunchData.Endpoint, legacyHolePunchData.ServerPubKey)
go keepSendingUDPHolePunch(legacyHolePunchData.Endpoint, id, sourcePort, legacyHolePunchData.ServerPubKey)
go keepSendingUDPHolePunchWithSharedBind(legacyHolePunchData.Endpoint, id, sharedBind, legacyHolePunchData.ServerPubKey)
})
olm.RegisterHandler("olm/wg/connect", func(msg websocket.WSMessage) {
@@ -305,7 +333,7 @@ func TunnelProcess(ctx context.Context, config Config, id string, secret string,
stopRegister = nil
}
close(stopHolepunch)
// close(stopHolepunch)
// wait 10 milliseconds to ensure the previous connection is closed
logger.Debug("Waiting 500 milliseconds to ensure previous connection is closed")
@@ -367,7 +395,7 @@ func TunnelProcess(ctx context.Context, config Config, id string, secret string,
return
}
dev = device.NewDevice(tdev, NewFixedPortBind(uint16(sourcePort)), device.NewLogger(mapToWireGuardLogLevel(loggerLevel), "wireguard: "))
dev = device.NewDevice(tdev, sharedBind, device.NewLogger(mapToWireGuardLogLevel(loggerLevel), "wireguard: "))
uapiListener, err = uapiListen(interfaceName, fileUAPI)
if err != nil {
@@ -804,7 +832,7 @@ func Stop() {
uapiListener = nil
}
if dev != nil {
dev.Close()
dev.Close() // This will call sharedBind.Close() which releases WireGuard's reference
dev = nil
}
// Close TUN device
@@ -813,6 +841,15 @@ func Stop() {
tdev = nil
}
// Release the hole punch reference to the shared bind
if sharedBind != nil {
// Release hole punch reference (WireGuard already released its reference via dev.Close())
logger.Debug("Releasing shared bind (refcount before release: %d)", sharedBind.GetRefCount())
sharedBind.Release()
sharedBind = nil
logger.Info("Released shared UDP bind")
}
logger.Info("Olm service stopped")
}