mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-22 02:06:39 +00:00
feat: changeable pat rate limiting
This commit is contained in:
@@ -4,14 +4,27 @@ import (
|
||||
"context"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"strconv"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/time/rate"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
)
|
||||
|
||||
const (
|
||||
RateLimitingEnabledEnv = "NB_API_RATE_LIMITING_ENABLED"
|
||||
RateLimitingBurstEnv = "NB_API_RATE_LIMITING_BURST"
|
||||
RateLimitingRPMEnv = "NB_API_RATE_LIMITING_RPM"
|
||||
|
||||
defaultAPIRPM = 6
|
||||
defaultAPIBurst = 500
|
||||
)
|
||||
|
||||
// RateLimiterConfig holds configuration for the API rate limiter
|
||||
type RateLimiterConfig struct {
|
||||
// RequestsPerMinute defines the rate at which tokens are replenished
|
||||
@@ -34,6 +47,35 @@ func DefaultRateLimiterConfig() *RateLimiterConfig {
|
||||
}
|
||||
}
|
||||
|
||||
func RateLimiterConfigFromEnv() (cfg *RateLimiterConfig, enabled bool) {
|
||||
rpm := defaultAPIRPM
|
||||
if v := os.Getenv(RateLimitingRPMEnv); v != "" {
|
||||
value, err := strconv.Atoi(v)
|
||||
if err != nil {
|
||||
log.Warnf("parsing %s env var: %v, using default %d", RateLimitingRPMEnv, err, rpm)
|
||||
} else {
|
||||
rpm = value
|
||||
}
|
||||
}
|
||||
|
||||
burst := defaultAPIBurst
|
||||
if v := os.Getenv(RateLimitingBurstEnv); v != "" {
|
||||
value, err := strconv.Atoi(v)
|
||||
if err != nil {
|
||||
log.Warnf("parsing %s env var: %v, using default %d", RateLimitingBurstEnv, err, burst)
|
||||
} else {
|
||||
burst = value
|
||||
}
|
||||
}
|
||||
|
||||
return &RateLimiterConfig{
|
||||
RequestsPerMinute: float64(rpm),
|
||||
Burst: burst,
|
||||
CleanupInterval: 6 * time.Hour,
|
||||
LimiterTTL: 24 * time.Hour,
|
||||
}, os.Getenv(RateLimitingEnabledEnv) == "true"
|
||||
}
|
||||
|
||||
// limiterEntry holds a rate limiter and its last access time
|
||||
type limiterEntry struct {
|
||||
limiter *rate.Limiter
|
||||
@@ -46,6 +88,7 @@ type APIRateLimiter struct {
|
||||
limiters map[string]*limiterEntry
|
||||
mu sync.RWMutex
|
||||
stopChan chan struct{}
|
||||
enabled atomic.Bool
|
||||
}
|
||||
|
||||
// NewAPIRateLimiter creates a new API rate limiter with the given configuration
|
||||
@@ -59,14 +102,49 @@ func NewAPIRateLimiter(config *RateLimiterConfig) *APIRateLimiter {
|
||||
limiters: make(map[string]*limiterEntry),
|
||||
stopChan: make(chan struct{}),
|
||||
}
|
||||
rl.enabled.Store(true)
|
||||
|
||||
go rl.cleanupLoop()
|
||||
|
||||
return rl
|
||||
}
|
||||
|
||||
func (rl *APIRateLimiter) SetEnabled(enabled bool) {
|
||||
rl.enabled.Store(enabled)
|
||||
}
|
||||
|
||||
func (rl *APIRateLimiter) Enabled() bool {
|
||||
return rl.enabled.Load()
|
||||
}
|
||||
|
||||
func (rl *APIRateLimiter) UpdateConfig(config *RateLimiterConfig) {
|
||||
if config == nil {
|
||||
return
|
||||
}
|
||||
|
||||
newRPS := rate.Limit(config.RequestsPerMinute / 60.0)
|
||||
newBurst := config.Burst
|
||||
|
||||
rl.mu.Lock()
|
||||
rl.config.RequestsPerMinute = config.RequestsPerMinute
|
||||
rl.config.Burst = newBurst
|
||||
snapshot := make([]*rate.Limiter, 0, len(rl.limiters))
|
||||
for _, entry := range rl.limiters {
|
||||
snapshot = append(snapshot, entry.limiter)
|
||||
}
|
||||
rl.mu.Unlock()
|
||||
|
||||
for _, l := range snapshot {
|
||||
l.SetLimit(newRPS)
|
||||
l.SetBurst(newBurst)
|
||||
}
|
||||
}
|
||||
|
||||
// Allow checks if a request for the given key (token) is allowed
|
||||
func (rl *APIRateLimiter) Allow(key string) bool {
|
||||
if !rl.enabled.Load() {
|
||||
return true
|
||||
}
|
||||
limiter := rl.getLimiter(key)
|
||||
return limiter.Allow()
|
||||
}
|
||||
@@ -74,6 +152,9 @@ func (rl *APIRateLimiter) Allow(key string) bool {
|
||||
// Wait blocks until the rate limiter allows another request for the given key
|
||||
// Returns an error if the context is canceled
|
||||
func (rl *APIRateLimiter) Wait(ctx context.Context, key string) error {
|
||||
if !rl.enabled.Load() {
|
||||
return nil
|
||||
}
|
||||
limiter := rl.getLimiter(key)
|
||||
return limiter.Wait(ctx)
|
||||
}
|
||||
@@ -153,6 +234,10 @@ func (rl *APIRateLimiter) Reset(key string) {
|
||||
// Returns 429 Too Many Requests if the rate limit is exceeded.
|
||||
func (rl *APIRateLimiter) Middleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if !rl.enabled.Load() {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
clientIP := getClientIP(r)
|
||||
if !rl.Allow(clientIP) {
|
||||
util.WriteErrorResponse("rate limit exceeded, please try again later", http.StatusTooManyRequests, w)
|
||||
|
||||
Reference in New Issue
Block a user