mirror of
https://github.com/fosrl/gerbil.git
synced 2026-05-12 03:09:57 +00:00
Add metrics tracking for UDP packet handling and session management
This commit is contained in:
115
relay/relay.go
115
relay/relay.go
@@ -13,12 +13,15 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/fosrl/gerbil/internal/metrics"
|
||||
"github.com/fosrl/gerbil/logger"
|
||||
"golang.org/x/crypto/chacha20poly1305"
|
||||
"golang.org/x/crypto/curve25519"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
)
|
||||
|
||||
const relayIfname = "relay"
|
||||
|
||||
type EncryptedHolePunchMessage struct {
|
||||
EphemeralPublicKey string `json:"ephemeralPublicKey"`
|
||||
Nonce []byte `json:"nonce"`
|
||||
@@ -290,9 +293,13 @@ func (s *UDPProxyServer) packetWorker() {
|
||||
for packet := range s.packetChan {
|
||||
// Determine packet type by inspecting the first byte.
|
||||
if packet.n > 0 && packet.data[0] >= 1 && packet.data[0] <= 4 {
|
||||
metrics.RecordUDPPacket(relayIfname, "wireguard", "in")
|
||||
metrics.RecordUDPPacketSize(relayIfname, "wireguard", float64(packet.n))
|
||||
// Process as a WireGuard packet.
|
||||
s.handleWireGuardPacket(packet.data, packet.remoteAddr)
|
||||
} else {
|
||||
metrics.RecordUDPPacket(relayIfname, "hole_punch", "in")
|
||||
metrics.RecordUDPPacketSize(relayIfname, "hole_punch", float64(packet.n))
|
||||
// Rate limit: allow at most 2 hole punch messages per IP:Port per second
|
||||
rateLimitKey := packet.remoteAddr.String()
|
||||
entryVal, _ := s.holePunchRateLimiter.LoadOrStore(rateLimitKey, &holePunchRateLimitEntry{
|
||||
@@ -310,6 +317,7 @@ func (s *UDPProxyServer) packetWorker() {
|
||||
rlEntry.mu.Unlock()
|
||||
if !allowed {
|
||||
// logger.Debug("Rate limiting hole punch message from %s", rateLimitKey)
|
||||
metrics.RecordHolePunchEvent(relayIfname, "rate_limited")
|
||||
bufferPool.Put(packet.data[:1500])
|
||||
continue
|
||||
}
|
||||
@@ -318,6 +326,7 @@ func (s *UDPProxyServer) packetWorker() {
|
||||
var encMsg EncryptedHolePunchMessage
|
||||
if err := json.Unmarshal(packet.data, &encMsg); err != nil {
|
||||
logger.Error("Error unmarshaling encrypted message: %v", err)
|
||||
metrics.RecordHolePunchEvent(relayIfname, "error")
|
||||
// Return the buffer to the pool for reuse and continue with next packet
|
||||
bufferPool.Put(packet.data[:1500])
|
||||
continue
|
||||
@@ -325,6 +334,7 @@ func (s *UDPProxyServer) packetWorker() {
|
||||
|
||||
if encMsg.EphemeralPublicKey == "" {
|
||||
logger.Error("Received malformed message without ephemeral key")
|
||||
metrics.RecordHolePunchEvent(relayIfname, "error")
|
||||
// Return the buffer to the pool for reuse and continue with next packet
|
||||
bufferPool.Put(packet.data[:1500])
|
||||
continue
|
||||
@@ -334,6 +344,7 @@ func (s *UDPProxyServer) packetWorker() {
|
||||
decryptedData, err := s.decryptMessage(encMsg)
|
||||
if err != nil {
|
||||
// logger.Error("Failed to decrypt message: %v", err)
|
||||
metrics.RecordHolePunchEvent(relayIfname, "error")
|
||||
// Return the buffer to the pool for reuse and continue with next packet
|
||||
bufferPool.Put(packet.data[:1500])
|
||||
continue
|
||||
@@ -343,6 +354,7 @@ func (s *UDPProxyServer) packetWorker() {
|
||||
var msg HolePunchMessage
|
||||
if err := json.Unmarshal(decryptedData, &msg); err != nil {
|
||||
logger.Error("Error unmarshaling decrypted message: %v", err)
|
||||
metrics.RecordHolePunchEvent(relayIfname, "error")
|
||||
// Return the buffer to the pool for reuse and continue with next packet
|
||||
bufferPool.Put(packet.data[:1500])
|
||||
continue
|
||||
@@ -362,6 +374,7 @@ func (s *UDPProxyServer) packetWorker() {
|
||||
logger.Debug("Created endpoint from packet remoteAddr %s: IP=%s, Port=%d", packet.remoteAddr.String(), endpoint.IP, endpoint.Port)
|
||||
s.notifyServer(endpoint)
|
||||
s.clearSessionsForIP(endpoint.IP) // Clear sessions for this IP to allow re-establishment
|
||||
metrics.RecordHolePunchEvent(relayIfname, "success")
|
||||
}
|
||||
// Return the buffer to the pool for reuse.
|
||||
bufferPool.Put(packet.data[:1500])
|
||||
@@ -429,6 +442,8 @@ func (s *UDPProxyServer) fetchInitialMappings() error {
|
||||
mapping.LastUsed = time.Now()
|
||||
s.proxyMappings.Store(key, mapping)
|
||||
}
|
||||
metrics.RecordProxyInitialMappings(relayIfname, int64(len(initialMappings.Mappings)))
|
||||
metrics.RecordProxyMapping(relayIfname, int64(len(initialMappings.Mappings)))
|
||||
logger.Info("Loaded %d initial proxy mappings", len(initialMappings.Mappings))
|
||||
return nil
|
||||
}
|
||||
@@ -544,7 +559,11 @@ func (s *UDPProxyServer) handleWireGuardPacket(packet []byte, remoteAddr *net.UD
|
||||
_, err = conn.Write(packet)
|
||||
if err != nil {
|
||||
logger.Debug("Failed to forward handshake initiation: %v", err)
|
||||
metrics.RecordProxyConnectionError(relayIfname, "write_udp")
|
||||
continue
|
||||
}
|
||||
metrics.RecordUDPPacket(relayIfname, "wireguard", "out")
|
||||
metrics.RecordUDPPacketSize(relayIfname, "wireguard", float64(len(packet)))
|
||||
}
|
||||
|
||||
case WireGuardMessageTypeHandshakeResponse:
|
||||
@@ -556,12 +575,17 @@ func (s *UDPProxyServer) handleWireGuardPacket(packet []byte, remoteAddr *net.UD
|
||||
sessionKey := fmt.Sprintf("%d:%d", receiverIndex, senderIndex)
|
||||
|
||||
// Store the session information
|
||||
s.wgSessions.Store(sessionKey, &WireGuardSession{
|
||||
session := &WireGuardSession{
|
||||
ReceiverIndex: receiverIndex,
|
||||
SenderIndex: senderIndex,
|
||||
DestAddr: remoteAddr,
|
||||
LastSeen: time.Now(),
|
||||
})
|
||||
}
|
||||
if _, loaded := s.wgSessions.LoadOrStore(sessionKey, session); loaded {
|
||||
s.wgSessions.Store(sessionKey, session)
|
||||
} else {
|
||||
metrics.RecordSession(relayIfname, 1)
|
||||
}
|
||||
|
||||
// Forward the response to the original sender
|
||||
for _, dest := range proxyMapping.Destinations {
|
||||
@@ -580,7 +604,11 @@ func (s *UDPProxyServer) handleWireGuardPacket(packet []byte, remoteAddr *net.UD
|
||||
_, err = conn.Write(packet)
|
||||
if err != nil {
|
||||
logger.Error("Failed to forward handshake response: %v", err)
|
||||
metrics.RecordProxyConnectionError(relayIfname, "write_udp")
|
||||
continue
|
||||
}
|
||||
metrics.RecordUDPPacket(relayIfname, "wireguard", "out")
|
||||
metrics.RecordUDPPacketSize(relayIfname, "wireguard", float64(len(packet)))
|
||||
}
|
||||
|
||||
case WireGuardMessageTypeTransportData:
|
||||
@@ -617,7 +645,11 @@ func (s *UDPProxyServer) handleWireGuardPacket(packet []byte, remoteAddr *net.UD
|
||||
_, err = conn.Write(packet)
|
||||
if err != nil {
|
||||
logger.Debug("Failed to forward transport data: %v", err)
|
||||
metrics.RecordProxyConnectionError(relayIfname, "write_udp")
|
||||
return
|
||||
}
|
||||
metrics.RecordUDPPacket(relayIfname, "wireguard", "out")
|
||||
metrics.RecordUDPPacketSize(relayIfname, "wireguard", float64(len(packet)))
|
||||
} else {
|
||||
// No known session, fall back to forwarding to all peers
|
||||
logger.Debug("No session found for receiver index %d, forwarding to all destinations", receiverIndex)
|
||||
@@ -640,7 +672,11 @@ func (s *UDPProxyServer) handleWireGuardPacket(packet []byte, remoteAddr *net.UD
|
||||
_, err = conn.Write(packet)
|
||||
if err != nil {
|
||||
logger.Debug("Failed to forward transport data: %v", err)
|
||||
metrics.RecordProxyConnectionError(relayIfname, "write_udp")
|
||||
continue
|
||||
}
|
||||
metrics.RecordUDPPacket(relayIfname, "wireguard", "out")
|
||||
metrics.RecordUDPPacketSize(relayIfname, "wireguard", float64(len(packet)))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -665,7 +701,11 @@ func (s *UDPProxyServer) handleWireGuardPacket(packet []byte, remoteAddr *net.UD
|
||||
_, err = conn.Write(packet)
|
||||
if err != nil {
|
||||
logger.Error("Failed to forward WireGuard packet: %v", err)
|
||||
metrics.RecordProxyConnectionError(relayIfname, "write_udp")
|
||||
continue
|
||||
}
|
||||
metrics.RecordUDPPacket(relayIfname, "wireguard", "out")
|
||||
metrics.RecordUDPPacketSize(relayIfname, "wireguard", float64(len(packet)))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -683,6 +723,7 @@ func (s *UDPProxyServer) getOrCreateConnection(destAddr *net.UDPAddr, remoteAddr
|
||||
// Create new connection
|
||||
newConn, err := net.DialUDP("udp", nil, destAddr)
|
||||
if err != nil {
|
||||
metrics.RecordProxyConnectionError(relayIfname, "dial_udp")
|
||||
return nil, fmt.Errorf("failed to create UDP connection: %v", err)
|
||||
}
|
||||
|
||||
@@ -706,6 +747,8 @@ func (s *UDPProxyServer) handleResponses(conn *net.UDPConn, destAddr *net.UDPAdd
|
||||
logger.Debug("Error reading response from %s: %v", destAddr.String(), err)
|
||||
return
|
||||
}
|
||||
metrics.RecordUDPPacket(relayIfname, "wireguard", "in")
|
||||
metrics.RecordUDPPacketSize(relayIfname, "wireguard", float64(n))
|
||||
|
||||
// Process the response to track sessions if it's a WireGuard packet
|
||||
if n > 0 && buffer[0] >= 1 && buffer[0] <= 4 {
|
||||
@@ -713,12 +756,17 @@ func (s *UDPProxyServer) handleResponses(conn *net.UDPConn, destAddr *net.UDPAdd
|
||||
if ok && buffer[0] == WireGuardMessageTypeHandshakeResponse {
|
||||
// Store the session mapping for the handshake response
|
||||
sessionKey := fmt.Sprintf("%d:%d", senderIndex, receiverIndex)
|
||||
s.wgSessions.Store(sessionKey, &WireGuardSession{
|
||||
session := &WireGuardSession{
|
||||
ReceiverIndex: receiverIndex,
|
||||
SenderIndex: senderIndex,
|
||||
DestAddr: destAddr,
|
||||
LastSeen: time.Now(),
|
||||
})
|
||||
}
|
||||
if _, loaded := s.wgSessions.LoadOrStore(sessionKey, session); loaded {
|
||||
s.wgSessions.Store(sessionKey, session)
|
||||
} else {
|
||||
metrics.RecordSession(relayIfname, 1)
|
||||
}
|
||||
logger.Debug("Stored session mapping: %s -> %s", sessionKey, destAddr.String())
|
||||
} else if ok && buffer[0] == WireGuardMessageTypeTransportData {
|
||||
// Track communication pattern for session rebuilding (reverse direction)
|
||||
@@ -730,7 +778,11 @@ func (s *UDPProxyServer) handleResponses(conn *net.UDPConn, destAddr *net.UDPAdd
|
||||
_, err = s.conn.WriteToUDP(buffer[:n], remoteAddr)
|
||||
if err != nil {
|
||||
logger.Error("Failed to forward response: %v", err)
|
||||
metrics.RecordProxyConnectionError(relayIfname, "write_udp")
|
||||
continue
|
||||
}
|
||||
metrics.RecordUDPPacket(relayIfname, "wireguard", "out")
|
||||
metrics.RecordUDPPacketSize(relayIfname, "wireguard", float64(n))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -741,15 +793,18 @@ func (s *UDPProxyServer) cleanupIdleConnections() {
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
cleanupStart := time.Now()
|
||||
now := time.Now()
|
||||
s.connections.Range(func(key, value interface{}) bool {
|
||||
destConn := value.(*DestinationConn)
|
||||
if now.Sub(destConn.lastUsed) > 10*time.Minute {
|
||||
destConn.conn.Close()
|
||||
s.connections.Delete(key)
|
||||
metrics.RecordProxyCleanupRemoved(relayIfname, "conn", 1)
|
||||
}
|
||||
return true
|
||||
})
|
||||
metrics.RecordProxyIdleCleanupDuration(relayIfname, "conn", time.Since(cleanupStart).Seconds())
|
||||
case <-s.ctx.Done():
|
||||
return
|
||||
}
|
||||
@@ -764,16 +819,20 @@ func (s *UDPProxyServer) cleanupIdleSessions() {
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
cleanupStart := time.Now()
|
||||
now := time.Now()
|
||||
s.wgSessions.Range(func(key, value interface{}) bool {
|
||||
session := value.(*WireGuardSession)
|
||||
// Use thread-safe method to read LastSeen
|
||||
if now.Sub(session.GetLastSeen()) > 15*time.Minute {
|
||||
s.wgSessions.Delete(key)
|
||||
metrics.RecordSession(relayIfname, -1)
|
||||
metrics.RecordProxyCleanupRemoved(relayIfname, "session", 1)
|
||||
logger.Debug("Removed idle session: %s", key)
|
||||
}
|
||||
return true
|
||||
})
|
||||
metrics.RecordProxyIdleCleanupDuration(relayIfname, "session", time.Since(cleanupStart).Seconds())
|
||||
case <-s.ctx.Done():
|
||||
return
|
||||
}
|
||||
@@ -787,16 +846,20 @@ func (s *UDPProxyServer) cleanupIdleProxyMappings() {
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
cleanupStart := time.Now()
|
||||
now := time.Now()
|
||||
s.proxyMappings.Range(func(key, value interface{}) bool {
|
||||
mapping := value.(ProxyMapping)
|
||||
// Remove mappings that haven't been used in 30 minutes
|
||||
if now.Sub(mapping.LastUsed) > 30*time.Minute {
|
||||
s.proxyMappings.Delete(key)
|
||||
metrics.RecordProxyMapping(relayIfname, -1)
|
||||
metrics.RecordProxyCleanupRemoved(relayIfname, "proxy_mapping", 1)
|
||||
logger.Debug("Removed idle proxy mapping: %s", key)
|
||||
}
|
||||
return true
|
||||
})
|
||||
metrics.RecordProxyIdleCleanupDuration(relayIfname, "proxy_mapping", time.Since(cleanupStart).Seconds())
|
||||
case <-s.ctx.Done():
|
||||
return
|
||||
}
|
||||
@@ -839,6 +902,11 @@ func (s *UDPProxyServer) notifyServer(endpoint ClientEndpoint) {
|
||||
key := fmt.Sprintf("%s:%d", endpoint.IP, endpoint.Port)
|
||||
logger.Debug("About to store proxy mapping with key: %s (from endpoint IP=%s, Port=%d)", key, endpoint.IP, endpoint.Port)
|
||||
mapping.LastUsed = time.Now()
|
||||
if _, existed := s.proxyMappings.Load(key); existed {
|
||||
metrics.RecordProxyMappingUpdate(relayIfname)
|
||||
} else {
|
||||
metrics.RecordProxyMapping(relayIfname, 1)
|
||||
}
|
||||
s.proxyMappings.Store(key, mapping)
|
||||
|
||||
logger.Debug("Stored proxy mapping for %s with %d destinations (timestamp: %v)", key, len(mapping.Destinations), mapping.LastUsed)
|
||||
@@ -851,6 +919,11 @@ func (s *UDPProxyServer) UpdateProxyMapping(sourceIP string, sourcePort int, des
|
||||
Destinations: destinations,
|
||||
LastUsed: time.Now(),
|
||||
}
|
||||
if _, existed := s.proxyMappings.Load(key); existed {
|
||||
metrics.RecordProxyMappingUpdate(relayIfname)
|
||||
} else {
|
||||
metrics.RecordProxyMapping(relayIfname, 1)
|
||||
}
|
||||
s.proxyMappings.Store(key, mapping)
|
||||
}
|
||||
|
||||
@@ -917,6 +990,10 @@ func (s *UDPProxyServer) clearSessionsForIP(ip string) {
|
||||
for _, key := range keysToDelete {
|
||||
s.wgSessions.Delete(key)
|
||||
}
|
||||
if len(keysToDelete) > 0 {
|
||||
metrics.RecordSession(relayIfname, -int64(len(keysToDelete)))
|
||||
metrics.RecordProxyCleanupRemoved(relayIfname, "session", int64(len(keysToDelete)))
|
||||
}
|
||||
|
||||
logger.Debug("Cleared %d sessions for WG IP: %s", len(keysToDelete), ip)
|
||||
}
|
||||
@@ -1077,7 +1154,9 @@ func (s *UDPProxyServer) trackCommunicationPattern(fromAddr, toAddr *net.UDPAddr
|
||||
pattern.LastFromDest = now
|
||||
}
|
||||
|
||||
s.commPatterns.Store(patternKey, pattern)
|
||||
if _, loaded := s.commPatterns.LoadOrStore(patternKey, pattern); !loaded {
|
||||
metrics.RecordCommPattern(relayIfname, 1)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1095,16 +1174,20 @@ func (s *UDPProxyServer) tryRebuildSession(pattern *CommunicationPattern) {
|
||||
sessionKey := fmt.Sprintf("%d:%d", pattern.DestIndex, pattern.ClientIndex)
|
||||
|
||||
// Check if we already have this session
|
||||
if _, exists := s.wgSessions.Load(sessionKey); !exists {
|
||||
s.wgSessions.Store(sessionKey, &WireGuardSession{
|
||||
ReceiverIndex: pattern.DestIndex,
|
||||
SenderIndex: pattern.ClientIndex,
|
||||
DestAddr: pattern.ToDestination,
|
||||
LastSeen: time.Now(),
|
||||
})
|
||||
logger.Info("Rebuilt WireGuard session from communication pattern: %s -> %s (packets: %d)",
|
||||
sessionKey, pattern.ToDestination.String(), pattern.PacketCount)
|
||||
session := &WireGuardSession{
|
||||
ReceiverIndex: pattern.DestIndex,
|
||||
SenderIndex: pattern.ClientIndex,
|
||||
DestAddr: pattern.ToDestination,
|
||||
LastSeen: time.Now(),
|
||||
}
|
||||
if _, loaded := s.wgSessions.LoadOrStore(sessionKey, session); loaded {
|
||||
s.wgSessions.Store(sessionKey, session)
|
||||
} else {
|
||||
metrics.RecordSession(relayIfname, 1)
|
||||
metrics.RecordSessionRebuilt(relayIfname)
|
||||
}
|
||||
logger.Info("Rebuilt WireGuard session from communication pattern: %s -> %s (packets: %d)",
|
||||
sessionKey, pattern.ToDestination.String(), pattern.PacketCount)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1139,6 +1222,7 @@ func (s *UDPProxyServer) cleanupIdleCommunicationPatterns() {
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
cleanupStart := time.Now()
|
||||
now := time.Now()
|
||||
s.commPatterns.Range(func(key, value interface{}) bool {
|
||||
pattern := value.(*CommunicationPattern)
|
||||
@@ -1152,10 +1236,13 @@ func (s *UDPProxyServer) cleanupIdleCommunicationPatterns() {
|
||||
// Remove patterns that haven't had activity in 20 minutes
|
||||
if now.Sub(lastActivity) > 20*time.Minute {
|
||||
s.commPatterns.Delete(key)
|
||||
metrics.RecordCommPattern(relayIfname, -1)
|
||||
metrics.RecordProxyCleanupRemoved(relayIfname, "comm_pattern", 1)
|
||||
logger.Debug("Removed idle communication pattern: %s", key)
|
||||
}
|
||||
return true
|
||||
})
|
||||
metrics.RecordProxyIdleCleanupDuration(relayIfname, "comm_pattern", time.Since(cleanupStart).Seconds())
|
||||
case <-s.ctx.Done():
|
||||
return
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user