Add rate limit to hole punch

This commit is contained in:
Owen
2026-03-20 16:02:58 -07:00
parent b9261b8fea
commit fcead8cc15

View File

@@ -118,6 +118,13 @@ type Packet struct {
n int n int
} }
// holePunchRateLimitEntry tracks hole punch message counts within a sliding 1-second window.
type holePunchRateLimitEntry struct {
mu sync.Mutex
count int
windowStart time.Time
}
// WireGuard message types // WireGuard message types
const ( const (
WireGuardMessageTypeHandshakeInitiation = 1 WireGuardMessageTypeHandshakeInitiation = 1
@@ -153,6 +160,8 @@ type UDPProxyServer struct {
// Communication pattern tracking for rebuilding sessions // Communication pattern tracking for rebuilding sessions
// Key format: "clientIP:clientPort-destIP:destPort" // Key format: "clientIP:clientPort-destIP:destPort"
commPatterns sync.Map commPatterns sync.Map
// Rate limiter for encrypted hole punch messages, keyed by "ip:port"
holePunchRateLimiter sync.Map
// ReachableAt is the URL where this server can be reached // ReachableAt is the URL where this server can be reached
ReachableAt string ReachableAt string
} }
@@ -210,6 +219,9 @@ func (s *UDPProxyServer) Start() error {
// Start the communication pattern cleanup routine // Start the communication pattern cleanup routine
go s.cleanupIdleCommunicationPatterns() go s.cleanupIdleCommunicationPatterns()
// Start the hole punch rate limiter cleanup routine
go s.cleanupHolePunchRateLimiter()
return nil return nil
} }
@@ -272,6 +284,27 @@ func (s *UDPProxyServer) packetWorker() {
// Process as a WireGuard packet. // Process as a WireGuard packet.
s.handleWireGuardPacket(packet.data, packet.remoteAddr) s.handleWireGuardPacket(packet.data, packet.remoteAddr)
} else { } else {
// Rate limit: allow at most 2 hole punch messages per IP:Port per second
rateLimitKey := packet.remoteAddr.String()
entryVal, _ := s.holePunchRateLimiter.LoadOrStore(rateLimitKey, &holePunchRateLimitEntry{
windowStart: time.Now(),
})
rlEntry := entryVal.(*holePunchRateLimitEntry)
rlEntry.mu.Lock()
now := time.Now()
if now.Sub(rlEntry.windowStart) >= time.Second {
rlEntry.count = 0
rlEntry.windowStart = now
}
rlEntry.count++
allowed := rlEntry.count <= 2
rlEntry.mu.Unlock()
if !allowed {
logger.Debug("Rate limiting hole punch message from %s", rateLimitKey)
bufferPool.Put(packet.data[:1500])
continue
}
// Process as an encrypted hole punch message // Process as an encrypted hole punch message
var encMsg EncryptedHolePunchMessage var encMsg EncryptedHolePunchMessage
if err := json.Unmarshal(packet.data, &encMsg); err != nil { if err := json.Unmarshal(packet.data, &encMsg); err != nil {
@@ -1030,6 +1063,30 @@ func (s *UDPProxyServer) tryRebuildSession(pattern *CommunicationPattern) {
} }
// cleanupIdleCommunicationPatterns periodically removes idle communication patterns // cleanupIdleCommunicationPatterns periodically removes idle communication patterns
// cleanupHolePunchRateLimiter periodically evicts stale rate limit entries to prevent unbounded growth.
func (s *UDPProxyServer) cleanupHolePunchRateLimiter() {
ticker := time.NewTicker(30 * time.Second)
defer ticker.Stop()
for {
select {
case <-ticker.C:
now := time.Now()
s.holePunchRateLimiter.Range(func(key, value interface{}) bool {
rlEntry := value.(*holePunchRateLimitEntry)
rlEntry.mu.Lock()
stale := now.Sub(rlEntry.windowStart) > 10*time.Second
rlEntry.mu.Unlock()
if stale {
s.holePunchRateLimiter.Delete(key)
}
return true
})
case <-s.ctx.Done():
return
}
}
}
func (s *UDPProxyServer) cleanupIdleCommunicationPatterns() { func (s *UDPProxyServer) cleanupIdleCommunicationPatterns() {
ticker := time.NewTicker(10 * time.Minute) ticker := time.NewTicker(10 * time.Minute)
defer ticker.Stop() defer ticker.Stop()