[management] add pat rate limiting (#4741)

This commit is contained in:
Pascal Fischer
2025-11-07 15:50:18 +01:00
committed by GitHub
parent 6aa4ba7af4
commit 48475ddc05
7 changed files with 496 additions and 4 deletions

2
go.mod
View File

@@ -108,6 +108,7 @@ require (
golang.org/x/oauth2 v0.30.0
golang.org/x/sync v0.16.0
golang.org/x/term v0.33.0
golang.org/x/time v0.12.0
google.golang.org/api v0.177.0
gopkg.in/yaml.v3 v3.0.1
gorm.io/driver/mysql v1.5.7
@@ -245,7 +246,6 @@ require (
go.uber.org/multierr v1.11.0 // indirect
golang.org/x/image v0.18.0 // indirect
golang.org/x/text v0.27.0 // indirect
golang.org/x/time v0.12.0 // indirect
golang.org/x/tools v0.34.0 // indirect
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20250707201910-8d1bb00bc6a7 // indirect

View File

@@ -4,9 +4,13 @@ import (
"context"
"fmt"
"net/http"
"os"
"strconv"
"time"
"github.com/gorilla/mux"
"github.com/rs/cors"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/management-integrations/integrations"
@@ -38,7 +42,12 @@ import (
"github.com/netbirdio/netbird/management/server/telemetry"
)
const apiPrefix = "/api"
const (
apiPrefix = "/api"
rateLimitingEnabledKey = "NB_API_RATE_LIMITING_ENABLED"
rateLimitingBurstKey = "NB_API_RATE_LIMITING_BURST"
rateLimitingRPMKey = "NB_API_RATE_LIMITING_RPM"
)
// NewAPIHandler creates the Management service HTTP API handler registering all the available endpoints.
func NewAPIHandler(
@@ -58,11 +67,42 @@ func NewAPIHandler(
settingsManager settings.Manager,
) (http.Handler, error) {
var rateLimitingConfig *middleware.RateLimiterConfig
if os.Getenv(rateLimitingEnabledKey) == "true" {
rpm := 6
if v := os.Getenv(rateLimitingRPMKey); v != "" {
value, err := strconv.Atoi(v)
if err != nil {
log.Warnf("parsing %s env var: %v, using default %d", rateLimitingRPMKey, err, rpm)
} else {
rpm = value
}
}
burst := 500
if v := os.Getenv(rateLimitingBurstKey); v != "" {
value, err := strconv.Atoi(v)
if err != nil {
log.Warnf("parsing %s env var: %v, using default %d", rateLimitingBurstKey, err, burst)
} else {
burst = value
}
}
rateLimitingConfig = &middleware.RateLimiterConfig{
RequestsPerMinute: float64(rpm),
Burst: burst,
CleanupInterval: 6 * time.Hour,
LimiterTTL: 24 * time.Hour,
}
}
authMiddleware := middleware.NewAuthMiddleware(
authManager,
accountManager.GetAccountIDFromUserAuth,
accountManager.SyncUserJWTGroups,
accountManager.GetUserFromUserAuth,
rateLimitingConfig,
)
corsMiddleware := cors.AllowAll()

View File

@@ -29,6 +29,7 @@ type AuthMiddleware struct {
ensureAccount EnsureAccountFunc
getUserFromUserAuth GetUserFromUserAuthFunc
syncUserJWTGroups SyncUserJWTGroupsFunc
rateLimiter *APIRateLimiter
}
// NewAuthMiddleware instance constructor
@@ -37,12 +38,19 @@ func NewAuthMiddleware(
ensureAccount EnsureAccountFunc,
syncUserJWTGroups SyncUserJWTGroupsFunc,
getUserFromUserAuth GetUserFromUserAuthFunc,
rateLimiterConfig *RateLimiterConfig,
) *AuthMiddleware {
var rateLimiter *APIRateLimiter
if rateLimiterConfig != nil {
rateLimiter = NewAPIRateLimiter(rateLimiterConfig)
}
return &AuthMiddleware{
authManager: authManager,
ensureAccount: ensureAccount,
syncUserJWTGroups: syncUserJWTGroups,
getUserFromUserAuth: getUserFromUserAuth,
rateLimiter: rateLimiter,
}
}
@@ -76,7 +84,11 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler {
request, err := m.checkPATFromRequest(r, auth)
if err != nil {
log.WithContext(r.Context()).Debugf("Error when validating PAT: %s", err.Error())
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "token invalid"), w)
// Check if it's a status error, otherwise default to Unauthorized
if _, ok := status.FromError(err); !ok {
err = status.Errorf(status.Unauthorized, "token invalid")
}
util.WriteError(r.Context(), err, w)
return
}
h.ServeHTTP(w, request)
@@ -145,6 +157,12 @@ func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, auth []string) (*h
return r, fmt.Errorf("error extracting token: %w", err)
}
if m.rateLimiter != nil {
if !m.rateLimiter.Allow(token) {
return r, status.Errorf(status.TooManyRequests, "too many requests")
}
}
ctx := r.Context()
user, pat, accDomain, accCategory, err := m.authManager.GetPATInfo(ctx, token)
if err != nil {

View File

@@ -27,7 +27,9 @@ const (
domainCategory = "domainCategory"
userID = "userID"
tokenID = "tokenID"
tokenID2 = "tokenID2"
PAT = "nbp_PAT"
PAT2 = "nbp_PAT2"
JWT = "JWT"
wrongToken = "wrongToken"
)
@@ -49,6 +51,15 @@ var testAccount = &types.Account{
CreatedAt: time.Now().UTC(),
LastUsed: util.ToPtr(time.Now().UTC()),
},
tokenID2: {
ID: tokenID2,
Name: "My second token",
HashedToken: "someHash2",
ExpirationDate: util.ToPtr(time.Now().UTC().AddDate(0, 0, 7)),
CreatedBy: userID,
CreatedAt: time.Now().UTC(),
LastUsed: util.ToPtr(time.Now().UTC()),
},
},
},
},
@@ -58,6 +69,9 @@ func mockGetAccountInfoFromPAT(_ context.Context, token string) (user *types.Use
if token == PAT {
return testAccount.Users[userID], testAccount.Users[userID].PATs[tokenID], testAccount.Domain, testAccount.DomainCategory, nil
}
if token == PAT2 {
return testAccount.Users[userID], testAccount.Users[userID].PATs[tokenID2], testAccount.Domain, testAccount.DomainCategory, nil
}
return nil, nil, "", "", fmt.Errorf("PAT invalid")
}
@@ -81,7 +95,7 @@ func mockValidateAndParseToken(_ context.Context, token string) (nbcontext.UserA
}
func mockMarkPATUsed(_ context.Context, token string) error {
if token == tokenID {
if token == tokenID || token == tokenID2 {
return nil
}
return fmt.Errorf("Should never get reached")
@@ -192,6 +206,7 @@ func TestAuthMiddleware_Handler(t *testing.T) {
func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) {
return &types.User{}, nil
},
nil,
)
handlerToTest := authMiddleware.Handler(nextHandler)
@@ -221,6 +236,273 @@ func TestAuthMiddleware_Handler(t *testing.T) {
}
}
func TestAuthMiddleware_RateLimiting(t *testing.T) {
mockAuth := &auth.MockManager{
ValidateAndParseTokenFunc: mockValidateAndParseToken,
EnsureUserAccessByJWTGroupsFunc: mockEnsureUserAccessByJWTGroups,
MarkPATUsedFunc: mockMarkPATUsed,
GetPATInfoFunc: mockGetAccountInfoFromPAT,
}
t.Run("PAT Token Rate Limiting - Burst Works", func(t *testing.T) {
// Configure rate limiter: 10 requests per minute with burst of 5
rateLimitConfig := &RateLimiterConfig{
RequestsPerMinute: 10,
Burst: 5,
CleanupInterval: 5 * time.Minute,
LimiterTTL: 10 * time.Minute,
}
authMiddleware := NewAuthMiddleware(
mockAuth,
func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) {
return userAuth.AccountId, userAuth.UserId, nil
},
func(ctx context.Context, userAuth nbcontext.UserAuth) error {
return nil
},
func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) {
return &types.User{}, nil
},
rateLimitConfig,
)
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
// Make burst requests - all should succeed
successCount := 0
for i := 0; i < 5; i++ {
req := httptest.NewRequest("GET", "http://testing/test", nil)
req.Header.Set("Authorization", "Token "+PAT)
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
if rec.Code == http.StatusOK {
successCount++
}
}
assert.Equal(t, 5, successCount, "All burst requests should succeed")
// The 6th request should fail (exceeded burst)
req := httptest.NewRequest("GET", "http://testing/test", nil)
req.Header.Set("Authorization", "Token "+PAT)
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusTooManyRequests, rec.Code, "Request beyond burst should be rate limited")
})
t.Run("PAT Token Rate Limiting - Rate Limit Enforced", func(t *testing.T) {
// Configure very low rate limit: 1 request per minute
rateLimitConfig := &RateLimiterConfig{
RequestsPerMinute: 1,
Burst: 1,
CleanupInterval: 5 * time.Minute,
LimiterTTL: 10 * time.Minute,
}
authMiddleware := NewAuthMiddleware(
mockAuth,
func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) {
return userAuth.AccountId, userAuth.UserId, nil
},
func(ctx context.Context, userAuth nbcontext.UserAuth) error {
return nil
},
func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) {
return &types.User{}, nil
},
rateLimitConfig,
)
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
// First request should succeed
req := httptest.NewRequest("GET", "http://testing/test", nil)
req.Header.Set("Authorization", "Token "+PAT)
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code, "First request should succeed")
// Second request should fail (rate limited)
req = httptest.NewRequest("GET", "http://testing/test", nil)
req.Header.Set("Authorization", "Token "+PAT)
rec = httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusTooManyRequests, rec.Code, "Second request should be rate limited")
})
t.Run("Bearer Token Not Rate Limited", func(t *testing.T) {
// Configure strict rate limit
rateLimitConfig := &RateLimiterConfig{
RequestsPerMinute: 1,
Burst: 1,
CleanupInterval: 5 * time.Minute,
LimiterTTL: 10 * time.Minute,
}
authMiddleware := NewAuthMiddleware(
mockAuth,
func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) {
return userAuth.AccountId, userAuth.UserId, nil
},
func(ctx context.Context, userAuth nbcontext.UserAuth) error {
return nil
},
func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) {
return &types.User{}, nil
},
rateLimitConfig,
)
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
// Make multiple requests with Bearer token - all should succeed
successCount := 0
for i := 0; i < 10; i++ {
req := httptest.NewRequest("GET", "http://testing/test", nil)
req.Header.Set("Authorization", "Bearer "+JWT)
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
if rec.Code == http.StatusOK {
successCount++
}
}
assert.Equal(t, 10, successCount, "All Bearer token requests should succeed (not rate limited)")
})
t.Run("PAT Token Rate Limiting Per Token", func(t *testing.T) {
// Configure rate limiter
rateLimitConfig := &RateLimiterConfig{
RequestsPerMinute: 1,
Burst: 1,
CleanupInterval: 5 * time.Minute,
LimiterTTL: 10 * time.Minute,
}
authMiddleware := NewAuthMiddleware(
mockAuth,
func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) {
return userAuth.AccountId, userAuth.UserId, nil
},
func(ctx context.Context, userAuth nbcontext.UserAuth) error {
return nil
},
func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) {
return &types.User{}, nil
},
rateLimitConfig,
)
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
// Use first PAT token
req := httptest.NewRequest("GET", "http://testing/test", nil)
req.Header.Set("Authorization", "Token "+PAT)
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code, "First request with PAT should succeed")
// Second request with same token should fail
req = httptest.NewRequest("GET", "http://testing/test", nil)
req.Header.Set("Authorization", "Token "+PAT)
rec = httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusTooManyRequests, rec.Code, "Second request with same PAT should be rate limited")
// Use second PAT token - should succeed because it has independent rate limit
req = httptest.NewRequest("GET", "http://testing/test", nil)
req.Header.Set("Authorization", "Token "+PAT2)
rec = httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code, "First request with PAT2 should succeed (independent rate limit)")
// Second request with PAT2 should also be rate limited
req = httptest.NewRequest("GET", "http://testing/test", nil)
req.Header.Set("Authorization", "Token "+PAT2)
rec = httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusTooManyRequests, rec.Code, "Second request with PAT2 should be rate limited")
// JWT should still work (not rate limited)
req = httptest.NewRequest("GET", "http://testing/test", nil)
req.Header.Set("Authorization", "Bearer "+JWT)
rec = httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code, "JWT request should succeed (not rate limited)")
})
t.Run("Rate Limiter Cleanup", func(t *testing.T) {
// Configure rate limiter with short cleanup interval and TTL for testing
rateLimitConfig := &RateLimiterConfig{
RequestsPerMinute: 60,
Burst: 1,
CleanupInterval: 100 * time.Millisecond,
LimiterTTL: 200 * time.Millisecond,
}
authMiddleware := NewAuthMiddleware(
mockAuth,
func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) {
return userAuth.AccountId, userAuth.UserId, nil
},
func(ctx context.Context, userAuth nbcontext.UserAuth) error {
return nil
},
func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) {
return &types.User{}, nil
},
rateLimitConfig,
)
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
// First request - should succeed
req := httptest.NewRequest("GET", "http://testing/test", nil)
req.Header.Set("Authorization", "Token "+PAT)
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code, "First request should succeed")
// Second request immediately - should fail (burst exhausted)
req = httptest.NewRequest("GET", "http://testing/test", nil)
req.Header.Set("Authorization", "Token "+PAT)
rec = httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusTooManyRequests, rec.Code, "Second request should be rate limited")
// Wait for limiter to be cleaned up (TTL + cleanup interval + buffer)
time.Sleep(400 * time.Millisecond)
// After cleanup, the limiter should be removed and recreated with full burst capacity
req = httptest.NewRequest("GET", "http://testing/test", nil)
req.Header.Set("Authorization", "Token "+PAT)
rec = httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code, "Request after cleanup should succeed (new limiter with full burst)")
// Verify it's a fresh limiter by checking burst is reset
req = httptest.NewRequest("GET", "http://testing/test", nil)
req.Header.Set("Authorization", "Token "+PAT)
rec = httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusTooManyRequests, rec.Code, "Second request after cleanup should be rate limited again")
})
}
func TestAuthMiddleware_Handler_Child(t *testing.T) {
tt := []struct {
name string
@@ -297,6 +579,7 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) {
func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) {
return &types.User{}, nil
},
nil,
)
for _, tc := range tt {

View File

@@ -0,0 +1,146 @@
package middleware
import (
"context"
"sync"
"time"
"golang.org/x/time/rate"
)
// RateLimiterConfig holds configuration for the API rate limiter
type RateLimiterConfig struct {
// RequestsPerMinute defines the rate at which tokens are replenished
RequestsPerMinute float64
// Burst defines the maximum number of requests that can be made in a burst
Burst int
// CleanupInterval defines how often to clean up old limiters (how often garbage collection runs)
CleanupInterval time.Duration
// LimiterTTL defines how long a limiter should be kept after last use (age threshold for removal)
LimiterTTL time.Duration
}
// DefaultRateLimiterConfig returns a default configuration
func DefaultRateLimiterConfig() *RateLimiterConfig {
return &RateLimiterConfig{
RequestsPerMinute: 100,
Burst: 120,
CleanupInterval: 5 * time.Minute,
LimiterTTL: 10 * time.Minute,
}
}
// limiterEntry holds a rate limiter and its last access time
type limiterEntry struct {
limiter *rate.Limiter
lastAccess time.Time
}
// APIRateLimiter manages rate limiting for API tokens
type APIRateLimiter struct {
config *RateLimiterConfig
limiters map[string]*limiterEntry
mu sync.RWMutex
stopChan chan struct{}
}
// NewAPIRateLimiter creates a new API rate limiter with the given configuration
func NewAPIRateLimiter(config *RateLimiterConfig) *APIRateLimiter {
if config == nil {
config = DefaultRateLimiterConfig()
}
rl := &APIRateLimiter{
config: config,
limiters: make(map[string]*limiterEntry),
stopChan: make(chan struct{}),
}
go rl.cleanupLoop()
return rl
}
// Allow checks if a request for the given key (token) is allowed
func (rl *APIRateLimiter) Allow(key string) bool {
limiter := rl.getLimiter(key)
return limiter.Allow()
}
// Wait blocks until the rate limiter allows another request for the given key
// Returns an error if the context is canceled
func (rl *APIRateLimiter) Wait(ctx context.Context, key string) error {
limiter := rl.getLimiter(key)
return limiter.Wait(ctx)
}
// getLimiter retrieves or creates a rate limiter for the given key
func (rl *APIRateLimiter) getLimiter(key string) *rate.Limiter {
rl.mu.RLock()
entry, exists := rl.limiters[key]
rl.mu.RUnlock()
if exists {
rl.mu.Lock()
entry.lastAccess = time.Now()
rl.mu.Unlock()
return entry.limiter
}
rl.mu.Lock()
defer rl.mu.Unlock()
if entry, exists := rl.limiters[key]; exists {
entry.lastAccess = time.Now()
return entry.limiter
}
requestsPerSecond := rl.config.RequestsPerMinute / 60.0
limiter := rate.NewLimiter(rate.Limit(requestsPerSecond), rl.config.Burst)
rl.limiters[key] = &limiterEntry{
limiter: limiter,
lastAccess: time.Now(),
}
return limiter
}
// cleanupLoop periodically removes old limiters that haven't been used recently
func (rl *APIRateLimiter) cleanupLoop() {
ticker := time.NewTicker(rl.config.CleanupInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
rl.cleanup()
case <-rl.stopChan:
return
}
}
}
// cleanup removes limiters that haven't been used within the TTL period
func (rl *APIRateLimiter) cleanup() {
rl.mu.Lock()
defer rl.mu.Unlock()
now := time.Now()
for key, entry := range rl.limiters {
if now.Sub(entry.lastAccess) > rl.config.LimiterTTL {
delete(rl.limiters, key)
}
}
}
// Stop stops the cleanup goroutine
func (rl *APIRateLimiter) Stop() {
close(rl.stopChan)
}
// Reset removes the rate limiter for a specific key
func (rl *APIRateLimiter) Reset(key string) {
rl.mu.Lock()
defer rl.mu.Unlock()
delete(rl.limiters, key)
}

View File

@@ -106,6 +106,8 @@ func WriteError(ctx context.Context, err error, w http.ResponseWriter) {
httpStatus = http.StatusUnauthorized
case status.BadRequest:
httpStatus = http.StatusBadRequest
case status.TooManyRequests:
httpStatus = http.StatusTooManyRequests
default:
}
msg = strings.ToLower(err.Error())

View File

@@ -37,6 +37,9 @@ const (
// Unauthenticated indicates that user is not authenticated due to absence of valid credentials
Unauthenticated Type = 10
// TooManyRequests indicates that the user has sent too many requests in a given amount of time (rate limiting)
TooManyRequests Type = 11
)
// Type is a type of the Error