diff --git a/main.go b/main.go index bb705eb..bdfc0a2 100644 --- a/main.go +++ b/main.go @@ -57,6 +57,20 @@ var ( wgClient *wgctrl.Client ) +// Add this new type at the top with other type definitions +type ClientEndpoint struct { + OlmID string `json:"olmId"` + NewtID string `json:"newtId"` + IP string `json:"ip"` + Port int `json:"port"` + Timestamp int64 `json:"timestamp"` +} + +type HolePunchMessage struct { + OlmID string `json:"olmId"` + NewtID string `json:"newtId"` +} + func parseLogLevel(level string) logger.LogLevel { switch strings.ToUpper(level) { case "DEBUG": @@ -74,6 +88,82 @@ func parseLogLevel(level string) logger.LogLevel { } } +// Update the startUDPServer function +func startUDPServer(addr string, server string) { + udpAddr, err := net.ResolveUDPAddr("udp", addr) + if err != nil { + logger.Fatal("Failed to resolve UDP address: %v", err) + } + + conn, err := net.ListenUDP("udp", udpAddr) + if err != nil { + logger.Fatal("Failed to start UDP server: %v", err) + } + defer conn.Close() + + logger.Info("UDP server listening on %s", addr) + + buffer := make([]byte, 1024) + for { + n, remoteAddr, err := conn.ReadFromUDP(buffer) + if err != nil { + logger.Error("Error reading UDP packet: %v", err) + continue + } + + var msg HolePunchMessage + if err := json.Unmarshal(buffer[:n], &msg); err != nil { + logger.Error("Error unmarshaling message: %v", err) + continue + } + + // Create endpoint info + endpoint := ClientEndpoint{ + OlmID: msg.OlmID, + NewtID: msg.NewtID, + IP: remoteAddr.IP.String(), + Port: remoteAddr.Port, + Timestamp: time.Now().Unix(), + } + + // Send the endpoint info to the Olm server + go notifyServer(endpoint, server) + + logger.Info("Received hole punch from %s:%d for Olm ID: %s", + remoteAddr.IP, + remoteAddr.Port, + msg.OlmID) + } +} + +// Add this new function +func notifyServer(endpoint ClientEndpoint, server string) { + jsonData, err := json.Marshal(endpoint) + if err != nil { + logger.Error("Failed to marshal endpoint data: %v", err) + return + } + + resp, err := http.Post(server, + "application/json", + bytes.NewBuffer(jsonData)) + if err != nil { + logger.Error("Failed to notify Olm server: %v", err) + return + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + logger.Error("Olm server returned non-OK status: %d, body: %s", + resp.StatusCode, + string(body)) + return + } + + logger.Info("Successfully notified Olm server about endpoint for ID: %s", endpoint.OlmID) +} + func main() { var ( err error @@ -85,6 +175,7 @@ func main() { reachableAt string logLevel string mtu string + reportHolePunchTo string ) interfaceName = os.Getenv("INTERFACE") @@ -96,6 +187,7 @@ func main() { reachableAt = os.Getenv("REACHABLE_AT") logLevel = os.Getenv("LOG_LEVEL") mtu = os.Getenv("MTU") + reportHolePunchTo = os.Getenv("REPORT_HOLE_PUNCH_TO") if interfaceName == "" { flag.StringVar(&interfaceName, "interface", "wg0", "Name of the WireGuard interface") @@ -112,6 +204,9 @@ func main() { if reportBandwidthTo == "" { flag.StringVar(&reportBandwidthTo, "reportBandwidthTo", "", "Address to listen on") } + if reportHolePunchTo == "" { + flag.StringVar(&reportHolePunchTo, "reportHolePunchTo", "", "Address to listen on") + } if generateAndSaveKeyTo == "" { flag.StringVar(&generateAndSaveKeyTo, "generateAndSaveKeyTo", "", "Path to save generated private key") } @@ -220,6 +315,9 @@ func main() { go periodicBandwidthCheck(reportBandwidthTo) } + // run the udp server + go startUDPServer(":21820", reportHolePunchTo) + http.HandleFunc("/peer", handlePeer) logger.Info("Starting server on %s", listenAddr) logger.Fatal("Failed to start server: %v", http.ListenAndServe(listenAddr, nil))