mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-16 07:16:38 +00:00
[management] add pat rate limiting (#4741)
This commit is contained in:
2
go.mod
2
go.mod
@@ -108,6 +108,7 @@ require (
|
||||
golang.org/x/oauth2 v0.30.0
|
||||
golang.org/x/sync v0.16.0
|
||||
golang.org/x/term v0.33.0
|
||||
golang.org/x/time v0.12.0
|
||||
google.golang.org/api v0.177.0
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
gorm.io/driver/mysql v1.5.7
|
||||
@@ -245,7 +246,6 @@ require (
|
||||
go.uber.org/multierr v1.11.0 // indirect
|
||||
golang.org/x/image v0.18.0 // indirect
|
||||
golang.org/x/text v0.27.0 // indirect
|
||||
golang.org/x/time v0.12.0 // indirect
|
||||
golang.org/x/tools v0.34.0 // indirect
|
||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20250707201910-8d1bb00bc6a7 // indirect
|
||||
|
||||
@@ -4,9 +4,13 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/rs/cors"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
|
||||
@@ -38,7 +42,12 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
)
|
||||
|
||||
const apiPrefix = "/api"
|
||||
const (
|
||||
apiPrefix = "/api"
|
||||
rateLimitingEnabledKey = "NB_API_RATE_LIMITING_ENABLED"
|
||||
rateLimitingBurstKey = "NB_API_RATE_LIMITING_BURST"
|
||||
rateLimitingRPMKey = "NB_API_RATE_LIMITING_RPM"
|
||||
)
|
||||
|
||||
// NewAPIHandler creates the Management service HTTP API handler registering all the available endpoints.
|
||||
func NewAPIHandler(
|
||||
@@ -58,11 +67,42 @@ func NewAPIHandler(
|
||||
settingsManager settings.Manager,
|
||||
) (http.Handler, error) {
|
||||
|
||||
var rateLimitingConfig *middleware.RateLimiterConfig
|
||||
if os.Getenv(rateLimitingEnabledKey) == "true" {
|
||||
rpm := 6
|
||||
if v := os.Getenv(rateLimitingRPMKey); v != "" {
|
||||
value, err := strconv.Atoi(v)
|
||||
if err != nil {
|
||||
log.Warnf("parsing %s env var: %v, using default %d", rateLimitingRPMKey, err, rpm)
|
||||
} else {
|
||||
rpm = value
|
||||
}
|
||||
}
|
||||
|
||||
burst := 500
|
||||
if v := os.Getenv(rateLimitingBurstKey); v != "" {
|
||||
value, err := strconv.Atoi(v)
|
||||
if err != nil {
|
||||
log.Warnf("parsing %s env var: %v, using default %d", rateLimitingBurstKey, err, burst)
|
||||
} else {
|
||||
burst = value
|
||||
}
|
||||
}
|
||||
|
||||
rateLimitingConfig = &middleware.RateLimiterConfig{
|
||||
RequestsPerMinute: float64(rpm),
|
||||
Burst: burst,
|
||||
CleanupInterval: 6 * time.Hour,
|
||||
LimiterTTL: 24 * time.Hour,
|
||||
}
|
||||
}
|
||||
|
||||
authMiddleware := middleware.NewAuthMiddleware(
|
||||
authManager,
|
||||
accountManager.GetAccountIDFromUserAuth,
|
||||
accountManager.SyncUserJWTGroups,
|
||||
accountManager.GetUserFromUserAuth,
|
||||
rateLimitingConfig,
|
||||
)
|
||||
|
||||
corsMiddleware := cors.AllowAll()
|
||||
|
||||
@@ -29,6 +29,7 @@ type AuthMiddleware struct {
|
||||
ensureAccount EnsureAccountFunc
|
||||
getUserFromUserAuth GetUserFromUserAuthFunc
|
||||
syncUserJWTGroups SyncUserJWTGroupsFunc
|
||||
rateLimiter *APIRateLimiter
|
||||
}
|
||||
|
||||
// NewAuthMiddleware instance constructor
|
||||
@@ -37,12 +38,19 @@ func NewAuthMiddleware(
|
||||
ensureAccount EnsureAccountFunc,
|
||||
syncUserJWTGroups SyncUserJWTGroupsFunc,
|
||||
getUserFromUserAuth GetUserFromUserAuthFunc,
|
||||
rateLimiterConfig *RateLimiterConfig,
|
||||
) *AuthMiddleware {
|
||||
var rateLimiter *APIRateLimiter
|
||||
if rateLimiterConfig != nil {
|
||||
rateLimiter = NewAPIRateLimiter(rateLimiterConfig)
|
||||
}
|
||||
|
||||
return &AuthMiddleware{
|
||||
authManager: authManager,
|
||||
ensureAccount: ensureAccount,
|
||||
syncUserJWTGroups: syncUserJWTGroups,
|
||||
getUserFromUserAuth: getUserFromUserAuth,
|
||||
rateLimiter: rateLimiter,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -76,7 +84,11 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler {
|
||||
request, err := m.checkPATFromRequest(r, auth)
|
||||
if err != nil {
|
||||
log.WithContext(r.Context()).Debugf("Error when validating PAT: %s", err.Error())
|
||||
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "token invalid"), w)
|
||||
// Check if it's a status error, otherwise default to Unauthorized
|
||||
if _, ok := status.FromError(err); !ok {
|
||||
err = status.Errorf(status.Unauthorized, "token invalid")
|
||||
}
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
h.ServeHTTP(w, request)
|
||||
@@ -145,6 +157,12 @@ func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, auth []string) (*h
|
||||
return r, fmt.Errorf("error extracting token: %w", err)
|
||||
}
|
||||
|
||||
if m.rateLimiter != nil {
|
||||
if !m.rateLimiter.Allow(token) {
|
||||
return r, status.Errorf(status.TooManyRequests, "too many requests")
|
||||
}
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
user, pat, accDomain, accCategory, err := m.authManager.GetPATInfo(ctx, token)
|
||||
if err != nil {
|
||||
|
||||
@@ -27,7 +27,9 @@ const (
|
||||
domainCategory = "domainCategory"
|
||||
userID = "userID"
|
||||
tokenID = "tokenID"
|
||||
tokenID2 = "tokenID2"
|
||||
PAT = "nbp_PAT"
|
||||
PAT2 = "nbp_PAT2"
|
||||
JWT = "JWT"
|
||||
wrongToken = "wrongToken"
|
||||
)
|
||||
@@ -49,6 +51,15 @@ var testAccount = &types.Account{
|
||||
CreatedAt: time.Now().UTC(),
|
||||
LastUsed: util.ToPtr(time.Now().UTC()),
|
||||
},
|
||||
tokenID2: {
|
||||
ID: tokenID2,
|
||||
Name: "My second token",
|
||||
HashedToken: "someHash2",
|
||||
ExpirationDate: util.ToPtr(time.Now().UTC().AddDate(0, 0, 7)),
|
||||
CreatedBy: userID,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
LastUsed: util.ToPtr(time.Now().UTC()),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -58,6 +69,9 @@ func mockGetAccountInfoFromPAT(_ context.Context, token string) (user *types.Use
|
||||
if token == PAT {
|
||||
return testAccount.Users[userID], testAccount.Users[userID].PATs[tokenID], testAccount.Domain, testAccount.DomainCategory, nil
|
||||
}
|
||||
if token == PAT2 {
|
||||
return testAccount.Users[userID], testAccount.Users[userID].PATs[tokenID2], testAccount.Domain, testAccount.DomainCategory, nil
|
||||
}
|
||||
return nil, nil, "", "", fmt.Errorf("PAT invalid")
|
||||
}
|
||||
|
||||
@@ -81,7 +95,7 @@ func mockValidateAndParseToken(_ context.Context, token string) (nbcontext.UserA
|
||||
}
|
||||
|
||||
func mockMarkPATUsed(_ context.Context, token string) error {
|
||||
if token == tokenID {
|
||||
if token == tokenID || token == tokenID2 {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("Should never get reached")
|
||||
@@ -192,6 +206,7 @@ func TestAuthMiddleware_Handler(t *testing.T) {
|
||||
func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) {
|
||||
return &types.User{}, nil
|
||||
},
|
||||
nil,
|
||||
)
|
||||
|
||||
handlerToTest := authMiddleware.Handler(nextHandler)
|
||||
@@ -221,6 +236,273 @@ func TestAuthMiddleware_Handler(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
||||
mockAuth := &auth.MockManager{
|
||||
ValidateAndParseTokenFunc: mockValidateAndParseToken,
|
||||
EnsureUserAccessByJWTGroupsFunc: mockEnsureUserAccessByJWTGroups,
|
||||
MarkPATUsedFunc: mockMarkPATUsed,
|
||||
GetPATInfoFunc: mockGetAccountInfoFromPAT,
|
||||
}
|
||||
|
||||
t.Run("PAT Token Rate Limiting - Burst Works", func(t *testing.T) {
|
||||
// Configure rate limiter: 10 requests per minute with burst of 5
|
||||
rateLimitConfig := &RateLimiterConfig{
|
||||
RequestsPerMinute: 10,
|
||||
Burst: 5,
|
||||
CleanupInterval: 5 * time.Minute,
|
||||
LimiterTTL: 10 * time.Minute,
|
||||
}
|
||||
|
||||
authMiddleware := NewAuthMiddleware(
|
||||
mockAuth,
|
||||
func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) {
|
||||
return userAuth.AccountId, userAuth.UserId, nil
|
||||
},
|
||||
func(ctx context.Context, userAuth nbcontext.UserAuth) error {
|
||||
return nil
|
||||
},
|
||||
func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) {
|
||||
return &types.User{}, nil
|
||||
},
|
||||
rateLimitConfig,
|
||||
)
|
||||
|
||||
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
// Make burst requests - all should succeed
|
||||
successCount := 0
|
||||
for i := 0; i < 5; i++ {
|
||||
req := httptest.NewRequest("GET", "http://testing/test", nil)
|
||||
req.Header.Set("Authorization", "Token "+PAT)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(rec, req)
|
||||
if rec.Code == http.StatusOK {
|
||||
successCount++
|
||||
}
|
||||
}
|
||||
|
||||
assert.Equal(t, 5, successCount, "All burst requests should succeed")
|
||||
|
||||
// The 6th request should fail (exceeded burst)
|
||||
req := httptest.NewRequest("GET", "http://testing/test", nil)
|
||||
req.Header.Set("Authorization", "Token "+PAT)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(rec, req)
|
||||
assert.Equal(t, http.StatusTooManyRequests, rec.Code, "Request beyond burst should be rate limited")
|
||||
})
|
||||
|
||||
t.Run("PAT Token Rate Limiting - Rate Limit Enforced", func(t *testing.T) {
|
||||
// Configure very low rate limit: 1 request per minute
|
||||
rateLimitConfig := &RateLimiterConfig{
|
||||
RequestsPerMinute: 1,
|
||||
Burst: 1,
|
||||
CleanupInterval: 5 * time.Minute,
|
||||
LimiterTTL: 10 * time.Minute,
|
||||
}
|
||||
|
||||
authMiddleware := NewAuthMiddleware(
|
||||
mockAuth,
|
||||
func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) {
|
||||
return userAuth.AccountId, userAuth.UserId, nil
|
||||
},
|
||||
func(ctx context.Context, userAuth nbcontext.UserAuth) error {
|
||||
return nil
|
||||
},
|
||||
func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) {
|
||||
return &types.User{}, nil
|
||||
},
|
||||
rateLimitConfig,
|
||||
)
|
||||
|
||||
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
// First request should succeed
|
||||
req := httptest.NewRequest("GET", "http://testing/test", nil)
|
||||
req.Header.Set("Authorization", "Token "+PAT)
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
assert.Equal(t, http.StatusOK, rec.Code, "First request should succeed")
|
||||
|
||||
// Second request should fail (rate limited)
|
||||
req = httptest.NewRequest("GET", "http://testing/test", nil)
|
||||
req.Header.Set("Authorization", "Token "+PAT)
|
||||
rec = httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
assert.Equal(t, http.StatusTooManyRequests, rec.Code, "Second request should be rate limited")
|
||||
})
|
||||
|
||||
t.Run("Bearer Token Not Rate Limited", func(t *testing.T) {
|
||||
// Configure strict rate limit
|
||||
rateLimitConfig := &RateLimiterConfig{
|
||||
RequestsPerMinute: 1,
|
||||
Burst: 1,
|
||||
CleanupInterval: 5 * time.Minute,
|
||||
LimiterTTL: 10 * time.Minute,
|
||||
}
|
||||
|
||||
authMiddleware := NewAuthMiddleware(
|
||||
mockAuth,
|
||||
func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) {
|
||||
return userAuth.AccountId, userAuth.UserId, nil
|
||||
},
|
||||
func(ctx context.Context, userAuth nbcontext.UserAuth) error {
|
||||
return nil
|
||||
},
|
||||
func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) {
|
||||
return &types.User{}, nil
|
||||
},
|
||||
rateLimitConfig,
|
||||
)
|
||||
|
||||
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
// Make multiple requests with Bearer token - all should succeed
|
||||
successCount := 0
|
||||
for i := 0; i < 10; i++ {
|
||||
req := httptest.NewRequest("GET", "http://testing/test", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+JWT)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(rec, req)
|
||||
if rec.Code == http.StatusOK {
|
||||
successCount++
|
||||
}
|
||||
}
|
||||
|
||||
assert.Equal(t, 10, successCount, "All Bearer token requests should succeed (not rate limited)")
|
||||
})
|
||||
|
||||
t.Run("PAT Token Rate Limiting Per Token", func(t *testing.T) {
|
||||
// Configure rate limiter
|
||||
rateLimitConfig := &RateLimiterConfig{
|
||||
RequestsPerMinute: 1,
|
||||
Burst: 1,
|
||||
CleanupInterval: 5 * time.Minute,
|
||||
LimiterTTL: 10 * time.Minute,
|
||||
}
|
||||
|
||||
authMiddleware := NewAuthMiddleware(
|
||||
mockAuth,
|
||||
func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) {
|
||||
return userAuth.AccountId, userAuth.UserId, nil
|
||||
},
|
||||
func(ctx context.Context, userAuth nbcontext.UserAuth) error {
|
||||
return nil
|
||||
},
|
||||
func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) {
|
||||
return &types.User{}, nil
|
||||
},
|
||||
rateLimitConfig,
|
||||
)
|
||||
|
||||
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
// Use first PAT token
|
||||
req := httptest.NewRequest("GET", "http://testing/test", nil)
|
||||
req.Header.Set("Authorization", "Token "+PAT)
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
assert.Equal(t, http.StatusOK, rec.Code, "First request with PAT should succeed")
|
||||
|
||||
// Second request with same token should fail
|
||||
req = httptest.NewRequest("GET", "http://testing/test", nil)
|
||||
req.Header.Set("Authorization", "Token "+PAT)
|
||||
rec = httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
assert.Equal(t, http.StatusTooManyRequests, rec.Code, "Second request with same PAT should be rate limited")
|
||||
|
||||
// Use second PAT token - should succeed because it has independent rate limit
|
||||
req = httptest.NewRequest("GET", "http://testing/test", nil)
|
||||
req.Header.Set("Authorization", "Token "+PAT2)
|
||||
rec = httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
assert.Equal(t, http.StatusOK, rec.Code, "First request with PAT2 should succeed (independent rate limit)")
|
||||
|
||||
// Second request with PAT2 should also be rate limited
|
||||
req = httptest.NewRequest("GET", "http://testing/test", nil)
|
||||
req.Header.Set("Authorization", "Token "+PAT2)
|
||||
rec = httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
assert.Equal(t, http.StatusTooManyRequests, rec.Code, "Second request with PAT2 should be rate limited")
|
||||
|
||||
// JWT should still work (not rate limited)
|
||||
req = httptest.NewRequest("GET", "http://testing/test", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+JWT)
|
||||
rec = httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
assert.Equal(t, http.StatusOK, rec.Code, "JWT request should succeed (not rate limited)")
|
||||
})
|
||||
|
||||
t.Run("Rate Limiter Cleanup", func(t *testing.T) {
|
||||
// Configure rate limiter with short cleanup interval and TTL for testing
|
||||
rateLimitConfig := &RateLimiterConfig{
|
||||
RequestsPerMinute: 60,
|
||||
Burst: 1,
|
||||
CleanupInterval: 100 * time.Millisecond,
|
||||
LimiterTTL: 200 * time.Millisecond,
|
||||
}
|
||||
|
||||
authMiddleware := NewAuthMiddleware(
|
||||
mockAuth,
|
||||
func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) {
|
||||
return userAuth.AccountId, userAuth.UserId, nil
|
||||
},
|
||||
func(ctx context.Context, userAuth nbcontext.UserAuth) error {
|
||||
return nil
|
||||
},
|
||||
func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) {
|
||||
return &types.User{}, nil
|
||||
},
|
||||
rateLimitConfig,
|
||||
)
|
||||
|
||||
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
// First request - should succeed
|
||||
req := httptest.NewRequest("GET", "http://testing/test", nil)
|
||||
req.Header.Set("Authorization", "Token "+PAT)
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
assert.Equal(t, http.StatusOK, rec.Code, "First request should succeed")
|
||||
|
||||
// Second request immediately - should fail (burst exhausted)
|
||||
req = httptest.NewRequest("GET", "http://testing/test", nil)
|
||||
req.Header.Set("Authorization", "Token "+PAT)
|
||||
rec = httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
assert.Equal(t, http.StatusTooManyRequests, rec.Code, "Second request should be rate limited")
|
||||
|
||||
// Wait for limiter to be cleaned up (TTL + cleanup interval + buffer)
|
||||
time.Sleep(400 * time.Millisecond)
|
||||
|
||||
// After cleanup, the limiter should be removed and recreated with full burst capacity
|
||||
req = httptest.NewRequest("GET", "http://testing/test", nil)
|
||||
req.Header.Set("Authorization", "Token "+PAT)
|
||||
rec = httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
assert.Equal(t, http.StatusOK, rec.Code, "Request after cleanup should succeed (new limiter with full burst)")
|
||||
|
||||
// Verify it's a fresh limiter by checking burst is reset
|
||||
req = httptest.NewRequest("GET", "http://testing/test", nil)
|
||||
req.Header.Set("Authorization", "Token "+PAT)
|
||||
rec = httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
assert.Equal(t, http.StatusTooManyRequests, rec.Code, "Second request after cleanup should be rate limited again")
|
||||
})
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_Handler_Child(t *testing.T) {
|
||||
tt := []struct {
|
||||
name string
|
||||
@@ -297,6 +579,7 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) {
|
||||
func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) {
|
||||
return &types.User{}, nil
|
||||
},
|
||||
nil,
|
||||
)
|
||||
|
||||
for _, tc := range tt {
|
||||
|
||||
146
management/server/http/middleware/rate_limiter.go
Normal file
146
management/server/http/middleware/rate_limiter.go
Normal file
@@ -0,0 +1,146 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
// RateLimiterConfig holds configuration for the API rate limiter
|
||||
type RateLimiterConfig struct {
|
||||
// RequestsPerMinute defines the rate at which tokens are replenished
|
||||
RequestsPerMinute float64
|
||||
// Burst defines the maximum number of requests that can be made in a burst
|
||||
Burst int
|
||||
// CleanupInterval defines how often to clean up old limiters (how often garbage collection runs)
|
||||
CleanupInterval time.Duration
|
||||
// LimiterTTL defines how long a limiter should be kept after last use (age threshold for removal)
|
||||
LimiterTTL time.Duration
|
||||
}
|
||||
|
||||
// DefaultRateLimiterConfig returns a default configuration
|
||||
func DefaultRateLimiterConfig() *RateLimiterConfig {
|
||||
return &RateLimiterConfig{
|
||||
RequestsPerMinute: 100,
|
||||
Burst: 120,
|
||||
CleanupInterval: 5 * time.Minute,
|
||||
LimiterTTL: 10 * time.Minute,
|
||||
}
|
||||
}
|
||||
|
||||
// limiterEntry holds a rate limiter and its last access time
|
||||
type limiterEntry struct {
|
||||
limiter *rate.Limiter
|
||||
lastAccess time.Time
|
||||
}
|
||||
|
||||
// APIRateLimiter manages rate limiting for API tokens
|
||||
type APIRateLimiter struct {
|
||||
config *RateLimiterConfig
|
||||
limiters map[string]*limiterEntry
|
||||
mu sync.RWMutex
|
||||
stopChan chan struct{}
|
||||
}
|
||||
|
||||
// NewAPIRateLimiter creates a new API rate limiter with the given configuration
|
||||
func NewAPIRateLimiter(config *RateLimiterConfig) *APIRateLimiter {
|
||||
if config == nil {
|
||||
config = DefaultRateLimiterConfig()
|
||||
}
|
||||
|
||||
rl := &APIRateLimiter{
|
||||
config: config,
|
||||
limiters: make(map[string]*limiterEntry),
|
||||
stopChan: make(chan struct{}),
|
||||
}
|
||||
|
||||
go rl.cleanupLoop()
|
||||
|
||||
return rl
|
||||
}
|
||||
|
||||
// Allow checks if a request for the given key (token) is allowed
|
||||
func (rl *APIRateLimiter) Allow(key string) bool {
|
||||
limiter := rl.getLimiter(key)
|
||||
return limiter.Allow()
|
||||
}
|
||||
|
||||
// Wait blocks until the rate limiter allows another request for the given key
|
||||
// Returns an error if the context is canceled
|
||||
func (rl *APIRateLimiter) Wait(ctx context.Context, key string) error {
|
||||
limiter := rl.getLimiter(key)
|
||||
return limiter.Wait(ctx)
|
||||
}
|
||||
|
||||
// getLimiter retrieves or creates a rate limiter for the given key
|
||||
func (rl *APIRateLimiter) getLimiter(key string) *rate.Limiter {
|
||||
rl.mu.RLock()
|
||||
entry, exists := rl.limiters[key]
|
||||
rl.mu.RUnlock()
|
||||
|
||||
if exists {
|
||||
rl.mu.Lock()
|
||||
entry.lastAccess = time.Now()
|
||||
rl.mu.Unlock()
|
||||
return entry.limiter
|
||||
}
|
||||
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
|
||||
if entry, exists := rl.limiters[key]; exists {
|
||||
entry.lastAccess = time.Now()
|
||||
return entry.limiter
|
||||
}
|
||||
|
||||
requestsPerSecond := rl.config.RequestsPerMinute / 60.0
|
||||
limiter := rate.NewLimiter(rate.Limit(requestsPerSecond), rl.config.Burst)
|
||||
rl.limiters[key] = &limiterEntry{
|
||||
limiter: limiter,
|
||||
lastAccess: time.Now(),
|
||||
}
|
||||
|
||||
return limiter
|
||||
}
|
||||
|
||||
// cleanupLoop periodically removes old limiters that haven't been used recently
|
||||
func (rl *APIRateLimiter) cleanupLoop() {
|
||||
ticker := time.NewTicker(rl.config.CleanupInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
rl.cleanup()
|
||||
case <-rl.stopChan:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// cleanup removes limiters that haven't been used within the TTL period
|
||||
func (rl *APIRateLimiter) cleanup() {
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
for key, entry := range rl.limiters {
|
||||
if now.Sub(entry.lastAccess) > rl.config.LimiterTTL {
|
||||
delete(rl.limiters, key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Stop stops the cleanup goroutine
|
||||
func (rl *APIRateLimiter) Stop() {
|
||||
close(rl.stopChan)
|
||||
}
|
||||
|
||||
// Reset removes the rate limiter for a specific key
|
||||
func (rl *APIRateLimiter) Reset(key string) {
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
delete(rl.limiters, key)
|
||||
}
|
||||
@@ -106,6 +106,8 @@ func WriteError(ctx context.Context, err error, w http.ResponseWriter) {
|
||||
httpStatus = http.StatusUnauthorized
|
||||
case status.BadRequest:
|
||||
httpStatus = http.StatusBadRequest
|
||||
case status.TooManyRequests:
|
||||
httpStatus = http.StatusTooManyRequests
|
||||
default:
|
||||
}
|
||||
msg = strings.ToLower(err.Error())
|
||||
|
||||
@@ -37,6 +37,9 @@ const (
|
||||
|
||||
// Unauthenticated indicates that user is not authenticated due to absence of valid credentials
|
||||
Unauthenticated Type = 10
|
||||
|
||||
// TooManyRequests indicates that the user has sent too many requests in a given amount of time (rate limiting)
|
||||
TooManyRequests Type = 11
|
||||
)
|
||||
|
||||
// Type is a type of the Error
|
||||
|
||||
Reference in New Issue
Block a user