mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-22 10:16:38 +00:00
feat: changeable pat rate limiting
This commit is contained in:
@@ -42,14 +42,9 @@ func NewAuthMiddleware(
|
||||
ensureAccount EnsureAccountFunc,
|
||||
syncUserJWTGroups SyncUserJWTGroupsFunc,
|
||||
getUserFromUserAuth GetUserFromUserAuthFunc,
|
||||
rateLimiterConfig *RateLimiterConfig,
|
||||
rateLimiter *APIRateLimiter,
|
||||
meter metric.Meter,
|
||||
) *AuthMiddleware {
|
||||
var rateLimiter *APIRateLimiter
|
||||
if rateLimiterConfig != nil {
|
||||
rateLimiter = NewAPIRateLimiter(rateLimiterConfig)
|
||||
}
|
||||
|
||||
var patUsageTracker *PATUsageTracker
|
||||
if meter != nil {
|
||||
var err error
|
||||
@@ -181,10 +176,8 @@ func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts []
|
||||
m.patUsageTracker.IncrementUsage(token)
|
||||
}
|
||||
|
||||
if m.rateLimiter != nil && !isTerraformRequest(r) {
|
||||
if !m.rateLimiter.Allow(token) {
|
||||
return r, status.Errorf(status.TooManyRequests, "too many requests")
|
||||
}
|
||||
if !isTerraformRequest(r) && !m.rateLimiter.Allow(token) {
|
||||
return r, status.Errorf(status.TooManyRequests, "too many requests")
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
|
||||
@@ -196,6 +196,8 @@ func TestAuthMiddleware_Handler(t *testing.T) {
|
||||
GetPATInfoFunc: mockGetAccountInfoFromPAT,
|
||||
}
|
||||
|
||||
disabledLimiter := NewAPIRateLimiter(nil)
|
||||
disabledLimiter.SetEnabled(false)
|
||||
authMiddleware := NewAuthMiddleware(
|
||||
mockAuth,
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) {
|
||||
@@ -207,7 +209,7 @@ func TestAuthMiddleware_Handler(t *testing.T) {
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
|
||||
return &types.User{}, nil
|
||||
},
|
||||
nil,
|
||||
disabledLimiter,
|
||||
nil,
|
||||
)
|
||||
|
||||
@@ -266,7 +268,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
|
||||
return &types.User{}, nil
|
||||
},
|
||||
rateLimitConfig,
|
||||
NewAPIRateLimiter(rateLimitConfig),
|
||||
nil,
|
||||
)
|
||||
|
||||
@@ -318,7 +320,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
|
||||
return &types.User{}, nil
|
||||
},
|
||||
rateLimitConfig,
|
||||
NewAPIRateLimiter(rateLimitConfig),
|
||||
nil,
|
||||
)
|
||||
|
||||
@@ -361,7 +363,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
|
||||
return &types.User{}, nil
|
||||
},
|
||||
rateLimitConfig,
|
||||
NewAPIRateLimiter(rateLimitConfig),
|
||||
nil,
|
||||
)
|
||||
|
||||
@@ -405,7 +407,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
|
||||
return &types.User{}, nil
|
||||
},
|
||||
rateLimitConfig,
|
||||
NewAPIRateLimiter(rateLimitConfig),
|
||||
nil,
|
||||
)
|
||||
|
||||
@@ -469,7 +471,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
|
||||
return &types.User{}, nil
|
||||
},
|
||||
rateLimitConfig,
|
||||
NewAPIRateLimiter(rateLimitConfig),
|
||||
nil,
|
||||
)
|
||||
|
||||
@@ -528,7 +530,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
|
||||
return &types.User{}, nil
|
||||
},
|
||||
rateLimitConfig,
|
||||
NewAPIRateLimiter(rateLimitConfig),
|
||||
nil,
|
||||
)
|
||||
|
||||
@@ -583,7 +585,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
|
||||
return &types.User{}, nil
|
||||
},
|
||||
rateLimitConfig,
|
||||
NewAPIRateLimiter(rateLimitConfig),
|
||||
nil,
|
||||
)
|
||||
|
||||
@@ -670,6 +672,8 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) {
|
||||
GetPATInfoFunc: mockGetAccountInfoFromPAT,
|
||||
}
|
||||
|
||||
disabledLimiter := NewAPIRateLimiter(nil)
|
||||
disabledLimiter.SetEnabled(false)
|
||||
authMiddleware := NewAuthMiddleware(
|
||||
mockAuth,
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) {
|
||||
@@ -681,7 +685,7 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) {
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
|
||||
return &types.User{}, nil
|
||||
},
|
||||
nil,
|
||||
disabledLimiter,
|
||||
nil,
|
||||
)
|
||||
|
||||
|
||||
@@ -4,14 +4,27 @@ import (
|
||||
"context"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"strconv"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/time/rate"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
)
|
||||
|
||||
const (
|
||||
RateLimitingEnabledEnv = "NB_API_RATE_LIMITING_ENABLED"
|
||||
RateLimitingBurstEnv = "NB_API_RATE_LIMITING_BURST"
|
||||
RateLimitingRPMEnv = "NB_API_RATE_LIMITING_RPM"
|
||||
|
||||
defaultAPIRPM = 6
|
||||
defaultAPIBurst = 500
|
||||
)
|
||||
|
||||
// RateLimiterConfig holds configuration for the API rate limiter
|
||||
type RateLimiterConfig struct {
|
||||
// RequestsPerMinute defines the rate at which tokens are replenished
|
||||
@@ -34,6 +47,35 @@ func DefaultRateLimiterConfig() *RateLimiterConfig {
|
||||
}
|
||||
}
|
||||
|
||||
func RateLimiterConfigFromEnv() (cfg *RateLimiterConfig, enabled bool) {
|
||||
rpm := defaultAPIRPM
|
||||
if v := os.Getenv(RateLimitingRPMEnv); v != "" {
|
||||
value, err := strconv.Atoi(v)
|
||||
if err != nil {
|
||||
log.Warnf("parsing %s env var: %v, using default %d", RateLimitingRPMEnv, err, rpm)
|
||||
} else {
|
||||
rpm = value
|
||||
}
|
||||
}
|
||||
|
||||
burst := defaultAPIBurst
|
||||
if v := os.Getenv(RateLimitingBurstEnv); v != "" {
|
||||
value, err := strconv.Atoi(v)
|
||||
if err != nil {
|
||||
log.Warnf("parsing %s env var: %v, using default %d", RateLimitingBurstEnv, err, burst)
|
||||
} else {
|
||||
burst = value
|
||||
}
|
||||
}
|
||||
|
||||
return &RateLimiterConfig{
|
||||
RequestsPerMinute: float64(rpm),
|
||||
Burst: burst,
|
||||
CleanupInterval: 6 * time.Hour,
|
||||
LimiterTTL: 24 * time.Hour,
|
||||
}, os.Getenv(RateLimitingEnabledEnv) == "true"
|
||||
}
|
||||
|
||||
// limiterEntry holds a rate limiter and its last access time
|
||||
type limiterEntry struct {
|
||||
limiter *rate.Limiter
|
||||
@@ -46,6 +88,7 @@ type APIRateLimiter struct {
|
||||
limiters map[string]*limiterEntry
|
||||
mu sync.RWMutex
|
||||
stopChan chan struct{}
|
||||
enabled atomic.Bool
|
||||
}
|
||||
|
||||
// NewAPIRateLimiter creates a new API rate limiter with the given configuration
|
||||
@@ -59,14 +102,49 @@ func NewAPIRateLimiter(config *RateLimiterConfig) *APIRateLimiter {
|
||||
limiters: make(map[string]*limiterEntry),
|
||||
stopChan: make(chan struct{}),
|
||||
}
|
||||
rl.enabled.Store(true)
|
||||
|
||||
go rl.cleanupLoop()
|
||||
|
||||
return rl
|
||||
}
|
||||
|
||||
func (rl *APIRateLimiter) SetEnabled(enabled bool) {
|
||||
rl.enabled.Store(enabled)
|
||||
}
|
||||
|
||||
func (rl *APIRateLimiter) Enabled() bool {
|
||||
return rl.enabled.Load()
|
||||
}
|
||||
|
||||
func (rl *APIRateLimiter) UpdateConfig(config *RateLimiterConfig) {
|
||||
if config == nil {
|
||||
return
|
||||
}
|
||||
|
||||
newRPS := rate.Limit(config.RequestsPerMinute / 60.0)
|
||||
newBurst := config.Burst
|
||||
|
||||
rl.mu.Lock()
|
||||
rl.config.RequestsPerMinute = config.RequestsPerMinute
|
||||
rl.config.Burst = newBurst
|
||||
snapshot := make([]*rate.Limiter, 0, len(rl.limiters))
|
||||
for _, entry := range rl.limiters {
|
||||
snapshot = append(snapshot, entry.limiter)
|
||||
}
|
||||
rl.mu.Unlock()
|
||||
|
||||
for _, l := range snapshot {
|
||||
l.SetLimit(newRPS)
|
||||
l.SetBurst(newBurst)
|
||||
}
|
||||
}
|
||||
|
||||
// Allow checks if a request for the given key (token) is allowed
|
||||
func (rl *APIRateLimiter) Allow(key string) bool {
|
||||
if !rl.enabled.Load() {
|
||||
return true
|
||||
}
|
||||
limiter := rl.getLimiter(key)
|
||||
return limiter.Allow()
|
||||
}
|
||||
@@ -74,6 +152,9 @@ func (rl *APIRateLimiter) Allow(key string) bool {
|
||||
// Wait blocks until the rate limiter allows another request for the given key
|
||||
// Returns an error if the context is canceled
|
||||
func (rl *APIRateLimiter) Wait(ctx context.Context, key string) error {
|
||||
if !rl.enabled.Load() {
|
||||
return nil
|
||||
}
|
||||
limiter := rl.getLimiter(key)
|
||||
return limiter.Wait(ctx)
|
||||
}
|
||||
@@ -153,6 +234,10 @@ func (rl *APIRateLimiter) Reset(key string) {
|
||||
// Returns 429 Too Many Requests if the rate limit is exceeded.
|
||||
func (rl *APIRateLimiter) Middleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if !rl.enabled.Load() {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
clientIP := getClientIP(r)
|
||||
if !rl.Allow(clientIP) {
|
||||
util.WriteErrorResponse("rate limit exceeded, please try again later", http.StatusTooManyRequests, w)
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -156,3 +158,145 @@ func TestAPIRateLimiter_Reset(t *testing.T) {
|
||||
// Should be allowed again
|
||||
assert.True(t, rl.Allow("test-key"))
|
||||
}
|
||||
|
||||
func TestAPIRateLimiter_SetEnabled(t *testing.T) {
|
||||
rl := NewAPIRateLimiter(&RateLimiterConfig{
|
||||
RequestsPerMinute: 60,
|
||||
Burst: 1,
|
||||
CleanupInterval: time.Minute,
|
||||
LimiterTTL: time.Minute,
|
||||
})
|
||||
defer rl.Stop()
|
||||
|
||||
assert.True(t, rl.Allow("key"))
|
||||
assert.False(t, rl.Allow("key"), "burst exhausted while enabled")
|
||||
|
||||
rl.SetEnabled(false)
|
||||
assert.False(t, rl.Enabled())
|
||||
for i := 0; i < 5; i++ {
|
||||
assert.True(t, rl.Allow("key"), "disabled limiter must always allow")
|
||||
}
|
||||
|
||||
rl.SetEnabled(true)
|
||||
assert.True(t, rl.Enabled())
|
||||
assert.False(t, rl.Allow("key"), "re-enabled limiter retains prior bucket state")
|
||||
}
|
||||
|
||||
func TestAPIRateLimiter_UpdateConfig(t *testing.T) {
|
||||
rl := NewAPIRateLimiter(&RateLimiterConfig{
|
||||
RequestsPerMinute: 60,
|
||||
Burst: 2,
|
||||
CleanupInterval: time.Minute,
|
||||
LimiterTTL: time.Minute,
|
||||
})
|
||||
defer rl.Stop()
|
||||
|
||||
assert.True(t, rl.Allow("k1"))
|
||||
assert.True(t, rl.Allow("k1"))
|
||||
assert.False(t, rl.Allow("k1"), "burst=2 exhausted")
|
||||
|
||||
rl.UpdateConfig(&RateLimiterConfig{
|
||||
RequestsPerMinute: 60,
|
||||
Burst: 10,
|
||||
CleanupInterval: time.Minute,
|
||||
LimiterTTL: time.Minute,
|
||||
})
|
||||
|
||||
// New burst applies to existing keys in place; bucket refills up to new burst over time,
|
||||
// but importantly newly-added keys use the updated config immediately.
|
||||
assert.True(t, rl.Allow("k2"))
|
||||
for i := 0; i < 9; i++ {
|
||||
assert.True(t, rl.Allow("k2"))
|
||||
}
|
||||
assert.False(t, rl.Allow("k2"), "new burst=10 exhausted")
|
||||
}
|
||||
|
||||
func TestAPIRateLimiter_UpdateConfig_NilIgnored(t *testing.T) {
|
||||
rl := NewAPIRateLimiter(&RateLimiterConfig{
|
||||
RequestsPerMinute: 60,
|
||||
Burst: 1,
|
||||
CleanupInterval: time.Minute,
|
||||
LimiterTTL: time.Minute,
|
||||
})
|
||||
defer rl.Stop()
|
||||
|
||||
rl.UpdateConfig(nil) // must not panic or zero the config
|
||||
|
||||
assert.True(t, rl.Allow("k"))
|
||||
assert.False(t, rl.Allow("k"))
|
||||
}
|
||||
|
||||
func TestAPIRateLimiter_ConcurrentAllowAndUpdate(t *testing.T) {
|
||||
rl := NewAPIRateLimiter(&RateLimiterConfig{
|
||||
RequestsPerMinute: 600,
|
||||
Burst: 10,
|
||||
CleanupInterval: time.Minute,
|
||||
LimiterTTL: time.Minute,
|
||||
})
|
||||
defer rl.Stop()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
stop := make(chan struct{})
|
||||
|
||||
for i := 0; i < 8; i++ {
|
||||
wg.Add(1)
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
key := fmt.Sprintf("k%d", id)
|
||||
for {
|
||||
select {
|
||||
case <-stop:
|
||||
return
|
||||
default:
|
||||
rl.Allow(key)
|
||||
}
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for i := 0; i < 200; i++ {
|
||||
select {
|
||||
case <-stop:
|
||||
return
|
||||
default:
|
||||
rl.UpdateConfig(&RateLimiterConfig{
|
||||
RequestsPerMinute: float64(30 + (i % 90)),
|
||||
Burst: 1 + (i % 20),
|
||||
CleanupInterval: time.Minute,
|
||||
LimiterTTL: time.Minute,
|
||||
})
|
||||
rl.SetEnabled(i%2 == 0)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
close(stop)
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestRateLimiterConfigFromEnv(t *testing.T) {
|
||||
t.Setenv(RateLimitingEnabledEnv, "true")
|
||||
t.Setenv(RateLimitingRPMEnv, "42")
|
||||
t.Setenv(RateLimitingBurstEnv, "7")
|
||||
|
||||
cfg, enabled := RateLimiterConfigFromEnv()
|
||||
assert.True(t, enabled)
|
||||
assert.Equal(t, float64(42), cfg.RequestsPerMinute)
|
||||
assert.Equal(t, 7, cfg.Burst)
|
||||
|
||||
t.Setenv(RateLimitingEnabledEnv, "false")
|
||||
_, enabled = RateLimiterConfigFromEnv()
|
||||
assert.False(t, enabled)
|
||||
|
||||
t.Setenv(RateLimitingEnabledEnv, "")
|
||||
t.Setenv(RateLimitingRPMEnv, "")
|
||||
t.Setenv(RateLimitingBurstEnv, "")
|
||||
cfg, enabled = RateLimiterConfigFromEnv()
|
||||
assert.False(t, enabled)
|
||||
assert.Equal(t, float64(defaultAPIRPM), cfg.RequestsPerMinute)
|
||||
assert.Equal(t, defaultAPIBurst, cfg.Burst)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user