diff --git a/main.go b/main.go index 4270251..e2d2c16 100644 --- a/main.go +++ b/main.go @@ -82,6 +82,12 @@ type ProxyMappingUpdate struct { NewDestination relay.PeerDestination `json:"newDestination"` } +type UpdateDestinationsRequest struct { + SourceIP string `json:"sourceIp"` + SourcePort int `json:"sourcePort"` + Destinations []relay.PeerDestination `json:"destinations"` +} + func parseLogLevel(level string) logger.LogLevel { switch strings.ToUpper(level) { case "DEBUG": @@ -252,7 +258,7 @@ func main() { go periodicBandwidthCheck(remoteConfigURL + "/gerbil/receive-bandwidth") // Start the UDP proxy server - proxyServer = relay.NewUDPProxyServer(":21820", remoteConfigURL, key) + proxyServer = relay.NewUDPProxyServer(":21820", remoteConfigURL, key, reachableAt) err = proxyServer.Start() if err != nil { logger.Fatal("Failed to start UDP proxy server: %v", err) @@ -262,6 +268,7 @@ func main() { // Set up HTTP server http.HandleFunc("/peer", handlePeer) http.HandleFunc("/update-proxy-mapping", handleUpdateProxyMapping) + http.HandleFunc("/update-destinations", handleUpdateDestinations) logger.Info("Starting HTTP server on %s", listenAddr) // Run HTTP server in a goroutine @@ -727,29 +734,34 @@ func removePeerInternal(publicKey string) error { func handleUpdateProxyMapping(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { + logger.Error("Invalid method: %s", r.Method) http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return } var update ProxyMappingUpdate if err := json.NewDecoder(r.Body).Decode(&update); err != nil { + logger.Error("Failed to decode request body: %v", err) http.Error(w, fmt.Sprintf("Failed to decode request body: %v", err), http.StatusBadRequest) return } // Validate the update request if update.OldDestination.DestinationIP == "" || update.NewDestination.DestinationIP == "" { + logger.Error("Both old and new destination IP addresses are required") http.Error(w, "Both old and new destination IP addresses are required", http.StatusBadRequest) return } if update.OldDestination.DestinationPort <= 0 || update.NewDestination.DestinationPort <= 0 { + logger.Error("Both old and new destination ports must be positive integers") http.Error(w, "Both old and new destination ports must be positive integers", http.StatusBadRequest) return } // Update the proxy mappings in the relay server if proxyServer == nil { + logger.Error("Proxy server is not available") http.Error(w, "Proxy server is not available", http.StatusInternalServerError) return } @@ -770,6 +782,75 @@ func handleUpdateProxyMapping(w http.ResponseWriter, r *http.Request) { }) } +func handleUpdateDestinations(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + logger.Error("Invalid method: %s", r.Method) + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + var request UpdateDestinationsRequest + if err := json.NewDecoder(r.Body).Decode(&request); err != nil { + logger.Error("Failed to decode request body: %v", err) + http.Error(w, fmt.Sprintf("Failed to decode request body: %v", err), http.StatusBadRequest) + return + } + + // Validate the request + if request.SourceIP == "" { + logger.Error("Source IP address is required") + http.Error(w, "Source IP address is required", http.StatusBadRequest) + return + } + + if request.SourcePort <= 0 { + logger.Error("Source port must be a positive integer") + http.Error(w, "Source port must be a positive integer", http.StatusBadRequest) + return + } + + if len(request.Destinations) == 0 { + logger.Error("At least one destination is required") + http.Error(w, "At least one destination is required", http.StatusBadRequest) + return + } + + // Validate each destination + for i, dest := range request.Destinations { + if dest.DestinationIP == "" { + logger.Error("Destination IP is required for destination %d", i) + http.Error(w, fmt.Sprintf("Destination IP is required for destination %d", i), http.StatusBadRequest) + return + } + if dest.DestinationPort <= 0 { + logger.Error("Destination port must be a positive integer for destination %d", i) + http.Error(w, fmt.Sprintf("Destination port must be a positive integer for destination %d", i), http.StatusBadRequest) + return + } + } + + // Update the proxy mappings in the relay server + if proxyServer == nil { + logger.Error("Proxy server is not available") + http.Error(w, "Proxy server is not available", http.StatusInternalServerError) + return + } + + proxyServer.UpdateProxyMapping(request.SourceIP, request.SourcePort, request.Destinations) + + logger.Info("Updated proxy mapping for %s:%d with %d destinations", + request.SourceIP, request.SourcePort, len(request.Destinations)) + + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(map[string]interface{}{ + "status": "Destinations updated successfully", + "sourceIP": request.SourceIP, + "sourcePort": request.SourcePort, + "destinationCount": len(request.Destinations), + "destinations": request.Destinations, + }) +} + func periodicBandwidthCheck(endpoint string) { ticker := time.NewTicker(10 * time.Second) defer ticker.Stop() diff --git a/relay/relay.go b/relay/relay.go index 939f4d2..2f87d78 100644 --- a/relay/relay.go +++ b/relay/relay.go @@ -8,7 +8,6 @@ import ( "io" "net" "net/http" - "strconv" "sync" "time" @@ -31,12 +30,13 @@ type HolePunchMessage struct { } type ClientEndpoint struct { - OlmID string `json:"olmId"` - NewtID string `json:"newtId"` - Token string `json:"token"` - IP string `json:"ip"` - Port int `json:"port"` - Timestamp int64 `json:"timestamp"` + OlmID string `json:"olmId"` + NewtID string `json:"newtId"` + Token string `json:"token"` + IP string `json:"ip"` + Port int `json:"port"` + Timestamp int64 `json:"timestamp"` + ReachableAt string `json:"reachableAt"` } // Updated to support multiple destination peers @@ -104,15 +104,18 @@ type UDPProxyServer struct { // Session tracking for WireGuard peers // Key format: "senderIndex:receiverIndex" wgSessions sync.Map + // ReachableAt is the URL where this server can be reached + ReachableAt string } // NewUDPProxyServer initializes the server with a buffered packet channel. -func NewUDPProxyServer(addr, serverURL string, privateKey wgtypes.Key) *UDPProxyServer { +func NewUDPProxyServer(addr, serverURL string, privateKey wgtypes.Key, reachableAt string) *UDPProxyServer { return &UDPProxyServer{ - addr: addr, - serverURL: serverURL, - privateKey: privateKey, - packetChan: make(chan Packet, 1000), + addr: addr, + serverURL: serverURL, + privateKey: privateKey, + packetChan: make(chan Packet, 1000), + ReachableAt: reachableAt, } } @@ -215,12 +218,13 @@ func (s *UDPProxyServer) packetWorker() { } endpoint := ClientEndpoint{ - NewtID: msg.NewtID, - OlmID: msg.OlmID, - Token: msg.Token, - IP: packet.remoteAddr.IP.String(), - Port: packet.remoteAddr.Port, - Timestamp: time.Now().Unix(), + NewtID: msg.NewtID, + OlmID: msg.OlmID, + Token: msg.Token, + IP: packet.remoteAddr.IP.String(), + Port: packet.remoteAddr.Port, + Timestamp: time.Now().Unix(), + ReachableAt: s.ReachableAt, } logger.Debug("Created endpoint from packet remoteAddr %s: IP=%s, Port=%d", packet.remoteAddr.String(), endpoint.IP, endpoint.Port) s.notifyServer(endpoint) @@ -644,7 +648,7 @@ func (s *UDPProxyServer) notifyServer(endpoint ClientEndpoint) { // Updated to support multiple destinations func (s *UDPProxyServer) UpdateProxyMapping(sourceIP string, sourcePort int, destinations []PeerDestination) { - key := net.JoinHostPort(sourceIP, strconv.Itoa(sourcePort)) + key := fmt.Sprintf("%s:%d", sourceIP, sourcePort) mapping := ProxyMapping{ Destinations: destinations, LastUsed: time.Now(),