diff --git a/.dockerignore b/.dockerignore index d16e2de..d5ba4e4 100644 --- a/.dockerignore +++ b/.dockerignore @@ -6,4 +6,5 @@ README.md Makefile public/ LICENSE -CONTRIBUTING.md \ No newline at end of file +CONTRIBUTING.md +.git diff --git a/main.go b/main.go index c18441d..9f01344 100644 --- a/main.go +++ b/main.go @@ -31,6 +31,7 @@ var ( lastReadings = make(map[string]PeerReading) mu sync.Mutex notifyURL string + proxyServer *relay.UDPProxyServer ) type WgConfig struct { @@ -75,6 +76,11 @@ type HolePunchMessage struct { NewtID string `json:"newtId"` } +type ProxyMappingUpdate struct { + OldDestination relay.PeerDestination `json:"oldDestination"` + NewDestination relay.PeerDestination `json:"newDestination"` +} + func parseLogLevel(level string) logger.LogLevel { switch strings.ToUpper(level) { case "DEBUG": @@ -245,15 +251,16 @@ func main() { go periodicBandwidthCheck(remoteConfigURL + "/gerbil/receive-bandwidth") // Start the UDP proxy server - server := relay.NewUDPProxyServer(":21820", remoteConfigURL, key) - err = server.Start() + proxyServer = relay.NewUDPProxyServer(":21820", remoteConfigURL, key) + err = proxyServer.Start() if err != nil { logger.Fatal("Failed to start UDP proxy server: %v", err) } - defer server.Stop() + defer proxyServer.Stop() // Set up HTTP server http.HandleFunc("/peer", handlePeer) + http.HandleFunc("/update-proxy-mapping", handleUpdateProxyMapping) logger.Info("Starting HTTP server on %s", listenAddr) // Run HTTP server in a goroutine @@ -598,12 +605,15 @@ func addPeer(peer Peer) error { // parse allowed IPs into array of net.IPNet var allowedIPs []net.IPNet + var wgIPs []string for _, ipStr := range peer.AllowedIPs { _, ipNet, err := net.ParseCIDR(ipStr) if err != nil { return fmt.Errorf("failed to parse allowed IP: %v", err) } allowedIPs = append(allowedIPs, *ipNet) + // Extract the IP address from the CIDR for relay cleanup + wgIPs = append(wgIPs, ipNet.IP.String()) } peerConfig := wgtypes.PeerConfig{ @@ -619,6 +629,13 @@ func addPeer(peer Peer) error { return fmt.Errorf("failed to add peer: %v", err) } + // Clear relay connections for the peer's WireGuard IPs + if proxyServer != nil { + for _, wgIP := range wgIPs { + proxyServer.OnPeerAdded(wgIP) + } + } + logger.Info("Peer %s added successfully", peer.PublicKey) return nil @@ -650,6 +667,23 @@ func removePeer(publicKey string) error { return fmt.Errorf("failed to parse public key: %v", err) } + // Get current peer info before removing to clear relay connections + var wgIPs []string + if proxyServer != nil { + device, err := wgClient.Device(interfaceName) + if err == nil { + for _, peer := range device.Peers { + if peer.PublicKey.String() == publicKey { + // Extract WireGuard IPs from this peer's allowed IPs + for _, allowedIP := range peer.AllowedIPs { + wgIPs = append(wgIPs, allowedIP.IP.String()) + } + break + } + } + } + } + peerConfig := wgtypes.PeerConfig{ PublicKey: pubKey, Remove: true, @@ -663,11 +697,63 @@ func removePeer(publicKey string) error { return fmt.Errorf("failed to remove peer: %v", err) } + // Clear relay connections for the peer's WireGuard IPs + if proxyServer != nil { + for _, wgIP := range wgIPs { + proxyServer.OnPeerRemoved(wgIP) + } + } + logger.Info("Peer %s removed successfully", publicKey) return nil } +func handleUpdateProxyMapping(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + var update ProxyMappingUpdate + if err := json.NewDecoder(r.Body).Decode(&update); err != nil { + http.Error(w, fmt.Sprintf("Failed to decode request body: %v", err), http.StatusBadRequest) + return + } + + // Validate the update request + if update.OldDestination.DestinationIP == "" || update.NewDestination.DestinationIP == "" { + http.Error(w, "Both old and new destination IP addresses are required", http.StatusBadRequest) + return + } + + if update.OldDestination.DestinationPort <= 0 || update.NewDestination.DestinationPort <= 0 { + http.Error(w, "Both old and new destination ports must be positive integers", http.StatusBadRequest) + return + } + + // Update the proxy mappings in the relay server + if proxyServer == nil { + http.Error(w, "Proxy server is not available", http.StatusInternalServerError) + return + } + + updatedCount := proxyServer.UpdateDestinationInMappings(update.OldDestination, update.NewDestination) + + logger.Info("Updated %d proxy mappings: %s:%d -> %s:%d", + updatedCount, + update.OldDestination.DestinationIP, update.OldDestination.DestinationPort, + update.NewDestination.DestinationIP, update.NewDestination.DestinationPort) + + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(map[string]interface{}{ + "status": "Proxy mappings updated successfully", + "updatedCount": updatedCount, + "oldDestination": update.OldDestination, + "newDestination": update.NewDestination, + }) +} + func periodicBandwidthCheck(endpoint string) { ticker := time.NewTicker(10 * time.Second) defer ticker.Stop() diff --git a/relay/relay.go b/relay/relay.go index dd293fc..a71eeed 100644 --- a/relay/relay.go +++ b/relay/relay.go @@ -42,6 +42,7 @@ type ClientEndpoint struct { // Updated to support multiple destination peers type ProxyMapping struct { Destinations []PeerDestination `json:"destinations"` + LastUsed time.Time `json:"-"` // Not serialized, used for cleanup } type PeerDestination struct { @@ -148,6 +149,9 @@ func (s *UDPProxyServer) Start() error { // Start the session cleanup routine go s.cleanupIdleSessions() + // Start the proxy mapping cleanup routine + go s.cleanupIdleProxyMappings() + return nil } @@ -218,6 +222,7 @@ func (s *UDPProxyServer) packetWorker() { Port: packet.remoteAddr.Port, Timestamp: time.Now().Unix(), } + logger.Debug("Created endpoint from packet remoteAddr %s: IP=%s, Port=%d", packet.remoteAddr.String(), endpoint.IP, endpoint.Port) s.notifyServer(endpoint) } // Return the buffer to the pool for reuse. @@ -282,6 +287,8 @@ func (s *UDPProxyServer) fetchInitialMappings() error { } // Store mappings in our sync.Map. for key, mapping := range initialMappings.Mappings { + // Initialize LastUsed timestamp for initial mappings + mapping.LastUsed = time.Now() s.proxyMappings.Store(key, mapping) } logger.Info("Loaded %d initial proxy mappings", len(initialMappings.Mappings)) @@ -336,6 +343,9 @@ func (s *UDPProxyServer) handleWireGuardPacket(packet []byte, remoteAddr *net.UD } proxyMapping := mappingObj.(ProxyMapping) + // Update the last used timestamp and store it back + proxyMapping.LastUsed = time.Now() + s.proxyMappings.Store(key, proxyMapping) // Handle different WireGuard message types switch messageType { @@ -574,7 +584,26 @@ func (s *UDPProxyServer) cleanupIdleSessions() { } } +// New method to periodically remove idle proxy mappings +func (s *UDPProxyServer) cleanupIdleProxyMappings() { + ticker := time.NewTicker(10 * time.Minute) + for range ticker.C { + now := time.Now() + s.proxyMappings.Range(func(key, value interface{}) bool { + mapping := value.(ProxyMapping) + // Remove mappings that haven't been used in 30 minutes + if now.Sub(mapping.LastUsed) > 30*time.Minute { + s.proxyMappings.Delete(key) + logger.Debug("Removed idle proxy mapping: %s", key) + } + return true + }) + } +} + func (s *UDPProxyServer) notifyServer(endpoint ClientEndpoint) { + logger.Debug("notifyServer called with endpoint: IP=%s, Port=%d", endpoint.IP, endpoint.Port) + jsonData, err := json.Marshal(endpoint) if err != nil { logger.Error("Failed to marshal endpoint data: %v", err) @@ -602,13 +631,15 @@ func (s *UDPProxyServer) notifyServer(endpoint ClientEndpoint) { return } - logger.Debug("Received proxy mapping: %v", mapping) + logger.Debug("Received proxy mapping from server: %v", mapping) - // Store the mapping + // Store the mapping with current timestamp key := fmt.Sprintf("%s:%d", endpoint.IP, endpoint.Port) + logger.Debug("About to store proxy mapping with key: %s (from endpoint IP=%s, Port=%d)", key, endpoint.IP, endpoint.Port) + mapping.LastUsed = time.Now() s.proxyMappings.Store(key, mapping) - logger.Debug("Stored proxy mapping for %s with %d destinations", key, len(mapping.Destinations)) + logger.Debug("Stored proxy mapping for %s with %d destinations (timestamp: %v)", key, len(mapping.Destinations), mapping.LastUsed) } // Updated to support multiple destinations @@ -616,6 +647,172 @@ func (s *UDPProxyServer) UpdateProxyMapping(sourceIP string, sourcePort int, des key := net.JoinHostPort(sourceIP, strconv.Itoa(sourcePort)) mapping := ProxyMapping{ Destinations: destinations, + LastUsed: time.Now(), } s.proxyMappings.Store(key, mapping) } + +// OnPeerAdded clears connections and sessions for a specific WireGuard IP to allow re-establishment +func (s *UDPProxyServer) OnPeerAdded(wgIP string) { + logger.Info("Clearing connections for added peer with WG IP: %s", wgIP) + s.clearConnectionsForWGIP(wgIP) + s.clearSessionsForWGIP(wgIP) + // s.clearProxyMappingsForWGIP(wgIP) +} + +// OnPeerRemoved clears connections and sessions for a specific WireGuard IP +func (s *UDPProxyServer) OnPeerRemoved(wgIP string) { + logger.Info("Clearing connections for removed peer with WG IP: %s", wgIP) + s.clearConnectionsForWGIP(wgIP) + s.clearSessionsForWGIP(wgIP) + // s.clearProxyMappingsForWGIP(wgIP) +} + +// clearConnectionsForWGIP removes all connections associated with a specific WireGuard IP +func (s *UDPProxyServer) clearConnectionsForWGIP(wgIP string) { + var keysToDelete []string + + s.connections.Range(func(key, value interface{}) bool { + keyStr := key.(string) + destConn := value.(*DestinationConn) + + // Connection keys are in format "destAddr-remoteAddr" + // Check if either destination or remote address contains the WG IP + if containsIP(keyStr, wgIP) { + keysToDelete = append(keysToDelete, keyStr) + destConn.conn.Close() + logger.Debug("Closing connection for WG IP %s: %s", wgIP, keyStr) + } + return true + }) + + // Delete the connections + for _, key := range keysToDelete { + s.connections.Delete(key) + } + + logger.Info("Cleared %d connections for WG IP: %s", len(keysToDelete), wgIP) +} + +// clearSessionsForWGIP removes all WireGuard sessions associated with a specific WireGuard IP +func (s *UDPProxyServer) clearSessionsForWGIP(wgIP string) { + var keysToDelete []string + + s.wgSessions.Range(func(key, value interface{}) bool { + keyStr := key.(string) + session := value.(*WireGuardSession) + + // Check if the session's destination address contains the WG IP + if session.DestAddr != nil && session.DestAddr.IP.String() == wgIP { + keysToDelete = append(keysToDelete, keyStr) + logger.Debug("Marking session for deletion for WG IP %s: %s", wgIP, keyStr) + } + return true + }) + + // Delete the sessions + for _, key := range keysToDelete { + s.wgSessions.Delete(key) + } + + logger.Info("Cleared %d sessions for WG IP: %s", len(keysToDelete), wgIP) +} + +// // clearProxyMappingsForWGIP removes all proxy mappings that have destinations pointing to a specific WireGuard IP +// func (s *UDPProxyServer) clearProxyMappingsForWGIP(wgIP string) { +// var keysToDelete []string + +// s.proxyMappings.Range(func(key, value interface{}) bool { +// keyStr := key.(string) +// mapping := value.(ProxyMapping) + +// // Check if any destination in the mapping contains the WG IP +// for _, dest := range mapping.Destinations { +// if dest.DestinationIP == wgIP { +// keysToDelete = append(keysToDelete, keyStr) +// logger.Debug("Marking proxy mapping for deletion for WG IP %s: %s -> %s:%d", wgIP, keyStr, dest.DestinationIP, dest.DestinationPort) +// break // Found one destination, no need to check others in this mapping +// } +// } +// return true +// }) + +// // Delete the proxy mappings +// for _, key := range keysToDelete { +// s.proxyMappings.Delete(key) +// logger.Debug("Deleted proxy mapping: %s", key) +// } + +// logger.Info("Cleared %d proxy mappings for WG IP: %s", len(keysToDelete), wgIP) +// } + +// containsIP checks if a connection key string contains the specified IP address +func containsIP(connectionKey, ip string) bool { + // Connection keys are in format "destIP:destPort-remoteIP:remotePort" + // Check if the IP appears at the beginning (destination) or after the dash (remote) + ipWithColon := ip + ":" + + // Check if connection key starts with the IP (destination address) + if len(connectionKey) >= len(ipWithColon) && connectionKey[:len(ipWithColon)] == ipWithColon { + return true + } + + // Check if connection key contains the IP after a dash (remote address) + dashIndex := -1 + for i := 0; i < len(connectionKey); i++ { + if connectionKey[i] == '-' { + dashIndex = i + break + } + } + + if dashIndex != -1 && dashIndex+1 < len(connectionKey) { + remainingPart := connectionKey[dashIndex+1:] + if len(remainingPart) >= len(ip)+1 && remainingPart[:len(ip)+1] == ipWithColon { + return true + } + } + + return false +} + +// UpdateDestinationInMappings updates all proxy mappings that contain the old destination with the new destination +// Returns the number of mappings that were updated +func (s *UDPProxyServer) UpdateDestinationInMappings(oldDest, newDest PeerDestination) int { + updatedCount := 0 + + s.proxyMappings.Range(func(key, value interface{}) bool { + keyStr := key.(string) + mapping := value.(ProxyMapping) + updated := false + + // Check each destination in the mapping + for i, dest := range mapping.Destinations { + if dest.DestinationIP == oldDest.DestinationIP && dest.DestinationPort == oldDest.DestinationPort { + // Update this destination + mapping.Destinations[i] = newDest + updated = true + logger.Debug("Updated destination in mapping %s: %s:%d -> %s:%d", + keyStr, oldDest.DestinationIP, oldDest.DestinationPort, + newDest.DestinationIP, newDest.DestinationPort) + } + } + + // If we updated any destinations, store the updated mapping back + if updated { + mapping.LastUsed = time.Now() + s.proxyMappings.Store(keyStr, mapping) + updatedCount++ + } + + return true // continue iteration + }) + + if updatedCount > 0 { + logger.Info("Updated %d proxy mappings from %s:%d to %s:%d", + updatedCount, oldDest.DestinationIP, oldDest.DestinationPort, + newDest.DestinationIP, newDest.DestinationPort) + } + + return updatedCount +}