mirror of
https://github.com/fosrl/olm.git
synced 2026-02-20 03:46:42 +00:00
Custom bind?
This commit is contained in:
209
olm/common.go
209
olm/common.go
@@ -14,13 +14,13 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/fosrl/newt/logger"
|
||||
"github.com/fosrl/olm/bind"
|
||||
"github.com/fosrl/olm/peermonitor"
|
||||
"github.com/fosrl/olm/websocket"
|
||||
"github.com/vishvananda/netlink"
|
||||
"golang.org/x/crypto/chacha20poly1305"
|
||||
"golang.org/x/crypto/curve25519"
|
||||
"golang.org/x/exp/rand"
|
||||
"golang.zx2c4.com/wireguard/conn"
|
||||
"golang.zx2c4.com/wireguard/device"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
)
|
||||
@@ -82,11 +82,6 @@ const (
|
||||
ENV_WG_PROCESS_FOREGROUND = "WG_PROCESS_FOREGROUND"
|
||||
)
|
||||
|
||||
type fixedPortBind struct {
|
||||
port uint16
|
||||
conn.Bind
|
||||
}
|
||||
|
||||
// PeerAction represents a request to add, update, or remove a peer
|
||||
type PeerAction struct {
|
||||
Action string `json:"action"` // "add", "update", or "remove"
|
||||
@@ -124,11 +119,6 @@ type RelayPeerData struct {
|
||||
PublicKey string `json:"publicKey"`
|
||||
}
|
||||
|
||||
func (b *fixedPortBind) Open(port uint16) ([]conn.ReceiveFunc, uint16, error) {
|
||||
// Ignore the requested port and use our fixed port
|
||||
return b.Bind.Open(b.port)
|
||||
}
|
||||
|
||||
// Helper function to format endpoints correctly
|
||||
func formatEndpoint(endpoint string) string {
|
||||
if endpoint == "" {
|
||||
@@ -156,13 +146,6 @@ func formatEndpoint(endpoint string) string {
|
||||
return endpoint
|
||||
}
|
||||
|
||||
func NewFixedPortBind(port uint16) conn.Bind {
|
||||
return &fixedPortBind{
|
||||
port: port,
|
||||
Bind: conn.NewDefaultBind(),
|
||||
}
|
||||
}
|
||||
|
||||
func fixKey(key string) string {
|
||||
// Remove any whitespace
|
||||
key = strings.TrimSpace(key)
|
||||
@@ -523,6 +506,196 @@ func keepSendingUDPHolePunch(endpoint string, olmID string, sourcePort uint16, s
|
||||
}
|
||||
}
|
||||
|
||||
// keepSendingUDPHolePunchToMultipleExitNodesWithSharedBind sends hole punch packets using the shared bind
|
||||
func keepSendingUDPHolePunchToMultipleExitNodesWithSharedBind(exitNodes []ExitNode, olmID string, sharedBind *bind.SharedBind) {
|
||||
if len(exitNodes) == 0 {
|
||||
logger.Warn("No exit nodes provided for hole punching")
|
||||
return
|
||||
}
|
||||
|
||||
// Check if hole punching is already running
|
||||
if holePunchRunning {
|
||||
logger.Debug("UDP hole punch already running, skipping new request")
|
||||
return
|
||||
}
|
||||
|
||||
// Set the flag to indicate hole punching is running
|
||||
holePunchRunning = true
|
||||
defer func() {
|
||||
holePunchRunning = false
|
||||
logger.Info("UDP hole punch goroutine ended")
|
||||
}()
|
||||
|
||||
logger.Info("Starting UDP hole punch to %d exit nodes with shared bind", len(exitNodes))
|
||||
defer logger.Info("UDP hole punch goroutine ended for all exit nodes")
|
||||
|
||||
// Resolve all endpoints upfront
|
||||
type resolvedExitNode struct {
|
||||
remoteAddr *net.UDPAddr
|
||||
publicKey string
|
||||
endpointName string
|
||||
}
|
||||
|
||||
var resolvedNodes []resolvedExitNode
|
||||
for _, exitNode := range exitNodes {
|
||||
host, err := resolveDomain(exitNode.Endpoint)
|
||||
if err != nil {
|
||||
logger.Error("Failed to resolve endpoint %s: %v", exitNode.Endpoint, err)
|
||||
continue
|
||||
}
|
||||
|
||||
serverAddr := net.JoinHostPort(host, "21820")
|
||||
remoteAddr, err := net.ResolveUDPAddr("udp", serverAddr)
|
||||
if err != nil {
|
||||
logger.Error("Failed to resolve UDP address for %s: %v", exitNode.Endpoint, err)
|
||||
continue
|
||||
}
|
||||
|
||||
resolvedNodes = append(resolvedNodes, resolvedExitNode{
|
||||
remoteAddr: remoteAddr,
|
||||
publicKey: exitNode.PublicKey,
|
||||
endpointName: exitNode.Endpoint,
|
||||
})
|
||||
logger.Info("Resolved exit node: %s -> %s", exitNode.Endpoint, remoteAddr.String())
|
||||
}
|
||||
|
||||
if len(resolvedNodes) == 0 {
|
||||
logger.Error("No exit nodes could be resolved")
|
||||
return
|
||||
}
|
||||
|
||||
// Send initial hole punch to all exit nodes
|
||||
for _, node := range resolvedNodes {
|
||||
if err := sendUDPHolePunchWithBind(sharedBind, node.remoteAddr, olmID, node.publicKey); err != nil {
|
||||
logger.Error("Failed to send initial UDP hole punch to %s: %v", node.endpointName, err)
|
||||
}
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(250 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
timeout := time.NewTimer(15 * time.Second)
|
||||
defer timeout.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-stopHolepunch:
|
||||
logger.Info("Stopping UDP holepunch for all exit nodes")
|
||||
return
|
||||
case <-timeout.C:
|
||||
logger.Info("UDP holepunch routine timed out after 15 seconds for all exit nodes")
|
||||
return
|
||||
case <-ticker.C:
|
||||
// Send hole punch to all exit nodes
|
||||
for _, node := range resolvedNodes {
|
||||
if err := sendUDPHolePunchWithBind(sharedBind, node.remoteAddr, olmID, node.publicKey); err != nil {
|
||||
logger.Error("Failed to send UDP hole punch to %s: %v", node.endpointName, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// keepSendingUDPHolePunchWithSharedBind sends hole punch packets to a single endpoint using shared bind
|
||||
func keepSendingUDPHolePunchWithSharedBind(endpoint string, olmID string, sharedBind *bind.SharedBind, serverPubKey string) {
|
||||
// Check if hole punching is already running
|
||||
if holePunchRunning {
|
||||
logger.Debug("UDP hole punch already running, skipping new request")
|
||||
return
|
||||
}
|
||||
|
||||
// Set the flag to indicate hole punching is running
|
||||
holePunchRunning = true
|
||||
defer func() {
|
||||
holePunchRunning = false
|
||||
logger.Info("UDP hole punch goroutine ended")
|
||||
}()
|
||||
|
||||
logger.Info("Starting UDP hole punch to %s with shared bind", endpoint)
|
||||
defer logger.Info("UDP hole punch goroutine ended for %s", endpoint)
|
||||
|
||||
host, err := resolveDomain(endpoint)
|
||||
if err != nil {
|
||||
logger.Error("Failed to resolve domain %s: %v", endpoint, err)
|
||||
return
|
||||
}
|
||||
|
||||
serverAddr := net.JoinHostPort(host, "21820")
|
||||
|
||||
remoteAddr, err := net.ResolveUDPAddr("udp", serverAddr)
|
||||
if err != nil {
|
||||
logger.Error("Failed to resolve UDP address %s: %v", serverAddr, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Execute once immediately before starting the loop
|
||||
if err := sendUDPHolePunchWithBind(sharedBind, remoteAddr, olmID, serverPubKey); err != nil {
|
||||
logger.Error("Failed to send initial UDP hole punch: %v", err)
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(250 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
timeout := time.NewTimer(15 * time.Second)
|
||||
defer timeout.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-stopHolepunch:
|
||||
logger.Info("Stopping UDP holepunch")
|
||||
return
|
||||
case <-timeout.C:
|
||||
logger.Info("UDP holepunch routine timed out after 15 seconds")
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := sendUDPHolePunchWithBind(sharedBind, remoteAddr, olmID, serverPubKey); err != nil {
|
||||
logger.Error("Failed to send UDP hole punch: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// sendUDPHolePunchWithBind sends an encrypted hole punch packet using the shared bind
|
||||
func sendUDPHolePunchWithBind(sharedBind *bind.SharedBind, remoteAddr *net.UDPAddr, olmID string, serverPubKey string) error {
|
||||
if serverPubKey == "" || olmToken == "" {
|
||||
return fmt.Errorf("server public key or OLM token is empty")
|
||||
}
|
||||
|
||||
payload := struct {
|
||||
OlmID string `json:"olmId"`
|
||||
Token string `json:"token"`
|
||||
}{
|
||||
OlmID: olmID,
|
||||
Token: olmToken,
|
||||
}
|
||||
|
||||
// Convert payload to JSON
|
||||
payloadBytes, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal payload: %w", err)
|
||||
}
|
||||
|
||||
// Encrypt the payload using the server's WireGuard public key
|
||||
encryptedPayload, err := encryptPayload(payloadBytes, serverPubKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encrypt payload: %w", err)
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(encryptedPayload)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal encrypted payload: %w", err)
|
||||
}
|
||||
|
||||
_, err = sharedBind.WriteToUDP(jsonData, remoteAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to write to UDP: %w", err)
|
||||
}
|
||||
|
||||
logger.Debug("Sent UDP hole punch to %s: %s", remoteAddr.String(), string(jsonData))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func FindAvailableUDPPort(minPort, maxPort uint16) (uint16, error) {
|
||||
if maxPort < minPort {
|
||||
return 0, fmt.Errorf("invalid port range: min=%d, max=%d", minPort, maxPort)
|
||||
|
||||
55
olm/olm.go
55
olm/olm.go
@@ -12,6 +12,7 @@ import (
|
||||
"github.com/fosrl/newt/logger"
|
||||
"github.com/fosrl/newt/updates"
|
||||
"github.com/fosrl/olm/api"
|
||||
"github.com/fosrl/olm/bind"
|
||||
"github.com/fosrl/olm/peermonitor"
|
||||
"github.com/fosrl/olm/websocket"
|
||||
"golang.zx2c4.com/wireguard/device"
|
||||
@@ -67,6 +68,7 @@ var (
|
||||
olmClient *websocket.Client
|
||||
tunnelCancel context.CancelFunc
|
||||
tunnelRunning bool
|
||||
sharedBind *bind.SharedBind
|
||||
)
|
||||
|
||||
func Run(ctx context.Context, config Config) {
|
||||
@@ -226,10 +228,36 @@ func TunnelProcess(ctx context.Context, config Config, id string, secret string,
|
||||
return
|
||||
}
|
||||
|
||||
sourcePort, err := FindAvailableUDPPort(49152, 65535)
|
||||
if err != nil {
|
||||
logger.Error("Error finding available port: %v", err)
|
||||
return
|
||||
// Create shared UDP socket for both holepunch and WireGuard
|
||||
if sharedBind == nil {
|
||||
sourcePort, err := FindAvailableUDPPort(49152, 65535)
|
||||
if err != nil {
|
||||
logger.Error("Error finding available port: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
localAddr := &net.UDPAddr{
|
||||
Port: int(sourcePort),
|
||||
IP: net.IPv4zero,
|
||||
}
|
||||
|
||||
udpConn, err := net.ListenUDP("udp", localAddr)
|
||||
if err != nil {
|
||||
logger.Error("Failed to create shared UDP socket: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
sharedBind, err = bind.New(udpConn)
|
||||
if err != nil {
|
||||
logger.Error("Failed to create shared bind: %v", err)
|
||||
udpConn.Close()
|
||||
return
|
||||
}
|
||||
|
||||
// Add a reference for the hole punch senders (creator already has one reference for WireGuard)
|
||||
sharedBind.AddRef()
|
||||
|
||||
logger.Info("Created shared UDP socket on port %d (refcount: %d)", sourcePort, sharedBind.GetRefCount())
|
||||
}
|
||||
|
||||
olm.RegisterHandler("olm/wg/holepunch/all", func(msg websocket.WSMessage) {
|
||||
@@ -251,7 +279,7 @@ func TunnelProcess(ctx context.Context, config Config, id string, secret string,
|
||||
|
||||
// Start a single hole punch goroutine for all exit nodes
|
||||
logger.Info("Starting hole punch for %d exit nodes", len(holePunchData.ExitNodes))
|
||||
go keepSendingUDPHolePunchToMultipleExitNodes(holePunchData.ExitNodes, id, sourcePort)
|
||||
go keepSendingUDPHolePunchToMultipleExitNodesWithSharedBind(holePunchData.ExitNodes, id, sharedBind)
|
||||
})
|
||||
|
||||
olm.RegisterHandler("olm/wg/holepunch", func(msg websocket.WSMessage) {
|
||||
@@ -289,7 +317,7 @@ func TunnelProcess(ctx context.Context, config Config, id string, secret string,
|
||||
|
||||
// Start hole punching for each exit node
|
||||
logger.Info("Starting hole punch for exit node: %s with public key: %s", legacyHolePunchData.Endpoint, legacyHolePunchData.ServerPubKey)
|
||||
go keepSendingUDPHolePunch(legacyHolePunchData.Endpoint, id, sourcePort, legacyHolePunchData.ServerPubKey)
|
||||
go keepSendingUDPHolePunchWithSharedBind(legacyHolePunchData.Endpoint, id, sharedBind, legacyHolePunchData.ServerPubKey)
|
||||
})
|
||||
|
||||
olm.RegisterHandler("olm/wg/connect", func(msg websocket.WSMessage) {
|
||||
@@ -305,7 +333,7 @@ func TunnelProcess(ctx context.Context, config Config, id string, secret string,
|
||||
stopRegister = nil
|
||||
}
|
||||
|
||||
close(stopHolepunch)
|
||||
// close(stopHolepunch)
|
||||
|
||||
// wait 10 milliseconds to ensure the previous connection is closed
|
||||
logger.Debug("Waiting 500 milliseconds to ensure previous connection is closed")
|
||||
@@ -367,7 +395,7 @@ func TunnelProcess(ctx context.Context, config Config, id string, secret string,
|
||||
return
|
||||
}
|
||||
|
||||
dev = device.NewDevice(tdev, NewFixedPortBind(uint16(sourcePort)), device.NewLogger(mapToWireGuardLogLevel(loggerLevel), "wireguard: "))
|
||||
dev = device.NewDevice(tdev, sharedBind, device.NewLogger(mapToWireGuardLogLevel(loggerLevel), "wireguard: "))
|
||||
|
||||
uapiListener, err = uapiListen(interfaceName, fileUAPI)
|
||||
if err != nil {
|
||||
@@ -804,7 +832,7 @@ func Stop() {
|
||||
uapiListener = nil
|
||||
}
|
||||
if dev != nil {
|
||||
dev.Close()
|
||||
dev.Close() // This will call sharedBind.Close() which releases WireGuard's reference
|
||||
dev = nil
|
||||
}
|
||||
// Close TUN device
|
||||
@@ -813,6 +841,15 @@ func Stop() {
|
||||
tdev = nil
|
||||
}
|
||||
|
||||
// Release the hole punch reference to the shared bind
|
||||
if sharedBind != nil {
|
||||
// Release hole punch reference (WireGuard already released its reference via dev.Close())
|
||||
logger.Debug("Releasing shared bind (refcount before release: %d)", sharedBind.GetRefCount())
|
||||
sharedBind.Release()
|
||||
sharedBind = nil
|
||||
logger.Info("Released shared UDP bind")
|
||||
}
|
||||
|
||||
logger.Info("Olm service stopped")
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user