mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-19 00:36:38 +00:00
[management,proxy,client] Add L4 capabilities (TLS/TCP/UDP) (#5530)
This commit is contained in:
@@ -2,6 +2,7 @@ package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/netip"
|
||||
"sync"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/internal/types"
|
||||
@@ -47,10 +48,10 @@ func (o ResponseOrigin) String() string {
|
||||
type CapturedData struct {
|
||||
mu sync.RWMutex
|
||||
RequestID string
|
||||
ServiceId string
|
||||
ServiceId types.ServiceID
|
||||
AccountId types.AccountID
|
||||
Origin ResponseOrigin
|
||||
ClientIP string
|
||||
ClientIP netip.Addr
|
||||
UserID string
|
||||
AuthMethod string
|
||||
}
|
||||
@@ -63,14 +64,14 @@ func (c *CapturedData) GetRequestID() string {
|
||||
}
|
||||
|
||||
// SetServiceId safely sets the service ID
|
||||
func (c *CapturedData) SetServiceId(serviceId string) {
|
||||
func (c *CapturedData) SetServiceId(serviceId types.ServiceID) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.ServiceId = serviceId
|
||||
}
|
||||
|
||||
// GetServiceId safely gets the service ID
|
||||
func (c *CapturedData) GetServiceId() string {
|
||||
func (c *CapturedData) GetServiceId() types.ServiceID {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
return c.ServiceId
|
||||
@@ -105,14 +106,14 @@ func (c *CapturedData) GetOrigin() ResponseOrigin {
|
||||
}
|
||||
|
||||
// SetClientIP safely sets the resolved client IP.
|
||||
func (c *CapturedData) SetClientIP(ip string) {
|
||||
func (c *CapturedData) SetClientIP(ip netip.Addr) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.ClientIP = ip
|
||||
}
|
||||
|
||||
// GetClientIP safely gets the resolved client IP.
|
||||
func (c *CapturedData) GetClientIP() string {
|
||||
func (c *CapturedData) GetClientIP() netip.Addr {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
return c.ClientIP
|
||||
@@ -161,13 +162,13 @@ func CapturedDataFromContext(ctx context.Context) *CapturedData {
|
||||
return data
|
||||
}
|
||||
|
||||
func withServiceId(ctx context.Context, serviceId string) context.Context {
|
||||
func withServiceId(ctx context.Context, serviceId types.ServiceID) context.Context {
|
||||
return context.WithValue(ctx, serviceIdKey, serviceId)
|
||||
}
|
||||
|
||||
func ServiceIdFromContext(ctx context.Context) string {
|
||||
func ServiceIdFromContext(ctx context.Context) types.ServiceID {
|
||||
v := ctx.Value(serviceIdKey)
|
||||
serviceId, ok := v.(string)
|
||||
serviceId, ok := v.(types.ServiceID)
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -25,7 +25,7 @@ func (nopTransport) RoundTrip(*http.Request) (*http.Response, error) {
|
||||
func BenchmarkServeHTTP(b *testing.B) {
|
||||
rp := proxy.NewReverseProxy(nopTransport{}, "http", nil, nil)
|
||||
rp.AddMapping(proxy.Mapping{
|
||||
ID: rand.Text(),
|
||||
ID: types.ServiceID(rand.Text()),
|
||||
AccountID: types.AccountID(rand.Text()),
|
||||
Host: "app.example.com",
|
||||
Paths: map[string]*proxy.PathTarget{
|
||||
@@ -66,7 +66,7 @@ func BenchmarkServeHTTPHostCount(b *testing.B) {
|
||||
target = id
|
||||
}
|
||||
rp.AddMapping(proxy.Mapping{
|
||||
ID: id,
|
||||
ID: types.ServiceID(id),
|
||||
AccountID: types.AccountID(rand.Text()),
|
||||
Host: host,
|
||||
Paths: map[string]*proxy.PathTarget{
|
||||
@@ -118,7 +118,7 @@ func BenchmarkServeHTTPPathCount(b *testing.B) {
|
||||
}
|
||||
}
|
||||
rp.AddMapping(proxy.Mapping{
|
||||
ID: rand.Text(),
|
||||
ID: types.ServiceID(rand.Text()),
|
||||
AccountID: types.AccountID(rand.Text()),
|
||||
Host: "app.example.com",
|
||||
Paths: paths,
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/auth"
|
||||
"github.com/netbirdio/netbird/proxy/internal/roundtrip"
|
||||
"github.com/netbirdio/netbird/proxy/internal/types"
|
||||
"github.com/netbirdio/netbird/proxy/web"
|
||||
)
|
||||
|
||||
@@ -86,9 +87,7 @@ func (p *ReverseProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
ctx = roundtrip.WithSkipTLSVerify(ctx)
|
||||
}
|
||||
if pt.RequestTimeout > 0 {
|
||||
var cancel context.CancelFunc
|
||||
ctx, cancel = context.WithTimeout(ctx, pt.RequestTimeout)
|
||||
defer cancel()
|
||||
ctx = types.WithDialTimeout(ctx, pt.RequestTimeout)
|
||||
}
|
||||
|
||||
rewriteMatchedPath := result.matchedPath
|
||||
@@ -142,9 +141,9 @@ func (p *ReverseProxy) rewriteFunc(target *url.URL, matchedPath string, passHost
|
||||
r.Out.Header.Set(k, v)
|
||||
}
|
||||
|
||||
clientIP := extractClientIP(r.In.RemoteAddr)
|
||||
clientIP := extractHostIP(r.In.RemoteAddr)
|
||||
|
||||
if IsTrustedProxy(clientIP, p.trustedProxies) {
|
||||
if isTrustedAddr(clientIP, p.trustedProxies) {
|
||||
p.setTrustedForwardingHeaders(r, clientIP)
|
||||
} else {
|
||||
p.setUntrustedForwardingHeaders(r, clientIP)
|
||||
@@ -214,12 +213,14 @@ func normalizeHost(u *url.URL) string {
|
||||
// setTrustedForwardingHeaders appends to the existing forwarding header chain
|
||||
// and preserves upstream-provided headers when the direct connection is from
|
||||
// a trusted proxy.
|
||||
func (p *ReverseProxy) setTrustedForwardingHeaders(r *httputil.ProxyRequest, clientIP string) {
|
||||
func (p *ReverseProxy) setTrustedForwardingHeaders(r *httputil.ProxyRequest, clientIP netip.Addr) {
|
||||
ipStr := clientIP.String()
|
||||
|
||||
// Append the direct connection IP to the existing X-Forwarded-For chain.
|
||||
if existing := r.In.Header.Get("X-Forwarded-For"); existing != "" {
|
||||
r.Out.Header.Set("X-Forwarded-For", existing+", "+clientIP)
|
||||
r.Out.Header.Set("X-Forwarded-For", existing+", "+ipStr)
|
||||
} else {
|
||||
r.Out.Header.Set("X-Forwarded-For", clientIP)
|
||||
r.Out.Header.Set("X-Forwarded-For", ipStr)
|
||||
}
|
||||
|
||||
// Preserve upstream X-Real-IP if present; otherwise resolve through the chain.
|
||||
@@ -227,7 +228,7 @@ func (p *ReverseProxy) setTrustedForwardingHeaders(r *httputil.ProxyRequest, cli
|
||||
r.Out.Header.Set("X-Real-IP", realIP)
|
||||
} else {
|
||||
resolved := ResolveClientIP(r.In.RemoteAddr, r.In.Header.Get("X-Forwarded-For"), p.trustedProxies)
|
||||
r.Out.Header.Set("X-Real-IP", resolved)
|
||||
r.Out.Header.Set("X-Real-IP", resolved.String())
|
||||
}
|
||||
|
||||
// Preserve upstream X-Forwarded-Host if present.
|
||||
@@ -257,10 +258,11 @@ func (p *ReverseProxy) setTrustedForwardingHeaders(r *httputil.ProxyRequest, cli
|
||||
// sets them fresh based on the direct connection. This is the default
|
||||
// behavior when no trusted proxies are configured or the direct connection
|
||||
// is from an untrusted source.
|
||||
func (p *ReverseProxy) setUntrustedForwardingHeaders(r *httputil.ProxyRequest, clientIP string) {
|
||||
func (p *ReverseProxy) setUntrustedForwardingHeaders(r *httputil.ProxyRequest, clientIP netip.Addr) {
|
||||
ipStr := clientIP.String()
|
||||
proto := auth.ResolveProto(p.forwardedProto, r.In.TLS)
|
||||
r.Out.Header.Set("X-Forwarded-For", clientIP)
|
||||
r.Out.Header.Set("X-Real-IP", clientIP)
|
||||
r.Out.Header.Set("X-Forwarded-For", ipStr)
|
||||
r.Out.Header.Set("X-Real-IP", ipStr)
|
||||
r.Out.Header.Set("X-Forwarded-Host", r.In.Host)
|
||||
r.Out.Header.Set("X-Forwarded-Proto", proto)
|
||||
r.Out.Header.Set("X-Forwarded-Port", extractForwardedPort(r.In.Host, proto))
|
||||
@@ -288,16 +290,6 @@ func stripSessionTokenQuery(r *httputil.ProxyRequest) {
|
||||
}
|
||||
}
|
||||
|
||||
// extractClientIP extracts the IP address from an http.Request.RemoteAddr
|
||||
// which is always in host:port format.
|
||||
func extractClientIP(remoteAddr string) string {
|
||||
ip, _, err := net.SplitHostPort(remoteAddr)
|
||||
if err != nil {
|
||||
return remoteAddr
|
||||
}
|
||||
return ip
|
||||
}
|
||||
|
||||
// extractForwardedPort returns the port from the Host header if present,
|
||||
// otherwise defaults to the standard port for the resolved protocol.
|
||||
func extractForwardedPort(host, resolvedProto string) string {
|
||||
@@ -327,10 +319,12 @@ func proxyErrorHandler(w http.ResponseWriter, r *http.Request, err error) {
|
||||
web.ServeErrorPage(w, r, code, title, message, requestID, status)
|
||||
}
|
||||
|
||||
// getClientIP retrieves the resolved client IP from context.
|
||||
// getClientIP retrieves the resolved client IP string from context.
|
||||
func getClientIP(r *http.Request) string {
|
||||
if capturedData := CapturedDataFromContext(r.Context()); capturedData != nil {
|
||||
return capturedData.GetClientIP()
|
||||
if ip := capturedData.GetClientIP(); ip.IsValid() {
|
||||
return ip.String()
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -284,23 +284,23 @@ func TestRewriteFunc_URLRewriting(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestExtractClientIP(t *testing.T) {
|
||||
func TestExtractHostIP(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
remoteAddr string
|
||||
expected string
|
||||
expected netip.Addr
|
||||
}{
|
||||
{"IPv4 with port", "192.168.1.1:12345", "192.168.1.1"},
|
||||
{"IPv6 with port", "[::1]:12345", "::1"},
|
||||
{"IPv6 full with port", "[2001:db8::1]:443", "2001:db8::1"},
|
||||
{"IPv4 without port fallback", "192.168.1.1", "192.168.1.1"},
|
||||
{"IPv6 without brackets fallback", "::1", "::1"},
|
||||
{"empty string fallback", "", ""},
|
||||
{"public IP", "203.0.113.50:9999", "203.0.113.50"},
|
||||
{"IPv4 with port", "192.168.1.1:12345", netip.MustParseAddr("192.168.1.1")},
|
||||
{"IPv6 with port", "[::1]:12345", netip.MustParseAddr("::1")},
|
||||
{"IPv6 full with port", "[2001:db8::1]:443", netip.MustParseAddr("2001:db8::1")},
|
||||
{"IPv4 without port fallback", "192.168.1.1", netip.MustParseAddr("192.168.1.1")},
|
||||
{"IPv6 without brackets fallback", "::1", netip.MustParseAddr("::1")},
|
||||
{"empty string fallback", "", netip.Addr{}},
|
||||
{"public IP", "203.0.113.50:9999", netip.MustParseAddr("203.0.113.50")},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
assert.Equal(t, tt.expected, extractClientIP(tt.remoteAddr))
|
||||
assert.Equal(t, tt.expected, extractHostIP(tt.remoteAddr))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -30,8 +30,9 @@ type PathTarget struct {
|
||||
CustomHeaders map[string]string
|
||||
}
|
||||
|
||||
// Mapping describes how a domain is routed by the HTTP reverse proxy.
|
||||
type Mapping struct {
|
||||
ID string
|
||||
ID types.ServiceID
|
||||
AccountID types.AccountID
|
||||
Host string
|
||||
Paths map[string]*PathTarget
|
||||
@@ -42,7 +43,7 @@ type Mapping struct {
|
||||
type targetResult struct {
|
||||
target *PathTarget
|
||||
matchedPath string
|
||||
serviceID string
|
||||
serviceID types.ServiceID
|
||||
accountID types.AccountID
|
||||
passHostHeader bool
|
||||
rewriteRedirects bool
|
||||
@@ -101,8 +102,13 @@ func (p *ReverseProxy) AddMapping(m Mapping) {
|
||||
p.mappings[m.Host] = m
|
||||
}
|
||||
|
||||
func (p *ReverseProxy) RemoveMapping(m Mapping) {
|
||||
// RemoveMapping removes the mapping for the given host and reports whether it existed.
|
||||
func (p *ReverseProxy) RemoveMapping(m Mapping) bool {
|
||||
p.mappingsMux.Lock()
|
||||
defer p.mappingsMux.Unlock()
|
||||
if _, ok := p.mappings[m.Host]; !ok {
|
||||
return false
|
||||
}
|
||||
delete(p.mappings, m.Host)
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -7,21 +7,11 @@ import (
|
||||
|
||||
// IsTrustedProxy checks if the given IP string falls within any of the trusted prefixes.
|
||||
func IsTrustedProxy(ipStr string, trusted []netip.Prefix) bool {
|
||||
if len(trusted) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
addr, err := netip.ParseAddr(ipStr)
|
||||
if err != nil {
|
||||
if err != nil || len(trusted) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, prefix := range trusted {
|
||||
if prefix.Contains(addr) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
return isTrustedAddr(addr.Unmap(), trusted)
|
||||
}
|
||||
|
||||
// ResolveClientIP extracts the real client IP from X-Forwarded-For using the trusted proxy list.
|
||||
@@ -30,10 +20,10 @@ func IsTrustedProxy(ipStr string, trusted []netip.Prefix) bool {
|
||||
//
|
||||
// If the trusted list is empty or remoteAddr is not trusted, it returns the
|
||||
// remoteAddr IP directly (ignoring any forwarding headers).
|
||||
func ResolveClientIP(remoteAddr, xff string, trusted []netip.Prefix) string {
|
||||
remoteIP := extractClientIP(remoteAddr)
|
||||
func ResolveClientIP(remoteAddr, xff string, trusted []netip.Prefix) netip.Addr {
|
||||
remoteIP := extractHostIP(remoteAddr)
|
||||
|
||||
if len(trusted) == 0 || !IsTrustedProxy(remoteIP, trusted) {
|
||||
if len(trusted) == 0 || !isTrustedAddr(remoteIP, trusted) {
|
||||
return remoteIP
|
||||
}
|
||||
|
||||
@@ -47,14 +37,45 @@ func ResolveClientIP(remoteAddr, xff string, trusted []netip.Prefix) string {
|
||||
if ip == "" {
|
||||
continue
|
||||
}
|
||||
if !IsTrustedProxy(ip, trusted) {
|
||||
return ip
|
||||
addr, err := netip.ParseAddr(ip)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
addr = addr.Unmap()
|
||||
if !isTrustedAddr(addr, trusted) {
|
||||
return addr
|
||||
}
|
||||
}
|
||||
|
||||
// All IPs in XFF are trusted; return the leftmost as best guess.
|
||||
if first := strings.TrimSpace(parts[0]); first != "" {
|
||||
return first
|
||||
if addr, err := netip.ParseAddr(first); err == nil {
|
||||
return addr.Unmap()
|
||||
}
|
||||
}
|
||||
return remoteIP
|
||||
}
|
||||
|
||||
// extractHostIP parses the IP from a host:port string and returns it unmapped.
|
||||
func extractHostIP(hostPort string) netip.Addr {
|
||||
if ap, err := netip.ParseAddrPort(hostPort); err == nil {
|
||||
return ap.Addr().Unmap()
|
||||
}
|
||||
if addr, err := netip.ParseAddr(hostPort); err == nil {
|
||||
return addr.Unmap()
|
||||
}
|
||||
return netip.Addr{}
|
||||
}
|
||||
|
||||
// isTrustedAddr checks if the given address falls within any of the trusted prefixes.
|
||||
func isTrustedAddr(addr netip.Addr, trusted []netip.Prefix) bool {
|
||||
if !addr.IsValid() {
|
||||
return false
|
||||
}
|
||||
for _, prefix := range trusted {
|
||||
if prefix.Contains(addr) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -48,77 +48,77 @@ func TestResolveClientIP(t *testing.T) {
|
||||
remoteAddr string
|
||||
xff string
|
||||
trusted []netip.Prefix
|
||||
want string
|
||||
want netip.Addr
|
||||
}{
|
||||
{
|
||||
name: "empty trusted list returns RemoteAddr",
|
||||
remoteAddr: "203.0.113.50:9999",
|
||||
xff: "1.2.3.4",
|
||||
trusted: nil,
|
||||
want: "203.0.113.50",
|
||||
want: netip.MustParseAddr("203.0.113.50"),
|
||||
},
|
||||
{
|
||||
name: "untrusted RemoteAddr ignores XFF",
|
||||
remoteAddr: "203.0.113.50:9999",
|
||||
xff: "1.2.3.4, 10.0.0.1",
|
||||
trusted: trusted,
|
||||
want: "203.0.113.50",
|
||||
want: netip.MustParseAddr("203.0.113.50"),
|
||||
},
|
||||
{
|
||||
name: "trusted RemoteAddr with single client in XFF",
|
||||
remoteAddr: "10.0.0.1:5000",
|
||||
xff: "203.0.113.50",
|
||||
trusted: trusted,
|
||||
want: "203.0.113.50",
|
||||
want: netip.MustParseAddr("203.0.113.50"),
|
||||
},
|
||||
{
|
||||
name: "trusted RemoteAddr walks past trusted entries in XFF",
|
||||
remoteAddr: "10.0.0.1:5000",
|
||||
xff: "203.0.113.50, 10.0.0.2, 172.16.0.5",
|
||||
trusted: trusted,
|
||||
want: "203.0.113.50",
|
||||
want: netip.MustParseAddr("203.0.113.50"),
|
||||
},
|
||||
{
|
||||
name: "trusted RemoteAddr with empty XFF falls back to RemoteAddr",
|
||||
remoteAddr: "10.0.0.1:5000",
|
||||
xff: "",
|
||||
trusted: trusted,
|
||||
want: "10.0.0.1",
|
||||
want: netip.MustParseAddr("10.0.0.1"),
|
||||
},
|
||||
{
|
||||
name: "all XFF IPs trusted returns leftmost",
|
||||
remoteAddr: "10.0.0.1:5000",
|
||||
xff: "10.0.0.2, 172.16.0.1, 10.0.0.3",
|
||||
trusted: trusted,
|
||||
want: "10.0.0.2",
|
||||
want: netip.MustParseAddr("10.0.0.2"),
|
||||
},
|
||||
{
|
||||
name: "XFF with whitespace",
|
||||
remoteAddr: "10.0.0.1:5000",
|
||||
xff: " 203.0.113.50 , 10.0.0.2 ",
|
||||
trusted: trusted,
|
||||
want: "203.0.113.50",
|
||||
want: netip.MustParseAddr("203.0.113.50"),
|
||||
},
|
||||
{
|
||||
name: "XFF with empty segments",
|
||||
remoteAddr: "10.0.0.1:5000",
|
||||
xff: "203.0.113.50,,10.0.0.2",
|
||||
trusted: trusted,
|
||||
want: "203.0.113.50",
|
||||
want: netip.MustParseAddr("203.0.113.50"),
|
||||
},
|
||||
{
|
||||
name: "multi-hop with mixed trust",
|
||||
remoteAddr: "10.0.0.1:5000",
|
||||
xff: "8.8.8.8, 203.0.113.50, 172.16.0.1",
|
||||
trusted: trusted,
|
||||
want: "203.0.113.50",
|
||||
want: netip.MustParseAddr("203.0.113.50"),
|
||||
},
|
||||
{
|
||||
name: "RemoteAddr without port",
|
||||
remoteAddr: "10.0.0.1",
|
||||
xff: "203.0.113.50",
|
||||
trusted: trusted,
|
||||
want: "203.0.113.50",
|
||||
want: netip.MustParseAddr("203.0.113.50"),
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
|
||||
Reference in New Issue
Block a user