Files
netbird/management/server/http/middleware/rate_limiter_test.go
2026-04-21 18:30:33 +02:00

330 lines
8.6 KiB
Go

package middleware
import (
"fmt"
"net/http"
"net/http/httptest"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestAPIRateLimiter_Allow(t *testing.T) {
rl := NewAPIRateLimiter(&RateLimiterConfig{
RequestsPerMinute: 60, // 1 per second
Burst: 2,
CleanupInterval: time.Minute,
LimiterTTL: time.Minute,
})
defer rl.Stop()
// First two requests should be allowed (burst)
assert.True(t, rl.Allow("test-key"))
assert.True(t, rl.Allow("test-key"))
// Third request should be denied (exceeded burst)
assert.False(t, rl.Allow("test-key"))
// Different key should be allowed
assert.True(t, rl.Allow("different-key"))
}
func TestAPIRateLimiter_Middleware(t *testing.T) {
rl := NewAPIRateLimiter(&RateLimiterConfig{
RequestsPerMinute: 60, // 1 per second
Burst: 2,
CleanupInterval: time.Minute,
LimiterTTL: time.Minute,
})
defer rl.Stop()
// Create a simple handler that returns 200 OK
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
// Wrap with rate limiter middleware
handler := rl.Middleware(nextHandler)
// First two requests should pass (burst)
for i := 0; i < 2; i++ {
req := httptest.NewRequest(http.MethodGet, "/test", nil)
req.RemoteAddr = "192.168.1.1:12345"
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
assert.Equal(t, http.StatusOK, rr.Code, "request %d should be allowed", i+1)
}
// Third request should be rate limited
req := httptest.NewRequest(http.MethodGet, "/test", nil)
req.RemoteAddr = "192.168.1.1:12345"
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
assert.Equal(t, http.StatusTooManyRequests, rr.Code)
}
func TestAPIRateLimiter_Middleware_DifferentIPs(t *testing.T) {
rl := NewAPIRateLimiter(&RateLimiterConfig{
RequestsPerMinute: 60,
Burst: 1,
CleanupInterval: time.Minute,
LimiterTTL: time.Minute,
})
defer rl.Stop()
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
handler := rl.Middleware(nextHandler)
// Request from first IP
req1 := httptest.NewRequest(http.MethodGet, "/test", nil)
req1.RemoteAddr = "192.168.1.1:12345"
rr1 := httptest.NewRecorder()
handler.ServeHTTP(rr1, req1)
assert.Equal(t, http.StatusOK, rr1.Code)
// Second request from first IP should be rate limited
req2 := httptest.NewRequest(http.MethodGet, "/test", nil)
req2.RemoteAddr = "192.168.1.1:12345"
rr2 := httptest.NewRecorder()
handler.ServeHTTP(rr2, req2)
assert.Equal(t, http.StatusTooManyRequests, rr2.Code)
// Request from different IP should be allowed
req3 := httptest.NewRequest(http.MethodGet, "/test", nil)
req3.RemoteAddr = "192.168.1.2:12345"
rr3 := httptest.NewRecorder()
handler.ServeHTTP(rr3, req3)
assert.Equal(t, http.StatusOK, rr3.Code)
}
func TestGetClientIP(t *testing.T) {
tests := []struct {
name string
remoteAddr string
expected string
}{
{
name: "remote addr with port",
remoteAddr: "192.168.1.1:12345",
expected: "192.168.1.1",
},
{
name: "remote addr without port",
remoteAddr: "192.168.1.1",
expected: "192.168.1.1",
},
{
name: "IPv6 with port",
remoteAddr: "[::1]:12345",
expected: "::1",
},
{
name: "IPv6 without port",
remoteAddr: "::1",
expected: "::1",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/test", nil)
req.RemoteAddr = tc.remoteAddr
assert.Equal(t, tc.expected, getClientIP(req))
})
}
}
func TestAPIRateLimiter_Reset(t *testing.T) {
rl := NewAPIRateLimiter(&RateLimiterConfig{
RequestsPerMinute: 60,
Burst: 1,
CleanupInterval: time.Minute,
LimiterTTL: time.Minute,
})
defer rl.Stop()
// Use up the burst
assert.True(t, rl.Allow("test-key"))
assert.False(t, rl.Allow("test-key"))
// Reset the limiter
rl.Reset("test-key")
// Should be allowed again
assert.True(t, rl.Allow("test-key"))
}
func TestAPIRateLimiter_SetEnabled(t *testing.T) {
rl := NewAPIRateLimiter(&RateLimiterConfig{
RequestsPerMinute: 60,
Burst: 1,
CleanupInterval: time.Minute,
LimiterTTL: time.Minute,
})
defer rl.Stop()
assert.True(t, rl.Allow("key"))
assert.False(t, rl.Allow("key"), "burst exhausted while enabled")
rl.SetEnabled(false)
assert.False(t, rl.Enabled())
for i := 0; i < 5; i++ {
assert.True(t, rl.Allow("key"), "disabled limiter must always allow")
}
rl.SetEnabled(true)
assert.True(t, rl.Enabled())
assert.False(t, rl.Allow("key"), "re-enabled limiter retains prior bucket state")
}
func TestAPIRateLimiter_UpdateConfig(t *testing.T) {
rl := NewAPIRateLimiter(&RateLimiterConfig{
RequestsPerMinute: 60,
Burst: 2,
CleanupInterval: time.Minute,
LimiterTTL: time.Minute,
})
defer rl.Stop()
assert.True(t, rl.Allow("k1"))
assert.True(t, rl.Allow("k1"))
assert.False(t, rl.Allow("k1"), "burst=2 exhausted")
rl.UpdateConfig(&RateLimiterConfig{
RequestsPerMinute: 60,
Burst: 10,
CleanupInterval: time.Minute,
LimiterTTL: time.Minute,
})
// New burst applies to existing keys in place; bucket refills up to new burst over time,
// but importantly newly-added keys use the updated config immediately.
assert.True(t, rl.Allow("k2"))
for i := 0; i < 9; i++ {
assert.True(t, rl.Allow("k2"))
}
assert.False(t, rl.Allow("k2"), "new burst=10 exhausted")
}
func TestAPIRateLimiter_UpdateConfig_NilIgnored(t *testing.T) {
rl := NewAPIRateLimiter(&RateLimiterConfig{
RequestsPerMinute: 60,
Burst: 1,
CleanupInterval: time.Minute,
LimiterTTL: time.Minute,
})
defer rl.Stop()
rl.UpdateConfig(nil) // must not panic or zero the config
assert.True(t, rl.Allow("k"))
assert.False(t, rl.Allow("k"))
}
func TestAPIRateLimiter_UpdateConfig_NonPositiveIgnored(t *testing.T) {
rl := NewAPIRateLimiter(&RateLimiterConfig{
RequestsPerMinute: 60,
Burst: 1,
CleanupInterval: time.Minute,
LimiterTTL: time.Minute,
})
defer rl.Stop()
assert.True(t, rl.Allow("k"))
assert.False(t, rl.Allow("k"))
rl.UpdateConfig(&RateLimiterConfig{RequestsPerMinute: 0, Burst: 0, CleanupInterval: time.Minute, LimiterTTL: time.Minute})
rl.UpdateConfig(&RateLimiterConfig{RequestsPerMinute: -1, Burst: 5, CleanupInterval: time.Minute, LimiterTTL: time.Minute})
rl.UpdateConfig(&RateLimiterConfig{RequestsPerMinute: 60, Burst: -1, CleanupInterval: time.Minute, LimiterTTL: time.Minute})
rl.Reset("k")
assert.True(t, rl.Allow("k"))
assert.False(t, rl.Allow("k"), "burst should still be 1 — invalid UpdateConfig calls were ignored")
}
func TestAPIRateLimiter_ConcurrentAllowAndUpdate(t *testing.T) {
rl := NewAPIRateLimiter(&RateLimiterConfig{
RequestsPerMinute: 600,
Burst: 10,
CleanupInterval: time.Minute,
LimiterTTL: time.Minute,
})
defer rl.Stop()
var wg sync.WaitGroup
stop := make(chan struct{})
for i := 0; i < 8; i++ {
wg.Add(1)
go func(id int) {
defer wg.Done()
key := fmt.Sprintf("k%d", id)
for {
select {
case <-stop:
return
default:
rl.Allow(key)
}
}
}(i)
}
wg.Add(1)
go func() {
defer wg.Done()
for i := 0; i < 200; i++ {
select {
case <-stop:
return
default:
rl.UpdateConfig(&RateLimiterConfig{
RequestsPerMinute: float64(30 + (i % 90)),
Burst: 1 + (i % 20),
CleanupInterval: time.Minute,
LimiterTTL: time.Minute,
})
rl.SetEnabled(i%2 == 0)
}
}
}()
time.Sleep(100 * time.Millisecond)
close(stop)
wg.Wait()
}
func TestRateLimiterConfigFromEnv(t *testing.T) {
t.Setenv(RateLimitingEnabledEnv, "true")
t.Setenv(RateLimitingRPMEnv, "42")
t.Setenv(RateLimitingBurstEnv, "7")
cfg, enabled := RateLimiterConfigFromEnv()
assert.True(t, enabled)
assert.Equal(t, float64(42), cfg.RequestsPerMinute)
assert.Equal(t, 7, cfg.Burst)
t.Setenv(RateLimitingEnabledEnv, "false")
_, enabled = RateLimiterConfigFromEnv()
assert.False(t, enabled)
t.Setenv(RateLimitingEnabledEnv, "")
t.Setenv(RateLimitingRPMEnv, "")
t.Setenv(RateLimitingBurstEnv, "")
cfg, enabled = RateLimiterConfigFromEnv()
assert.False(t, enabled)
assert.Equal(t, float64(defaultAPIRPM), cfg.RequestsPerMinute)
assert.Equal(t, defaultAPIBurst, cfg.Burst)
t.Setenv(RateLimitingRPMEnv, "0")
t.Setenv(RateLimitingBurstEnv, "-5")
cfg, _ = RateLimiterConfigFromEnv()
assert.Equal(t, float64(defaultAPIRPM), cfg.RequestsPerMinute, "non-positive rpm must fall back to default")
assert.Equal(t, defaultAPIBurst, cfg.Burst, "non-positive burst must fall back to default")
}