mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-21 17:56:39 +00:00
330 lines
8.6 KiB
Go
330 lines
8.6 KiB
Go
package middleware
|
|
|
|
import (
|
|
"fmt"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
)
|
|
|
|
func TestAPIRateLimiter_Allow(t *testing.T) {
|
|
rl := NewAPIRateLimiter(&RateLimiterConfig{
|
|
RequestsPerMinute: 60, // 1 per second
|
|
Burst: 2,
|
|
CleanupInterval: time.Minute,
|
|
LimiterTTL: time.Minute,
|
|
})
|
|
defer rl.Stop()
|
|
|
|
// First two requests should be allowed (burst)
|
|
assert.True(t, rl.Allow("test-key"))
|
|
assert.True(t, rl.Allow("test-key"))
|
|
|
|
// Third request should be denied (exceeded burst)
|
|
assert.False(t, rl.Allow("test-key"))
|
|
|
|
// Different key should be allowed
|
|
assert.True(t, rl.Allow("different-key"))
|
|
}
|
|
|
|
func TestAPIRateLimiter_Middleware(t *testing.T) {
|
|
rl := NewAPIRateLimiter(&RateLimiterConfig{
|
|
RequestsPerMinute: 60, // 1 per second
|
|
Burst: 2,
|
|
CleanupInterval: time.Minute,
|
|
LimiterTTL: time.Minute,
|
|
})
|
|
defer rl.Stop()
|
|
|
|
// Create a simple handler that returns 200 OK
|
|
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
})
|
|
|
|
// Wrap with rate limiter middleware
|
|
handler := rl.Middleware(nextHandler)
|
|
|
|
// First two requests should pass (burst)
|
|
for i := 0; i < 2; i++ {
|
|
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
|
req.RemoteAddr = "192.168.1.1:12345"
|
|
rr := httptest.NewRecorder()
|
|
handler.ServeHTTP(rr, req)
|
|
assert.Equal(t, http.StatusOK, rr.Code, "request %d should be allowed", i+1)
|
|
}
|
|
|
|
// Third request should be rate limited
|
|
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
|
req.RemoteAddr = "192.168.1.1:12345"
|
|
rr := httptest.NewRecorder()
|
|
handler.ServeHTTP(rr, req)
|
|
assert.Equal(t, http.StatusTooManyRequests, rr.Code)
|
|
}
|
|
|
|
func TestAPIRateLimiter_Middleware_DifferentIPs(t *testing.T) {
|
|
rl := NewAPIRateLimiter(&RateLimiterConfig{
|
|
RequestsPerMinute: 60,
|
|
Burst: 1,
|
|
CleanupInterval: time.Minute,
|
|
LimiterTTL: time.Minute,
|
|
})
|
|
defer rl.Stop()
|
|
|
|
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
})
|
|
|
|
handler := rl.Middleware(nextHandler)
|
|
|
|
// Request from first IP
|
|
req1 := httptest.NewRequest(http.MethodGet, "/test", nil)
|
|
req1.RemoteAddr = "192.168.1.1:12345"
|
|
rr1 := httptest.NewRecorder()
|
|
handler.ServeHTTP(rr1, req1)
|
|
assert.Equal(t, http.StatusOK, rr1.Code)
|
|
|
|
// Second request from first IP should be rate limited
|
|
req2 := httptest.NewRequest(http.MethodGet, "/test", nil)
|
|
req2.RemoteAddr = "192.168.1.1:12345"
|
|
rr2 := httptest.NewRecorder()
|
|
handler.ServeHTTP(rr2, req2)
|
|
assert.Equal(t, http.StatusTooManyRequests, rr2.Code)
|
|
|
|
// Request from different IP should be allowed
|
|
req3 := httptest.NewRequest(http.MethodGet, "/test", nil)
|
|
req3.RemoteAddr = "192.168.1.2:12345"
|
|
rr3 := httptest.NewRecorder()
|
|
handler.ServeHTTP(rr3, req3)
|
|
assert.Equal(t, http.StatusOK, rr3.Code)
|
|
}
|
|
|
|
func TestGetClientIP(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
remoteAddr string
|
|
expected string
|
|
}{
|
|
{
|
|
name: "remote addr with port",
|
|
remoteAddr: "192.168.1.1:12345",
|
|
expected: "192.168.1.1",
|
|
},
|
|
{
|
|
name: "remote addr without port",
|
|
remoteAddr: "192.168.1.1",
|
|
expected: "192.168.1.1",
|
|
},
|
|
{
|
|
name: "IPv6 with port",
|
|
remoteAddr: "[::1]:12345",
|
|
expected: "::1",
|
|
},
|
|
{
|
|
name: "IPv6 without port",
|
|
remoteAddr: "::1",
|
|
expected: "::1",
|
|
},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
|
req.RemoteAddr = tc.remoteAddr
|
|
assert.Equal(t, tc.expected, getClientIP(req))
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestAPIRateLimiter_Reset(t *testing.T) {
|
|
rl := NewAPIRateLimiter(&RateLimiterConfig{
|
|
RequestsPerMinute: 60,
|
|
Burst: 1,
|
|
CleanupInterval: time.Minute,
|
|
LimiterTTL: time.Minute,
|
|
})
|
|
defer rl.Stop()
|
|
|
|
// Use up the burst
|
|
assert.True(t, rl.Allow("test-key"))
|
|
assert.False(t, rl.Allow("test-key"))
|
|
|
|
// Reset the limiter
|
|
rl.Reset("test-key")
|
|
|
|
// Should be allowed again
|
|
assert.True(t, rl.Allow("test-key"))
|
|
}
|
|
|
|
func TestAPIRateLimiter_SetEnabled(t *testing.T) {
|
|
rl := NewAPIRateLimiter(&RateLimiterConfig{
|
|
RequestsPerMinute: 60,
|
|
Burst: 1,
|
|
CleanupInterval: time.Minute,
|
|
LimiterTTL: time.Minute,
|
|
})
|
|
defer rl.Stop()
|
|
|
|
assert.True(t, rl.Allow("key"))
|
|
assert.False(t, rl.Allow("key"), "burst exhausted while enabled")
|
|
|
|
rl.SetEnabled(false)
|
|
assert.False(t, rl.Enabled())
|
|
for i := 0; i < 5; i++ {
|
|
assert.True(t, rl.Allow("key"), "disabled limiter must always allow")
|
|
}
|
|
|
|
rl.SetEnabled(true)
|
|
assert.True(t, rl.Enabled())
|
|
assert.False(t, rl.Allow("key"), "re-enabled limiter retains prior bucket state")
|
|
}
|
|
|
|
func TestAPIRateLimiter_UpdateConfig(t *testing.T) {
|
|
rl := NewAPIRateLimiter(&RateLimiterConfig{
|
|
RequestsPerMinute: 60,
|
|
Burst: 2,
|
|
CleanupInterval: time.Minute,
|
|
LimiterTTL: time.Minute,
|
|
})
|
|
defer rl.Stop()
|
|
|
|
assert.True(t, rl.Allow("k1"))
|
|
assert.True(t, rl.Allow("k1"))
|
|
assert.False(t, rl.Allow("k1"), "burst=2 exhausted")
|
|
|
|
rl.UpdateConfig(&RateLimiterConfig{
|
|
RequestsPerMinute: 60,
|
|
Burst: 10,
|
|
CleanupInterval: time.Minute,
|
|
LimiterTTL: time.Minute,
|
|
})
|
|
|
|
// New burst applies to existing keys in place; bucket refills up to new burst over time,
|
|
// but importantly newly-added keys use the updated config immediately.
|
|
assert.True(t, rl.Allow("k2"))
|
|
for i := 0; i < 9; i++ {
|
|
assert.True(t, rl.Allow("k2"))
|
|
}
|
|
assert.False(t, rl.Allow("k2"), "new burst=10 exhausted")
|
|
}
|
|
|
|
func TestAPIRateLimiter_UpdateConfig_NilIgnored(t *testing.T) {
|
|
rl := NewAPIRateLimiter(&RateLimiterConfig{
|
|
RequestsPerMinute: 60,
|
|
Burst: 1,
|
|
CleanupInterval: time.Minute,
|
|
LimiterTTL: time.Minute,
|
|
})
|
|
defer rl.Stop()
|
|
|
|
rl.UpdateConfig(nil) // must not panic or zero the config
|
|
|
|
assert.True(t, rl.Allow("k"))
|
|
assert.False(t, rl.Allow("k"))
|
|
}
|
|
|
|
func TestAPIRateLimiter_UpdateConfig_NonPositiveIgnored(t *testing.T) {
|
|
rl := NewAPIRateLimiter(&RateLimiterConfig{
|
|
RequestsPerMinute: 60,
|
|
Burst: 1,
|
|
CleanupInterval: time.Minute,
|
|
LimiterTTL: time.Minute,
|
|
})
|
|
defer rl.Stop()
|
|
|
|
assert.True(t, rl.Allow("k"))
|
|
assert.False(t, rl.Allow("k"))
|
|
|
|
rl.UpdateConfig(&RateLimiterConfig{RequestsPerMinute: 0, Burst: 0, CleanupInterval: time.Minute, LimiterTTL: time.Minute})
|
|
rl.UpdateConfig(&RateLimiterConfig{RequestsPerMinute: -1, Burst: 5, CleanupInterval: time.Minute, LimiterTTL: time.Minute})
|
|
rl.UpdateConfig(&RateLimiterConfig{RequestsPerMinute: 60, Burst: -1, CleanupInterval: time.Minute, LimiterTTL: time.Minute})
|
|
|
|
rl.Reset("k")
|
|
assert.True(t, rl.Allow("k"))
|
|
assert.False(t, rl.Allow("k"), "burst should still be 1 — invalid UpdateConfig calls were ignored")
|
|
}
|
|
|
|
func TestAPIRateLimiter_ConcurrentAllowAndUpdate(t *testing.T) {
|
|
rl := NewAPIRateLimiter(&RateLimiterConfig{
|
|
RequestsPerMinute: 600,
|
|
Burst: 10,
|
|
CleanupInterval: time.Minute,
|
|
LimiterTTL: time.Minute,
|
|
})
|
|
defer rl.Stop()
|
|
|
|
var wg sync.WaitGroup
|
|
stop := make(chan struct{})
|
|
|
|
for i := 0; i < 8; i++ {
|
|
wg.Add(1)
|
|
go func(id int) {
|
|
defer wg.Done()
|
|
key := fmt.Sprintf("k%d", id)
|
|
for {
|
|
select {
|
|
case <-stop:
|
|
return
|
|
default:
|
|
rl.Allow(key)
|
|
}
|
|
}
|
|
}(i)
|
|
}
|
|
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
for i := 0; i < 200; i++ {
|
|
select {
|
|
case <-stop:
|
|
return
|
|
default:
|
|
rl.UpdateConfig(&RateLimiterConfig{
|
|
RequestsPerMinute: float64(30 + (i % 90)),
|
|
Burst: 1 + (i % 20),
|
|
CleanupInterval: time.Minute,
|
|
LimiterTTL: time.Minute,
|
|
})
|
|
rl.SetEnabled(i%2 == 0)
|
|
}
|
|
}
|
|
}()
|
|
|
|
time.Sleep(100 * time.Millisecond)
|
|
close(stop)
|
|
wg.Wait()
|
|
}
|
|
|
|
func TestRateLimiterConfigFromEnv(t *testing.T) {
|
|
t.Setenv(RateLimitingEnabledEnv, "true")
|
|
t.Setenv(RateLimitingRPMEnv, "42")
|
|
t.Setenv(RateLimitingBurstEnv, "7")
|
|
|
|
cfg, enabled := RateLimiterConfigFromEnv()
|
|
assert.True(t, enabled)
|
|
assert.Equal(t, float64(42), cfg.RequestsPerMinute)
|
|
assert.Equal(t, 7, cfg.Burst)
|
|
|
|
t.Setenv(RateLimitingEnabledEnv, "false")
|
|
_, enabled = RateLimiterConfigFromEnv()
|
|
assert.False(t, enabled)
|
|
|
|
t.Setenv(RateLimitingEnabledEnv, "")
|
|
t.Setenv(RateLimitingRPMEnv, "")
|
|
t.Setenv(RateLimitingBurstEnv, "")
|
|
cfg, enabled = RateLimiterConfigFromEnv()
|
|
assert.False(t, enabled)
|
|
assert.Equal(t, float64(defaultAPIRPM), cfg.RequestsPerMinute)
|
|
assert.Equal(t, defaultAPIBurst, cfg.Burst)
|
|
|
|
t.Setenv(RateLimitingRPMEnv, "0")
|
|
t.Setenv(RateLimitingBurstEnv, "-5")
|
|
cfg, _ = RateLimiterConfigFromEnv()
|
|
assert.Equal(t, float64(defaultAPIRPM), cfg.RequestsPerMinute, "non-positive rpm must fall back to default")
|
|
assert.Equal(t, defaultAPIBurst, cfg.Burst, "non-positive burst must fall back to default")
|
|
}
|