Add option to rewrite redirects

This commit is contained in:
Viktor Liu
2026-02-09 00:38:16 +08:00
parent 260c46df04
commit 3630ebb3ae
9 changed files with 459 additions and 36 deletions

View File

@@ -95,6 +95,7 @@ type ReverseProxy struct {
Targets []Target `gorm:"serializer:json"`
Enabled bool
PassHostHeader bool
RewriteRedirects bool
Auth AuthConfig `gorm:"serializer:json"`
Meta ReverseProxyMeta `gorm:"embedded;embeddedPrefix:meta_"`
SessionPrivateKey string `gorm:"column:session_private_key"`
@@ -174,14 +175,15 @@ func (r *ReverseProxy) ToAPIResponse() *api.ReverseProxy {
}
resp := &api.ReverseProxy{
Id: r.ID,
Name: r.Name,
Domain: r.Domain,
Targets: apiTargets,
Enabled: r.Enabled,
PassHostHeader: &r.PassHostHeader,
Auth: authConfig,
Meta: meta,
Id: r.ID,
Name: r.Name,
Domain: r.Domain,
Targets: apiTargets,
Enabled: r.Enabled,
PassHostHeader: &r.PassHostHeader,
RewriteRedirects: &r.RewriteRedirects,
Auth: authConfig,
Meta: meta,
}
if r.ProxyCluster != "" {
@@ -203,6 +205,9 @@ func (r *ReverseProxy) ToProtoMapping(operation Operation, authToken string, oid
path = *target.Path
}
// TODO: Make path prefix stripping configurable per-target.
// Currently the matching prefix is baked into the target URL path,
// so the proxy strips-then-re-adds it (effectively a no-op).
targetURL := url.URL{
Scheme: target.Protocol,
Host: target.Host,
@@ -236,14 +241,15 @@ func (r *ReverseProxy) ToProtoMapping(operation Operation, authToken string, oid
}
return &proto.ProxyMapping{
Type: operationToProtoType(operation),
Id: r.ID,
Domain: r.Domain,
Path: pathMappings,
AuthToken: authToken,
Auth: auth,
AccountId: r.AccountID,
PassHostHeader: r.PassHostHeader,
Type: operationToProtoType(operation),
Id: r.ID,
Domain: r.Domain,
Path: pathMappings,
AuthToken: authToken,
Auth: auth,
AccountId: r.AccountID,
PassHostHeader: r.PassHostHeader,
RewriteRedirects: r.RewriteRedirects,
}
}
@@ -288,6 +294,10 @@ func (r *ReverseProxy) FromAPIRequest(req *api.ReverseProxyRequest, accountID st
r.PassHostHeader = *req.PassHostHeader
}
if req.RewriteRedirects != nil {
r.RewriteRedirects = *req.RewriteRedirects
}
if req.Auth.PasswordAuth != nil {
r.Auth.PasswordAuth = &PasswordAuthConfig{
Enabled: req.Auth.PasswordAuth.Enabled,
@@ -358,6 +368,7 @@ func (r *ReverseProxy) Copy() *ReverseProxy {
Targets: targets,
Enabled: r.Enabled,
PassHostHeader: r.PassHostHeader,
RewriteRedirects: r.RewriteRedirects,
Auth: r.Auth,
Meta: r.Meta,
SessionPrivateKey: r.SessionPrivateKey,

View File

@@ -3,6 +3,7 @@ package proxy
import (
"context"
"errors"
"fmt"
"net"
"net/http"
"net/http/httputil"
@@ -84,6 +85,9 @@ func (p *ReverseProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
Transport: p.transport,
ErrorHandler: proxyErrorHandler,
}
if result.rewriteRedirects {
rp.ModifyResponse = p.rewriteLocationFunc(result.url, result.matchedPath, r)
}
rp.ServeHTTP(w, r.WithContext(ctx))
}
@@ -124,6 +128,62 @@ func (p *ReverseProxy) rewriteFunc(target *url.URL, matchedPath string, passHost
}
}
// rewriteLocationFunc returns a ModifyResponse function that rewrites Location
// headers in backend responses when they point to the backend's address,
// replacing them with the public-facing host and scheme.
func (p *ReverseProxy) rewriteLocationFunc(target *url.URL, matchedPath string, inReq *http.Request) func(*http.Response) error {
publicHost := inReq.Host
publicScheme := auth.ResolveProto(p.forwardedProto, inReq.TLS)
return func(resp *http.Response) error {
location := resp.Header.Get("Location")
if location == "" {
return nil
}
locURL, err := url.Parse(location)
if err != nil {
return fmt.Errorf("parse Location header %q: %w", location, err)
}
// Only rewrite absolute URLs that point to the backend.
if locURL.Host == "" || !hostsEqual(locURL, target) {
return nil
}
locURL.Host = publicHost
locURL.Scheme = publicScheme
// Re-add the stripped path prefix so the client reaches the correct route.
// TrimRight prevents double slashes when matchedPath has a trailing slash.
if matchedPath != "" && matchedPath != "/" {
locURL.Path = strings.TrimRight(matchedPath, "/") + "/" + strings.TrimLeft(locURL.Path, "/")
}
resp.Header.Set("Location", locURL.String())
return nil
}
}
// hostsEqual compares two URL authorities, normalizing default ports per
// RFC 3986 Section 6.2.3 (https://443 == https, http://80 == http).
func hostsEqual(a, b *url.URL) bool {
return normalizeHost(a) == normalizeHost(b)
}
// normalizeHost strips the port from a URL's Host field if it matches the
// scheme's default port (443 for https, 80 for http).
func normalizeHost(u *url.URL) string {
host, port, err := net.SplitHostPort(u.Host)
if err != nil {
return u.Host
}
if (u.Scheme == "https" && port == "443") || (u.Scheme == "http" && port == "80") {
return host
}
return u.Host
}
// setTrustedForwardingHeaders appends to the existing forwarding header chain
// and preserves upstream-provided headers when the direct connection is from
// a trusted proxy.

View File

@@ -470,6 +470,329 @@ func TestRewriteFunc_TrustedProxy(t *testing.T) {
})
}
// TestRewriteFunc_PathForwarding verifies what path the backend actually
// receives given different configurations. This simulates the full pipeline:
// management builds a target URL (with matching prefix baked into the path),
// then the proxy strips the prefix and SetURL re-joins with the target path.
func TestRewriteFunc_PathForwarding(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto"}
// Simulate what ToProtoMapping does: target URL includes the matching
// prefix as its path component, so the proxy strips-then-re-adds.
t.Run("path prefix baked into target URL is a no-op", func(t *testing.T) {
// Management builds: path="/heise", target="https://heise.de:443/heise"
target, _ := url.Parse("https://heise.de:443/heise")
rewrite := p.rewriteFunc(target, "/heise", false)
pr := newProxyRequest(t, "http://external.test/heise", "1.2.3.4:5000")
rewrite(pr)
assert.Equal(t, "/heise/", pr.Out.URL.Path,
"backend sees /heise/ because prefix is stripped then re-added by SetURL")
})
t.Run("subpath under prefix also preserved", func(t *testing.T) {
target, _ := url.Parse("https://heise.de:443/heise")
rewrite := p.rewriteFunc(target, "/heise", false)
pr := newProxyRequest(t, "http://external.test/heise/article/123", "1.2.3.4:5000")
rewrite(pr)
assert.Equal(t, "/heise/article/123", pr.Out.URL.Path,
"subpath is preserved on top of the re-added prefix")
})
// What the behavior WOULD be if target URL had no path (true stripping)
t.Run("target without path prefix gives true stripping", func(t *testing.T) {
target, _ := url.Parse("https://heise.de:443")
rewrite := p.rewriteFunc(target, "/heise", false)
pr := newProxyRequest(t, "http://external.test/heise", "1.2.3.4:5000")
rewrite(pr)
assert.Equal(t, "/", pr.Out.URL.Path,
"without path in target URL, backend sees / (true prefix stripping)")
})
t.Run("target without path prefix strips and preserves subpath", func(t *testing.T) {
target, _ := url.Parse("https://heise.de:443")
rewrite := p.rewriteFunc(target, "/heise", false)
pr := newProxyRequest(t, "http://external.test/heise/article/123", "1.2.3.4:5000")
rewrite(pr)
assert.Equal(t, "/article/123", pr.Out.URL.Path,
"without path in target URL, prefix is truly stripped")
})
// Root path "/" — no stripping expected
t.Run("root path forwards full request path unchanged", func(t *testing.T) {
target, _ := url.Parse("https://backend.example.com:443/")
rewrite := p.rewriteFunc(target, "/", false)
pr := newProxyRequest(t, "http://external.test/heise", "1.2.3.4:5000")
rewrite(pr)
assert.Equal(t, "/heise", pr.Out.URL.Path,
"root path match must not strip anything")
})
}
func TestRewriteLocationFunc(t *testing.T) {
target, _ := url.Parse("http://backend.internal:8080")
newProxy := func(proto string) *ReverseProxy { return &ReverseProxy{forwardedProto: proto} }
newReq := func(rawURL string) *http.Request {
t.Helper()
r := httptest.NewRequest(http.MethodGet, rawURL, nil)
parsed, _ := url.Parse(rawURL)
r.Host = parsed.Host
return r
}
run := func(p *ReverseProxy, matchedPath string, inReq *http.Request, location string) (*http.Response, error) {
t.Helper()
modifyResp := p.rewriteLocationFunc(target, matchedPath, inReq)
resp := &http.Response{Header: http.Header{}}
if location != "" {
resp.Header.Set("Location", location)
}
err := modifyResp(resp)
return resp, err
}
t.Run("rewrites Location pointing to backend", func(t *testing.T) {
resp, err := run(newProxy("https"), "", newReq("https://public.example.com/page"),
"http://backend.internal:8080/login")
require.NoError(t, err)
assert.Equal(t, "https://public.example.com/login", resp.Header.Get("Location"))
})
t.Run("does not rewrite Location pointing to other host", func(t *testing.T) {
resp, err := run(newProxy("https"), "", newReq("https://public.example.com/"),
"https://other.example.com/path")
require.NoError(t, err)
assert.Equal(t, "https://other.example.com/path", resp.Header.Get("Location"))
})
t.Run("does not rewrite relative Location", func(t *testing.T) {
resp, err := run(newProxy("https"), "", newReq("https://public.example.com/"),
"/dashboard")
require.NoError(t, err)
assert.Equal(t, "/dashboard", resp.Header.Get("Location"))
})
t.Run("re-adds stripped path prefix", func(t *testing.T) {
resp, err := run(newProxy("https"), "/api", newReq("https://public.example.com/api/users"),
"http://backend.internal:8080/users")
require.NoError(t, err)
assert.Equal(t, "https://public.example.com/api/users", resp.Header.Get("Location"))
})
t.Run("uses resolved proto for scheme", func(t *testing.T) {
resp, err := run(newProxy("auto"), "", newReq("http://public.example.com/"),
"http://backend.internal:8080/path")
require.NoError(t, err)
assert.Equal(t, "http://public.example.com/path", resp.Header.Get("Location"))
})
t.Run("no-op when Location header is empty", func(t *testing.T) {
resp, err := run(newProxy("https"), "", newReq("https://public.example.com/"), "")
require.NoError(t, err)
assert.Empty(t, resp.Header.Get("Location"))
})
t.Run("does not prepend root path prefix", func(t *testing.T) {
resp, err := run(newProxy("https"), "/", newReq("https://public.example.com/login"),
"http://backend.internal:8080/login")
require.NoError(t, err)
assert.Equal(t, "https://public.example.com/login", resp.Header.Get("Location"))
})
// --- Edge cases: query parameters and fragments ---
t.Run("preserves query parameters", func(t *testing.T) {
resp, err := run(newProxy("https"), "", newReq("https://public.example.com/"),
"http://backend.internal:8080/login?redirect=%2Fdashboard&lang=en")
require.NoError(t, err)
assert.Equal(t, "https://public.example.com/login?redirect=%2Fdashboard&lang=en", resp.Header.Get("Location"))
})
t.Run("preserves fragment", func(t *testing.T) {
resp, err := run(newProxy("https"), "", newReq("https://public.example.com/"),
"http://backend.internal:8080/docs#section-2")
require.NoError(t, err)
assert.Equal(t, "https://public.example.com/docs#section-2", resp.Header.Get("Location"))
})
t.Run("preserves query parameters and fragment together", func(t *testing.T) {
resp, err := run(newProxy("https"), "", newReq("https://public.example.com/"),
"http://backend.internal:8080/search?q=test&page=1#results")
require.NoError(t, err)
assert.Equal(t, "https://public.example.com/search?q=test&page=1#results", resp.Header.Get("Location"))
})
t.Run("preserves query parameters with path prefix re-added", func(t *testing.T) {
resp, err := run(newProxy("https"), "/api", newReq("https://public.example.com/api/search"),
"http://backend.internal:8080/search?q=hello")
require.NoError(t, err)
assert.Equal(t, "https://public.example.com/api/search?q=hello", resp.Header.Get("Location"))
})
// --- Edge cases: slash handling ---
t.Run("no double slash when matchedPath has trailing slash", func(t *testing.T) {
resp, err := run(newProxy("https"), "/api/", newReq("https://public.example.com/api/users"),
"http://backend.internal:8080/users")
require.NoError(t, err)
assert.Equal(t, "https://public.example.com/api/users", resp.Header.Get("Location"))
})
t.Run("backend redirect to root with path prefix", func(t *testing.T) {
resp, err := run(newProxy("https"), "/app", newReq("https://public.example.com/app/"),
"http://backend.internal:8080/")
require.NoError(t, err)
assert.Equal(t, "https://public.example.com/app/", resp.Header.Get("Location"))
})
t.Run("backend redirect to root with trailing-slash path prefix", func(t *testing.T) {
resp, err := run(newProxy("https"), "/app/", newReq("https://public.example.com/app/"),
"http://backend.internal:8080/")
require.NoError(t, err)
assert.Equal(t, "https://public.example.com/app/", resp.Header.Get("Location"))
})
t.Run("preserves trailing slash on redirect path", func(t *testing.T) {
resp, err := run(newProxy("https"), "", newReq("https://public.example.com/"),
"http://backend.internal:8080/path/")
require.NoError(t, err)
assert.Equal(t, "https://public.example.com/path/", resp.Header.Get("Location"))
})
t.Run("backend redirect to bare root", func(t *testing.T) {
resp, err := run(newProxy("https"), "", newReq("https://public.example.com/page"),
"http://backend.internal:8080/")
require.NoError(t, err)
assert.Equal(t, "https://public.example.com/", resp.Header.Get("Location"))
})
// --- Edge cases: host/port matching ---
t.Run("does not rewrite when backend host matches but port differs", func(t *testing.T) {
resp, err := run(newProxy("https"), "", newReq("https://public.example.com/"),
"http://backend.internal:9090/other")
require.NoError(t, err)
assert.Equal(t, "http://backend.internal:9090/other", resp.Header.Get("Location"),
"Different port means different host authority, must not rewrite")
})
t.Run("rewrites when redirect omits default port matching target", func(t *testing.T) {
// Target is backend.internal:8080, redirect is to backend.internal (no port).
// These are different authorities, so should NOT rewrite.
resp, err := run(newProxy("https"), "", newReq("https://public.example.com/"),
"http://backend.internal/path")
require.NoError(t, err)
assert.Equal(t, "http://backend.internal/path", resp.Header.Get("Location"),
"backend.internal != backend.internal:8080, must not rewrite")
})
t.Run("rewrites when target has :443 but redirect omits it for https", func(t *testing.T) {
// Target: heise.de:443, redirect: https://heise.de/path (no :443 because it's default)
// Per RFC 3986, these are the same authority.
target443, _ := url.Parse("https://heise.de:443")
p := newProxy("https")
modifyResp := p.rewriteLocationFunc(target443, "", newReq("https://public.example.com/"))
resp := &http.Response{Header: http.Header{}}
resp.Header.Set("Location", "https://heise.de/path")
err := modifyResp(resp)
require.NoError(t, err)
assert.Equal(t, "https://public.example.com/path", resp.Header.Get("Location"),
"heise.de:443 and heise.de are the same for https")
})
t.Run("rewrites when target has :80 but redirect omits it for http", func(t *testing.T) {
target80, _ := url.Parse("http://backend.local:80")
p := newProxy("http")
modifyResp := p.rewriteLocationFunc(target80, "", newReq("http://public.example.com/"))
resp := &http.Response{Header: http.Header{}}
resp.Header.Set("Location", "http://backend.local/path")
err := modifyResp(resp)
require.NoError(t, err)
assert.Equal(t, "http://public.example.com/path", resp.Header.Get("Location"),
"backend.local:80 and backend.local are the same for http")
})
t.Run("rewrites when redirect has :443 but target omits it", func(t *testing.T) {
targetNoPort, _ := url.Parse("https://heise.de")
p := newProxy("https")
modifyResp := p.rewriteLocationFunc(targetNoPort, "", newReq("https://public.example.com/"))
resp := &http.Response{Header: http.Header{}}
resp.Header.Set("Location", "https://heise.de:443/path")
err := modifyResp(resp)
require.NoError(t, err)
assert.Equal(t, "https://public.example.com/path", resp.Header.Get("Location"),
"heise.de and heise.de:443 are the same for https")
})
t.Run("does not conflate non-default ports", func(t *testing.T) {
target8443, _ := url.Parse("https://backend.internal:8443")
p := newProxy("https")
modifyResp := p.rewriteLocationFunc(target8443, "", newReq("https://public.example.com/"))
resp := &http.Response{Header: http.Header{}}
resp.Header.Set("Location", "https://backend.internal/path")
err := modifyResp(resp)
require.NoError(t, err)
assert.Equal(t, "https://backend.internal/path", resp.Header.Get("Location"),
"backend.internal:8443 != backend.internal (port 443), must not rewrite")
})
// --- Edge cases: encoded paths ---
t.Run("preserves percent-encoded path segments", func(t *testing.T) {
resp, err := run(newProxy("https"), "", newReq("https://public.example.com/"),
"http://backend.internal:8080/path%20with%20spaces/file%2Fname")
require.NoError(t, err)
loc := resp.Header.Get("Location")
assert.Contains(t, loc, "public.example.com")
parsed, err := url.Parse(loc)
require.NoError(t, err)
assert.Equal(t, "/path with spaces/file/name", parsed.Path)
})
t.Run("preserves encoded query parameters with path prefix", func(t *testing.T) {
resp, err := run(newProxy("https"), "/v1", newReq("https://public.example.com/v1/"),
"http://backend.internal:8080/redirect?url=http%3A%2F%2Fexample.com")
require.NoError(t, err)
assert.Equal(t, "https://public.example.com/v1/redirect?url=http%3A%2F%2Fexample.com", resp.Header.Get("Location"))
})
}
// newProxyRequest creates an httputil.ProxyRequest suitable for testing
// the Rewrite function. It simulates what httputil.ReverseProxy does internally:
// Out is a shallow clone of In with headers copied.

View File

@@ -11,19 +11,21 @@ import (
)
type Mapping struct {
ID string
AccountID types.AccountID
Host string
Paths map[string]*url.URL
PassHostHeader bool
ID string
AccountID types.AccountID
Host string
Paths map[string]*url.URL
PassHostHeader bool
RewriteRedirects bool
}
type targetResult struct {
url *url.URL
matchedPath string
serviceID string
accountID types.AccountID
passHostHeader bool
url *url.URL
matchedPath string
serviceID string
accountID types.AccountID
passHostHeader bool
rewriteRedirects bool
}
func (p *ReverseProxy) findTargetForRequest(req *http.Request) (targetResult, bool) {
@@ -56,11 +58,12 @@ func (p *ReverseProxy) findTargetForRequest(req *http.Request) (targetResult, bo
target := m.Paths[path]
p.logger.Debugf("matched host: %s, path: %s -> %s", host, path, target)
return targetResult{
url: target,
matchedPath: path,
serviceID: m.ID,
accountID: m.AccountID,
passHostHeader: m.PassHostHeader,
url: target,
matchedPath: path,
serviceID: m.ID,
accountID: m.AccountID,
passHostHeader: m.PassHostHeader,
rewriteRedirects: m.RewriteRedirects,
}, true
}
}

View File

@@ -491,11 +491,12 @@ func (s *Server) protoToMapping(mapping *proto.ProxyMapping) proxy.Mapping {
paths[pathMapping.GetPath()] = targetURL
}
return proxy.Mapping{
ID: mapping.GetId(),
AccountID: types.AccountID(mapping.GetAccountId()),
Host: mapping.GetDomain(),
Paths: paths,
PassHostHeader: mapping.GetPassHostHeader(),
ID: mapping.GetId(),
AccountID: types.AccountID(mapping.GetAccountId()),
Host: mapping.GetDomain(),
Paths: paths,
PassHostHeader: mapping.GetPassHostHeader(),
RewriteRedirects: mapping.GetRewriteRedirects(),
}
}

View File

@@ -2865,6 +2865,9 @@ components:
pass_host_header:
type: boolean
description: When true, the original client Host header is passed through to the backend instead of being rewritten to the backend's address
rewrite_redirects:
type: boolean
description: When true, Location headers in backend responses are rewritten to replace the backend address with the public-facing domain
auth:
$ref: '#/components/schemas/ReverseProxyAuthConfig'
meta:
@@ -2925,6 +2928,9 @@ components:
pass_host_header:
type: boolean
description: When true, the original client Host header is passed through to the backend instead of being rewritten to the backend's address
rewrite_redirects:
type: boolean
description: When true, Location headers in backend responses are rewritten to replace the backend address with the public-facing domain
auth:
$ref: '#/components/schemas/ReverseProxyAuthConfig'
required:

View File

@@ -1992,6 +1992,9 @@ type ReverseProxy struct {
// ProxyCluster The proxy cluster handling this reverse proxy (derived from domain)
ProxyCluster *string `json:"proxy_cluster,omitempty"`
// RewriteRedirects When true, Location headers in backend responses are rewritten to replace the backend address with the public-facing domain
RewriteRedirects *bool `json:"rewrite_redirects,omitempty"`
// Targets List of target backends for this reverse proxy
Targets []ReverseProxyTarget `json:"targets"`
}
@@ -2065,6 +2068,9 @@ type ReverseProxyRequest struct {
// PassHostHeader When true, the original client Host header is passed through to the backend instead of being rewritten to the backend's address
PassHostHeader *bool `json:"pass_host_header,omitempty"`
// RewriteRedirects When true, Location headers in backend responses are rewritten to replace the backend address with the public-facing domain
RewriteRedirects *bool `json:"rewrite_redirects,omitempty"`
// Targets List of target backends for this reverse proxy
Targets []ReverseProxyTarget `json:"targets"`
}

View File

@@ -399,6 +399,9 @@ type ProxyMapping struct {
// When true, the original Host header from the client request is passed
// through to the backend instead of being rewritten to the backend's address.
PassHostHeader bool `protobuf:"varint,8,opt,name=pass_host_header,json=passHostHeader,proto3" json:"pass_host_header,omitempty"`
// When true, Location headers in backend responses are rewritten to replace
// the backend address with the public-facing domain.
RewriteRedirects bool `protobuf:"varint,9,opt,name=rewrite_redirects,json=rewriteRedirects,proto3" json:"rewrite_redirects,omitempty"`
}
func (x *ProxyMapping) Reset() {
@@ -489,6 +492,13 @@ func (x *ProxyMapping) GetPassHostHeader() bool {
return false
}
func (x *ProxyMapping) GetRewriteRedirects() bool {
if x != nil {
return x.RewriteRedirects
}
return false
}
// SendAccessLogRequest consists of one or more AccessLogs from a Proxy.
type SendAccessLogRequest struct {
state protoimpl.MessageState

View File

@@ -67,6 +67,9 @@ message ProxyMapping {
// When true, the original Host header from the client request is passed
// through to the backend instead of being rewritten to the backend's address.
bool pass_host_header = 8;
// When true, Location headers in backend responses are rewritten to replace
// the backend address with the public-facing domain.
bool rewrite_redirects = 9;
}
// SendAccessLogRequest consists of one or more AccessLogs from a Proxy.