[management,proxy,client] Add L4 capabilities (TLS/TCP/UDP) (#5530)

This commit is contained in:
Viktor Liu
2026-03-14 01:36:44 +08:00
committed by GitHub
parent fe9b844511
commit 3e6baea405
90 changed files with 9611 additions and 1397 deletions

View File

@@ -2,6 +2,7 @@ package proxy
import (
"context"
"net/netip"
"sync"
"github.com/netbirdio/netbird/proxy/internal/types"
@@ -47,10 +48,10 @@ func (o ResponseOrigin) String() string {
type CapturedData struct {
mu sync.RWMutex
RequestID string
ServiceId string
ServiceId types.ServiceID
AccountId types.AccountID
Origin ResponseOrigin
ClientIP string
ClientIP netip.Addr
UserID string
AuthMethod string
}
@@ -63,14 +64,14 @@ func (c *CapturedData) GetRequestID() string {
}
// SetServiceId safely sets the service ID
func (c *CapturedData) SetServiceId(serviceId string) {
func (c *CapturedData) SetServiceId(serviceId types.ServiceID) {
c.mu.Lock()
defer c.mu.Unlock()
c.ServiceId = serviceId
}
// GetServiceId safely gets the service ID
func (c *CapturedData) GetServiceId() string {
func (c *CapturedData) GetServiceId() types.ServiceID {
c.mu.RLock()
defer c.mu.RUnlock()
return c.ServiceId
@@ -105,14 +106,14 @@ func (c *CapturedData) GetOrigin() ResponseOrigin {
}
// SetClientIP safely sets the resolved client IP.
func (c *CapturedData) SetClientIP(ip string) {
func (c *CapturedData) SetClientIP(ip netip.Addr) {
c.mu.Lock()
defer c.mu.Unlock()
c.ClientIP = ip
}
// GetClientIP safely gets the resolved client IP.
func (c *CapturedData) GetClientIP() string {
func (c *CapturedData) GetClientIP() netip.Addr {
c.mu.RLock()
defer c.mu.RUnlock()
return c.ClientIP
@@ -161,13 +162,13 @@ func CapturedDataFromContext(ctx context.Context) *CapturedData {
return data
}
func withServiceId(ctx context.Context, serviceId string) context.Context {
func withServiceId(ctx context.Context, serviceId types.ServiceID) context.Context {
return context.WithValue(ctx, serviceIdKey, serviceId)
}
func ServiceIdFromContext(ctx context.Context) string {
func ServiceIdFromContext(ctx context.Context) types.ServiceID {
v := ctx.Value(serviceIdKey)
serviceId, ok := v.(string)
serviceId, ok := v.(types.ServiceID)
if !ok {
return ""
}

View File

@@ -25,7 +25,7 @@ func (nopTransport) RoundTrip(*http.Request) (*http.Response, error) {
func BenchmarkServeHTTP(b *testing.B) {
rp := proxy.NewReverseProxy(nopTransport{}, "http", nil, nil)
rp.AddMapping(proxy.Mapping{
ID: rand.Text(),
ID: types.ServiceID(rand.Text()),
AccountID: types.AccountID(rand.Text()),
Host: "app.example.com",
Paths: map[string]*proxy.PathTarget{
@@ -66,7 +66,7 @@ func BenchmarkServeHTTPHostCount(b *testing.B) {
target = id
}
rp.AddMapping(proxy.Mapping{
ID: id,
ID: types.ServiceID(id),
AccountID: types.AccountID(rand.Text()),
Host: host,
Paths: map[string]*proxy.PathTarget{
@@ -118,7 +118,7 @@ func BenchmarkServeHTTPPathCount(b *testing.B) {
}
}
rp.AddMapping(proxy.Mapping{
ID: rand.Text(),
ID: types.ServiceID(rand.Text()),
AccountID: types.AccountID(rand.Text()),
Host: "app.example.com",
Paths: paths,

View File

@@ -16,6 +16,7 @@ import (
"github.com/netbirdio/netbird/proxy/auth"
"github.com/netbirdio/netbird/proxy/internal/roundtrip"
"github.com/netbirdio/netbird/proxy/internal/types"
"github.com/netbirdio/netbird/proxy/web"
)
@@ -86,9 +87,7 @@ func (p *ReverseProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
ctx = roundtrip.WithSkipTLSVerify(ctx)
}
if pt.RequestTimeout > 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, pt.RequestTimeout)
defer cancel()
ctx = types.WithDialTimeout(ctx, pt.RequestTimeout)
}
rewriteMatchedPath := result.matchedPath
@@ -142,9 +141,9 @@ func (p *ReverseProxy) rewriteFunc(target *url.URL, matchedPath string, passHost
r.Out.Header.Set(k, v)
}
clientIP := extractClientIP(r.In.RemoteAddr)
clientIP := extractHostIP(r.In.RemoteAddr)
if IsTrustedProxy(clientIP, p.trustedProxies) {
if isTrustedAddr(clientIP, p.trustedProxies) {
p.setTrustedForwardingHeaders(r, clientIP)
} else {
p.setUntrustedForwardingHeaders(r, clientIP)
@@ -214,12 +213,14 @@ func normalizeHost(u *url.URL) string {
// setTrustedForwardingHeaders appends to the existing forwarding header chain
// and preserves upstream-provided headers when the direct connection is from
// a trusted proxy.
func (p *ReverseProxy) setTrustedForwardingHeaders(r *httputil.ProxyRequest, clientIP string) {
func (p *ReverseProxy) setTrustedForwardingHeaders(r *httputil.ProxyRequest, clientIP netip.Addr) {
ipStr := clientIP.String()
// Append the direct connection IP to the existing X-Forwarded-For chain.
if existing := r.In.Header.Get("X-Forwarded-For"); existing != "" {
r.Out.Header.Set("X-Forwarded-For", existing+", "+clientIP)
r.Out.Header.Set("X-Forwarded-For", existing+", "+ipStr)
} else {
r.Out.Header.Set("X-Forwarded-For", clientIP)
r.Out.Header.Set("X-Forwarded-For", ipStr)
}
// Preserve upstream X-Real-IP if present; otherwise resolve through the chain.
@@ -227,7 +228,7 @@ func (p *ReverseProxy) setTrustedForwardingHeaders(r *httputil.ProxyRequest, cli
r.Out.Header.Set("X-Real-IP", realIP)
} else {
resolved := ResolveClientIP(r.In.RemoteAddr, r.In.Header.Get("X-Forwarded-For"), p.trustedProxies)
r.Out.Header.Set("X-Real-IP", resolved)
r.Out.Header.Set("X-Real-IP", resolved.String())
}
// Preserve upstream X-Forwarded-Host if present.
@@ -257,10 +258,11 @@ func (p *ReverseProxy) setTrustedForwardingHeaders(r *httputil.ProxyRequest, cli
// sets them fresh based on the direct connection. This is the default
// behavior when no trusted proxies are configured or the direct connection
// is from an untrusted source.
func (p *ReverseProxy) setUntrustedForwardingHeaders(r *httputil.ProxyRequest, clientIP string) {
func (p *ReverseProxy) setUntrustedForwardingHeaders(r *httputil.ProxyRequest, clientIP netip.Addr) {
ipStr := clientIP.String()
proto := auth.ResolveProto(p.forwardedProto, r.In.TLS)
r.Out.Header.Set("X-Forwarded-For", clientIP)
r.Out.Header.Set("X-Real-IP", clientIP)
r.Out.Header.Set("X-Forwarded-For", ipStr)
r.Out.Header.Set("X-Real-IP", ipStr)
r.Out.Header.Set("X-Forwarded-Host", r.In.Host)
r.Out.Header.Set("X-Forwarded-Proto", proto)
r.Out.Header.Set("X-Forwarded-Port", extractForwardedPort(r.In.Host, proto))
@@ -288,16 +290,6 @@ func stripSessionTokenQuery(r *httputil.ProxyRequest) {
}
}
// extractClientIP extracts the IP address from an http.Request.RemoteAddr
// which is always in host:port format.
func extractClientIP(remoteAddr string) string {
ip, _, err := net.SplitHostPort(remoteAddr)
if err != nil {
return remoteAddr
}
return ip
}
// extractForwardedPort returns the port from the Host header if present,
// otherwise defaults to the standard port for the resolved protocol.
func extractForwardedPort(host, resolvedProto string) string {
@@ -327,10 +319,12 @@ func proxyErrorHandler(w http.ResponseWriter, r *http.Request, err error) {
web.ServeErrorPage(w, r, code, title, message, requestID, status)
}
// getClientIP retrieves the resolved client IP from context.
// getClientIP retrieves the resolved client IP string from context.
func getClientIP(r *http.Request) string {
if capturedData := CapturedDataFromContext(r.Context()); capturedData != nil {
return capturedData.GetClientIP()
if ip := capturedData.GetClientIP(); ip.IsValid() {
return ip.String()
}
}
return ""
}

View File

@@ -284,23 +284,23 @@ func TestRewriteFunc_URLRewriting(t *testing.T) {
})
}
func TestExtractClientIP(t *testing.T) {
func TestExtractHostIP(t *testing.T) {
tests := []struct {
name string
remoteAddr string
expected string
expected netip.Addr
}{
{"IPv4 with port", "192.168.1.1:12345", "192.168.1.1"},
{"IPv6 with port", "[::1]:12345", "::1"},
{"IPv6 full with port", "[2001:db8::1]:443", "2001:db8::1"},
{"IPv4 without port fallback", "192.168.1.1", "192.168.1.1"},
{"IPv6 without brackets fallback", "::1", "::1"},
{"empty string fallback", "", ""},
{"public IP", "203.0.113.50:9999", "203.0.113.50"},
{"IPv4 with port", "192.168.1.1:12345", netip.MustParseAddr("192.168.1.1")},
{"IPv6 with port", "[::1]:12345", netip.MustParseAddr("::1")},
{"IPv6 full with port", "[2001:db8::1]:443", netip.MustParseAddr("2001:db8::1")},
{"IPv4 without port fallback", "192.168.1.1", netip.MustParseAddr("192.168.1.1")},
{"IPv6 without brackets fallback", "::1", netip.MustParseAddr("::1")},
{"empty string fallback", "", netip.Addr{}},
{"public IP", "203.0.113.50:9999", netip.MustParseAddr("203.0.113.50")},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert.Equal(t, tt.expected, extractClientIP(tt.remoteAddr))
assert.Equal(t, tt.expected, extractHostIP(tt.remoteAddr))
})
}
}

View File

@@ -30,8 +30,9 @@ type PathTarget struct {
CustomHeaders map[string]string
}
// Mapping describes how a domain is routed by the HTTP reverse proxy.
type Mapping struct {
ID string
ID types.ServiceID
AccountID types.AccountID
Host string
Paths map[string]*PathTarget
@@ -42,7 +43,7 @@ type Mapping struct {
type targetResult struct {
target *PathTarget
matchedPath string
serviceID string
serviceID types.ServiceID
accountID types.AccountID
passHostHeader bool
rewriteRedirects bool
@@ -101,8 +102,13 @@ func (p *ReverseProxy) AddMapping(m Mapping) {
p.mappings[m.Host] = m
}
func (p *ReverseProxy) RemoveMapping(m Mapping) {
// RemoveMapping removes the mapping for the given host and reports whether it existed.
func (p *ReverseProxy) RemoveMapping(m Mapping) bool {
p.mappingsMux.Lock()
defer p.mappingsMux.Unlock()
if _, ok := p.mappings[m.Host]; !ok {
return false
}
delete(p.mappings, m.Host)
return true
}

View File

@@ -7,21 +7,11 @@ import (
// IsTrustedProxy checks if the given IP string falls within any of the trusted prefixes.
func IsTrustedProxy(ipStr string, trusted []netip.Prefix) bool {
if len(trusted) == 0 {
return false
}
addr, err := netip.ParseAddr(ipStr)
if err != nil {
if err != nil || len(trusted) == 0 {
return false
}
for _, prefix := range trusted {
if prefix.Contains(addr) {
return true
}
}
return false
return isTrustedAddr(addr.Unmap(), trusted)
}
// ResolveClientIP extracts the real client IP from X-Forwarded-For using the trusted proxy list.
@@ -30,10 +20,10 @@ func IsTrustedProxy(ipStr string, trusted []netip.Prefix) bool {
//
// If the trusted list is empty or remoteAddr is not trusted, it returns the
// remoteAddr IP directly (ignoring any forwarding headers).
func ResolveClientIP(remoteAddr, xff string, trusted []netip.Prefix) string {
remoteIP := extractClientIP(remoteAddr)
func ResolveClientIP(remoteAddr, xff string, trusted []netip.Prefix) netip.Addr {
remoteIP := extractHostIP(remoteAddr)
if len(trusted) == 0 || !IsTrustedProxy(remoteIP, trusted) {
if len(trusted) == 0 || !isTrustedAddr(remoteIP, trusted) {
return remoteIP
}
@@ -47,14 +37,45 @@ func ResolveClientIP(remoteAddr, xff string, trusted []netip.Prefix) string {
if ip == "" {
continue
}
if !IsTrustedProxy(ip, trusted) {
return ip
addr, err := netip.ParseAddr(ip)
if err != nil {
continue
}
addr = addr.Unmap()
if !isTrustedAddr(addr, trusted) {
return addr
}
}
// All IPs in XFF are trusted; return the leftmost as best guess.
if first := strings.TrimSpace(parts[0]); first != "" {
return first
if addr, err := netip.ParseAddr(first); err == nil {
return addr.Unmap()
}
}
return remoteIP
}
// extractHostIP parses the IP from a host:port string and returns it unmapped.
func extractHostIP(hostPort string) netip.Addr {
if ap, err := netip.ParseAddrPort(hostPort); err == nil {
return ap.Addr().Unmap()
}
if addr, err := netip.ParseAddr(hostPort); err == nil {
return addr.Unmap()
}
return netip.Addr{}
}
// isTrustedAddr checks if the given address falls within any of the trusted prefixes.
func isTrustedAddr(addr netip.Addr, trusted []netip.Prefix) bool {
if !addr.IsValid() {
return false
}
for _, prefix := range trusted {
if prefix.Contains(addr) {
return true
}
}
return false
}

View File

@@ -48,77 +48,77 @@ func TestResolveClientIP(t *testing.T) {
remoteAddr string
xff string
trusted []netip.Prefix
want string
want netip.Addr
}{
{
name: "empty trusted list returns RemoteAddr",
remoteAddr: "203.0.113.50:9999",
xff: "1.2.3.4",
trusted: nil,
want: "203.0.113.50",
want: netip.MustParseAddr("203.0.113.50"),
},
{
name: "untrusted RemoteAddr ignores XFF",
remoteAddr: "203.0.113.50:9999",
xff: "1.2.3.4, 10.0.0.1",
trusted: trusted,
want: "203.0.113.50",
want: netip.MustParseAddr("203.0.113.50"),
},
{
name: "trusted RemoteAddr with single client in XFF",
remoteAddr: "10.0.0.1:5000",
xff: "203.0.113.50",
trusted: trusted,
want: "203.0.113.50",
want: netip.MustParseAddr("203.0.113.50"),
},
{
name: "trusted RemoteAddr walks past trusted entries in XFF",
remoteAddr: "10.0.0.1:5000",
xff: "203.0.113.50, 10.0.0.2, 172.16.0.5",
trusted: trusted,
want: "203.0.113.50",
want: netip.MustParseAddr("203.0.113.50"),
},
{
name: "trusted RemoteAddr with empty XFF falls back to RemoteAddr",
remoteAddr: "10.0.0.1:5000",
xff: "",
trusted: trusted,
want: "10.0.0.1",
want: netip.MustParseAddr("10.0.0.1"),
},
{
name: "all XFF IPs trusted returns leftmost",
remoteAddr: "10.0.0.1:5000",
xff: "10.0.0.2, 172.16.0.1, 10.0.0.3",
trusted: trusted,
want: "10.0.0.2",
want: netip.MustParseAddr("10.0.0.2"),
},
{
name: "XFF with whitespace",
remoteAddr: "10.0.0.1:5000",
xff: " 203.0.113.50 , 10.0.0.2 ",
trusted: trusted,
want: "203.0.113.50",
want: netip.MustParseAddr("203.0.113.50"),
},
{
name: "XFF with empty segments",
remoteAddr: "10.0.0.1:5000",
xff: "203.0.113.50,,10.0.0.2",
trusted: trusted,
want: "203.0.113.50",
want: netip.MustParseAddr("203.0.113.50"),
},
{
name: "multi-hop with mixed trust",
remoteAddr: "10.0.0.1:5000",
xff: "8.8.8.8, 203.0.113.50, 172.16.0.1",
trusted: trusted,
want: "203.0.113.50",
want: netip.MustParseAddr("203.0.113.50"),
},
{
name: "RemoteAddr without port",
remoteAddr: "10.0.0.1",
xff: "203.0.113.50",
trusted: trusted,
want: "203.0.113.50",
want: netip.MustParseAddr("203.0.113.50"),
},
}
for _, tt := range tests {