diff --git a/management/internals/server/boot.go b/management/internals/server/boot.go index 36e28504d..7da1e6898 100644 --- a/management/internals/server/boot.go +++ b/management/internals/server/boot.go @@ -94,7 +94,7 @@ func (s *BaseServer) EventStore() activity.Store { func (s *BaseServer) APIHandler() http.Handler { return Create(s, func() http.Handler { - httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.IdpManager(), s.ReverseProxyManager(), s.ReverseProxyDomainManager(), s.AccessLogsManager(), s.ReverseProxyGRPCServer()) + httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.IdpManager(), s.ReverseProxyManager(), s.ReverseProxyDomainManager(), s.AccessLogsManager(), s.ReverseProxyGRPCServer(), s.Config.ReverseProxy.TrustedHTTPProxies) if err != nil { log.Fatalf("failed to create API handler: %v", err) } diff --git a/management/server/http/handler.go b/management/server/http/handler.go index 6e8b3cc55..9d2384cae 100644 --- a/management/server/http/handler.go +++ b/management/server/http/handler.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "net/http" + "net/netip" "os" "strconv" "time" @@ -13,15 +14,18 @@ import ( log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain/manager" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs" reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/manager" + nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" idpmanager "github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/management-integrations/integrations" + "github.com/netbirdio/netbird/management/internals/controllers/network_map" "github.com/netbirdio/netbird/management/internals/modules/zones" zonesManager "github.com/netbirdio/netbird/management/internals/modules/zones/manager" @@ -33,6 +37,8 @@ import ( "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/permissions" + "github.com/netbirdio/netbird/management/server/http/handlers/proxy" + nbpeers "github.com/netbirdio/netbird/management/internals/modules/peers" "github.com/netbirdio/netbird/management/server/auth" "github.com/netbirdio/netbird/management/server/geolocation" @@ -46,7 +52,6 @@ import ( "github.com/netbirdio/netbird/management/server/http/handlers/networks" "github.com/netbirdio/netbird/management/server/http/handlers/peers" "github.com/netbirdio/netbird/management/server/http/handlers/policies" - "github.com/netbirdio/netbird/management/server/http/handlers/proxy" "github.com/netbirdio/netbird/management/server/http/handlers/routes" "github.com/netbirdio/netbird/management/server/http/handlers/setup_keys" "github.com/netbirdio/netbird/management/server/http/handlers/users" @@ -68,7 +73,7 @@ const ( ) // NewAPIHandler creates the Management service HTTP API handler registering all the available endpoints. -func NewAPIHandler(ctx context.Context, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, integratedValidator integrated_validator.IntegratedValidator, proxyController port_forwarding.Controller, permissionsManager permissions.Manager, peersManager nbpeers.Manager, settingsManager settings.Manager, zManager zones.Manager, rManager records.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager, reverseProxyManager reverseproxy.Manager, reverseProxyDomainManager *manager.Manager, reverseProxyAccessLogsManager accesslogs.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer) (http.Handler, error) { +func NewAPIHandler(ctx context.Context, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, integratedValidator integrated_validator.IntegratedValidator, proxyController port_forwarding.Controller, permissionsManager permissions.Manager, peersManager nbpeers.Manager, settingsManager settings.Manager, zManager zones.Manager, rManager records.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager, reverseProxyManager reverseproxy.Manager, reverseProxyDomainManager *manager.Manager, reverseProxyAccessLogsManager accesslogs.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, trustedHTTPProxies []netip.Prefix) (http.Handler, error) { // Register bypass paths for unauthenticated endpoints if err := bypass.AddBypassPath("/api/instance"); err != nil { @@ -174,7 +179,7 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks // Register OAuth callback handler for proxy authentication if proxyGRPCServer != nil { - oauthHandler := proxy.NewAuthCallbackHandler(proxyGRPCServer) + oauthHandler := proxy.NewAuthCallbackHandler(proxyGRPCServer, trustedHTTPProxies) oauthHandler.RegisterEndpoints(router) } diff --git a/management/server/http/handlers/proxy/auth.go b/management/server/http/handlers/proxy/auth.go index 690b9e703..0120fad0e 100644 --- a/management/server/http/handlers/proxy/auth.go +++ b/management/server/http/handlers/proxy/auth.go @@ -4,6 +4,7 @@ import ( "context" "net" "net/http" + "net/netip" "net/url" "strings" "time" @@ -21,12 +22,13 @@ import ( // AuthCallbackHandler handles OAuth callbacks for proxy authentication. type AuthCallbackHandler struct { - proxyService *nbgrpc.ProxyServiceServer - rateLimiter *middleware.APIRateLimiter + proxyService *nbgrpc.ProxyServiceServer + rateLimiter *middleware.APIRateLimiter + trustedProxies []netip.Prefix } // NewAuthCallbackHandler creates a new OAuth callback handler. -func NewAuthCallbackHandler(proxyService *nbgrpc.ProxyServiceServer) *AuthCallbackHandler { +func NewAuthCallbackHandler(proxyService *nbgrpc.ProxyServiceServer, trustedProxies []netip.Prefix) *AuthCallbackHandler { rateLimiterConfig := &middleware.RateLimiterConfig{ RequestsPerMinute: 10, Burst: 15, @@ -35,8 +37,9 @@ func NewAuthCallbackHandler(proxyService *nbgrpc.ProxyServiceServer) *AuthCallba } return &AuthCallbackHandler{ - proxyService: proxyService, - rateLimiter: middleware.NewAPIRateLimiter(rateLimiterConfig), + proxyService: proxyService, + rateLimiter: middleware.NewAPIRateLimiter(rateLimiterConfig), + trustedProxies: trustedProxies, } } @@ -46,7 +49,7 @@ func (h *AuthCallbackHandler) RegisterEndpoints(router *mux.Router) { } func (h *AuthCallbackHandler) handleCallback(w http.ResponseWriter, r *http.Request) { - clientIP := getClientIP(r) + clientIP := h.resolveClientIP(r) if !h.rateLimiter.Allow(clientIP) { log.WithField("client_ip", clientIP).Warn("OAuth callback rate limit exceeded") http.Error(w, "Too many requests. Please try again later.", http.StatusTooManyRequests) @@ -149,23 +152,57 @@ func extractUserIDFromToken(ctx context.Context, provider *oidc.Provider, config return claims.Subject } -// getClientIP extracts the client IP address from the request. -func getClientIP(r *http.Request) string { - if xff := r.Header.Get("X-Forwarded-For"); xff != "" { - if idx := strings.Index(xff, ","); idx != -1 { - return strings.TrimSpace(xff[:idx]) +// resolveClientIP extracts the real client IP from the request. +// When trustedProxies is non-empty and the direct peer is trusted, +// it walks X-Forwarded-For right-to-left skipping trusted IPs. +// Otherwise it returns RemoteAddr directly. +func (h *AuthCallbackHandler) resolveClientIP(r *http.Request) string { + remoteIP := extractHost(r.RemoteAddr) + + if len(h.trustedProxies) == 0 || !isTrustedProxy(remoteIP, h.trustedProxies) { + return remoteIP + } + + xff := r.Header.Get("X-Forwarded-For") + if xff == "" { + return remoteIP + } + + parts := strings.Split(xff, ",") + for i := len(parts) - 1; i >= 0; i-- { + ip := strings.TrimSpace(parts[i]) + if ip == "" { + continue + } + if !isTrustedProxy(ip, h.trustedProxies) { + return ip } - return xff } - if xri := r.Header.Get("X-Real-IP"); xri != "" { - return xri + // All IPs in XFF are trusted; return the leftmost as best guess. + if first := strings.TrimSpace(parts[0]); first != "" { + return first } + return remoteIP +} - // Fall back to RemoteAddr - host, _, err := net.SplitHostPort(r.RemoteAddr) +func extractHost(remoteAddr string) string { + host, _, err := net.SplitHostPort(remoteAddr) if err != nil { - return r.RemoteAddr + return remoteAddr } return host } + +func isTrustedProxy(ipStr string, trusted []netip.Prefix) bool { + addr, err := netip.ParseAddr(ipStr) + if err != nil { + return false + } + for _, prefix := range trusted { + if prefix.Contains(addr) { + return true + } + } + return false +} diff --git a/management/server/http/handlers/proxy/auth_callback_integration_test.go b/management/server/http/handlers/proxy/auth_callback_integration_test.go index ee3a6a93d..0a9a560cd 100644 --- a/management/server/http/handlers/proxy/auth_callback_integration_test.go +++ b/management/server/http/handlers/proxy/auth_callback_integration_test.go @@ -200,7 +200,7 @@ func setupAuthCallbackTest(t *testing.T) *testSetup { proxyService.SetProxyManager(&testServiceManager{store: testStore}) - handler := NewAuthCallbackHandler(proxyService) + handler := NewAuthCallbackHandler(proxyService, nil) router := mux.NewRouter() handler.RegisterEndpoints(router) diff --git a/management/server/http/handlers/proxy/auth_test.go b/management/server/http/handlers/proxy/auth_test.go index d462d5689..360405474 100644 --- a/management/server/http/handlers/proxy/auth_test.go +++ b/management/server/http/handlers/proxy/auth_test.go @@ -3,6 +3,7 @@ package proxy import ( "net/http" "net/http/httptest" + "net/netip" "testing" "github.com/stretchr/testify/assert" @@ -12,7 +13,7 @@ import ( ) func TestAuthCallbackHandler_RateLimiting(t *testing.T) { - handler := NewAuthCallbackHandler(&nbgrpc.ProxyServiceServer{}) + handler := NewAuthCallbackHandler(&nbgrpc.ProxyServiceServer{}, nil) require.NotNil(t, handler.rateLimiter, "Rate limiter should be initialized") req := httptest.NewRequest(http.MethodGet, "/callback?state=test&code=test", nil) @@ -54,7 +55,7 @@ func TestAuthCallbackHandler_RateLimiting(t *testing.T) { } func TestAuthCallbackHandler_RateLimitInHandleCallback(t *testing.T) { - handler := NewAuthCallbackHandler(&nbgrpc.ProxyServiceServer{}) + handler := NewAuthCallbackHandler(&nbgrpc.ProxyServiceServer{}, nil) testIP := "10.0.0.50" handler.rateLimiter.Reset(testIP) @@ -75,46 +76,76 @@ func TestAuthCallbackHandler_RateLimitInHandleCallback(t *testing.T) { }) } -func TestGetClientIP(t *testing.T) { +func TestResolveClientIP(t *testing.T) { + trusted := []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/8"), + netip.MustParsePrefix("172.16.0.0/12"), + } + tests := []struct { name string remoteAddr string xForwardedFor string - xRealIP string + trustedProxy []netip.Prefix expectedIP string }{ { - name: "extract from RemoteAddr", - remoteAddr: "192.168.1.100:12345", - expectedIP: "192.168.1.100", + name: "no trusted proxies returns RemoteAddr", + remoteAddr: "203.0.113.50:9999", + xForwardedFor: "1.2.3.4", + trustedProxy: nil, + expectedIP: "203.0.113.50", }, { - name: "extract from X-Forwarded-For single IP", - remoteAddr: "10.0.0.1:54321", - xForwardedFor: "203.0.113.195", - expectedIP: "203.0.113.195", + name: "untrusted RemoteAddr ignores XFF", + remoteAddr: "203.0.113.50:9999", + xForwardedFor: "1.2.3.4, 10.0.0.1", + trustedProxy: trusted, + expectedIP: "203.0.113.50", }, { - name: "extract from X-Forwarded-For multiple IPs", - remoteAddr: "10.0.0.1:54321", - xForwardedFor: "203.0.113.195, 70.41.3.18, 150.172.238.178", - expectedIP: "203.0.113.195", + name: "trusted RemoteAddr with single client in XFF", + remoteAddr: "10.0.0.1:5000", + xForwardedFor: "203.0.113.50", + trustedProxy: trusted, + expectedIP: "203.0.113.50", }, { - name: "extract from X-Real-IP", - remoteAddr: "10.0.0.1:54321", - xRealIP: "198.51.100.42", - expectedIP: "198.51.100.42", + name: "trusted RemoteAddr walks past trusted entries in XFF", + remoteAddr: "10.0.0.1:5000", + xForwardedFor: "203.0.113.50, 10.0.0.2, 172.16.0.5", + trustedProxy: trusted, + expectedIP: "203.0.113.50", }, { - name: "X-Forwarded-For takes precedence over X-Real-IP", - remoteAddr: "10.0.0.1:54321", - xForwardedFor: "203.0.113.195", - xRealIP: "198.51.100.42", - expectedIP: "203.0.113.195", + name: "trusted RemoteAddr with empty XFF falls back to RemoteAddr", + remoteAddr: "10.0.0.1:5000", + trustedProxy: trusted, + expectedIP: "10.0.0.1", }, { - name: "handle RemoteAddr without port", + name: "all XFF IPs trusted returns leftmost", + remoteAddr: "10.0.0.1:5000", + xForwardedFor: "10.0.0.2, 172.16.0.1, 10.0.0.3", + trustedProxy: trusted, + expectedIP: "10.0.0.2", + }, + { + name: "XFF with whitespace", + remoteAddr: "10.0.0.1:5000", + xForwardedFor: " 203.0.113.50 , 10.0.0.2 ", + trustedProxy: trusted, + expectedIP: "203.0.113.50", + }, + { + name: "multi-hop with mixed trust", + remoteAddr: "10.0.0.1:5000", + xForwardedFor: "8.8.8.8, 203.0.113.50, 172.16.0.1", + trustedProxy: trusted, + expectedIP: "203.0.113.50", + }, + { + name: "RemoteAddr without port", remoteAddr: "192.168.1.100", expectedIP: "192.168.1.100", }, @@ -122,24 +153,22 @@ func TestGetClientIP(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + handler := NewAuthCallbackHandler(&nbgrpc.ProxyServiceServer{}, tt.trustedProxy) + req := httptest.NewRequest(http.MethodGet, "/test", nil) req.RemoteAddr = tt.remoteAddr - if tt.xForwardedFor != "" { req.Header.Set("X-Forwarded-For", tt.xForwardedFor) } - if tt.xRealIP != "" { - req.Header.Set("X-Real-IP", tt.xRealIP) - } - ip := getClientIP(req) - assert.Equal(t, tt.expectedIP, ip, "Extracted IP should match expected") + ip := handler.resolveClientIP(req) + assert.Equal(t, tt.expectedIP, ip) }) } } func TestAuthCallbackHandler_RateLimiterConfiguration(t *testing.T) { - handler := NewAuthCallbackHandler(&nbgrpc.ProxyServiceServer{}) + handler := NewAuthCallbackHandler(&nbgrpc.ProxyServiceServer{}, nil) require.NotNil(t, handler.rateLimiter, "Rate limiter should be initialized") diff --git a/management/server/http/testing/testing_tools/channel/channel.go b/management/server/http/testing/testing_tools/channel/channel.go index ad9d56d5e..c5e21b591 100644 --- a/management/server/http/testing/testing_tools/channel/channel.go +++ b/management/server/http/testing/testing_tools/channel/channel.go @@ -114,7 +114,7 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee customZonesManager := zonesManager.NewManager(store, am, permissionsManager, "") zoneRecordsManager := recordsManager.NewManager(store, am, permissionsManager) - apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManagerMock, resourcesManagerMock, routersManagerMock, groupsManagerMock, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, reverseProxyManager, nil, nil, nil) + apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManagerMock, resourcesManagerMock, routersManagerMock, groupsManagerMock, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, reverseProxyManager, nil, nil, nil, nil) if err != nil { t.Fatalf("Failed to create API handler: %v", err) }