Add holepunch udp server

This commit is contained in:
Owen
2025-02-21 22:28:16 -05:00
parent aa4f4ebfab
commit 7b3f7d2b12

98
main.go
View File

@@ -57,6 +57,20 @@ var (
wgClient *wgctrl.Client
)
// Add this new type at the top with other type definitions
type ClientEndpoint struct {
OlmID string `json:"olmId"`
NewtID string `json:"newtId"`
IP string `json:"ip"`
Port int `json:"port"`
Timestamp int64 `json:"timestamp"`
}
type HolePunchMessage struct {
OlmID string `json:"olmId"`
NewtID string `json:"newtId"`
}
func parseLogLevel(level string) logger.LogLevel {
switch strings.ToUpper(level) {
case "DEBUG":
@@ -74,6 +88,82 @@ func parseLogLevel(level string) logger.LogLevel {
}
}
// Update the startUDPServer function
func startUDPServer(addr string, server string) {
udpAddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
logger.Fatal("Failed to resolve UDP address: %v", err)
}
conn, err := net.ListenUDP("udp", udpAddr)
if err != nil {
logger.Fatal("Failed to start UDP server: %v", err)
}
defer conn.Close()
logger.Info("UDP server listening on %s", addr)
buffer := make([]byte, 1024)
for {
n, remoteAddr, err := conn.ReadFromUDP(buffer)
if err != nil {
logger.Error("Error reading UDP packet: %v", err)
continue
}
var msg HolePunchMessage
if err := json.Unmarshal(buffer[:n], &msg); err != nil {
logger.Error("Error unmarshaling message: %v", err)
continue
}
// Create endpoint info
endpoint := ClientEndpoint{
OlmID: msg.OlmID,
NewtID: msg.NewtID,
IP: remoteAddr.IP.String(),
Port: remoteAddr.Port,
Timestamp: time.Now().Unix(),
}
// Send the endpoint info to the Olm server
go notifyServer(endpoint, server)
logger.Info("Received hole punch from %s:%d for Olm ID: %s",
remoteAddr.IP,
remoteAddr.Port,
msg.OlmID)
}
}
// Add this new function
func notifyServer(endpoint ClientEndpoint, server string) {
jsonData, err := json.Marshal(endpoint)
if err != nil {
logger.Error("Failed to marshal endpoint data: %v", err)
return
}
resp, err := http.Post(server,
"application/json",
bytes.NewBuffer(jsonData))
if err != nil {
logger.Error("Failed to notify Olm server: %v", err)
return
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
logger.Error("Olm server returned non-OK status: %d, body: %s",
resp.StatusCode,
string(body))
return
}
logger.Info("Successfully notified Olm server about endpoint for ID: %s", endpoint.OlmID)
}
func main() {
var (
err error
@@ -85,6 +175,7 @@ func main() {
reachableAt string
logLevel string
mtu string
reportHolePunchTo string
)
interfaceName = os.Getenv("INTERFACE")
@@ -96,6 +187,7 @@ func main() {
reachableAt = os.Getenv("REACHABLE_AT")
logLevel = os.Getenv("LOG_LEVEL")
mtu = os.Getenv("MTU")
reportHolePunchTo = os.Getenv("REPORT_HOLE_PUNCH_TO")
if interfaceName == "" {
flag.StringVar(&interfaceName, "interface", "wg0", "Name of the WireGuard interface")
@@ -112,6 +204,9 @@ func main() {
if reportBandwidthTo == "" {
flag.StringVar(&reportBandwidthTo, "reportBandwidthTo", "", "Address to listen on")
}
if reportHolePunchTo == "" {
flag.StringVar(&reportHolePunchTo, "reportHolePunchTo", "", "Address to listen on")
}
if generateAndSaveKeyTo == "" {
flag.StringVar(&generateAndSaveKeyTo, "generateAndSaveKeyTo", "", "Path to save generated private key")
}
@@ -220,6 +315,9 @@ func main() {
go periodicBandwidthCheck(reportBandwidthTo)
}
// run the udp server
go startUDPServer(":21820", reportHolePunchTo)
http.HandleFunc("/peer", handlePeer)
logger.Info("Starting server on %s", listenAddr)
logger.Fatal("Failed to start server: %v", http.ListenAndServe(listenAddr, nil))