mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-18 08:16:39 +00:00
[management] adapt ratelimiting (#5080)
This commit is contained in:
@@ -178,7 +178,7 @@ func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts []
|
|||||||
m.patUsageTracker.IncrementUsage(token)
|
m.patUsageTracker.IncrementUsage(token)
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.rateLimiter != nil {
|
if m.rateLimiter != nil && !isTerraformRequest(r) {
|
||||||
if !m.rateLimiter.Allow(token) {
|
if !m.rateLimiter.Allow(token) {
|
||||||
return r, status.Errorf(status.TooManyRequests, "too many requests")
|
return r, status.Errorf(status.TooManyRequests, "too many requests")
|
||||||
}
|
}
|
||||||
@@ -214,6 +214,11 @@ func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts []
|
|||||||
return nbcontext.SetUserAuthInRequest(r, userAuth), nil
|
return nbcontext.SetUserAuthInRequest(r, userAuth), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func isTerraformRequest(r *http.Request) bool {
|
||||||
|
ua := strings.ToLower(r.Header.Get("User-Agent"))
|
||||||
|
return strings.Contains(ua, "terraform")
|
||||||
|
}
|
||||||
|
|
||||||
// getTokenFromJWTRequest is a "TokenExtractor" that takes auth header parts and extracts
|
// getTokenFromJWTRequest is a "TokenExtractor" that takes auth header parts and extracts
|
||||||
// the JWT token from the Authorization header.
|
// the JWT token from the Authorization header.
|
||||||
func getTokenFromJWTRequest(authHeaderParts []string) (string, error) {
|
func getTokenFromJWTRequest(authHeaderParts []string) (string, error) {
|
||||||
|
|||||||
@@ -508,6 +508,103 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
|||||||
handler.ServeHTTP(rec, req)
|
handler.ServeHTTP(rec, req)
|
||||||
assert.Equal(t, http.StatusTooManyRequests, rec.Code, "Second request after cleanup should be rate limited again")
|
assert.Equal(t, http.StatusTooManyRequests, rec.Code, "Second request after cleanup should be rate limited again")
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("Terraform User Agent Not Rate Limited", func(t *testing.T) {
|
||||||
|
rateLimitConfig := &RateLimiterConfig{
|
||||||
|
RequestsPerMinute: 1,
|
||||||
|
Burst: 1,
|
||||||
|
CleanupInterval: 5 * time.Minute,
|
||||||
|
LimiterTTL: 10 * time.Minute,
|
||||||
|
}
|
||||||
|
|
||||||
|
authMiddleware := NewAuthMiddleware(
|
||||||
|
mockAuth,
|
||||||
|
func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) {
|
||||||
|
return userAuth.AccountId, userAuth.UserId, nil
|
||||||
|
},
|
||||||
|
func(ctx context.Context, userAuth nbauth.UserAuth) error {
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
|
||||||
|
return &types.User{}, nil
|
||||||
|
},
|
||||||
|
rateLimitConfig,
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
|
||||||
|
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
|
||||||
|
// Test various Terraform user agent formats
|
||||||
|
terraformUserAgents := []string{
|
||||||
|
"Terraform/1.5.0",
|
||||||
|
"terraform/1.0.0",
|
||||||
|
"Terraform-Provider/2.0.0",
|
||||||
|
"Mozilla/5.0 (compatible; Terraform/1.3.0)",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, userAgent := range terraformUserAgents {
|
||||||
|
t.Run("UserAgent: "+userAgent, func(t *testing.T) {
|
||||||
|
successCount := 0
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
req := httptest.NewRequest("GET", "http://testing/test", nil)
|
||||||
|
req.Header.Set("Authorization", "Token "+PAT)
|
||||||
|
req.Header.Set("User-Agent", userAgent)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
|
handler.ServeHTTP(rec, req)
|
||||||
|
if rec.Code == http.StatusOK {
|
||||||
|
successCount++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, 10, successCount, "All Terraform user agent requests should succeed (not rate limited)")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Non-Terraform User Agent With PAT Is Rate Limited", func(t *testing.T) {
|
||||||
|
rateLimitConfig := &RateLimiterConfig{
|
||||||
|
RequestsPerMinute: 1,
|
||||||
|
Burst: 1,
|
||||||
|
CleanupInterval: 5 * time.Minute,
|
||||||
|
LimiterTTL: 10 * time.Minute,
|
||||||
|
}
|
||||||
|
|
||||||
|
authMiddleware := NewAuthMiddleware(
|
||||||
|
mockAuth,
|
||||||
|
func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) {
|
||||||
|
return userAuth.AccountId, userAuth.UserId, nil
|
||||||
|
},
|
||||||
|
func(ctx context.Context, userAuth nbauth.UserAuth) error {
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
|
||||||
|
return &types.User{}, nil
|
||||||
|
},
|
||||||
|
rateLimitConfig,
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
|
||||||
|
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
|
||||||
|
req := httptest.NewRequest("GET", "http://testing/test", nil)
|
||||||
|
req.Header.Set("Authorization", "Token "+PAT)
|
||||||
|
req.Header.Set("User-Agent", "curl/7.68.0")
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(rec, req)
|
||||||
|
assert.Equal(t, http.StatusOK, rec.Code, "First request should succeed")
|
||||||
|
|
||||||
|
req = httptest.NewRequest("GET", "http://testing/test", nil)
|
||||||
|
req.Header.Set("Authorization", "Token "+PAT)
|
||||||
|
req.Header.Set("User-Agent", "curl/7.68.0")
|
||||||
|
rec = httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(rec, req)
|
||||||
|
assert.Equal(t, http.StatusTooManyRequests, rec.Code, "Second request should be rate limited")
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAuthMiddleware_Handler_Child(t *testing.T) {
|
func TestAuthMiddleware_Handler_Child(t *testing.T) {
|
||||||
|
|||||||
Reference in New Issue
Block a user