feat(private-service): expose NetBird-only services over tunnel peers

Adds a new "private" service mode for the reverse proxy: services
reachable exclusively over the embedded WireGuard tunnel, gated by
per-peer group membership instead of operator auth schemes.

Wire contract
- ProxyMapping.private (field 13): the proxy MUST call
  ValidateTunnelPeer and fail closed; operator schemes are bypassed.
- ProxyCapabilities.private (4) + supports_private_service (5):
  capability gate. Management never streams private mappings to
  proxies that don't claim the capability; the broadcast path applies
  the same filter via filterMappingsForProxy.
- ValidateTunnelPeer RPC: resolves an inbound tunnel IP to a peer,
  checks the peer's groups against service.AccessGroups, and mints
  a session JWT on success. checkPeerGroupAccess fails closed when
  a private service has empty AccessGroups.
- ValidateSession/ValidateTunnelPeer responses now carry
  peer_group_ids + peer_group_names so the proxy can authorise
  policy-aware middlewares without an extra management round-trip.
- ProxyInboundListener + SendStatusUpdate.inbound_listener: per-account
  inbound listener state surfaced to dashboards.
- PathTargetOptions.direct_upstream (11): bypass the embedded NetBird
  client and dial the target via the proxy host's network stack for
  upstreams reachable without WireGuard.

Data model
- Service.Private (bool) + Service.AccessGroups ([]string, JSON-
  serialised). Validate() rejects bearer auth on private services.
  Copy() deep-copies AccessGroups. pgx getServices loads the columns.
- DomainConfig.Private threaded into the proxy auth middleware.
  Request handler routes private services through forwardWithTunnelPeer
  and returns 403 on validation failure.
- Account-level SynthesizePrivateServiceZones (synthetic DNS) and
  injectPrivateServicePolicies (synthetic ACL) gate on
  len(svc.AccessGroups) > 0.

Proxy
- /netbird proxy --private (embedded mode) flag; Config.Private in
  proxy/lifecycle.go.
- Per-account inbound listener (proxy/inbound.go) binding HTTP/HTTPS
  on the embedded NetBird client's WireGuard tunnel netstack.
- proxy/internal/auth/tunnel_cache: ValidateTunnelPeer response cache
  with single-flight de-duplication and per-account eviction.
- Local peerstore short-circuit: when the inbound IP isn't in the
  account roster, deny fast without an RPC.
- proxy/server.go reports SupportsPrivateService=true and redacts the
  full ProxyMapping JSON from info logs (auth_token + header-auth
  hashed values now only at debug level).

Identity forwarding
- ValidateSessionJWT returns user_id, email, method, groups,
  group_names. sessionkey.Claims carries Email + Groups + GroupNames
  so the proxy can stamp identity onto upstream requests without an
  extra management round-trip on every cookie-bearing request.
- CapturedData carries userEmail / userGroups / userGroupNames; the
  proxy stamps X-NetBird-User and X-NetBird-Groups on r.Out from the
  authenticated identity (strips client-supplied values first to
  prevent spoofing).
- AccessLog.UserGroups: access-log enrichment captures the user's
  group memberships at write time so the dashboard can render group
  context without reverse-resolving stale memberships.

OpenAPI/dashboard surface
- ReverseProxyService gains private + access_groups; ReverseProxyCluster
  gains private + supports_private. ReverseProxyTarget target_type
  enum gains "cluster". ServiceTargetOptions gains direct_upstream.
  ProxyAccessLog gains user_groups.
This commit is contained in:
mlsmaycon
2026-05-20 21:39:22 +02:00
parent 37052fd5bc
commit 167ee08e14
72 changed files with 6584 additions and 2586 deletions

View File

@@ -4,13 +4,16 @@ import (
"context"
"crypto/ed25519"
"crypto/rand"
"crypto/tls"
"encoding/base64"
"errors"
"io"
"net/http"
"net/http/httptest"
"net/netip"
"net/url"
"strings"
"sync"
"testing"
"time"
@@ -62,7 +65,7 @@ func TestAddDomain_ValidKey(t *testing.T) {
kp := generateTestKeyPair(t)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
err := mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil)
err := mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false)
require.NoError(t, err)
mw.domainsMux.RLock()
@@ -79,7 +82,7 @@ func TestAddDomain_EmptyKey(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil, nil)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
err := mw.AddDomain("example.com", []Scheme{scheme}, "", time.Hour, "", "", nil)
err := mw.AddDomain("example.com", []Scheme{scheme}, "", time.Hour, "", "", nil, false)
require.Error(t, err)
assert.Contains(t, err.Error(), "invalid session public key size")
@@ -93,7 +96,7 @@ func TestAddDomain_InvalidBase64(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil, nil)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
err := mw.AddDomain("example.com", []Scheme{scheme}, "not-valid-base64!!!", time.Hour, "", "", nil)
err := mw.AddDomain("example.com", []Scheme{scheme}, "not-valid-base64!!!", time.Hour, "", "", nil, false)
require.Error(t, err)
assert.Contains(t, err.Error(), "decode session public key")
@@ -108,7 +111,7 @@ func TestAddDomain_WrongKeySize(t *testing.T) {
shortKey := base64.StdEncoding.EncodeToString([]byte("tooshort"))
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
err := mw.AddDomain("example.com", []Scheme{scheme}, shortKey, time.Hour, "", "", nil)
err := mw.AddDomain("example.com", []Scheme{scheme}, shortKey, time.Hour, "", "", nil, false)
require.Error(t, err)
assert.Contains(t, err.Error(), "invalid session public key size")
@@ -121,7 +124,7 @@ func TestAddDomain_WrongKeySize(t *testing.T) {
func TestAddDomain_NoSchemes_NoKeyRequired(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil, nil)
err := mw.AddDomain("example.com", nil, "", time.Hour, "", "", nil)
err := mw.AddDomain("example.com", nil, "", time.Hour, "", "", nil, false)
require.NoError(t, err, "domains with no auth schemes should not require a key")
mw.domainsMux.RLock()
@@ -137,8 +140,8 @@ func TestAddDomain_OverwritesPreviousConfig(t *testing.T) {
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp1.PublicKey, time.Hour, "", "", nil))
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp2.PublicKey, 2*time.Hour, "", "", nil))
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp1.PublicKey, time.Hour, "", "", nil, false))
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp2.PublicKey, 2*time.Hour, "", "", nil, false))
mw.domainsMux.RLock()
config := mw.domains["example.com"]
@@ -154,7 +157,7 @@ func TestRemoveDomain(t *testing.T) {
kp := generateTestKeyPair(t)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
mw.RemoveDomain("example.com")
@@ -178,7 +181,7 @@ func TestProtect_UnknownDomainPassesThrough(t *testing.T) {
func TestProtect_DomainWithNoSchemesPassesThrough(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil, nil)
require.NoError(t, mw.AddDomain("example.com", nil, "", time.Hour, "", "", nil))
require.NoError(t, mw.AddDomain("example.com", nil, "", time.Hour, "", "", nil, false))
handler := mw.Protect(newPassthroughHandler())
@@ -195,7 +198,7 @@ func TestProtect_UnauthenticatedRequestIsBlocked(t *testing.T) {
kp := generateTestKeyPair(t)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
var backendCalled bool
backend := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
@@ -216,7 +219,7 @@ func TestProtect_HostWithPortIsMatched(t *testing.T) {
kp := generateTestKeyPair(t)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
var backendCalled bool
backend := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
@@ -237,9 +240,9 @@ func TestProtect_ValidSessionCookiePassesThrough(t *testing.T) {
kp := generateTestKeyPair(t)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
token, err := sessionkey.SignToken(kp.PrivateKey, "test-user", "example.com", auth.MethodPIN, time.Hour)
token, err := sessionkey.SignToken(kp.PrivateKey, "test-user", "", "example.com", auth.MethodPIN, nil, nil, time.Hour)
require.NoError(t, err)
capturedData := proxy.NewCapturedData("")
@@ -262,15 +265,48 @@ func TestProtect_ValidSessionCookiePassesThrough(t *testing.T) {
assert.Equal(t, "authenticated", rec.Body.String())
}
// TestProtect_SessionCookieGroupsPropagate verifies the cookie path lifts the
// JWT's groups claim into CapturedData so policy-aware middlewares can
// authorise without an extra management round-trip.
func TestProtect_SessionCookieGroupsPropagate(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil, nil)
kp := generateTestKeyPair(t)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
groups := []string{"engineering", "sre"}
token, err := sessionkey.SignToken(kp.PrivateKey, "test-user", "", "example.com", auth.MethodPIN, groups, nil, time.Hour)
require.NoError(t, err)
capturedData := proxy.NewCapturedData("")
handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
cd := proxy.CapturedDataFromContext(r.Context())
require.NotNil(t, cd, "captured data must be present in request context")
assert.Equal(t, "test-user", cd.GetUserID())
assert.Equal(t, groups, cd.GetUserGroups(), "JWT groups claim must propagate to CapturedData")
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
req = req.WithContext(proxy.WithCapturedData(req.Context(), capturedData))
req.AddCookie(&http.Cookie{Name: auth.SessionCookieName, Value: token})
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code, "request with valid groups-bearing cookie must succeed")
assert.Equal(t, groups, capturedData.GetUserGroups(), "CapturedData groups must be retained after handler completes")
}
func TestProtect_ExpiredSessionCookieIsRejected(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil, nil)
kp := generateTestKeyPair(t)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
// Sign a token that expired 1 second ago.
token, err := sessionkey.SignToken(kp.PrivateKey, "test-user", "example.com", auth.MethodPIN, -time.Second)
token, err := sessionkey.SignToken(kp.PrivateKey, "test-user", "", "example.com", auth.MethodPIN, nil, nil, -time.Second)
require.NoError(t, err)
var backendCalled bool
@@ -293,10 +329,10 @@ func TestProtect_WrongDomainCookieIsRejected(t *testing.T) {
kp := generateTestKeyPair(t)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
// Token signed for a different domain audience.
token, err := sessionkey.SignToken(kp.PrivateKey, "test-user", "other.com", auth.MethodPIN, time.Hour)
token, err := sessionkey.SignToken(kp.PrivateKey, "test-user", "", "other.com", auth.MethodPIN, nil, nil, time.Hour)
require.NoError(t, err)
var backendCalled bool
@@ -320,10 +356,10 @@ func TestProtect_WrongKeyCookieIsRejected(t *testing.T) {
kp2 := generateTestKeyPair(t)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp1.PublicKey, time.Hour, "", "", nil))
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp1.PublicKey, time.Hour, "", "", nil, false))
// Token signed with a different private key.
token, err := sessionkey.SignToken(kp2.PrivateKey, "test-user", "example.com", auth.MethodPIN, time.Hour)
token, err := sessionkey.SignToken(kp2.PrivateKey, "test-user", "", "example.com", auth.MethodPIN, nil, nil, time.Hour)
require.NoError(t, err)
var backendCalled bool
@@ -345,7 +381,7 @@ func TestProtect_SchemeAuthRedirectsWithCookie(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil, nil)
kp := generateTestKeyPair(t)
token, err := sessionkey.SignToken(kp.PrivateKey, "pin-user", "example.com", auth.MethodPIN, time.Hour)
token, err := sessionkey.SignToken(kp.PrivateKey, "pin-user", "", "example.com", auth.MethodPIN, nil, nil, time.Hour)
require.NoError(t, err)
scheme := &stubScheme{
@@ -357,7 +393,7 @@ func TestProtect_SchemeAuthRedirectsWithCookie(t *testing.T) {
return "", "pin", nil
},
}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
var backendCalled bool
backend := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
@@ -410,7 +446,7 @@ func TestProtect_FailedAuthDoesNotSetCookie(t *testing.T) {
return "", "pin", nil
},
}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
handler := mw.Protect(newPassthroughHandler())
@@ -427,7 +463,7 @@ func TestProtect_MultipleSchemes(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil, nil)
kp := generateTestKeyPair(t)
token, err := sessionkey.SignToken(kp.PrivateKey, "password-user", "example.com", auth.MethodPassword, time.Hour)
token, err := sessionkey.SignToken(kp.PrivateKey, "password-user", "", "example.com", auth.MethodPassword, nil, nil, time.Hour)
require.NoError(t, err)
// First scheme (PIN) always fails, second scheme (password) succeeds.
@@ -446,7 +482,7 @@ func TestProtect_MultipleSchemes(t *testing.T) {
return "", "password", nil
},
}
require.NoError(t, mw.AddDomain("example.com", []Scheme{pinScheme, passwordScheme}, kp.PublicKey, time.Hour, "", "", nil))
require.NoError(t, mw.AddDomain("example.com", []Scheme{pinScheme, passwordScheme}, kp.PublicKey, time.Hour, "", "", nil, false))
var backendCalled bool
backend := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
@@ -476,7 +512,7 @@ func TestProtect_InvalidTokenFromSchemeReturns400(t *testing.T) {
return "invalid-jwt-token", "", nil
},
}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
handler := mw.Protect(newPassthroughHandler())
@@ -500,7 +536,7 @@ func TestAddDomain_RandomBytes32NotEd25519(t *testing.T) {
key := base64.StdEncoding.EncodeToString(randomBytes)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
err = mw.AddDomain("example.com", []Scheme{scheme}, key, time.Hour, "", "", nil)
err = mw.AddDomain("example.com", []Scheme{scheme}, key, time.Hour, "", "", nil, false)
require.NoError(t, err, "any 32-byte key should be accepted at registration time")
}
@@ -509,10 +545,10 @@ func TestAddDomain_InvalidKeyDoesNotCorruptExistingConfig(t *testing.T) {
kp := generateTestKeyPair(t)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
// Attempt to overwrite with an invalid key.
err := mw.AddDomain("example.com", []Scheme{scheme}, "bad", time.Hour, "", "", nil)
err := mw.AddDomain("example.com", []Scheme{scheme}, "bad", time.Hour, "", "", nil, false)
require.Error(t, err)
// The original valid config should still be intact.
@@ -536,7 +572,7 @@ func TestProtect_FailedPinAuthCapturesAuthMethod(t *testing.T) {
return "", "pin", nil
},
}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
capturedData := proxy.NewCapturedData("")
handler := mw.Protect(newPassthroughHandler())
@@ -563,7 +599,7 @@ func TestProtect_FailedPasswordAuthCapturesAuthMethod(t *testing.T) {
return "", "password", nil
},
}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
capturedData := proxy.NewCapturedData("")
handler := mw.Protect(newPassthroughHandler())
@@ -590,7 +626,7 @@ func TestProtect_NoCredentialsDoesNotCaptureAuthMethod(t *testing.T) {
return "", "pin", nil
},
}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
capturedData := proxy.NewCapturedData("")
handler := mw.Protect(newPassthroughHandler())
@@ -678,7 +714,7 @@ func TestCheckIPRestrictions_UnparseableAddress(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil, nil)
err := mw.AddDomain("example.com", nil, "", 0, "acc1", "svc1",
restrict.ParseFilter(restrict.FilterConfig{AllowedCIDRs: []string{"10.0.0.0/8"}}))
restrict.ParseFilter(restrict.FilterConfig{AllowedCIDRs: []string{"10.0.0.0/8"}}), false)
require.NoError(t, err)
handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -714,7 +750,7 @@ func TestCheckIPRestrictions_UsesCapturedDataClientIP(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil, nil)
err := mw.AddDomain("example.com", nil, "", 0, "acc1", "svc1",
restrict.ParseFilter(restrict.FilterConfig{AllowedCIDRs: []string{"203.0.113.0/24"}}))
restrict.ParseFilter(restrict.FilterConfig{AllowedCIDRs: []string{"203.0.113.0/24"}}), false)
require.NoError(t, err)
handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -755,7 +791,7 @@ func TestCheckIPRestrictions_NilGeoWithCountryRules(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil, nil)
err := mw.AddDomain("example.com", nil, "", 0, "acc1", "svc1",
restrict.ParseFilter(restrict.FilterConfig{AllowedCountries: []string{"US"}}))
restrict.ParseFilter(restrict.FilterConfig{AllowedCountries: []string{"US"}}), false)
require.NoError(t, err)
handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -781,11 +817,12 @@ func TestProtect_OIDCOnlyRedirectsDirectly(t *testing.T) {
return "", oidcURL, nil
},
}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
handler := mw.Protect(newPassthroughHandler())
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
req := httptest.NewRequest(http.MethodGet, "https://example.com/", nil)
req.TLS = &tls.ConnectionState{}
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
@@ -809,11 +846,12 @@ func TestProtect_OIDCWithOtherMethodShowsLoginPage(t *testing.T) {
return "", "pin", nil
},
}
require.NoError(t, mw.AddDomain("example.com", []Scheme{oidcScheme, pinScheme}, kp.PublicKey, time.Hour, "", "", nil))
require.NoError(t, mw.AddDomain("example.com", []Scheme{oidcScheme, pinScheme}, kp.PublicKey, time.Hour, "", "", nil, false))
handler := mw.Protect(newPassthroughHandler())
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
req := httptest.NewRequest(http.MethodGet, "https://example.com/", nil)
req.TLS = &tls.ConnectionState{}
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
@@ -834,7 +872,7 @@ func (m *mockAuthenticator) Authenticate(ctx context.Context, in *proto.Authenti
// returns a signed session token when the expected header value is provided.
func newHeaderSchemeWithToken(t *testing.T, kp *sessionkey.KeyPair, headerName, expectedValue string) Header {
t.Helper()
token, err := sessionkey.SignToken(kp.PrivateKey, "header-user", "example.com", auth.MethodHeader, time.Hour)
token, err := sessionkey.SignToken(kp.PrivateKey, "header-user", "", "example.com", auth.MethodHeader, nil, nil, time.Hour)
require.NoError(t, err)
mock := &mockAuthenticator{fn: func(_ context.Context, req *proto.AuthenticateRequest) (*proto.AuthenticateResponse, error) {
@@ -852,7 +890,7 @@ func TestProtect_HeaderAuth_ForwardsOnSuccess(t *testing.T) {
kp := generateTestKeyPair(t)
hdr := newHeaderSchemeWithToken(t, kp, "X-API-Key", "secret-key")
require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr}, kp.PublicKey, time.Hour, "acc1", "svc1", nil))
require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr}, kp.PublicKey, time.Hour, "acc1", "svc1", nil, false))
var backendCalled bool
capturedData := proxy.NewCapturedData("")
@@ -895,7 +933,7 @@ func TestProtect_HeaderAuth_MissingHeaderFallsThrough(t *testing.T) {
hdr := newHeaderSchemeWithToken(t, kp, "X-API-Key", "secret-key")
// Also add a PIN scheme so we can verify fallthrough behavior.
pinScheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr, pinScheme}, kp.PublicKey, time.Hour, "acc1", "svc1", nil))
require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr, pinScheme}, kp.PublicKey, time.Hour, "acc1", "svc1", nil, false))
handler := mw.Protect(newPassthroughHandler())
@@ -915,7 +953,7 @@ func TestProtect_HeaderAuth_WrongValueReturns401(t *testing.T) {
return &proto.AuthenticateResponse{Success: false}, nil
}}
hdr := NewHeader(mock, "svc1", "acc1", "X-API-Key")
require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr}, kp.PublicKey, time.Hour, "acc1", "svc1", nil))
require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr}, kp.PublicKey, time.Hour, "acc1", "svc1", nil, false))
capturedData := proxy.NewCapturedData("")
handler := mw.Protect(newPassthroughHandler())
@@ -938,7 +976,7 @@ func TestProtect_HeaderAuth_InfraErrorReturns502(t *testing.T) {
return nil, errors.New("gRPC unavailable")
}}
hdr := NewHeader(mock, "svc1", "acc1", "X-API-Key")
require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr}, kp.PublicKey, time.Hour, "acc1", "svc1", nil))
require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr}, kp.PublicKey, time.Hour, "acc1", "svc1", nil, false))
handler := mw.Protect(newPassthroughHandler())
@@ -955,7 +993,7 @@ func TestProtect_HeaderAuth_SubsequentRequestUsesSessionCookie(t *testing.T) {
kp := generateTestKeyPair(t)
hdr := newHeaderSchemeWithToken(t, kp, "X-API-Key", "secret-key")
require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr}, kp.PublicKey, time.Hour, "acc1", "svc1", nil))
require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr}, kp.PublicKey, time.Hour, "acc1", "svc1", nil, false))
handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
@@ -1006,7 +1044,7 @@ func TestProtect_HeaderAuth_MultipleValuesSameHeader(t *testing.T) {
mock := &mockAuthenticator{fn: func(_ context.Context, req *proto.AuthenticateRequest) (*proto.AuthenticateResponse, error) {
ha := req.GetHeaderAuth()
if ha != nil && accepted[ha.GetHeaderValue()] {
token, err := sessionkey.SignToken(kp.PrivateKey, "header-user", "example.com", auth.MethodHeader, time.Hour)
token, err := sessionkey.SignToken(kp.PrivateKey, "header-user", "", "example.com", auth.MethodHeader, nil, nil, time.Hour)
require.NoError(t, err)
return &proto.AuthenticateResponse{Success: true, SessionToken: token}, nil
}
@@ -1015,7 +1053,7 @@ func TestProtect_HeaderAuth_MultipleValuesSameHeader(t *testing.T) {
// Single Header scheme (as if one entry existed), but the mock checks both values.
hdr := NewHeader(mock, "svc1", "acc1", "Authorization")
require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr}, kp.PublicKey, time.Hour, "acc1", "svc1", nil))
require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr}, kp.PublicKey, time.Hour, "acc1", "svc1", nil, false))
var backendCalled bool
handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
@@ -1059,3 +1097,173 @@ func TestProtect_HeaderAuth_MultipleValuesSameHeader(t *testing.T) {
assert.False(t, backendCalled, "unknown token should be rejected")
})
}
// TestProtect_OIDCOnPlainHTTP_BlockedWith400 verifies that when an OIDC
// scheme is configured and the request arrived without TLS, the middleware
// short-circuits with a 400 instead of dispatching to the IdP redirect.
func TestProtect_OIDCOnPlainHTTP_BlockedWith400(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil, nil)
kp := generateTestKeyPair(t)
scheme := &stubScheme{
method: auth.MethodOIDC,
authFn: func(_ *http.Request) (string, string, error) {
return "", "https://idp.example.com/authorize", nil
},
}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
handler := mw.Protect(newPassthroughHandler())
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusBadRequest, rec.Code, "OIDC over plain HTTP should be rejected")
assert.Contains(t, rec.Body.String(), "OIDC requires TLS", "response body should explain the rejection")
}
// TestProtect_OIDCOverTLS_NotBlocked confirms the same configuration works
// over TLS — the block only fires on plain HTTP.
func TestProtect_OIDCOverTLS_NotBlocked(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil, nil)
kp := generateTestKeyPair(t)
scheme := &stubScheme{
method: auth.MethodOIDC,
authFn: func(_ *http.Request) (string, string, error) {
return "", "https://idp.example.com/authorize", nil
},
}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
handler := mw.Protect(newPassthroughHandler())
req := httptest.NewRequest(http.MethodGet, "https://example.com/", nil)
req.TLS = &tls.ConnectionState{}
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusFound, rec.Code, "OIDC over TLS should redirect to IdP")
}
// TestProtect_NonOIDCSchemes_PlainHTTP_NotBlocked confirms that the OIDC
// block only fires when an OIDC scheme is configured. PIN-only domains
// pass through normally on plain HTTP.
func TestProtect_NonOIDCSchemes_PlainHTTP_NotBlocked(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil, nil)
kp := generateTestKeyPair(t)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
handler := mw.Protect(newPassthroughHandler())
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusUnauthorized, rec.Code, "PIN-only domain should serve the login page on plain HTTP")
}
// TestProtect_SessionCookieOnPlainHTTP_LogsWarn verifies that a request
// carrying a valid session cookie over plain HTTP is still forwarded but
// emits a WARN-level log line for the operator.
func TestProtect_SessionCookieOnPlainHTTP_LogsWarn(t *testing.T) {
logger, hook := newTestLogger()
mw := NewMiddleware(logger, nil, nil)
kp := generateTestKeyPair(t)
pinScheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
require.NoError(t, mw.AddDomain("example.com", []Scheme{pinScheme}, kp.PublicKey, time.Hour, "", "", nil, false))
token, err := sessionkey.SignToken(kp.PrivateKey, "user-1", "", "example.com", auth.MethodPassword, nil, nil, time.Hour)
require.NoError(t, err)
var backendCalled bool
handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
backendCalled = true
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
req.AddCookie(&http.Cookie{Name: auth.SessionCookieName, Value: token})
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.True(t, backendCalled, "backend should still be reached — we don't drop the cookie")
assert.Equal(t, http.StatusOK, rec.Code)
var found bool
for _, entry := range hook.entries() {
if entry.Level == log.WarnLevel && strings.Contains(entry.Message, "session cookie on plain HTTP path") {
found = true
break
}
}
assert.True(t, found, "expected WARN log for session cookie on plain HTTP")
}
// TestProtect_SessionCookieOverTLS_NoWarn confirms the WARN only fires
// on plain HTTP — TLS requests with a session cookie behave as before.
func TestProtect_SessionCookieOverTLS_NoWarn(t *testing.T) {
logger, hook := newTestLogger()
mw := NewMiddleware(logger, nil, nil)
kp := generateTestKeyPair(t)
pinScheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
require.NoError(t, mw.AddDomain("example.com", []Scheme{pinScheme}, kp.PublicKey, time.Hour, "", "", nil, false))
token, err := sessionkey.SignToken(kp.PrivateKey, "user-1", "", "example.com", auth.MethodPassword, nil, nil, time.Hour)
require.NoError(t, err)
handler := mw.Protect(newPassthroughHandler())
req := httptest.NewRequest(http.MethodGet, "https://example.com/", nil)
req.TLS = &tls.ConnectionState{}
req.AddCookie(&http.Cookie{Name: auth.SessionCookieName, Value: token})
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
for _, entry := range hook.entries() {
assert.NotContains(t, entry.Message, "session cookie on plain HTTP path", "no plain-HTTP cookie warn expected over TLS")
}
}
// captureHook is a minimal logrus hook that records emitted entries for
// inspection in tests. It avoids pulling in the full sirupsen test
// helpers package (which the rest of the codebase doesn't use).
type captureHook struct {
mu sync.Mutex
records []log.Entry
}
func (h *captureHook) Levels() []log.Level { return log.AllLevels }
func (h *captureHook) Fire(entry *log.Entry) error {
h.mu.Lock()
defer h.mu.Unlock()
h.records = append(h.records, *entry)
return nil
}
func (h *captureHook) entries() []log.Entry {
h.mu.Lock()
defer h.mu.Unlock()
return append([]log.Entry{}, h.records...)
}
// newTestLogger builds an isolated logrus logger with a capture hook so
// tests can assert on emitted records without contending on the global
// logger.
func newTestLogger() (*log.Logger, *captureHook) {
hook := &captureHook{}
logger := log.New()
logger.SetOutput(io.Discard)
logger.AddHook(hook)
logger.SetLevel(log.DebugLevel)
return logger, hook
}