diff --git a/go.mod b/go.mod index 68a12908d..7b9bae321 100644 --- a/go.mod +++ b/go.mod @@ -108,6 +108,7 @@ require ( golang.org/x/oauth2 v0.30.0 golang.org/x/sync v0.16.0 golang.org/x/term v0.33.0 + golang.org/x/time v0.12.0 google.golang.org/api v0.177.0 gopkg.in/yaml.v3 v3.0.1 gorm.io/driver/mysql v1.5.7 @@ -245,7 +246,6 @@ require ( go.uber.org/multierr v1.11.0 // indirect golang.org/x/image v0.18.0 // indirect golang.org/x/text v0.27.0 // indirect - golang.org/x/time v0.12.0 // indirect golang.org/x/tools v0.34.0 // indirect golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20250707201910-8d1bb00bc6a7 // indirect diff --git a/management/server/http/handler.go b/management/server/http/handler.go index 3d4de31d0..4d2c224b4 100644 --- a/management/server/http/handler.go +++ b/management/server/http/handler.go @@ -4,9 +4,13 @@ import ( "context" "fmt" "net/http" + "os" + "strconv" + "time" "github.com/gorilla/mux" "github.com/rs/cors" + log "github.com/sirupsen/logrus" "github.com/netbirdio/management-integrations/integrations" @@ -38,7 +42,12 @@ import ( "github.com/netbirdio/netbird/management/server/telemetry" ) -const apiPrefix = "/api" +const ( + apiPrefix = "/api" + rateLimitingEnabledKey = "NB_API_RATE_LIMITING_ENABLED" + rateLimitingBurstKey = "NB_API_RATE_LIMITING_BURST" + rateLimitingRPMKey = "NB_API_RATE_LIMITING_RPM" +) // NewAPIHandler creates the Management service HTTP API handler registering all the available endpoints. func NewAPIHandler( @@ -58,11 +67,42 @@ func NewAPIHandler( settingsManager settings.Manager, ) (http.Handler, error) { + var rateLimitingConfig *middleware.RateLimiterConfig + if os.Getenv(rateLimitingEnabledKey) == "true" { + rpm := 6 + if v := os.Getenv(rateLimitingRPMKey); v != "" { + value, err := strconv.Atoi(v) + if err != nil { + log.Warnf("parsing %s env var: %v, using default %d", rateLimitingRPMKey, err, rpm) + } else { + rpm = value + } + } + + burst := 500 + if v := os.Getenv(rateLimitingBurstKey); v != "" { + value, err := strconv.Atoi(v) + if err != nil { + log.Warnf("parsing %s env var: %v, using default %d", rateLimitingBurstKey, err, burst) + } else { + burst = value + } + } + + rateLimitingConfig = &middleware.RateLimiterConfig{ + RequestsPerMinute: float64(rpm), + Burst: burst, + CleanupInterval: 6 * time.Hour, + LimiterTTL: 24 * time.Hour, + } + } + authMiddleware := middleware.NewAuthMiddleware( authManager, accountManager.GetAccountIDFromUserAuth, accountManager.SyncUserJWTGroups, accountManager.GetUserFromUserAuth, + rateLimitingConfig, ) corsMiddleware := cors.AllowAll() diff --git a/management/server/http/middleware/auth_middleware.go b/management/server/http/middleware/auth_middleware.go index 6091a4c31..bce917a25 100644 --- a/management/server/http/middleware/auth_middleware.go +++ b/management/server/http/middleware/auth_middleware.go @@ -29,6 +29,7 @@ type AuthMiddleware struct { ensureAccount EnsureAccountFunc getUserFromUserAuth GetUserFromUserAuthFunc syncUserJWTGroups SyncUserJWTGroupsFunc + rateLimiter *APIRateLimiter } // NewAuthMiddleware instance constructor @@ -37,12 +38,19 @@ func NewAuthMiddleware( ensureAccount EnsureAccountFunc, syncUserJWTGroups SyncUserJWTGroupsFunc, getUserFromUserAuth GetUserFromUserAuthFunc, + rateLimiterConfig *RateLimiterConfig, ) *AuthMiddleware { + var rateLimiter *APIRateLimiter + if rateLimiterConfig != nil { + rateLimiter = NewAPIRateLimiter(rateLimiterConfig) + } + return &AuthMiddleware{ authManager: authManager, ensureAccount: ensureAccount, syncUserJWTGroups: syncUserJWTGroups, getUserFromUserAuth: getUserFromUserAuth, + rateLimiter: rateLimiter, } } @@ -76,7 +84,11 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler { request, err := m.checkPATFromRequest(r, auth) if err != nil { log.WithContext(r.Context()).Debugf("Error when validating PAT: %s", err.Error()) - util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "token invalid"), w) + // Check if it's a status error, otherwise default to Unauthorized + if _, ok := status.FromError(err); !ok { + err = status.Errorf(status.Unauthorized, "token invalid") + } + util.WriteError(r.Context(), err, w) return } h.ServeHTTP(w, request) @@ -145,6 +157,12 @@ func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, auth []string) (*h return r, fmt.Errorf("error extracting token: %w", err) } + if m.rateLimiter != nil { + if !m.rateLimiter.Allow(token) { + return r, status.Errorf(status.TooManyRequests, "too many requests") + } + } + ctx := r.Context() user, pat, accDomain, accCategory, err := m.authManager.GetPATInfo(ctx, token) if err != nil { diff --git a/management/server/http/middleware/auth_middleware_test.go b/management/server/http/middleware/auth_middleware_test.go index d815f5422..d1bd9959f 100644 --- a/management/server/http/middleware/auth_middleware_test.go +++ b/management/server/http/middleware/auth_middleware_test.go @@ -27,7 +27,9 @@ const ( domainCategory = "domainCategory" userID = "userID" tokenID = "tokenID" + tokenID2 = "tokenID2" PAT = "nbp_PAT" + PAT2 = "nbp_PAT2" JWT = "JWT" wrongToken = "wrongToken" ) @@ -49,6 +51,15 @@ var testAccount = &types.Account{ CreatedAt: time.Now().UTC(), LastUsed: util.ToPtr(time.Now().UTC()), }, + tokenID2: { + ID: tokenID2, + Name: "My second token", + HashedToken: "someHash2", + ExpirationDate: util.ToPtr(time.Now().UTC().AddDate(0, 0, 7)), + CreatedBy: userID, + CreatedAt: time.Now().UTC(), + LastUsed: util.ToPtr(time.Now().UTC()), + }, }, }, }, @@ -58,6 +69,9 @@ func mockGetAccountInfoFromPAT(_ context.Context, token string) (user *types.Use if token == PAT { return testAccount.Users[userID], testAccount.Users[userID].PATs[tokenID], testAccount.Domain, testAccount.DomainCategory, nil } + if token == PAT2 { + return testAccount.Users[userID], testAccount.Users[userID].PATs[tokenID2], testAccount.Domain, testAccount.DomainCategory, nil + } return nil, nil, "", "", fmt.Errorf("PAT invalid") } @@ -81,7 +95,7 @@ func mockValidateAndParseToken(_ context.Context, token string) (nbcontext.UserA } func mockMarkPATUsed(_ context.Context, token string) error { - if token == tokenID { + if token == tokenID || token == tokenID2 { return nil } return fmt.Errorf("Should never get reached") @@ -192,6 +206,7 @@ func TestAuthMiddleware_Handler(t *testing.T) { func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) { return &types.User{}, nil }, + nil, ) handlerToTest := authMiddleware.Handler(nextHandler) @@ -221,6 +236,273 @@ func TestAuthMiddleware_Handler(t *testing.T) { } } +func TestAuthMiddleware_RateLimiting(t *testing.T) { + mockAuth := &auth.MockManager{ + ValidateAndParseTokenFunc: mockValidateAndParseToken, + EnsureUserAccessByJWTGroupsFunc: mockEnsureUserAccessByJWTGroups, + MarkPATUsedFunc: mockMarkPATUsed, + GetPATInfoFunc: mockGetAccountInfoFromPAT, + } + + t.Run("PAT Token Rate Limiting - Burst Works", func(t *testing.T) { + // Configure rate limiter: 10 requests per minute with burst of 5 + rateLimitConfig := &RateLimiterConfig{ + RequestsPerMinute: 10, + Burst: 5, + CleanupInterval: 5 * time.Minute, + LimiterTTL: 10 * time.Minute, + } + + authMiddleware := NewAuthMiddleware( + mockAuth, + func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) { + return userAuth.AccountId, userAuth.UserId, nil + }, + func(ctx context.Context, userAuth nbcontext.UserAuth) error { + return nil + }, + func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) { + return &types.User{}, nil + }, + rateLimitConfig, + ) + + handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + // Make burst requests - all should succeed + successCount := 0 + for i := 0; i < 5; i++ { + req := httptest.NewRequest("GET", "http://testing/test", nil) + req.Header.Set("Authorization", "Token "+PAT) + rec := httptest.NewRecorder() + + handler.ServeHTTP(rec, req) + if rec.Code == http.StatusOK { + successCount++ + } + } + + assert.Equal(t, 5, successCount, "All burst requests should succeed") + + // The 6th request should fail (exceeded burst) + req := httptest.NewRequest("GET", "http://testing/test", nil) + req.Header.Set("Authorization", "Token "+PAT) + rec := httptest.NewRecorder() + + handler.ServeHTTP(rec, req) + assert.Equal(t, http.StatusTooManyRequests, rec.Code, "Request beyond burst should be rate limited") + }) + + t.Run("PAT Token Rate Limiting - Rate Limit Enforced", func(t *testing.T) { + // Configure very low rate limit: 1 request per minute + rateLimitConfig := &RateLimiterConfig{ + RequestsPerMinute: 1, + Burst: 1, + CleanupInterval: 5 * time.Minute, + LimiterTTL: 10 * time.Minute, + } + + authMiddleware := NewAuthMiddleware( + mockAuth, + func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) { + return userAuth.AccountId, userAuth.UserId, nil + }, + func(ctx context.Context, userAuth nbcontext.UserAuth) error { + return nil + }, + func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) { + return &types.User{}, nil + }, + rateLimitConfig, + ) + + handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + // First request should succeed + req := httptest.NewRequest("GET", "http://testing/test", nil) + req.Header.Set("Authorization", "Token "+PAT) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + assert.Equal(t, http.StatusOK, rec.Code, "First request should succeed") + + // Second request should fail (rate limited) + req = httptest.NewRequest("GET", "http://testing/test", nil) + req.Header.Set("Authorization", "Token "+PAT) + rec = httptest.NewRecorder() + handler.ServeHTTP(rec, req) + assert.Equal(t, http.StatusTooManyRequests, rec.Code, "Second request should be rate limited") + }) + + t.Run("Bearer Token Not Rate Limited", func(t *testing.T) { + // Configure strict rate limit + rateLimitConfig := &RateLimiterConfig{ + RequestsPerMinute: 1, + Burst: 1, + CleanupInterval: 5 * time.Minute, + LimiterTTL: 10 * time.Minute, + } + + authMiddleware := NewAuthMiddleware( + mockAuth, + func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) { + return userAuth.AccountId, userAuth.UserId, nil + }, + func(ctx context.Context, userAuth nbcontext.UserAuth) error { + return nil + }, + func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) { + return &types.User{}, nil + }, + rateLimitConfig, + ) + + handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + // Make multiple requests with Bearer token - all should succeed + successCount := 0 + for i := 0; i < 10; i++ { + req := httptest.NewRequest("GET", "http://testing/test", nil) + req.Header.Set("Authorization", "Bearer "+JWT) + rec := httptest.NewRecorder() + + handler.ServeHTTP(rec, req) + if rec.Code == http.StatusOK { + successCount++ + } + } + + assert.Equal(t, 10, successCount, "All Bearer token requests should succeed (not rate limited)") + }) + + t.Run("PAT Token Rate Limiting Per Token", func(t *testing.T) { + // Configure rate limiter + rateLimitConfig := &RateLimiterConfig{ + RequestsPerMinute: 1, + Burst: 1, + CleanupInterval: 5 * time.Minute, + LimiterTTL: 10 * time.Minute, + } + + authMiddleware := NewAuthMiddleware( + mockAuth, + func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) { + return userAuth.AccountId, userAuth.UserId, nil + }, + func(ctx context.Context, userAuth nbcontext.UserAuth) error { + return nil + }, + func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) { + return &types.User{}, nil + }, + rateLimitConfig, + ) + + handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + // Use first PAT token + req := httptest.NewRequest("GET", "http://testing/test", nil) + req.Header.Set("Authorization", "Token "+PAT) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + assert.Equal(t, http.StatusOK, rec.Code, "First request with PAT should succeed") + + // Second request with same token should fail + req = httptest.NewRequest("GET", "http://testing/test", nil) + req.Header.Set("Authorization", "Token "+PAT) + rec = httptest.NewRecorder() + handler.ServeHTTP(rec, req) + assert.Equal(t, http.StatusTooManyRequests, rec.Code, "Second request with same PAT should be rate limited") + + // Use second PAT token - should succeed because it has independent rate limit + req = httptest.NewRequest("GET", "http://testing/test", nil) + req.Header.Set("Authorization", "Token "+PAT2) + rec = httptest.NewRecorder() + handler.ServeHTTP(rec, req) + assert.Equal(t, http.StatusOK, rec.Code, "First request with PAT2 should succeed (independent rate limit)") + + // Second request with PAT2 should also be rate limited + req = httptest.NewRequest("GET", "http://testing/test", nil) + req.Header.Set("Authorization", "Token "+PAT2) + rec = httptest.NewRecorder() + handler.ServeHTTP(rec, req) + assert.Equal(t, http.StatusTooManyRequests, rec.Code, "Second request with PAT2 should be rate limited") + + // JWT should still work (not rate limited) + req = httptest.NewRequest("GET", "http://testing/test", nil) + req.Header.Set("Authorization", "Bearer "+JWT) + rec = httptest.NewRecorder() + handler.ServeHTTP(rec, req) + assert.Equal(t, http.StatusOK, rec.Code, "JWT request should succeed (not rate limited)") + }) + + t.Run("Rate Limiter Cleanup", func(t *testing.T) { + // Configure rate limiter with short cleanup interval and TTL for testing + rateLimitConfig := &RateLimiterConfig{ + RequestsPerMinute: 60, + Burst: 1, + CleanupInterval: 100 * time.Millisecond, + LimiterTTL: 200 * time.Millisecond, + } + + authMiddleware := NewAuthMiddleware( + mockAuth, + func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) { + return userAuth.AccountId, userAuth.UserId, nil + }, + func(ctx context.Context, userAuth nbcontext.UserAuth) error { + return nil + }, + func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) { + return &types.User{}, nil + }, + rateLimitConfig, + ) + + handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + // First request - should succeed + req := httptest.NewRequest("GET", "http://testing/test", nil) + req.Header.Set("Authorization", "Token "+PAT) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + assert.Equal(t, http.StatusOK, rec.Code, "First request should succeed") + + // Second request immediately - should fail (burst exhausted) + req = httptest.NewRequest("GET", "http://testing/test", nil) + req.Header.Set("Authorization", "Token "+PAT) + rec = httptest.NewRecorder() + handler.ServeHTTP(rec, req) + assert.Equal(t, http.StatusTooManyRequests, rec.Code, "Second request should be rate limited") + + // Wait for limiter to be cleaned up (TTL + cleanup interval + buffer) + time.Sleep(400 * time.Millisecond) + + // After cleanup, the limiter should be removed and recreated with full burst capacity + req = httptest.NewRequest("GET", "http://testing/test", nil) + req.Header.Set("Authorization", "Token "+PAT) + rec = httptest.NewRecorder() + handler.ServeHTTP(rec, req) + assert.Equal(t, http.StatusOK, rec.Code, "Request after cleanup should succeed (new limiter with full burst)") + + // Verify it's a fresh limiter by checking burst is reset + req = httptest.NewRequest("GET", "http://testing/test", nil) + req.Header.Set("Authorization", "Token "+PAT) + rec = httptest.NewRecorder() + handler.ServeHTTP(rec, req) + assert.Equal(t, http.StatusTooManyRequests, rec.Code, "Second request after cleanup should be rate limited again") + }) +} + func TestAuthMiddleware_Handler_Child(t *testing.T) { tt := []struct { name string @@ -297,6 +579,7 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) { func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) { return &types.User{}, nil }, + nil, ) for _, tc := range tt { diff --git a/management/server/http/middleware/rate_limiter.go b/management/server/http/middleware/rate_limiter.go new file mode 100644 index 000000000..a6266d4f3 --- /dev/null +++ b/management/server/http/middleware/rate_limiter.go @@ -0,0 +1,146 @@ +package middleware + +import ( + "context" + "sync" + "time" + + "golang.org/x/time/rate" +) + +// RateLimiterConfig holds configuration for the API rate limiter +type RateLimiterConfig struct { + // RequestsPerMinute defines the rate at which tokens are replenished + RequestsPerMinute float64 + // Burst defines the maximum number of requests that can be made in a burst + Burst int + // CleanupInterval defines how often to clean up old limiters (how often garbage collection runs) + CleanupInterval time.Duration + // LimiterTTL defines how long a limiter should be kept after last use (age threshold for removal) + LimiterTTL time.Duration +} + +// DefaultRateLimiterConfig returns a default configuration +func DefaultRateLimiterConfig() *RateLimiterConfig { + return &RateLimiterConfig{ + RequestsPerMinute: 100, + Burst: 120, + CleanupInterval: 5 * time.Minute, + LimiterTTL: 10 * time.Minute, + } +} + +// limiterEntry holds a rate limiter and its last access time +type limiterEntry struct { + limiter *rate.Limiter + lastAccess time.Time +} + +// APIRateLimiter manages rate limiting for API tokens +type APIRateLimiter struct { + config *RateLimiterConfig + limiters map[string]*limiterEntry + mu sync.RWMutex + stopChan chan struct{} +} + +// NewAPIRateLimiter creates a new API rate limiter with the given configuration +func NewAPIRateLimiter(config *RateLimiterConfig) *APIRateLimiter { + if config == nil { + config = DefaultRateLimiterConfig() + } + + rl := &APIRateLimiter{ + config: config, + limiters: make(map[string]*limiterEntry), + stopChan: make(chan struct{}), + } + + go rl.cleanupLoop() + + return rl +} + +// Allow checks if a request for the given key (token) is allowed +func (rl *APIRateLimiter) Allow(key string) bool { + limiter := rl.getLimiter(key) + return limiter.Allow() +} + +// Wait blocks until the rate limiter allows another request for the given key +// Returns an error if the context is canceled +func (rl *APIRateLimiter) Wait(ctx context.Context, key string) error { + limiter := rl.getLimiter(key) + return limiter.Wait(ctx) +} + +// getLimiter retrieves or creates a rate limiter for the given key +func (rl *APIRateLimiter) getLimiter(key string) *rate.Limiter { + rl.mu.RLock() + entry, exists := rl.limiters[key] + rl.mu.RUnlock() + + if exists { + rl.mu.Lock() + entry.lastAccess = time.Now() + rl.mu.Unlock() + return entry.limiter + } + + rl.mu.Lock() + defer rl.mu.Unlock() + + if entry, exists := rl.limiters[key]; exists { + entry.lastAccess = time.Now() + return entry.limiter + } + + requestsPerSecond := rl.config.RequestsPerMinute / 60.0 + limiter := rate.NewLimiter(rate.Limit(requestsPerSecond), rl.config.Burst) + rl.limiters[key] = &limiterEntry{ + limiter: limiter, + lastAccess: time.Now(), + } + + return limiter +} + +// cleanupLoop periodically removes old limiters that haven't been used recently +func (rl *APIRateLimiter) cleanupLoop() { + ticker := time.NewTicker(rl.config.CleanupInterval) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + rl.cleanup() + case <-rl.stopChan: + return + } + } +} + +// cleanup removes limiters that haven't been used within the TTL period +func (rl *APIRateLimiter) cleanup() { + rl.mu.Lock() + defer rl.mu.Unlock() + + now := time.Now() + for key, entry := range rl.limiters { + if now.Sub(entry.lastAccess) > rl.config.LimiterTTL { + delete(rl.limiters, key) + } + } +} + +// Stop stops the cleanup goroutine +func (rl *APIRateLimiter) Stop() { + close(rl.stopChan) +} + +// Reset removes the rate limiter for a specific key +func (rl *APIRateLimiter) Reset(key string) { + rl.mu.Lock() + defer rl.mu.Unlock() + delete(rl.limiters, key) +} diff --git a/shared/management/http/util/util.go b/shared/management/http/util/util.go index 3ae321023..0a29469da 100644 --- a/shared/management/http/util/util.go +++ b/shared/management/http/util/util.go @@ -106,6 +106,8 @@ func WriteError(ctx context.Context, err error, w http.ResponseWriter) { httpStatus = http.StatusUnauthorized case status.BadRequest: httpStatus = http.StatusBadRequest + case status.TooManyRequests: + httpStatus = http.StatusTooManyRequests default: } msg = strings.ToLower(err.Error()) diff --git a/shared/management/status/error.go b/shared/management/status/error.go index 1e914babb..09676847e 100644 --- a/shared/management/status/error.go +++ b/shared/management/status/error.go @@ -37,6 +37,9 @@ const ( // Unauthenticated indicates that user is not authenticated due to absence of valid credentials Unauthenticated Type = 10 + + // TooManyRequests indicates that the user has sent too many requests in a given amount of time (rate limiting) + TooManyRequests Type = 11 ) // Type is a type of the Error