diff --git a/management/server/http/middleware/auth_middleware.go b/management/server/http/middleware/auth_middleware.go index 966a6802a..257347153 100644 --- a/management/server/http/middleware/auth_middleware.go +++ b/management/server/http/middleware/auth_middleware.go @@ -178,7 +178,7 @@ func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts [] m.patUsageTracker.IncrementUsage(token) } - if m.rateLimiter != nil { + if m.rateLimiter != nil && !isTerraformRequest(r) { if !m.rateLimiter.Allow(token) { return r, status.Errorf(status.TooManyRequests, "too many requests") } @@ -214,6 +214,11 @@ func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts [] return nbcontext.SetUserAuthInRequest(r, userAuth), nil } +func isTerraformRequest(r *http.Request) bool { + ua := strings.ToLower(r.Header.Get("User-Agent")) + return strings.Contains(ua, "terraform") +} + // getTokenFromJWTRequest is a "TokenExtractor" that takes auth header parts and extracts // the JWT token from the Authorization header. func getTokenFromJWTRequest(authHeaderParts []string) (string, error) { diff --git a/management/server/http/middleware/auth_middleware_test.go b/management/server/http/middleware/auth_middleware_test.go index ba4d16796..05ca59419 100644 --- a/management/server/http/middleware/auth_middleware_test.go +++ b/management/server/http/middleware/auth_middleware_test.go @@ -508,6 +508,103 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) { handler.ServeHTTP(rec, req) assert.Equal(t, http.StatusTooManyRequests, rec.Code, "Second request after cleanup should be rate limited again") }) + + t.Run("Terraform User Agent Not Rate Limited", func(t *testing.T) { + rateLimitConfig := &RateLimiterConfig{ + RequestsPerMinute: 1, + Burst: 1, + CleanupInterval: 5 * time.Minute, + LimiterTTL: 10 * time.Minute, + } + + authMiddleware := NewAuthMiddleware( + mockAuth, + func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) { + return userAuth.AccountId, userAuth.UserId, nil + }, + func(ctx context.Context, userAuth nbauth.UserAuth) error { + return nil + }, + func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) { + return &types.User{}, nil + }, + rateLimitConfig, + nil, + ) + + handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + // Test various Terraform user agent formats + terraformUserAgents := []string{ + "Terraform/1.5.0", + "terraform/1.0.0", + "Terraform-Provider/2.0.0", + "Mozilla/5.0 (compatible; Terraform/1.3.0)", + } + + for _, userAgent := range terraformUserAgents { + t.Run("UserAgent: "+userAgent, func(t *testing.T) { + successCount := 0 + for i := 0; i < 10; i++ { + req := httptest.NewRequest("GET", "http://testing/test", nil) + req.Header.Set("Authorization", "Token "+PAT) + req.Header.Set("User-Agent", userAgent) + rec := httptest.NewRecorder() + + handler.ServeHTTP(rec, req) + if rec.Code == http.StatusOK { + successCount++ + } + } + + assert.Equal(t, 10, successCount, "All Terraform user agent requests should succeed (not rate limited)") + }) + } + }) + + t.Run("Non-Terraform User Agent With PAT Is Rate Limited", func(t *testing.T) { + rateLimitConfig := &RateLimiterConfig{ + RequestsPerMinute: 1, + Burst: 1, + CleanupInterval: 5 * time.Minute, + LimiterTTL: 10 * time.Minute, + } + + authMiddleware := NewAuthMiddleware( + mockAuth, + func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) { + return userAuth.AccountId, userAuth.UserId, nil + }, + func(ctx context.Context, userAuth nbauth.UserAuth) error { + return nil + }, + func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) { + return &types.User{}, nil + }, + rateLimitConfig, + nil, + ) + + handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest("GET", "http://testing/test", nil) + req.Header.Set("Authorization", "Token "+PAT) + req.Header.Set("User-Agent", "curl/7.68.0") + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + assert.Equal(t, http.StatusOK, rec.Code, "First request should succeed") + + req = httptest.NewRequest("GET", "http://testing/test", nil) + req.Header.Set("Authorization", "Token "+PAT) + req.Header.Set("User-Agent", "curl/7.68.0") + rec = httptest.NewRecorder() + handler.ServeHTTP(rec, req) + assert.Equal(t, http.StatusTooManyRequests, rec.Code, "Second request should be rate limited") + }) } func TestAuthMiddleware_Handler_Child(t *testing.T) {