mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-22 10:16:38 +00:00
Add reverse proxy header security and forwarding
- Rewrite Host header to backend target (configurable via pass_host_header per mapping) - Strip and set X-Forwarded-For/X-Real-IP from direct connection (trust boundary) - Set X-Forwarded-Host and X-Forwarded-Proto headers - Strip nb_session cookie and session_token query param before forwarding - Add --forwarded-proto flag (auto/http/https) for proto detection - Fix OIDC redirect hardcoded https scheme - Add pass_host_header to proto, API, and management model
This commit is contained in:
@@ -3,22 +3,28 @@ package proxy
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/auth"
|
||||
"github.com/netbirdio/netbird/proxy/internal/roundtrip"
|
||||
"github.com/netbirdio/netbird/proxy/web"
|
||||
)
|
||||
|
||||
type ReverseProxy struct {
|
||||
transport http.RoundTripper
|
||||
mappingsMux sync.RWMutex
|
||||
mappings map[string]Mapping
|
||||
logger *log.Logger
|
||||
transport http.RoundTripper
|
||||
// forwardedProto overrides the X-Forwarded-Proto header value.
|
||||
// Valid values: "auto" (detect from TLS), "http", "https".
|
||||
forwardedProto string
|
||||
mappingsMux sync.RWMutex
|
||||
mappings map[string]Mapping
|
||||
logger *log.Logger
|
||||
}
|
||||
|
||||
// NewReverseProxy configures a new NetBird ReverseProxy.
|
||||
@@ -27,19 +33,20 @@ type ReverseProxy struct {
|
||||
// between requested URLs and targets.
|
||||
// The internal mappings can be modified using the AddMapping
|
||||
// and RemoveMapping functions.
|
||||
func NewReverseProxy(transport http.RoundTripper, logger *log.Logger) *ReverseProxy {
|
||||
func NewReverseProxy(transport http.RoundTripper, forwardedProto string, logger *log.Logger) *ReverseProxy {
|
||||
if logger == nil {
|
||||
logger = log.StandardLogger()
|
||||
}
|
||||
return &ReverseProxy{
|
||||
transport: transport,
|
||||
mappings: make(map[string]Mapping),
|
||||
logger: logger,
|
||||
transport: transport,
|
||||
forwardedProto: forwardedProto,
|
||||
mappings: make(map[string]Mapping),
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *ReverseProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
target, serviceId, accountID, exists := p.findTargetForRequest(r)
|
||||
result, exists := p.findTargetForRequest(r)
|
||||
if !exists {
|
||||
requestID := getRequestID(r)
|
||||
web.ServeErrorPage(w, r, http.StatusNotFound, "Service Not Found",
|
||||
@@ -49,25 +56,101 @@ func (p *ReverseProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// Set the serviceId in the context for later retrieval.
|
||||
ctx := withServiceId(r.Context(), serviceId)
|
||||
ctx := withServiceId(r.Context(), result.serviceID)
|
||||
// Set the accountId in the context for later retrieval (for middleware).
|
||||
ctx = withAccountId(ctx, accountID)
|
||||
ctx = withAccountId(ctx, result.accountID)
|
||||
// Set the accountId in the context for the roundtripper to use.
|
||||
ctx = roundtrip.WithAccountID(ctx, accountID)
|
||||
ctx = roundtrip.WithAccountID(ctx, result.accountID)
|
||||
|
||||
// Also populate captured data if it exists (allows middleware to read after handler completes).
|
||||
// This solves the problem of passing data UP the middleware chain: we put a mutable struct
|
||||
// pointer in the context, and mutate the struct here so outer middleware can read it.
|
||||
if capturedData := CapturedDataFromContext(ctx); capturedData != nil {
|
||||
capturedData.SetServiceId(serviceId)
|
||||
capturedData.SetAccountId(accountID)
|
||||
capturedData.SetServiceId(result.serviceID)
|
||||
capturedData.SetAccountId(result.accountID)
|
||||
}
|
||||
|
||||
// Set up a reverse proxy using the transport and then use it to serve the request.
|
||||
proxy := httputil.NewSingleHostReverseProxy(target)
|
||||
proxy.Transport = p.transport
|
||||
proxy.ErrorHandler = proxyErrorHandler
|
||||
proxy.ServeHTTP(w, r.WithContext(ctx))
|
||||
rp := &httputil.ReverseProxy{
|
||||
Rewrite: p.rewriteFunc(result.url, result.passHostHeader),
|
||||
Transport: p.transport,
|
||||
ErrorHandler: proxyErrorHandler,
|
||||
}
|
||||
rp.ServeHTTP(w, r.WithContext(ctx))
|
||||
}
|
||||
|
||||
// rewriteFunc returns a Rewrite function for httputil.ReverseProxy that rewrites
|
||||
// inbound requests to target the backend service while setting security-relevant
|
||||
// forwarding headers and stripping proxy authentication credentials.
|
||||
// When passHostHeader is true, the original client Host header is preserved
|
||||
// instead of being rewritten to the backend's address.
|
||||
func (p *ReverseProxy) rewriteFunc(target *url.URL, passHostHeader bool) func(r *httputil.ProxyRequest) {
|
||||
return func(r *httputil.ProxyRequest) {
|
||||
r.SetURL(target)
|
||||
if passHostHeader {
|
||||
r.Out.Host = r.In.Host
|
||||
} else {
|
||||
r.Out.Host = target.Host
|
||||
}
|
||||
|
||||
clientIP := extractClientIP(r.In.RemoteAddr)
|
||||
proto := auth.ResolveProto(p.forwardedProto, r.In.TLS)
|
||||
|
||||
// Strip any incoming forwarding headers since this proxy is the trust
|
||||
// boundary and set them fresh based on the direct connection.
|
||||
r.Out.Header.Set("X-Forwarded-For", clientIP)
|
||||
r.Out.Header.Set("X-Real-IP", clientIP)
|
||||
r.Out.Header.Set("X-Forwarded-Host", r.In.Host)
|
||||
r.Out.Header.Set("X-Forwarded-Proto", proto)
|
||||
r.Out.Header.Set("X-Forwarded-Port", extractForwardedPort(r.In.Host, proto))
|
||||
|
||||
stripSessionCookie(r)
|
||||
stripSessionTokenQuery(r)
|
||||
}
|
||||
}
|
||||
|
||||
// stripSessionCookie removes the proxy's session cookie from the outgoing
|
||||
// request while preserving all other cookies.
|
||||
func stripSessionCookie(r *httputil.ProxyRequest) {
|
||||
cookies := r.In.Cookies()
|
||||
r.Out.Header.Del("Cookie")
|
||||
for _, c := range cookies {
|
||||
if c.Name != auth.SessionCookieName {
|
||||
r.Out.AddCookie(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// stripSessionTokenQuery removes the OIDC session_token query parameter from
|
||||
// the outgoing URL to prevent credential leakage to backends.
|
||||
func stripSessionTokenQuery(r *httputil.ProxyRequest) {
|
||||
q := r.Out.URL.Query()
|
||||
if q.Has("session_token") {
|
||||
q.Del("session_token")
|
||||
r.Out.URL.RawQuery = q.Encode()
|
||||
}
|
||||
}
|
||||
|
||||
// extractClientIP extracts the IP address from an http.Request.RemoteAddr
|
||||
// which is always in host:port format.
|
||||
func extractClientIP(remoteAddr string) string {
|
||||
ip, _, err := net.SplitHostPort(remoteAddr)
|
||||
if err != nil {
|
||||
return remoteAddr
|
||||
}
|
||||
return ip
|
||||
}
|
||||
|
||||
// extractForwardedPort returns the port from the Host header if present,
|
||||
// otherwise defaults to the standard port for the resolved protocol.
|
||||
func extractForwardedPort(host, resolvedProto string) string {
|
||||
_, port, err := net.SplitHostPort(host)
|
||||
if err == nil && port != "" {
|
||||
return port
|
||||
}
|
||||
if resolvedProto == "https" {
|
||||
return "443"
|
||||
}
|
||||
return "80"
|
||||
}
|
||||
|
||||
// proxyErrorHandler handles errors from the reverse proxy and serves
|
||||
|
||||
313
proxy/internal/proxy/reverseproxy_test.go
Normal file
313
proxy/internal/proxy/reverseproxy_test.go
Normal file
@@ -0,0 +1,313 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/auth"
|
||||
)
|
||||
|
||||
func TestRewriteFunc_HostRewriting(t *testing.T) {
|
||||
target, _ := url.Parse("http://backend.internal:8080")
|
||||
p := &ReverseProxy{forwardedProto: "auto"}
|
||||
|
||||
t.Run("rewrites host to backend by default", func(t *testing.T) {
|
||||
rewrite := p.rewriteFunc(target, false)
|
||||
pr := newProxyRequest(t, "https://public.example.com/path", "203.0.113.1:12345")
|
||||
|
||||
rewrite(pr)
|
||||
|
||||
assert.Equal(t, "backend.internal:8080", pr.Out.Host)
|
||||
})
|
||||
|
||||
t.Run("preserves original host when passHostHeader is true", func(t *testing.T) {
|
||||
rewrite := p.rewriteFunc(target, true)
|
||||
pr := newProxyRequest(t, "https://public.example.com/path", "203.0.113.1:12345")
|
||||
|
||||
rewrite(pr)
|
||||
|
||||
assert.Equal(t, "public.example.com", pr.Out.Host,
|
||||
"Host header should be the original client host")
|
||||
assert.Equal(t, "backend.internal:8080", pr.Out.URL.Host,
|
||||
"URL host (used for TLS/SNI) must still point to the backend")
|
||||
})
|
||||
}
|
||||
|
||||
func TestRewriteFunc_XForwardedForStripping(t *testing.T) {
|
||||
target, _ := url.Parse("http://backend.internal:8080")
|
||||
p := &ReverseProxy{forwardedProto: "auto"}
|
||||
rewrite := p.rewriteFunc(target, false)
|
||||
|
||||
t.Run("sets X-Forwarded-For from direct connection IP", func(t *testing.T) {
|
||||
pr := newProxyRequest(t, "http://example.com/", "203.0.113.50:9999")
|
||||
|
||||
rewrite(pr)
|
||||
|
||||
assert.Equal(t, "203.0.113.50", pr.Out.Header.Get("X-Forwarded-For"),
|
||||
"should be set to the connecting client IP")
|
||||
})
|
||||
|
||||
t.Run("strips spoofed X-Forwarded-For from client", func(t *testing.T) {
|
||||
pr := newProxyRequest(t, "http://example.com/", "203.0.113.50:9999")
|
||||
pr.In.Header.Set("X-Forwarded-For", "10.0.0.1, 172.16.0.1")
|
||||
|
||||
rewrite(pr)
|
||||
|
||||
assert.Equal(t, "203.0.113.50", pr.Out.Header.Get("X-Forwarded-For"),
|
||||
"spoofed XFF must be replaced, not appended to")
|
||||
})
|
||||
|
||||
t.Run("strips spoofed X-Real-IP from client", func(t *testing.T) {
|
||||
pr := newProxyRequest(t, "http://example.com/", "203.0.113.50:9999")
|
||||
pr.In.Header.Set("X-Real-IP", "10.0.0.1")
|
||||
|
||||
rewrite(pr)
|
||||
|
||||
assert.Equal(t, "203.0.113.50", pr.Out.Header.Get("X-Real-IP"),
|
||||
"spoofed X-Real-IP must be replaced")
|
||||
})
|
||||
}
|
||||
|
||||
func TestRewriteFunc_ForwardedHostAndProto(t *testing.T) {
|
||||
target, _ := url.Parse("http://backend.internal:8080")
|
||||
|
||||
t.Run("sets X-Forwarded-Host to original host", func(t *testing.T) {
|
||||
p := &ReverseProxy{forwardedProto: "auto"}
|
||||
rewrite := p.rewriteFunc(target, false)
|
||||
pr := newProxyRequest(t, "http://myapp.example.com:8443/path", "1.2.3.4:5000")
|
||||
|
||||
rewrite(pr)
|
||||
|
||||
assert.Equal(t, "myapp.example.com:8443", pr.Out.Header.Get("X-Forwarded-Host"))
|
||||
})
|
||||
|
||||
t.Run("sets X-Forwarded-Port from explicit host port", func(t *testing.T) {
|
||||
p := &ReverseProxy{forwardedProto: "auto"}
|
||||
rewrite := p.rewriteFunc(target, false)
|
||||
pr := newProxyRequest(t, "http://example.com:8443/path", "1.2.3.4:5000")
|
||||
|
||||
rewrite(pr)
|
||||
|
||||
assert.Equal(t, "8443", pr.Out.Header.Get("X-Forwarded-Port"))
|
||||
})
|
||||
|
||||
t.Run("defaults X-Forwarded-Port to 443 for https", func(t *testing.T) {
|
||||
p := &ReverseProxy{forwardedProto: "auto"}
|
||||
rewrite := p.rewriteFunc(target, false)
|
||||
pr := newProxyRequest(t, "https://example.com/", "1.2.3.4:5000")
|
||||
pr.In.TLS = &tls.ConnectionState{}
|
||||
|
||||
rewrite(pr)
|
||||
|
||||
assert.Equal(t, "443", pr.Out.Header.Get("X-Forwarded-Port"))
|
||||
})
|
||||
|
||||
t.Run("defaults X-Forwarded-Port to 80 for http", func(t *testing.T) {
|
||||
p := &ReverseProxy{forwardedProto: "auto"}
|
||||
rewrite := p.rewriteFunc(target, false)
|
||||
pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000")
|
||||
|
||||
rewrite(pr)
|
||||
|
||||
assert.Equal(t, "80", pr.Out.Header.Get("X-Forwarded-Port"))
|
||||
})
|
||||
|
||||
t.Run("auto detects https from TLS", func(t *testing.T) {
|
||||
p := &ReverseProxy{forwardedProto: "auto"}
|
||||
rewrite := p.rewriteFunc(target, false)
|
||||
pr := newProxyRequest(t, "https://example.com/", "1.2.3.4:5000")
|
||||
pr.In.TLS = &tls.ConnectionState{}
|
||||
|
||||
rewrite(pr)
|
||||
|
||||
assert.Equal(t, "https", pr.Out.Header.Get("X-Forwarded-Proto"))
|
||||
})
|
||||
|
||||
t.Run("auto detects http without TLS", func(t *testing.T) {
|
||||
p := &ReverseProxy{forwardedProto: "auto"}
|
||||
rewrite := p.rewriteFunc(target, false)
|
||||
pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000")
|
||||
|
||||
rewrite(pr)
|
||||
|
||||
assert.Equal(t, "http", pr.Out.Header.Get("X-Forwarded-Proto"))
|
||||
})
|
||||
|
||||
t.Run("forced proto overrides TLS detection", func(t *testing.T) {
|
||||
p := &ReverseProxy{forwardedProto: "https"}
|
||||
rewrite := p.rewriteFunc(target, false)
|
||||
pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000")
|
||||
// No TLS, but forced to https
|
||||
|
||||
rewrite(pr)
|
||||
|
||||
assert.Equal(t, "https", pr.Out.Header.Get("X-Forwarded-Proto"))
|
||||
})
|
||||
|
||||
t.Run("forced http proto", func(t *testing.T) {
|
||||
p := &ReverseProxy{forwardedProto: "http"}
|
||||
rewrite := p.rewriteFunc(target, false)
|
||||
pr := newProxyRequest(t, "https://example.com/", "1.2.3.4:5000")
|
||||
pr.In.TLS = &tls.ConnectionState{}
|
||||
|
||||
rewrite(pr)
|
||||
|
||||
assert.Equal(t, "http", pr.Out.Header.Get("X-Forwarded-Proto"))
|
||||
})
|
||||
}
|
||||
|
||||
func TestRewriteFunc_SessionCookieStripping(t *testing.T) {
|
||||
target, _ := url.Parse("http://backend.internal:8080")
|
||||
p := &ReverseProxy{forwardedProto: "auto"}
|
||||
rewrite := p.rewriteFunc(target, false)
|
||||
|
||||
t.Run("strips nb_session cookie", func(t *testing.T) {
|
||||
pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000")
|
||||
pr.In.AddCookie(&http.Cookie{Name: auth.SessionCookieName, Value: "jwt-token-here"})
|
||||
|
||||
rewrite(pr)
|
||||
|
||||
cookies := pr.Out.Cookies()
|
||||
for _, c := range cookies {
|
||||
assert.NotEqual(t, auth.SessionCookieName, c.Name,
|
||||
"proxy session cookie must not be forwarded to backend")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("preserves other cookies", func(t *testing.T) {
|
||||
pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000")
|
||||
pr.In.AddCookie(&http.Cookie{Name: auth.SessionCookieName, Value: "jwt-token"})
|
||||
pr.In.AddCookie(&http.Cookie{Name: "app_session", Value: "app-value"})
|
||||
pr.In.AddCookie(&http.Cookie{Name: "tracking", Value: "track-value"})
|
||||
|
||||
rewrite(pr)
|
||||
|
||||
cookies := pr.Out.Cookies()
|
||||
cookieNames := make([]string, 0, len(cookies))
|
||||
for _, c := range cookies {
|
||||
cookieNames = append(cookieNames, c.Name)
|
||||
}
|
||||
assert.Contains(t, cookieNames, "app_session", "non-proxy cookies should be preserved")
|
||||
assert.Contains(t, cookieNames, "tracking", "non-proxy cookies should be preserved")
|
||||
assert.NotContains(t, cookieNames, auth.SessionCookieName, "proxy cookie must be stripped")
|
||||
})
|
||||
|
||||
t.Run("handles request with no cookies", func(t *testing.T) {
|
||||
pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000")
|
||||
|
||||
rewrite(pr)
|
||||
|
||||
assert.Empty(t, pr.Out.Header.Get("Cookie"))
|
||||
})
|
||||
}
|
||||
|
||||
func TestRewriteFunc_SessionTokenQueryStripping(t *testing.T) {
|
||||
target, _ := url.Parse("http://backend.internal:8080")
|
||||
p := &ReverseProxy{forwardedProto: "auto"}
|
||||
rewrite := p.rewriteFunc(target, false)
|
||||
|
||||
t.Run("strips session_token query parameter", func(t *testing.T) {
|
||||
pr := newProxyRequest(t, "http://example.com/callback?session_token=secret123&other=keep", "1.2.3.4:5000")
|
||||
|
||||
rewrite(pr)
|
||||
|
||||
assert.Empty(t, pr.Out.URL.Query().Get("session_token"),
|
||||
"OIDC session token must be stripped from backend request")
|
||||
assert.Equal(t, "keep", pr.Out.URL.Query().Get("other"),
|
||||
"other query parameters must be preserved")
|
||||
})
|
||||
|
||||
t.Run("preserves query when no session_token present", func(t *testing.T) {
|
||||
pr := newProxyRequest(t, "http://example.com/api?foo=bar&baz=qux", "1.2.3.4:5000")
|
||||
|
||||
rewrite(pr)
|
||||
|
||||
assert.Equal(t, "bar", pr.Out.URL.Query().Get("foo"))
|
||||
assert.Equal(t, "qux", pr.Out.URL.Query().Get("baz"))
|
||||
})
|
||||
}
|
||||
|
||||
func TestRewriteFunc_URLRewriting(t *testing.T) {
|
||||
target, _ := url.Parse("http://backend.internal:8080/app")
|
||||
p := &ReverseProxy{forwardedProto: "auto"}
|
||||
rewrite := p.rewriteFunc(target, false)
|
||||
|
||||
t.Run("rewrites URL to target with path prefix", func(t *testing.T) {
|
||||
pr := newProxyRequest(t, "http://example.com/somepath", "1.2.3.4:5000")
|
||||
|
||||
rewrite(pr)
|
||||
|
||||
assert.Equal(t, "http", pr.Out.URL.Scheme)
|
||||
assert.Equal(t, "backend.internal:8080", pr.Out.URL.Host)
|
||||
assert.Equal(t, "/app/somepath", pr.Out.URL.Path,
|
||||
"SetURL should join the target base path with the request path")
|
||||
})
|
||||
}
|
||||
|
||||
func TestExtractClientIP(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
remoteAddr string
|
||||
expected string
|
||||
}{
|
||||
{"IPv4 with port", "192.168.1.1:12345", "192.168.1.1"},
|
||||
{"IPv6 with port", "[::1]:12345", "::1"},
|
||||
{"IPv6 full with port", "[2001:db8::1]:443", "2001:db8::1"},
|
||||
{"IPv4 without port fallback", "192.168.1.1", "192.168.1.1"},
|
||||
{"IPv6 without brackets fallback", "::1", "::1"},
|
||||
{"empty string fallback", "", ""},
|
||||
{"public IP", "203.0.113.50:9999", "203.0.113.50"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
assert.Equal(t, tt.expected, extractClientIP(tt.remoteAddr))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractForwardedPort(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
host string
|
||||
resolvedProto string
|
||||
expected string
|
||||
}{
|
||||
{"explicit port in host", "example.com:8443", "https", "8443"},
|
||||
{"explicit port overrides proto default", "example.com:9090", "http", "9090"},
|
||||
{"no port defaults to 443 for https", "example.com", "https", "443"},
|
||||
{"no port defaults to 80 for http", "example.com", "http", "80"},
|
||||
{"IPv6 host with port", "[::1]:8080", "http", "8080"},
|
||||
{"IPv6 host without port", "::1", "https", "443"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
assert.Equal(t, tt.expected, extractForwardedPort(tt.host, tt.resolvedProto))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// newProxyRequest creates an httputil.ProxyRequest suitable for testing
|
||||
// the Rewrite function. It simulates what httputil.ReverseProxy does internally:
|
||||
// Out is a shallow clone of In with headers copied.
|
||||
func newProxyRequest(t *testing.T, rawURL, remoteAddr string) *httputil.ProxyRequest {
|
||||
t.Helper()
|
||||
|
||||
parsed, err := url.Parse(rawURL)
|
||||
require.NoError(t, err)
|
||||
|
||||
in := httptest.NewRequest(http.MethodGet, rawURL, nil)
|
||||
in.RemoteAddr = remoteAddr
|
||||
in.Host = parsed.Host
|
||||
|
||||
out := in.Clone(in.Context())
|
||||
out.Header = in.Header.Clone()
|
||||
|
||||
return &httputil.ProxyRequest{In: in, Out: out}
|
||||
}
|
||||
@@ -11,22 +11,22 @@ import (
|
||||
)
|
||||
|
||||
type Mapping struct {
|
||||
ID string
|
||||
AccountID types.AccountID
|
||||
Host string
|
||||
Paths map[string]*url.URL
|
||||
ID string
|
||||
AccountID types.AccountID
|
||||
Host string
|
||||
Paths map[string]*url.URL
|
||||
PassHostHeader bool
|
||||
}
|
||||
|
||||
func (p *ReverseProxy) findTargetForRequest(req *http.Request) (*url.URL, string, types.AccountID, bool) {
|
||||
type targetResult struct {
|
||||
url *url.URL
|
||||
serviceID string
|
||||
accountID types.AccountID
|
||||
passHostHeader bool
|
||||
}
|
||||
|
||||
func (p *ReverseProxy) findTargetForRequest(req *http.Request) (targetResult, bool) {
|
||||
p.mappingsMux.RLock()
|
||||
if p.mappings == nil {
|
||||
p.mappingsMux.RUnlock()
|
||||
p.mappingsMux.Lock()
|
||||
defer p.mappingsMux.Unlock()
|
||||
p.mappings = make(map[string]Mapping)
|
||||
// There cannot be any loaded Mappings as we have only just initialized.
|
||||
return nil, "", "", false
|
||||
}
|
||||
defer p.mappingsMux.RUnlock()
|
||||
|
||||
// Strip port from host if present (e.g., "external.test:8443" -> "external.test")
|
||||
@@ -38,7 +38,7 @@ func (p *ReverseProxy) findTargetForRequest(req *http.Request) (*url.URL, string
|
||||
p.logger.Debugf("looking for mapping for host: %s, path: %s", host, req.URL.Path)
|
||||
m, exists := p.mappings[host]
|
||||
if !exists {
|
||||
return nil, "", "", false
|
||||
return targetResult{}, false
|
||||
}
|
||||
|
||||
// Sort paths by length (longest first) in a naive attempt to match the most specific route first.
|
||||
@@ -52,18 +52,20 @@ func (p *ReverseProxy) findTargetForRequest(req *http.Request) (*url.URL, string
|
||||
|
||||
for _, path := range paths {
|
||||
if strings.HasPrefix(req.URL.Path, path) {
|
||||
return m.Paths[path], m.ID, m.AccountID, true
|
||||
return targetResult{
|
||||
url: m.Paths[path],
|
||||
serviceID: m.ID,
|
||||
accountID: m.AccountID,
|
||||
passHostHeader: m.PassHostHeader,
|
||||
}, true
|
||||
}
|
||||
}
|
||||
return nil, "", "", false
|
||||
return targetResult{}, false
|
||||
}
|
||||
|
||||
func (p *ReverseProxy) AddMapping(m Mapping) {
|
||||
p.mappingsMux.Lock()
|
||||
defer p.mappingsMux.Unlock()
|
||||
if p.mappings == nil {
|
||||
p.mappings = make(map[string]Mapping)
|
||||
}
|
||||
p.mappings[m.Host] = m
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user