Add metrics tracking for UDP packet handling and session management

This commit is contained in:
Marc Schäfer
2026-04-03 18:15:58 +02:00
parent e47a57cb4f
commit 652d9c5c68

View File

@@ -13,12 +13,15 @@ import (
"sync"
"time"
"github.com/fosrl/gerbil/internal/metrics"
"github.com/fosrl/gerbil/logger"
"golang.org/x/crypto/chacha20poly1305"
"golang.org/x/crypto/curve25519"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
const relayIfname = "relay"
type EncryptedHolePunchMessage struct {
EphemeralPublicKey string `json:"ephemeralPublicKey"`
Nonce []byte `json:"nonce"`
@@ -290,9 +293,13 @@ func (s *UDPProxyServer) packetWorker() {
for packet := range s.packetChan {
// Determine packet type by inspecting the first byte.
if packet.n > 0 && packet.data[0] >= 1 && packet.data[0] <= 4 {
metrics.RecordUDPPacket(relayIfname, "wireguard", "in")
metrics.RecordUDPPacketSize(relayIfname, "wireguard", float64(packet.n))
// Process as a WireGuard packet.
s.handleWireGuardPacket(packet.data, packet.remoteAddr)
} else {
metrics.RecordUDPPacket(relayIfname, "hole_punch", "in")
metrics.RecordUDPPacketSize(relayIfname, "hole_punch", float64(packet.n))
// Rate limit: allow at most 2 hole punch messages per IP:Port per second
rateLimitKey := packet.remoteAddr.String()
entryVal, _ := s.holePunchRateLimiter.LoadOrStore(rateLimitKey, &holePunchRateLimitEntry{
@@ -310,6 +317,7 @@ func (s *UDPProxyServer) packetWorker() {
rlEntry.mu.Unlock()
if !allowed {
// logger.Debug("Rate limiting hole punch message from %s", rateLimitKey)
metrics.RecordHolePunchEvent(relayIfname, "rate_limited")
bufferPool.Put(packet.data[:1500])
continue
}
@@ -318,6 +326,7 @@ func (s *UDPProxyServer) packetWorker() {
var encMsg EncryptedHolePunchMessage
if err := json.Unmarshal(packet.data, &encMsg); err != nil {
logger.Error("Error unmarshaling encrypted message: %v", err)
metrics.RecordHolePunchEvent(relayIfname, "error")
// Return the buffer to the pool for reuse and continue with next packet
bufferPool.Put(packet.data[:1500])
continue
@@ -325,6 +334,7 @@ func (s *UDPProxyServer) packetWorker() {
if encMsg.EphemeralPublicKey == "" {
logger.Error("Received malformed message without ephemeral key")
metrics.RecordHolePunchEvent(relayIfname, "error")
// Return the buffer to the pool for reuse and continue with next packet
bufferPool.Put(packet.data[:1500])
continue
@@ -334,6 +344,7 @@ func (s *UDPProxyServer) packetWorker() {
decryptedData, err := s.decryptMessage(encMsg)
if err != nil {
// logger.Error("Failed to decrypt message: %v", err)
metrics.RecordHolePunchEvent(relayIfname, "error")
// Return the buffer to the pool for reuse and continue with next packet
bufferPool.Put(packet.data[:1500])
continue
@@ -343,6 +354,7 @@ func (s *UDPProxyServer) packetWorker() {
var msg HolePunchMessage
if err := json.Unmarshal(decryptedData, &msg); err != nil {
logger.Error("Error unmarshaling decrypted message: %v", err)
metrics.RecordHolePunchEvent(relayIfname, "error")
// Return the buffer to the pool for reuse and continue with next packet
bufferPool.Put(packet.data[:1500])
continue
@@ -362,6 +374,7 @@ func (s *UDPProxyServer) packetWorker() {
logger.Debug("Created endpoint from packet remoteAddr %s: IP=%s, Port=%d", packet.remoteAddr.String(), endpoint.IP, endpoint.Port)
s.notifyServer(endpoint)
s.clearSessionsForIP(endpoint.IP) // Clear sessions for this IP to allow re-establishment
metrics.RecordHolePunchEvent(relayIfname, "success")
}
// Return the buffer to the pool for reuse.
bufferPool.Put(packet.data[:1500])
@@ -429,6 +442,8 @@ func (s *UDPProxyServer) fetchInitialMappings() error {
mapping.LastUsed = time.Now()
s.proxyMappings.Store(key, mapping)
}
metrics.RecordProxyInitialMappings(relayIfname, int64(len(initialMappings.Mappings)))
metrics.RecordProxyMapping(relayIfname, int64(len(initialMappings.Mappings)))
logger.Info("Loaded %d initial proxy mappings", len(initialMappings.Mappings))
return nil
}
@@ -544,7 +559,11 @@ func (s *UDPProxyServer) handleWireGuardPacket(packet []byte, remoteAddr *net.UD
_, err = conn.Write(packet)
if err != nil {
logger.Debug("Failed to forward handshake initiation: %v", err)
metrics.RecordProxyConnectionError(relayIfname, "write_udp")
continue
}
metrics.RecordUDPPacket(relayIfname, "wireguard", "out")
metrics.RecordUDPPacketSize(relayIfname, "wireguard", float64(len(packet)))
}
case WireGuardMessageTypeHandshakeResponse:
@@ -556,12 +575,17 @@ func (s *UDPProxyServer) handleWireGuardPacket(packet []byte, remoteAddr *net.UD
sessionKey := fmt.Sprintf("%d:%d", receiverIndex, senderIndex)
// Store the session information
s.wgSessions.Store(sessionKey, &WireGuardSession{
session := &WireGuardSession{
ReceiverIndex: receiverIndex,
SenderIndex: senderIndex,
DestAddr: remoteAddr,
LastSeen: time.Now(),
})
}
if _, loaded := s.wgSessions.LoadOrStore(sessionKey, session); loaded {
s.wgSessions.Store(sessionKey, session)
} else {
metrics.RecordSession(relayIfname, 1)
}
// Forward the response to the original sender
for _, dest := range proxyMapping.Destinations {
@@ -580,7 +604,11 @@ func (s *UDPProxyServer) handleWireGuardPacket(packet []byte, remoteAddr *net.UD
_, err = conn.Write(packet)
if err != nil {
logger.Error("Failed to forward handshake response: %v", err)
metrics.RecordProxyConnectionError(relayIfname, "write_udp")
continue
}
metrics.RecordUDPPacket(relayIfname, "wireguard", "out")
metrics.RecordUDPPacketSize(relayIfname, "wireguard", float64(len(packet)))
}
case WireGuardMessageTypeTransportData:
@@ -617,7 +645,11 @@ func (s *UDPProxyServer) handleWireGuardPacket(packet []byte, remoteAddr *net.UD
_, err = conn.Write(packet)
if err != nil {
logger.Debug("Failed to forward transport data: %v", err)
metrics.RecordProxyConnectionError(relayIfname, "write_udp")
return
}
metrics.RecordUDPPacket(relayIfname, "wireguard", "out")
metrics.RecordUDPPacketSize(relayIfname, "wireguard", float64(len(packet)))
} else {
// No known session, fall back to forwarding to all peers
logger.Debug("No session found for receiver index %d, forwarding to all destinations", receiverIndex)
@@ -640,7 +672,11 @@ func (s *UDPProxyServer) handleWireGuardPacket(packet []byte, remoteAddr *net.UD
_, err = conn.Write(packet)
if err != nil {
logger.Debug("Failed to forward transport data: %v", err)
metrics.RecordProxyConnectionError(relayIfname, "write_udp")
continue
}
metrics.RecordUDPPacket(relayIfname, "wireguard", "out")
metrics.RecordUDPPacketSize(relayIfname, "wireguard", float64(len(packet)))
}
}
@@ -665,7 +701,11 @@ func (s *UDPProxyServer) handleWireGuardPacket(packet []byte, remoteAddr *net.UD
_, err = conn.Write(packet)
if err != nil {
logger.Error("Failed to forward WireGuard packet: %v", err)
metrics.RecordProxyConnectionError(relayIfname, "write_udp")
continue
}
metrics.RecordUDPPacket(relayIfname, "wireguard", "out")
metrics.RecordUDPPacketSize(relayIfname, "wireguard", float64(len(packet)))
}
}
}
@@ -683,6 +723,7 @@ func (s *UDPProxyServer) getOrCreateConnection(destAddr *net.UDPAddr, remoteAddr
// Create new connection
newConn, err := net.DialUDP("udp", nil, destAddr)
if err != nil {
metrics.RecordProxyConnectionError(relayIfname, "dial_udp")
return nil, fmt.Errorf("failed to create UDP connection: %v", err)
}
@@ -706,6 +747,8 @@ func (s *UDPProxyServer) handleResponses(conn *net.UDPConn, destAddr *net.UDPAdd
logger.Debug("Error reading response from %s: %v", destAddr.String(), err)
return
}
metrics.RecordUDPPacket(relayIfname, "wireguard", "in")
metrics.RecordUDPPacketSize(relayIfname, "wireguard", float64(n))
// Process the response to track sessions if it's a WireGuard packet
if n > 0 && buffer[0] >= 1 && buffer[0] <= 4 {
@@ -713,12 +756,17 @@ func (s *UDPProxyServer) handleResponses(conn *net.UDPConn, destAddr *net.UDPAdd
if ok && buffer[0] == WireGuardMessageTypeHandshakeResponse {
// Store the session mapping for the handshake response
sessionKey := fmt.Sprintf("%d:%d", senderIndex, receiverIndex)
s.wgSessions.Store(sessionKey, &WireGuardSession{
session := &WireGuardSession{
ReceiverIndex: receiverIndex,
SenderIndex: senderIndex,
DestAddr: destAddr,
LastSeen: time.Now(),
})
}
if _, loaded := s.wgSessions.LoadOrStore(sessionKey, session); loaded {
s.wgSessions.Store(sessionKey, session)
} else {
metrics.RecordSession(relayIfname, 1)
}
logger.Debug("Stored session mapping: %s -> %s", sessionKey, destAddr.String())
} else if ok && buffer[0] == WireGuardMessageTypeTransportData {
// Track communication pattern for session rebuilding (reverse direction)
@@ -730,7 +778,11 @@ func (s *UDPProxyServer) handleResponses(conn *net.UDPConn, destAddr *net.UDPAdd
_, err = s.conn.WriteToUDP(buffer[:n], remoteAddr)
if err != nil {
logger.Error("Failed to forward response: %v", err)
metrics.RecordProxyConnectionError(relayIfname, "write_udp")
continue
}
metrics.RecordUDPPacket(relayIfname, "wireguard", "out")
metrics.RecordUDPPacketSize(relayIfname, "wireguard", float64(n))
}
}
@@ -741,15 +793,18 @@ func (s *UDPProxyServer) cleanupIdleConnections() {
for {
select {
case <-ticker.C:
cleanupStart := time.Now()
now := time.Now()
s.connections.Range(func(key, value interface{}) bool {
destConn := value.(*DestinationConn)
if now.Sub(destConn.lastUsed) > 10*time.Minute {
destConn.conn.Close()
s.connections.Delete(key)
metrics.RecordProxyCleanupRemoved(relayIfname, "conn", 1)
}
return true
})
metrics.RecordProxyIdleCleanupDuration(relayIfname, "conn", time.Since(cleanupStart).Seconds())
case <-s.ctx.Done():
return
}
@@ -764,16 +819,20 @@ func (s *UDPProxyServer) cleanupIdleSessions() {
for {
select {
case <-ticker.C:
cleanupStart := time.Now()
now := time.Now()
s.wgSessions.Range(func(key, value interface{}) bool {
session := value.(*WireGuardSession)
// Use thread-safe method to read LastSeen
if now.Sub(session.GetLastSeen()) > 15*time.Minute {
s.wgSessions.Delete(key)
metrics.RecordSession(relayIfname, -1)
metrics.RecordProxyCleanupRemoved(relayIfname, "session", 1)
logger.Debug("Removed idle session: %s", key)
}
return true
})
metrics.RecordProxyIdleCleanupDuration(relayIfname, "session", time.Since(cleanupStart).Seconds())
case <-s.ctx.Done():
return
}
@@ -787,16 +846,20 @@ func (s *UDPProxyServer) cleanupIdleProxyMappings() {
for {
select {
case <-ticker.C:
cleanupStart := time.Now()
now := time.Now()
s.proxyMappings.Range(func(key, value interface{}) bool {
mapping := value.(ProxyMapping)
// Remove mappings that haven't been used in 30 minutes
if now.Sub(mapping.LastUsed) > 30*time.Minute {
s.proxyMappings.Delete(key)
metrics.RecordProxyMapping(relayIfname, -1)
metrics.RecordProxyCleanupRemoved(relayIfname, "proxy_mapping", 1)
logger.Debug("Removed idle proxy mapping: %s", key)
}
return true
})
metrics.RecordProxyIdleCleanupDuration(relayIfname, "proxy_mapping", time.Since(cleanupStart).Seconds())
case <-s.ctx.Done():
return
}
@@ -839,6 +902,11 @@ func (s *UDPProxyServer) notifyServer(endpoint ClientEndpoint) {
key := fmt.Sprintf("%s:%d", endpoint.IP, endpoint.Port)
logger.Debug("About to store proxy mapping with key: %s (from endpoint IP=%s, Port=%d)", key, endpoint.IP, endpoint.Port)
mapping.LastUsed = time.Now()
if _, existed := s.proxyMappings.Load(key); existed {
metrics.RecordProxyMappingUpdate(relayIfname)
} else {
metrics.RecordProxyMapping(relayIfname, 1)
}
s.proxyMappings.Store(key, mapping)
logger.Debug("Stored proxy mapping for %s with %d destinations (timestamp: %v)", key, len(mapping.Destinations), mapping.LastUsed)
@@ -851,6 +919,11 @@ func (s *UDPProxyServer) UpdateProxyMapping(sourceIP string, sourcePort int, des
Destinations: destinations,
LastUsed: time.Now(),
}
if _, existed := s.proxyMappings.Load(key); existed {
metrics.RecordProxyMappingUpdate(relayIfname)
} else {
metrics.RecordProxyMapping(relayIfname, 1)
}
s.proxyMappings.Store(key, mapping)
}
@@ -917,6 +990,10 @@ func (s *UDPProxyServer) clearSessionsForIP(ip string) {
for _, key := range keysToDelete {
s.wgSessions.Delete(key)
}
if len(keysToDelete) > 0 {
metrics.RecordSession(relayIfname, -int64(len(keysToDelete)))
metrics.RecordProxyCleanupRemoved(relayIfname, "session", int64(len(keysToDelete)))
}
logger.Debug("Cleared %d sessions for WG IP: %s", len(keysToDelete), ip)
}
@@ -1077,7 +1154,9 @@ func (s *UDPProxyServer) trackCommunicationPattern(fromAddr, toAddr *net.UDPAddr
pattern.LastFromDest = now
}
s.commPatterns.Store(patternKey, pattern)
if _, loaded := s.commPatterns.LoadOrStore(patternKey, pattern); !loaded {
metrics.RecordCommPattern(relayIfname, 1)
}
}
}
@@ -1095,16 +1174,20 @@ func (s *UDPProxyServer) tryRebuildSession(pattern *CommunicationPattern) {
sessionKey := fmt.Sprintf("%d:%d", pattern.DestIndex, pattern.ClientIndex)
// Check if we already have this session
if _, exists := s.wgSessions.Load(sessionKey); !exists {
s.wgSessions.Store(sessionKey, &WireGuardSession{
ReceiverIndex: pattern.DestIndex,
SenderIndex: pattern.ClientIndex,
DestAddr: pattern.ToDestination,
LastSeen: time.Now(),
})
logger.Info("Rebuilt WireGuard session from communication pattern: %s -> %s (packets: %d)",
sessionKey, pattern.ToDestination.String(), pattern.PacketCount)
session := &WireGuardSession{
ReceiverIndex: pattern.DestIndex,
SenderIndex: pattern.ClientIndex,
DestAddr: pattern.ToDestination,
LastSeen: time.Now(),
}
if _, loaded := s.wgSessions.LoadOrStore(sessionKey, session); loaded {
s.wgSessions.Store(sessionKey, session)
} else {
metrics.RecordSession(relayIfname, 1)
metrics.RecordSessionRebuilt(relayIfname)
}
logger.Info("Rebuilt WireGuard session from communication pattern: %s -> %s (packets: %d)",
sessionKey, pattern.ToDestination.String(), pattern.PacketCount)
}
}
@@ -1139,6 +1222,7 @@ func (s *UDPProxyServer) cleanupIdleCommunicationPatterns() {
for {
select {
case <-ticker.C:
cleanupStart := time.Now()
now := time.Now()
s.commPatterns.Range(func(key, value interface{}) bool {
pattern := value.(*CommunicationPattern)
@@ -1152,10 +1236,13 @@ func (s *UDPProxyServer) cleanupIdleCommunicationPatterns() {
// Remove patterns that haven't had activity in 20 minutes
if now.Sub(lastActivity) > 20*time.Minute {
s.commPatterns.Delete(key)
metrics.RecordCommPattern(relayIfname, -1)
metrics.RecordProxyCleanupRemoved(relayIfname, "comm_pattern", 1)
logger.Debug("Removed idle communication pattern: %s", key)
}
return true
})
metrics.RecordProxyIdleCleanupDuration(relayIfname, "comm_pattern", time.Since(cleanupStart).Seconds())
case <-s.ctx.Done():
return
}