Set forwarded headers from trusted proxies only

This commit is contained in:
Viktor Liu
2026-02-08 17:31:10 +08:00
parent 5190923c70
commit ed58659a01
12 changed files with 608 additions and 57 deletions

View File

@@ -44,6 +44,7 @@ var (
oidcEndpoint string oidcEndpoint string
oidcScopes string oidcScopes string
forwardedProto string forwardedProto string
trustedProxies string
) )
var rootCmd = &cobra.Command{ var rootCmd = &cobra.Command{
@@ -72,6 +73,7 @@ func init() {
rootCmd.Flags().StringVar(&oidcEndpoint, "oidc-endpoint", envStringOrDefault("NB_PROXY_OIDC_ENDPOINT", ""), "The OIDC Endpoint for OIDC User Authentication") rootCmd.Flags().StringVar(&oidcEndpoint, "oidc-endpoint", envStringOrDefault("NB_PROXY_OIDC_ENDPOINT", ""), "The OIDC Endpoint for OIDC User Authentication")
rootCmd.Flags().StringVar(&oidcScopes, "oidc-scopes", envStringOrDefault("NB_PROXY_OIDC_SCOPES", "openid,profile,email"), "The OAuth2 scopes for OIDC User Authentication, comma separated") rootCmd.Flags().StringVar(&oidcScopes, "oidc-scopes", envStringOrDefault("NB_PROXY_OIDC_SCOPES", "openid,profile,email"), "The OAuth2 scopes for OIDC User Authentication, comma separated")
rootCmd.Flags().StringVar(&forwardedProto, "forwarded-proto", envStringOrDefault("NB_PROXY_FORWARDED_PROTO", "auto"), "X-Forwarded-Proto value for backends: auto, http, or https") rootCmd.Flags().StringVar(&forwardedProto, "forwarded-proto", envStringOrDefault("NB_PROXY_FORWARDED_PROTO", "auto"), "X-Forwarded-Proto value for backends: auto, http, or https")
rootCmd.Flags().StringVar(&trustedProxies, "trusted-proxies", envStringOrDefault("NB_PROXY_TRUSTED_PROXIES", ""), "Comma-separated list of trusted upstream proxy CIDR ranges (e.g. '10.0.0.0/8,192.168.1.1')")
} }
// Execute runs the root command. // Execute runs the root command.
@@ -113,6 +115,11 @@ func runServer(cmd *cobra.Command, args []string) error {
return fmt.Errorf("invalid --forwarded-proto value %q: must be auto, http, or https", forwardedProto) return fmt.Errorf("invalid --forwarded-proto value %q: must be auto, http, or https", forwardedProto)
} }
parsedTrustedProxies, err := proxy.ParseTrustedProxies(trustedProxies)
if err != nil {
return fmt.Errorf("invalid --trusted-proxies: %w", err)
}
srv := proxy.Server{ srv := proxy.Server{
Logger: logger, Logger: logger,
Version: Version, Version: Version,
@@ -131,6 +138,7 @@ func runServer(cmd *cobra.Command, args []string) error {
OIDCEndpoint: oidcEndpoint, OIDCEndpoint: oidcEndpoint,
OIDCScopes: strings.Split(oidcScopes, ","), OIDCScopes: strings.Split(oidcScopes, ","),
ForwardedProto: forwardedProto, ForwardedProto: forwardedProto,
TrustedProxies: parsedTrustedProxies,
} }
if err := srv.ListenAndServe(context.TODO(), addr); err != nil { if err := srv.ListenAndServe(context.TODO(), addr); err != nil {

View File

@@ -2,6 +2,7 @@ package accesslog
import ( import (
"context" "context"
"net/netip"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"google.golang.org/grpc" "google.golang.org/grpc"
@@ -14,18 +15,24 @@ type gRPCClient interface {
SendAccessLog(ctx context.Context, in *proto.SendAccessLogRequest, opts ...grpc.CallOption) (*proto.SendAccessLogResponse, error) SendAccessLog(ctx context.Context, in *proto.SendAccessLogRequest, opts ...grpc.CallOption) (*proto.SendAccessLogResponse, error)
} }
// Logger sends access log entries to the management server via gRPC.
type Logger struct { type Logger struct {
client gRPCClient client gRPCClient
logger *log.Logger logger *log.Logger
trustedProxies []netip.Prefix
} }
func NewLogger(client gRPCClient, logger *log.Logger) *Logger { // NewLogger creates a new access log Logger. The trustedProxies parameter
// configures which upstream proxy IP ranges are trusted for extracting
// the real client IP from X-Forwarded-For headers.
func NewLogger(client gRPCClient, logger *log.Logger, trustedProxies []netip.Prefix) *Logger {
if logger == nil { if logger == nil {
logger = log.StandardLogger() logger = log.StandardLogger()
} }
return &Logger{ return &Logger{
client: client, client: client,
logger: logger, logger: logger,
trustedProxies: trustedProxies,
} }
} }

View File

@@ -24,14 +24,15 @@ func (l *Logger) Middleware(next http.Handler) http.Handler {
status: http.StatusOK, status: http.StatusOK,
} }
// Get the source IP before passing the request on as the proxy will modify // Resolve the source IP using trusted proxy configuration before passing
// headers that we wish to use to gather that information on the request. // the request on, as the proxy will modify forwarding headers.
sourceIp := extractSourceIP(r) sourceIp := extractSourceIP(r, l.trustedProxies)
// Create a mutable struct to capture data from downstream handlers. // Create a mutable struct to capture data from downstream handlers.
// We pass a pointer in the context - the pointer itself flows down immutably, // We pass a pointer in the context - the pointer itself flows down immutably,
// but the struct it points to can be mutated by inner handlers. // but the struct it points to can be mutated by inner handlers.
capturedData := &proxy.CapturedData{RequestID: requestID} capturedData := &proxy.CapturedData{RequestID: requestID}
capturedData.SetClientIP(sourceIp)
ctx := proxy.WithCapturedData(r.Context(), capturedData) ctx := proxy.WithCapturedData(r.Context(), capturedData)
start := time.Now() start := time.Now()

View File

@@ -1,43 +1,16 @@
package accesslog package accesslog
import ( import (
"net"
"net/http" "net/http"
"slices" "net/netip"
"strings"
"github.com/netbirdio/netbird/proxy/internal/proxy"
) )
// requestIP attempts to extract the source IP from a request. // extractSourceIP resolves the real client IP from the request using trusted
// Adapted from https://husobee.github.io/golang/ip-address/2015/12/17/remote-ip-go.html // proxy configuration. When trustedProxies is non-empty and the direct
// with the addition of some newer stdlib functions that are now // connection is from a trusted source, it walks X-Forwarded-For right-to-left
// available. // skipping trusted IPs. Otherwise it returns RemoteAddr directly.
// The concept here is to look backwards through IP headers until func extractSourceIP(r *http.Request, trustedProxies []netip.Prefix) string {
// the first public IP address is found. The hypothesis is that return proxy.ResolveClientIP(r.RemoteAddr, r.Header.Get("X-Forwarded-For"), trustedProxies)
// even if there are multiple IP addresses specified in these headers,
// the last public IP should be the hop immediately before reaching
// the server and therefore represents the "true" source IP regardless
// of the number of intermediate proxies or network hops.
func extractSourceIP(r *http.Request) string {
for _, h := range []string{"X-Forwarded-For", "X-Real-IP"} {
addresses := strings.Split(r.Header.Get(h), ",")
// Iterate from right to left until we get a public address
// that should be the address right before our proxy.
for _, address := range slices.Backward(addresses) {
// Trim the address because sometimes clients put whitespace in there.
ip := strings.TrimSpace(address)
// Parse the IP so that we can easily check whether it is a valid public address.
realIP := net.ParseIP(ip)
if !realIP.IsGlobalUnicast() || realIP.IsPrivate() || realIP.IsLoopback() {
continue
}
return ip
}
}
// Fallback to the requests RemoteAddr, this is least likely to be correct but
// should at least yield something in the event that the above has failed.
ip, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
ip = r.RemoteAddr
}
return ip
} }

View File

@@ -50,6 +50,7 @@ type CapturedData struct {
ServiceId string ServiceId string
AccountId types.AccountID AccountId types.AccountID
Origin ResponseOrigin Origin ResponseOrigin
ClientIP string
} }
// GetRequestID safely gets the request ID // GetRequestID safely gets the request ID
@@ -101,6 +102,20 @@ func (c *CapturedData) GetOrigin() ResponseOrigin {
return c.Origin return c.Origin
} }
// SetClientIP safely sets the resolved client IP.
func (c *CapturedData) SetClientIP(ip string) {
c.mu.Lock()
defer c.mu.Unlock()
c.ClientIP = ip
}
// GetClientIP safely gets the resolved client IP.
func (c *CapturedData) GetClientIP() string {
c.mu.RLock()
defer c.mu.RUnlock()
return c.ClientIP
}
// WithCapturedData adds a CapturedData struct to the context // WithCapturedData adds a CapturedData struct to the context
func WithCapturedData(ctx context.Context, data *CapturedData) context.Context { func WithCapturedData(ctx context.Context, data *CapturedData) context.Context {
return context.WithValue(ctx, capturedDataKey, data) return context.WithValue(ctx, capturedDataKey, data)

View File

@@ -6,6 +6,7 @@ import (
"net" "net"
"net/http" "net/http"
"net/http/httputil" "net/http/httputil"
"net/netip"
"net/url" "net/url"
"strings" "strings"
"sync" "sync"
@@ -22,6 +23,10 @@ type ReverseProxy struct {
// forwardedProto overrides the X-Forwarded-Proto header value. // forwardedProto overrides the X-Forwarded-Proto header value.
// Valid values: "auto" (detect from TLS), "http", "https". // Valid values: "auto" (detect from TLS), "http", "https".
forwardedProto string forwardedProto string
// trustedProxies is a list of IP prefixes for trusted upstream proxies.
// When the direct connection comes from a trusted proxy, forwarding
// headers are preserved and appended to instead of being stripped.
trustedProxies []netip.Prefix
mappingsMux sync.RWMutex mappingsMux sync.RWMutex
mappings map[string]Mapping mappings map[string]Mapping
logger *log.Logger logger *log.Logger
@@ -33,13 +38,14 @@ type ReverseProxy struct {
// between requested URLs and targets. // between requested URLs and targets.
// The internal mappings can be modified using the AddMapping // The internal mappings can be modified using the AddMapping
// and RemoveMapping functions. // and RemoveMapping functions.
func NewReverseProxy(transport http.RoundTripper, forwardedProto string, logger *log.Logger) *ReverseProxy { func NewReverseProxy(transport http.RoundTripper, forwardedProto string, trustedProxies []netip.Prefix, logger *log.Logger) *ReverseProxy {
if logger == nil { if logger == nil {
logger = log.StandardLogger() logger = log.StandardLogger()
} }
return &ReverseProxy{ return &ReverseProxy{
transport: transport, transport: transport,
forwardedProto: forwardedProto, forwardedProto: forwardedProto,
trustedProxies: trustedProxies,
mappings: make(map[string]Mapping), mappings: make(map[string]Mapping),
logger: logger, logger: logger,
} }
@@ -96,21 +102,73 @@ func (p *ReverseProxy) rewriteFunc(target *url.URL, passHostHeader bool) func(r
} }
clientIP := extractClientIP(r.In.RemoteAddr) clientIP := extractClientIP(r.In.RemoteAddr)
proto := auth.ResolveProto(p.forwardedProto, r.In.TLS)
// Strip any incoming forwarding headers since this proxy is the trust if IsTrustedProxy(clientIP, p.trustedProxies) {
// boundary and set them fresh based on the direct connection. p.setTrustedForwardingHeaders(r, clientIP)
r.Out.Header.Set("X-Forwarded-For", clientIP) } else {
r.Out.Header.Set("X-Real-IP", clientIP) p.setUntrustedForwardingHeaders(r, clientIP)
r.Out.Header.Set("X-Forwarded-Host", r.In.Host) }
r.Out.Header.Set("X-Forwarded-Proto", proto)
r.Out.Header.Set("X-Forwarded-Port", extractForwardedPort(r.In.Host, proto))
stripSessionCookie(r) stripSessionCookie(r)
stripSessionTokenQuery(r) stripSessionTokenQuery(r)
} }
} }
// setTrustedForwardingHeaders appends to the existing forwarding header chain
// and preserves upstream-provided headers when the direct connection is from
// a trusted proxy.
func (p *ReverseProxy) setTrustedForwardingHeaders(r *httputil.ProxyRequest, clientIP string) {
// Append the direct connection IP to the existing X-Forwarded-For chain.
if existing := r.In.Header.Get("X-Forwarded-For"); existing != "" {
r.Out.Header.Set("X-Forwarded-For", existing+", "+clientIP)
} else {
r.Out.Header.Set("X-Forwarded-For", clientIP)
}
// Preserve upstream X-Real-IP if present; otherwise resolve through the chain.
if realIP := r.In.Header.Get("X-Real-IP"); realIP != "" {
r.Out.Header.Set("X-Real-IP", realIP)
} else {
resolved := ResolveClientIP(r.In.RemoteAddr, r.In.Header.Get("X-Forwarded-For"), p.trustedProxies)
r.Out.Header.Set("X-Real-IP", resolved)
}
// Preserve upstream X-Forwarded-Host if present.
if fwdHost := r.In.Header.Get("X-Forwarded-Host"); fwdHost != "" {
r.Out.Header.Set("X-Forwarded-Host", fwdHost)
} else {
r.Out.Header.Set("X-Forwarded-Host", r.In.Host)
}
// Trust upstream X-Forwarded-Proto; fall back to local resolution.
if fwdProto := r.In.Header.Get("X-Forwarded-Proto"); fwdProto != "" {
r.Out.Header.Set("X-Forwarded-Proto", fwdProto)
} else {
r.Out.Header.Set("X-Forwarded-Proto", auth.ResolveProto(p.forwardedProto, r.In.TLS))
}
// Trust upstream X-Forwarded-Port; fall back to local computation.
if fwdPort := r.In.Header.Get("X-Forwarded-Port"); fwdPort != "" {
r.Out.Header.Set("X-Forwarded-Port", fwdPort)
} else {
resolvedProto := r.Out.Header.Get("X-Forwarded-Proto")
r.Out.Header.Set("X-Forwarded-Port", extractForwardedPort(r.In.Host, resolvedProto))
}
}
// setUntrustedForwardingHeaders strips all incoming forwarding headers and
// sets them fresh based on the direct connection. This is the default
// behavior when no trusted proxies are configured or the direct connection
// is from an untrusted source.
func (p *ReverseProxy) setUntrustedForwardingHeaders(r *httputil.ProxyRequest, clientIP string) {
proto := auth.ResolveProto(p.forwardedProto, r.In.TLS)
r.Out.Header.Set("X-Forwarded-For", clientIP)
r.Out.Header.Set("X-Real-IP", clientIP)
r.Out.Header.Set("X-Forwarded-Host", r.In.Host)
r.Out.Header.Set("X-Forwarded-Proto", proto)
r.Out.Header.Set("X-Forwarded-Port", extractForwardedPort(r.In.Host, proto))
}
// stripSessionCookie removes the proxy's session cookie from the outgoing // stripSessionCookie removes the proxy's session cookie from the outgoing
// request while preserving all other cookies. // request while preserving all other cookies.
func stripSessionCookie(r *httputil.ProxyRequest) { func stripSessionCookie(r *httputil.ProxyRequest) {
@@ -163,14 +221,23 @@ func proxyErrorHandler(w http.ResponseWriter, r *http.Request, err error) {
cd.SetOrigin(OriginProxyError) cd.SetOrigin(OriginProxyError)
} }
requestID := getRequestID(r) requestID := getRequestID(r)
clientIP := getClientIP(r)
title, message, code, status := classifyProxyError(err) title, message, code, status := classifyProxyError(err)
log.Warnf("proxy error: request_id=%s method=%s host=%s path=%s status=%d title=%q err=%v", log.Warnf("proxy error: request_id=%s client_ip=%s method=%s host=%s path=%s status=%d title=%q err=%v",
requestID, r.Method, r.Host, r.URL.Path, code, title, err) requestID, clientIP, r.Method, r.Host, r.URL.Path, code, title, err)
web.ServeErrorPage(w, r, code, title, message, requestID, status) web.ServeErrorPage(w, r, code, title, message, requestID, status)
} }
// getClientIP retrieves the resolved client IP from context.
func getClientIP(r *http.Request) string {
if capturedData := CapturedDataFromContext(r.Context()); capturedData != nil {
return capturedData.GetClientIP()
}
return ""
}
// getRequestID retrieves the request ID from context or returns empty string. // getRequestID retrieves the request ID from context or returns empty string.
func getRequestID(r *http.Request) string { func getRequestID(r *http.Request) string {
if capturedData := CapturedDataFromContext(r.Context()); capturedData != nil { if capturedData := CapturedDataFromContext(r.Context()); capturedData != nil {

View File

@@ -5,6 +5,7 @@ import (
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/http/httputil" "net/http/httputil"
"net/netip"
"net/url" "net/url"
"testing" "testing"
@@ -293,6 +294,158 @@ func TestExtractForwardedPort(t *testing.T) {
} }
} }
func TestRewriteFunc_TrustedProxy(t *testing.T) {
target, _ := url.Parse("http://backend.internal:8080")
trusted := []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")}
t.Run("appends to X-Forwarded-For", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
rewrite := p.rewriteFunc(target, false)
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
pr.In.Header.Set("X-Forwarded-For", "203.0.113.50")
rewrite(pr)
assert.Equal(t, "203.0.113.50, 10.0.0.1", pr.Out.Header.Get("X-Forwarded-For"))
})
t.Run("preserves upstream X-Real-IP", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
rewrite := p.rewriteFunc(target, false)
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
pr.In.Header.Set("X-Forwarded-For", "203.0.113.50")
pr.In.Header.Set("X-Real-IP", "203.0.113.50")
rewrite(pr)
assert.Equal(t, "203.0.113.50", pr.Out.Header.Get("X-Real-IP"))
})
t.Run("resolves X-Real-IP from XFF when not set by upstream", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
rewrite := p.rewriteFunc(target, false)
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
pr.In.Header.Set("X-Forwarded-For", "203.0.113.50, 10.0.0.2")
rewrite(pr)
assert.Equal(t, "203.0.113.50", pr.Out.Header.Get("X-Real-IP"),
"should resolve real client through trusted chain")
})
t.Run("preserves upstream X-Forwarded-Host", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
rewrite := p.rewriteFunc(target, false)
pr := newProxyRequest(t, "http://proxy.internal/", "10.0.0.1:5000")
pr.In.Header.Set("X-Forwarded-Host", "original.example.com")
rewrite(pr)
assert.Equal(t, "original.example.com", pr.Out.Header.Get("X-Forwarded-Host"))
})
t.Run("preserves upstream X-Forwarded-Proto", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
rewrite := p.rewriteFunc(target, false)
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
pr.In.Header.Set("X-Forwarded-Proto", "https")
rewrite(pr)
assert.Equal(t, "https", pr.Out.Header.Get("X-Forwarded-Proto"))
})
t.Run("preserves upstream X-Forwarded-Port", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
rewrite := p.rewriteFunc(target, false)
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
pr.In.Header.Set("X-Forwarded-Port", "8443")
rewrite(pr)
assert.Equal(t, "8443", pr.Out.Header.Get("X-Forwarded-Port"))
})
t.Run("falls back to local proto when upstream does not set it", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "https", trustedProxies: trusted}
rewrite := p.rewriteFunc(target, false)
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
rewrite(pr)
assert.Equal(t, "https", pr.Out.Header.Get("X-Forwarded-Proto"),
"should use configured forwardedProto as fallback")
})
t.Run("sets X-Forwarded-Host from request when upstream does not set it", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
rewrite := p.rewriteFunc(target, false)
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
rewrite(pr)
assert.Equal(t, "example.com", pr.Out.Header.Get("X-Forwarded-Host"))
})
t.Run("untrusted RemoteAddr strips headers even with trusted list", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
rewrite := p.rewriteFunc(target, false)
pr := newProxyRequest(t, "http://example.com/", "203.0.113.50:9999")
pr.In.Header.Set("X-Forwarded-For", "10.0.0.1, 172.16.0.1")
pr.In.Header.Set("X-Real-IP", "evil")
pr.In.Header.Set("X-Forwarded-Host", "evil.example.com")
pr.In.Header.Set("X-Forwarded-Proto", "https")
pr.In.Header.Set("X-Forwarded-Port", "9999")
rewrite(pr)
assert.Equal(t, "203.0.113.50", pr.Out.Header.Get("X-Forwarded-For"),
"untrusted: XFF must be replaced")
assert.Equal(t, "203.0.113.50", pr.Out.Header.Get("X-Real-IP"),
"untrusted: X-Real-IP must be replaced")
assert.Equal(t, "example.com", pr.Out.Header.Get("X-Forwarded-Host"),
"untrusted: host must be from direct connection")
assert.Equal(t, "http", pr.Out.Header.Get("X-Forwarded-Proto"),
"untrusted: proto must be locally resolved")
assert.Equal(t, "80", pr.Out.Header.Get("X-Forwarded-Port"),
"untrusted: port must be locally computed")
})
t.Run("empty trusted list behaves as untrusted", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: nil}
rewrite := p.rewriteFunc(target, false)
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
pr.In.Header.Set("X-Forwarded-For", "203.0.113.50")
rewrite(pr)
assert.Equal(t, "10.0.0.1", pr.Out.Header.Get("X-Forwarded-For"),
"nil trusted list: should strip and use RemoteAddr")
})
t.Run("XFF starts fresh when trusted proxy has no upstream XFF", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
rewrite := p.rewriteFunc(target, false)
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
rewrite(pr)
assert.Equal(t, "10.0.0.1", pr.Out.Header.Get("X-Forwarded-For"),
"no upstream XFF: should set direct connection IP")
})
}
// newProxyRequest creates an httputil.ProxyRequest suitable for testing // newProxyRequest creates an httputil.ProxyRequest suitable for testing
// the Rewrite function. It simulates what httputil.ReverseProxy does internally: // the Rewrite function. It simulates what httputil.ReverseProxy does internally:
// Out is a shallow clone of In with headers copied. // Out is a shallow clone of In with headers copied.

View File

@@ -0,0 +1,60 @@
package proxy
import (
"net/netip"
"strings"
)
// IsTrustedProxy checks if the given IP string falls within any of the trusted prefixes.
func IsTrustedProxy(ipStr string, trusted []netip.Prefix) bool {
if len(trusted) == 0 {
return false
}
addr, err := netip.ParseAddr(ipStr)
if err != nil {
return false
}
for _, prefix := range trusted {
if prefix.Contains(addr) {
return true
}
}
return false
}
// ResolveClientIP extracts the real client IP from X-Forwarded-For using the trusted proxy list.
// It walks the XFF chain right-to-left, skipping IPs that match trusted prefixes.
// The first untrusted IP is the real client.
//
// If the trusted list is empty or remoteAddr is not trusted, it returns the
// remoteAddr IP directly (ignoring any forwarding headers).
func ResolveClientIP(remoteAddr, xff string, trusted []netip.Prefix) string {
remoteIP := extractClientIP(remoteAddr)
if len(trusted) == 0 || !IsTrustedProxy(remoteIP, trusted) {
return remoteIP
}
if xff == "" {
return remoteIP
}
parts := strings.Split(xff, ",")
for i := len(parts) - 1; i >= 0; i-- {
ip := strings.TrimSpace(parts[i])
if ip == "" {
continue
}
if !IsTrustedProxy(ip, trusted) {
return ip
}
}
// All IPs in XFF are trusted; return the leftmost as best guess.
if first := strings.TrimSpace(parts[0]); first != "" {
return first
}
return remoteIP
}

View File

@@ -0,0 +1,129 @@
package proxy
import (
"net/netip"
"testing"
"github.com/stretchr/testify/assert"
)
func TestIsTrustedProxy(t *testing.T) {
trusted := []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/8"),
netip.MustParsePrefix("192.168.1.0/24"),
netip.MustParsePrefix("fd00::/8"),
}
tests := []struct {
name string
ip string
trusted []netip.Prefix
want bool
}{
{"empty trusted list", "10.0.0.1", nil, false},
{"IP within /8 prefix", "10.1.2.3", trusted, true},
{"IP within /24 prefix", "192.168.1.100", trusted, true},
{"IP outside all prefixes", "203.0.113.50", trusted, false},
{"boundary IP just outside prefix", "192.168.2.1", trusted, false},
{"unparseable IP", "not-an-ip", trusted, false},
{"IPv6 in trusted range", "fd00::1", trusted, true},
{"IPv6 outside range", "2001:db8::1", trusted, false},
{"empty string", "", trusted, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert.Equal(t, tt.want, IsTrustedProxy(tt.ip, tt.trusted))
})
}
}
func TestResolveClientIP(t *testing.T) {
trusted := []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/8"),
netip.MustParsePrefix("172.16.0.0/12"),
}
tests := []struct {
name string
remoteAddr string
xff string
trusted []netip.Prefix
want string
}{
{
name: "empty trusted list returns RemoteAddr",
remoteAddr: "203.0.113.50:9999",
xff: "1.2.3.4",
trusted: nil,
want: "203.0.113.50",
},
{
name: "untrusted RemoteAddr ignores XFF",
remoteAddr: "203.0.113.50:9999",
xff: "1.2.3.4, 10.0.0.1",
trusted: trusted,
want: "203.0.113.50",
},
{
name: "trusted RemoteAddr with single client in XFF",
remoteAddr: "10.0.0.1:5000",
xff: "203.0.113.50",
trusted: trusted,
want: "203.0.113.50",
},
{
name: "trusted RemoteAddr walks past trusted entries in XFF",
remoteAddr: "10.0.0.1:5000",
xff: "203.0.113.50, 10.0.0.2, 172.16.0.5",
trusted: trusted,
want: "203.0.113.50",
},
{
name: "trusted RemoteAddr with empty XFF falls back to RemoteAddr",
remoteAddr: "10.0.0.1:5000",
xff: "",
trusted: trusted,
want: "10.0.0.1",
},
{
name: "all XFF IPs trusted returns leftmost",
remoteAddr: "10.0.0.1:5000",
xff: "10.0.0.2, 172.16.0.1, 10.0.0.3",
trusted: trusted,
want: "10.0.0.2",
},
{
name: "XFF with whitespace",
remoteAddr: "10.0.0.1:5000",
xff: " 203.0.113.50 , 10.0.0.2 ",
trusted: trusted,
want: "203.0.113.50",
},
{
name: "XFF with empty segments",
remoteAddr: "10.0.0.1:5000",
xff: "203.0.113.50,,10.0.0.2",
trusted: trusted,
want: "203.0.113.50",
},
{
name: "multi-hop with mixed trust",
remoteAddr: "10.0.0.1:5000",
xff: "8.8.8.8, 203.0.113.50, 172.16.0.1",
trusted: trusted,
want: "203.0.113.50",
},
{
name: "RemoteAddr without port",
remoteAddr: "10.0.0.1",
xff: "203.0.113.50",
trusted: trusted,
want: "203.0.113.50",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert.Equal(t, tt.want, ResolveClientIP(tt.remoteAddr, tt.xff, tt.trusted))
})
}
}

View File

@@ -16,6 +16,7 @@ import (
"io" "io"
"net" "net"
"net/http" "net/http"
"net/netip"
"net/url" "net/url"
"path/filepath" "path/filepath"
"time" "time"
@@ -82,6 +83,10 @@ type Server struct {
// ForwardedProto overrides the X-Forwarded-Proto value sent to backends. // ForwardedProto overrides the X-Forwarded-Proto value sent to backends.
// Valid values: "auto" (detect from TLS), "http", "https". // Valid values: "auto" (detect from TLS), "http", "https".
ForwardedProto string ForwardedProto string
// TrustedProxies is a list of IP prefixes for trusted upstream proxies.
// When set, forwarding headers from these sources are preserved and
// appended to instead of being stripped.
TrustedProxies []netip.Prefix
} }
// NotifyStatus sends a status update to management about tunnel connectivity // NotifyStatus sends a status update to management about tunnel connectivity
@@ -217,13 +222,13 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
} }
// Configure the reverse proxy using NetBird's HTTP Client Transport for proxying. // Configure the reverse proxy using NetBird's HTTP Client Transport for proxying.
s.proxy = proxy.NewReverseProxy(s.netbird, s.ForwardedProto, s.Logger) s.proxy = proxy.NewReverseProxy(s.netbird, s.ForwardedProto, s.TrustedProxies, s.Logger)
// Configure the authentication middleware. // Configure the authentication middleware.
s.auth = auth.NewMiddleware(s.Logger) s.auth = auth.NewMiddleware(s.Logger)
// Configure Access logs to management server. // Configure Access logs to management server.
accessLog := accesslog.NewLogger(s.mgmtClient, s.Logger) accessLog := accesslog.NewLogger(s.mgmtClient, s.Logger, s.TrustedProxies)
if s.DebugEndpointEnabled { if s.DebugEndpointEnabled {
debugAddr := debugEndpointAddr(s.DebugEndpointAddress) debugAddr := debugEndpointAddr(s.DebugEndpointAddress)

43
proxy/trustedproxy.go Normal file
View File

@@ -0,0 +1,43 @@
package proxy
import (
"fmt"
"net/netip"
"strings"
)
// ParseTrustedProxies parses a comma-separated list of CIDR prefixes or bare IPs
// into a slice of netip.Prefix values suitable for trusted proxy configuration.
// Bare IPs are converted to single-host prefixes (/32 or /128).
func ParseTrustedProxies(raw string) ([]netip.Prefix, error) {
if raw == "" {
return nil, nil
}
parts := strings.Split(raw, ",")
prefixes := make([]netip.Prefix, 0, len(parts))
for _, part := range parts {
part = strings.TrimSpace(part)
if part == "" {
continue
}
prefix, err := netip.ParsePrefix(part)
if err == nil {
prefixes = append(prefixes, prefix)
continue
}
addr, addrErr := netip.ParseAddr(part)
if addrErr != nil {
return nil, fmt.Errorf("parse trusted proxy %q: not a valid CIDR or IP: %w", part, addrErr)
}
bits := 32
if addr.Is6() {
bits = 128
}
prefixes = append(prefixes, netip.PrefixFrom(addr, bits))
}
return prefixes, nil
}

View File

@@ -0,0 +1,90 @@
package proxy
import (
"net/netip"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestParseTrustedProxies(t *testing.T) {
tests := []struct {
name string
raw string
want []netip.Prefix
wantErr bool
}{
{
name: "empty string returns nil",
raw: "",
want: nil,
},
{
name: "single CIDR",
raw: "10.0.0.0/8",
want: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")},
},
{
name: "single bare IPv4",
raw: "1.2.3.4",
want: []netip.Prefix{netip.MustParsePrefix("1.2.3.4/32")},
},
{
name: "single bare IPv6",
raw: "::1",
want: []netip.Prefix{netip.MustParsePrefix("::1/128")},
},
{
name: "comma-separated CIDRs",
raw: "10.0.0.0/8, 192.168.1.0/24",
want: []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/8"),
netip.MustParsePrefix("192.168.1.0/24"),
},
},
{
name: "mixed CIDRs and bare IPs",
raw: "10.0.0.0/8, 1.2.3.4, fd00::/8",
want: []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/8"),
netip.MustParsePrefix("1.2.3.4/32"),
netip.MustParsePrefix("fd00::/8"),
},
},
{
name: "whitespace around entries",
raw: " 10.0.0.0/8 , 192.168.0.0/16 ",
want: []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/8"),
netip.MustParsePrefix("192.168.0.0/16"),
},
},
{
name: "trailing comma produces no extra entry",
raw: "10.0.0.0/8,",
want: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")},
},
{
name: "invalid entry",
raw: "not-an-ip",
wantErr: true,
},
{
name: "partially invalid",
raw: "10.0.0.0/8, garbage",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := ParseTrustedProxies(tt.raw)
if tt.wantErr {
require.Error(t, err)
return
}
require.NoError(t, err)
assert.Equal(t, tt.want, got)
})
}
}