diff --git a/relay/relay.go b/relay/relay.go index 22aff76..8fabbff 100644 --- a/relay/relay.go +++ b/relay/relay.go @@ -118,6 +118,13 @@ type Packet struct { n int } +// holePunchRateLimitEntry tracks hole punch message counts within a sliding 1-second window. +type holePunchRateLimitEntry struct { + mu sync.Mutex + count int + windowStart time.Time +} + // WireGuard message types const ( WireGuardMessageTypeHandshakeInitiation = 1 @@ -153,6 +160,8 @@ type UDPProxyServer struct { // Communication pattern tracking for rebuilding sessions // Key format: "clientIP:clientPort-destIP:destPort" commPatterns sync.Map + // Rate limiter for encrypted hole punch messages, keyed by "ip:port" + holePunchRateLimiter sync.Map // ReachableAt is the URL where this server can be reached ReachableAt string } @@ -210,6 +219,9 @@ func (s *UDPProxyServer) Start() error { // Start the communication pattern cleanup routine go s.cleanupIdleCommunicationPatterns() + // Start the hole punch rate limiter cleanup routine + go s.cleanupHolePunchRateLimiter() + return nil } @@ -272,6 +284,27 @@ func (s *UDPProxyServer) packetWorker() { // Process as a WireGuard packet. s.handleWireGuardPacket(packet.data, packet.remoteAddr) } else { + // Rate limit: allow at most 2 hole punch messages per IP:Port per second + rateLimitKey := packet.remoteAddr.String() + entryVal, _ := s.holePunchRateLimiter.LoadOrStore(rateLimitKey, &holePunchRateLimitEntry{ + windowStart: time.Now(), + }) + rlEntry := entryVal.(*holePunchRateLimitEntry) + rlEntry.mu.Lock() + now := time.Now() + if now.Sub(rlEntry.windowStart) >= time.Second { + rlEntry.count = 0 + rlEntry.windowStart = now + } + rlEntry.count++ + allowed := rlEntry.count <= 2 + rlEntry.mu.Unlock() + if !allowed { + logger.Debug("Rate limiting hole punch message from %s", rateLimitKey) + bufferPool.Put(packet.data[:1500]) + continue + } + // Process as an encrypted hole punch message var encMsg EncryptedHolePunchMessage if err := json.Unmarshal(packet.data, &encMsg); err != nil { @@ -1030,6 +1063,30 @@ func (s *UDPProxyServer) tryRebuildSession(pattern *CommunicationPattern) { } // cleanupIdleCommunicationPatterns periodically removes idle communication patterns +// cleanupHolePunchRateLimiter periodically evicts stale rate limit entries to prevent unbounded growth. +func (s *UDPProxyServer) cleanupHolePunchRateLimiter() { + ticker := time.NewTicker(30 * time.Second) + defer ticker.Stop() + for { + select { + case <-ticker.C: + now := time.Now() + s.holePunchRateLimiter.Range(func(key, value interface{}) bool { + rlEntry := value.(*holePunchRateLimitEntry) + rlEntry.mu.Lock() + stale := now.Sub(rlEntry.windowStart) > 10*time.Second + rlEntry.mu.Unlock() + if stale { + s.holePunchRateLimiter.Delete(key) + } + return true + }) + case <-s.ctx.Done(): + return + } + } +} + func (s *UDPProxyServer) cleanupIdleCommunicationPatterns() { ticker := time.NewTicker(10 * time.Minute) defer ticker.Stop()