diff --git a/relay/relay.go b/relay/relay.go index 22aff76..190d077 100644 --- a/relay/relay.go +++ b/relay/relay.go @@ -153,6 +153,9 @@ type UDPProxyServer struct { // Communication pattern tracking for rebuilding sessions // Key format: "clientIP:clientPort-destIP:destPort" commPatterns sync.Map + // Cache for resolved UDP addresses to avoid per-packet DNS lookups + // Key: "ip:port" string, Value: *net.UDPAddr + addrCache sync.Map // ReachableAt is the URL where this server can be reached ReachableAt string } @@ -416,6 +419,43 @@ func extractWireGuardIndices(packet []byte) (uint32, uint32, bool) { return 0, 0, false } +// cachedAddr holds a resolved UDP address with TTL +type cachedAddr struct { + addr *net.UDPAddr + expiresAt time.Time +} + +// addrCacheTTL is how long resolved addresses are cached before re-resolving +const addrCacheTTL = 5 * time.Minute + +// getCachedAddr returns a cached UDP address or resolves and caches it. +// This avoids per-packet DNS lookups which are a major throughput bottleneck. +func (s *UDPProxyServer) getCachedAddr(ip string, port int) (*net.UDPAddr, error) { + key := fmt.Sprintf("%s:%d", ip, port) + + // Check cache first + if cached, ok := s.addrCache.Load(key); ok { + entry := cached.(*cachedAddr) + if time.Now().Before(entry.expiresAt) { + return entry.addr, nil + } + // Cache expired, delete and re-resolve + s.addrCache.Delete(key) + } + + // Resolve and cache + addr, err := net.ResolveUDPAddr("udp", key) + if err != nil { + return nil, err + } + + s.addrCache.Store(key, &cachedAddr{ + addr: addr, + expiresAt: time.Now().Add(addrCacheTTL), + }) + return addr, nil +} + // Updated to handle multi-peer WireGuard communication func (s *UDPProxyServer) handleWireGuardPacket(packet []byte, remoteAddr *net.UDPAddr) { if len(packet) == 0 { @@ -450,7 +490,7 @@ func (s *UDPProxyServer) handleWireGuardPacket(packet []byte, remoteAddr *net.UD logger.Debug("Forwarding handshake initiation from %s (sender index: %d) to peers %v", remoteAddr, senderIndex, proxyMapping.Destinations) for _, dest := range proxyMapping.Destinations { - destAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", dest.DestinationIP, dest.DestinationPort)) + destAddr, err := s.getCachedAddr(dest.DestinationIP, dest.DestinationPort) if err != nil { logger.Error("Failed to resolve destination address: %v", err) continue @@ -486,7 +526,7 @@ func (s *UDPProxyServer) handleWireGuardPacket(packet []byte, remoteAddr *net.UD // Forward the response to the original sender for _, dest := range proxyMapping.Destinations { - destAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", dest.DestinationIP, dest.DestinationPort)) + destAddr, err := s.getCachedAddr(dest.DestinationIP, dest.DestinationPort) if err != nil { logger.Error("Failed to resolve destination address: %v", err) continue @@ -543,7 +583,7 @@ func (s *UDPProxyServer) handleWireGuardPacket(packet []byte, remoteAddr *net.UD // No known session, fall back to forwarding to all peers logger.Debug("No session found for receiver index %d, forwarding to all destinations", receiverIndex) for _, dest := range proxyMapping.Destinations { - destAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", dest.DestinationIP, dest.DestinationPort)) + destAddr, err := s.getCachedAddr(dest.DestinationIP, dest.DestinationPort) if err != nil { logger.Error("Failed to resolve destination address: %v", err) continue @@ -571,7 +611,7 @@ func (s *UDPProxyServer) handleWireGuardPacket(packet []byte, remoteAddr *net.UD // Forward to all peers for _, dest := range proxyMapping.Destinations { - destAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", dest.DestinationIP, dest.DestinationPort)) + destAddr, err := s.getCachedAddr(dest.DestinationIP, dest.DestinationPort) if err != nil { logger.Error("Failed to resolve destination address: %v", err) continue