From 43a43b429d746e71124a9913e717fef96a69c278 Mon Sep 17 00:00:00 2001 From: Owen Date: Sat, 22 Feb 2025 11:20:30 -0500 Subject: [PATCH] Initial hp working maybe? --- main.go | 190 +++++++++++++++++++++++++++++++++++++++++--------------- 1 file changed, 139 insertions(+), 51 deletions(-) diff --git a/main.go b/main.go index cb93126..1ece254 100644 --- a/main.go +++ b/main.go @@ -15,6 +15,7 @@ import ( "strconv" "strings" "syscall" + "time" "github.com/fosrl/newt/logger" "github.com/fosrl/olm/websocket" @@ -46,6 +47,34 @@ type TargetData struct { Targets []string `json:"targets"` } +var ( + stopHolepunch chan struct{} + stopRegister chan struct{} +) + +const ( + ENV_WG_TUN_FD = "WG_TUN_FD" + ENV_WG_UAPI_FD = "WG_UAPI_FD" + ENV_WG_PROCESS_FOREGROUND = "WG_PROCESS_FOREGROUND" +) + +type fixedPortBind struct { + port uint16 + conn.Bind +} + +func (b *fixedPortBind) Open(port uint16) ([]conn.ReceiveFunc, uint16, error) { + // Ignore the requested port and use our fixed port + return b.Bind.Open(b.port) +} + +func NewFixedPortBind(port uint16) conn.Bind { + return &fixedPortBind{ + port: port, + Bind: conn.NewDefaultBind(), + } +} + func fixKey(key string) string { // Remove any whitespace key = strings.TrimSpace(key) @@ -60,12 +89,6 @@ func fixKey(key string) string { return hex.EncodeToString(decoded) } -const ( - ENV_WG_TUN_FD = "WG_TUN_FD" - ENV_WG_UAPI_FD = "WG_UAPI_FD" - ENV_WG_PROCESS_FOREGROUND = "WG_PROCESS_FOREGROUND" -) - func parseLogLevel(level string) logger.LogLevel { switch strings.ToUpper(level) { case "DEBUG": @@ -201,7 +224,7 @@ func configureDarwin(interfaceName string, ip net.IP, destIp string) error { ipStr := ip.String() - cmd := exec.Command("ifconfig", interfaceName, ipStr+"/24", destIp, "up") + cmd := exec.Command("ifconfig", interfaceName, ipStr+"/24", destIp, "up") // TODO: dont hard code /24 // print the command used logger.Info("Running command: %v", cmd) @@ -213,20 +236,6 @@ func configureDarwin(interfaceName string, ip net.IP, destIp string) error { return nil } -// Helper function for ioctl calls -func ioctl(fd int, request uint, argp uintptr) error { - _, _, errno := syscall.Syscall( - syscall.SYS_IOCTL, - uintptr(fd), - uintptr(request), - argp, - ) - if errno != 0 { - return os.NewSyscallError("ioctl", errno) - } - return nil -} - func configureLinux(interfaceName string, ip net.IP, ipNet *net.IPNet) error { // Get the interface link, err := netlink.LinkByName(interfaceName) @@ -256,10 +265,10 @@ func configureLinux(interfaceName string, ip net.IP, ipNet *net.IPNet) error { } // TODO: we need to send the token with this probably to verify auth -func sendUDPHolePunch(serverAddr string, olmID string, sourcePort int) error { +func sendUDPHolePunch(serverAddr string, olmID string, sourcePort uint16) error { // Bind to specific local port localAddr := &net.UDPAddr{ - Port: sourcePort, + Port: int(sourcePort), IP: net.IPv4zero, } @@ -293,21 +302,85 @@ func sendUDPHolePunch(serverAddr string, olmID string, sourcePort int) error { return nil } -type fixedPortBind struct { - port uint16 - conn.Bind -} - -func (b *fixedPortBind) Open(port uint16) ([]conn.ReceiveFunc, uint16, error) { - // Ignore the requested port and use our fixed port - return b.Bind.Open(b.port) -} - -func NewFixedPortBind(port uint16) conn.Bind { - return &fixedPortBind{ - port: port, - Bind: conn.NewDefaultBind(), +func keepSendingUDPHolePunch(endpoint string, olmID string, sourcePort uint16) { + var host = endpoint + if strings.HasPrefix(host, "http://") { + host = strings.TrimPrefix(host, "http://") + } else if strings.HasPrefix(host, "https://") { + host = strings.TrimPrefix(host, "https://") } + + ticker := time.NewTicker(1 * time.Second) + defer ticker.Stop() + + for { + select { + case <-stopHolepunch: + logger.Info("Stopping UDP holepunch") + return + case <-ticker.C: + if err := sendUDPHolePunch(host+":21820", olmID, sourcePort); err != nil { + logger.Error("Failed to send UDP hole punch: %v", err) + } + } + } +} + +func sendRegistration(olm *websocket.Client, publicKey string) error { + err := olm.SendMessage("olm/wg/register", map[string]interface{}{ + "publicKey": publicKey, + }) + if err != nil { + logger.Error("Failed to send registration message: %v", err) + return err + } + logger.Info("Sent registration message") + return nil +} + +func keepSendingRegistration(olm *websocket.Client, publicKey string) { + ticker := time.NewTicker(2 * time.Second) + defer ticker.Stop() + + for { + select { + case <-stopRegister: + logger.Info("Stopping registration messages") + return + case <-ticker.C: + if err := sendRegistration(olm, publicKey); err != nil { + logger.Error("Failed to send periodic registration: %v", err) + } + } + } +} + +func FindAvailableUDPPort(minPort, maxPort uint16) (uint16, error) { + if maxPort < minPort { + return 0, fmt.Errorf("invalid port range: min=%d, max=%d", minPort, maxPort) + } + + for port := minPort; port <= maxPort; port++ { + // Create the UDP address to test + addr := &net.UDPAddr{ + IP: net.ParseIP("127.0.0.1"), + Port: int(port), + } + + // Attempt to create a UDP listener + conn, err := net.ListenUDP("udp", addr) + if err != nil { + continue // Port is in use or there was an error, try next port + } + + // Close the connection immediately + _ = conn.SetDeadline(time.Now()) + conn.Close() + + return port, nil + } + + return 0, fmt.Errorf("no available UDP ports found in range %d-%d", minPort, maxPort) } func main() { @@ -326,7 +399,8 @@ func main() { reachableAt string ) - const sourcePort = 21821 + stopHolepunch = make(chan struct{}) + stopRegister = make(chan struct{}) // if PANGOLIN_ENDPOINT, OLM_ID, and OLM_SECRET are set as environment variables, they will be used as default values endpoint = os.Getenv("PANGOLIN_ENDPOINT") @@ -402,9 +476,14 @@ func main() { logger.Fatal("Failed to create olm: %v", err) } + sourcePort, err := FindAvailableUDPPort(49152, 65535) + if err != nil { + fmt.Printf("Error finding available port: %v\n", err) + os.Exit(1) + } + // Create TUN device and network stack var dev *device.Device - // var connected bool var wgData WgData var uapi *os.File @@ -417,6 +496,8 @@ func main() { olm.RegisterHandler("olm/wg/connect", func(msg websocket.WSMessage) { logger.Info("Received message: %v", msg.Data) + close(stopRegister) + jsonData, err := json.Marshal(msg.Data) if err != nil { logger.Info("Error marshaling data: %v", err) @@ -531,7 +612,7 @@ func main() { public_key=%s allowed_ip=%s/32 endpoint=%s -persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.PublicKey), wgData.ServerIP, endpoint) +persistent_keepalive_interval=1`, fixKey(privateKey.String()), fixKey(wgData.PublicKey), wgData.ServerIP, endpoint) err = dev.IpcSet(config) if err != nil { @@ -550,6 +631,7 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub logger.Error("Failed to configure interface: %v", err) } + close(stopHolepunch) logger.Info("WireGuard device created.") }) @@ -557,22 +639,14 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub publicKey := privateKey.PublicKey() logger.Debug("Public key: %s", publicKey) - err := olm.SendMessage("olm/wg/register", map[string]interface{}{ - "publicKey": publicKey.String(), - }) - if err != nil { - logger.Error("Failed to send registration message: %v", err) - return err - } + go keepSendingRegistration(olm, publicKey.String()) logger.Info("Sent registration message") return nil }) - // Send holepunch from specific port - if err := sendUDPHolePunch(endpoint+":21820", id, sourcePort); err != nil { - logger.Error("Failed to send UDP hole punch: %v", err) - } + // start sending UDP hole punch + go keepSendingUDPHolePunch(endpoint, id, sourcePort) // Connect to the WebSocket server if err := olm.Connect(); err != nil { @@ -585,6 +659,20 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) <-sigCh + select { + case <-stopHolepunch: + // Channel already closed, do nothing + default: + close(stopHolepunch) + } + + select { + case <-stopRegister: + // Channel already closed + default: + close(stopRegister) + } + uapi.Close() dev.Close() }