diff --git a/relay/relay.go b/relay/relay.go index 0ab5930..c065c29 100644 --- a/relay/relay.go +++ b/relay/relay.go @@ -13,12 +13,15 @@ import ( "sync" "time" + "github.com/fosrl/gerbil/internal/metrics" "github.com/fosrl/gerbil/logger" "golang.org/x/crypto/chacha20poly1305" "golang.org/x/crypto/curve25519" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) +const relayIfname = "relay" + type EncryptedHolePunchMessage struct { EphemeralPublicKey string `json:"ephemeralPublicKey"` Nonce []byte `json:"nonce"` @@ -290,9 +293,13 @@ func (s *UDPProxyServer) packetWorker() { for packet := range s.packetChan { // Determine packet type by inspecting the first byte. if packet.n > 0 && packet.data[0] >= 1 && packet.data[0] <= 4 { + metrics.RecordUDPPacket(relayIfname, "wireguard", "in") + metrics.RecordUDPPacketSize(relayIfname, "wireguard", float64(packet.n)) // Process as a WireGuard packet. s.handleWireGuardPacket(packet.data, packet.remoteAddr) } else { + metrics.RecordUDPPacket(relayIfname, "hole_punch", "in") + metrics.RecordUDPPacketSize(relayIfname, "hole_punch", float64(packet.n)) // Rate limit: allow at most 2 hole punch messages per IP:Port per second rateLimitKey := packet.remoteAddr.String() entryVal, _ := s.holePunchRateLimiter.LoadOrStore(rateLimitKey, &holePunchRateLimitEntry{ @@ -310,6 +317,7 @@ func (s *UDPProxyServer) packetWorker() { rlEntry.mu.Unlock() if !allowed { // logger.Debug("Rate limiting hole punch message from %s", rateLimitKey) + metrics.RecordHolePunchEvent(relayIfname, "rate_limited") bufferPool.Put(packet.data[:1500]) continue } @@ -318,6 +326,7 @@ func (s *UDPProxyServer) packetWorker() { var encMsg EncryptedHolePunchMessage if err := json.Unmarshal(packet.data, &encMsg); err != nil { logger.Error("Error unmarshaling encrypted message: %v", err) + metrics.RecordHolePunchEvent(relayIfname, "error") // Return the buffer to the pool for reuse and continue with next packet bufferPool.Put(packet.data[:1500]) continue @@ -325,6 +334,7 @@ func (s *UDPProxyServer) packetWorker() { if encMsg.EphemeralPublicKey == "" { logger.Error("Received malformed message without ephemeral key") + metrics.RecordHolePunchEvent(relayIfname, "error") // Return the buffer to the pool for reuse and continue with next packet bufferPool.Put(packet.data[:1500]) continue @@ -334,6 +344,7 @@ func (s *UDPProxyServer) packetWorker() { decryptedData, err := s.decryptMessage(encMsg) if err != nil { // logger.Error("Failed to decrypt message: %v", err) + metrics.RecordHolePunchEvent(relayIfname, "error") // Return the buffer to the pool for reuse and continue with next packet bufferPool.Put(packet.data[:1500]) continue @@ -343,6 +354,7 @@ func (s *UDPProxyServer) packetWorker() { var msg HolePunchMessage if err := json.Unmarshal(decryptedData, &msg); err != nil { logger.Error("Error unmarshaling decrypted message: %v", err) + metrics.RecordHolePunchEvent(relayIfname, "error") // Return the buffer to the pool for reuse and continue with next packet bufferPool.Put(packet.data[:1500]) continue @@ -362,6 +374,7 @@ func (s *UDPProxyServer) packetWorker() { logger.Debug("Created endpoint from packet remoteAddr %s: IP=%s, Port=%d", packet.remoteAddr.String(), endpoint.IP, endpoint.Port) s.notifyServer(endpoint) s.clearSessionsForIP(endpoint.IP) // Clear sessions for this IP to allow re-establishment + metrics.RecordHolePunchEvent(relayIfname, "success") } // Return the buffer to the pool for reuse. bufferPool.Put(packet.data[:1500]) @@ -429,6 +442,8 @@ func (s *UDPProxyServer) fetchInitialMappings() error { mapping.LastUsed = time.Now() s.proxyMappings.Store(key, mapping) } + metrics.RecordProxyInitialMappings(relayIfname, int64(len(initialMappings.Mappings))) + metrics.RecordProxyMapping(relayIfname, int64(len(initialMappings.Mappings))) logger.Info("Loaded %d initial proxy mappings", len(initialMappings.Mappings)) return nil } @@ -544,7 +559,11 @@ func (s *UDPProxyServer) handleWireGuardPacket(packet []byte, remoteAddr *net.UD _, err = conn.Write(packet) if err != nil { logger.Debug("Failed to forward handshake initiation: %v", err) + metrics.RecordProxyConnectionError(relayIfname, "write_udp") + continue } + metrics.RecordUDPPacket(relayIfname, "wireguard", "out") + metrics.RecordUDPPacketSize(relayIfname, "wireguard", float64(len(packet))) } case WireGuardMessageTypeHandshakeResponse: @@ -556,12 +575,17 @@ func (s *UDPProxyServer) handleWireGuardPacket(packet []byte, remoteAddr *net.UD sessionKey := fmt.Sprintf("%d:%d", receiverIndex, senderIndex) // Store the session information - s.wgSessions.Store(sessionKey, &WireGuardSession{ + session := &WireGuardSession{ ReceiverIndex: receiverIndex, SenderIndex: senderIndex, DestAddr: remoteAddr, LastSeen: time.Now(), - }) + } + if _, loaded := s.wgSessions.LoadOrStore(sessionKey, session); loaded { + s.wgSessions.Store(sessionKey, session) + } else { + metrics.RecordSession(relayIfname, 1) + } // Forward the response to the original sender for _, dest := range proxyMapping.Destinations { @@ -580,7 +604,11 @@ func (s *UDPProxyServer) handleWireGuardPacket(packet []byte, remoteAddr *net.UD _, err = conn.Write(packet) if err != nil { logger.Error("Failed to forward handshake response: %v", err) + metrics.RecordProxyConnectionError(relayIfname, "write_udp") + continue } + metrics.RecordUDPPacket(relayIfname, "wireguard", "out") + metrics.RecordUDPPacketSize(relayIfname, "wireguard", float64(len(packet))) } case WireGuardMessageTypeTransportData: @@ -617,7 +645,11 @@ func (s *UDPProxyServer) handleWireGuardPacket(packet []byte, remoteAddr *net.UD _, err = conn.Write(packet) if err != nil { logger.Debug("Failed to forward transport data: %v", err) + metrics.RecordProxyConnectionError(relayIfname, "write_udp") + return } + metrics.RecordUDPPacket(relayIfname, "wireguard", "out") + metrics.RecordUDPPacketSize(relayIfname, "wireguard", float64(len(packet))) } else { // No known session, fall back to forwarding to all peers logger.Debug("No session found for receiver index %d, forwarding to all destinations", receiverIndex) @@ -640,7 +672,11 @@ func (s *UDPProxyServer) handleWireGuardPacket(packet []byte, remoteAddr *net.UD _, err = conn.Write(packet) if err != nil { logger.Debug("Failed to forward transport data: %v", err) + metrics.RecordProxyConnectionError(relayIfname, "write_udp") + continue } + metrics.RecordUDPPacket(relayIfname, "wireguard", "out") + metrics.RecordUDPPacketSize(relayIfname, "wireguard", float64(len(packet))) } } @@ -665,7 +701,11 @@ func (s *UDPProxyServer) handleWireGuardPacket(packet []byte, remoteAddr *net.UD _, err = conn.Write(packet) if err != nil { logger.Error("Failed to forward WireGuard packet: %v", err) + metrics.RecordProxyConnectionError(relayIfname, "write_udp") + continue } + metrics.RecordUDPPacket(relayIfname, "wireguard", "out") + metrics.RecordUDPPacketSize(relayIfname, "wireguard", float64(len(packet))) } } } @@ -683,6 +723,7 @@ func (s *UDPProxyServer) getOrCreateConnection(destAddr *net.UDPAddr, remoteAddr // Create new connection newConn, err := net.DialUDP("udp", nil, destAddr) if err != nil { + metrics.RecordProxyConnectionError(relayIfname, "dial_udp") return nil, fmt.Errorf("failed to create UDP connection: %v", err) } @@ -706,6 +747,8 @@ func (s *UDPProxyServer) handleResponses(conn *net.UDPConn, destAddr *net.UDPAdd logger.Debug("Error reading response from %s: %v", destAddr.String(), err) return } + metrics.RecordUDPPacket(relayIfname, "wireguard", "in") + metrics.RecordUDPPacketSize(relayIfname, "wireguard", float64(n)) // Process the response to track sessions if it's a WireGuard packet if n > 0 && buffer[0] >= 1 && buffer[0] <= 4 { @@ -713,12 +756,17 @@ func (s *UDPProxyServer) handleResponses(conn *net.UDPConn, destAddr *net.UDPAdd if ok && buffer[0] == WireGuardMessageTypeHandshakeResponse { // Store the session mapping for the handshake response sessionKey := fmt.Sprintf("%d:%d", senderIndex, receiverIndex) - s.wgSessions.Store(sessionKey, &WireGuardSession{ + session := &WireGuardSession{ ReceiverIndex: receiverIndex, SenderIndex: senderIndex, DestAddr: destAddr, LastSeen: time.Now(), - }) + } + if _, loaded := s.wgSessions.LoadOrStore(sessionKey, session); loaded { + s.wgSessions.Store(sessionKey, session) + } else { + metrics.RecordSession(relayIfname, 1) + } logger.Debug("Stored session mapping: %s -> %s", sessionKey, destAddr.String()) } else if ok && buffer[0] == WireGuardMessageTypeTransportData { // Track communication pattern for session rebuilding (reverse direction) @@ -730,7 +778,11 @@ func (s *UDPProxyServer) handleResponses(conn *net.UDPConn, destAddr *net.UDPAdd _, err = s.conn.WriteToUDP(buffer[:n], remoteAddr) if err != nil { logger.Error("Failed to forward response: %v", err) + metrics.RecordProxyConnectionError(relayIfname, "write_udp") + continue } + metrics.RecordUDPPacket(relayIfname, "wireguard", "out") + metrics.RecordUDPPacketSize(relayIfname, "wireguard", float64(n)) } } @@ -741,15 +793,18 @@ func (s *UDPProxyServer) cleanupIdleConnections() { for { select { case <-ticker.C: + cleanupStart := time.Now() now := time.Now() s.connections.Range(func(key, value interface{}) bool { destConn := value.(*DestinationConn) if now.Sub(destConn.lastUsed) > 10*time.Minute { destConn.conn.Close() s.connections.Delete(key) + metrics.RecordProxyCleanupRemoved(relayIfname, "conn", 1) } return true }) + metrics.RecordProxyIdleCleanupDuration(relayIfname, "conn", time.Since(cleanupStart).Seconds()) case <-s.ctx.Done(): return } @@ -764,16 +819,20 @@ func (s *UDPProxyServer) cleanupIdleSessions() { for { select { case <-ticker.C: + cleanupStart := time.Now() now := time.Now() s.wgSessions.Range(func(key, value interface{}) bool { session := value.(*WireGuardSession) // Use thread-safe method to read LastSeen if now.Sub(session.GetLastSeen()) > 15*time.Minute { s.wgSessions.Delete(key) + metrics.RecordSession(relayIfname, -1) + metrics.RecordProxyCleanupRemoved(relayIfname, "session", 1) logger.Debug("Removed idle session: %s", key) } return true }) + metrics.RecordProxyIdleCleanupDuration(relayIfname, "session", time.Since(cleanupStart).Seconds()) case <-s.ctx.Done(): return } @@ -787,16 +846,20 @@ func (s *UDPProxyServer) cleanupIdleProxyMappings() { for { select { case <-ticker.C: + cleanupStart := time.Now() now := time.Now() s.proxyMappings.Range(func(key, value interface{}) bool { mapping := value.(ProxyMapping) // Remove mappings that haven't been used in 30 minutes if now.Sub(mapping.LastUsed) > 30*time.Minute { s.proxyMappings.Delete(key) + metrics.RecordProxyMapping(relayIfname, -1) + metrics.RecordProxyCleanupRemoved(relayIfname, "proxy_mapping", 1) logger.Debug("Removed idle proxy mapping: %s", key) } return true }) + metrics.RecordProxyIdleCleanupDuration(relayIfname, "proxy_mapping", time.Since(cleanupStart).Seconds()) case <-s.ctx.Done(): return } @@ -839,6 +902,11 @@ func (s *UDPProxyServer) notifyServer(endpoint ClientEndpoint) { key := fmt.Sprintf("%s:%d", endpoint.IP, endpoint.Port) logger.Debug("About to store proxy mapping with key: %s (from endpoint IP=%s, Port=%d)", key, endpoint.IP, endpoint.Port) mapping.LastUsed = time.Now() + if _, existed := s.proxyMappings.Load(key); existed { + metrics.RecordProxyMappingUpdate(relayIfname) + } else { + metrics.RecordProxyMapping(relayIfname, 1) + } s.proxyMappings.Store(key, mapping) logger.Debug("Stored proxy mapping for %s with %d destinations (timestamp: %v)", key, len(mapping.Destinations), mapping.LastUsed) @@ -851,6 +919,11 @@ func (s *UDPProxyServer) UpdateProxyMapping(sourceIP string, sourcePort int, des Destinations: destinations, LastUsed: time.Now(), } + if _, existed := s.proxyMappings.Load(key); existed { + metrics.RecordProxyMappingUpdate(relayIfname) + } else { + metrics.RecordProxyMapping(relayIfname, 1) + } s.proxyMappings.Store(key, mapping) } @@ -917,6 +990,10 @@ func (s *UDPProxyServer) clearSessionsForIP(ip string) { for _, key := range keysToDelete { s.wgSessions.Delete(key) } + if len(keysToDelete) > 0 { + metrics.RecordSession(relayIfname, -int64(len(keysToDelete))) + metrics.RecordProxyCleanupRemoved(relayIfname, "session", int64(len(keysToDelete))) + } logger.Debug("Cleared %d sessions for WG IP: %s", len(keysToDelete), ip) } @@ -1077,7 +1154,9 @@ func (s *UDPProxyServer) trackCommunicationPattern(fromAddr, toAddr *net.UDPAddr pattern.LastFromDest = now } - s.commPatterns.Store(patternKey, pattern) + if _, loaded := s.commPatterns.LoadOrStore(patternKey, pattern); !loaded { + metrics.RecordCommPattern(relayIfname, 1) + } } } @@ -1095,16 +1174,20 @@ func (s *UDPProxyServer) tryRebuildSession(pattern *CommunicationPattern) { sessionKey := fmt.Sprintf("%d:%d", pattern.DestIndex, pattern.ClientIndex) // Check if we already have this session - if _, exists := s.wgSessions.Load(sessionKey); !exists { - s.wgSessions.Store(sessionKey, &WireGuardSession{ - ReceiverIndex: pattern.DestIndex, - SenderIndex: pattern.ClientIndex, - DestAddr: pattern.ToDestination, - LastSeen: time.Now(), - }) - logger.Info("Rebuilt WireGuard session from communication pattern: %s -> %s (packets: %d)", - sessionKey, pattern.ToDestination.String(), pattern.PacketCount) + session := &WireGuardSession{ + ReceiverIndex: pattern.DestIndex, + SenderIndex: pattern.ClientIndex, + DestAddr: pattern.ToDestination, + LastSeen: time.Now(), } + if _, loaded := s.wgSessions.LoadOrStore(sessionKey, session); loaded { + s.wgSessions.Store(sessionKey, session) + } else { + metrics.RecordSession(relayIfname, 1) + metrics.RecordSessionRebuilt(relayIfname) + } + logger.Info("Rebuilt WireGuard session from communication pattern: %s -> %s (packets: %d)", + sessionKey, pattern.ToDestination.String(), pattern.PacketCount) } } @@ -1139,6 +1222,7 @@ func (s *UDPProxyServer) cleanupIdleCommunicationPatterns() { for { select { case <-ticker.C: + cleanupStart := time.Now() now := time.Now() s.commPatterns.Range(func(key, value interface{}) bool { pattern := value.(*CommunicationPattern) @@ -1152,10 +1236,13 @@ func (s *UDPProxyServer) cleanupIdleCommunicationPatterns() { // Remove patterns that haven't had activity in 20 minutes if now.Sub(lastActivity) > 20*time.Minute { s.commPatterns.Delete(key) + metrics.RecordCommPattern(relayIfname, -1) + metrics.RecordProxyCleanupRemoved(relayIfname, "comm_pattern", 1) logger.Debug("Removed idle communication pattern: %s", key) } return true }) + metrics.RecordProxyIdleCleanupDuration(relayIfname, "comm_pattern", time.Since(cleanupStart).Seconds()) case <-s.ctx.Done(): return }