From 4ee9d775324497deafe4df708a319c8d015860b3 Mon Sep 17 00:00:00 2001 From: Owen Date: Sun, 28 Sep 2025 15:31:34 -0700 Subject: [PATCH] Rebuild sessions --- relay/relay.go | 140 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 140 insertions(+) diff --git a/relay/relay.go b/relay/relay.go index 322276a..e74ed87 100644 --- a/relay/relay.go +++ b/relay/relay.go @@ -64,6 +64,17 @@ type WireGuardSession struct { LastSeen time.Time } +// Type for tracking bidirectional communication patterns to rebuild sessions +type CommunicationPattern struct { + FromClient *net.UDPAddr // The client address + ToDestination *net.UDPAddr // The destination address + ClientIndex uint32 // The receiver index seen from client + DestIndex uint32 // The receiver index seen from destination + LastFromClient time.Time // Last packet from client to destination + LastFromDest time.Time // Last packet from destination to client + PacketCount int // Number of packets observed +} + type InitialMappings struct { Mappings map[string]ProxyMapping `json:"mappings"` // key is "ip:port" } @@ -105,6 +116,9 @@ type UDPProxyServer struct { // Session tracking for WireGuard peers // Key format: "senderIndex:receiverIndex" wgSessions sync.Map + // Communication pattern tracking for rebuilding sessions + // Key format: "clientIP:clientPort-destIP:destPort" + commPatterns sync.Map // ReachableAt is the URL where this server can be reached ReachableAt string } @@ -156,6 +170,9 @@ func (s *UDPProxyServer) Start() error { // Start the proxy mapping cleanup routine go s.cleanupIdleProxyMappings() + // Start the communication pattern cleanup routine + go s.cleanupIdleCommunicationPatterns() + return nil } @@ -445,6 +462,9 @@ func (s *UDPProxyServer) handleWireGuardPacket(packet []byte, remoteAddr *net.UD return } + // Track communication pattern for session rebuilding + s.trackCommunicationPattern(remoteAddr, destAddr, receiverIndex, true) + _, err = conn.Write(packet) if err != nil { logger.Debug("Failed to forward transport data: %v", err) @@ -465,6 +485,9 @@ func (s *UDPProxyServer) handleWireGuardPacket(packet []byte, remoteAddr *net.UD continue } + // Track communication pattern for session rebuilding + s.trackCommunicationPattern(remoteAddr, destAddr, receiverIndex, true) + _, err = conn.Write(packet) if err != nil { logger.Debug("Failed to forward transport data: %v", err) @@ -548,6 +571,9 @@ func (s *UDPProxyServer) handleResponses(conn *net.UDPConn, destAddr *net.UDPAdd LastSeen: time.Now(), }) logger.Debug("Stored session mapping: %s -> %s", sessionKey, destAddr.String()) + } else if ok && buffer[0] == WireGuardMessageTypeTransportData { + // Track communication pattern for session rebuilding (reverse direction) + s.trackCommunicationPattern(destAddr, remoteAddr, receiverIndex, false) } } @@ -823,3 +849,117 @@ func (s *UDPProxyServer) UpdateDestinationInMappings(oldDest, newDest PeerDestin return updatedCount } + +// trackCommunicationPattern tracks bidirectional communication patterns to rebuild sessions +func (s *UDPProxyServer) trackCommunicationPattern(fromAddr, toAddr *net.UDPAddr, receiverIndex uint32, fromClient bool) { + var clientAddr, destAddr *net.UDPAddr + var clientIndex, destIndex uint32 + + if fromClient { + clientAddr = fromAddr + destAddr = toAddr + clientIndex = receiverIndex + destIndex = 0 // We don't know the destination index yet + } else { + clientAddr = toAddr + destAddr = fromAddr + clientIndex = 0 // We don't know the client index yet + destIndex = receiverIndex + } + + patternKey := fmt.Sprintf("%s-%s", clientAddr.String(), destAddr.String()) + now := time.Now() + + if existingPattern, ok := s.commPatterns.Load(patternKey); ok { + pattern := existingPattern.(*CommunicationPattern) + + // Update the pattern + if fromClient { + pattern.LastFromClient = now + if pattern.ClientIndex == 0 { + pattern.ClientIndex = clientIndex + } + } else { + pattern.LastFromDest = now + if pattern.DestIndex == 0 { + pattern.DestIndex = destIndex + } + } + + pattern.PacketCount++ + s.commPatterns.Store(patternKey, pattern) + + // Check if we have bidirectional communication and can rebuild a session + s.tryRebuildSession(pattern) + } else { + // Create new pattern + pattern := &CommunicationPattern{ + FromClient: clientAddr, + ToDestination: destAddr, + ClientIndex: clientIndex, + DestIndex: destIndex, + PacketCount: 1, + } + + if fromClient { + pattern.LastFromClient = now + } else { + pattern.LastFromDest = now + } + + s.commPatterns.Store(patternKey, pattern) + } +} + +// tryRebuildSession attempts to rebuild a WireGuard session from communication patterns +func (s *UDPProxyServer) tryRebuildSession(pattern *CommunicationPattern) { + // Check if we have bidirectional communication within a reasonable time window + timeDiff := pattern.LastFromClient.Sub(pattern.LastFromDest) + if timeDiff < 0 { + timeDiff = -timeDiff + } + + // Only rebuild if we have recent bidirectional communication and both indices + if timeDiff < 30*time.Second && pattern.ClientIndex != 0 && pattern.DestIndex != 0 && pattern.PacketCount >= 4 { + // Create session mapping: client's index maps to destination + sessionKey := fmt.Sprintf("%d:%d", pattern.DestIndex, pattern.ClientIndex) + + // Check if we already have this session + if _, exists := s.wgSessions.Load(sessionKey); !exists { + session := &WireGuardSession{ + ReceiverIndex: pattern.DestIndex, + SenderIndex: pattern.ClientIndex, + DestAddr: pattern.ToDestination, + LastSeen: time.Now(), + } + + s.wgSessions.Store(sessionKey, session) + logger.Info("Rebuilt WireGuard session from communication pattern: %s -> %s (packets: %d)", + sessionKey, pattern.ToDestination.String(), pattern.PacketCount) + } + } +} + +// cleanupIdleCommunicationPatterns periodically removes idle communication patterns +func (s *UDPProxyServer) cleanupIdleCommunicationPatterns() { + ticker := time.NewTicker(10 * time.Minute) + for range ticker.C { + now := time.Now() + s.commPatterns.Range(func(key, value interface{}) bool { + pattern := value.(*CommunicationPattern) + + // Get the most recent activity + lastActivity := pattern.LastFromClient + if pattern.LastFromDest.After(lastActivity) { + lastActivity = pattern.LastFromDest + } + + // Remove patterns that haven't had activity in 20 minutes + if now.Sub(lastActivity) > 20*time.Minute { + s.commPatterns.Delete(key) + logger.Debug("Removed idle communication pattern: %s", key) + } + return true + }) + } +}