From c9a6b85e1d98fc41ddeeb74ec66c6eddbddfc3c2 Mon Sep 17 00:00:00 2001 From: Owen Date: Mon, 7 Apr 2025 21:45:57 -0400 Subject: [PATCH] Attempt to add sender and receiver ids to relaying --- relay/relay.go | 273 ++++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 250 insertions(+), 23 deletions(-) diff --git a/relay/relay.go b/relay/relay.go index 8611241..dd293fc 100644 --- a/relay/relay.go +++ b/relay/relay.go @@ -2,6 +2,7 @@ package relay import ( "bytes" + "encoding/binary" "encoding/json" "fmt" "io" @@ -38,7 +39,12 @@ type ClientEndpoint struct { Timestamp int64 `json:"timestamp"` } +// Updated to support multiple destination peers type ProxyMapping struct { + Destinations []PeerDestination `json:"destinations"` +} + +type PeerDestination struct { DestinationIP string `json:"destinationIP"` DestinationPort int `json:"destinationPort"` } @@ -48,6 +54,14 @@ type DestinationConn struct { lastUsed time.Time } +// Type for storing WireGuard handshake information +type WireGuardSession struct { + ReceiverIndex uint32 + SenderIndex uint32 + DestAddr *net.UDPAddr + LastSeen time.Time +} + type InitialMappings struct { Mappings map[string]ProxyMapping `json:"mappings"` // key is "ip:port" } @@ -59,6 +73,14 @@ type Packet struct { n int } +// WireGuard message types +const ( + WireGuardMessageTypeHandshakeInitiation = 1 + WireGuardMessageTypeHandshakeResponse = 2 + WireGuardMessageTypeCookieReply = 3 + WireGuardMessageTypeTransportData = 4 +) + // --- End Types --- // bufferPool allows reusing buffers to reduce allocations. @@ -77,6 +99,10 @@ type UDPProxyServer struct { connections sync.Map // map[string]*DestinationConn where key is destination "ip:port" privateKey wgtypes.Key packetChan chan Packet + + // Session tracking for WireGuard peers + // Key format: "senderIndex:receiverIndex" + wgSessions sync.Map } // NewUDPProxyServer initializes the server with a buffered packet channel. @@ -119,6 +145,9 @@ func (s *UDPProxyServer) Start() error { // Start the idle connection cleanup routine. go s.cleanupIdleConnections() + // Start the session cleanup routine + go s.cleanupIdleSessions() + return nil } @@ -259,34 +288,201 @@ func (s *UDPProxyServer) fetchInitialMappings() error { return nil } -// Example handleWireGuardPacket remains unchanged. +// Extract WireGuard message indices +func extractWireGuardIndices(packet []byte) (uint32, uint32, bool) { + if len(packet) < 12 { + return 0, 0, false + } + + messageType := packet[0] + if messageType == WireGuardMessageTypeHandshakeInitiation { + // Handshake initiation: extract sender index at offset 4 + senderIndex := binary.LittleEndian.Uint32(packet[4:8]) + return 0, senderIndex, true + } else if messageType == WireGuardMessageTypeHandshakeResponse { + // Handshake response: extract sender index at offset 4 and receiver index at offset 8 + senderIndex := binary.LittleEndian.Uint32(packet[4:8]) + receiverIndex := binary.LittleEndian.Uint32(packet[8:12]) + return receiverIndex, senderIndex, true + } else if messageType == WireGuardMessageTypeTransportData { + // Transport data: extract receiver index at offset 4 + receiverIndex := binary.LittleEndian.Uint32(packet[4:8]) + return receiverIndex, 0, true + } + + return 0, 0, false +} + +// Updated to handle multi-peer WireGuard communication func (s *UDPProxyServer) handleWireGuardPacket(packet []byte, remoteAddr *net.UDPAddr) { + if len(packet) == 0 { + logger.Error("Received empty packet") + return + } + + messageType := packet[0] + receiverIndex, senderIndex, ok := extractWireGuardIndices(packet) + + if !ok { + logger.Error("Failed to extract WireGuard indices") + return + } + key := remoteAddr.String() - mapping, ok := s.proxyMappings.Load(key) + mappingObj, ok := s.proxyMappings.Load(key) if !ok { logger.Error("No proxy mapping found for %s", key) return } - proxyMapping := mapping.(ProxyMapping) - destAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", - proxyMapping.DestinationIP, proxyMapping.DestinationPort)) - if err != nil { - logger.Error("Failed to resolve destination address: %v", err) - return - } - conn, err := s.getOrCreateConnection(destAddr, remoteAddr) - if err != nil { - logger.Error("Failed to get/create connection: %v", err) - return - } - _, err = conn.Write(packet) - if err != nil { - logger.Error("Failed to proxy packet: %v", err) + + proxyMapping := mappingObj.(ProxyMapping) + + // Handle different WireGuard message types + switch messageType { + case WireGuardMessageTypeHandshakeInitiation: + // Initial handshake: forward to all peers + logger.Debug("Forwarding handshake initiation from %s (sender index: %d)", remoteAddr, senderIndex) + + for _, dest := range proxyMapping.Destinations { + destAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", dest.DestinationIP, dest.DestinationPort)) + if err != nil { + logger.Error("Failed to resolve destination address: %v", err) + continue + } + + conn, err := s.getOrCreateConnection(destAddr, remoteAddr) + if err != nil { + logger.Error("Failed to get/create connection: %v", err) + continue + } + + _, err = conn.Write(packet) + if err != nil { + logger.Error("Failed to forward handshake initiation: %v", err) + } + } + + case WireGuardMessageTypeHandshakeResponse: + // Received handshake response: establish session mapping + logger.Debug("Received handshake response with receiver index %d and sender index %d from %s", + receiverIndex, senderIndex, remoteAddr) + + // Create a session key for the peer that sent the initial handshake + sessionKey := fmt.Sprintf("%d:%d", receiverIndex, senderIndex) + + // Store the session information + s.wgSessions.Store(sessionKey, &WireGuardSession{ + ReceiverIndex: receiverIndex, + SenderIndex: senderIndex, + DestAddr: remoteAddr, + LastSeen: time.Now(), + }) + + // Forward the response to the original sender + for _, dest := range proxyMapping.Destinations { + destAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", dest.DestinationIP, dest.DestinationPort)) + if err != nil { + logger.Error("Failed to resolve destination address: %v", err) + continue + } + + conn, err := s.getOrCreateConnection(destAddr, remoteAddr) + if err != nil { + logger.Error("Failed to get/create connection: %v", err) + continue + } + + _, err = conn.Write(packet) + if err != nil { + logger.Error("Failed to forward handshake response: %v", err) + } + } + + case WireGuardMessageTypeTransportData: + // Data packet: forward only to the established session peer + logger.Debug("Received transport data with receiver index %d from %s", receiverIndex, remoteAddr) + + // Look up the session based on the receiver index + var destAddr *net.UDPAddr + + // First check for existing sessions to see if we know where to send this packet + s.wgSessions.Range(func(k, v interface{}) bool { + session := v.(*WireGuardSession) + if session.SenderIndex == receiverIndex { + // Found matching session + destAddr = session.DestAddr + + // Update last seen time + session.LastSeen = time.Now() + s.wgSessions.Store(k, session) + return false // stop iteration + } + return true // continue iteration + }) + + if destAddr != nil { + // We found a specific peer to forward to + conn, err := s.getOrCreateConnection(destAddr, remoteAddr) + if err != nil { + logger.Error("Failed to get/create connection: %v", err) + return + } + + _, err = conn.Write(packet) + if err != nil { + logger.Error("Failed to forward transport data: %v", err) + } + } else { + // No known session, fall back to forwarding to all peers + logger.Debug("No session found for receiver index %d, forwarding to all destinations", receiverIndex) + for _, dest := range proxyMapping.Destinations { + destAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", dest.DestinationIP, dest.DestinationPort)) + if err != nil { + logger.Error("Failed to resolve destination address: %v", err) + continue + } + + conn, err := s.getOrCreateConnection(destAddr, remoteAddr) + if err != nil { + logger.Error("Failed to get/create connection: %v", err) + continue + } + + _, err = conn.Write(packet) + if err != nil { + logger.Error("Failed to forward transport data: %v", err) + } + } + } + + default: + // Other packet types (like cookie reply) + logger.Debug("Forwarding WireGuard packet type %d from %s", messageType, remoteAddr) + + // Forward to all peers + for _, dest := range proxyMapping.Destinations { + destAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", dest.DestinationIP, dest.DestinationPort)) + if err != nil { + logger.Error("Failed to resolve destination address: %v", err) + continue + } + + conn, err := s.getOrCreateConnection(destAddr, remoteAddr) + if err != nil { + logger.Error("Failed to get/create connection: %v", err) + continue + } + + _, err = conn.Write(packet) + if err != nil { + logger.Error("Failed to forward WireGuard packet: %v", err) + } + } } } func (s *UDPProxyServer) getOrCreateConnection(destAddr *net.UDPAddr, remoteAddr *net.UDPAddr) (*net.UDPConn, error) { - key := remoteAddr.String() + key := destAddr.String() + "-" + remoteAddr.String() // Check if we have an existing connection if conn, ok := s.connections.Load(key); ok { @@ -322,6 +518,22 @@ func (s *UDPProxyServer) handleResponses(conn *net.UDPConn, destAddr *net.UDPAdd return } + // Process the response to track sessions if it's a WireGuard packet + if n > 0 && buffer[0] >= 1 && buffer[0] <= 4 { + receiverIndex, senderIndex, ok := extractWireGuardIndices(buffer[:n]) + if ok && buffer[0] == WireGuardMessageTypeHandshakeResponse { + // Store the session mapping for the handshake response + sessionKey := fmt.Sprintf("%d:%d", senderIndex, receiverIndex) + s.wgSessions.Store(sessionKey, &WireGuardSession{ + ReceiverIndex: receiverIndex, + SenderIndex: senderIndex, + DestAddr: destAddr, + LastSeen: time.Now(), + }) + logger.Debug("Stored session mapping: %s -> %s", sessionKey, destAddr.String()) + } + } + // Forward the response back through the main listener _, err = s.conn.WriteToUDP(buffer[:n], remoteAddr) if err != nil { @@ -346,6 +558,22 @@ func (s *UDPProxyServer) cleanupIdleConnections() { } } +// New method to periodically remove idle sessions +func (s *UDPProxyServer) cleanupIdleSessions() { + ticker := time.NewTicker(5 * time.Minute) + for range ticker.C { + now := time.Now() + s.wgSessions.Range(func(key, value interface{}) bool { + session := value.(*WireGuardSession) + if now.Sub(session.LastSeen) > 15*time.Minute { + s.wgSessions.Delete(key) + logger.Debug("Removed idle session: %s", key) + } + return true + }) + } +} + func (s *UDPProxyServer) notifyServer(endpoint ClientEndpoint) { jsonData, err := json.Marshal(endpoint) if err != nil { @@ -380,15 +608,14 @@ func (s *UDPProxyServer) notifyServer(endpoint ClientEndpoint) { key := fmt.Sprintf("%s:%d", endpoint.IP, endpoint.Port) s.proxyMappings.Store(key, mapping) - logger.Debug("Stored proxy mapping for %s: %v", key, mapping) + logger.Debug("Stored proxy mapping for %s with %d destinations", key, len(mapping.Destinations)) } -func (s *UDPProxyServer) UpdateProxyMapping(sourceIP string, sourcePort int, - destinationIP string, destinationPort int) { +// Updated to support multiple destinations +func (s *UDPProxyServer) UpdateProxyMapping(sourceIP string, sourcePort int, destinations []PeerDestination) { key := net.JoinHostPort(sourceIP, strconv.Itoa(sourcePort)) mapping := ProxyMapping{ - DestinationIP: destinationIP, - DestinationPort: destinationPort, + Destinations: destinations, } s.proxyMappings.Store(key, mapping) }