Compare commits

...

2 Commits

Author SHA1 Message Date
crn4
6b252de820 review comments 2026-04-21 18:30:33 +02:00
crn4
b6b1b8b338 feat: changeable pat rate limiting 2026-04-21 17:41:51 +02:00
7 changed files with 304 additions and 59 deletions

View File

@@ -30,6 +30,7 @@ import (
nbcache "github.com/netbirdio/netbird/management/server/cache" nbcache "github.com/netbirdio/netbird/management/server/cache"
nbContext "github.com/netbirdio/netbird/management/server/context" nbContext "github.com/netbirdio/netbird/management/server/context"
nbhttp "github.com/netbirdio/netbird/management/server/http" nbhttp "github.com/netbirdio/netbird/management/server/http"
"github.com/netbirdio/netbird/management/server/http/middleware"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/telemetry"
mgmtProto "github.com/netbirdio/netbird/shared/management/proto" mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
@@ -109,7 +110,7 @@ func (s *BaseServer) EventStore() activity.Store {
func (s *BaseServer) APIHandler() http.Handler { func (s *BaseServer) APIHandler() http.Handler {
return Create(s, func() http.Handler { return Create(s, func() http.Handler {
httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.IdpManager(), s.ServiceManager(), s.ReverseProxyDomainManager(), s.AccessLogsManager(), s.ReverseProxyGRPCServer(), s.Config.ReverseProxy.TrustedHTTPProxies) httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.IdpManager(), s.ServiceManager(), s.ReverseProxyDomainManager(), s.AccessLogsManager(), s.ReverseProxyGRPCServer(), s.Config.ReverseProxy.TrustedHTTPProxies, s.RateLimiter())
if err != nil { if err != nil {
log.Fatalf("failed to create API handler: %v", err) log.Fatalf("failed to create API handler: %v", err)
} }
@@ -117,6 +118,15 @@ func (s *BaseServer) APIHandler() http.Handler {
}) })
} }
func (s *BaseServer) RateLimiter() *middleware.APIRateLimiter {
return Create(s, func() *middleware.APIRateLimiter {
cfg, enabled := middleware.RateLimiterConfigFromEnv()
limiter := middleware.NewAPIRateLimiter(cfg)
limiter.SetEnabled(enabled)
return limiter
})
}
func (s *BaseServer) GRPCServer() *grpc.Server { func (s *BaseServer) GRPCServer() *grpc.Server {
return Create(s, func() *grpc.Server { return Create(s, func() *grpc.Server {
trustedPeers := s.Config.ReverseProxy.TrustedPeers trustedPeers := s.Config.ReverseProxy.TrustedPeers

View File

@@ -5,9 +5,6 @@ import (
"fmt" "fmt"
"net/http" "net/http"
"net/netip" "net/netip"
"os"
"strconv"
"time"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/rs/cors" "github.com/rs/cors"
@@ -67,13 +64,10 @@ import (
const ( const (
apiPrefix = "/api" apiPrefix = "/api"
rateLimitingEnabledKey = "NB_API_RATE_LIMITING_ENABLED"
rateLimitingBurstKey = "NB_API_RATE_LIMITING_BURST"
rateLimitingRPMKey = "NB_API_RATE_LIMITING_RPM"
) )
// NewAPIHandler creates the Management service HTTP API handler registering all the available endpoints. // NewAPIHandler creates the Management service HTTP API handler registering all the available endpoints.
func NewAPIHandler(ctx context.Context, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, integratedValidator integrated_validator.IntegratedValidator, proxyController port_forwarding.Controller, permissionsManager permissions.Manager, peersManager nbpeers.Manager, settingsManager settings.Manager, zManager zones.Manager, rManager records.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager, serviceManager service.Manager, reverseProxyDomainManager *manager.Manager, reverseProxyAccessLogsManager accesslogs.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, trustedHTTPProxies []netip.Prefix) (http.Handler, error) { func NewAPIHandler(ctx context.Context, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, integratedValidator integrated_validator.IntegratedValidator, proxyController port_forwarding.Controller, permissionsManager permissions.Manager, peersManager nbpeers.Manager, settingsManager settings.Manager, zManager zones.Manager, rManager records.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager, serviceManager service.Manager, reverseProxyDomainManager *manager.Manager, reverseProxyAccessLogsManager accesslogs.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, trustedHTTPProxies []netip.Prefix, rateLimiter *middleware.APIRateLimiter) (http.Handler, error) {
// Register bypass paths for unauthenticated endpoints // Register bypass paths for unauthenticated endpoints
if err := bypass.AddBypassPath("/api/instance"); err != nil { if err := bypass.AddBypassPath("/api/instance"); err != nil {
@@ -94,34 +88,10 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks
return nil, fmt.Errorf("failed to add bypass path: %w", err) return nil, fmt.Errorf("failed to add bypass path: %w", err)
} }
var rateLimitingConfig *middleware.RateLimiterConfig if rateLimiter == nil {
if os.Getenv(rateLimitingEnabledKey) == "true" { log.Warn("NewAPIHandler: nil rate limiter, rate limiting disabled")
rpm := 6 rateLimiter = middleware.NewAPIRateLimiter(nil)
if v := os.Getenv(rateLimitingRPMKey); v != "" { rateLimiter.SetEnabled(false)
value, err := strconv.Atoi(v)
if err != nil {
log.Warnf("parsing %s env var: %v, using default %d", rateLimitingRPMKey, err, rpm)
} else {
rpm = value
}
}
burst := 500
if v := os.Getenv(rateLimitingBurstKey); v != "" {
value, err := strconv.Atoi(v)
if err != nil {
log.Warnf("parsing %s env var: %v, using default %d", rateLimitingBurstKey, err, burst)
} else {
burst = value
}
}
rateLimitingConfig = &middleware.RateLimiterConfig{
RequestsPerMinute: float64(rpm),
Burst: burst,
CleanupInterval: 6 * time.Hour,
LimiterTTL: 24 * time.Hour,
}
} }
authMiddleware := middleware.NewAuthMiddleware( authMiddleware := middleware.NewAuthMiddleware(
@@ -129,7 +99,7 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks
accountManager.GetAccountIDFromUserAuth, accountManager.GetAccountIDFromUserAuth,
accountManager.SyncUserJWTGroups, accountManager.SyncUserJWTGroups,
accountManager.GetUserFromUserAuth, accountManager.GetUserFromUserAuth,
rateLimitingConfig, rateLimiter,
appMetrics.GetMeter(), appMetrics.GetMeter(),
) )

View File

@@ -42,14 +42,9 @@ func NewAuthMiddleware(
ensureAccount EnsureAccountFunc, ensureAccount EnsureAccountFunc,
syncUserJWTGroups SyncUserJWTGroupsFunc, syncUserJWTGroups SyncUserJWTGroupsFunc,
getUserFromUserAuth GetUserFromUserAuthFunc, getUserFromUserAuth GetUserFromUserAuthFunc,
rateLimiterConfig *RateLimiterConfig, rateLimiter *APIRateLimiter,
meter metric.Meter, meter metric.Meter,
) *AuthMiddleware { ) *AuthMiddleware {
var rateLimiter *APIRateLimiter
if rateLimiterConfig != nil {
rateLimiter = NewAPIRateLimiter(rateLimiterConfig)
}
var patUsageTracker *PATUsageTracker var patUsageTracker *PATUsageTracker
if meter != nil { if meter != nil {
var err error var err error
@@ -181,11 +176,9 @@ func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts []
m.patUsageTracker.IncrementUsage(token) m.patUsageTracker.IncrementUsage(token)
} }
if m.rateLimiter != nil && !isTerraformRequest(r) { if !isTerraformRequest(r) && !m.rateLimiter.Allow(token) {
if !m.rateLimiter.Allow(token) {
return r, status.Errorf(status.TooManyRequests, "too many requests") return r, status.Errorf(status.TooManyRequests, "too many requests")
} }
}
ctx := r.Context() ctx := r.Context()
user, pat, accDomain, accCategory, err := m.authManager.GetPATInfo(ctx, token) user, pat, accDomain, accCategory, err := m.authManager.GetPATInfo(ctx, token)

View File

@@ -196,6 +196,8 @@ func TestAuthMiddleware_Handler(t *testing.T) {
GetPATInfoFunc: mockGetAccountInfoFromPAT, GetPATInfoFunc: mockGetAccountInfoFromPAT,
} }
disabledLimiter := NewAPIRateLimiter(nil)
disabledLimiter.SetEnabled(false)
authMiddleware := NewAuthMiddleware( authMiddleware := NewAuthMiddleware(
mockAuth, mockAuth,
func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) { func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) {
@@ -207,7 +209,7 @@ func TestAuthMiddleware_Handler(t *testing.T) {
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) { func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
return &types.User{}, nil return &types.User{}, nil
}, },
nil, disabledLimiter,
nil, nil,
) )
@@ -266,7 +268,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) { func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
return &types.User{}, nil return &types.User{}, nil
}, },
rateLimitConfig, NewAPIRateLimiter(rateLimitConfig),
nil, nil,
) )
@@ -318,7 +320,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) { func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
return &types.User{}, nil return &types.User{}, nil
}, },
rateLimitConfig, NewAPIRateLimiter(rateLimitConfig),
nil, nil,
) )
@@ -361,7 +363,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) { func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
return &types.User{}, nil return &types.User{}, nil
}, },
rateLimitConfig, NewAPIRateLimiter(rateLimitConfig),
nil, nil,
) )
@@ -405,7 +407,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) { func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
return &types.User{}, nil return &types.User{}, nil
}, },
rateLimitConfig, NewAPIRateLimiter(rateLimitConfig),
nil, nil,
) )
@@ -469,7 +471,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) { func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
return &types.User{}, nil return &types.User{}, nil
}, },
rateLimitConfig, NewAPIRateLimiter(rateLimitConfig),
nil, nil,
) )
@@ -528,7 +530,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) { func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
return &types.User{}, nil return &types.User{}, nil
}, },
rateLimitConfig, NewAPIRateLimiter(rateLimitConfig),
nil, nil,
) )
@@ -583,7 +585,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) { func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
return &types.User{}, nil return &types.User{}, nil
}, },
rateLimitConfig, NewAPIRateLimiter(rateLimitConfig),
nil, nil,
) )
@@ -670,6 +672,8 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) {
GetPATInfoFunc: mockGetAccountInfoFromPAT, GetPATInfoFunc: mockGetAccountInfoFromPAT,
} }
disabledLimiter := NewAPIRateLimiter(nil)
disabledLimiter.SetEnabled(false)
authMiddleware := NewAuthMiddleware( authMiddleware := NewAuthMiddleware(
mockAuth, mockAuth,
func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) { func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) {
@@ -681,7 +685,7 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) {
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) { func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
return &types.User{}, nil return &types.User{}, nil
}, },
nil, disabledLimiter,
nil, nil,
) )

View File

@@ -4,14 +4,27 @@ import (
"context" "context"
"net" "net"
"net/http" "net/http"
"os"
"strconv"
"sync" "sync"
"sync/atomic"
"time" "time"
log "github.com/sirupsen/logrus"
"golang.org/x/time/rate" "golang.org/x/time/rate"
"github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/shared/management/http/util"
) )
const (
RateLimitingEnabledEnv = "NB_API_RATE_LIMITING_ENABLED"
RateLimitingBurstEnv = "NB_API_RATE_LIMITING_BURST"
RateLimitingRPMEnv = "NB_API_RATE_LIMITING_RPM"
defaultAPIRPM = 6
defaultAPIBurst = 500
)
// RateLimiterConfig holds configuration for the API rate limiter // RateLimiterConfig holds configuration for the API rate limiter
type RateLimiterConfig struct { type RateLimiterConfig struct {
// RequestsPerMinute defines the rate at which tokens are replenished // RequestsPerMinute defines the rate at which tokens are replenished
@@ -34,6 +47,43 @@ func DefaultRateLimiterConfig() *RateLimiterConfig {
} }
} }
func RateLimiterConfigFromEnv() (cfg *RateLimiterConfig, enabled bool) {
rpm := defaultAPIRPM
if v := os.Getenv(RateLimitingRPMEnv); v != "" {
value, err := strconv.Atoi(v)
if err != nil {
log.Warnf("parsing %s env var: %v, using default %d", RateLimitingRPMEnv, err, rpm)
} else {
rpm = value
}
}
if rpm <= 0 {
log.Warnf("%s=%d is non-positive, using default %d", RateLimitingRPMEnv, rpm, defaultAPIRPM)
rpm = defaultAPIRPM
}
burst := defaultAPIBurst
if v := os.Getenv(RateLimitingBurstEnv); v != "" {
value, err := strconv.Atoi(v)
if err != nil {
log.Warnf("parsing %s env var: %v, using default %d", RateLimitingBurstEnv, err, burst)
} else {
burst = value
}
}
if burst <= 0 {
log.Warnf("%s=%d is non-positive, using default %d", RateLimitingBurstEnv, burst, defaultAPIBurst)
burst = defaultAPIBurst
}
return &RateLimiterConfig{
RequestsPerMinute: float64(rpm),
Burst: burst,
CleanupInterval: 6 * time.Hour,
LimiterTTL: 24 * time.Hour,
}, os.Getenv(RateLimitingEnabledEnv) == "true"
}
// limiterEntry holds a rate limiter and its last access time // limiterEntry holds a rate limiter and its last access time
type limiterEntry struct { type limiterEntry struct {
limiter *rate.Limiter limiter *rate.Limiter
@@ -46,6 +96,7 @@ type APIRateLimiter struct {
limiters map[string]*limiterEntry limiters map[string]*limiterEntry
mu sync.RWMutex mu sync.RWMutex
stopChan chan struct{} stopChan chan struct{}
enabled atomic.Bool
} }
// NewAPIRateLimiter creates a new API rate limiter with the given configuration // NewAPIRateLimiter creates a new API rate limiter with the given configuration
@@ -59,14 +110,53 @@ func NewAPIRateLimiter(config *RateLimiterConfig) *APIRateLimiter {
limiters: make(map[string]*limiterEntry), limiters: make(map[string]*limiterEntry),
stopChan: make(chan struct{}), stopChan: make(chan struct{}),
} }
rl.enabled.Store(true)
go rl.cleanupLoop() go rl.cleanupLoop()
return rl return rl
} }
func (rl *APIRateLimiter) SetEnabled(enabled bool) {
rl.enabled.Store(enabled)
}
func (rl *APIRateLimiter) Enabled() bool {
return rl.enabled.Load()
}
func (rl *APIRateLimiter) UpdateConfig(config *RateLimiterConfig) {
if config == nil {
return
}
if config.RequestsPerMinute <= 0 || config.Burst <= 0 {
log.Warnf("UpdateConfig: ignoring invalid rpm=%v burst=%d", config.RequestsPerMinute, config.Burst)
return
}
newRPS := rate.Limit(config.RequestsPerMinute / 60.0)
newBurst := config.Burst
rl.mu.Lock()
rl.config.RequestsPerMinute = config.RequestsPerMinute
rl.config.Burst = newBurst
snapshot := make([]*rate.Limiter, 0, len(rl.limiters))
for _, entry := range rl.limiters {
snapshot = append(snapshot, entry.limiter)
}
rl.mu.Unlock()
for _, l := range snapshot {
l.SetLimit(newRPS)
l.SetBurst(newBurst)
}
}
// Allow checks if a request for the given key (token) is allowed // Allow checks if a request for the given key (token) is allowed
func (rl *APIRateLimiter) Allow(key string) bool { func (rl *APIRateLimiter) Allow(key string) bool {
if !rl.enabled.Load() {
return true
}
limiter := rl.getLimiter(key) limiter := rl.getLimiter(key)
return limiter.Allow() return limiter.Allow()
} }
@@ -74,6 +164,9 @@ func (rl *APIRateLimiter) Allow(key string) bool {
// Wait blocks until the rate limiter allows another request for the given key // Wait blocks until the rate limiter allows another request for the given key
// Returns an error if the context is canceled // Returns an error if the context is canceled
func (rl *APIRateLimiter) Wait(ctx context.Context, key string) error { func (rl *APIRateLimiter) Wait(ctx context.Context, key string) error {
if !rl.enabled.Load() {
return nil
}
limiter := rl.getLimiter(key) limiter := rl.getLimiter(key)
return limiter.Wait(ctx) return limiter.Wait(ctx)
} }
@@ -153,6 +246,10 @@ func (rl *APIRateLimiter) Reset(key string) {
// Returns 429 Too Many Requests if the rate limit is exceeded. // Returns 429 Too Many Requests if the rate limit is exceeded.
func (rl *APIRateLimiter) Middleware(next http.Handler) http.Handler { func (rl *APIRateLimiter) Middleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if !rl.enabled.Load() {
next.ServeHTTP(w, r)
return
}
clientIP := getClientIP(r) clientIP := getClientIP(r)
if !rl.Allow(clientIP) { if !rl.Allow(clientIP) {
util.WriteErrorResponse("rate limit exceeded, please try again later", http.StatusTooManyRequests, w) util.WriteErrorResponse("rate limit exceeded, please try again later", http.StatusTooManyRequests, w)

View File

@@ -1,8 +1,10 @@
package middleware package middleware
import ( import (
"fmt"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"sync"
"testing" "testing"
"time" "time"
@@ -156,3 +158,172 @@ func TestAPIRateLimiter_Reset(t *testing.T) {
// Should be allowed again // Should be allowed again
assert.True(t, rl.Allow("test-key")) assert.True(t, rl.Allow("test-key"))
} }
func TestAPIRateLimiter_SetEnabled(t *testing.T) {
rl := NewAPIRateLimiter(&RateLimiterConfig{
RequestsPerMinute: 60,
Burst: 1,
CleanupInterval: time.Minute,
LimiterTTL: time.Minute,
})
defer rl.Stop()
assert.True(t, rl.Allow("key"))
assert.False(t, rl.Allow("key"), "burst exhausted while enabled")
rl.SetEnabled(false)
assert.False(t, rl.Enabled())
for i := 0; i < 5; i++ {
assert.True(t, rl.Allow("key"), "disabled limiter must always allow")
}
rl.SetEnabled(true)
assert.True(t, rl.Enabled())
assert.False(t, rl.Allow("key"), "re-enabled limiter retains prior bucket state")
}
func TestAPIRateLimiter_UpdateConfig(t *testing.T) {
rl := NewAPIRateLimiter(&RateLimiterConfig{
RequestsPerMinute: 60,
Burst: 2,
CleanupInterval: time.Minute,
LimiterTTL: time.Minute,
})
defer rl.Stop()
assert.True(t, rl.Allow("k1"))
assert.True(t, rl.Allow("k1"))
assert.False(t, rl.Allow("k1"), "burst=2 exhausted")
rl.UpdateConfig(&RateLimiterConfig{
RequestsPerMinute: 60,
Burst: 10,
CleanupInterval: time.Minute,
LimiterTTL: time.Minute,
})
// New burst applies to existing keys in place; bucket refills up to new burst over time,
// but importantly newly-added keys use the updated config immediately.
assert.True(t, rl.Allow("k2"))
for i := 0; i < 9; i++ {
assert.True(t, rl.Allow("k2"))
}
assert.False(t, rl.Allow("k2"), "new burst=10 exhausted")
}
func TestAPIRateLimiter_UpdateConfig_NilIgnored(t *testing.T) {
rl := NewAPIRateLimiter(&RateLimiterConfig{
RequestsPerMinute: 60,
Burst: 1,
CleanupInterval: time.Minute,
LimiterTTL: time.Minute,
})
defer rl.Stop()
rl.UpdateConfig(nil) // must not panic or zero the config
assert.True(t, rl.Allow("k"))
assert.False(t, rl.Allow("k"))
}
func TestAPIRateLimiter_UpdateConfig_NonPositiveIgnored(t *testing.T) {
rl := NewAPIRateLimiter(&RateLimiterConfig{
RequestsPerMinute: 60,
Burst: 1,
CleanupInterval: time.Minute,
LimiterTTL: time.Minute,
})
defer rl.Stop()
assert.True(t, rl.Allow("k"))
assert.False(t, rl.Allow("k"))
rl.UpdateConfig(&RateLimiterConfig{RequestsPerMinute: 0, Burst: 0, CleanupInterval: time.Minute, LimiterTTL: time.Minute})
rl.UpdateConfig(&RateLimiterConfig{RequestsPerMinute: -1, Burst: 5, CleanupInterval: time.Minute, LimiterTTL: time.Minute})
rl.UpdateConfig(&RateLimiterConfig{RequestsPerMinute: 60, Burst: -1, CleanupInterval: time.Minute, LimiterTTL: time.Minute})
rl.Reset("k")
assert.True(t, rl.Allow("k"))
assert.False(t, rl.Allow("k"), "burst should still be 1 — invalid UpdateConfig calls were ignored")
}
func TestAPIRateLimiter_ConcurrentAllowAndUpdate(t *testing.T) {
rl := NewAPIRateLimiter(&RateLimiterConfig{
RequestsPerMinute: 600,
Burst: 10,
CleanupInterval: time.Minute,
LimiterTTL: time.Minute,
})
defer rl.Stop()
var wg sync.WaitGroup
stop := make(chan struct{})
for i := 0; i < 8; i++ {
wg.Add(1)
go func(id int) {
defer wg.Done()
key := fmt.Sprintf("k%d", id)
for {
select {
case <-stop:
return
default:
rl.Allow(key)
}
}
}(i)
}
wg.Add(1)
go func() {
defer wg.Done()
for i := 0; i < 200; i++ {
select {
case <-stop:
return
default:
rl.UpdateConfig(&RateLimiterConfig{
RequestsPerMinute: float64(30 + (i % 90)),
Burst: 1 + (i % 20),
CleanupInterval: time.Minute,
LimiterTTL: time.Minute,
})
rl.SetEnabled(i%2 == 0)
}
}
}()
time.Sleep(100 * time.Millisecond)
close(stop)
wg.Wait()
}
func TestRateLimiterConfigFromEnv(t *testing.T) {
t.Setenv(RateLimitingEnabledEnv, "true")
t.Setenv(RateLimitingRPMEnv, "42")
t.Setenv(RateLimitingBurstEnv, "7")
cfg, enabled := RateLimiterConfigFromEnv()
assert.True(t, enabled)
assert.Equal(t, float64(42), cfg.RequestsPerMinute)
assert.Equal(t, 7, cfg.Burst)
t.Setenv(RateLimitingEnabledEnv, "false")
_, enabled = RateLimiterConfigFromEnv()
assert.False(t, enabled)
t.Setenv(RateLimitingEnabledEnv, "")
t.Setenv(RateLimitingRPMEnv, "")
t.Setenv(RateLimitingBurstEnv, "")
cfg, enabled = RateLimiterConfigFromEnv()
assert.False(t, enabled)
assert.Equal(t, float64(defaultAPIRPM), cfg.RequestsPerMinute)
assert.Equal(t, defaultAPIBurst, cfg.Burst)
t.Setenv(RateLimitingRPMEnv, "0")
t.Setenv(RateLimitingBurstEnv, "-5")
cfg, _ = RateLimiterConfigFromEnv()
assert.Equal(t, float64(defaultAPIRPM), cfg.RequestsPerMinute, "non-positive rpm must fall back to default")
assert.Equal(t, defaultAPIBurst, cfg.Burst, "non-positive burst must fall back to default")
}

View File

@@ -135,7 +135,7 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
customZonesManager := zonesManager.NewManager(store, am, permissionsManager, "") customZonesManager := zonesManager.NewManager(store, am, permissionsManager, "")
zoneRecordsManager := recordsManager.NewManager(store, am, permissionsManager) zoneRecordsManager := recordsManager.NewManager(store, am, permissionsManager)
apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil) apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil, nil)
if err != nil { if err != nil {
t.Fatalf("Failed to create API handler: %v", err) t.Fatalf("Failed to create API handler: %v", err)
} }
@@ -264,7 +264,7 @@ func BuildApiBlackBoxWithDBStateAndPeerChannel(t testing_tools.TB, sqlFile strin
customZonesManager := zonesManager.NewManager(store, am, permissionsManager, "") customZonesManager := zonesManager.NewManager(store, am, permissionsManager, "")
zoneRecordsManager := recordsManager.NewManager(store, am, permissionsManager) zoneRecordsManager := recordsManager.NewManager(store, am, permissionsManager)
apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil) apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil, nil)
if err != nil { if err != nil {
t.Fatalf("Failed to create API handler: %v", err) t.Fatalf("Failed to create API handler: %v", err)
} }