diff --git a/management/server/http/handlers/proxy/auth.go b/management/server/http/handlers/proxy/auth.go index 27c64ecae..690b9e703 100644 --- a/management/server/http/handlers/proxy/auth.go +++ b/management/server/http/handlers/proxy/auth.go @@ -2,8 +2,11 @@ package proxy import ( "context" + "net" "net/http" "net/url" + "strings" + "time" "github.com/coreos/go-oidc/v3/oidc" "github.com/gorilla/mux" @@ -11,6 +14,7 @@ import ( "golang.org/x/oauth2" nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" + "github.com/netbirdio/netbird/management/server/http/middleware" "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/proxy/auth" ) @@ -18,12 +22,21 @@ import ( // AuthCallbackHandler handles OAuth callbacks for proxy authentication. type AuthCallbackHandler struct { proxyService *nbgrpc.ProxyServiceServer + rateLimiter *middleware.APIRateLimiter } // NewAuthCallbackHandler creates a new OAuth callback handler. func NewAuthCallbackHandler(proxyService *nbgrpc.ProxyServiceServer) *AuthCallbackHandler { + rateLimiterConfig := &middleware.RateLimiterConfig{ + RequestsPerMinute: 10, + Burst: 15, + CleanupInterval: 5 * time.Minute, + LimiterTTL: 10 * time.Minute, + } + return &AuthCallbackHandler{ proxyService: proxyService, + rateLimiter: middleware.NewAPIRateLimiter(rateLimiterConfig), } } @@ -33,6 +46,13 @@ func (h *AuthCallbackHandler) RegisterEndpoints(router *mux.Router) { } func (h *AuthCallbackHandler) handleCallback(w http.ResponseWriter, r *http.Request) { + clientIP := getClientIP(r) + if !h.rateLimiter.Allow(clientIP) { + log.WithField("client_ip", clientIP).Warn("OAuth callback rate limit exceeded") + http.Error(w, "Too many requests. Please try again later.", http.StatusTooManyRequests) + return + } + state := r.URL.Query().Get("state") codeVerifier, originalURL, err := h.proxyService.ValidateState(state) @@ -128,3 +148,24 @@ func extractUserIDFromToken(ctx context.Context, provider *oidc.Provider, config return claims.Subject } + +// getClientIP extracts the client IP address from the request. +func getClientIP(r *http.Request) string { + if xff := r.Header.Get("X-Forwarded-For"); xff != "" { + if idx := strings.Index(xff, ","); idx != -1 { + return strings.TrimSpace(xff[:idx]) + } + return xff + } + + if xri := r.Header.Get("X-Real-IP"); xri != "" { + return xri + } + + // Fall back to RemoteAddr + host, _, err := net.SplitHostPort(r.RemoteAddr) + if err != nil { + return r.RemoteAddr + } + return host +} diff --git a/management/server/http/handlers/proxy/auth_test.go b/management/server/http/handlers/proxy/auth_test.go new file mode 100644 index 000000000..d462d5689 --- /dev/null +++ b/management/server/http/handlers/proxy/auth_test.go @@ -0,0 +1,156 @@ +package proxy + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" +) + +func TestAuthCallbackHandler_RateLimiting(t *testing.T) { + handler := NewAuthCallbackHandler(&nbgrpc.ProxyServiceServer{}) + require.NotNil(t, handler.rateLimiter, "Rate limiter should be initialized") + + req := httptest.NewRequest(http.MethodGet, "/callback?state=test&code=test", nil) + req.RemoteAddr = "192.168.1.100:12345" + + t.Run("allows requests under limit", func(t *testing.T) { + for i := 0; i < 15; i++ { + allowed := handler.rateLimiter.Allow("192.168.1.100") + assert.True(t, allowed, "Request %d should be allowed", i+1) + } + }) + + t.Run("blocks requests over limit", func(t *testing.T) { + handler.rateLimiter.Reset("192.168.1.200") + + for i := 0; i < 15; i++ { + handler.rateLimiter.Allow("192.168.1.200") + } + + allowed := handler.rateLimiter.Allow("192.168.1.200") + assert.False(t, allowed, "Request over limit should be blocked") + }) + + t.Run("different IPs have separate limits", func(t *testing.T) { + ip1 := "192.168.1.201" + ip2 := "192.168.1.202" + + handler.rateLimiter.Reset(ip1) + handler.rateLimiter.Reset(ip2) + + for i := 0; i < 15; i++ { + handler.rateLimiter.Allow(ip1) + } + + assert.False(t, handler.rateLimiter.Allow(ip1), "IP1 should be blocked") + + assert.True(t, handler.rateLimiter.Allow(ip2), "IP2 should be allowed") + }) +} + +func TestAuthCallbackHandler_RateLimitInHandleCallback(t *testing.T) { + handler := NewAuthCallbackHandler(&nbgrpc.ProxyServiceServer{}) + testIP := "10.0.0.50" + + handler.rateLimiter.Reset(testIP) + + t.Run("returns 429 when rate limited", func(t *testing.T) { + for i := 0; i < 15; i++ { + handler.rateLimiter.Allow(testIP) + } + + req := httptest.NewRequest(http.MethodGet, "/callback?state=test&code=test", nil) + req.RemoteAddr = testIP + ":12345" + + rr := httptest.NewRecorder() + handler.handleCallback(rr, req) + + assert.Equal(t, http.StatusTooManyRequests, rr.Code, "Should return 429 status code") + assert.Contains(t, rr.Body.String(), "Too many requests", "Should contain rate limit message") + }) +} + +func TestGetClientIP(t *testing.T) { + tests := []struct { + name string + remoteAddr string + xForwardedFor string + xRealIP string + expectedIP string + }{ + { + name: "extract from RemoteAddr", + remoteAddr: "192.168.1.100:12345", + expectedIP: "192.168.1.100", + }, + { + name: "extract from X-Forwarded-For single IP", + remoteAddr: "10.0.0.1:54321", + xForwardedFor: "203.0.113.195", + expectedIP: "203.0.113.195", + }, + { + name: "extract from X-Forwarded-For multiple IPs", + remoteAddr: "10.0.0.1:54321", + xForwardedFor: "203.0.113.195, 70.41.3.18, 150.172.238.178", + expectedIP: "203.0.113.195", + }, + { + name: "extract from X-Real-IP", + remoteAddr: "10.0.0.1:54321", + xRealIP: "198.51.100.42", + expectedIP: "198.51.100.42", + }, + { + name: "X-Forwarded-For takes precedence over X-Real-IP", + remoteAddr: "10.0.0.1:54321", + xForwardedFor: "203.0.113.195", + xRealIP: "198.51.100.42", + expectedIP: "203.0.113.195", + }, + { + name: "handle RemoteAddr without port", + remoteAddr: "192.168.1.100", + expectedIP: "192.168.1.100", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/test", nil) + req.RemoteAddr = tt.remoteAddr + + if tt.xForwardedFor != "" { + req.Header.Set("X-Forwarded-For", tt.xForwardedFor) + } + if tt.xRealIP != "" { + req.Header.Set("X-Real-IP", tt.xRealIP) + } + + ip := getClientIP(req) + assert.Equal(t, tt.expectedIP, ip, "Extracted IP should match expected") + }) + } +} + +func TestAuthCallbackHandler_RateLimiterConfiguration(t *testing.T) { + handler := NewAuthCallbackHandler(&nbgrpc.ProxyServiceServer{}) + + require.NotNil(t, handler.rateLimiter, "Rate limiter should be initialized") + + testIP := "192.168.1.250" + handler.rateLimiter.Reset(testIP) + + for i := 0; i < 15; i++ { + allowed := handler.rateLimiter.Allow(testIP) + assert.True(t, allowed, "Should allow request %d within burst limit", i+1) + } + + allowed := handler.rateLimiter.Allow(testIP) + assert.False(t, allowed, "Should block request that exceeds burst limit") +}