mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-21 17:56:39 +00:00
Compare commits
2 Commits
debug-logg
...
feature/dy
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6b252de820 | ||
|
|
b6b1b8b338 |
@@ -30,6 +30,7 @@ import (
|
|||||||
nbcache "github.com/netbirdio/netbird/management/server/cache"
|
nbcache "github.com/netbirdio/netbird/management/server/cache"
|
||||||
nbContext "github.com/netbirdio/netbird/management/server/context"
|
nbContext "github.com/netbirdio/netbird/management/server/context"
|
||||||
nbhttp "github.com/netbirdio/netbird/management/server/http"
|
nbhttp "github.com/netbirdio/netbird/management/server/http"
|
||||||
|
"github.com/netbirdio/netbird/management/server/http/middleware"
|
||||||
"github.com/netbirdio/netbird/management/server/store"
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||||
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
|
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||||
@@ -109,7 +110,7 @@ func (s *BaseServer) EventStore() activity.Store {
|
|||||||
|
|
||||||
func (s *BaseServer) APIHandler() http.Handler {
|
func (s *BaseServer) APIHandler() http.Handler {
|
||||||
return Create(s, func() http.Handler {
|
return Create(s, func() http.Handler {
|
||||||
httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.IdpManager(), s.ServiceManager(), s.ReverseProxyDomainManager(), s.AccessLogsManager(), s.ReverseProxyGRPCServer(), s.Config.ReverseProxy.TrustedHTTPProxies)
|
httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.IdpManager(), s.ServiceManager(), s.ReverseProxyDomainManager(), s.AccessLogsManager(), s.ReverseProxyGRPCServer(), s.Config.ReverseProxy.TrustedHTTPProxies, s.RateLimiter())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("failed to create API handler: %v", err)
|
log.Fatalf("failed to create API handler: %v", err)
|
||||||
}
|
}
|
||||||
@@ -117,6 +118,15 @@ func (s *BaseServer) APIHandler() http.Handler {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *BaseServer) RateLimiter() *middleware.APIRateLimiter {
|
||||||
|
return Create(s, func() *middleware.APIRateLimiter {
|
||||||
|
cfg, enabled := middleware.RateLimiterConfigFromEnv()
|
||||||
|
limiter := middleware.NewAPIRateLimiter(cfg)
|
||||||
|
limiter.SetEnabled(enabled)
|
||||||
|
return limiter
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func (s *BaseServer) GRPCServer() *grpc.Server {
|
func (s *BaseServer) GRPCServer() *grpc.Server {
|
||||||
return Create(s, func() *grpc.Server {
|
return Create(s, func() *grpc.Server {
|
||||||
trustedPeers := s.Config.ReverseProxy.TrustedPeers
|
trustedPeers := s.Config.ReverseProxy.TrustedPeers
|
||||||
|
|||||||
@@ -5,9 +5,6 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
|
||||||
"strconv"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
"github.com/rs/cors"
|
"github.com/rs/cors"
|
||||||
@@ -67,13 +64,10 @@ import (
|
|||||||
|
|
||||||
const (
|
const (
|
||||||
apiPrefix = "/api"
|
apiPrefix = "/api"
|
||||||
rateLimitingEnabledKey = "NB_API_RATE_LIMITING_ENABLED"
|
|
||||||
rateLimitingBurstKey = "NB_API_RATE_LIMITING_BURST"
|
|
||||||
rateLimitingRPMKey = "NB_API_RATE_LIMITING_RPM"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// NewAPIHandler creates the Management service HTTP API handler registering all the available endpoints.
|
// NewAPIHandler creates the Management service HTTP API handler registering all the available endpoints.
|
||||||
func NewAPIHandler(ctx context.Context, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, integratedValidator integrated_validator.IntegratedValidator, proxyController port_forwarding.Controller, permissionsManager permissions.Manager, peersManager nbpeers.Manager, settingsManager settings.Manager, zManager zones.Manager, rManager records.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager, serviceManager service.Manager, reverseProxyDomainManager *manager.Manager, reverseProxyAccessLogsManager accesslogs.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, trustedHTTPProxies []netip.Prefix) (http.Handler, error) {
|
func NewAPIHandler(ctx context.Context, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, integratedValidator integrated_validator.IntegratedValidator, proxyController port_forwarding.Controller, permissionsManager permissions.Manager, peersManager nbpeers.Manager, settingsManager settings.Manager, zManager zones.Manager, rManager records.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager, serviceManager service.Manager, reverseProxyDomainManager *manager.Manager, reverseProxyAccessLogsManager accesslogs.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, trustedHTTPProxies []netip.Prefix, rateLimiter *middleware.APIRateLimiter) (http.Handler, error) {
|
||||||
|
|
||||||
// Register bypass paths for unauthenticated endpoints
|
// Register bypass paths for unauthenticated endpoints
|
||||||
if err := bypass.AddBypassPath("/api/instance"); err != nil {
|
if err := bypass.AddBypassPath("/api/instance"); err != nil {
|
||||||
@@ -94,34 +88,10 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks
|
|||||||
return nil, fmt.Errorf("failed to add bypass path: %w", err)
|
return nil, fmt.Errorf("failed to add bypass path: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var rateLimitingConfig *middleware.RateLimiterConfig
|
if rateLimiter == nil {
|
||||||
if os.Getenv(rateLimitingEnabledKey) == "true" {
|
log.Warn("NewAPIHandler: nil rate limiter, rate limiting disabled")
|
||||||
rpm := 6
|
rateLimiter = middleware.NewAPIRateLimiter(nil)
|
||||||
if v := os.Getenv(rateLimitingRPMKey); v != "" {
|
rateLimiter.SetEnabled(false)
|
||||||
value, err := strconv.Atoi(v)
|
|
||||||
if err != nil {
|
|
||||||
log.Warnf("parsing %s env var: %v, using default %d", rateLimitingRPMKey, err, rpm)
|
|
||||||
} else {
|
|
||||||
rpm = value
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
burst := 500
|
|
||||||
if v := os.Getenv(rateLimitingBurstKey); v != "" {
|
|
||||||
value, err := strconv.Atoi(v)
|
|
||||||
if err != nil {
|
|
||||||
log.Warnf("parsing %s env var: %v, using default %d", rateLimitingBurstKey, err, burst)
|
|
||||||
} else {
|
|
||||||
burst = value
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
rateLimitingConfig = &middleware.RateLimiterConfig{
|
|
||||||
RequestsPerMinute: float64(rpm),
|
|
||||||
Burst: burst,
|
|
||||||
CleanupInterval: 6 * time.Hour,
|
|
||||||
LimiterTTL: 24 * time.Hour,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
authMiddleware := middleware.NewAuthMiddleware(
|
authMiddleware := middleware.NewAuthMiddleware(
|
||||||
@@ -129,7 +99,7 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks
|
|||||||
accountManager.GetAccountIDFromUserAuth,
|
accountManager.GetAccountIDFromUserAuth,
|
||||||
accountManager.SyncUserJWTGroups,
|
accountManager.SyncUserJWTGroups,
|
||||||
accountManager.GetUserFromUserAuth,
|
accountManager.GetUserFromUserAuth,
|
||||||
rateLimitingConfig,
|
rateLimiter,
|
||||||
appMetrics.GetMeter(),
|
appMetrics.GetMeter(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -42,14 +42,9 @@ func NewAuthMiddleware(
|
|||||||
ensureAccount EnsureAccountFunc,
|
ensureAccount EnsureAccountFunc,
|
||||||
syncUserJWTGroups SyncUserJWTGroupsFunc,
|
syncUserJWTGroups SyncUserJWTGroupsFunc,
|
||||||
getUserFromUserAuth GetUserFromUserAuthFunc,
|
getUserFromUserAuth GetUserFromUserAuthFunc,
|
||||||
rateLimiterConfig *RateLimiterConfig,
|
rateLimiter *APIRateLimiter,
|
||||||
meter metric.Meter,
|
meter metric.Meter,
|
||||||
) *AuthMiddleware {
|
) *AuthMiddleware {
|
||||||
var rateLimiter *APIRateLimiter
|
|
||||||
if rateLimiterConfig != nil {
|
|
||||||
rateLimiter = NewAPIRateLimiter(rateLimiterConfig)
|
|
||||||
}
|
|
||||||
|
|
||||||
var patUsageTracker *PATUsageTracker
|
var patUsageTracker *PATUsageTracker
|
||||||
if meter != nil {
|
if meter != nil {
|
||||||
var err error
|
var err error
|
||||||
@@ -181,11 +176,9 @@ func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts []
|
|||||||
m.patUsageTracker.IncrementUsage(token)
|
m.patUsageTracker.IncrementUsage(token)
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.rateLimiter != nil && !isTerraformRequest(r) {
|
if !isTerraformRequest(r) && !m.rateLimiter.Allow(token) {
|
||||||
if !m.rateLimiter.Allow(token) {
|
|
||||||
return r, status.Errorf(status.TooManyRequests, "too many requests")
|
return r, status.Errorf(status.TooManyRequests, "too many requests")
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
ctx := r.Context()
|
ctx := r.Context()
|
||||||
user, pat, accDomain, accCategory, err := m.authManager.GetPATInfo(ctx, token)
|
user, pat, accDomain, accCategory, err := m.authManager.GetPATInfo(ctx, token)
|
||||||
|
|||||||
@@ -196,6 +196,8 @@ func TestAuthMiddleware_Handler(t *testing.T) {
|
|||||||
GetPATInfoFunc: mockGetAccountInfoFromPAT,
|
GetPATInfoFunc: mockGetAccountInfoFromPAT,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
disabledLimiter := NewAPIRateLimiter(nil)
|
||||||
|
disabledLimiter.SetEnabled(false)
|
||||||
authMiddleware := NewAuthMiddleware(
|
authMiddleware := NewAuthMiddleware(
|
||||||
mockAuth,
|
mockAuth,
|
||||||
func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) {
|
func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) {
|
||||||
@@ -207,7 +209,7 @@ func TestAuthMiddleware_Handler(t *testing.T) {
|
|||||||
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
|
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
|
||||||
return &types.User{}, nil
|
return &types.User{}, nil
|
||||||
},
|
},
|
||||||
nil,
|
disabledLimiter,
|
||||||
nil,
|
nil,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -266,7 +268,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
|||||||
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
|
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
|
||||||
return &types.User{}, nil
|
return &types.User{}, nil
|
||||||
},
|
},
|
||||||
rateLimitConfig,
|
NewAPIRateLimiter(rateLimitConfig),
|
||||||
nil,
|
nil,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -318,7 +320,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
|||||||
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
|
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
|
||||||
return &types.User{}, nil
|
return &types.User{}, nil
|
||||||
},
|
},
|
||||||
rateLimitConfig,
|
NewAPIRateLimiter(rateLimitConfig),
|
||||||
nil,
|
nil,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -361,7 +363,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
|||||||
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
|
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
|
||||||
return &types.User{}, nil
|
return &types.User{}, nil
|
||||||
},
|
},
|
||||||
rateLimitConfig,
|
NewAPIRateLimiter(rateLimitConfig),
|
||||||
nil,
|
nil,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -405,7 +407,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
|||||||
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
|
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
|
||||||
return &types.User{}, nil
|
return &types.User{}, nil
|
||||||
},
|
},
|
||||||
rateLimitConfig,
|
NewAPIRateLimiter(rateLimitConfig),
|
||||||
nil,
|
nil,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -469,7 +471,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
|||||||
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
|
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
|
||||||
return &types.User{}, nil
|
return &types.User{}, nil
|
||||||
},
|
},
|
||||||
rateLimitConfig,
|
NewAPIRateLimiter(rateLimitConfig),
|
||||||
nil,
|
nil,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -528,7 +530,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
|||||||
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
|
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
|
||||||
return &types.User{}, nil
|
return &types.User{}, nil
|
||||||
},
|
},
|
||||||
rateLimitConfig,
|
NewAPIRateLimiter(rateLimitConfig),
|
||||||
nil,
|
nil,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -583,7 +585,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
|||||||
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
|
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
|
||||||
return &types.User{}, nil
|
return &types.User{}, nil
|
||||||
},
|
},
|
||||||
rateLimitConfig,
|
NewAPIRateLimiter(rateLimitConfig),
|
||||||
nil,
|
nil,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -670,6 +672,8 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) {
|
|||||||
GetPATInfoFunc: mockGetAccountInfoFromPAT,
|
GetPATInfoFunc: mockGetAccountInfoFromPAT,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
disabledLimiter := NewAPIRateLimiter(nil)
|
||||||
|
disabledLimiter.SetEnabled(false)
|
||||||
authMiddleware := NewAuthMiddleware(
|
authMiddleware := NewAuthMiddleware(
|
||||||
mockAuth,
|
mockAuth,
|
||||||
func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) {
|
func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) {
|
||||||
@@ -681,7 +685,7 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) {
|
|||||||
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
|
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
|
||||||
return &types.User{}, nil
|
return &types.User{}, nil
|
||||||
},
|
},
|
||||||
nil,
|
disabledLimiter,
|
||||||
nil,
|
nil,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -4,14 +4,27 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"strconv"
|
||||||
"sync"
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.org/x/time/rate"
|
"golang.org/x/time/rate"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
RateLimitingEnabledEnv = "NB_API_RATE_LIMITING_ENABLED"
|
||||||
|
RateLimitingBurstEnv = "NB_API_RATE_LIMITING_BURST"
|
||||||
|
RateLimitingRPMEnv = "NB_API_RATE_LIMITING_RPM"
|
||||||
|
|
||||||
|
defaultAPIRPM = 6
|
||||||
|
defaultAPIBurst = 500
|
||||||
|
)
|
||||||
|
|
||||||
// RateLimiterConfig holds configuration for the API rate limiter
|
// RateLimiterConfig holds configuration for the API rate limiter
|
||||||
type RateLimiterConfig struct {
|
type RateLimiterConfig struct {
|
||||||
// RequestsPerMinute defines the rate at which tokens are replenished
|
// RequestsPerMinute defines the rate at which tokens are replenished
|
||||||
@@ -34,6 +47,43 @@ func DefaultRateLimiterConfig() *RateLimiterConfig {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func RateLimiterConfigFromEnv() (cfg *RateLimiterConfig, enabled bool) {
|
||||||
|
rpm := defaultAPIRPM
|
||||||
|
if v := os.Getenv(RateLimitingRPMEnv); v != "" {
|
||||||
|
value, err := strconv.Atoi(v)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("parsing %s env var: %v, using default %d", RateLimitingRPMEnv, err, rpm)
|
||||||
|
} else {
|
||||||
|
rpm = value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if rpm <= 0 {
|
||||||
|
log.Warnf("%s=%d is non-positive, using default %d", RateLimitingRPMEnv, rpm, defaultAPIRPM)
|
||||||
|
rpm = defaultAPIRPM
|
||||||
|
}
|
||||||
|
|
||||||
|
burst := defaultAPIBurst
|
||||||
|
if v := os.Getenv(RateLimitingBurstEnv); v != "" {
|
||||||
|
value, err := strconv.Atoi(v)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("parsing %s env var: %v, using default %d", RateLimitingBurstEnv, err, burst)
|
||||||
|
} else {
|
||||||
|
burst = value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if burst <= 0 {
|
||||||
|
log.Warnf("%s=%d is non-positive, using default %d", RateLimitingBurstEnv, burst, defaultAPIBurst)
|
||||||
|
burst = defaultAPIBurst
|
||||||
|
}
|
||||||
|
|
||||||
|
return &RateLimiterConfig{
|
||||||
|
RequestsPerMinute: float64(rpm),
|
||||||
|
Burst: burst,
|
||||||
|
CleanupInterval: 6 * time.Hour,
|
||||||
|
LimiterTTL: 24 * time.Hour,
|
||||||
|
}, os.Getenv(RateLimitingEnabledEnv) == "true"
|
||||||
|
}
|
||||||
|
|
||||||
// limiterEntry holds a rate limiter and its last access time
|
// limiterEntry holds a rate limiter and its last access time
|
||||||
type limiterEntry struct {
|
type limiterEntry struct {
|
||||||
limiter *rate.Limiter
|
limiter *rate.Limiter
|
||||||
@@ -46,6 +96,7 @@ type APIRateLimiter struct {
|
|||||||
limiters map[string]*limiterEntry
|
limiters map[string]*limiterEntry
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
stopChan chan struct{}
|
stopChan chan struct{}
|
||||||
|
enabled atomic.Bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewAPIRateLimiter creates a new API rate limiter with the given configuration
|
// NewAPIRateLimiter creates a new API rate limiter with the given configuration
|
||||||
@@ -59,14 +110,53 @@ func NewAPIRateLimiter(config *RateLimiterConfig) *APIRateLimiter {
|
|||||||
limiters: make(map[string]*limiterEntry),
|
limiters: make(map[string]*limiterEntry),
|
||||||
stopChan: make(chan struct{}),
|
stopChan: make(chan struct{}),
|
||||||
}
|
}
|
||||||
|
rl.enabled.Store(true)
|
||||||
|
|
||||||
go rl.cleanupLoop()
|
go rl.cleanupLoop()
|
||||||
|
|
||||||
return rl
|
return rl
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (rl *APIRateLimiter) SetEnabled(enabled bool) {
|
||||||
|
rl.enabled.Store(enabled)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rl *APIRateLimiter) Enabled() bool {
|
||||||
|
return rl.enabled.Load()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rl *APIRateLimiter) UpdateConfig(config *RateLimiterConfig) {
|
||||||
|
if config == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if config.RequestsPerMinute <= 0 || config.Burst <= 0 {
|
||||||
|
log.Warnf("UpdateConfig: ignoring invalid rpm=%v burst=%d", config.RequestsPerMinute, config.Burst)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
newRPS := rate.Limit(config.RequestsPerMinute / 60.0)
|
||||||
|
newBurst := config.Burst
|
||||||
|
|
||||||
|
rl.mu.Lock()
|
||||||
|
rl.config.RequestsPerMinute = config.RequestsPerMinute
|
||||||
|
rl.config.Burst = newBurst
|
||||||
|
snapshot := make([]*rate.Limiter, 0, len(rl.limiters))
|
||||||
|
for _, entry := range rl.limiters {
|
||||||
|
snapshot = append(snapshot, entry.limiter)
|
||||||
|
}
|
||||||
|
rl.mu.Unlock()
|
||||||
|
|
||||||
|
for _, l := range snapshot {
|
||||||
|
l.SetLimit(newRPS)
|
||||||
|
l.SetBurst(newBurst)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Allow checks if a request for the given key (token) is allowed
|
// Allow checks if a request for the given key (token) is allowed
|
||||||
func (rl *APIRateLimiter) Allow(key string) bool {
|
func (rl *APIRateLimiter) Allow(key string) bool {
|
||||||
|
if !rl.enabled.Load() {
|
||||||
|
return true
|
||||||
|
}
|
||||||
limiter := rl.getLimiter(key)
|
limiter := rl.getLimiter(key)
|
||||||
return limiter.Allow()
|
return limiter.Allow()
|
||||||
}
|
}
|
||||||
@@ -74,6 +164,9 @@ func (rl *APIRateLimiter) Allow(key string) bool {
|
|||||||
// Wait blocks until the rate limiter allows another request for the given key
|
// Wait blocks until the rate limiter allows another request for the given key
|
||||||
// Returns an error if the context is canceled
|
// Returns an error if the context is canceled
|
||||||
func (rl *APIRateLimiter) Wait(ctx context.Context, key string) error {
|
func (rl *APIRateLimiter) Wait(ctx context.Context, key string) error {
|
||||||
|
if !rl.enabled.Load() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
limiter := rl.getLimiter(key)
|
limiter := rl.getLimiter(key)
|
||||||
return limiter.Wait(ctx)
|
return limiter.Wait(ctx)
|
||||||
}
|
}
|
||||||
@@ -153,6 +246,10 @@ func (rl *APIRateLimiter) Reset(key string) {
|
|||||||
// Returns 429 Too Many Requests if the rate limit is exceeded.
|
// Returns 429 Too Many Requests if the rate limit is exceeded.
|
||||||
func (rl *APIRateLimiter) Middleware(next http.Handler) http.Handler {
|
func (rl *APIRateLimiter) Middleware(next http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if !rl.enabled.Load() {
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
clientIP := getClientIP(r)
|
clientIP := getClientIP(r)
|
||||||
if !rl.Allow(clientIP) {
|
if !rl.Allow(clientIP) {
|
||||||
util.WriteErrorResponse("rate limit exceeded, please try again later", http.StatusTooManyRequests, w)
|
util.WriteErrorResponse("rate limit exceeded, please try again later", http.StatusTooManyRequests, w)
|
||||||
|
|||||||
@@ -1,8 +1,10 @@
|
|||||||
package middleware
|
package middleware
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -156,3 +158,172 @@ func TestAPIRateLimiter_Reset(t *testing.T) {
|
|||||||
// Should be allowed again
|
// Should be allowed again
|
||||||
assert.True(t, rl.Allow("test-key"))
|
assert.True(t, rl.Allow("test-key"))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAPIRateLimiter_SetEnabled(t *testing.T) {
|
||||||
|
rl := NewAPIRateLimiter(&RateLimiterConfig{
|
||||||
|
RequestsPerMinute: 60,
|
||||||
|
Burst: 1,
|
||||||
|
CleanupInterval: time.Minute,
|
||||||
|
LimiterTTL: time.Minute,
|
||||||
|
})
|
||||||
|
defer rl.Stop()
|
||||||
|
|
||||||
|
assert.True(t, rl.Allow("key"))
|
||||||
|
assert.False(t, rl.Allow("key"), "burst exhausted while enabled")
|
||||||
|
|
||||||
|
rl.SetEnabled(false)
|
||||||
|
assert.False(t, rl.Enabled())
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
assert.True(t, rl.Allow("key"), "disabled limiter must always allow")
|
||||||
|
}
|
||||||
|
|
||||||
|
rl.SetEnabled(true)
|
||||||
|
assert.True(t, rl.Enabled())
|
||||||
|
assert.False(t, rl.Allow("key"), "re-enabled limiter retains prior bucket state")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAPIRateLimiter_UpdateConfig(t *testing.T) {
|
||||||
|
rl := NewAPIRateLimiter(&RateLimiterConfig{
|
||||||
|
RequestsPerMinute: 60,
|
||||||
|
Burst: 2,
|
||||||
|
CleanupInterval: time.Minute,
|
||||||
|
LimiterTTL: time.Minute,
|
||||||
|
})
|
||||||
|
defer rl.Stop()
|
||||||
|
|
||||||
|
assert.True(t, rl.Allow("k1"))
|
||||||
|
assert.True(t, rl.Allow("k1"))
|
||||||
|
assert.False(t, rl.Allow("k1"), "burst=2 exhausted")
|
||||||
|
|
||||||
|
rl.UpdateConfig(&RateLimiterConfig{
|
||||||
|
RequestsPerMinute: 60,
|
||||||
|
Burst: 10,
|
||||||
|
CleanupInterval: time.Minute,
|
||||||
|
LimiterTTL: time.Minute,
|
||||||
|
})
|
||||||
|
|
||||||
|
// New burst applies to existing keys in place; bucket refills up to new burst over time,
|
||||||
|
// but importantly newly-added keys use the updated config immediately.
|
||||||
|
assert.True(t, rl.Allow("k2"))
|
||||||
|
for i := 0; i < 9; i++ {
|
||||||
|
assert.True(t, rl.Allow("k2"))
|
||||||
|
}
|
||||||
|
assert.False(t, rl.Allow("k2"), "new burst=10 exhausted")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAPIRateLimiter_UpdateConfig_NilIgnored(t *testing.T) {
|
||||||
|
rl := NewAPIRateLimiter(&RateLimiterConfig{
|
||||||
|
RequestsPerMinute: 60,
|
||||||
|
Burst: 1,
|
||||||
|
CleanupInterval: time.Minute,
|
||||||
|
LimiterTTL: time.Minute,
|
||||||
|
})
|
||||||
|
defer rl.Stop()
|
||||||
|
|
||||||
|
rl.UpdateConfig(nil) // must not panic or zero the config
|
||||||
|
|
||||||
|
assert.True(t, rl.Allow("k"))
|
||||||
|
assert.False(t, rl.Allow("k"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAPIRateLimiter_UpdateConfig_NonPositiveIgnored(t *testing.T) {
|
||||||
|
rl := NewAPIRateLimiter(&RateLimiterConfig{
|
||||||
|
RequestsPerMinute: 60,
|
||||||
|
Burst: 1,
|
||||||
|
CleanupInterval: time.Minute,
|
||||||
|
LimiterTTL: time.Minute,
|
||||||
|
})
|
||||||
|
defer rl.Stop()
|
||||||
|
|
||||||
|
assert.True(t, rl.Allow("k"))
|
||||||
|
assert.False(t, rl.Allow("k"))
|
||||||
|
|
||||||
|
rl.UpdateConfig(&RateLimiterConfig{RequestsPerMinute: 0, Burst: 0, CleanupInterval: time.Minute, LimiterTTL: time.Minute})
|
||||||
|
rl.UpdateConfig(&RateLimiterConfig{RequestsPerMinute: -1, Burst: 5, CleanupInterval: time.Minute, LimiterTTL: time.Minute})
|
||||||
|
rl.UpdateConfig(&RateLimiterConfig{RequestsPerMinute: 60, Burst: -1, CleanupInterval: time.Minute, LimiterTTL: time.Minute})
|
||||||
|
|
||||||
|
rl.Reset("k")
|
||||||
|
assert.True(t, rl.Allow("k"))
|
||||||
|
assert.False(t, rl.Allow("k"), "burst should still be 1 — invalid UpdateConfig calls were ignored")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAPIRateLimiter_ConcurrentAllowAndUpdate(t *testing.T) {
|
||||||
|
rl := NewAPIRateLimiter(&RateLimiterConfig{
|
||||||
|
RequestsPerMinute: 600,
|
||||||
|
Burst: 10,
|
||||||
|
CleanupInterval: time.Minute,
|
||||||
|
LimiterTTL: time.Minute,
|
||||||
|
})
|
||||||
|
defer rl.Stop()
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
stop := make(chan struct{})
|
||||||
|
|
||||||
|
for i := 0; i < 8; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go func(id int) {
|
||||||
|
defer wg.Done()
|
||||||
|
key := fmt.Sprintf("k%d", id)
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-stop:
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
rl.Allow(key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
for i := 0; i < 200; i++ {
|
||||||
|
select {
|
||||||
|
case <-stop:
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
rl.UpdateConfig(&RateLimiterConfig{
|
||||||
|
RequestsPerMinute: float64(30 + (i % 90)),
|
||||||
|
Burst: 1 + (i % 20),
|
||||||
|
CleanupInterval: time.Minute,
|
||||||
|
LimiterTTL: time.Minute,
|
||||||
|
})
|
||||||
|
rl.SetEnabled(i%2 == 0)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
close(stop)
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRateLimiterConfigFromEnv(t *testing.T) {
|
||||||
|
t.Setenv(RateLimitingEnabledEnv, "true")
|
||||||
|
t.Setenv(RateLimitingRPMEnv, "42")
|
||||||
|
t.Setenv(RateLimitingBurstEnv, "7")
|
||||||
|
|
||||||
|
cfg, enabled := RateLimiterConfigFromEnv()
|
||||||
|
assert.True(t, enabled)
|
||||||
|
assert.Equal(t, float64(42), cfg.RequestsPerMinute)
|
||||||
|
assert.Equal(t, 7, cfg.Burst)
|
||||||
|
|
||||||
|
t.Setenv(RateLimitingEnabledEnv, "false")
|
||||||
|
_, enabled = RateLimiterConfigFromEnv()
|
||||||
|
assert.False(t, enabled)
|
||||||
|
|
||||||
|
t.Setenv(RateLimitingEnabledEnv, "")
|
||||||
|
t.Setenv(RateLimitingRPMEnv, "")
|
||||||
|
t.Setenv(RateLimitingBurstEnv, "")
|
||||||
|
cfg, enabled = RateLimiterConfigFromEnv()
|
||||||
|
assert.False(t, enabled)
|
||||||
|
assert.Equal(t, float64(defaultAPIRPM), cfg.RequestsPerMinute)
|
||||||
|
assert.Equal(t, defaultAPIBurst, cfg.Burst)
|
||||||
|
|
||||||
|
t.Setenv(RateLimitingRPMEnv, "0")
|
||||||
|
t.Setenv(RateLimitingBurstEnv, "-5")
|
||||||
|
cfg, _ = RateLimiterConfigFromEnv()
|
||||||
|
assert.Equal(t, float64(defaultAPIRPM), cfg.RequestsPerMinute, "non-positive rpm must fall back to default")
|
||||||
|
assert.Equal(t, defaultAPIBurst, cfg.Burst, "non-positive burst must fall back to default")
|
||||||
|
}
|
||||||
|
|||||||
@@ -135,7 +135,7 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
|
|||||||
customZonesManager := zonesManager.NewManager(store, am, permissionsManager, "")
|
customZonesManager := zonesManager.NewManager(store, am, permissionsManager, "")
|
||||||
zoneRecordsManager := recordsManager.NewManager(store, am, permissionsManager)
|
zoneRecordsManager := recordsManager.NewManager(store, am, permissionsManager)
|
||||||
|
|
||||||
apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil)
|
apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to create API handler: %v", err)
|
t.Fatalf("Failed to create API handler: %v", err)
|
||||||
}
|
}
|
||||||
@@ -264,7 +264,7 @@ func BuildApiBlackBoxWithDBStateAndPeerChannel(t testing_tools.TB, sqlFile strin
|
|||||||
customZonesManager := zonesManager.NewManager(store, am, permissionsManager, "")
|
customZonesManager := zonesManager.NewManager(store, am, permissionsManager, "")
|
||||||
zoneRecordsManager := recordsManager.NewManager(store, am, permissionsManager)
|
zoneRecordsManager := recordsManager.NewManager(store, am, permissionsManager)
|
||||||
|
|
||||||
apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil)
|
apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to create API handler: %v", err)
|
t.Fatalf("Failed to create API handler: %v", err)
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user