Validate trusted proxies in OAuth callback getClientIP

This commit is contained in:
Viktor Liu
2026-02-12 21:15:41 +08:00
parent 7fdb824a37
commit 9554934b92
6 changed files with 126 additions and 55 deletions

View File

@@ -4,6 +4,7 @@ import (
"context"
"fmt"
"net/http"
"net/netip"
"os"
"strconv"
"time"
@@ -13,15 +14,18 @@ import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain/manager"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/manager"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
idpmanager "github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
"github.com/netbirdio/netbird/management/internals/modules/zones"
zonesManager "github.com/netbirdio/netbird/management/internals/modules/zones/manager"
@@ -33,6 +37,8 @@ import (
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/http/handlers/proxy"
nbpeers "github.com/netbirdio/netbird/management/internals/modules/peers"
"github.com/netbirdio/netbird/management/server/auth"
"github.com/netbirdio/netbird/management/server/geolocation"
@@ -46,7 +52,6 @@ import (
"github.com/netbirdio/netbird/management/server/http/handlers/networks"
"github.com/netbirdio/netbird/management/server/http/handlers/peers"
"github.com/netbirdio/netbird/management/server/http/handlers/policies"
"github.com/netbirdio/netbird/management/server/http/handlers/proxy"
"github.com/netbirdio/netbird/management/server/http/handlers/routes"
"github.com/netbirdio/netbird/management/server/http/handlers/setup_keys"
"github.com/netbirdio/netbird/management/server/http/handlers/users"
@@ -68,7 +73,7 @@ const (
)
// NewAPIHandler creates the Management service HTTP API handler registering all the available endpoints.
func NewAPIHandler(ctx context.Context, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, integratedValidator integrated_validator.IntegratedValidator, proxyController port_forwarding.Controller, permissionsManager permissions.Manager, peersManager nbpeers.Manager, settingsManager settings.Manager, zManager zones.Manager, rManager records.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager, reverseProxyManager reverseproxy.Manager, reverseProxyDomainManager *manager.Manager, reverseProxyAccessLogsManager accesslogs.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer) (http.Handler, error) {
func NewAPIHandler(ctx context.Context, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, integratedValidator integrated_validator.IntegratedValidator, proxyController port_forwarding.Controller, permissionsManager permissions.Manager, peersManager nbpeers.Manager, settingsManager settings.Manager, zManager zones.Manager, rManager records.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager, reverseProxyManager reverseproxy.Manager, reverseProxyDomainManager *manager.Manager, reverseProxyAccessLogsManager accesslogs.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, trustedHTTPProxies []netip.Prefix) (http.Handler, error) {
// Register bypass paths for unauthenticated endpoints
if err := bypass.AddBypassPath("/api/instance"); err != nil {
@@ -174,7 +179,7 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks
// Register OAuth callback handler for proxy authentication
if proxyGRPCServer != nil {
oauthHandler := proxy.NewAuthCallbackHandler(proxyGRPCServer)
oauthHandler := proxy.NewAuthCallbackHandler(proxyGRPCServer, trustedHTTPProxies)
oauthHandler.RegisterEndpoints(router)
}

View File

@@ -4,6 +4,7 @@ import (
"context"
"net"
"net/http"
"net/netip"
"net/url"
"strings"
"time"
@@ -21,12 +22,13 @@ import (
// AuthCallbackHandler handles OAuth callbacks for proxy authentication.
type AuthCallbackHandler struct {
proxyService *nbgrpc.ProxyServiceServer
rateLimiter *middleware.APIRateLimiter
proxyService *nbgrpc.ProxyServiceServer
rateLimiter *middleware.APIRateLimiter
trustedProxies []netip.Prefix
}
// NewAuthCallbackHandler creates a new OAuth callback handler.
func NewAuthCallbackHandler(proxyService *nbgrpc.ProxyServiceServer) *AuthCallbackHandler {
func NewAuthCallbackHandler(proxyService *nbgrpc.ProxyServiceServer, trustedProxies []netip.Prefix) *AuthCallbackHandler {
rateLimiterConfig := &middleware.RateLimiterConfig{
RequestsPerMinute: 10,
Burst: 15,
@@ -35,8 +37,9 @@ func NewAuthCallbackHandler(proxyService *nbgrpc.ProxyServiceServer) *AuthCallba
}
return &AuthCallbackHandler{
proxyService: proxyService,
rateLimiter: middleware.NewAPIRateLimiter(rateLimiterConfig),
proxyService: proxyService,
rateLimiter: middleware.NewAPIRateLimiter(rateLimiterConfig),
trustedProxies: trustedProxies,
}
}
@@ -46,7 +49,7 @@ func (h *AuthCallbackHandler) RegisterEndpoints(router *mux.Router) {
}
func (h *AuthCallbackHandler) handleCallback(w http.ResponseWriter, r *http.Request) {
clientIP := getClientIP(r)
clientIP := h.resolveClientIP(r)
if !h.rateLimiter.Allow(clientIP) {
log.WithField("client_ip", clientIP).Warn("OAuth callback rate limit exceeded")
http.Error(w, "Too many requests. Please try again later.", http.StatusTooManyRequests)
@@ -149,23 +152,57 @@ func extractUserIDFromToken(ctx context.Context, provider *oidc.Provider, config
return claims.Subject
}
// getClientIP extracts the client IP address from the request.
func getClientIP(r *http.Request) string {
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
if idx := strings.Index(xff, ","); idx != -1 {
return strings.TrimSpace(xff[:idx])
// resolveClientIP extracts the real client IP from the request.
// When trustedProxies is non-empty and the direct peer is trusted,
// it walks X-Forwarded-For right-to-left skipping trusted IPs.
// Otherwise it returns RemoteAddr directly.
func (h *AuthCallbackHandler) resolveClientIP(r *http.Request) string {
remoteIP := extractHost(r.RemoteAddr)
if len(h.trustedProxies) == 0 || !isTrustedProxy(remoteIP, h.trustedProxies) {
return remoteIP
}
xff := r.Header.Get("X-Forwarded-For")
if xff == "" {
return remoteIP
}
parts := strings.Split(xff, ",")
for i := len(parts) - 1; i >= 0; i-- {
ip := strings.TrimSpace(parts[i])
if ip == "" {
continue
}
if !isTrustedProxy(ip, h.trustedProxies) {
return ip
}
return xff
}
if xri := r.Header.Get("X-Real-IP"); xri != "" {
return xri
// All IPs in XFF are trusted; return the leftmost as best guess.
if first := strings.TrimSpace(parts[0]); first != "" {
return first
}
return remoteIP
}
// Fall back to RemoteAddr
host, _, err := net.SplitHostPort(r.RemoteAddr)
func extractHost(remoteAddr string) string {
host, _, err := net.SplitHostPort(remoteAddr)
if err != nil {
return r.RemoteAddr
return remoteAddr
}
return host
}
func isTrustedProxy(ipStr string, trusted []netip.Prefix) bool {
addr, err := netip.ParseAddr(ipStr)
if err != nil {
return false
}
for _, prefix := range trusted {
if prefix.Contains(addr) {
return true
}
}
return false
}

View File

@@ -200,7 +200,7 @@ func setupAuthCallbackTest(t *testing.T) *testSetup {
proxyService.SetProxyManager(&testServiceManager{store: testStore})
handler := NewAuthCallbackHandler(proxyService)
handler := NewAuthCallbackHandler(proxyService, nil)
router := mux.NewRouter()
handler.RegisterEndpoints(router)

View File

@@ -3,6 +3,7 @@ package proxy
import (
"net/http"
"net/http/httptest"
"net/netip"
"testing"
"github.com/stretchr/testify/assert"
@@ -12,7 +13,7 @@ import (
)
func TestAuthCallbackHandler_RateLimiting(t *testing.T) {
handler := NewAuthCallbackHandler(&nbgrpc.ProxyServiceServer{})
handler := NewAuthCallbackHandler(&nbgrpc.ProxyServiceServer{}, nil)
require.NotNil(t, handler.rateLimiter, "Rate limiter should be initialized")
req := httptest.NewRequest(http.MethodGet, "/callback?state=test&code=test", nil)
@@ -54,7 +55,7 @@ func TestAuthCallbackHandler_RateLimiting(t *testing.T) {
}
func TestAuthCallbackHandler_RateLimitInHandleCallback(t *testing.T) {
handler := NewAuthCallbackHandler(&nbgrpc.ProxyServiceServer{})
handler := NewAuthCallbackHandler(&nbgrpc.ProxyServiceServer{}, nil)
testIP := "10.0.0.50"
handler.rateLimiter.Reset(testIP)
@@ -75,46 +76,76 @@ func TestAuthCallbackHandler_RateLimitInHandleCallback(t *testing.T) {
})
}
func TestGetClientIP(t *testing.T) {
func TestResolveClientIP(t *testing.T) {
trusted := []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/8"),
netip.MustParsePrefix("172.16.0.0/12"),
}
tests := []struct {
name string
remoteAddr string
xForwardedFor string
xRealIP string
trustedProxy []netip.Prefix
expectedIP string
}{
{
name: "extract from RemoteAddr",
remoteAddr: "192.168.1.100:12345",
expectedIP: "192.168.1.100",
name: "no trusted proxies returns RemoteAddr",
remoteAddr: "203.0.113.50:9999",
xForwardedFor: "1.2.3.4",
trustedProxy: nil,
expectedIP: "203.0.113.50",
},
{
name: "extract from X-Forwarded-For single IP",
remoteAddr: "10.0.0.1:54321",
xForwardedFor: "203.0.113.195",
expectedIP: "203.0.113.195",
name: "untrusted RemoteAddr ignores XFF",
remoteAddr: "203.0.113.50:9999",
xForwardedFor: "1.2.3.4, 10.0.0.1",
trustedProxy: trusted,
expectedIP: "203.0.113.50",
},
{
name: "extract from X-Forwarded-For multiple IPs",
remoteAddr: "10.0.0.1:54321",
xForwardedFor: "203.0.113.195, 70.41.3.18, 150.172.238.178",
expectedIP: "203.0.113.195",
name: "trusted RemoteAddr with single client in XFF",
remoteAddr: "10.0.0.1:5000",
xForwardedFor: "203.0.113.50",
trustedProxy: trusted,
expectedIP: "203.0.113.50",
},
{
name: "extract from X-Real-IP",
remoteAddr: "10.0.0.1:54321",
xRealIP: "198.51.100.42",
expectedIP: "198.51.100.42",
name: "trusted RemoteAddr walks past trusted entries in XFF",
remoteAddr: "10.0.0.1:5000",
xForwardedFor: "203.0.113.50, 10.0.0.2, 172.16.0.5",
trustedProxy: trusted,
expectedIP: "203.0.113.50",
},
{
name: "X-Forwarded-For takes precedence over X-Real-IP",
remoteAddr: "10.0.0.1:54321",
xForwardedFor: "203.0.113.195",
xRealIP: "198.51.100.42",
expectedIP: "203.0.113.195",
name: "trusted RemoteAddr with empty XFF falls back to RemoteAddr",
remoteAddr: "10.0.0.1:5000",
trustedProxy: trusted,
expectedIP: "10.0.0.1",
},
{
name: "handle RemoteAddr without port",
name: "all XFF IPs trusted returns leftmost",
remoteAddr: "10.0.0.1:5000",
xForwardedFor: "10.0.0.2, 172.16.0.1, 10.0.0.3",
trustedProxy: trusted,
expectedIP: "10.0.0.2",
},
{
name: "XFF with whitespace",
remoteAddr: "10.0.0.1:5000",
xForwardedFor: " 203.0.113.50 , 10.0.0.2 ",
trustedProxy: trusted,
expectedIP: "203.0.113.50",
},
{
name: "multi-hop with mixed trust",
remoteAddr: "10.0.0.1:5000",
xForwardedFor: "8.8.8.8, 203.0.113.50, 172.16.0.1",
trustedProxy: trusted,
expectedIP: "203.0.113.50",
},
{
name: "RemoteAddr without port",
remoteAddr: "192.168.1.100",
expectedIP: "192.168.1.100",
},
@@ -122,24 +153,22 @@ func TestGetClientIP(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
handler := NewAuthCallbackHandler(&nbgrpc.ProxyServiceServer{}, tt.trustedProxy)
req := httptest.NewRequest(http.MethodGet, "/test", nil)
req.RemoteAddr = tt.remoteAddr
if tt.xForwardedFor != "" {
req.Header.Set("X-Forwarded-For", tt.xForwardedFor)
}
if tt.xRealIP != "" {
req.Header.Set("X-Real-IP", tt.xRealIP)
}
ip := getClientIP(req)
assert.Equal(t, tt.expectedIP, ip, "Extracted IP should match expected")
ip := handler.resolveClientIP(req)
assert.Equal(t, tt.expectedIP, ip)
})
}
}
func TestAuthCallbackHandler_RateLimiterConfiguration(t *testing.T) {
handler := NewAuthCallbackHandler(&nbgrpc.ProxyServiceServer{})
handler := NewAuthCallbackHandler(&nbgrpc.ProxyServiceServer{}, nil)
require.NotNil(t, handler.rateLimiter, "Rate limiter should be initialized")

View File

@@ -114,7 +114,7 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
customZonesManager := zonesManager.NewManager(store, am, permissionsManager, "")
zoneRecordsManager := recordsManager.NewManager(store, am, permissionsManager)
apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManagerMock, resourcesManagerMock, routersManagerMock, groupsManagerMock, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, reverseProxyManager, nil, nil, nil)
apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManagerMock, resourcesManagerMock, routersManagerMock, groupsManagerMock, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, reverseProxyManager, nil, nil, nil, nil)
if err != nil {
t.Fatalf("Failed to create API handler: %v", err)
}