mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-19 08:46:38 +00:00
[management,proxy] Add per-target options to reverse proxy (#5501)
This commit is contained in:
@@ -5,6 +5,7 @@ import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -87,6 +88,188 @@ func TestValidate_MultipleTargetsOneInvalid(t *testing.T) {
|
||||
assert.Contains(t, err.Error(), "empty target_id")
|
||||
}
|
||||
|
||||
func TestValidateTargetOptions_PathRewrite(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
mode PathRewriteMode
|
||||
wantErr string
|
||||
}{
|
||||
{"empty is default", "", ""},
|
||||
{"preserve is valid", PathRewritePreserve, ""},
|
||||
{"unknown rejected", "regex", "unknown path_rewrite mode"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
rp := validProxy()
|
||||
rp.Targets[0].Options.PathRewrite = tt.mode
|
||||
err := rp.Validate()
|
||||
if tt.wantErr == "" {
|
||||
assert.NoError(t, err)
|
||||
} else {
|
||||
assert.ErrorContains(t, err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateTargetOptions_RequestTimeout(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
timeout time.Duration
|
||||
wantErr string
|
||||
}{
|
||||
{"valid 30s", 30 * time.Second, ""},
|
||||
{"valid 2m", 2 * time.Minute, ""},
|
||||
{"zero is fine", 0, ""},
|
||||
{"negative", -1 * time.Second, "must be positive"},
|
||||
{"exceeds max", 10 * time.Minute, "exceeds maximum"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
rp := validProxy()
|
||||
rp.Targets[0].Options.RequestTimeout = tt.timeout
|
||||
err := rp.Validate()
|
||||
if tt.wantErr == "" {
|
||||
assert.NoError(t, err)
|
||||
} else {
|
||||
assert.ErrorContains(t, err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateTargetOptions_CustomHeaders(t *testing.T) {
|
||||
t.Run("valid headers", func(t *testing.T) {
|
||||
rp := validProxy()
|
||||
rp.Targets[0].Options.CustomHeaders = map[string]string{
|
||||
"X-Custom": "value",
|
||||
"X-Trace": "abc123",
|
||||
}
|
||||
assert.NoError(t, rp.Validate())
|
||||
})
|
||||
|
||||
t.Run("CRLF in key", func(t *testing.T) {
|
||||
rp := validProxy()
|
||||
rp.Targets[0].Options.CustomHeaders = map[string]string{"X-Bad\r\nKey": "value"}
|
||||
assert.ErrorContains(t, rp.Validate(), "not a valid HTTP header name")
|
||||
})
|
||||
|
||||
t.Run("CRLF in value", func(t *testing.T) {
|
||||
rp := validProxy()
|
||||
rp.Targets[0].Options.CustomHeaders = map[string]string{"X-Good": "bad\nvalue"}
|
||||
assert.ErrorContains(t, rp.Validate(), "invalid characters")
|
||||
})
|
||||
|
||||
t.Run("hop-by-hop header rejected", func(t *testing.T) {
|
||||
for _, h := range []string{"Connection", "Transfer-Encoding", "Keep-Alive", "Upgrade", "Proxy-Connection"} {
|
||||
rp := validProxy()
|
||||
rp.Targets[0].Options.CustomHeaders = map[string]string{h: "value"}
|
||||
assert.ErrorContains(t, rp.Validate(), "hop-by-hop", "header %q should be rejected", h)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("reserved header rejected", func(t *testing.T) {
|
||||
for _, h := range []string{"X-Forwarded-For", "X-Real-IP", "X-Forwarded-Proto", "X-Forwarded-Host", "X-Forwarded-Port", "Cookie", "Forwarded", "Content-Length", "Content-Type"} {
|
||||
rp := validProxy()
|
||||
rp.Targets[0].Options.CustomHeaders = map[string]string{h: "value"}
|
||||
assert.ErrorContains(t, rp.Validate(), "managed by the proxy", "header %q should be rejected", h)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Host header rejected", func(t *testing.T) {
|
||||
rp := validProxy()
|
||||
rp.Targets[0].Options.CustomHeaders = map[string]string{"Host": "evil.com"}
|
||||
assert.ErrorContains(t, rp.Validate(), "pass_host_header")
|
||||
})
|
||||
|
||||
t.Run("too many headers", func(t *testing.T) {
|
||||
rp := validProxy()
|
||||
headers := make(map[string]string, 17)
|
||||
for i := range 17 {
|
||||
headers[fmt.Sprintf("X-H%d", i)] = "v"
|
||||
}
|
||||
rp.Targets[0].Options.CustomHeaders = headers
|
||||
assert.ErrorContains(t, rp.Validate(), "exceeds maximum of 16")
|
||||
})
|
||||
|
||||
t.Run("key too long", func(t *testing.T) {
|
||||
rp := validProxy()
|
||||
rp.Targets[0].Options.CustomHeaders = map[string]string{strings.Repeat("X", 129): "v"}
|
||||
assert.ErrorContains(t, rp.Validate(), "key")
|
||||
assert.ErrorContains(t, rp.Validate(), "exceeds maximum length")
|
||||
})
|
||||
|
||||
t.Run("value too long", func(t *testing.T) {
|
||||
rp := validProxy()
|
||||
rp.Targets[0].Options.CustomHeaders = map[string]string{"X-Ok": strings.Repeat("v", 4097)}
|
||||
assert.ErrorContains(t, rp.Validate(), "value exceeds maximum length")
|
||||
})
|
||||
|
||||
t.Run("duplicate canonical keys rejected", func(t *testing.T) {
|
||||
rp := validProxy()
|
||||
rp.Targets[0].Options.CustomHeaders = map[string]string{
|
||||
"x-custom": "a",
|
||||
"X-Custom": "b",
|
||||
}
|
||||
assert.ErrorContains(t, rp.Validate(), "collide")
|
||||
})
|
||||
}
|
||||
|
||||
func TestToProtoMapping_TargetOptions(t *testing.T) {
|
||||
rp := &Service{
|
||||
ID: "svc-1",
|
||||
AccountID: "acc-1",
|
||||
Domain: "example.com",
|
||||
Targets: []*Target{
|
||||
{
|
||||
TargetId: "peer-1",
|
||||
TargetType: TargetTypePeer,
|
||||
Host: "10.0.0.1",
|
||||
Port: 8080,
|
||||
Protocol: "http",
|
||||
Enabled: true,
|
||||
Options: TargetOptions{
|
||||
SkipTLSVerify: true,
|
||||
RequestTimeout: 30 * time.Second,
|
||||
PathRewrite: PathRewritePreserve,
|
||||
CustomHeaders: map[string]string{"X-Custom": "val"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
pm := rp.ToProtoMapping(Create, "token", proxy.OIDCValidationConfig{})
|
||||
require.Len(t, pm.Path, 1)
|
||||
|
||||
opts := pm.Path[0].Options
|
||||
require.NotNil(t, opts, "options should be populated")
|
||||
assert.True(t, opts.SkipTlsVerify)
|
||||
assert.Equal(t, proto.PathRewriteMode_PATH_REWRITE_PRESERVE, opts.PathRewrite)
|
||||
assert.Equal(t, map[string]string{"X-Custom": "val"}, opts.CustomHeaders)
|
||||
require.NotNil(t, opts.RequestTimeout)
|
||||
assert.Equal(t, int64(30), opts.RequestTimeout.Seconds)
|
||||
}
|
||||
|
||||
func TestToProtoMapping_NoOptionsWhenDefault(t *testing.T) {
|
||||
rp := &Service{
|
||||
ID: "svc-1",
|
||||
AccountID: "acc-1",
|
||||
Domain: "example.com",
|
||||
Targets: []*Target{
|
||||
{
|
||||
TargetId: "peer-1",
|
||||
TargetType: TargetTypePeer,
|
||||
Host: "10.0.0.1",
|
||||
Port: 8080,
|
||||
Protocol: "http",
|
||||
Enabled: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
pm := rp.ToProtoMapping(Create, "token", proxy.OIDCValidationConfig{})
|
||||
require.Len(t, pm.Path, 1)
|
||||
assert.Nil(t, pm.Path[0].Options, "options should be nil when all defaults")
|
||||
}
|
||||
|
||||
func TestIsDefaultPort(t *testing.T) {
|
||||
tests := []struct {
|
||||
scheme string
|
||||
|
||||
Reference in New Issue
Block a user