[management, client, proxy] add expose NetBird-only services over tunnel peers (#6226)

Adds a new "private" service mode for the reverse proxy: services reachable exclusively over the embedded WireGuard tunnel, gated by per-peer group membership instead of operator auth schemes.

Wire contract
- ProxyMapping.private (field 13): the proxy MUST call ValidateTunnelPeer and fail closed; operator schemes are bypassed.
- ProxyCapabilities.private (4) + supports_private_service (5): capability gate. Management never streams private mappings to proxies that don't claim the capability; the broadcast path applies the same filter via filterMappingsForProxy.
- ValidateTunnelPeer RPC: resolves an inbound tunnel IP to a peer, checks the peer's groups against service.AccessGroups, and mints a session JWT on success. checkPeerGroupAccess fails closed when a private service has empty AccessGroups.
- ValidateSession/ValidateTunnelPeer responses now carry peer_group_ids + peer_group_names so the proxy can authorise policy-aware middlewares without an extra management round-trip.
- ProxyInboundListener + SendStatusUpdate.inbound_listener: per-account inbound listener state surfaced to dashboards.
- PathTargetOptions.direct_upstream (11): bypass the embedded NetBird client and dial the target via the proxy host's network stack for upstreams reachable without WireGuard.

Data model
- Service.Private (bool) + Service.AccessGroups ([]string, JSON- serialised). Validate() rejects bearer auth on private services. Copy() deep-copies AccessGroups. pgx getServices loads the columns.
- DomainConfig.Private threaded into the proxy auth middleware. Request handler routes private services through forwardWithTunnelPeer and returns 403 on validation failure.
- Account-level SynthesizePrivateServiceZones (synthetic DNS) and injectPrivateServicePolicies (synthetic ACL) gate on len(svc.AccessGroups) > 0.

Proxy
- /netbird proxy --private (embedded mode) flag; Config.Private in proxy/lifecycle.go.
- Per-account inbound listener (proxy/inbound.go) binding HTTP/HTTPS on the embedded NetBird client's WireGuard tunnel netstack.
- proxy/internal/auth/tunnel_cache: ValidateTunnelPeer response cache with single-flight de-duplication and per-account eviction.
- Local peerstore short-circuit: when the inbound IP isn't in the account roster, deny fast without an RPC.
- proxy/server.go reports SupportsPrivateService=true and redacts the full ProxyMapping JSON from info logs (auth_token + header-auth hashed values now only at debug level).

Identity forwarding
- ValidateSessionJWT returns user_id, email, method, groups, group_names. sessionkey.Claims carries Email + Groups + GroupNames so the proxy can stamp identity onto upstream requests without an extra management round-trip on every cookie-bearing request.
- CapturedData carries userEmail / userGroups / userGroupNames; the proxy stamps X-NetBird-User and X-NetBird-Groups on r.Out from the authenticated identity (strips client-supplied values first to prevent spoofing).
- AccessLog.UserGroups: access-log enrichment captures the user's group memberships at write time so the dashboard can render group context without reverse-resolving stale memberships.

OpenAPI/dashboard surface
- ReverseProxyService gains private + access_groups; ReverseProxyCluster gains private + supports_private. ReverseProxyTarget target_type enum gains "cluster". ServiceTargetOptions gains direct_upstream. ProxyAccessLog gains user_groups.
This commit is contained in:
Maycon Santos
2026-05-25 17:41:50 +02:00
committed by GitHub
parent 0358be2313
commit 7aebdd69dd
84 changed files with 7810 additions and 933 deletions

View File

@@ -45,10 +45,14 @@ func ResolveProto(forwardedProto string, conn *tls.ConnectionState) string {
}
}
// ValidateSessionJWT validates a session JWT and returns the user ID and method.
func ValidateSessionJWT(tokenString, domain string, publicKey ed25519.PublicKey) (userID, method string, err error) {
// ValidateSessionJWT validates a session JWT and returns the user ID, the
// user's email (when carried), the authentication method, any embedded
// group memberships, and the parallel group display names. email,
// groups, and groupNames may be empty for tokens minted before those
// claims were introduced. groupNames pairs positionally with groups.
func ValidateSessionJWT(tokenString, domain string, publicKey ed25519.PublicKey) (userID, email, method string, groups, groupNames []string, err error) {
if publicKey == nil {
return "", "", fmt.Errorf("no public key configured for domain")
return "", "", "", nil, nil, fmt.Errorf("no public key configured for domain")
}
token, err := jwt.Parse(tokenString, func(t *jwt.Token) (interface{}, error) {
@@ -58,20 +62,46 @@ func ValidateSessionJWT(tokenString, domain string, publicKey ed25519.PublicKey)
return publicKey, nil
}, jwt.WithAudience(domain), jwt.WithIssuer(SessionJWTIssuer))
if err != nil {
return "", "", fmt.Errorf("parse token: %w", err)
return "", "", "", nil, nil, fmt.Errorf("parse token: %w", err)
}
claims, ok := token.Claims.(jwt.MapClaims)
if !ok || !token.Valid {
return "", "", fmt.Errorf("invalid token claims")
return "", "", "", nil, nil, fmt.Errorf("invalid token claims")
}
sub, _ := claims.GetSubject()
if sub == "" {
return "", "", fmt.Errorf("missing subject claim")
return "", "", "", nil, nil, fmt.Errorf("missing subject claim")
}
methodClaim, _ := claims["method"].(string)
emailClaim, _ := claims["email"].(string)
groups = extractGroupsClaim(claims["groups"])
groupNames = extractGroupsClaim(claims["group_names"])
return sub, methodClaim, nil
return sub, emailClaim, methodClaim, groups, groupNames, nil
}
// extractGroupsClaim decodes the "groups" claim into a string slice. The JWT
// library decodes JSON arrays as []interface{}, so we coerce element-wise
// and skip non-string entries silently.
func extractGroupsClaim(claim interface{}) []string {
raw, ok := claim.([]interface{})
if !ok {
return nil
}
if len(raw) == 0 {
return nil
}
groups := make([]string, 0, len(raw))
for _, v := range raw {
if s, ok := v.(string); ok && s != "" {
groups = append(groups, s)
}
}
if len(groups) == 0 {
return nil
}
return groups
}

View File

@@ -63,6 +63,7 @@ var (
preSharedKey string
supportsCustomPorts bool
requireSubdomain bool
private bool
geoDataDir string
crowdsecAPIURL string
crowdsecAPIKey string
@@ -105,6 +106,8 @@ func init() {
rootCmd.Flags().StringVar(&preSharedKey, "preshared-key", envStringOrDefault("NB_PROXY_PRESHARED_KEY", ""), "Define a pre-shared key for the tunnel between proxy and peers")
rootCmd.Flags().BoolVar(&supportsCustomPorts, "supports-custom-ports", envBoolOrDefault("NB_PROXY_SUPPORTS_CUSTOM_PORTS", true), "Whether the proxy can bind arbitrary ports for UDP/TCP passthrough")
rootCmd.Flags().BoolVar(&requireSubdomain, "require-subdomain", envBoolOrDefault("NB_PROXY_REQUIRE_SUBDOMAIN", false), "Require a subdomain label in front of the cluster domain")
rootCmd.Flags().BoolVar(&private, "private", envBoolOrDefault("NB_PROXY_PRIVATE", false), "Enable private services accessible with NetBird-Only authentication mode.")
_ = rootCmd.Flags().MarkHidden("private")
rootCmd.Flags().DurationVar(&maxDialTimeout, "max-dial-timeout", envDurationOrDefault("NB_PROXY_MAX_DIAL_TIMEOUT", 0), "Cap per-service backend dial timeout (0 = no cap)")
rootCmd.Flags().DurationVar(&maxSessionIdleTimeout, "max-session-idle-timeout", envDurationOrDefault("NB_PROXY_MAX_SESSION_IDLE_TIMEOUT", 0), "Cap per-service session idle timeout (0 = no cap)")
rootCmd.Flags().StringVar(&geoDataDir, "geo-data-dir", envStringOrDefault("NB_PROXY_GEO_DATA_DIR", "/var/lib/netbird/geolocation"), "Directory for the GeoLite2 MMDB file (auto-downloaded if missing)")
@@ -161,7 +164,8 @@ func runServer(cmd *cobra.Command, args []string) error {
return fmt.Errorf("invalid --trusted-proxies: %w", err)
}
srv := proxy.Server{
srv := proxy.New(proxy.Config{
ListenAddr: addr,
Logger: logger,
Version: Version,
ManagementAddress: mgmtAddr,
@@ -178,7 +182,7 @@ func runServer(cmd *cobra.Command, args []string) error {
ACMEChallengeType: acmeChallengeType,
DebugEndpointEnabled: debugEndpoint,
DebugEndpointAddress: debugEndpointAddr,
HealthAddress: healthAddr,
HealthAddr: healthAddr,
ForwardedProto: forwardedProto,
TrustedProxies: parsedTrustedProxies,
CertLockMethod: nbacme.CertLockMethod(certLockMethod),
@@ -188,12 +192,13 @@ func runServer(cmd *cobra.Command, args []string) error {
PreSharedKey: preSharedKey,
SupportsCustomPorts: supportsCustomPorts,
RequireSubdomain: requireSubdomain,
Private: private,
MaxDialTimeout: maxDialTimeout,
MaxSessionIdleTimeout: maxSessionIdleTimeout,
GeoDataDir: geoDataDir,
CrowdSecAPIURL: crowdsecAPIURL,
CrowdSecAPIKey: crowdsecAPIKey,
}
})
ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGTERM, syscall.SIGINT)
defer stop()

547
proxy/inbound.go Normal file
View File

@@ -0,0 +1,547 @@
package proxy
import (
"context"
"crypto/tls"
"errors"
"fmt"
stdlog "log"
"net"
"net/http"
"net/netip"
"strconv"
"sync"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/embed"
"github.com/netbirdio/netbird/proxy/internal/auth"
"github.com/netbirdio/netbird/proxy/internal/debug"
nbtcp "github.com/netbirdio/netbird/proxy/internal/tcp"
"github.com/netbirdio/netbird/proxy/internal/types"
)
// httpInboundReadHeaderTimeout matches the host-listener read header timeout
// so per-account http.Servers don't leak idle connections.
const httpInboundReadHeaderTimeout = 30 * time.Second
// httpInboundIdleTimeout caps idle keep-alive on per-account inbound HTTP
// servers; matches the host listener.
const httpInboundIdleTimeout = 90 * time.Second
// inboundShutdownTimeout caps how long a per-account http.Server gets to
// drain in-flight requests during teardown.
const inboundShutdownTimeout = 5 * time.Second
// privateInboundPortHTTPS is the WG-side TLS port. Each account's
// embedded netstack binds independently, so a fixed port is fine.
const privateInboundPortHTTPS = 443
// privateInboundPortHTTP is the WG-side plain-HTTP port.
const privateInboundPortHTTP = 80
// inboundManager wires per-account inbound listeners into the proxy
// pipeline when --private-inbound is enabled. When disabled the manager
// is nil and every method on *Server that touches it short-circuits.
type inboundManager struct {
logger *log.Logger
handler http.Handler
tlsConfig *tls.Config
// muxLock guards entries and pendingRoutes.
muxLock sync.Mutex
entries map[types.AccountID]*inboundEntry
pendingRoutes map[types.AccountID][]pendingInboundRoute
}
// inboundEntry owns the listeners, router and HTTP servers for a single
// account's embedded netstack.
type inboundEntry struct {
router *nbtcp.Router
tlsListener net.Listener
plainListener net.Listener
httpsServer *http.Server
httpServer *http.Server
cancel context.CancelFunc
wg sync.WaitGroup
}
// pendingInboundRoute holds a route that arrived before the account's
// listener finished starting.
type pendingInboundRoute struct {
host nbtcp.SNIHost
route nbtcp.Route
}
// newInboundManager constructs a manager bound to the proxy's HTTP
// handler chain and TLS config.
func newInboundManager(logger *log.Logger, handler http.Handler, tlsConfig *tls.Config) *inboundManager {
return &inboundManager{
logger: logger,
handler: handler,
tlsConfig: tlsConfig,
entries: make(map[types.AccountID]*inboundEntry),
pendingRoutes: make(map[types.AccountID][]pendingInboundRoute),
}
}
// onClientReady is registered with NetBird.SetClientLifecycle so the
// listener pair comes up exactly when the embedded client reports ready.
// The returned value is opaque to the roundtrip package; it is handed
// back verbatim to onClientStop on teardown.
func (m *inboundManager) onClientReady(ctx context.Context, accountID types.AccountID, client *embed.Client) any {
if m == nil {
return nil
}
entry, err := m.bringUp(ctx, accountID, client)
if err != nil {
m.logger.WithField("account_id", accountID).WithError(err).Warn("failed to start per-account inbound listener; continuing without inbound")
return nil
}
m.flushPending(accountID, entry)
m.logger.WithFields(log.Fields{
"account_id": accountID,
"https": entry.tlsListener.Addr().String(),
"http": entry.plainListener.Addr().String(),
}).Info("per-account inbound listeners up")
return entry
}
// onClientStop tears down a per-account listener bundle. State is the
// opaque value previously returned by onClientReady.
func (m *inboundManager) onClientStop(accountID types.AccountID, state any) {
if m == nil {
return
}
entry, ok := state.(*inboundEntry)
if !ok || entry == nil {
return
}
m.tearDown(accountID, entry)
}
// bringUp opens both listeners on the account's netstack, builds the
// router, and starts the parallel HTTP servers.
func (m *inboundManager) bringUp(ctx context.Context, accountID types.AccountID, client *embed.Client) (*inboundEntry, error) {
tlsListener, err := client.ListenTCP(fmt.Sprintf(":%d", privateInboundPortHTTPS))
if err != nil {
return nil, fmt.Errorf("listen tls on netstack: %w", err)
}
plainListener, err := client.ListenTCP(fmt.Sprintf(":%d", privateInboundPortHTTP))
if err != nil {
_ = tlsListener.Close()
return nil, fmt.Errorf("listen plain on netstack: %w", err)
}
router := nbtcp.NewRouter(m.logger, accountDialResolver(accountID, client), tlsListener.Addr(), nbtcp.WithPlainHTTP(plainListener.Addr()))
scopedHandler := withTunnelLookup(m.handler, accountTunnelLookup(client))
// markOverlayOrigin stamps every connection accepted by an inbound
// listener with a context value middlewares can read to skip
// geo/CrowdSec checks (the source address is always inside the
// NetBird CGNAT range and won't match either dataset).
markOverlayOrigin := func(ctx context.Context, _ net.Conn) context.Context {
return types.WithOverlayOrigin(ctx)
}
httpsServer := &http.Server{
Handler: scopedHandler,
TLSConfig: m.tlsConfig,
ReadHeaderTimeout: httpInboundReadHeaderTimeout,
IdleTimeout: httpInboundIdleTimeout,
ErrorLog: newInboundErrorLog(m.logger, "https", accountID),
ConnContext: markOverlayOrigin,
}
httpServer := &http.Server{
Handler: scopedHandler,
ReadHeaderTimeout: httpInboundReadHeaderTimeout,
IdleTimeout: httpInboundIdleTimeout,
ErrorLog: newInboundErrorLog(m.logger, "http", accountID),
ConnContext: markOverlayOrigin,
}
runCtx, cancel := context.WithCancel(ctx)
entry := &inboundEntry{
router: router,
tlsListener: tlsListener,
plainListener: plainListener,
httpsServer: httpsServer,
httpServer: httpServer,
cancel: cancel,
}
entry.wg.Add(1)
go func() {
defer entry.wg.Done()
if err := router.Serve(runCtx, tlsListener); err != nil {
m.logger.WithField("account_id", accountID).Debugf("per-account router stopped: %v", err)
}
}()
entry.wg.Add(1)
go func() {
defer entry.wg.Done()
if err := httpsServer.ServeTLS(router.HTTPListener(), "", ""); err != nil && !errors.Is(err, http.ErrServerClosed) {
m.logger.WithField("account_id", accountID).Debugf("per-account https server stopped: %v", err)
}
}()
entry.wg.Add(1)
go func() {
defer entry.wg.Done()
if err := httpServer.Serve(router.HTTPListenerPlain()); err != nil && !errors.Is(err, http.ErrServerClosed) {
m.logger.WithField("account_id", accountID).Debugf("per-account http server stopped: %v", err)
}
}()
entry.wg.Add(1)
go func() {
defer entry.wg.Done()
feedRouterFromListener(runCtx, plainListener, router, m.logger, accountID)
}()
m.muxLock.Lock()
m.entries[accountID] = entry
m.muxLock.Unlock()
return entry, nil
}
// tearDown shuts every goroutine down and closes the netstack listeners.
func (m *inboundManager) tearDown(accountID types.AccountID, entry *inboundEntry) {
m.muxLock.Lock()
if m.entries[accountID] == entry {
delete(m.entries, accountID)
delete(m.pendingRoutes, accountID)
}
m.muxLock.Unlock()
entry.cancel()
shutdownCtx, cancel := context.WithTimeout(context.Background(), inboundShutdownTimeout)
defer cancel()
if err := entry.httpsServer.Shutdown(shutdownCtx); err != nil {
m.logger.Debugf("per-account https shutdown: %v", err)
}
if err := entry.httpServer.Shutdown(shutdownCtx); err != nil {
m.logger.Debugf("per-account http shutdown: %v", err)
}
if err := entry.tlsListener.Close(); err != nil {
m.logger.Debugf("close per-account tls listener: %v", err)
}
if err := entry.plainListener.Close(); err != nil {
m.logger.Debugf("close per-account plain listener: %v", err)
}
entry.wg.Wait()
}
// AddRoute records an SNI/host route on the account's per-account router.
// Routes registered before the listener is up are queued and replayed
// once startup completes.
func (m *inboundManager) AddRoute(accountID types.AccountID, host nbtcp.SNIHost, route nbtcp.Route) {
if m == nil {
return
}
m.muxLock.Lock()
entry, ok := m.entries[accountID]
if !ok {
m.queuePendingLocked(accountID, host, route)
m.muxLock.Unlock()
return
}
router := entry.router
m.muxLock.Unlock()
router.AddRoute(host, route)
}
// RemoveRoute drops a previously registered route. Safe to call when the
// listener is not yet up; queued copies are pruned in that case.
func (m *inboundManager) RemoveRoute(accountID types.AccountID, host nbtcp.SNIHost, svcID types.ServiceID) {
if m == nil {
return
}
m.muxLock.Lock()
m.dropPendingLocked(accountID, host, svcID)
entry, ok := m.entries[accountID]
if !ok {
m.muxLock.Unlock()
return
}
router := entry.router
m.muxLock.Unlock()
router.RemoveRoute(host, svcID)
}
// queuePendingLocked stores or upserts a pending route. Caller holds muxLock.
func (m *inboundManager) queuePendingLocked(accountID types.AccountID, host nbtcp.SNIHost, route nbtcp.Route) {
queued := m.pendingRoutes[accountID]
for i, pr := range queued {
if pr.host == host && pr.route.ServiceID == route.ServiceID {
queued[i] = pendingInboundRoute{host: host, route: route}
m.pendingRoutes[accountID] = queued
return
}
}
m.pendingRoutes[accountID] = append(queued, pendingInboundRoute{host: host, route: route})
}
// dropPendingLocked removes any queued route matching host/svcID.
// Caller holds muxLock.
func (m *inboundManager) dropPendingLocked(accountID types.AccountID, host nbtcp.SNIHost, svcID types.ServiceID) {
queued, ok := m.pendingRoutes[accountID]
if !ok {
return
}
filtered := queued[:0]
for _, pr := range queued {
if pr.host == host && pr.route.ServiceID == svcID {
continue
}
filtered = append(filtered, pr)
}
if len(filtered) == 0 {
delete(m.pendingRoutes, accountID)
return
}
m.pendingRoutes[accountID] = filtered
}
// flushPending applies all queued routes to a freshly-up router.
func (m *inboundManager) flushPending(accountID types.AccountID, entry *inboundEntry) {
m.muxLock.Lock()
queued := m.pendingRoutes[accountID]
delete(m.pendingRoutes, accountID)
m.muxLock.Unlock()
for _, pr := range queued {
entry.router.AddRoute(pr.host, pr.route)
}
}
// HasInbound reports whether the manager has a live listener for the account.
// Used by tests.
func (m *inboundManager) HasInbound(accountID types.AccountID) bool {
if m == nil {
return false
}
m.muxLock.Lock()
defer m.muxLock.Unlock()
_, ok := m.entries[accountID]
return ok
}
// PendingRouteCount reports the number of queued routes for the account.
// Used by tests.
func (m *inboundManager) PendingRouteCount(accountID types.AccountID) int {
if m == nil {
return 0
}
m.muxLock.Lock()
defer m.muxLock.Unlock()
return len(m.pendingRoutes[accountID])
}
// InboundListenerInfo describes the bound addresses of a single
// per-account inbound listener. Both addresses live on the embedded
// netstack of the account's WireGuard client and share the same tunnel IP.
type InboundListenerInfo struct {
TunnelIP string
HTTPSPort uint16
HTTPPort uint16
}
// ListenerInfo returns the inbound listener addresses for the given
// account, or ok=false when the account has no live listener. Used by
// the status-update RPC and the debug HTTP handler to surface inbound
// reachability to operators.
func (m *inboundManager) ListenerInfo(accountID types.AccountID) (InboundListenerInfo, bool) {
if m == nil {
return InboundListenerInfo{}, false
}
m.muxLock.Lock()
defer m.muxLock.Unlock()
entry, ok := m.entries[accountID]
if !ok || entry == nil {
return InboundListenerInfo{}, false
}
return listenerInfoFromEntry(entry), true
}
// Snapshot returns the inbound listener state for every account that has
// a live listener at call time. Empty when --private-inbound is off or
// no accounts have come up yet.
func (m *inboundManager) Snapshot() map[types.AccountID]InboundListenerInfo {
if m == nil {
return nil
}
m.muxLock.Lock()
defer m.muxLock.Unlock()
if len(m.entries) == 0 {
return nil
}
out := make(map[types.AccountID]InboundListenerInfo, len(m.entries))
for id, entry := range m.entries {
if entry == nil {
continue
}
out[id] = listenerInfoFromEntry(entry)
}
return out
}
// listenerInfoFromEntry extracts the tunnel IP and ports from a live
// per-account entry. Both listeners are bound on the same netstack so
// their host components match; we still pull the TLS host as the
// authoritative source.
func listenerInfoFromEntry(entry *inboundEntry) InboundListenerInfo {
info := InboundListenerInfo{HTTPSPort: privateInboundPortHTTPS, HTTPPort: privateInboundPortHTTP}
if entry.tlsListener != nil {
host, port := splitHostPort(entry.tlsListener.Addr())
info.TunnelIP = host
if port != 0 {
info.HTTPSPort = port
}
}
if entry.plainListener != nil {
host, port := splitHostPort(entry.plainListener.Addr())
if info.TunnelIP == "" {
info.TunnelIP = host
}
if port != 0 {
info.HTTPPort = port
}
}
return info
}
// splitHostPort extracts host and port from a net.Addr, returning the
// zero values when the address is missing or malformed.
func splitHostPort(addr net.Addr) (string, uint16) {
if addr == nil {
return "", 0
}
host, portStr, err := net.SplitHostPort(addr.String())
if err != nil {
return "", 0
}
if portStr == "" {
return host, 0
}
port, err := strconv.ParseUint(portStr, 10, 16)
if err != nil {
return host, 0
}
return host, uint16(port)
}
// feedRouterFromListener accepts on the plain-HTTP netstack listener and
// hands every connection to the account's router. The router peeks the
// first byte and dispatches to the plain-HTTP channel for non-TLS
// streams or the TLS channel for ClientHellos that arrive on :80.
func feedRouterFromListener(ctx context.Context, ln net.Listener, router *nbtcp.Router, logger *log.Logger, accountID types.AccountID) {
go func() {
<-ctx.Done()
_ = ln.Close()
}()
for {
conn, err := ln.Accept()
if err != nil {
if ctx.Err() != nil || errors.Is(err, net.ErrClosed) {
return
}
logger.WithField("account_id", accountID).Debugf("plain inbound accept: %v", err)
continue
}
router.HandleConn(ctx, conn)
}
}
// accountDialResolver returns a DialResolver bound to a single account's
// embedded client. The router only ever serves traffic for that account
// so the supplied accountID is ignored at dial time.
func accountDialResolver(_ types.AccountID, client *embed.Client) nbtcp.DialResolver {
return func(_ types.AccountID) (types.DialContextFunc, error) {
return client.DialContext, nil
}
}
// accountTunnelLookup returns a TunnelLookupFunc backed by the embedded
// client's peerstore for a single account. Phase 3 uses the result to
// short-circuit ValidateTunnelPeer when the source IP is not in the
// account's roster and to seed the cached identity for known peers.
func accountTunnelLookup(client *embed.Client) auth.TunnelLookupFunc {
if client == nil {
return nil
}
return func(ip netip.Addr) (auth.PeerIdentity, bool) {
pubKey, fqdn, ok := client.IdentityForIP(ip)
if !ok {
return auth.PeerIdentity{}, false
}
return auth.PeerIdentity{
PubKey: pubKey,
TunnelIP: ip,
FQDN: fqdn,
}, true
}
}
// withTunnelLookup returns an http.Handler that attaches the per-account
// peerstore lookup to every request's context before delegating to next.
// Calling on the host-level listener is a no-op because that path never
// installs this wrapper, so the existing behaviour stays byte-for-byte
// identical when --private-inbound is off or the request didn't arrive
// on a per-account listener.
func withTunnelLookup(next http.Handler, lookup auth.TunnelLookupFunc) http.Handler {
if lookup == nil {
return next
}
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := auth.WithTunnelLookup(r.Context(), lookup)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
// inboundDebugAdapter adapts *inboundManager to the debug.InboundProvider
// interface so the debug HTTP handler can render per-account inbound
// listener state without importing the proxy package.
type inboundDebugAdapter struct {
mgr *inboundManager
}
// InboundListeners returns a snapshot of the live per-account inbound
// listeners formatted for the debug surface.
func (a inboundDebugAdapter) InboundListeners() map[types.AccountID]debug.InboundListenerInfo {
if a.mgr == nil {
return nil
}
snap := a.mgr.Snapshot()
if len(snap) == 0 {
return nil
}
out := make(map[types.AccountID]debug.InboundListenerInfo, len(snap))
for id, info := range snap {
out[id] = debug.InboundListenerInfo{
TunnelIP: info.TunnelIP,
HTTPSPort: info.HTTPSPort,
HTTPPort: info.HTTPPort,
}
}
return out
}
// newInboundErrorLog routes a per-account http.Server's stdlib error
// stream through logrus at warn level.
func newInboundErrorLog(logger *log.Logger, scheme string, accountID types.AccountID) *stdlog.Logger {
return stdlog.New(logger.WithFields(log.Fields{
"inbound-http": scheme,
"account_id": accountID,
}).WriterLevel(log.WarnLevel), "", 0)
}

502
proxy/inbound_test.go Normal file
View File

@@ -0,0 +1,502 @@
package proxy
import (
"bufio"
"context"
"crypto/tls"
"net"
"net/http"
"net/http/httptest"
"net/netip"
"sync"
"sync/atomic"
"testing"
"time"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"github.com/netbirdio/netbird/proxy/internal/auth"
"github.com/netbirdio/netbird/proxy/internal/roundtrip"
nbtcp "github.com/netbirdio/netbird/proxy/internal/tcp"
"github.com/netbirdio/netbird/proxy/internal/types"
"github.com/netbirdio/netbird/shared/management/proto"
)
// bufioReader wraps the connection in a buffered reader so http.ReadResponse
// can parse the response line + headers off the wire.
func bufioReader(conn net.Conn) *bufio.Reader {
return bufio.NewReader(conn)
}
// quietLogger returns a logger that emits nothing — keeps test output tidy.
func quietLogger() *log.Logger {
logger := log.New()
logger.SetLevel(log.PanicLevel)
return logger
}
func TestInboundManager_RouteScopedToAccount(t *testing.T) {
mgr := newInboundManager(quietLogger(), http.NotFoundHandler(), nil)
accountA := types.AccountID("acct-a")
accountB := types.AccountID("acct-b")
mgr.AddRoute(accountA, "shared.example", nbtcp.Route{Type: nbtcp.RouteHTTP, AccountID: accountA, ServiceID: "svc-a", Domain: "shared.example"})
mgr.AddRoute(accountB, "other.example", nbtcp.Route{Type: nbtcp.RouteHTTP, AccountID: accountB, ServiceID: "svc-b", Domain: "other.example"})
require.Equal(t, 1, mgr.PendingRouteCount(accountA), "account A should have one queued route")
require.Equal(t, 1, mgr.PendingRouteCount(accountB), "account B should have one queued route")
mgr.RemoveRoute(accountA, "shared.example", "svc-a")
mgr.RemoveRoute(accountB, "other.example", "svc-b")
assert.Equal(t, 0, mgr.PendingRouteCount(accountA), "queue should drain on remove")
assert.Equal(t, 0, mgr.PendingRouteCount(accountB), "queue should drain on remove")
}
func TestInboundManager_PendingThenFlush(t *testing.T) {
mgr := newInboundManager(quietLogger(), http.NotFoundHandler(), nil)
accountID := types.AccountID("acct-1")
host := nbtcp.SNIHost("example.test")
route := nbtcp.Route{Type: nbtcp.RouteHTTP, AccountID: accountID, ServiceID: "svc-1", Domain: "example.test"}
mgr.AddRoute(accountID, host, route)
require.Equal(t, 1, mgr.PendingRouteCount(accountID), "pending count before listener is up")
// Simulate listener up by registering a fake entry, then flushing.
router := nbtcp.NewRouter(quietLogger(), nil, &fakeAddr{addr: "127.0.0.1:0"})
entry := &inboundEntry{router: router}
mgr.muxLock.Lock()
mgr.entries[accountID] = entry
mgr.muxLock.Unlock()
mgr.flushPending(accountID, entry)
assert.Equal(t, 0, mgr.PendingRouteCount(accountID), "queue should be empty after flush")
}
// fakeAddr is a stub net.Addr for tests that don't actually bind sockets.
type fakeAddr struct {
addr string
}
func (a *fakeAddr) Network() string { return "tcp" }
func (a *fakeAddr) String() string { return a.addr }
// fakeMgmtClient implements roundtrip.managementClient for tests.
type fakeMgmtClient struct{}
func (fakeMgmtClient) CreateProxyPeer(_ context.Context, _ *proto.CreateProxyPeerRequest, _ ...grpc.CallOption) (*proto.CreateProxyPeerResponse, error) {
return &proto.CreateProxyPeerResponse{Success: true}, nil
}
// TestServer_PrivateInbound_NotEnabled_NoManager confirms that with
// --private off the inbound manager is nil and the standalone proxy
// keeps its zero-overhead default path.
func TestServer_PrivateInbound_NotEnabled_NoManager(t *testing.T) {
s := &Server{Logger: quietLogger(), Private: false}
s.initPrivateInbound(http.NotFoundHandler(), nil)
assert.Nil(t, s.inbound, "manager should remain nil when --private is off")
}
// TestServer_PrivateInbound_Enabled_WiresLifecycle confirms that
// --private alone wires the manager into the NetBird transport, so
// AddPeer / RemovePeer drive the lifecycle.
func TestServer_PrivateInbound_Enabled_WiresLifecycle(t *testing.T) {
s := &Server{Logger: quietLogger(), Private: true}
// Construct a NetBird transport. We can't actually start the embedded
// client here (that needs a real management server), but we can
// confirm that the lifecycle callbacks are registered.
s.netbird = roundtrip.NewNetBird("test", "test", roundtrip.ClientConfig{
MgmtAddr: "http://invalid.test",
}, quietLogger(), nil, fakeMgmtClient{})
s.initPrivateInbound(http.NotFoundHandler(), &tls.Config{}) //nolint:gosec
require.NotNil(t, s.inbound, "manager should be set when --private is on")
assert.NotNil(t, s.inbound.handler, "handler should be set on manager")
assert.NotNil(t, s.inbound.tlsConfig, "tls config should be set on manager")
}
// TestInboundManager_AddRouteAfterReady_RegistersDirectly verifies that
// when the listener is already up, AddRoute writes straight to the
// router without queueing.
func TestInboundManager_AddRouteAfterReady_RegistersDirectly(t *testing.T) {
mgr := newInboundManager(quietLogger(), http.NotFoundHandler(), nil)
accountID := types.AccountID("acct-1")
router := nbtcp.NewRouter(quietLogger(), nil, &fakeAddr{addr: "127.0.0.1:0"})
mgr.muxLock.Lock()
mgr.entries[accountID] = &inboundEntry{router: router}
mgr.muxLock.Unlock()
host := nbtcp.SNIHost("ready.example")
mgr.AddRoute(accountID, host, nbtcp.Route{Type: nbtcp.RouteHTTP, AccountID: accountID, ServiceID: "svc-ready", Domain: string(host)})
assert.Equal(t, 0, mgr.PendingRouteCount(accountID), "no pending entries when listener is up")
}
// TestPrivateCapability_DerivedFromPrivateOnly tests that the capability
// bit reported upstream tracks --private exclusively. The previous
// --private-inbound flag has been folded into --private.
func TestPrivateCapability_DerivedFromPrivateOnly(t *testing.T) {
tests := []struct {
name string
private bool
expected bool
}{
{"off", false, false},
{"on", true, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := &Server{Private: tt.private}
assert.Equal(t, tt.expected, s.Private, "private capability bit should match --private")
})
}
}
// TestInboundManager_RouteScopedToAccountB_DoesNotMatchA verifies that a
// service registered for account B is invisible to a router serving
// account A. We exercise the path through real per-account routers.
func TestInboundManager_RouteScopedToAccountB_DoesNotMatchA(t *testing.T) {
mgr := newInboundManager(quietLogger(), http.NotFoundHandler(), nil)
accountA := types.AccountID("acct-a")
accountB := types.AccountID("acct-b")
routerA := nbtcp.NewRouter(quietLogger(), nil, &fakeAddr{addr: "127.0.0.1:0"})
routerB := nbtcp.NewRouter(quietLogger(), nil, &fakeAddr{addr: "127.0.0.1:0"})
mgr.muxLock.Lock()
mgr.entries[accountA] = &inboundEntry{router: routerA}
mgr.entries[accountB] = &inboundEntry{router: routerB}
mgr.muxLock.Unlock()
host := nbtcp.SNIHost("shared.example")
mgr.AddRoute(accountB, host, nbtcp.Route{Type: nbtcp.RouteHTTP, AccountID: accountB, ServiceID: "svc-b", Domain: string(host)})
// Account A's router should have no routes; account B's should have one.
// We check via IsEmpty — true means no routes and no fallback.
assert.True(t, routerA.IsEmpty(), "account A router must not see account B's mappings")
assert.False(t, routerB.IsEmpty(), "account B router should hold its own mapping")
}
// TestInboundEntry_ShutdownIdempotent ensures that tearDown can run twice
// without panicking — callers may invoke it from RemovePeer + StopAll.
func TestInboundEntry_ShutdownIdempotent(t *testing.T) {
t.Skip("teardown requires real netstack listeners; covered by integration tests")
}
// TestRouter_PlainHTTP_ForwardedProtoIsHTTP exercises the full per-account
// router pipeline against a loopback listener (proxy of a netstack
// listener for test purposes): a plain HTTP request lands on the plain
// http.Server and the inner handler observes a nil r.TLS, which is what
// auth.ResolveProto translates to "http" in the real pipeline.
func TestRouter_PlainHTTP_ForwardedProtoIsHTTP(t *testing.T) {
logger := quietLogger()
var captured atomic.Value
captured.Store("")
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.TLS == nil {
captured.Store("http")
} else {
captured.Store("https")
}
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ok"))
})
hostListener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err, "loopback listener bind must succeed")
defer hostListener.Close()
router := nbtcp.NewRouter(logger, nil, hostListener.Addr(), nbtcp.WithPlainHTTP(hostListener.Addr()))
httpServer := &http.Server{Handler: handler, ReadHeaderTimeout: time.Second}
defer func() { _ = httpServer.Close() }()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go func() { _ = httpServer.Serve(router.HTTPListenerPlain()) }()
go func() { _ = router.Serve(ctx, hostListener) }()
conn, err := net.DialTimeout("tcp", hostListener.Addr().String(), 2*time.Second)
require.NoError(t, err, "plain HTTP dial must succeed")
defer conn.Close()
_, err = conn.Write([]byte("GET / HTTP/1.1\r\nHost: example.com\r\nConnection: close\r\n\r\n"))
require.NoError(t, err, "write must succeed")
resp, err := http.ReadResponse(bufioReader(conn), nil)
require.NoError(t, err, "must read response")
defer resp.Body.Close()
assert.Equal(t, http.StatusOK, resp.StatusCode)
assert.Equal(t, "http", captured.Load(), "ForwardedProto must be http on plain path")
}
// TestWithTunnelLookup_AttachesLookupToContext verifies that requests
// flowing through the per-account handler wrapper carry the peerstore
// lookup function. Phase 3's local-first deny path depends on this.
func TestWithTunnelLookup_AttachesLookupToContext(t *testing.T) {
expected := auth.PeerIdentity{TunnelIP: netip.MustParseAddr("100.64.0.10"), FQDN: "peer.netbird"}
lookup := auth.TunnelLookupFunc(func(_ netip.Addr) (auth.PeerIdentity, bool) {
return expected, true
})
var observed auth.TunnelLookupFunc
inner := http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
observed = auth.TunnelLookupFromContext(r.Context())
})
handler := withTunnelLookup(inner, lookup)
r := httptest.NewRequest(http.MethodGet, "https://svc.example/", nil)
handler.ServeHTTP(httptest.NewRecorder(), r)
require.NotNil(t, observed, "wrapper must inject the lookup into the request context")
got, ok := observed(netip.MustParseAddr("100.64.0.10"))
assert.True(t, ok, "lookup must round-trip through context")
assert.Equal(t, expected.FQDN, got.FQDN, "lookup must return the same identity it was constructed with")
}
// TestWithTunnelLookup_NilLookupIsNoop confirms the wrapper is a pure
// pass-through when no lookup is provided. Required for the host-level
// listener path to keep its byte-for-byte previous behaviour.
func TestWithTunnelLookup_NilLookupIsNoop(t *testing.T) {
var called bool
inner := http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
called = true
assert.Nil(t, auth.TunnelLookupFromContext(r.Context()), "host-level path must not see a lookup function")
})
handler := withTunnelLookup(inner, nil)
r := httptest.NewRequest(http.MethodGet, "https://svc.example/", nil)
handler.ServeHTTP(httptest.NewRecorder(), r)
assert.True(t, called, "wrapper without lookup must still invoke next")
}
// fakeListener satisfies net.Listener for snapshot tests without binding
// a real socket on the netstack.
type fakeListener struct {
addr net.Addr
}
func (f *fakeListener) Accept() (net.Conn, error) { return nil, net.ErrClosed }
func (f *fakeListener) Close() error { return nil }
func (f *fakeListener) Addr() net.Addr { return f.addr }
// TestInboundManager_ListenerInfo confirms ListenerInfo and Snapshot
// surface the bound tunnel-IP and ports for live entries.
func TestInboundManager_ListenerInfo(t *testing.T) {
mgr := newInboundManager(quietLogger(), http.NotFoundHandler(), nil)
accountID := types.AccountID("acct-info")
tlsAddr := &net.TCPAddr{IP: net.ParseIP("100.64.0.5"), Port: privateInboundPortHTTPS}
plainAddr := &net.TCPAddr{IP: net.ParseIP("100.64.0.5"), Port: privateInboundPortHTTP}
mgr.muxLock.Lock()
mgr.entries[accountID] = &inboundEntry{
tlsListener: &fakeListener{addr: tlsAddr},
plainListener: &fakeListener{addr: plainAddr},
}
mgr.muxLock.Unlock()
info, ok := mgr.ListenerInfo(accountID)
require.True(t, ok, "ListenerInfo must report ok for live entry")
assert.Equal(t, "100.64.0.5", info.TunnelIP, "tunnel IP must come from listener address")
assert.Equal(t, uint16(privateInboundPortHTTPS), info.HTTPSPort, "TLS port must match bound port")
assert.Equal(t, uint16(privateInboundPortHTTP), info.HTTPPort, "HTTP port must match bound port")
snap := mgr.Snapshot()
require.Len(t, snap, 1, "snapshot must contain exactly one entry")
assert.Equal(t, info, snap[accountID], "snapshot entry must equal direct lookup")
_, ok = mgr.ListenerInfo(types.AccountID("missing"))
assert.False(t, ok, "ListenerInfo must report ok=false for unknown accounts")
}
// TestInboundManager_NilManagerSafe ensures the observability accessors
// are safe to call when --private-inbound is off (nil manager).
func TestInboundManager_NilManagerSafe(t *testing.T) {
var mgr *inboundManager
_, ok := mgr.ListenerInfo("anything")
assert.False(t, ok, "nil manager must return ok=false")
assert.Nil(t, mgr.Snapshot(), "nil manager must return nil snapshot")
}
// TestInboundManager_ConcurrentAddRemove pounds AddRoute / RemoveRoute
// from multiple goroutines to expose any locking gaps.
func TestInboundManager_ConcurrentAddRemove(t *testing.T) {
mgr := newInboundManager(quietLogger(), http.NotFoundHandler(), nil)
accountID := types.AccountID("acct-1")
const workers = 32
const iterations = 50
var wg sync.WaitGroup
wg.Add(workers)
for i := 0; i < workers; i++ {
go func(idx int) {
defer wg.Done()
host := nbtcp.SNIHost("example.test")
svc := types.ServiceID("svc")
route := nbtcp.Route{Type: nbtcp.RouteHTTP, AccountID: accountID, ServiceID: svc, Domain: "example.test"}
for j := 0; j < iterations; j++ {
mgr.AddRoute(accountID, host, route)
mgr.RemoveRoute(accountID, host, svc)
}
}(i)
}
done := make(chan struct{})
go func() {
wg.Wait()
close(done)
}()
select {
case <-done:
case <-time.After(10 * time.Second):
t.Fatal("concurrent add/remove timed out")
}
}
// TestFeedRouterFromListener_DeliversConnectionToHandler validates the
// per-account inbound chain end-to-end with a loopback listener
// substituted for the embedded netstack: a TCP connection arriving at
// the plain listener flows through feedRouterFromListener, the router's
// peek-and-dispatch, the wrapped HTTP server, and reaches the user
// handler. If the embedded netstack is delivering connections at all,
// this is the path they take. Failures localise to wiring bugs in the
// proxy, not the netstack.
func TestFeedRouterFromListener_DeliversConnectionToHandler(t *testing.T) {
logger := quietLogger()
hits := make(chan string, 1)
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
hits <- r.Host
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("served"))
})
plainLn, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err, "plain loopback bind must succeed")
t.Cleanup(func() { _ = plainLn.Close() })
router := nbtcp.NewRouter(logger, nil, &fakeAddr{addr: "127.0.0.1:0"}, nbtcp.WithPlainHTTP(plainLn.Addr()))
httpServer := &http.Server{Handler: handler, ReadHeaderTimeout: time.Second}
t.Cleanup(func() { _ = httpServer.Close() })
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)
go func() { _ = httpServer.Serve(router.HTTPListenerPlain()) }()
go feedRouterFromListener(ctx, plainLn, router, logger, types.AccountID("acct-1"))
conn, err := net.DialTimeout("tcp", plainLn.Addr().String(), 2*time.Second)
require.NoError(t, err, "must connect to the plain listener")
t.Cleanup(func() { _ = conn.Close() })
_, err = conn.Write([]byte("GET / HTTP/1.1\r\nHost: app.example\r\nConnection: close\r\n\r\n"))
require.NoError(t, err, "request write must succeed")
resp, err := http.ReadResponse(bufioReader(conn), nil)
require.NoError(t, err, "must read response from server")
t.Cleanup(func() { _ = resp.Body.Close() })
assert.Equal(t, http.StatusOK, resp.StatusCode, "handler must be reached")
select {
case host := <-hits:
assert.Equal(t, "app.example", host, "handler must observe the request Host")
case <-time.After(2 * time.Second):
t.Fatal("handler was not invoked — connection did not flow through router → http server")
}
}
// TestFeedRouterFromListener_DispatchesTLSToTLSChannel verifies that a
// TLS ClientHello arriving on the plain listener is detected by the
// router peek and re-dispatched to the TLS channel — the cross-channel
// fallback the inbound stack relies on for HTTPS-on-:80 testing.
func TestFeedRouterFromListener_DispatchesTLSToTLSChannel(t *testing.T) {
logger := quietLogger()
hits := make(chan string, 1)
tlsHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
hits <- r.Host
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("served-tls"))
})
plainLn, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err, "plain loopback bind must succeed")
t.Cleanup(func() { _ = plainLn.Close() })
tlsLn, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err, "tls loopback bind must succeed")
t.Cleanup(func() { _ = tlsLn.Close() })
router := nbtcp.NewRouter(logger, nil, tlsLn.Addr(), nbtcp.WithPlainHTTP(plainLn.Addr()))
tlsConfig := selfSignedTLSConfig(t)
httpsServer := &http.Server{
Handler: tlsHandler,
TLSConfig: tlsConfig,
ReadHeaderTimeout: time.Second,
}
t.Cleanup(func() { _ = httpsServer.Close() })
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)
go func() { _ = httpsServer.ServeTLS(router.HTTPListener(), "", "") }()
go feedRouterFromListener(ctx, plainLn, router, logger, types.AccountID("acct-tls"))
tlsConn, err := tls.Dial("tcp", plainLn.Addr().String(), &tls.Config{InsecureSkipVerify: true}) //nolint:gosec
require.NoError(t, err, "TLS dial against the plain listener must succeed (cross-channel)")
t.Cleanup(func() { _ = tlsConn.Close() })
req, err := http.NewRequest(http.MethodGet, "https://app.example/", nil)
require.NoError(t, err)
require.NoError(t, req.Write(tlsConn), "TLS request write must succeed")
resp, err := http.ReadResponse(bufioReader(tlsConn), req)
require.NoError(t, err, "must read TLS response")
t.Cleanup(func() { _ = resp.Body.Close() })
assert.Equal(t, http.StatusOK, resp.StatusCode, "TLS handler must be reached")
select {
case host := <-hits:
assert.Equal(t, "app.example", host, "TLS handler must observe the request Host")
case <-time.After(2 * time.Second):
t.Fatal("TLS handler was not invoked — peek/dispatch path is broken")
}
}
func selfSignedTLSConfig(t *testing.T) *tls.Config {
t.Helper()
cert, err := tls.X509KeyPair(testCertPEM, testKeyPEM)
require.NoError(t, err, "load static self-signed cert")
return &tls.Config{Certificates: []tls.Certificate{cert}, MinVersion: tls.VersionTLS12} //nolint:gosec
}
// testCertPEM / testKeyPEM are a minimal RSA self-signed cert for
// 127.0.0.1 — only used by tests that need a working TLS handshake.
var testCertPEM = []byte(`-----BEGIN CERTIFICATE-----
MIIBhTCCASugAwIBAgIQIRi6zePL6mKjOipn+dNuaTAKBggqhkjOPQQDAjASMRAw
DgYDVQQKEwdBY21lIENvMB4XDTE3MTAyMDE5NDMwNloXDTE4MTAyMDE5NDMwNlow
EjEQMA4GA1UEChMHQWNtZSBDbzBZMBMGByqGSM49AgEGCCqGSM49AwEHA0IABD0d
7VNhbWvZLWPuj/RtHFjvtJBEwOkhbN/BnnE8rnZR8+sbwnc/KhCk3FhnpHZnQz7B
5aETbbIgmuvewdjvSBSjYzBhMA4GA1UdDwEB/wQEAwICpDATBgNVHSUEDDAKBggr
BgEFBQcDATAPBgNVHRMBAf8EBTADAQH/MCkGA1UdEQQiMCCCDmxvY2FsaG9zdDo1
NDUzgg4xMjcuMC4wLjE6NTQ1MzAKBggqhkjOPQQDAgNIADBFAiEA2zpJEPQyz6/l
Wf86aX6PepsntZv2GYlA5UpabfT2EZICICpJ5h/iI+i341gBmLiAFQOyTDT+/wQc
6MF9+Yw1Yy0t
-----END CERTIFICATE-----`)
var testKeyPEM = []byte(`-----BEGIN EC PRIVATE KEY-----
MHcCAQEEIIrYSSNQFaA2Hwf1duRSxKtLYX5CB04fSeQ6tF1aY/PuoAoGCCqGSM49
AwEHoUQDQgAEPR3tU2Fta9ktY+6P9G0cWO+0kETA6SFs38GecTyudlHz6xvCdz8q
EKTcWGekdmdDPsHloRNtsiCa697B2O9IFA==
-----END EC PRIVATE KEY-----`)

View File

@@ -0,0 +1,47 @@
package auth
import (
"context"
"net/netip"
)
// PeerIdentity describes the locally-known facts about a peer reachable on
// the proxy's per-account WireGuard listener. Phase 3 fills PubKey, TunnelIP
// and FQDN from the embedded client's peerstore. UserID, Email and Groups
// stay zero in V1 — full identity still travels through ValidateTunnelPeer.
// Phase V2 will populate them once RemotePeerConfig carries user identity.
type PeerIdentity struct {
PubKey string
TunnelIP netip.Addr
FQDN string
// V2 fields (zero in V1).
UserID string
Email string
Groups []string
}
// TunnelLookupFunc resolves a tunnel IP to a peer identity using locally
// available peerstore data. ok=false means the IP is not in the calling
// account's roster.
type TunnelLookupFunc func(ip netip.Addr) (PeerIdentity, bool)
type tunnelLookupContextKey struct{}
// WithTunnelLookup attaches a per-account peerstore lookup function to
// the request context. The auth middleware calls this lookup before
// hitting management's ValidateTunnelPeer to short-circuit unknown IPs
// and to skip the RPC for already-cached identities.
func WithTunnelLookup(ctx context.Context, lookup TunnelLookupFunc) context.Context {
if lookup == nil {
return ctx
}
return context.WithValue(ctx, tunnelLookupContextKey{}, lookup)
}
// TunnelLookupFromContext returns the peerstore lookup attached to ctx,
// or nil when the request did not arrive on a per-account listener.
func TunnelLookupFromContext(ctx context.Context) TunnelLookupFunc {
v, _ := ctx.Value(tunnelLookupContextKey{}).(TunnelLookupFunc)
return v
}

View File

@@ -36,6 +36,7 @@ type authenticator interface {
// SessionValidator validates session tokens and checks user access permissions.
type SessionValidator interface {
ValidateSession(ctx context.Context, in *proto.ValidateSessionRequest, opts ...grpc.CallOption) (*proto.ValidateSessionResponse, error)
ValidateTunnelPeer(ctx context.Context, in *proto.ValidateTunnelPeerRequest, opts ...grpc.CallOption) (*proto.ValidateTunnelPeerResponse, error)
}
// Scheme defines an authentication mechanism for a domain.
@@ -56,12 +57,21 @@ type DomainConfig struct {
AccountID types.AccountID
ServiceID types.ServiceID
IPRestrictions *restrict.Filter
// Private routes the domain through ValidateTunnelPeer; failure → 403.
Private bool
}
type validationResult struct {
UserID string
UserEmail string
Valid bool
DeniedReason string
Groups []string
// GroupNames carries the human-readable display names for Groups,
// ordered identically (positional pairing). May be shorter than
// Groups for tokens minted before names were embedded; the consumer
// falls back to ids for missing positions.
GroupNames []string
}
// Middleware applies per-domain authentication and IP restriction checks.
@@ -71,6 +81,7 @@ type Middleware struct {
logger *log.Logger
sessionValidator SessionValidator
geo restrict.GeoResolver
tunnelCache *tunnelValidationCache
}
// NewMiddleware creates a new authentication middleware. The sessionValidator is
@@ -84,6 +95,7 @@ func NewMiddleware(logger *log.Logger, sessionValidator SessionValidator, geo re
logger: logger,
sessionValidator: sessionValidator,
geo: geo,
tunnelCache: newTunnelValidationCache(),
}
}
@@ -111,6 +123,15 @@ func (mw *Middleware) Protect(next http.Handler) http.Handler {
return
}
// Private services bypass operator schemes and gate on tunnel peer.
if config.Private {
if mw.forwardWithTunnelPeer(w, r, host, config, next) {
return
}
http.Error(w, "Forbidden", http.StatusForbidden)
return
}
// Domains with no authentication schemes pass through after IP checks.
if len(config.Schemes) == 0 {
next.ServeHTTP(w, r)
@@ -129,10 +150,54 @@ func (mw *Middleware) Protect(next http.Handler) http.Handler {
return
}
if mw.forwardWithTunnelPeer(w, r, host, config, next) {
return
}
if mw.blockOIDCOnPlainHTTP(w, r, config) {
return
}
mw.authenticateWithSchemes(w, r, host, config)
})
}
// requestIsPlainHTTP reports whether the request arrived without TLS.
// Used to gate cookie-on-plain warnings and the OIDC plain-HTTP block.
func requestIsPlainHTTP(r *http.Request) bool {
return r.TLS == nil
}
// hasOIDCScheme reports whether any of the configured schemes requires
// TLS to round-trip safely with an external IdP.
func hasOIDCScheme(schemes []Scheme) bool {
for _, s := range schemes {
if s.Type() == auth.MethodOIDC {
return true
}
}
return false
}
// blockOIDCOnPlainHTTP fails fast when an OIDC-configured domain is hit
// over plain HTTP. Most IdPs reject http:// redirect URIs, so surfacing
// the misconfiguration here yields a clearer error than the IdP's
// "invalid redirect_uri" round-trip.
func (mw *Middleware) blockOIDCOnPlainHTTP(w http.ResponseWriter, r *http.Request, config DomainConfig) bool {
if !requestIsPlainHTTP(r) {
return false
}
if !hasOIDCScheme(config.Schemes) {
return false
}
mw.logger.WithFields(log.Fields{
"host": r.Host,
"remote": r.RemoteAddr,
}).Warn("OIDC scheme reached on plain HTTP path; rejecting with 400 — use port 443")
http.Error(w, "OIDC requires TLS — use port 443", http.StatusBadRequest)
return true
}
func (mw *Middleware) getDomainConfig(host string) (DomainConfig, bool) {
mw.domainsMux.RLock()
defer mw.domainsMux.RUnlock()
@@ -162,7 +227,17 @@ func (mw *Middleware) checkIPRestrictions(w http.ResponseWriter, r *http.Request
return false
}
verdict := config.IPRestrictions.Check(clientIP, mw.geo)
var verdict restrict.Verdict
if types.IsOverlayOrigin(r.Context()) {
// Geo/CrowdSec checks don't apply over the WireGuard overlay:
// the source address is always inside the NetBird CGNAT range,
// which is never in a GeoIP database or a CrowdSec decision
// list. Enforcing them here would either no-op (best case) or
// fail-closed when the geo database is missing.
verdict = config.IPRestrictions.CheckCIDR(clientIP)
} else {
verdict = config.IPRestrictions.Check(clientIP, mw.geo)
}
if verdict == restrict.Allow {
return true
}
@@ -246,18 +321,111 @@ func (mw *Middleware) forwardWithSessionCookie(w http.ResponseWriter, r *http.Re
if err != nil {
return false
}
userID, method, err := auth.ValidateSessionJWT(cookie.Value, host, config.SessionPublicKey)
userID, email, method, groups, groupNames, err := auth.ValidateSessionJWT(cookie.Value, host, config.SessionPublicKey)
if err != nil {
return false
}
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
cd.SetUserID(userID)
cd.SetUserEmail(email)
cd.SetUserGroups(groups)
cd.SetUserGroupNames(groupNames)
cd.SetAuthMethod(method)
}
next.ServeHTTP(w, r)
return true
}
// forwardWithTunnelPeer is the OIDC fast-path for requests originating on the
// netbird mesh. When the source IP belongs to a private/CGNAT range the proxy
// asks management to resolve it to a peer/user and to gate by the service's
// distribution_groups. On success the proxy installs the freshly minted JWT
// as a session cookie, sets UserID + Method=oidc on the captured data, and
// forwards directly — operators see the same access-log shape as if the user
// had completed an OIDC redirect. Any failure (private-range mismatch,
// management unreachable, peer unknown, user not in group) returns false so
// the caller falls back to the existing OIDC scheme dispatch.
//
// Phase 3 adds a local-first short-circuit: when the request arrived on a
// per-account inbound listener the context carries a peerstore lookup
// (TunnelLookupFromContext). If the lookup says the IP isn't in the account's
// roster the proxy denies fast without calling management. If the lookup
// confirms a known peer the RPC still runs for the user-identity tail
// (UserID + group access), but its result is cached for tunnelCacheTTL so
// repeat requests skip management entirely.
func (mw *Middleware) forwardWithTunnelPeer(w http.ResponseWriter, r *http.Request, host string, config DomainConfig, next http.Handler) bool {
if mw.sessionValidator == nil {
return false
}
clientIP := mw.resolveClientIP(r)
if !clientIP.IsValid() {
return false
}
if !isTunnelSourceIP(clientIP) {
return false
}
if lookup := TunnelLookupFromContext(r.Context()); lookup != nil {
if _, ok := lookup(clientIP); !ok {
mw.logger.WithFields(log.Fields{
"host": host,
"remote": clientIP,
}).Debug("local peerstore: tunnel IP not in account roster; denying without RPC")
return false
}
}
resp, _, err := mw.tunnelCache.fetch(r.Context(), tunnelCacheKey{
accountID: config.AccountID,
tunnelIP: clientIP,
domain: host,
}, mw.validateTunnelPeer)
if err != nil {
mw.logger.WithError(err).Debug("ValidateTunnelPeer failed; falling back to OIDC")
return false
}
if !resp.GetValid() || resp.GetSessionToken() == "" {
return false
}
setSessionCookie(w, resp.GetSessionToken(), config.SessionExpiration)
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
cd.SetOrigin(proxy.OriginAuth)
cd.SetUserID(resp.GetUserId())
cd.SetUserEmail(resp.GetUserEmail())
cd.SetUserGroups(resp.GetPeerGroupIds())
cd.SetUserGroupNames(resp.GetPeerGroupNames())
cd.SetAuthMethod(auth.MethodOIDC.String())
}
next.ServeHTTP(w, r)
return true
}
// validateTunnelPeer adapts the SessionValidator interface to the cache's
// validateTunnelPeerFn signature.
func (mw *Middleware) validateTunnelPeer(ctx context.Context, req *proto.ValidateTunnelPeerRequest) (*proto.ValidateTunnelPeerResponse, error) {
return mw.sessionValidator.ValidateTunnelPeer(ctx, req)
}
// cgnatPrefix covers RFC 6598 100.64.0.0/10, the CGNAT block NetBird
// allocates tunnel addresses from by default. IsPrivate() doesn't include
// it, so we check it explicitly.
var cgnatPrefix = netip.MustParsePrefix("100.64.0.0/10")
// isTunnelSourceIP reports whether ip falls within an address range typical
// of NetBird tunnels: RFC1918 private space, IPv6 ULA, or CGNAT 100.64/10
// (NetBird's default range). Loopback and link-local are excluded — the
// fast-path is meant for peer-to-peer mesh traffic, not localhost.
func isTunnelSourceIP(ip netip.Addr) bool {
if !ip.IsValid() || ip.IsLoopback() || ip.IsLinkLocalUnicast() {
return false
}
if ip.IsPrivate() {
return true
}
return cgnatPrefix.Contains(ip)
}
// forwardWithHeaderAuth checks for a Header auth scheme. If the header validates,
// the request is forwarded directly (no redirect), which is important for API clients.
func (mw *Middleware) forwardWithHeaderAuth(w http.ResponseWriter, r *http.Request, host string, config DomainConfig, next http.Handler) bool {
@@ -286,7 +454,7 @@ func (mw *Middleware) tryHeaderScheme(w http.ResponseWriter, r *http.Request, ho
result, err := mw.validateSessionToken(r.Context(), host, token, config.SessionPublicKey, auth.MethodHeader)
if err != nil {
setHeaderCapturedData(r.Context(), "")
setHeaderCapturedData(r.Context(), "", "", nil, nil)
status := http.StatusBadRequest
msg := "invalid session token"
if errors.Is(err, errValidationUnavailable) {
@@ -298,7 +466,7 @@ func (mw *Middleware) tryHeaderScheme(w http.ResponseWriter, r *http.Request, ho
}
if !result.Valid {
setHeaderCapturedData(r.Context(), result.UserID)
setHeaderCapturedData(r.Context(), result.UserID, result.UserEmail, result.Groups, result.GroupNames)
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return true
}
@@ -306,6 +474,9 @@ func (mw *Middleware) tryHeaderScheme(w http.ResponseWriter, r *http.Request, ho
setSessionCookie(w, token, config.SessionExpiration)
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
cd.SetUserID(result.UserID)
cd.SetUserEmail(result.UserEmail)
cd.SetUserGroups(result.Groups)
cd.SetUserGroupNames(result.GroupNames)
cd.SetAuthMethod(auth.MethodHeader.String())
}
@@ -315,7 +486,7 @@ func (mw *Middleware) tryHeaderScheme(w http.ResponseWriter, r *http.Request, ho
func (mw *Middleware) handleHeaderAuthError(w http.ResponseWriter, r *http.Request, err error) bool {
if errors.Is(err, ErrHeaderAuthFailed) {
setHeaderCapturedData(r.Context(), "")
setHeaderCapturedData(r.Context(), "", "", nil, nil)
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return true
}
@@ -327,7 +498,7 @@ func (mw *Middleware) handleHeaderAuthError(w http.ResponseWriter, r *http.Reque
return true
}
func setHeaderCapturedData(ctx context.Context, userID string) {
func setHeaderCapturedData(ctx context.Context, userID, userEmail string, groups, groupNames []string) {
cd := proxy.CapturedDataFromContext(ctx)
if cd == nil {
return
@@ -335,6 +506,9 @@ func setHeaderCapturedData(ctx context.Context, userID string) {
cd.SetOrigin(proxy.OriginAuth)
cd.SetAuthMethod(auth.MethodHeader.String())
cd.SetUserID(userID)
cd.SetUserEmail(userEmail)
cd.SetUserGroups(groups)
cd.SetUserGroupNames(groupNames)
}
// authenticateWithSchemes tries each configured auth scheme in order.
@@ -405,6 +579,9 @@ func (mw *Middleware) handleAuthenticatedToken(w http.ResponseWriter, r *http.Re
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
cd.SetOrigin(proxy.OriginAuth)
cd.SetUserID(result.UserID)
cd.SetUserEmail(result.UserEmail)
cd.SetUserGroups(result.Groups)
cd.SetUserGroupNames(result.GroupNames)
cd.SetAuthMethod(scheme.Type().String())
requestID = cd.GetRequestID()
}
@@ -419,6 +596,9 @@ func (mw *Middleware) handleAuthenticatedToken(w http.ResponseWriter, r *http.Re
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
cd.SetOrigin(proxy.OriginAuth)
cd.SetUserID(result.UserID)
cd.SetUserEmail(result.UserEmail)
cd.SetUserGroups(result.Groups)
cd.SetUserGroupNames(result.GroupNames)
cd.SetAuthMethod(scheme.Type().String())
}
redirectURL := stripSessionTokenParam(r.URL)
@@ -454,12 +634,9 @@ func wasCredentialSubmitted(r *http.Request, method auth.Method) bool {
return false
}
// AddDomain registers authentication schemes for the given domain.
// If schemes are provided, a valid session public key is required to sign/verify
// session JWTs. Returns an error if the key is missing or invalid.
// Callers must not serve the domain if this returns an error, to avoid
// exposing an unauthenticated service.
func (mw *Middleware) AddDomain(domain string, schemes []Scheme, publicKeyB64 string, expiration time.Duration, accountID types.AccountID, serviceID types.ServiceID, ipRestrictions *restrict.Filter) error {
// AddDomain registers authentication schemes for the given domain. With schemes a valid session public key is required.
// private=true forces ValidateTunnelPeer enforcement (403 on failure) regardless of the schemes list.
func (mw *Middleware) AddDomain(domain string, schemes []Scheme, publicKeyB64 string, expiration time.Duration, accountID types.AccountID, serviceID types.ServiceID, ipRestrictions *restrict.Filter, private bool) error {
if len(schemes) == 0 {
mw.domainsMux.Lock()
defer mw.domainsMux.Unlock()
@@ -467,6 +644,7 @@ func (mw *Middleware) AddDomain(domain string, schemes []Scheme, publicKeyB64 st
AccountID: accountID,
ServiceID: serviceID,
IPRestrictions: ipRestrictions,
Private: private,
}
return nil
}
@@ -488,6 +666,7 @@ func (mw *Middleware) AddDomain(domain string, schemes []Scheme, publicKeyB64 st
AccountID: accountID,
ServiceID: serviceID,
IPRestrictions: ipRestrictions,
Private: private,
}
return nil
}
@@ -518,18 +697,25 @@ func (mw *Middleware) validateSessionToken(ctx context.Context, host, token stri
}).Debug("Session validation denied")
return &validationResult{
UserID: resp.UserId,
UserEmail: resp.GetUserEmail(),
Valid: false,
DeniedReason: resp.DeniedReason,
}, nil
}
return &validationResult{UserID: resp.UserId, Valid: true}, nil
return &validationResult{
UserID: resp.UserId,
UserEmail: resp.GetUserEmail(),
Valid: true,
Groups: resp.GetPeerGroupIds(),
GroupNames: resp.GetPeerGroupNames(),
}, nil
}
userID, _, err := auth.ValidateSessionJWT(token, host, publicKey)
userID, email, _, groups, groupNames, err := auth.ValidateSessionJWT(token, host, publicKey)
if err != nil {
return nil, err
}
return &validationResult{UserID: userID, Valid: true}, nil
return &validationResult{UserID: userID, UserEmail: email, Valid: true, Groups: groups, GroupNames: groupNames}, nil
}
// stripSessionTokenParam returns the request URI with the session_token query

View File

@@ -4,6 +4,7 @@ import (
"context"
"crypto/ed25519"
"crypto/rand"
"crypto/tls"
"encoding/base64"
"errors"
"net/http"
@@ -23,6 +24,7 @@ import (
"github.com/netbirdio/netbird/proxy/auth"
"github.com/netbirdio/netbird/proxy/internal/proxy"
"github.com/netbirdio/netbird/proxy/internal/restrict"
"github.com/netbirdio/netbird/proxy/internal/types"
"github.com/netbirdio/netbird/shared/management/proto"
)
@@ -62,7 +64,7 @@ func TestAddDomain_ValidKey(t *testing.T) {
kp := generateTestKeyPair(t)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
err := mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil)
err := mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false)
require.NoError(t, err)
mw.domainsMux.RLock()
@@ -79,7 +81,7 @@ func TestAddDomain_EmptyKey(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil, nil)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
err := mw.AddDomain("example.com", []Scheme{scheme}, "", time.Hour, "", "", nil)
err := mw.AddDomain("example.com", []Scheme{scheme}, "", time.Hour, "", "", nil, false)
require.Error(t, err)
assert.Contains(t, err.Error(), "invalid session public key size")
@@ -93,7 +95,7 @@ func TestAddDomain_InvalidBase64(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil, nil)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
err := mw.AddDomain("example.com", []Scheme{scheme}, "not-valid-base64!!!", time.Hour, "", "", nil)
err := mw.AddDomain("example.com", []Scheme{scheme}, "not-valid-base64!!!", time.Hour, "", "", nil, false)
require.Error(t, err)
assert.Contains(t, err.Error(), "decode session public key")
@@ -108,7 +110,7 @@ func TestAddDomain_WrongKeySize(t *testing.T) {
shortKey := base64.StdEncoding.EncodeToString([]byte("tooshort"))
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
err := mw.AddDomain("example.com", []Scheme{scheme}, shortKey, time.Hour, "", "", nil)
err := mw.AddDomain("example.com", []Scheme{scheme}, shortKey, time.Hour, "", "", nil, false)
require.Error(t, err)
assert.Contains(t, err.Error(), "invalid session public key size")
@@ -121,7 +123,7 @@ func TestAddDomain_WrongKeySize(t *testing.T) {
func TestAddDomain_NoSchemes_NoKeyRequired(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil, nil)
err := mw.AddDomain("example.com", nil, "", time.Hour, "", "", nil)
err := mw.AddDomain("example.com", nil, "", time.Hour, "", "", nil, false)
require.NoError(t, err, "domains with no auth schemes should not require a key")
mw.domainsMux.RLock()
@@ -137,8 +139,8 @@ func TestAddDomain_OverwritesPreviousConfig(t *testing.T) {
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp1.PublicKey, time.Hour, "", "", nil))
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp2.PublicKey, 2*time.Hour, "", "", nil))
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp1.PublicKey, time.Hour, "", "", nil, false))
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp2.PublicKey, 2*time.Hour, "", "", nil, false))
mw.domainsMux.RLock()
config := mw.domains["example.com"]
@@ -154,7 +156,7 @@ func TestRemoveDomain(t *testing.T) {
kp := generateTestKeyPair(t)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
mw.RemoveDomain("example.com")
@@ -178,7 +180,7 @@ func TestProtect_UnknownDomainPassesThrough(t *testing.T) {
func TestProtect_DomainWithNoSchemesPassesThrough(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil, nil)
require.NoError(t, mw.AddDomain("example.com", nil, "", time.Hour, "", "", nil))
require.NoError(t, mw.AddDomain("example.com", nil, "", time.Hour, "", "", nil, false))
handler := mw.Protect(newPassthroughHandler())
@@ -195,7 +197,7 @@ func TestProtect_UnauthenticatedRequestIsBlocked(t *testing.T) {
kp := generateTestKeyPair(t)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
var backendCalled bool
backend := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
@@ -216,7 +218,7 @@ func TestProtect_HostWithPortIsMatched(t *testing.T) {
kp := generateTestKeyPair(t)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
var backendCalled bool
backend := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
@@ -237,9 +239,9 @@ func TestProtect_ValidSessionCookiePassesThrough(t *testing.T) {
kp := generateTestKeyPair(t)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
token, err := sessionkey.SignToken(kp.PrivateKey, "test-user", "example.com", auth.MethodPIN, time.Hour)
token, err := sessionkey.SignToken(kp.PrivateKey, "test-user", "", "example.com", auth.MethodPIN, nil, nil, time.Hour)
require.NoError(t, err)
capturedData := proxy.NewCapturedData("")
@@ -262,15 +264,48 @@ func TestProtect_ValidSessionCookiePassesThrough(t *testing.T) {
assert.Equal(t, "authenticated", rec.Body.String())
}
// TestProtect_SessionCookieGroupsPropagate verifies the cookie path lifts the
// JWT's groups claim into CapturedData so policy-aware middlewares can
// authorise without an extra management round-trip.
func TestProtect_SessionCookieGroupsPropagate(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil, nil)
kp := generateTestKeyPair(t)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
groups := []string{"engineering", "sre"}
token, err := sessionkey.SignToken(kp.PrivateKey, "test-user", "", "example.com", auth.MethodPIN, groups, nil, time.Hour)
require.NoError(t, err)
capturedData := proxy.NewCapturedData("")
handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
cd := proxy.CapturedDataFromContext(r.Context())
require.NotNil(t, cd, "captured data must be present in request context")
assert.Equal(t, "test-user", cd.GetUserID())
assert.Equal(t, groups, cd.GetUserGroups(), "JWT groups claim must propagate to CapturedData")
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
req = req.WithContext(proxy.WithCapturedData(req.Context(), capturedData))
req.AddCookie(&http.Cookie{Name: auth.SessionCookieName, Value: token})
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code, "request with valid groups-bearing cookie must succeed")
assert.Equal(t, groups, capturedData.GetUserGroups(), "CapturedData groups must be retained after handler completes")
}
func TestProtect_ExpiredSessionCookieIsRejected(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil, nil)
kp := generateTestKeyPair(t)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
// Sign a token that expired 1 second ago.
token, err := sessionkey.SignToken(kp.PrivateKey, "test-user", "example.com", auth.MethodPIN, -time.Second)
token, err := sessionkey.SignToken(kp.PrivateKey, "test-user", "", "example.com", auth.MethodPIN, nil, nil, -time.Second)
require.NoError(t, err)
var backendCalled bool
@@ -293,10 +328,10 @@ func TestProtect_WrongDomainCookieIsRejected(t *testing.T) {
kp := generateTestKeyPair(t)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
// Token signed for a different domain audience.
token, err := sessionkey.SignToken(kp.PrivateKey, "test-user", "other.com", auth.MethodPIN, time.Hour)
token, err := sessionkey.SignToken(kp.PrivateKey, "test-user", "", "other.com", auth.MethodPIN, nil, nil, time.Hour)
require.NoError(t, err)
var backendCalled bool
@@ -320,10 +355,10 @@ func TestProtect_WrongKeyCookieIsRejected(t *testing.T) {
kp2 := generateTestKeyPair(t)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp1.PublicKey, time.Hour, "", "", nil))
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp1.PublicKey, time.Hour, "", "", nil, false))
// Token signed with a different private key.
token, err := sessionkey.SignToken(kp2.PrivateKey, "test-user", "example.com", auth.MethodPIN, time.Hour)
token, err := sessionkey.SignToken(kp2.PrivateKey, "test-user", "", "example.com", auth.MethodPIN, nil, nil, time.Hour)
require.NoError(t, err)
var backendCalled bool
@@ -345,7 +380,7 @@ func TestProtect_SchemeAuthRedirectsWithCookie(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil, nil)
kp := generateTestKeyPair(t)
token, err := sessionkey.SignToken(kp.PrivateKey, "pin-user", "example.com", auth.MethodPIN, time.Hour)
token, err := sessionkey.SignToken(kp.PrivateKey, "pin-user", "", "example.com", auth.MethodPIN, nil, nil, time.Hour)
require.NoError(t, err)
scheme := &stubScheme{
@@ -357,7 +392,7 @@ func TestProtect_SchemeAuthRedirectsWithCookie(t *testing.T) {
return "", "pin", nil
},
}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
var backendCalled bool
backend := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
@@ -410,7 +445,7 @@ func TestProtect_FailedAuthDoesNotSetCookie(t *testing.T) {
return "", "pin", nil
},
}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
handler := mw.Protect(newPassthroughHandler())
@@ -427,7 +462,7 @@ func TestProtect_MultipleSchemes(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil, nil)
kp := generateTestKeyPair(t)
token, err := sessionkey.SignToken(kp.PrivateKey, "password-user", "example.com", auth.MethodPassword, time.Hour)
token, err := sessionkey.SignToken(kp.PrivateKey, "password-user", "", "example.com", auth.MethodPassword, nil, nil, time.Hour)
require.NoError(t, err)
// First scheme (PIN) always fails, second scheme (password) succeeds.
@@ -446,7 +481,7 @@ func TestProtect_MultipleSchemes(t *testing.T) {
return "", "password", nil
},
}
require.NoError(t, mw.AddDomain("example.com", []Scheme{pinScheme, passwordScheme}, kp.PublicKey, time.Hour, "", "", nil))
require.NoError(t, mw.AddDomain("example.com", []Scheme{pinScheme, passwordScheme}, kp.PublicKey, time.Hour, "", "", nil, false))
var backendCalled bool
backend := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
@@ -476,7 +511,7 @@ func TestProtect_InvalidTokenFromSchemeReturns400(t *testing.T) {
return "invalid-jwt-token", "", nil
},
}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
handler := mw.Protect(newPassthroughHandler())
@@ -500,7 +535,7 @@ func TestAddDomain_RandomBytes32NotEd25519(t *testing.T) {
key := base64.StdEncoding.EncodeToString(randomBytes)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
err = mw.AddDomain("example.com", []Scheme{scheme}, key, time.Hour, "", "", nil)
err = mw.AddDomain("example.com", []Scheme{scheme}, key, time.Hour, "", "", nil, false)
require.NoError(t, err, "any 32-byte key should be accepted at registration time")
}
@@ -509,10 +544,10 @@ func TestAddDomain_InvalidKeyDoesNotCorruptExistingConfig(t *testing.T) {
kp := generateTestKeyPair(t)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
// Attempt to overwrite with an invalid key.
err := mw.AddDomain("example.com", []Scheme{scheme}, "bad", time.Hour, "", "", nil)
err := mw.AddDomain("example.com", []Scheme{scheme}, "bad", time.Hour, "", "", nil, false)
require.Error(t, err)
// The original valid config should still be intact.
@@ -536,7 +571,7 @@ func TestProtect_FailedPinAuthCapturesAuthMethod(t *testing.T) {
return "", "pin", nil
},
}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
capturedData := proxy.NewCapturedData("")
handler := mw.Protect(newPassthroughHandler())
@@ -563,7 +598,7 @@ func TestProtect_FailedPasswordAuthCapturesAuthMethod(t *testing.T) {
return "", "password", nil
},
}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
capturedData := proxy.NewCapturedData("")
handler := mw.Protect(newPassthroughHandler())
@@ -590,7 +625,7 @@ func TestProtect_NoCredentialsDoesNotCaptureAuthMethod(t *testing.T) {
return "", "pin", nil
},
}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
capturedData := proxy.NewCapturedData("")
handler := mw.Protect(newPassthroughHandler())
@@ -678,7 +713,7 @@ func TestCheckIPRestrictions_UnparseableAddress(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil, nil)
err := mw.AddDomain("example.com", nil, "", 0, "acc1", "svc1",
restrict.ParseFilter(restrict.FilterConfig{AllowedCIDRs: []string{"10.0.0.0/8"}}))
restrict.ParseFilter(restrict.FilterConfig{AllowedCIDRs: []string{"10.0.0.0/8"}}), false)
require.NoError(t, err)
handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -714,7 +749,7 @@ func TestCheckIPRestrictions_UsesCapturedDataClientIP(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil, nil)
err := mw.AddDomain("example.com", nil, "", 0, "acc1", "svc1",
restrict.ParseFilter(restrict.FilterConfig{AllowedCIDRs: []string{"203.0.113.0/24"}}))
restrict.ParseFilter(restrict.FilterConfig{AllowedCIDRs: []string{"203.0.113.0/24"}}), false)
require.NoError(t, err)
handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -755,7 +790,7 @@ func TestCheckIPRestrictions_NilGeoWithCountryRules(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil, nil)
err := mw.AddDomain("example.com", nil, "", 0, "acc1", "svc1",
restrict.ParseFilter(restrict.FilterConfig{AllowedCountries: []string{"US"}}))
restrict.ParseFilter(restrict.FilterConfig{AllowedCountries: []string{"US"}}), false)
require.NoError(t, err)
handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -770,6 +805,69 @@ func TestCheckIPRestrictions_NilGeoWithCountryRules(t *testing.T) {
assert.Equal(t, http.StatusForbidden, rr.Code, "country restrictions with nil geo must deny")
}
// TestCheckIPRestrictions_OverlayOriginSkipsCountryRules covers the
// inbound (WG) listener path: requests stamped with WithOverlayOrigin
// must skip country lookups, even when no geo database is configured.
// Without this short-circuit the inbound flow would fail-closed for
// every overlay request whenever country rules are configured.
func TestCheckIPRestrictions_OverlayOriginSkipsCountryRules(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil, nil)
err := mw.AddDomain("example.com", nil, "", 0, "acc1", "svc1",
restrict.ParseFilter(restrict.FilterConfig{
AllowedCIDRs: []string{"100.64.0.0/10"},
AllowedCountries: []string{"US"},
}), false)
require.NoError(t, err)
handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
req.RemoteAddr = "100.64.5.6:5000"
req.Host = "example.com"
req = req.WithContext(types.WithOverlayOrigin(req.Context()))
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
assert.Equal(t, http.StatusOK, rr.Code,
"overlay-origin requests must not be denied by country rules they would fail without geo data")
// Sanity check: the same filter without the overlay flag denies (no geo,
// country allowlist active → DenyGeoUnavailable).
req2 := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
req2.RemoteAddr = "100.64.5.6:5000"
req2.Host = "example.com"
rr2 := httptest.NewRecorder()
handler.ServeHTTP(rr2, req2)
assert.Equal(t, http.StatusForbidden, rr2.Code,
"WAN-origin requests must still hit the full Check path and be denied without geo data")
}
// TestCheckIPRestrictions_OverlayOriginRespectsCIDR confirms CIDR
// rules still apply on the overlay path so operators retain a way to
// scope private services to specific peer subnets.
func TestCheckIPRestrictions_OverlayOriginRespectsCIDR(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil, nil)
err := mw.AddDomain("example.com", nil, "", 0, "acc1", "svc1",
restrict.ParseFilter(restrict.FilterConfig{AllowedCIDRs: []string{"100.64.0.0/16"}}), false)
require.NoError(t, err)
handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
req.RemoteAddr = "100.65.5.6:5000" // outside 100.64.0.0/16
req.Host = "example.com"
req = req.WithContext(types.WithOverlayOrigin(req.Context()))
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
assert.Equal(t, http.StatusForbidden, rr.Code,
"CIDR rules must still apply on the overlay path")
}
func TestProtect_OIDCOnlyRedirectsDirectly(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil, nil)
kp := generateTestKeyPair(t)
@@ -781,11 +879,12 @@ func TestProtect_OIDCOnlyRedirectsDirectly(t *testing.T) {
return "", oidcURL, nil
},
}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
handler := mw.Protect(newPassthroughHandler())
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
req := httptest.NewRequest(http.MethodGet, "https://example.com/", nil)
req.TLS = &tls.ConnectionState{}
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
@@ -809,11 +908,12 @@ func TestProtect_OIDCWithOtherMethodShowsLoginPage(t *testing.T) {
return "", "pin", nil
},
}
require.NoError(t, mw.AddDomain("example.com", []Scheme{oidcScheme, pinScheme}, kp.PublicKey, time.Hour, "", "", nil))
require.NoError(t, mw.AddDomain("example.com", []Scheme{oidcScheme, pinScheme}, kp.PublicKey, time.Hour, "", "", nil, false))
handler := mw.Protect(newPassthroughHandler())
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
req := httptest.NewRequest(http.MethodGet, "https://example.com/", nil)
req.TLS = &tls.ConnectionState{}
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
@@ -834,7 +934,7 @@ func (m *mockAuthenticator) Authenticate(ctx context.Context, in *proto.Authenti
// returns a signed session token when the expected header value is provided.
func newHeaderSchemeWithToken(t *testing.T, kp *sessionkey.KeyPair, headerName, expectedValue string) Header {
t.Helper()
token, err := sessionkey.SignToken(kp.PrivateKey, "header-user", "example.com", auth.MethodHeader, time.Hour)
token, err := sessionkey.SignToken(kp.PrivateKey, "header-user", "", "example.com", auth.MethodHeader, nil, nil, time.Hour)
require.NoError(t, err)
mock := &mockAuthenticator{fn: func(_ context.Context, req *proto.AuthenticateRequest) (*proto.AuthenticateResponse, error) {
@@ -852,7 +952,7 @@ func TestProtect_HeaderAuth_ForwardsOnSuccess(t *testing.T) {
kp := generateTestKeyPair(t)
hdr := newHeaderSchemeWithToken(t, kp, "X-API-Key", "secret-key")
require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr}, kp.PublicKey, time.Hour, "acc1", "svc1", nil))
require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr}, kp.PublicKey, time.Hour, "acc1", "svc1", nil, false))
var backendCalled bool
capturedData := proxy.NewCapturedData("")
@@ -895,7 +995,7 @@ func TestProtect_HeaderAuth_MissingHeaderFallsThrough(t *testing.T) {
hdr := newHeaderSchemeWithToken(t, kp, "X-API-Key", "secret-key")
// Also add a PIN scheme so we can verify fallthrough behavior.
pinScheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr, pinScheme}, kp.PublicKey, time.Hour, "acc1", "svc1", nil))
require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr, pinScheme}, kp.PublicKey, time.Hour, "acc1", "svc1", nil, false))
handler := mw.Protect(newPassthroughHandler())
@@ -915,7 +1015,7 @@ func TestProtect_HeaderAuth_WrongValueReturns401(t *testing.T) {
return &proto.AuthenticateResponse{Success: false}, nil
}}
hdr := NewHeader(mock, "svc1", "acc1", "X-API-Key")
require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr}, kp.PublicKey, time.Hour, "acc1", "svc1", nil))
require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr}, kp.PublicKey, time.Hour, "acc1", "svc1", nil, false))
capturedData := proxy.NewCapturedData("")
handler := mw.Protect(newPassthroughHandler())
@@ -938,7 +1038,7 @@ func TestProtect_HeaderAuth_InfraErrorReturns502(t *testing.T) {
return nil, errors.New("gRPC unavailable")
}}
hdr := NewHeader(mock, "svc1", "acc1", "X-API-Key")
require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr}, kp.PublicKey, time.Hour, "acc1", "svc1", nil))
require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr}, kp.PublicKey, time.Hour, "acc1", "svc1", nil, false))
handler := mw.Protect(newPassthroughHandler())
@@ -955,7 +1055,7 @@ func TestProtect_HeaderAuth_SubsequentRequestUsesSessionCookie(t *testing.T) {
kp := generateTestKeyPair(t)
hdr := newHeaderSchemeWithToken(t, kp, "X-API-Key", "secret-key")
require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr}, kp.PublicKey, time.Hour, "acc1", "svc1", nil))
require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr}, kp.PublicKey, time.Hour, "acc1", "svc1", nil, false))
handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
@@ -1006,7 +1106,7 @@ func TestProtect_HeaderAuth_MultipleValuesSameHeader(t *testing.T) {
mock := &mockAuthenticator{fn: func(_ context.Context, req *proto.AuthenticateRequest) (*proto.AuthenticateResponse, error) {
ha := req.GetHeaderAuth()
if ha != nil && accepted[ha.GetHeaderValue()] {
token, err := sessionkey.SignToken(kp.PrivateKey, "header-user", "example.com", auth.MethodHeader, time.Hour)
token, err := sessionkey.SignToken(kp.PrivateKey, "header-user", "", "example.com", auth.MethodHeader, nil, nil, time.Hour)
require.NoError(t, err)
return &proto.AuthenticateResponse{Success: true, SessionToken: token}, nil
}
@@ -1015,7 +1115,7 @@ func TestProtect_HeaderAuth_MultipleValuesSameHeader(t *testing.T) {
// Single Header scheme (as if one entry existed), but the mock checks both values.
hdr := NewHeader(mock, "svc1", "acc1", "Authorization")
require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr}, kp.PublicKey, time.Hour, "acc1", "svc1", nil))
require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr}, kp.PublicKey, time.Hour, "acc1", "svc1", nil, false))
var backendCalled bool
handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
@@ -1059,3 +1159,71 @@ func TestProtect_HeaderAuth_MultipleValuesSameHeader(t *testing.T) {
assert.False(t, backendCalled, "unknown token should be rejected")
})
}
// TestProtect_OIDCOnPlainHTTP_BlockedWith400 verifies that when an OIDC
// scheme is configured and the request arrived without TLS, the middleware
// short-circuits with a 400 instead of dispatching to the IdP redirect.
func TestProtect_OIDCOnPlainHTTP_BlockedWith400(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil, nil)
kp := generateTestKeyPair(t)
scheme := &stubScheme{
method: auth.MethodOIDC,
authFn: func(_ *http.Request) (string, string, error) {
return "", "https://idp.example.com/authorize", nil
},
}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
handler := mw.Protect(newPassthroughHandler())
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusBadRequest, rec.Code, "OIDC over plain HTTP should be rejected")
assert.Contains(t, rec.Body.String(), "OIDC requires TLS", "response body should explain the rejection")
}
// TestProtect_OIDCOverTLS_NotBlocked confirms the same configuration works
// over TLS — the block only fires on plain HTTP.
func TestProtect_OIDCOverTLS_NotBlocked(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil, nil)
kp := generateTestKeyPair(t)
scheme := &stubScheme{
method: auth.MethodOIDC,
authFn: func(_ *http.Request) (string, string, error) {
return "", "https://idp.example.com/authorize", nil
},
}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
handler := mw.Protect(newPassthroughHandler())
req := httptest.NewRequest(http.MethodGet, "https://example.com/", nil)
req.TLS = &tls.ConnectionState{}
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusFound, rec.Code, "OIDC over TLS should redirect to IdP")
}
// TestProtect_NonOIDCSchemes_PlainHTTP_NotBlocked confirms that the OIDC
// block only fires when an OIDC scheme is configured. PIN-only domains
// pass through normally on plain HTTP.
func TestProtect_NonOIDCSchemes_PlainHTTP_NotBlocked(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil, nil)
kp := generateTestKeyPair(t)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
handler := mw.Protect(newPassthroughHandler())
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusUnauthorized, rec.Code, "PIN-only domain should serve the login page on plain HTTP")
}

View File

@@ -0,0 +1,171 @@
package auth
import (
"context"
"net/netip"
"sync"
"time"
"golang.org/x/sync/singleflight"
"github.com/netbirdio/netbird/proxy/internal/types"
"github.com/netbirdio/netbird/shared/management/proto"
)
// tunnelCacheTTL caps how long a positive ValidateTunnelPeer result is
// reused before re-fetching from management. 5 minutes balances freshness
// against management load on busy mesh networks.
const tunnelCacheTTL = 300 * time.Second
// tunnelCachePerAccount caps the number of cached identities per account.
// Bounded eviction avoids memory growth in pathological cases (huge peer
// roster, brief request bursts) while staying generous for normal use.
const tunnelCachePerAccount = 1024
// tunnelCacheKey identifies a cached entry by tunnel IP and originating
// account. Domain is part of the value, not the key, because the
// management response is per (account, IP) — domain only gates whether a
// re-fetch is needed if the operator is accessing a different service.
type tunnelCacheKey struct {
accountID types.AccountID
tunnelIP netip.Addr
domain string
}
// tunnelCacheEntry stores a positive validation response with the time it
// was minted. Entries past tunnelCacheTTL are treated as misses.
type tunnelCacheEntry struct {
resp *proto.ValidateTunnelPeerResponse
cachedAt time.Time
}
// tunnelValidationCache memoizes ValidateTunnelPeer responses keyed by
// (accountID, tunnelIP, domain). Only successful, valid responses are
// cached — denials skip the cache so policy changes apply immediately.
// Single-flight de-duplicates concurrent fetches for the same key so a
// burst of cold requests collapses into a single RPC.
type tunnelValidationCache struct {
mu sync.Mutex
entries map[types.AccountID]*accountBucket
flight singleflight.Group
ttl time.Duration
maxSize int
now func() time.Time
}
// accountBucket holds the cached entries for a single account, with a
// FIFO eviction queue used when the bucket exceeds maxSize.
type accountBucket struct {
items map[tunnelCacheKey]tunnelCacheEntry
order []tunnelCacheKey
}
// newTunnelValidationCache constructs a cache with default TTL and bounds.
func newTunnelValidationCache() *tunnelValidationCache {
return &tunnelValidationCache{
entries: make(map[types.AccountID]*accountBucket),
ttl: tunnelCacheTTL,
maxSize: tunnelCachePerAccount,
now: time.Now,
}
}
// get returns a cached response for the key, or nil when missing or
// expired. Expired entries are evicted lazily on read.
func (c *tunnelValidationCache) get(key tunnelCacheKey) *proto.ValidateTunnelPeerResponse {
c.mu.Lock()
defer c.mu.Unlock()
bucket, ok := c.entries[key.accountID]
if !ok {
return nil
}
entry, ok := bucket.items[key]
if !ok {
return nil
}
if c.now().Sub(entry.cachedAt) > c.ttl {
delete(bucket.items, key)
bucket.order = removeKey(bucket.order, key)
return nil
}
return entry.resp
}
// put records a positive response under the key. Evicts the oldest entry
// in the account's bucket when the bound is exceeded.
func (c *tunnelValidationCache) put(key tunnelCacheKey, resp *proto.ValidateTunnelPeerResponse) {
c.mu.Lock()
defer c.mu.Unlock()
bucket, ok := c.entries[key.accountID]
if !ok {
bucket = &accountBucket{items: make(map[tunnelCacheKey]tunnelCacheEntry)}
c.entries[key.accountID] = bucket
}
if _, exists := bucket.items[key]; !exists {
bucket.order = append(bucket.order, key)
}
bucket.items[key] = tunnelCacheEntry{resp: resp, cachedAt: c.now()}
for len(bucket.order) > c.maxSize {
oldest := bucket.order[0]
bucket.order = bucket.order[1:]
delete(bucket.items, oldest)
}
}
// removeKey drops the first occurrence of needle from order. The cache
// uses small slices so a linear scan is cheaper than a map+slice combo.
func removeKey(order []tunnelCacheKey, needle tunnelCacheKey) []tunnelCacheKey {
for i, k := range order {
if k == needle {
return append(order[:i], order[i+1:]...)
}
}
return order
}
// flightKey turns a cache key into a single-flight string. AccountID and
// IP isolation by themselves are insufficient because different domains
// for the same peer/account may have different group access.
func flightKey(key tunnelCacheKey) string {
return string(key.accountID) + "|" + key.tunnelIP.String() + "|" + key.domain
}
// validateTunnelPeerFn is the RPC entry point the cache wraps. It matches
// the SessionValidator.ValidateTunnelPeer signature without exposing the
// gRPC option variadic, since callers don't need it on the cache hot path.
type validateTunnelPeerFn func(ctx context.Context, req *proto.ValidateTunnelPeerRequest) (*proto.ValidateTunnelPeerResponse, error)
// fetch returns a cached response when present, otherwise calls validate
// under single-flight and caches the result. Denied responses pass
// through but are not cached so policy changes apply immediately.
func (c *tunnelValidationCache) fetch(ctx context.Context, key tunnelCacheKey, validate validateTunnelPeerFn) (*proto.ValidateTunnelPeerResponse, bool, error) {
if resp := c.get(key); resp != nil {
return resp, true, nil
}
flight := flightKey(key)
res, err, _ := c.flight.Do(flight, func() (any, error) {
if cached := c.get(key); cached != nil {
return cached, nil
}
resp, err := validate(ctx, &proto.ValidateTunnelPeerRequest{
TunnelIp: key.tunnelIP.String(),
Domain: key.domain,
})
if err != nil {
return nil, err
}
if resp.GetValid() && resp.GetSessionToken() != "" {
c.put(key, resp)
}
return resp, nil
})
if err != nil {
return nil, false, err
}
resp, _ := res.(*proto.ValidateTunnelPeerResponse)
return resp, false, nil
}

View File

@@ -0,0 +1,171 @@
package auth
import (
"context"
"net/netip"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/proxy/internal/types"
"github.com/netbirdio/netbird/shared/management/proto"
)
func newTestKey(account types.AccountID, ip string, domain string) tunnelCacheKey {
return tunnelCacheKey{
accountID: account,
tunnelIP: netip.MustParseAddr(ip),
domain: domain,
}
}
func TestTunnelCache_HitSkipsRPC(t *testing.T) {
cache := newTunnelValidationCache()
key := newTestKey("acct-1", "100.64.0.10", "svc.example")
var calls int32
validate := func(_ context.Context, req *proto.ValidateTunnelPeerRequest) (*proto.ValidateTunnelPeerResponse, error) {
atomic.AddInt32(&calls, 1)
return &proto.ValidateTunnelPeerResponse{Valid: true, SessionToken: "tok", UserId: "user-1"}, nil
}
resp, fromCache, err := cache.fetch(context.Background(), key, validate)
require.NoError(t, err)
require.NotNil(t, resp, "first fetch returns RPC response")
assert.False(t, fromCache, "first fetch must not be cached")
resp2, fromCache2, err := cache.fetch(context.Background(), key, validate)
require.NoError(t, err)
require.NotNil(t, resp2, "second fetch returns cached response")
assert.True(t, fromCache2, "second fetch must be served from cache")
assert.Equal(t, "user-1", resp2.GetUserId(), "cached response should preserve user identity")
assert.Equal(t, int32(1), atomic.LoadInt32(&calls), "validate should run exactly once with one cache hit")
}
func TestTunnelCache_ExpiredEntryRefetches(t *testing.T) {
cache := newTunnelValidationCache()
clock := time.Now()
cache.now = func() time.Time { return clock }
key := newTestKey("acct-1", "100.64.0.10", "svc.example")
var calls int32
validate := func(_ context.Context, _ *proto.ValidateTunnelPeerRequest) (*proto.ValidateTunnelPeerResponse, error) {
atomic.AddInt32(&calls, 1)
return &proto.ValidateTunnelPeerResponse{Valid: true, SessionToken: "tok"}, nil
}
_, _, err := cache.fetch(context.Background(), key, validate)
require.NoError(t, err)
assert.Equal(t, int32(1), atomic.LoadInt32(&calls), "first fetch issues one RPC")
clock = clock.Add(tunnelCacheTTL + time.Second)
_, fromCache, err := cache.fetch(context.Background(), key, validate)
require.NoError(t, err)
assert.False(t, fromCache, "expired entry must miss the cache")
assert.Equal(t, int32(2), atomic.LoadInt32(&calls), "expired entry forces a re-fetch")
}
func TestTunnelCache_DeniedResponseNotCached(t *testing.T) {
cache := newTunnelValidationCache()
key := newTestKey("acct-1", "100.64.0.10", "svc.example")
var calls int32
validate := func(_ context.Context, _ *proto.ValidateTunnelPeerRequest) (*proto.ValidateTunnelPeerResponse, error) {
atomic.AddInt32(&calls, 1)
return &proto.ValidateTunnelPeerResponse{Valid: false, DeniedReason: "not_in_group"}, nil
}
for i := 0; i < 3; i++ {
_, _, err := cache.fetch(context.Background(), key, validate)
require.NoError(t, err, "fetch must not error on denied response")
}
assert.Equal(t, int32(3), atomic.LoadInt32(&calls), "denied responses bypass the cache so policy changes apply immediately")
}
func TestTunnelCache_ConcurrentColdHitsCoalesce(t *testing.T) {
cache := newTunnelValidationCache()
key := newTestKey("acct-1", "100.64.0.10", "svc.example")
gate := make(chan struct{})
var calls int32
validate := func(_ context.Context, _ *proto.ValidateTunnelPeerRequest) (*proto.ValidateTunnelPeerResponse, error) {
atomic.AddInt32(&calls, 1)
<-gate
return &proto.ValidateTunnelPeerResponse{Valid: true, SessionToken: "tok"}, nil
}
const workers = 16
var wg sync.WaitGroup
wg.Add(workers)
results := make([]bool, workers)
for i := 0; i < workers; i++ {
go func(idx int) {
defer wg.Done()
resp, _, err := cache.fetch(context.Background(), key, validate)
results[idx] = err == nil && resp.GetValid()
}(i)
}
time.Sleep(20 * time.Millisecond)
close(gate)
wg.Wait()
for i, ok := range results {
assert.Truef(t, ok, "worker %d should observe a successful response", i)
}
assert.Equal(t, int32(1), atomic.LoadInt32(&calls), "single-flight must collapse concurrent cold fetches into one RPC")
}
func TestTunnelCache_PerAccountIsolation(t *testing.T) {
cache := newTunnelValidationCache()
keyA := newTestKey("acct-a", "100.64.0.10", "svc.example")
keyB := newTestKey("acct-b", "100.64.0.10", "svc.example")
var callsA, callsB int32
validateA := func(_ context.Context, _ *proto.ValidateTunnelPeerRequest) (*proto.ValidateTunnelPeerResponse, error) {
atomic.AddInt32(&callsA, 1)
return &proto.ValidateTunnelPeerResponse{Valid: true, SessionToken: "tok-a", UserId: "user-a"}, nil
}
validateB := func(_ context.Context, _ *proto.ValidateTunnelPeerRequest) (*proto.ValidateTunnelPeerResponse, error) {
atomic.AddInt32(&callsB, 1)
return &proto.ValidateTunnelPeerResponse{Valid: true, SessionToken: "tok-b", UserId: "user-b"}, nil
}
respA, _, err := cache.fetch(context.Background(), keyA, validateA)
require.NoError(t, err)
respB, _, err := cache.fetch(context.Background(), keyB, validateB)
require.NoError(t, err)
assert.Equal(t, "user-a", respA.GetUserId(), "account A response should belong to user-a")
assert.Equal(t, "user-b", respB.GetUserId(), "account B response must not be served from account A's cache")
assert.Equal(t, int32(1), atomic.LoadInt32(&callsA), "validateA called exactly once")
assert.Equal(t, int32(1), atomic.LoadInt32(&callsB), "validateB called exactly once")
}
func TestTunnelCache_BoundedSizeEvictsOldest(t *testing.T) {
cache := newTunnelValidationCache()
cache.maxSize = 2
validate := func(_ context.Context, req *proto.ValidateTunnelPeerRequest) (*proto.ValidateTunnelPeerResponse, error) {
return &proto.ValidateTunnelPeerResponse{Valid: true, SessionToken: "tok-" + req.GetTunnelIp()}, nil
}
keys := []tunnelCacheKey{
newTestKey("acct-1", "100.64.0.10", "svc"),
newTestKey("acct-1", "100.64.0.11", "svc"),
newTestKey("acct-1", "100.64.0.12", "svc"),
}
for _, k := range keys {
_, _, err := cache.fetch(context.Background(), k, validate)
require.NoError(t, err)
}
assert.Nil(t, cache.get(keys[0]), "oldest key should be evicted past maxSize")
assert.NotNil(t, cache.get(keys[1]), "second-newest must remain cached")
assert.NotNil(t, cache.get(keys[2]), "newest must remain cached")
}

View File

@@ -0,0 +1,325 @@
package auth
import (
"context"
"errors"
"net/http"
"net/http/httptest"
"net/netip"
"sync/atomic"
"testing"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"github.com/netbirdio/netbird/proxy/internal/proxy"
"github.com/netbirdio/netbird/shared/management/proto"
)
// stubSessionValidator records ValidateTunnelPeer calls and returns the
// pre-canned response. Counts let tests assert RPC traffic.
type stubSessionValidator struct {
respFn func(req *proto.ValidateTunnelPeerRequest) *proto.ValidateTunnelPeerResponse
respErr error
tunnelCalls atomic.Int32
}
func (s *stubSessionValidator) ValidateSession(_ context.Context, _ *proto.ValidateSessionRequest, _ ...grpc.CallOption) (*proto.ValidateSessionResponse, error) {
return &proto.ValidateSessionResponse{Valid: false}, nil
}
func (s *stubSessionValidator) ValidateTunnelPeer(_ context.Context, in *proto.ValidateTunnelPeerRequest, _ ...grpc.CallOption) (*proto.ValidateTunnelPeerResponse, error) {
s.tunnelCalls.Add(1)
if s.respErr != nil {
return nil, s.respErr
}
if s.respFn != nil {
return s.respFn(in), nil
}
return &proto.ValidateTunnelPeerResponse{Valid: false}, nil
}
func newTunnelMiddleware(t *testing.T, validator SessionValidator) *Middleware {
t.Helper()
mw := NewMiddleware(log.New(), validator, nil)
require.NoError(t, mw.AddDomain("svc.example", nil, "", 0, "acct-1", "svc-1", nil, false))
return mw
}
func newTunnelRequest(remoteAddr string) (*httptest.ResponseRecorder, *http.Request) {
w := httptest.NewRecorder()
r := httptest.NewRequest(http.MethodGet, "https://svc.example/", nil)
r.Host = "svc.example"
r.RemoteAddr = remoteAddr
return w, r
}
// TestForwardWithTunnelPeer_LocalLookupUnknownIPDeniesFast verifies the
// short-circuit: a tunnel IP not in the account's roster never reaches
// management's ValidateTunnelPeer.
func TestForwardWithTunnelPeer_LocalLookupUnknownIPDeniesFast(t *testing.T) {
validator := &stubSessionValidator{}
mw := newTunnelMiddleware(t, validator)
lookup := TunnelLookupFunc(func(_ netip.Addr) (PeerIdentity, bool) {
return PeerIdentity{}, false
})
w, r := newTunnelRequest("100.64.0.99:55555")
r = r.WithContext(WithTunnelLookup(r.Context(), lookup))
called := false
next := http.HandlerFunc(func(http.ResponseWriter, *http.Request) { called = true })
config, _ := mw.getDomainConfig("svc.example")
handled := mw.forwardWithTunnelPeer(w, r, "svc.example", config, next)
assert.False(t, handled, "unknown peer must fall through, not forward")
assert.False(t, called, "next handler must not run for unknown peer")
assert.Equal(t, int32(0), validator.tunnelCalls.Load(), "ValidateTunnelPeer must be skipped on local-lookup miss")
}
// TestForwardWithTunnelPeer_GroupsPropagateToCapturedData verifies the proxy
// surfaces the calling peer's group memberships from ValidateTunnelPeerResponse
// onto CapturedData so policy-aware middlewares can authorise without an
// extra management round-trip.
func TestForwardWithTunnelPeer_GroupsPropagateToCapturedData(t *testing.T) {
groups := []string{"engineering", "sre"}
validator := &stubSessionValidator{
respFn: func(_ *proto.ValidateTunnelPeerRequest) *proto.ValidateTunnelPeerResponse {
return &proto.ValidateTunnelPeerResponse{
Valid: true,
SessionToken: "tok",
UserId: "user-1",
PeerGroupIds: groups,
}
},
}
mw := newTunnelMiddleware(t, validator)
w, r := newTunnelRequest("100.64.0.10:55555")
cd := proxy.NewCapturedData("")
r = r.WithContext(proxy.WithCapturedData(r.Context(), cd))
called := false
next := http.HandlerFunc(func(http.ResponseWriter, *http.Request) { called = true })
config, _ := mw.getDomainConfig("svc.example")
handled := mw.forwardWithTunnelPeer(w, r, "svc.example", config, next)
require.True(t, handled, "valid tunnel-peer response must forward")
require.True(t, called, "next handler must run")
assert.Equal(t, "user-1", cd.GetUserID(), "user id must propagate from tunnel-peer response")
assert.Equal(t, groups, cd.GetUserGroups(), "peer group IDs must propagate from tunnel-peer response")
}
// TestForwardWithTunnelPeer_LocalLookupKnownPeerStillRPCs verifies that a
// known tunnel IP still triggers ValidateTunnelPeer for the user-identity
// tail (UserID + group access). Phase 3 only short-circuits the deny path.
func TestForwardWithTunnelPeer_LocalLookupKnownPeerStillRPCs(t *testing.T) {
validator := &stubSessionValidator{
respFn: func(_ *proto.ValidateTunnelPeerRequest) *proto.ValidateTunnelPeerResponse {
return &proto.ValidateTunnelPeerResponse{Valid: true, SessionToken: "tok", UserId: "user-1"}
},
}
mw := newTunnelMiddleware(t, validator)
knownIP := netip.MustParseAddr("100.64.0.10")
lookup := TunnelLookupFunc(func(ip netip.Addr) (PeerIdentity, bool) {
if ip == knownIP {
return PeerIdentity{PubKey: "pk", TunnelIP: ip, FQDN: "peer.netbird.cloud"}, true
}
return PeerIdentity{}, false
})
w, r := newTunnelRequest(knownIP.String() + ":55555")
r = r.WithContext(WithTunnelLookup(r.Context(), lookup))
called := false
next := http.HandlerFunc(func(http.ResponseWriter, *http.Request) { called = true })
config, _ := mw.getDomainConfig("svc.example")
handled := mw.forwardWithTunnelPeer(w, r, "svc.example", config, next)
assert.True(t, handled, "known peer with valid RPC response must forward")
assert.True(t, called, "next handler must run on success")
assert.Equal(t, int32(1), validator.tunnelCalls.Load(), "RPC must run for the user-identity tail when local lookup confirms the peer")
}
// TestForwardWithTunnelPeer_NoLookupKeepsLegacyPath ensures the existing
// behaviour stays intact on the host-level listener (no lookup attached).
func TestForwardWithTunnelPeer_NoLookupKeepsLegacyPath(t *testing.T) {
validator := &stubSessionValidator{
respFn: func(_ *proto.ValidateTunnelPeerRequest) *proto.ValidateTunnelPeerResponse {
return &proto.ValidateTunnelPeerResponse{Valid: true, SessionToken: "tok", UserId: "user-1"}
},
}
mw := newTunnelMiddleware(t, validator)
w, r := newTunnelRequest("100.64.0.10:55555")
called := false
next := http.HandlerFunc(func(http.ResponseWriter, *http.Request) { called = true })
config, _ := mw.getDomainConfig("svc.example")
handled := mw.forwardWithTunnelPeer(w, r, "svc.example", config, next)
assert.True(t, handled, "host-level path forwards on positive RPC result")
assert.True(t, called, "next handler runs on host-level success")
assert.Equal(t, int32(1), validator.tunnelCalls.Load(), "host-level path always RPCs (Phase 3 unchanged)")
}
// TestForwardWithTunnelPeer_RPCErrorFallsThrough validates that an RPC
// failure still falls through to the next scheme (no false positive).
func TestForwardWithTunnelPeer_RPCErrorFallsThrough(t *testing.T) {
validator := &stubSessionValidator{respErr: errors.New("management down")}
mw := newTunnelMiddleware(t, validator)
knownIP := netip.MustParseAddr("100.64.0.10")
lookup := TunnelLookupFunc(func(ip netip.Addr) (PeerIdentity, bool) {
return PeerIdentity{TunnelIP: ip}, true
})
w, r := newTunnelRequest(knownIP.String() + ":55555")
r = r.WithContext(WithTunnelLookup(r.Context(), lookup))
config, _ := mw.getDomainConfig("svc.example")
handled := mw.forwardWithTunnelPeer(w, r, "svc.example", config, http.HandlerFunc(func(http.ResponseWriter, *http.Request) {}))
assert.False(t, handled, "RPC error must let the caller try other schemes")
assert.Equal(t, int32(1), validator.tunnelCalls.Load(), "RPC was attempted exactly once")
}
// TestForwardWithTunnelPeer_CacheReusesPositiveResponse confirms the
// (account, IP, domain) cache prevents repeated RPCs for the same peer.
func TestForwardWithTunnelPeer_CacheReusesPositiveResponse(t *testing.T) {
validator := &stubSessionValidator{
respFn: func(_ *proto.ValidateTunnelPeerRequest) *proto.ValidateTunnelPeerResponse {
return &proto.ValidateTunnelPeerResponse{Valid: true, SessionToken: "tok", UserId: "user-1"}
},
}
mw := newTunnelMiddleware(t, validator)
for i := 0; i < 4; i++ {
w, r := newTunnelRequest("100.64.0.10:55555")
next := http.HandlerFunc(func(http.ResponseWriter, *http.Request) {})
config, _ := mw.getDomainConfig("svc.example")
handled := mw.forwardWithTunnelPeer(w, r, "svc.example", config, next)
require.True(t, handled, "iteration %d should forward", i)
}
assert.Equal(t, int32(1), validator.tunnelCalls.Load(), "subsequent forwards must hit the cache, not management")
}
// TestForwardWithTunnelPeer_RoutesAccountIDIntoCacheKey ensures cache keys
// honour account scoping — same tunnel IP on different accounts must not
// collide.
func TestForwardWithTunnelPeer_RoutesAccountIDIntoCacheKey(t *testing.T) {
validator := &stubSessionValidator{
respFn: func(req *proto.ValidateTunnelPeerRequest) *proto.ValidateTunnelPeerResponse {
return &proto.ValidateTunnelPeerResponse{Valid: true, SessionToken: "tok", UserId: "user"}
},
}
mw := NewMiddleware(log.New(), validator, nil)
require.NoError(t, mw.AddDomain("svc-a.example", nil, "", 0, "acct-a", "svc-a", nil, false))
require.NoError(t, mw.AddDomain("svc-b.example", nil, "", 0, "acct-b", "svc-b", nil, false))
for _, host := range []string{"svc-a.example", "svc-b.example"} {
w := httptest.NewRecorder()
r := httptest.NewRequest(http.MethodGet, "https://"+host+"/", nil)
r.Host = host
r.RemoteAddr = "100.64.0.10:55555"
config, _ := mw.getDomainConfig(host)
handled := mw.forwardWithTunnelPeer(w, r, host, config, http.HandlerFunc(func(http.ResponseWriter, *http.Request) {}))
require.True(t, handled, "host %s should forward", host)
}
assert.Equal(t, int32(2), validator.tunnelCalls.Load(), "cache must not collide across accounts even when tunnel IPs match")
}
// TestForwardWithTunnelPeer_LocalLookupShortCircuitDoesNotPopulateCache
// guarantees that the deny-fast path leaves the cache untouched, so a
// subsequent request from the same IP after the peerstore catches up
// goes through the normal RPC flow.
func TestForwardWithTunnelPeer_LocalLookupShortCircuitDoesNotPopulateCache(t *testing.T) {
validator := &stubSessionValidator{
respFn: func(_ *proto.ValidateTunnelPeerRequest) *proto.ValidateTunnelPeerResponse {
return &proto.ValidateTunnelPeerResponse{Valid: true, SessionToken: "tok"}
},
}
mw := newTunnelMiddleware(t, validator)
knownIP := netip.MustParseAddr("100.64.0.10")
known := false
lookup := TunnelLookupFunc(func(ip netip.Addr) (PeerIdentity, bool) {
if known && ip == knownIP {
return PeerIdentity{TunnelIP: ip}, true
}
return PeerIdentity{}, false
})
doRequest := func() bool {
w, r := newTunnelRequest(knownIP.String() + ":55555")
r = r.WithContext(WithTunnelLookup(r.Context(), lookup))
config, _ := mw.getDomainConfig("svc.example")
return mw.forwardWithTunnelPeer(w, r, "svc.example", config, http.HandlerFunc(func(http.ResponseWriter, *http.Request) {}))
}
require.False(t, doRequest(), "first request must short-circuit")
require.Equal(t, int32(0), validator.tunnelCalls.Load(), "short-circuit must not populate the cache")
known = true
require.True(t, doRequest(), "second request with peer in roster must forward via RPC")
assert.Equal(t, int32(1), validator.tunnelCalls.Load(), "RPC runs once after peerstore catches up")
}
func TestPrivateService_FailsClosedOnTunnelPeerFailure(t *testing.T) {
mw := NewMiddleware(log.New(), nil, nil)
require.NoError(t, mw.AddDomain("private.svc", nil, "", 0, "acct-1", "svc-1", nil, true))
called := false
handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
called = true
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodGet, "https://private.svc/", nil)
req.Host = "private.svc"
req.RemoteAddr = "100.64.0.10:55555"
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
assert.Equal(t, http.StatusForbidden, w.Code)
assert.False(t, called)
}
func TestPrivateService_ForwardsOnTunnelPeerSuccess(t *testing.T) {
validator := &stubSessionValidator{
respFn: func(_ *proto.ValidateTunnelPeerRequest) *proto.ValidateTunnelPeerResponse {
return &proto.ValidateTunnelPeerResponse{
Valid: true,
SessionToken: "tok",
UserId: "user-1",
}
},
}
mw := NewMiddleware(log.New(), validator, nil)
require.NoError(t, mw.AddDomain("private.svc", nil, "", 0, "acct-1", "svc-1", nil, true))
called := false
handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
called = true
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodGet, "https://private.svc/", nil)
req.Host = "private.svc"
req.RemoteAddr = "100.64.0.10:55555"
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
assert.True(t, called)
}

View File

@@ -11,7 +11,6 @@ import (
"net/url"
"strings"
"time"
)
// StatusFilters contains filter options for status queries.
@@ -160,6 +159,49 @@ func (c *Client) printClients(data map[string]any) {
for _, item := range clients {
c.printClientRow(item)
}
c.printInboundListeners(clients)
}
func (c *Client) printInboundListeners(clients []any) {
type row struct {
accountID string
tunnelIP string
httpsPort int
httpPort int
}
var rows []row
for _, item := range clients {
client, ok := item.(map[string]any)
if !ok {
continue
}
inbound, ok := client["inbound_listener"].(map[string]any)
if !ok {
continue
}
tunnelIP, _ := inbound["tunnel_ip"].(string)
httpsPort, _ := inbound["https_port"].(float64)
httpPort, _ := inbound["http_port"].(float64)
accountID, _ := client["account_id"].(string)
rows = append(rows, row{
accountID: accountID,
tunnelIP: tunnelIP,
httpsPort: int(httpsPort),
httpPort: int(httpPort),
})
}
if len(rows) == 0 {
return
}
_, _ = fmt.Fprintln(c.out)
_, _ = fmt.Fprintln(c.out, "Inbound listeners (per-account):")
_, _ = fmt.Fprintf(c.out, " %-38s %-20s %-7s %s\n", "ACCOUNT ID", "TUNNEL IP", "HTTPS", "HTTP")
_, _ = fmt.Fprintln(c.out, " "+strings.Repeat("-", 78))
for _, r := range rows {
_, _ = fmt.Fprintf(c.out, " %-38s %-20s %-7d %d\n", r.accountID, r.tunnelIP, r.httpsPort, r.httpPort)
}
}
func (c *Client) printClientRow(item any) {
@@ -219,7 +261,14 @@ func (c *Client) ClientStatus(ctx context.Context, accountID string, filters Sta
}
func (c *Client) printClientStatus(data map[string]any) {
_, _ = fmt.Fprintf(c.out, "Account: %v\n\n", data["account_id"])
_, _ = fmt.Fprintf(c.out, "Account: %v\n", data["account_id"])
if inbound, ok := data["inbound_listener"].(map[string]any); ok {
tunnelIP, _ := inbound["tunnel_ip"].(string)
httpsPort, _ := inbound["https_port"].(float64)
httpPort, _ := inbound["http_port"].(float64)
_, _ = fmt.Fprintf(c.out, "Inbound listener: %s (https=%d, http=%d)\n", tunnelIP, int(httpsPort), int(httpPort))
}
_, _ = fmt.Fprintln(c.out)
if status, ok := data["status"].(string); ok {
_, _ = fmt.Fprint(c.out, status)
}

View File

@@ -61,6 +61,23 @@ type clientProvider interface {
ListClientsForDebug() map[types.AccountID]roundtrip.ClientDebugInfo
}
// InboundListenerInfo describes a per-account inbound listener as
// surfaced through the debug HTTP handler. Mirrors the proto sub-message
// emitted with SendStatusUpdate so dashboards and CLI tooling see the
// same shape.
type InboundListenerInfo struct {
TunnelIP string `json:"tunnel_ip"`
HTTPSPort uint16 `json:"https_port"`
HTTPPort uint16 `json:"http_port"`
}
// InboundProvider exposes per-account inbound listener state. Optional;
// when nil the debug endpoint omits the inbound section entirely so the
// existing JSON shape stays additive.
type InboundProvider interface {
InboundListeners() map[types.AccountID]InboundListenerInfo
}
// healthChecker provides health probe state.
type healthChecker interface {
ReadinessProbe() bool
@@ -80,6 +97,7 @@ type Handler struct {
provider clientProvider
health healthChecker
certStatus certStatus
inbound InboundProvider
logger *log.Logger
startTime time.Time
templates *template.Template
@@ -108,6 +126,13 @@ func (h *Handler) SetCertStatus(cs certStatus) {
h.certStatus = cs
}
// SetInboundProvider wires per-account inbound listener observability.
// Pass nil (or skip the call) to keep the inbound section out of debug
// responses on proxies that don't run --private-inbound.
func (h *Handler) SetInboundProvider(p InboundProvider) {
h.inbound = p
}
func (h *Handler) loadTemplates() error {
tmpl, err := template.ParseFS(templateFS, "templates/*.html")
if err != nil {
@@ -323,23 +348,35 @@ func (h *Handler) handleListClients(w http.ResponseWriter, _ *http.Request, want
sortedIDs := sortedAccountIDs(clients)
if wantJSON {
var inboundAll map[types.AccountID]InboundListenerInfo
if h.inbound != nil {
inboundAll = h.inbound.InboundListeners()
}
clientsJSON := make([]map[string]interface{}, 0, len(clients))
for _, id := range sortedIDs {
info := clients[id]
clientsJSON = append(clientsJSON, map[string]interface{}{
row := map[string]interface{}{
"account_id": info.AccountID,
"service_count": info.ServiceCount,
"service_keys": info.ServiceKeys,
"has_client": info.HasClient,
"created_at": info.CreatedAt,
"age": time.Since(info.CreatedAt).Round(time.Second).String(),
})
}
if inb, ok := inboundAll[id]; ok {
row["inbound_listener"] = inb
}
clientsJSON = append(clientsJSON, row)
}
h.writeJSON(w, map[string]interface{}{
resp := map[string]interface{}{
"uptime": time.Since(h.startTime).Round(time.Second).String(),
"client_count": len(clients),
"clients": clientsJSON,
})
}
if len(inboundAll) > 0 {
resp["inbound_listener_count"] = len(inboundAll)
}
h.writeJSON(w, resp)
return
}
@@ -421,10 +458,14 @@ func (h *Handler) handleClientStatus(w http.ResponseWriter, r *http.Request, acc
})
if wantJSON {
h.writeJSON(w, map[string]interface{}{
resp := map[string]interface{}{
"account_id": accountID,
"status": overview.FullDetailSummary(),
})
}
if info, ok := h.inboundInfoFor(accountID); ok {
resp["inbound_listener"] = info
}
h.writeJSON(w, resp)
return
}
@@ -437,6 +478,18 @@ func (h *Handler) handleClientStatus(w http.ResponseWriter, r *http.Request, acc
h.renderTemplate(w, "clientDetail", data)
}
// inboundInfoFor returns the inbound listener info for an account, or
// ok=false when no inbound provider is wired or the account has no live
// listener.
func (h *Handler) inboundInfoFor(accountID types.AccountID) (InboundListenerInfo, bool) {
if h.inbound == nil {
return InboundListenerInfo{}, false
}
all := h.inbound.InboundListeners()
info, ok := all[accountID]
return info, ok
}
func (h *Handler) handleClientSyncResponse(w http.ResponseWriter, _ *http.Request, accountID types.AccountID, wantJSON bool) {
client, ok := h.provider.GetClient(accountID)
if !ok {

View File

@@ -52,8 +52,15 @@ type CapturedData struct {
origin ResponseOrigin
clientIP netip.Addr
userID string
authMethod string
metadata map[string]string
userEmail string
userGroups []string
// userGroupNames pairs positionally with userGroups; populated from
// the JWT's group_names claim or from ValidateSession/Tunnel
// responses. Slice may be shorter than userGroups for tokens minted
// before names were resolvable.
userGroupNames []string
authMethod string
metadata map[string]string
}
// NewCapturedData creates a CapturedData with the given request ID.
@@ -138,6 +145,81 @@ func (c *CapturedData) GetUserID() string {
return c.userID
}
// SetUserEmail records the authenticated user's email address. Used by
// policy-aware middlewares to stamp identity onto upstream requests
// (e.g. x-litellm-end-user-id) without a management round-trip.
func (c *CapturedData) SetUserEmail(email string) {
c.mu.Lock()
defer c.mu.Unlock()
c.userEmail = email
}
// GetUserEmail returns the authenticated user's email address. Returns
// the empty string when the auth path didn't carry an email (e.g.
// non-OIDC schemes or legacy JWTs minted before the email claim).
func (c *CapturedData) GetUserEmail() string {
c.mu.RLock()
defer c.mu.RUnlock()
return c.userEmail
}
// SetUserGroups records the authenticated user's group memberships so
// downstream policy-aware middlewares can authorise the request without
// an additional management round-trip. The auth middleware populates this
// from ValidateSessionResponse / ValidateTunnelPeerResponse and from the
// session JWT's groups claim on cookie-bearing requests.
func (c *CapturedData) SetUserGroups(groups []string) {
c.mu.Lock()
defer c.mu.Unlock()
if len(groups) == 0 {
c.userGroups = nil
return
}
c.userGroups = append(c.userGroups[:0], groups...)
}
// GetUserGroups returns a copy of the authenticated user's group
// memberships.
func (c *CapturedData) GetUserGroups() []string {
c.mu.RLock()
defer c.mu.RUnlock()
if len(c.userGroups) == 0 {
return nil
}
out := make([]string, len(c.userGroups))
copy(out, c.userGroups)
return out
}
// SetUserGroupNames records the human-readable display names for the
// user's groups, ordered identically to UserGroups (positional
// pairing). Stamped onto upstream requests as X-NetBird-Groups so
// downstream services can read names rather than opaque ids.
func (c *CapturedData) SetUserGroupNames(names []string) {
c.mu.Lock()
defer c.mu.Unlock()
if len(names) == 0 {
c.userGroupNames = nil
return
}
c.userGroupNames = append(c.userGroupNames[:0], names...)
}
// GetUserGroupNames returns a copy of the authenticated user's group
// display names. Position i pairs with UserGroups[i]. May be shorter
// than UserGroups for tokens minted before names were resolvable; the
// consumer should fall back to ids for missing positions.
func (c *CapturedData) GetUserGroupNames() []string {
c.mu.RLock()
defer c.mu.RUnlock()
if len(c.userGroupNames) == 0 {
return nil
}
out := make([]string, len(c.userGroupNames))
copy(out, c.userGroupNames)
return out
}
// SetAuthMethod sets the authentication method used.
func (c *CapturedData) SetAuthMethod(method string) {
c.mu.Lock()

View File

@@ -86,6 +86,9 @@ func (p *ReverseProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if pt.RequestTimeout > 0 {
ctx = types.WithDialTimeout(ctx, pt.RequestTimeout)
}
if pt.DirectUpstream {
ctx = roundtrip.WithDirectUpstream(ctx)
}
rewriteMatchedPath := result.matchedPath
if pt.PathRewrite == PathRewritePreserve {
@@ -142,6 +145,8 @@ func (p *ReverseProxy) rewriteFunc(target *url.URL, matchedPath string, passHost
r.Out.Header.Set(k, v)
}
stampNetBirdIdentity(r)
clientIP := extractHostIP(r.In.RemoteAddr)
if isTrustedAddr(clientIP, p.trustedProxies) {
@@ -426,3 +431,70 @@ func opErrorContains(err error, substr string) bool {
}
return false
}
const (
// headerNetBirdUser carries the authenticated user's display identity
// (email when the peer is attached to a user, else peer name) onto
// upstream requests. Stripped from inbound requests before stamping
// so a client can't spoof identity by setting the header themselves.
headerNetBirdUser = "X-NetBird-User"
// headerNetBirdGroups carries the user's group display names as a
// comma-separated list. Falls back to group IDs at positions where a
// name wasn't available at session-mint time. Labels containing a
// comma or any non-printable byte are dropped at stamp time so the
// list is unambiguously splittable by consumers.
headerNetBirdGroups = "X-NetBird-Groups"
)
// isHeaderValueSafe reports whether v is a valid RFC 7230 field-value:
// VCHAR (0x21-0x7E), SP (0x20), or HTAB (0x09). Empty values are
// rejected; the caller decides whether to omit the header entirely.
func isHeaderValueSafe(v string) bool {
if v == "" {
return false
}
for i := 0; i < len(v); i++ {
c := v[i]
if c == '\t' || (c >= 0x20 && c <= 0x7E) {
continue
}
return false
}
return true
}
// stampNetBirdIdentity injects authenticated identity onto outbound
// requests as X-NetBird-User and X-NetBird-Groups. Always strips any
// client-sent values first (anti-spoof). Skips when the request didn't
// carry CapturedData (early-path errors, internal endpoints).
func stampNetBirdIdentity(r *httputil.ProxyRequest) {
r.Out.Header.Del(headerNetBirdUser)
r.Out.Header.Del(headerNetBirdGroups)
cd := CapturedDataFromContext(r.In.Context())
if cd == nil {
return
}
if email := cd.GetUserEmail(); isHeaderValueSafe(email) {
r.Out.Header.Set(headerNetBirdUser, email)
}
groupIDs := cd.GetUserGroups()
if len(groupIDs) == 0 {
return
}
groupNames := cd.GetUserGroupNames()
labels := make([]string, 0, len(groupIDs))
for i, id := range groupIDs {
label := id
if i < len(groupNames) && groupNames[i] != "" {
label = groupNames[i]
}
if !isHeaderValueSafe(label) || strings.ContainsRune(label, ',') {
continue
}
labels = append(labels, label)
}
if len(labels) > 0 {
r.Out.Header.Set(headerNetBirdGroups, strings.Join(labels, ","))
}
}

View File

@@ -1067,3 +1067,245 @@ func TestClassifyProxyError(t *testing.T) {
})
}
}
func TestStampNetBirdIdentity_NoCapturedData_StripsOnly(t *testing.T) {
target, _ := url.Parse("http://backend.internal:8080")
p := &ReverseProxy{forwardedProto: "auto"}
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
pr := newProxyRequest(t, "http://example.com/", "203.0.113.50:9999")
pr.In.Header.Set(headerNetBirdUser, "spoofed@evil.io")
pr.In.Header.Set(headerNetBirdGroups, "admin")
pr.Out.Header = pr.In.Header.Clone()
rewrite(pr)
assert.Empty(t, pr.Out.Header.Get(headerNetBirdUser),
"client-supplied X-NetBird-User must be stripped when no captured identity is present")
assert.Empty(t, pr.Out.Header.Get(headerNetBirdGroups),
"client-supplied X-NetBird-Groups must be stripped when no captured identity is present")
}
func TestStampNetBirdIdentity_StampsFromCapturedData(t *testing.T) {
target, _ := url.Parse("http://backend.internal:8080")
p := &ReverseProxy{forwardedProto: "auto"}
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
pr := newProxyRequest(t, "http://example.com/", "203.0.113.50:9999")
pr.In.Header.Set(headerNetBirdUser, "spoofed@evil.io")
pr.Out.Header = pr.In.Header.Clone()
cd := NewCapturedData("req-1")
cd.SetUserEmail("alice@netbird.io")
cd.SetUserGroups([]string{"grp-eng", "grp-ops"})
cd.SetUserGroupNames([]string{"engineering", "operations"})
pr.In = pr.In.WithContext(WithCapturedData(pr.In.Context(), cd))
rewrite(pr)
assert.Equal(t, "alice@netbird.io", pr.Out.Header.Get(headerNetBirdUser),
"captured email must overwrite any spoofed value")
assert.Equal(t, "engineering,operations", pr.Out.Header.Get(headerNetBirdGroups),
"group display names must be CSV-joined in positional order")
}
// TestStampNetBirdIdentity_GroupsOnlyWhenEmailEmpty covers the
// tunnel-peer-without-user case (machine agents, unattached proxy peers).
// The proxy must still stamp the peer's groups so downstream services can
// authorise, but X-NetBird-User stays unset — only its inbound stripping
// must happen.
func TestStampNetBirdIdentity_GroupsOnlyWhenEmailEmpty(t *testing.T) {
target, _ := url.Parse("http://backend.internal:8080")
p := &ReverseProxy{forwardedProto: "auto"}
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
pr := newProxyRequest(t, "http://example.com/", "203.0.113.50:9999")
pr.In.Header.Set(headerNetBirdUser, "spoofed@evil.io")
pr.Out.Header = pr.In.Header.Clone()
cd := NewCapturedData("req-1")
cd.SetUserGroups([]string{"grp-machines"})
cd.SetUserGroupNames([]string{"machines"})
pr.In = pr.In.WithContext(WithCapturedData(pr.In.Context(), cd))
rewrite(pr)
assert.Empty(t, pr.Out.Header.Get(headerNetBirdUser),
"X-NetBird-User must remain unset when CapturedData carries no email")
assert.Equal(t, "machines", pr.Out.Header.Get(headerNetBirdGroups),
"groups must still be stamped for peers without a user identity")
}
// TestStampNetBirdIdentity_EmailOnlyWhenGroupsEmpty covers the symmetric
// case: identity-resolved user without resolved group memberships.
func TestStampNetBirdIdentity_EmailOnlyWhenGroupsEmpty(t *testing.T) {
target, _ := url.Parse("http://backend.internal:8080")
p := &ReverseProxy{forwardedProto: "auto"}
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
pr := newProxyRequest(t, "http://example.com/", "203.0.113.50:9999")
pr.In.Header.Set(headerNetBirdGroups, "spoofed-admin")
pr.Out.Header = pr.In.Header.Clone()
cd := NewCapturedData("req-1")
cd.SetUserEmail("carol@netbird.io")
pr.In = pr.In.WithContext(WithCapturedData(pr.In.Context(), cd))
rewrite(pr)
assert.Equal(t, "carol@netbird.io", pr.Out.Header.Get(headerNetBirdUser),
"email must be stamped even when no groups are captured")
assert.Empty(t, pr.Out.Header.Get(headerNetBirdGroups),
"X-NetBird-Groups must remain unset when CapturedData carries no groups")
}
func TestStampNetBirdIdentity_FallsBackToGroupIDsWhenNameMissing(t *testing.T) {
target, _ := url.Parse("http://backend.internal:8080")
p := &ReverseProxy{forwardedProto: "auto"}
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
pr := newProxyRequest(t, "http://example.com/", "203.0.113.50:9999")
cd := NewCapturedData("req-1")
cd.SetUserEmail("bob@netbird.io")
cd.SetUserGroups([]string{"grp-a", "grp-b", "grp-c"})
// "grp-b" gets an explicit empty-string display name (not just a
// shorter slice). Both gap shapes must fall back to the id.
cd.SetUserGroupNames([]string{"alpha", "", ""})
pr.In = pr.In.WithContext(WithCapturedData(pr.In.Context(), cd))
rewrite(pr)
assert.Equal(t, "alpha,grp-b,grp-c", pr.Out.Header.Get(headerNetBirdGroups),
"empty-string and out-of-range name slots must both fall back to the group id")
}
// TestStampNetBirdIdentity_DropsLabelsWithComma covers the
// comma-separator constraint: a group display name that itself contains
// a comma is dropped from the header (rather than corrupting the list),
// and the remaining labels are stamped.
func TestStampNetBirdIdentity_DropsLabelsWithComma(t *testing.T) {
target, _ := url.Parse("http://backend.internal:8080")
p := &ReverseProxy{forwardedProto: "auto"}
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
pr := newProxyRequest(t, "http://example.com/", "203.0.113.50:9999")
cd := NewCapturedData("req-1")
cd.SetUserEmail("alice@netbird.io")
cd.SetUserGroups([]string{"grp-a", "grp-b", "grp-c"})
cd.SetUserGroupNames([]string{"engineering", "EU, EMEA", "operations"})
pr.In = pr.In.WithContext(WithCapturedData(pr.In.Context(), cd))
rewrite(pr)
assert.Equal(t, "engineering,operations", pr.Out.Header.Get(headerNetBirdGroups),
"group label with embedded comma must be dropped, remaining labels stamped")
}
// TestStampNetBirdIdentity_RejectsControlCharsInEmail covers the
// header-injection defence: an email value containing CR/LF/control
// chars is omitted entirely (not partially stamped) so the upstream
// request stays well-formed and no header injection is possible.
func TestStampNetBirdIdentity_RejectsControlCharsInEmail(t *testing.T) {
target, _ := url.Parse("http://backend.internal:8080")
p := &ReverseProxy{forwardedProto: "auto"}
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
pr := newProxyRequest(t, "http://example.com/", "203.0.113.50:9999")
pr.In.Header.Set(headerNetBirdUser, "spoofed@evil.io")
pr.Out.Header = pr.In.Header.Clone()
cd := NewCapturedData("req-1")
cd.SetUserEmail("alice@netbird.io\r\nX-Admin: yes")
cd.SetUserGroups([]string{"grp-a"})
cd.SetUserGroupNames([]string{"engineering"})
pr.In = pr.In.WithContext(WithCapturedData(pr.In.Context(), cd))
rewrite(pr)
assert.Empty(t, pr.Out.Header.Get(headerNetBirdUser),
"email with CR/LF must be dropped, not partially stamped")
assert.Equal(t, "engineering", pr.Out.Header.Get(headerNetBirdGroups),
"groups remain stampable even when email is invalid")
}
// TestStampNetBirdIdentity_RejectsControlCharsInGroup covers the
// per-label defence: a group name with a control char is silently
// dropped, the rest are stamped.
func TestStampNetBirdIdentity_RejectsControlCharsInGroup(t *testing.T) {
target, _ := url.Parse("http://backend.internal:8080")
p := &ReverseProxy{forwardedProto: "auto"}
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
pr := newProxyRequest(t, "http://example.com/", "203.0.113.50:9999")
cd := NewCapturedData("req-1")
cd.SetUserEmail("alice@netbird.io")
cd.SetUserGroups([]string{"grp-a", "grp-b"})
cd.SetUserGroupNames([]string{"engineering\r\nsneaky", "operations"})
pr.In = pr.In.WithContext(WithCapturedData(pr.In.Context(), cd))
rewrite(pr)
assert.Equal(t, "operations", pr.Out.Header.Get(headerNetBirdGroups),
"group label with control char must be dropped, valid ones kept")
}
// TestStampNetBirdIdentity_OmitsGroupsHeaderWhenAllInvalid covers the
// edge case where every group label is rejected: the header must not be
// set at all (rather than set to an empty string).
func TestStampNetBirdIdentity_OmitsGroupsHeaderWhenAllInvalid(t *testing.T) {
target, _ := url.Parse("http://backend.internal:8080")
p := &ReverseProxy{forwardedProto: "auto"}
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
pr := newProxyRequest(t, "http://example.com/", "203.0.113.50:9999")
pr.In.Header.Set(headerNetBirdGroups, "spoofed-admin")
pr.Out.Header = pr.In.Header.Clone()
cd := NewCapturedData("req-1")
cd.SetUserEmail("alice@netbird.io")
cd.SetUserGroups([]string{"grp-a", "grp-b"})
cd.SetUserGroupNames([]string{"with,comma", "with\nbreak"})
pr.In = pr.In.WithContext(WithCapturedData(pr.In.Context(), cd))
rewrite(pr)
_, present := pr.Out.Header[http.CanonicalHeaderKey(headerNetBirdGroups)]
assert.False(t, present,
"X-NetBird-Groups must not be set when every group label is rejected")
}
// TestStampNetBirdIdentity_CapturedDataPresentButEmpty covers requests
// that carry CapturedData with no identity fields populated (e.g. the
// auth middleware ran but the request didn't authenticate). Both
// headers must be cleared and neither stamped.
func TestStampNetBirdIdentity_CapturedDataPresentButEmpty(t *testing.T) {
target, _ := url.Parse("http://backend.internal:8080")
p := &ReverseProxy{forwardedProto: "auto"}
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
pr := newProxyRequest(t, "http://example.com/", "203.0.113.50:9999")
pr.In.Header.Set(headerNetBirdUser, "spoofed@evil.io")
pr.In.Header.Set(headerNetBirdGroups, "spoofed-admin")
pr.Out.Header = pr.In.Header.Clone()
cd := NewCapturedData("req-1")
pr.In = pr.In.WithContext(WithCapturedData(pr.In.Context(), cd))
rewrite(pr)
assert.Empty(t, pr.Out.Header.Get(headerNetBirdUser),
"X-NetBird-User must be stripped when CapturedData has no email")
assert.Empty(t, pr.Out.Header.Get(headerNetBirdGroups),
"X-NetBird-Groups must be stripped when CapturedData has no groups")
}

View File

@@ -28,6 +28,10 @@ type PathTarget struct {
RequestTimeout time.Duration
PathRewrite PathRewriteMode
CustomHeaders map[string]string
// DirectUpstream selects the stdlib HTTP transport (host network stack)
// over the embedded NetBird WireGuard client when forwarding requests
// to this target. Default false → embedded client (existing behaviour).
DirectUpstream bool
}
// Mapping describes how a domain is routed by the HTTP reverse proxy.

View File

@@ -191,6 +191,18 @@ func (f *Filter) IsObserveOnly(v Verdict) bool {
return v.IsCrowdSec() && f.CrowdSecMode == CrowdSecObserve
}
// CheckCIDR runs only the CIDR allow/block evaluation. Use this when
// country and CrowdSec checks don't apply — e.g. requests arriving
// from the WireGuard overlay, whose source addresses live in the
// CGNAT range and have no meaningful geolocation or IP-reputation
// data.
func (f *Filter) CheckCIDR(addr netip.Addr) Verdict {
if f == nil {
return Allow
}
return f.checkCIDR(addr.Unmap())
}
// Check evaluates whether addr is permitted. CIDR rules are evaluated
// first because they are O(n) prefix comparisons. Country rules run
// only when CIDR checks pass and require a geo lookup. CrowdSec checks

View File

@@ -514,6 +514,34 @@ func TestFilter_CrowdSec_Observe_NilChecker(t *testing.T) {
assert.Equal(t, Allow, f.Check(netip.MustParseAddr("1.2.3.4"), nil))
}
func TestFilter_CheckCIDR_AllowsWithoutCountryOrCrowdSec(t *testing.T) {
cs := &mockCrowdSec{ready: true, decisions: map[string]*CrowdSecDecision{
"100.64.5.6": {Type: DecisionBan},
}}
f := ParseFilter(FilterConfig{
AllowedCIDRs: []string{"100.64.0.0/10"},
AllowedCountries: []string{"US"},
CrowdSec: cs,
CrowdSecMode: CrowdSecEnforce,
})
// CheckCIDR skips country + CrowdSec evaluation: an address inside
// the allowed CIDR passes even when it would be denied by CrowdSec
// or by the country allowlist (CGNAT addresses have no geo data).
assert.Equal(t, Allow, f.CheckCIDR(netip.MustParseAddr("100.64.5.6")),
"CheckCIDR must not run CrowdSec lookups on overlay traffic")
// CIDR denials still fire.
assert.Equal(t, DenyCIDR, f.CheckCIDR(netip.MustParseAddr("198.51.100.1")),
"CheckCIDR must still reject addresses outside the allow list")
}
func TestFilter_CheckCIDR_NilFilter(t *testing.T) {
var f *Filter
assert.Equal(t, Allow, f.CheckCIDR(netip.MustParseAddr("100.64.5.6")),
"CheckCIDR on a nil filter must allow")
}
func TestFilter_HasRestrictions_CrowdSec(t *testing.T) {
cs := &mockCrowdSec{ready: true}
f := ParseFilter(FilterConfig{CrowdSec: cs, CrowdSecMode: CrowdSecEnforce})

View File

@@ -0,0 +1,112 @@
package roundtrip
import (
"crypto/tls"
"errors"
"net"
"net/http"
"time"
log "github.com/sirupsen/logrus"
)
// MultiTransport dispatches each request to either the embedded NetBird
// http.RoundTripper or a stdlib http.Transport based on a per-request
// context flag set by the reverse-proxy rewrite step. When the flag is
// absent (the default for every existing target), requests follow the
// embedded NetBird path — current behaviour, preserved.
//
// The stdlib branch is used when a target was configured with
// direct_upstream=true. It dials via the host's network stack, which is
// what private (`netbird proxy`) deployments and centralised proxies
// fronting host-reachable upstreams (public APIs, LAN services,
// localhost sidecars) want.
//
// An embedded roundtripper is required. To run direct-only (no WG
// branch at all), construct the MultiTransport via NewDirectOnly.
type MultiTransport struct {
embedded http.RoundTripper
direct *http.Transport
insecure *http.Transport
}
// errNoEmbeddedTransport is returned when a request reaches the
// embedded branch on a MultiTransport that wasn't given one. Surfaces
// the misconfiguration to the caller instead of silently routing to
// the direct branch, which would bypass the WG tunnel.
var errNoEmbeddedTransport = errors.New("multitransport: embedded roundtripper not configured")
// NewMultiTransport wires both branches. embedded is the existing NetBird
// roundtripper and must not be nil — pass to NewDirectOnly for a
// MultiTransport that only ever uses the direct branch. The direct
// branches honour the same NB_PROXY_* tuning env vars as the embedded
// transport (see loadTransportConfig) plus a dial-timeout wrapper that
// respects types.WithDialTimeout.
func NewMultiTransport(embedded http.RoundTripper, logger *log.Logger) *MultiTransport {
if logger == nil {
logger = log.StandardLogger()
}
cfg := loadTransportConfig(logger)
dialer := &net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}
direct := &http.Transport{
DialContext: dialWithTimeout(dialer.DialContext),
ForceAttemptHTTP2: true,
MaxIdleConns: cfg.maxIdleConns,
MaxIdleConnsPerHost: cfg.maxIdleConnsPerHost,
MaxConnsPerHost: cfg.maxConnsPerHost,
IdleConnTimeout: cfg.idleConnTimeout,
TLSHandshakeTimeout: cfg.tlsHandshakeTimeout,
ExpectContinueTimeout: cfg.expectContinueTimeout,
ResponseHeaderTimeout: cfg.responseHeaderTimeout,
WriteBufferSize: cfg.writeBufferSize,
ReadBufferSize: cfg.readBufferSize,
DisableCompression: cfg.disableCompression,
}
insecure := direct.Clone()
insecure.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} //nolint:gosec // matches the embedded NetBird transport's per-target opt-in
return &MultiTransport{
embedded: embedded,
direct: direct,
insecure: insecure,
}
}
// NewDirectOnly returns a MultiTransport with no embedded branch.
// Every request goes through the direct branch regardless of the
// per-request flag, so the embedded path can never be reached
// silently — wiring code that needs WG must use NewMultiTransport.
func NewDirectOnly(logger *log.Logger) *MultiTransport {
return NewMultiTransport(noEmbeddedRoundTripper{}, logger)
}
// noEmbeddedRoundTripper is the sentinel embedded transport for
// direct-only MultiTransports. RoundTrip is never called in practice
// because the direct branch matches every request, but if anything
// ever did reach this path it would fail loudly instead of falling
// back to direct.
type noEmbeddedRoundTripper struct{}
func (noEmbeddedRoundTripper) RoundTrip(*http.Request) (*http.Response, error) {
return nil, errNoEmbeddedTransport
}
// RoundTrip dispatches by reading the direct-upstream flag from the request
// context. When set, the request is forwarded via the stdlib transport,
// honouring the existing per-request skip-TLS-verify flag. Otherwise it
// goes through the embedded NetBird roundtripper.
func (m *MultiTransport) RoundTrip(req *http.Request) (*http.Response, error) {
if DirectUpstreamFromContext(req.Context()) {
if skipTLSVerifyFromContext(req.Context()) {
return m.insecure.RoundTrip(req)
}
return m.direct.RoundTrip(req)
}
if m.embedded == nil {
return nil, errNoEmbeddedTransport
}
return m.embedded.RoundTrip(req)
}

View File

@@ -0,0 +1,134 @@
package roundtrip
import (
"context"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// stubRoundTripper records whether RoundTrip was called and returns a
// canned response so tests can assert the dispatch decision without
// running a real network.
type stubRoundTripper struct {
called bool
body string
}
func (s *stubRoundTripper) RoundTrip(_ *http.Request) (*http.Response, error) {
s.called = true
return &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader(s.body)),
Header: http.Header{},
}, nil
}
func TestMultiTransport_DispatchesByContextFlag(t *testing.T) {
embedded := &stubRoundTripper{body: "embedded"}
mt := NewMultiTransport(embedded, nil)
t.Run("default routes to embedded", func(t *testing.T) {
embedded.called = false
req := httptest.NewRequest(http.MethodGet, "http://example.invalid", nil)
resp, err := mt.RoundTrip(req)
require.NoError(t, err, "embedded path must not error on stubbed transport")
require.NotNil(t, resp)
_ = resp.Body.Close()
assert.True(t, embedded.called, "request without WithDirectUpstream must hit the embedded transport")
})
t.Run("WithDirectUpstream skips embedded", func(t *testing.T) {
embedded.called = false
// Hit a server we control to verify the stdlib transport is used.
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
_, _ = io.WriteString(w, "direct")
}))
defer srv.Close()
req, err := http.NewRequestWithContext(WithDirectUpstream(context.Background()), http.MethodGet, srv.URL, nil)
require.NoError(t, err)
resp, err := mt.RoundTrip(req)
require.NoError(t, err, "direct path must dial via stdlib transport")
body, err := io.ReadAll(resp.Body)
_ = resp.Body.Close()
require.NoError(t, err)
assert.Equal(t, "direct", string(body), "stdlib transport must reach the test server")
assert.False(t, embedded.called, "WithDirectUpstream must bypass the embedded transport")
})
}
// TestMultiTransport_AppliesEnvOverridesToDirect verifies that the
// NB_PROXY_* env vars consumed by loadTransportConfig flow into the
// direct branches (previously they only applied to the embedded
// roundtripper, so direct-upstream traffic ignored operator tuning).
func TestMultiTransport_AppliesEnvOverridesToDirect(t *testing.T) {
t.Setenv(EnvMaxIdleConns, "42")
t.Setenv(EnvIdleConnTimeout, "11s")
t.Setenv(EnvTLSHandshakeTimeout, "7s")
mt := NewMultiTransport(&stubRoundTripper{body: "embedded"}, nil)
assert.Equal(t, 42, mt.direct.MaxIdleConns,
"NB_PROXY_MAX_IDLE_CONNS must propagate to the direct transport")
assert.Equal(t, 11*time.Second, mt.direct.IdleConnTimeout,
"NB_PROXY_IDLE_CONN_TIMEOUT must propagate to the direct transport")
assert.Equal(t, 7*time.Second, mt.direct.TLSHandshakeTimeout,
"NB_PROXY_TLS_HANDSHAKE_TIMEOUT must propagate to the direct transport")
assert.Equal(t, 42, mt.insecure.MaxIdleConns,
"env tuning must also apply to the insecure-skip-verify direct transport")
}
// TestMultiTransport_NilEmbeddedErrorsWhenWGPathRequested guards
// against the previous silent fallback: a MultiTransport constructed
// without an embedded transport must reject requests that don't
// explicitly opt into the direct branch, rather than routing them
// over the host stack and bypassing WireGuard.
func TestMultiTransport_NilEmbeddedErrorsWhenWGPathRequested(t *testing.T) {
mt := NewMultiTransport(nil, nil)
req := httptest.NewRequest(http.MethodGet, "http://example.invalid", nil)
resp, err := mt.RoundTrip(req)
if resp != nil {
_ = resp.Body.Close()
}
require.Error(t, err, "nil embedded must surface as an explicit error, not a silent direct dispatch")
assert.Nil(t, resp)
assert.ErrorIs(t, err, errNoEmbeddedTransport,
"the error must be the sentinel so callers can distinguish misconfiguration from network failures")
}
// TestMultiTransport_DirectOnlyServesDirectBranch verifies NewDirectOnly
// constructs a MultiTransport whose direct branch handles requests with
// the direct-upstream flag set, and surfaces the explicit sentinel
// when the embedded path is reached.
func TestMultiTransport_DirectOnlyServesDirectBranch(t *testing.T) {
mt := NewDirectOnly(nil)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
_, _ = io.WriteString(w, "ok")
}))
defer srv.Close()
req, err := http.NewRequestWithContext(WithDirectUpstream(context.Background()), http.MethodGet, srv.URL, nil)
require.NoError(t, err)
resp, err := mt.RoundTrip(req)
require.NoError(t, err, "direct-only must serve requests that opt into the direct branch")
_ = resp.Body.Close()
assert.Equal(t, http.StatusOK, resp.StatusCode)
wgReq := httptest.NewRequest(http.MethodGet, "http://example.invalid", nil)
resp, err = mt.RoundTrip(wgReq)
if resp != nil {
_ = resp.Body.Close()
}
require.Error(t, err, "direct-only must refuse requests that didn't opt into the direct branch")
assert.Nil(t, resp)
assert.ErrorIs(t, err, errNoEmbeddedTransport)
}

View File

@@ -7,6 +7,7 @@ import (
"fmt"
"net"
"net/http"
"net/netip"
"sync"
"time"
@@ -76,11 +77,11 @@ type clientEntry struct {
services map[ServiceKey]serviceInfo
createdAt time.Time
started bool
// ready is closed once the client has been fully initialized.
// Callers that find a pending entry wait on this channel before
// accessing the client. A nil initErr means success.
ready chan struct{}
initErr error
// inbound is opaque per-account state owned by the NetBird parent's
// ReadyHandler. The roundtrip package never inspects this value; it
// only stores it so RemovePeer / StopAll can hand it back to the
// matching StopHandler. Nil when no inbound integration is active.
inbound any
// Per-backend in-flight limiting keyed by target host:port.
// TODO: clean up stale entries when backend targets change.
inflightMu sync.Mutex
@@ -88,6 +89,19 @@ type clientEntry struct {
maxInflight int
}
// IdentityForIP resolves a tunnel IP to the peer identity locally known by
// this account's embedded client. Returns (pubKey, fqdn) on success.
// ok=false means the IP is not in the account's roster — callers can use
// that as a fast deny without round-tripping management. The returned
// strings carry only what the embedded peerstore exposes; user identity
// (UserID / Email / Groups) still flows through ValidateTunnelPeer.
func (e *clientEntry) IdentityForIP(ip netip.Addr) (pubKey, fqdn string, ok bool) {
if e == nil || e.client == nil || !ip.IsValid() {
return "", "", false
}
return e.client.IdentityForIP(ip)
}
// acquireInflight attempts to acquire an in-flight slot for the given backend.
// It returns a release function that must always be called, and true on success.
func (e *clientEntry) acquireInflight(backend backendKey) (release func(), ok bool) {
@@ -117,6 +131,12 @@ type ClientConfig struct {
MgmtAddr string
WGPort uint16
PreSharedKey string
// BlockInbound mirrors embed.Options.BlockInbound. Set to true on the
// standalone proxy where the embedded client never accepts inbound;
// set to false on the private/embedded proxy so the engine creates
// the ACL manager and applies management's per-policy firewall rules
// (which is what gates per-account inbound listeners on the netstack).
BlockInbound bool
}
type statusNotifier interface {
@@ -142,6 +162,14 @@ type NetBird struct {
clients map[types.AccountID]*clientEntry
initLogOnce sync.Once
statusNotifier statusNotifier
// readyHandler runs after the embedded client for an account reports
// Ready. The opaque return value is stored on clientEntry and handed
// back to stopHandler when the entry is torn down. Nil disables the
// hook entirely (default for the standalone proxy).
readyHandler func(ctx context.Context, accountID types.AccountID, client *embed.Client) any
// stopHandler runs when an account's last service is removed (or the
// transport is shutting down). Receives whatever readyHandler returned.
stopHandler func(accountID types.AccountID, state any)
// OnAddPeer, when set, is called after AddPeer completes for a new account
// (i.e. when a new client was actually created, not when an existing one
@@ -167,9 +195,6 @@ type skipTLSVerifyContextKey struct{}
// AddPeer registers a service for an account. If the account doesn't have a client yet,
// one is created by authenticating with the management server using the provided token.
// Multiple services can share the same client.
//
// Client creation (WG keygen, gRPC, embed.New) runs without holding clientsMux
// so that concurrent AddPeer calls for different accounts execute in parallel.
func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, key ServiceKey, authToken string, serviceID types.ServiceID) error {
si := serviceInfo{serviceID: serviceID}
@@ -177,23 +202,10 @@ func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, key Se
entry, exists := n.clients[accountID]
if exists {
ready := entry.ready
entry.services[key] = si
started := entry.started
n.clientsMux.Unlock()
// If the entry is still being initialized by another goroutine, wait.
if ready != nil {
select {
case <-ready:
case <-ctx.Done():
return ctx.Err()
}
if entry.initErr != nil {
return fmt.Errorf("peer initialization failed: %w", entry.initErr)
}
}
n.logger.WithFields(log.Fields{
"account_id": accountID,
"service_key": key,
@@ -210,43 +222,19 @@ func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, key Se
return nil
}
// Insert a placeholder so other goroutines calling AddPeer for the same
// account will wait on the ready channel instead of starting a second
// client creation.
entry = &clientEntry{
services: map[ServiceKey]serviceInfo{key: si},
ready: make(chan struct{}),
}
n.clients[accountID] = entry
n.clientsMux.Unlock()
createStart := time.Now()
created, err := n.createClientEntry(ctx, accountID, key, authToken, si)
entry, err := n.createClientEntry(ctx, accountID, key, authToken, si)
if n.OnAddPeer != nil {
n.OnAddPeer(time.Since(createStart), err)
}
if err != nil {
entry.initErr = err
close(entry.ready)
n.clientsMux.Lock()
delete(n.clients, accountID)
n.clientsMux.Unlock()
return err
}
// Transfer any services that were registered by concurrent AddPeer calls
// while we were creating the client.
n.clientsMux.Lock()
for k, v := range entry.services {
created.services[k] = v
}
created.ready = nil
n.clients[accountID] = created
n.clients[accountID] = entry
n.clientsMux.Unlock()
close(entry.ready)
n.logger.WithFields(log.Fields{
"account_id": accountID,
"service_key": key,
@@ -254,13 +242,13 @@ func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, key Se
// Attempt to start the client in the background; if this fails we will
// retry on the first request via RoundTrip.
go n.runClientStartup(ctx, accountID, created.client)
go n.runClientStartup(ctx, accountID, entry.client)
return nil
}
// createClientEntry generates a WireGuard keypair, authenticates with management,
// and creates an embedded NetBird client.
// and creates an embedded NetBird client. Must be called with clientsMux held.
func (n *NetBird) createClientEntry(ctx context.Context, accountID types.AccountID, key ServiceKey, authToken string, si serviceInfo) (*clientEntry, error) {
serviceID := si.serviceID
n.logger.WithFields(log.Fields{
@@ -318,9 +306,15 @@ func (n *NetBird) createClientEntry(ctx context.Context, accountID types.Account
ManagementURL: n.clientCfg.MgmtAddr,
PrivateKey: privateKey.String(),
LogLevel: log.WarnLevel.String(),
BlockInbound: true,
WireguardPort: &wgPort,
PreSharedKey: n.clientCfg.PreSharedKey,
BlockInbound: n.clientCfg.BlockInbound,
// The embedded proxy peer must never be a stepping stone into
// the proxy host's LAN: it only exists to reach NetBird mesh
// targets or, when direct_upstream is set, the host network
// stack via the MultiTransport's direct branch (which bypasses
// the engine routing entirely).
BlockLANAccess: true,
WireguardPort: &wgPort,
PreSharedKey: n.clientCfg.PreSharedKey,
})
if err != nil {
return nil, fmt.Errorf("create netbird client: %w", err)
@@ -385,8 +379,25 @@ func (n *NetBird) runClientStartup(ctx context.Context, accountID types.AccountI
toNotify = append(toNotify, serviceNotification{key: key, serviceID: info.serviceID})
}
}
readyHandler := n.readyHandler
n.clientsMux.Unlock()
if readyHandler != nil {
state := readyHandler(ctx, accountID, client)
n.clientsMux.Lock()
if e, ok := n.clients[accountID]; ok {
e.inbound = state
} else if state != nil && n.stopHandler != nil {
// Account was removed while readyHandler ran; tear down the
// resources it just brought up.
stop := n.stopHandler
n.clientsMux.Unlock()
stop(accountID, state)
n.clientsMux.Lock()
}
n.clientsMux.Unlock()
}
if n.statusNotifier == nil {
return
}
@@ -432,11 +443,15 @@ func (n *NetBird) RemovePeer(ctx context.Context, accountID types.AccountID, key
stopClient := len(entry.services) == 0
var client *embed.Client
var transport, insecureTransport *http.Transport
var inbound any
var stopHandler func(types.AccountID, any)
if stopClient {
n.logger.WithField("account_id", accountID).Info("stopping client, no more services")
client = entry.client
transport = entry.transport
insecureTransport = entry.insecureTransport
inbound = entry.inbound
stopHandler = n.stopHandler
delete(n.clients, accountID)
} else {
n.logger.WithFields(log.Fields{
@@ -450,6 +465,9 @@ func (n *NetBird) RemovePeer(ctx context.Context, accountID types.AccountID, key
n.notifyDisconnect(ctx, accountID, key, si.serviceID)
if stopClient {
if inbound != nil && stopHandler != nil {
stopHandler(accountID, inbound)
}
transport.CloseIdleConnections()
insecureTransport.CloseIdleConnections()
if err := client.Stop(ctx); err != nil {
@@ -536,8 +554,12 @@ func (n *NetBird) StopAll(ctx context.Context) error {
n.clientsMux.Lock()
defer n.clientsMux.Unlock()
stopHandler := n.stopHandler
var merr *multierror.Error
for accountID, entry := range n.clients {
if entry.inbound != nil && stopHandler != nil {
stopHandler(accountID, entry.inbound)
}
entry.transport.CloseIdleConnections()
entry.insecureTransport.CloseIdleConnections()
if err := entry.client.Stop(ctx); err != nil {
@@ -590,6 +612,19 @@ func (n *NetBird) GetClient(accountID types.AccountID) (*embed.Client, bool) {
return entry.client, true
}
// IdentityForIP resolves a tunnel IP to a peer identity local to the given
// account. Delegates to clientEntry.IdentityForIP. Returns ok=false when
// the account has no client or the IP is not in its peerstore.
func (n *NetBird) IdentityForIP(accountID types.AccountID, ip netip.Addr) (pubKey, fqdn string, ok bool) {
n.clientsMux.RLock()
entry, exists := n.clients[accountID]
n.clientsMux.RUnlock()
if !exists {
return "", "", false
}
return entry.IdentityForIP(ip)
}
// ListClientsForDebug returns information about all clients for debug purposes.
func (n *NetBird) ListClientsForDebug() map[types.AccountID]ClientDebugInfo {
n.clientsMux.RLock()
@@ -645,6 +680,18 @@ func NewNetBird(proxyID, proxyAddr string, clientCfg ClientConfig, logger *log.L
}
}
// SetClientLifecycle registers callbacks that run when an embedded
// client becomes ready and when its entry is torn down. The opaque value
// returned by ready is stored on the entry and handed back to stop on
// cleanup. Must be called before AddPeer. A nil pair leaves the
// outbound-only behaviour intact.
func (n *NetBird) SetClientLifecycle(ready func(ctx context.Context, accountID types.AccountID, client *embed.Client) any, stop func(accountID types.AccountID, state any)) {
n.clientsMux.Lock()
defer n.clientsMux.Unlock()
n.readyHandler = ready
n.stopHandler = stop
}
// dialWithTimeout wraps a DialContext function so that any dial timeout
// stored in the context (via types.WithDialTimeout) is applied only to
// the connection establishment phase, not the full request lifetime.
@@ -687,3 +734,22 @@ func skipTLSVerifyFromContext(ctx context.Context) bool {
v, _ := ctx.Value(skipTLSVerifyContextKey{}).(bool)
return v
}
// directUpstreamContextKey signals that the request should bypass the embedded
// NetBird WireGuard client and dial via the host's network stack instead.
// Set by the reverse-proxy rewrite step when the matched target carries
// PathTarget.DirectUpstream; consumed by MultiTransport.
type directUpstreamContextKey struct{}
// WithDirectUpstream marks the context so MultiTransport routes the request
// through its stdlib transport instead of the embedded NetBird roundtripper.
func WithDirectUpstream(ctx context.Context) context.Context {
return context.WithValue(ctx, directUpstreamContextKey{}, true)
}
// DirectUpstreamFromContext reports whether the context has been marked to
// bypass the embedded NetBird client.
func DirectUpstreamFromContext(ctx context.Context) bool {
v, _ := ctx.Value(directUpstreamContextKey{}).(bool)
return v
}

View File

@@ -3,6 +3,7 @@ package roundtrip
import (
"context"
"net/http"
"net/netip"
"sync"
"testing"
@@ -305,6 +306,36 @@ func TestNetBird_AddPeer_ExistingStartedClient_NotifiesStatus(t *testing.T) {
assert.True(t, calls[0].connected)
}
// TestNetBird_IdentityForIP_UnknownAccountReturnsFalse confirms that the
// public lookup short-circuits when no client has been registered for
// the queried account. The auth middleware uses ok=false as a fast deny.
func TestNetBird_IdentityForIP_UnknownAccountReturnsFalse(t *testing.T) {
nb := mockNetBird()
_, _, ok := nb.IdentityForIP("acct-missing", netip.MustParseAddr("100.64.0.10"))
assert.False(t, ok, "unknown account must yield ok=false")
}
// TestClientEntry_IdentityForIP_NilClientGuard ensures the receiver
// methods stay safe when called on partially-initialized state, which
// can happen briefly during AddPeer setup or test fixtures.
func TestClientEntry_IdentityForIP_NilClientGuard(t *testing.T) {
var e *clientEntry
_, _, ok := e.IdentityForIP(netip.MustParseAddr("100.64.0.10"))
assert.False(t, ok, "nil clientEntry must yield ok=false")
e = &clientEntry{}
_, _, ok = e.IdentityForIP(netip.MustParseAddr("100.64.0.10"))
assert.False(t, ok, "clientEntry with nil embed.Client must yield ok=false")
}
// TestClientEntry_IdentityForIP_InvalidIPReturnsFalse covers the input
// guard so callers don't have to repeat the check.
func TestClientEntry_IdentityForIP_InvalidIPReturnsFalse(t *testing.T) {
e := &clientEntry{}
_, _, ok := e.IdentityForIP(netip.Addr{})
assert.False(t, ok, "invalid IP must yield ok=false")
}
func TestNetBird_RemovePeer_NotifiesDisconnection(t *testing.T) {
notifier := &mockStatusNotifier{}
nb := NewNetBird("test-proxy", "invalid.test", ClientConfig{

View File

@@ -36,7 +36,7 @@ func BenchmarkPeekClientHello_TLS(b *testing.B) {
for b.Loop() {
r := bytes.NewReader(hello)
conn := &readerConn{Reader: r}
sni, wrapped, err := PeekClientHello(conn)
sni, wrapped, _, err := PeekClientHello(conn)
if err != nil {
b.Fatal(err)
}
@@ -59,7 +59,7 @@ func BenchmarkPeekClientHello_NonTLS(b *testing.B) {
for b.Loop() {
r := bytes.NewReader(httpReq)
conn := &readerConn{Reader: r}
_, wrapped, err := PeekClientHello(conn)
_, wrapped, _, err := PeekClientHello(conn)
if err != nil {
b.Fatal(err)
}

View File

@@ -100,28 +100,50 @@ type Router struct {
// httpCh is immutable after construction: set only in NewRouter, nil in NewPortRouter.
httpCh chan net.Conn
httpListener *chanListener
mu sync.RWMutex
routes map[SNIHost][]Route
fallback *Route
draining bool
dialResolve DialResolver
activeConns sync.WaitGroup
activeRelays sync.WaitGroup
relaySem chan struct{}
drainDone chan struct{}
observer RelayObserver
accessLog l4Logger
geo restrict.GeoResolver
// httpPlainCh feeds non-TLS HTTP connections to a parallel http.Server.
// Set only when NewRouter is called with WithPlainHTTP option (used by
// per-account inbound listeners that accept both :80 and :443 traffic).
// Nil for the host SNI router and for port routers.
httpPlainCh chan net.Conn
httpPlainListener *chanListener
mu sync.RWMutex
routes map[SNIHost][]Route
fallback *Route
draining bool
dialResolve DialResolver
activeConns sync.WaitGroup
activeRelays sync.WaitGroup
relaySem chan struct{}
drainDone chan struct{}
observer RelayObserver
accessLog l4Logger
geo restrict.GeoResolver
// svcCtxs tracks a context per service ID. All relay goroutines for a
// service derive from its context; canceling it kills them immediately.
svcCtxs map[types.ServiceID]context.Context
svcCancels map[types.ServiceID]context.CancelFunc
}
// RouterOption customises Router construction.
type RouterOption func(*Router)
// WithPlainHTTP enables a parallel plain-HTTP channel on the router. When
// set, connections whose first byte is not a TLS handshake are forwarded
// to the plain channel returned by HTTPListenerPlain instead of the TLS
// channel. Used by per-account inbound listeners that share both :80 and
// :443 traffic on the same router.
func WithPlainHTTP(addr net.Addr) RouterOption {
return func(r *Router) {
ch := make(chan net.Conn, httpChannelBuffer)
r.httpPlainCh = ch
r.httpPlainListener = newChanListener(ch, addr)
}
}
// NewRouter creates a new SNI-based connection router.
func NewRouter(logger *log.Logger, dialResolve DialResolver, addr net.Addr) *Router {
func NewRouter(logger *log.Logger, dialResolve DialResolver, addr net.Addr, opts ...RouterOption) *Router {
httpCh := make(chan net.Conn, httpChannelBuffer)
return &Router{
r := &Router{
logger: logger,
httpCh: httpCh,
httpListener: newChanListener(httpCh, addr),
@@ -131,6 +153,10 @@ func NewRouter(logger *log.Logger, dialResolve DialResolver, addr net.Addr) *Rou
svcCtxs: make(map[types.ServiceID]context.Context),
svcCancels: make(map[types.ServiceID]context.CancelFunc),
}
for _, opt := range opts {
opt(r)
}
return r
}
// NewPortRouter creates a Router for a dedicated port without an HTTP
@@ -153,6 +179,16 @@ func (r *Router) HTTPListener() net.Listener {
return r.httpListener
}
// HTTPListenerPlain returns a net.Listener yielding non-TLS connections
// for use with a parallel plain http.Server. Returns nil when the router
// was not constructed with WithPlainHTTP.
func (r *Router) HTTPListenerPlain() net.Listener {
if r.httpPlainListener == nil {
return nil
}
return r.httpPlainListener
}
// AddRoute registers an SNI route. Multiple routes for the same host are
// stored and resolved by priority at lookup time (HTTP > TCP).
// Empty host is ignored to prevent conflicts with ECH/ESNI fallback.
@@ -254,6 +290,9 @@ func (r *Router) Serve(ctx context.Context, ln net.Listener) error {
if r.httpListener != nil {
r.httpListener.Close()
}
if r.httpPlainListener != nil {
r.httpPlainListener.Close()
}
case <-done:
}
}()
@@ -270,6 +309,7 @@ func (r *Router) Serve(ctx context.Context, ln net.Listener) error {
r.logger.Debugf("SNI router accept: %v", err)
continue
}
r.logger.Debugf("SNI router accepted conn from %s on %s", conn.RemoteAddr(), conn.LocalAddr())
r.activeConns.Add(1)
go func() {
defer r.activeConns.Done()
@@ -278,13 +318,24 @@ func (r *Router) Serve(ctx context.Context, ln net.Listener) error {
}
}
// HandleConn lets external accept loops feed a connection through the
// router's peek-and-dispatch logic. Use this when the same router serves
// a secondary listener (for example, a per-account inbound :80 socket
// alongside its :443 socket).
func (r *Router) HandleConn(ctx context.Context, conn net.Conn) {
r.activeConns.Add(1)
defer r.activeConns.Done()
r.handleConn(ctx, conn)
}
// handleConn peeks at the TLS ClientHello and routes the connection.
func (r *Router) handleConn(ctx context.Context, conn net.Conn) {
// Fast path: when no SNI routes and no HTTP channel exist (pure TCP
// fallback port), skip the TLS peek entirely to avoid read errors on
// non-TLS connections and reduce latency.
if r.isFallbackOnly() {
r.handleUnmatched(ctx, conn)
r.logger.Debugf("SNI router fallback-only mode for conn from %s; skipping ClientHello peek", conn.RemoteAddr())
r.handleUnmatched(ctx, conn, false)
return
}
@@ -294,11 +345,11 @@ func (r *Router) handleConn(ctx context.Context, conn net.Conn) {
return
}
sni, wrapped, err := PeekClientHello(conn)
sni, wrapped, isTLS, err := PeekClientHello(conn)
if err != nil {
r.logger.Debugf("SNI peek: %v", err)
r.logger.Debugf("SNI peek failed for conn from %s: %v", conn.RemoteAddr(), err)
if wrapped != nil {
r.handleUnmatched(ctx, wrapped)
r.handleUnmatched(ctx, wrapped, isTLS)
} else {
_ = conn.Close()
}
@@ -313,13 +364,20 @@ func (r *Router) handleConn(ctx context.Context, conn net.Conn) {
host := SNIHost(strings.ToLower(sni))
route, ok := r.lookupRoute(host)
r.logger.WithFields(log.Fields{
"remote": wrapped.RemoteAddr().String(),
"sni": string(host),
"match": ok,
"tls": isTLS,
}).Debug("SNI route lookup")
if !ok {
r.handleUnmatched(ctx, wrapped)
r.handleUnmatched(ctx, wrapped, isTLS)
return
}
if route.Type == RouteHTTP {
r.sendToHTTP(wrapped)
r.logger.Debugf("SNI %q routed to HTTP handler (service_id=%s)", host, route.ServiceID)
r.sendToHTTP(wrapped, isTLS)
return
}
@@ -344,15 +402,17 @@ func (r *Router) isFallbackOnly() bool {
}
// handleUnmatched routes a connection that didn't match any SNI route.
// This includes ECH/ESNI connections where the cleartext SNI is empty.
// This includes ECH/ESNI connections where the cleartext SNI is empty,
// and plain (non-TLS) HTTP connections when isTLS is false.
// It tries the fallback relay first, then the HTTP channel, and closes
// the connection if neither is available.
func (r *Router) handleUnmatched(ctx context.Context, conn net.Conn) {
func (r *Router) handleUnmatched(ctx context.Context, conn net.Conn, isTLS bool) {
r.mu.RLock()
fb := r.fallback
r.mu.RUnlock()
if fb != nil {
r.logger.Debugf("unmatched conn from %s relayed to TCP fallback (service_id=%s, target=%s)", conn.RemoteAddr(), fb.ServiceID, fb.Target)
if err := r.relayTCP(ctx, conn, SNIHost("fallback"), *fb); err != nil {
if !errors.Is(err, errAccessRestricted) {
r.logger.WithFields(log.Fields{
@@ -364,7 +424,8 @@ func (r *Router) handleUnmatched(ctx context.Context, conn net.Conn) {
}
return
}
r.sendToHTTP(conn)
r.logger.Debugf("unmatched conn from %s sent to HTTP channel (no TCP fallback configured)", conn.RemoteAddr())
r.sendToHTTP(conn, isTLS)
}
// lookupRoute returns the highest-priority route for the given SNI host.
@@ -386,10 +447,20 @@ func (r *Router) lookupRoute(host SNIHost) (Route, bool) {
}
// sendToHTTP feeds the connection to the HTTP handler via the channel.
// If no HTTP channel is configured (port router), the router is
// draining, or the channel is full, the connection is closed.
func (r *Router) sendToHTTP(conn net.Conn) {
if r.httpCh == nil {
// When isTLS is false and a plain channel is configured the connection
// is forwarded to the plain channel; otherwise it lands on the TLS
// channel. If no usable channel exists, the router is draining, or the
// channel is full, the connection is closed.
func (r *Router) sendToHTTP(conn net.Conn, isTLS bool) {
ch := r.httpCh
chanName := "HTTP"
if !isTLS && r.httpPlainCh != nil {
ch = r.httpPlainCh
chanName = "HTTP-plain"
}
if ch == nil {
r.logger.Debugf("%s channel nil; dropping conn from %s", chanName, conn.RemoteAddr())
_ = conn.Close()
return
}
@@ -399,14 +470,15 @@ func (r *Router) sendToHTTP(conn net.Conn) {
r.mu.RUnlock()
if draining {
r.logger.Debugf("router draining; dropping conn from %s", conn.RemoteAddr())
_ = conn.Close()
return
}
select {
case r.httpCh <- conn:
case ch <- conn:
default:
r.logger.Warnf("HTTP channel full, dropping connection from %s", conn.RemoteAddr())
r.logger.Warnf("%s channel full, dropping connection from %s", chanName, conn.RemoteAddr())
_ = conn.Close()
}
}

View File

@@ -1739,3 +1739,97 @@ func TestCheckRestrictions_IPv4MappedIPv6(t *testing.T) {
connOutside := &fakeConn{remote: fakeAddr("[::ffff:192.168.1.1]:5678")}
assert.NotEqual(t, restrict.Allow, router.checkRestrictions(connOutside, route), "::ffff:192.168.1.1 not in v4 CIDR")
}
// TestRouter_PlainHTTP_RoutesToPlainChannel verifies that a plain (non-TLS)
// connection lands on the plain HTTP channel when the router was built
// with WithPlainHTTP, leaving the TLS channel untouched.
func TestRouter_PlainHTTP_RoutesToPlainChannel(t *testing.T) {
logger := log.StandardLogger()
addr := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 443}
router := NewRouter(logger, nil, addr, WithPlainHTTP(addr))
router.AddRoute("example.com", Route{Type: RouteHTTP})
ln, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err, "test listener bind must succeed")
defer ln.Close()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go func() {
_ = router.Serve(ctx, ln)
}()
// Plain HTTP request (no TLS handshake byte).
go func() {
conn, err := net.DialTimeout("tcp", ln.Addr().String(), 2*time.Second)
if err != nil {
return
}
_, _ = conn.Write([]byte("GET / HTTP/1.1\r\nHost: example.com\r\n\r\n"))
}()
plainListener := router.HTTPListenerPlain()
require.NotNil(t, plainListener, "plain listener must be exposed when WithPlainHTTP is set")
acceptDone := make(chan net.Conn, 1)
go func() {
conn, err := plainListener.Accept()
if err == nil {
acceptDone <- conn
}
}()
select {
case conn := <-acceptDone:
require.NotNil(t, conn)
_ = conn.Close()
case <-router.HTTPListener().(*chanListener).ch:
t.Fatal("plain HTTP request leaked into TLS channel")
case <-time.After(3 * time.Second):
t.Fatal("plain HTTP connection never reached plain channel")
}
}
// TestRouter_TLS_StaysOnTLSChannel_WhenPlainEnabled verifies that the
// presence of a plain channel does not divert TLS traffic — TLS still
// goes to the TLS channel as before.
func TestRouter_TLS_StaysOnTLSChannel_WhenPlainEnabled(t *testing.T) {
logger := log.StandardLogger()
addr := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 443}
router := NewRouter(logger, nil, addr, WithPlainHTTP(addr))
router.AddRoute("example.com", Route{Type: RouteHTTP})
ln, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err, "test listener bind must succeed")
defer ln.Close()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go func() { _ = router.Serve(ctx, ln) }()
// Send a TLS ClientHello.
go func() {
conn, err := net.DialTimeout("tcp", ln.Addr().String(), 2*time.Second)
if err != nil {
return
}
tlsConn := tls.Client(conn, &tls.Config{
ServerName: "example.com",
InsecureSkipVerify: true, //nolint:gosec
})
_ = tlsConn.Handshake()
_ = tlsConn.Close()
}()
select {
case conn := <-router.httpCh:
require.NotNil(t, conn, "TLS conn should land on the TLS channel")
_ = conn.Close()
case <-time.After(5 * time.Second):
t.Fatal("TLS conn never reached the TLS channel")
}
}

View File

@@ -30,26 +30,30 @@ const (
// bytes transparently. If the data is not a valid TLS ClientHello or
// contains no SNI extension, sni is empty and err is nil.
//
// isTLS reports whether the first byte indicated a TLS handshake record.
// Callers can use this to distinguish plain (non-TLS) traffic from a TLS
// stream that simply lacked an SNI extension or used ECH.
//
// ECH/ESNI: When the client uses Encrypted Client Hello (TLS 1.3), the
// real server name is encrypted inside the encrypted_client_hello
// extension. This parser only reads the cleartext server_name extension
// (type 0x0000), so ECH connections return sni="" and are routed through
// the fallback path (or HTTP channel), which is the correct behavior
// for a transparent proxy that does not terminate TLS.
func PeekClientHello(conn net.Conn) (sni string, wrapped net.Conn, err error) {
func PeekClientHello(conn net.Conn) (sni string, wrapped net.Conn, isTLS bool, err error) {
// Read the 5-byte TLS record header into a small stack-friendly buffer.
var header [tlsRecordHeaderLen]byte
if _, err := io.ReadFull(conn, header[:]); err != nil {
return "", nil, fmt.Errorf("read TLS record header: %w", err)
return "", nil, false, fmt.Errorf("read TLS record header: %w", err)
}
if header[0] != contentTypeHandshake {
return "", newPeekedConn(conn, header[:]), nil
return "", newPeekedConn(conn, header[:]), false, nil
}
recordLen := int(binary.BigEndian.Uint16(header[3:5]))
if recordLen == 0 || recordLen > maxClientHelloLen {
return "", newPeekedConn(conn, header[:]), nil
return "", newPeekedConn(conn, header[:]), true, nil
}
// Single allocation for header + payload. The peekedConn takes
@@ -59,11 +63,11 @@ func PeekClientHello(conn net.Conn) (sni string, wrapped net.Conn, err error) {
n, err := io.ReadFull(conn, buf[tlsRecordHeaderLen:])
if err != nil {
return "", newPeekedConn(conn, buf[:tlsRecordHeaderLen+n]), fmt.Errorf("read TLS handshake payload: %w", err)
return "", newPeekedConn(conn, buf[:tlsRecordHeaderLen+n]), true, fmt.Errorf("read TLS handshake payload: %w", err)
}
sni = extractSNI(buf[tlsRecordHeaderLen:])
return sni, newPeekedConn(conn, buf), nil
return sni, newPeekedConn(conn, buf), true, nil
}
// extractSNI parses a TLS handshake payload to find the SNI extension.

View File

@@ -29,10 +29,11 @@ func TestPeekClientHello_ValidSNI(t *testing.T) {
_ = tlsConn.Handshake()
}()
sni, wrapped, err := PeekClientHello(serverConn)
sni, wrapped, isTLS, err := PeekClientHello(serverConn)
require.NoError(t, err)
assert.Equal(t, expectedSNI, sni, "should extract SNI from ClientHello")
assert.NotNil(t, wrapped, "wrapped connection should not be nil")
assert.True(t, isTLS, "TLS ClientHello should be flagged as TLS")
// Verify the wrapped connection replays the peeked bytes.
// Read the first 5 bytes (TLS record header) to confirm replay.
@@ -83,10 +84,11 @@ func TestPeekClientHello_MultipleSNIs(t *testing.T) {
_ = tlsConn.Handshake()
}()
sni, wrapped, err := PeekClientHello(serverConn)
sni, wrapped, isTLS, err := PeekClientHello(serverConn)
require.NoError(t, err)
assert.Equal(t, tt.expectedSNI, sni)
assert.NotNil(t, wrapped)
assert.True(t, isTLS, "TLS handshake should be flagged as TLS")
})
}
}
@@ -102,10 +104,11 @@ func TestPeekClientHello_NonTLSData(t *testing.T) {
_, _ = clientConn.Write(httpData)
}()
sni, wrapped, err := PeekClientHello(serverConn)
sni, wrapped, isTLS, err := PeekClientHello(serverConn)
require.NoError(t, err)
assert.Empty(t, sni, "should return empty SNI for non-TLS data")
assert.NotNil(t, wrapped)
assert.False(t, isTLS, "plain HTTP data should not be flagged as TLS")
// Verify the wrapped connection still provides the original data.
buf := make([]byte, len(httpData))
@@ -124,7 +127,7 @@ func TestPeekClientHello_TruncatedHeader(t *testing.T) {
clientConn.Close()
}()
_, _, err := PeekClientHello(serverConn)
_, _, _, err := PeekClientHello(serverConn)
assert.Error(t, err, "should error on truncated header")
}
@@ -140,7 +143,7 @@ func TestPeekClientHello_TruncatedPayload(t *testing.T) {
clientConn.Close()
}()
_, _, err := PeekClientHello(serverConn)
_, _, _, err := PeekClientHello(serverConn)
assert.Error(t, err, "should error on truncated payload")
}
@@ -154,10 +157,11 @@ func TestPeekClientHello_ZeroLengthRecord(t *testing.T) {
_, _ = clientConn.Write([]byte{0x16, 0x03, 0x01, 0x00, 0x00})
}()
sni, wrapped, err := PeekClientHello(serverConn)
sni, wrapped, isTLS, err := PeekClientHello(serverConn)
require.NoError(t, err)
assert.Empty(t, sni)
assert.NotNil(t, wrapped)
assert.True(t, isTLS, "zero-length record should still be a TLS handshake byte")
}
func TestExtractSNI_InvalidPayload(t *testing.T) {

View File

@@ -54,3 +54,23 @@ func DialTimeoutFromContext(ctx context.Context) (time.Duration, bool) {
d, ok := ctx.Value(dialTimeoutKey{}).(time.Duration)
return d, ok && d > 0
}
// overlayOriginKey is the context key set by per-account inbound
// listeners to mark a request as originating from the WireGuard
// overlay rather than the public-facing host listener.
type overlayOriginKey struct{}
// WithOverlayOrigin marks the context as originating from the
// embedded NetBird overlay (tunnel-side inbound listener).
func WithOverlayOrigin(ctx context.Context) context.Context {
return context.WithValue(ctx, overlayOriginKey{}, true)
}
// IsOverlayOrigin reports whether the request reached the proxy via
// the overlay listener. Middlewares that only make sense for WAN
// traffic (geolocation, CrowdSec IP reputation) should short-circuit
// when this is true.
func IsOverlayOrigin(ctx context.Context) bool {
v, _ := ctx.Value(overlayOriginKey{}).(bool)
return v
}

160
proxy/lifecycle.go Normal file
View File

@@ -0,0 +1,160 @@
package proxy
import (
"net/netip"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/proxy/internal/acme"
)
// Config bundles every knob the proxy reads at construction time. It mirrors
// the public fields on Server so library callers don't have to learn the
// internal struct layout. Zero values mean "feature off" or "fall back to the
// internal default" depending on the field — see the per-field doc.
//
// The standalone binary continues to populate Server fields directly, so
// adding fields here must not change the zero-value behaviour of Server.
type Config struct {
// ListenAddr is the TCP address the main listener binds. Required.
ListenAddr string
// ID identifies this proxy instance to management. Empty value lets
// New generate a timestamped default.
ID string
// Logger is the logrus logger used everywhere. Empty value falls back
// to log.StandardLogger().
Logger *log.Logger
// Version is the build version string reported to management. Empty
// becomes "dev".
Version string
// ProxyURL is the public address operators use to reach this proxy.
ProxyURL string
// ManagementAddress is the gRPC URL of the management server.
ManagementAddress string
// ProxyToken authenticates this proxy with the management server.
ProxyToken string
// CertificateDirectory is the directory holding TLS certificate
// material (static or ACME-provisioned).
CertificateDirectory string
// CertificateFile is the certificate filename within
// CertificateDirectory.
CertificateFile string
// CertificateKeyFile is the private key filename within
// CertificateDirectory.
CertificateKeyFile string
// GenerateACMECertificates toggles ACME certificate provisioning.
GenerateACMECertificates bool
// ACMEChallengeAddress is the listen address for HTTP-01 challenges.
ACMEChallengeAddress string
// ACMEDirectory is the ACME directory URL (Let's Encrypt by default).
ACMEDirectory string
// ACMEEABKID is the External Account Binding Key ID for CAs that
// require EAB (e.g. ZeroSSL).
ACMEEABKID string
// ACMEEABHMACKey is the External Account Binding HMAC key for CAs
// that require EAB.
ACMEEABHMACKey string
// ACMEChallengeType is the ACME challenge type ("tls-alpn-01" or
// "http-01"). Empty defaults to "tls-alpn-01".
ACMEChallengeType string
// CertLockMethod controls how ACME certificate locks are coordinated
// across replicas.
CertLockMethod acme.CertLockMethod
// WildcardCertDir is an optional directory containing static wildcard
// certificates that override ACME for matching domains.
WildcardCertDir string
// DebugEndpointEnabled toggles the debug HTTP endpoint.
DebugEndpointEnabled bool
// DebugEndpointAddress is the bind address for the debug endpoint.
DebugEndpointAddress string
// HealthAddr is the bind address for the health probe and metrics
// surface. Empty disables the health probe entirely (library callers
// can attach their own).
HealthAddr string
// ForwardedProto overrides the X-Forwarded-Proto value sent to
// backends. Valid values: "auto", "http", "https".
ForwardedProto string
// TrustedProxies is a list of IP prefixes for trusted upstream
// proxies that may set forwarding headers.
TrustedProxies []netip.Prefix
// WireguardPort is the UDP port for the embedded NetBird tunnel.
// Zero asks the OS for a random port.
WireguardPort uint16
// ProxyProtocol enables PROXY protocol (v1/v2) on TCP listeners.
ProxyProtocol bool
// PreSharedKey is the WireGuard pre-shared key used between the
// proxy's embedded clients and peers.
PreSharedKey string
// SupportsCustomPorts indicates whether the proxy can bind arbitrary
// ports for TCP/UDP/TLS services.
SupportsCustomPorts bool
// RequireSubdomain forces accounts to use a subdomain in front of
// the proxy's cluster domain.
RequireSubdomain bool
// Private flags this proxy as embedded in a netbird client and
// serving exclusively over the WireGuard tunnel. Also enables
// per-account inbound listeners on each embedded client's netstack.
Private bool
// MaxDialTimeout caps the per-service backend dial timeout.
MaxDialTimeout time.Duration
// MaxSessionIdleTimeout caps the per-service session idle timeout.
MaxSessionIdleTimeout time.Duration
// GeoDataDir is the directory containing GeoLite2 MMDB files.
GeoDataDir string
// CrowdSecAPIURL is the CrowdSec LAPI URL. Empty disables CrowdSec.
CrowdSecAPIURL string
// CrowdSecAPIKey is the CrowdSec bouncer API key. Empty disables
// CrowdSec.
CrowdSecAPIKey string
}
// New builds a Server from cfg without performing any I/O. No goroutines
// are spawned, no network connections are dialed, and no listeners are
// bound — call Start to bring the proxy up. Returning a fully-formed
// Server keeps the standalone code path (which still constructs Server
// directly) byte-for-byte equivalent.
func New(cfg Config) *Server {
return &Server{
ListenAddr: cfg.ListenAddr,
ID: cfg.ID,
Logger: cfg.Logger,
Version: cfg.Version,
ProxyURL: cfg.ProxyURL,
ManagementAddress: cfg.ManagementAddress,
ProxyToken: cfg.ProxyToken,
CertificateDirectory: cfg.CertificateDirectory,
CertificateFile: cfg.CertificateFile,
CertificateKeyFile: cfg.CertificateKeyFile,
GenerateACMECertificates: cfg.GenerateACMECertificates,
ACMEChallengeAddress: cfg.ACMEChallengeAddress,
ACMEDirectory: cfg.ACMEDirectory,
ACMEEABKID: cfg.ACMEEABKID,
ACMEEABHMACKey: cfg.ACMEEABHMACKey,
ACMEChallengeType: cfg.ACMEChallengeType,
CertLockMethod: cfg.CertLockMethod,
WildcardCertDir: cfg.WildcardCertDir,
DebugEndpointEnabled: cfg.DebugEndpointEnabled,
DebugEndpointAddress: cfg.DebugEndpointAddress,
HealthAddress: cfg.HealthAddr,
ForwardedProto: cfg.ForwardedProto,
TrustedProxies: cfg.TrustedProxies,
WireguardPort: cfg.WireguardPort,
ProxyProtocol: cfg.ProxyProtocol,
PreSharedKey: cfg.PreSharedKey,
SupportsCustomPorts: cfg.SupportsCustomPorts,
RequireSubdomain: cfg.RequireSubdomain,
Private: cfg.Private,
MaxDialTimeout: cfg.MaxDialTimeout,
MaxSessionIdleTimeout: cfg.MaxSessionIdleTimeout,
GeoDataDir: cfg.GeoDataDir,
CrowdSecAPIURL: cfg.CrowdSecAPIURL,
CrowdSecAPIKey: cfg.CrowdSecAPIKey,
}
}

View File

@@ -239,6 +239,10 @@ func (m *testProxyManager) ClusterSupportsCrowdSec(_ context.Context, _ string)
return nil
}
func (m *testProxyManager) ClusterSupportsPrivate(_ context.Context, _ string) *bool {
return nil
}
func (m *testProxyManager) CleanupStale(_ context.Context, _ time.Duration) error {
return nil
}
@@ -565,6 +569,7 @@ func TestIntegration_ProxyConnection_ReconnectDoesNotDuplicateState(t *testing.T
proxytypes.AccountID(mapping.GetAccountId()),
proxytypes.ServiceID(mapping.GetId()),
nil,
mapping.GetPrivate(),
)
require.NoError(t, err)

View File

@@ -37,6 +37,8 @@ import (
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/keepalive"
grpcstatus "google.golang.org/grpc/status"
"google.golang.org/protobuf/encoding/protojson"
goproto "google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/netbirdio/netbird/proxy/internal/accesslog"
@@ -114,9 +116,28 @@ type Server struct {
// The mapping worker waits on this before processing updates.
routerReady chan struct{}
// inbound, when non-nil, manages per-account inbound listeners. Set by
// initPrivateInbound only when Private is true so the standalone
// proxy keeps its zero-overhead default path.
inbound *inboundManager
// Lifecycle state — populated by Start, consumed by Stop. The fields
// stay zero on a fresh Server until Start runs so direct struct
// construction (`&Server{...}`) used by tests still works.
runCancel context.CancelFunc
mgmtConn *grpc.ClientConn
runErr error
runErrCh chan struct{}
startMu sync.Mutex
started bool
stopOnce sync.Once
// Mostly used for debugging on management.
startTime time.Time
// ListenAddr is the address the main TCP listener binds. Populated by
// New from Config or by ListenAndServe from its addr argument.
ListenAddr string
ID string
Logger *log.Logger
Version string
@@ -177,6 +198,14 @@ type Server struct {
// in front of this proxy's cluster domain. When true, accounts cannot
// create services on the bare cluster domain.
RequireSubdomain bool
// Private flags this proxy as embedded in a netbird client and serving
// exclusively over the WireGuard tunnel (i.e. `netbird proxy`). Reported
// upstream as a capability so dashboards can distinguish per-peer
// clusters from centralised ones, and turns on per-account inbound
// listeners on each embedded client's netstack: every account that
// registers a service exposes :80 + :443 inside its own WG tunnel,
// scoped to that account's services only.
Private bool
// MaxDialTimeout caps the per-service backend dial timeout.
// When the API sends a timeout, it is clamped to this value.
// When the API sends no timeout, this value is used as the default.
@@ -222,12 +251,16 @@ func (s *Server) NotifyStatus(ctx context.Context, accountID types.AccountID, se
status = proto.ProxyStatus_PROXY_STATUS_ACTIVE
}
_, err := s.mgmtClient.SendStatusUpdate(ctx, &proto.SendStatusUpdateRequest{
req := &proto.SendStatusUpdateRequest{
ServiceId: string(serviceID),
AccountId: string(accountID),
Status: status,
CertificateIssued: false,
})
}
if connected {
req.InboundListener = s.inboundListenerProto(accountID)
}
_, err := s.mgmtClient.SendStatusUpdate(ctx, req)
return err
}
@@ -238,56 +271,68 @@ func (s *Server) NotifyCertificateIssued(ctx context.Context, accountID types.Ac
AccountId: string(accountID),
Status: proto.ProxyStatus_PROXY_STATUS_ACTIVE,
CertificateIssued: true,
InboundListener: s.inboundListenerProto(accountID),
})
return err
}
func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
s.initDefaults()
s.routerReady = make(chan struct{})
s.udpRelays = make(map[types.ServiceID]*udprelay.Relay)
s.portRouters = make(map[uint16]*portRouter)
s.svcPorts = make(map[types.ServiceID][]uint16)
s.lastMappings = make(map[types.ServiceID]*proto.ProxyMapping)
exporter, err := prometheus.New()
if err != nil {
return fmt.Errorf("create prometheus exporter: %w", err)
// inboundListenerProto resolves the per-account inbound listener state for
// the SendStatusUpdate payload. Returns nil when --private-inbound is off
// or the account has no live listener so management treats the field as
// absent.
func (s *Server) inboundListenerProto(accountID types.AccountID) *proto.ProxyInboundListener {
if s.inbound == nil {
return nil
}
provider := metric.NewMeterProvider(metric.WithReader(exporter))
pkg := reflect.TypeOf(Server{}).PkgPath()
meter := provider.Meter(pkg)
s.meter, err = proxymetrics.New(ctx, meter)
if err != nil {
return fmt.Errorf("create metrics: %w", err)
info, ok := s.inbound.ListenerInfo(accountID)
if !ok || info.TunnelIP == "" {
return nil
}
return &proto.ProxyInboundListener{
TunnelIp: info.TunnelIP,
HttpsPort: uint32(info.HTTPSPort),
HttpPort: uint32(info.HTTPPort),
}
}
mgmtConn, err := s.dialManagement()
if err != nil {
// ListenAndServe is the standalone entrypoint. It binds the listener, runs
// the proxy until ctx is cancelled or a background goroutine fails, then
// drains and stops. Library callers should prefer New + Start + Stop and
// own their own shutdown signalling.
func (s *Server) ListenAndServe(ctx context.Context, addr string) error {
s.ListenAddr = addr
if err := s.Start(ctx); err != nil {
return err
}
defer func() {
if err := mgmtConn.Close(); err != nil {
s.Logger.Debugf("management connection close: %v", err)
}
}()
s.mgmtClient = proto.NewProxyServiceClient(mgmtConn)
return s.waitAndStop(ctx)
}
// Start brings the proxy up: dials management, configures TLS, binds the
// main listener, and spawns the SNI router and HTTPS server goroutines. It
// returns once the listener is bound; background errors are surfaced
// through Stop's return value. Start is not safe to call twice.
func (s *Server) Start(ctx context.Context) error {
s.startMu.Lock()
if s.started {
s.startMu.Unlock()
return errors.New("proxy already started")
}
s.started = true
s.startMu.Unlock()
s.initLifecycleState()
if err := s.initMetrics(ctx); err != nil {
return err
}
if err := s.initManagementClient(); err != nil {
return err
}
runCtx, runCancel := context.WithCancel(ctx)
defer runCancel()
s.runCancel = runCancel
// Initialize the netbird client, this is required to build peer connections
// to proxy over.
s.netbird = roundtrip.NewNetBird(s.ID, s.ProxyURL, roundtrip.ClientConfig{
MgmtAddr: s.ManagementAddress,
WGPort: s.WireguardPort,
PreSharedKey: s.PreSharedKey,
}, s.Logger, s, s.mgmtClient)
s.netbird.OnAddPeer = s.meter.RecordAddPeerDuration
// Create health checker before the mapping worker so it can track
// management connectivity from the first stream connection.
s.initNetBirdClient()
s.healthChecker = health.NewChecker(s.Logger, s.netbird)
s.crowdsecRegistry = crowdsec.NewRegistry(s.CrowdSecAPIURL, s.CrowdSecAPIKey, log.NewEntry(s.Logger))
@@ -300,34 +345,25 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
return err
}
// Configure the reverse proxy using NetBird's HTTP Client Transport for proxying.
s.proxy = proxy.NewReverseProxy(s.meter.RoundTripper(s.netbird), s.ForwardedProto, s.TrustedProxies, s.Logger)
s.initReverseProxy()
geoLookup, err := geolocation.NewLookup(s.Logger, s.GeoDataDir)
if err != nil {
return fmt.Errorf("initialize geolocation: %w", err)
}
s.geoRaw = geoLookup
if geoLookup != nil {
s.geo = geoLookup
if err := s.initGeoLookup(); err != nil {
return err
}
var startupOK bool
startupOK := false
defer func() {
if startupOK {
return
}
if s.geoRaw != nil {
if err := s.geoRaw.Close(); err != nil {
s.Logger.Debugf("close geolocation on startup failure: %v", err)
if closeErr := s.geoRaw.Close(); closeErr != nil {
s.Logger.Debugf("close geolocation on startup failure: %v", closeErr)
}
}
}()
// Configure the authentication middleware with session validator for OIDC group checks.
s.auth = auth.NewMiddleware(s.Logger, s.mgmtClient, s.geo)
// Configure Access logs to management server.
s.accessLog = accesslog.NewLogger(s.mgmtClient, s.Logger, s.TrustedProxies)
s.startDebugEndpoint()
@@ -336,35 +372,21 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
return err
}
// Build the handler chain from inside out.
handler := http.Handler(s.proxy)
handler = s.auth.Protect(handler)
handler = web.AssetHandler(handler)
handler = s.accessLog.Middleware(handler)
handler = s.meter.Middleware(handler)
handler = s.hijackTracker.Middleware(handler)
handler := s.buildHandlerChain()
s.initPrivateInbound(handler, tlsConfig)
// Start a raw TCP listener; the SNI router peeks at ClientHello
// and routes to either the HTTP handler or a TCP relay.
lc := net.ListenConfig{}
ln, err := lc.Listen(ctx, "tcp", addr)
ln, err := s.bindMainListener(ctx)
if err != nil {
return fmt.Errorf("listen on %s: %w", addr, err)
return err
}
if s.ProxyProtocol {
ln = s.wrapProxyProtocol(ln)
}
s.mainPort = uint16(ln.Addr().(*net.TCPAddr).Port) //nolint:gosec // port from OS is always valid
// Set up the SNI router for TCP/HTTP multiplexing on the main port.
s.mainRouter = nbtcp.NewRouter(s.Logger, s.resolveDialFunc, ln.Addr())
s.mainRouter.SetObserver(s.meter)
s.mainRouter.SetAccessLogger(s.accessLog)
close(s.routerReady)
// The HTTP server uses the chanListener fed by the SNI router.
s.https = &http.Server{
Addr: addr,
Addr: s.ListenAddr,
Handler: handler,
TLSConfig: tlsConfig,
ReadHeaderTimeout: httpReadHeaderTimeout,
@@ -374,35 +396,201 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
startupOK = true
httpsErr := make(chan error, 1)
go func() {
s.Logger.Debug("starting HTTPS server on SNI router HTTP channel")
httpsErr <- s.https.ServeTLS(s.mainRouter.HTTPListener(), "", "")
if serveErr := s.https.ServeTLS(s.mainRouter.HTTPListener(), "", ""); serveErr != nil && !errors.Is(serveErr, http.ErrServerClosed) {
s.recordRunErr(fmt.Errorf("https server: %w", serveErr))
}
}()
routerErr := make(chan error, 1)
go func() {
s.Logger.Debugf("starting SNI router on %s", addr)
routerErr <- s.mainRouter.Serve(runCtx, ln)
s.Logger.Debugf("starting SNI router on %s", s.ListenAddr)
if serveErr := s.mainRouter.Serve(runCtx, ln); serveErr != nil {
s.recordRunErr(fmt.Errorf("SNI router: %w", serveErr))
}
}()
return nil
}
// Stop drains in-flight connections, shuts down all background services,
// and releases resources. Idempotent; calling it before Start is a no-op.
// Returns the first fatal error reported by a background goroutine, if
// any. The provided ctx bounds the total wait time; once it is cancelled
// Stop returns even if drain is still in flight.
func (s *Server) Stop(ctx context.Context) error {
s.stopOnce.Do(func() {
s.startMu.Lock()
started := s.started
s.startMu.Unlock()
if !started {
return
}
done := make(chan struct{})
go func() {
defer close(done)
s.gracefulShutdown()
if s.runCancel != nil {
s.runCancel()
}
if s.mgmtConn != nil {
if err := s.mgmtConn.Close(); err != nil {
s.Logger.Debugf("management connection close: %v", err)
}
}
}()
select {
case <-done:
case <-ctx.Done():
s.Logger.Warnf("proxy stop deadline exceeded: %v", ctx.Err())
}
})
s.startMu.Lock()
defer s.startMu.Unlock()
return s.runErr
}
// waitAndStop blocks until ctx is cancelled or a background goroutine
// reports a fatal error, then drains and stops. Used by ListenAndServe.
func (s *Server) waitAndStop(ctx context.Context) error {
select {
case err := <-httpsErr:
s.shutdownServices()
if !errors.Is(err, http.ErrServerClosed) {
return fmt.Errorf("https server: %w", err)
}
return nil
case err := <-routerErr:
s.shutdownServices()
if err != nil {
return fmt.Errorf("SNI router: %w", err)
}
return nil
case <-ctx.Done():
s.gracefulShutdown()
return nil
case <-s.runErrCh:
}
stopCtx, cancel := context.WithTimeout(context.Background(), shutdownDrainTimeout+shutdownServiceTimeout)
defer cancel()
return s.Stop(stopCtx)
}
// recordRunErr stores the first fatal background error and signals
// waitAndStop. Subsequent errors are logged at debug level so the first
// cause is preserved.
func (s *Server) recordRunErr(err error) {
s.startMu.Lock()
defer s.startMu.Unlock()
if s.runErr != nil {
s.Logger.Debugf("background error after first failure: %v", err)
return
}
s.runErr = err
if s.runErrCh != nil {
close(s.runErrCh)
}
}
// initLifecycleState seeds the maps and channels Start needs to wire up
// background goroutines. Called once at the top of Start.
func (s *Server) initLifecycleState() {
s.initDefaults()
s.routerReady = make(chan struct{})
s.runErrCh = make(chan struct{})
s.udpRelays = make(map[types.ServiceID]*udprelay.Relay)
s.portRouters = make(map[uint16]*portRouter)
s.svcPorts = make(map[types.ServiceID][]uint16)
s.lastMappings = make(map[types.ServiceID]*proto.ProxyMapping)
}
// initMetrics builds the prometheus exporter and meter bundle.
func (s *Server) initMetrics(ctx context.Context) error {
exporter, err := prometheus.New()
if err != nil {
return fmt.Errorf("create prometheus exporter: %w", err)
}
provider := metric.NewMeterProvider(metric.WithReader(exporter))
pkg := reflect.TypeOf(Server{}).PkgPath()
meter := provider.Meter(pkg)
s.meter, err = proxymetrics.New(ctx, meter)
if err != nil {
return fmt.Errorf("create metrics: %w", err)
}
return nil
}
// initManagementClient dials management and stashes the connection so
// Stop can close it deterministically.
func (s *Server) initManagementClient() error {
conn, err := s.dialManagement()
if err != nil {
return err
}
s.mgmtConn = conn
s.mgmtClient = proto.NewProxyServiceClient(conn)
return nil
}
// initNetBirdClient builds the multi-tenant embedded NetBird client used
// for outbound RoundTripping and (when --private-inbound is on) per-account
// inbound listeners.
func (s *Server) initNetBirdClient() {
s.netbird = roundtrip.NewNetBird(s.ID, s.ProxyURL, roundtrip.ClientConfig{
MgmtAddr: s.ManagementAddress,
WGPort: s.WireguardPort,
PreSharedKey: s.PreSharedKey,
// On --private the embedded client serves per-account inbound
// listeners and must apply management's ACL: keep BlockInbound off
// so the engine creates the ACL manager. On the standalone proxy
// the embedded client never accepts inbound, so block.
BlockInbound: !s.Private,
}, s.Logger, s, s.mgmtClient)
s.netbird.OnAddPeer = s.meter.RecordAddPeerDuration
}
// initReverseProxy builds the meter-instrumented reverse proxy. MultiTransport
// routes targets opted into direct_upstream through the host's network stack
// (stdlib transport); everything else falls through to the embedded NetBird
// client. The split is needed so direct_upstream targets resolve DNS via the
// proxy host's resolver instead of the tunnel's DNS.
func (s *Server) initReverseProxy() {
upstreamRT := roundtrip.NewMultiTransport(s.netbird, s.Logger)
s.proxy = proxy.NewReverseProxy(s.meter.RoundTripper(upstreamRT), s.ForwardedProto, s.TrustedProxies, s.Logger)
}
// initGeoLookup configures the GeoLite2 lookup used for country-based
// access restrictions and access-log enrichment.
func (s *Server) initGeoLookup() error {
geoLookup, err := geolocation.NewLookup(s.Logger, s.GeoDataDir)
if err != nil {
return fmt.Errorf("initialize geolocation: %w", err)
}
s.geoRaw = geoLookup
if geoLookup != nil {
s.geo = geoLookup
}
return nil
}
// buildHandlerChain wires the request middlewares from inside out.
func (s *Server) buildHandlerChain() http.Handler {
handler := http.Handler(s.proxy)
handler = s.auth.Protect(handler)
handler = web.AssetHandler(handler)
handler = s.accessLog.Middleware(handler)
handler = s.meter.Middleware(handler)
return s.hijackTracker.Middleware(handler)
}
// bindMainListener binds the main TCP listener and wraps it with PROXY
// protocol when configured.
func (s *Server) bindMainListener(ctx context.Context) (net.Listener, error) {
lc := net.ListenConfig{}
ln, err := lc.Listen(ctx, "tcp", s.ListenAddr)
if err != nil {
return nil, fmt.Errorf("listen on %s: %w", s.ListenAddr, err)
}
if s.ProxyProtocol {
ln = s.wrapProxyProtocol(ln)
}
s.mainPort = uint16(ln.Addr().(*net.TCPAddr).Port) //nolint:gosec // port from OS is always valid
s.Logger.WithFields(log.Fields{
"requested_addr": s.ListenAddr,
"bound_addr": ln.Addr().String(),
"private": s.Private,
"proxy_protocol": s.ProxyProtocol,
}).Info("proxy main listener bound")
return ln, nil
}
// initDefaults sets fallback values for optional Server fields.
@@ -434,6 +622,9 @@ func (s *Server) startDebugEndpoint() {
if s.acme != nil {
debugHandler.SetCertStatus(s.acme)
}
if s.inbound != nil {
debugHandler.SetInboundProvider(inboundDebugAdapter{mgr: s.inbound})
}
s.debug = &http.Server{
Addr: debugAddr,
Handler: debugHandler,
@@ -447,16 +638,18 @@ func (s *Server) startDebugEndpoint() {
}()
}
// startHealthServer launches the health probe and metrics server.
// startHealthServer launches the health probe and metrics server. Empty
// HealthAddress disables the probe entirely (intended for library callers
// that want to manage their own health surface).
func (s *Server) startHealthServer() error {
healthAddr := s.HealthAddress
if healthAddr == "" {
healthAddr = defaultHealthAddr
if s.HealthAddress == "" {
s.Logger.Debug("health probe disabled (empty HealthAddress)")
return nil
}
s.healthServer = health.NewServer(healthAddr, s.healthChecker, s.Logger, promhttp.HandlerFor(prometheus2.DefaultGatherer, promhttp.HandlerOpts{EnableOpenMetrics: true}))
healthListener, err := net.Listen("tcp", healthAddr)
s.healthServer = health.NewServer(s.HealthAddress, s.healthChecker, s.Logger, promhttp.HandlerFor(prometheus2.DefaultGatherer, promhttp.HandlerOpts{EnableOpenMetrics: true}))
healthListener, err := net.Listen("tcp", s.HealthAddress)
if err != nil {
return fmt.Errorf("health probe server listen on %s: %w", healthAddr, err)
return fmt.Errorf("health probe server listen on %s: %w", s.HealthAddress, err)
}
go func() {
if err := s.healthServer.Serve(healthListener); err != nil && !errors.Is(err, http.ErrServerClosed) {
@@ -507,8 +700,9 @@ func (s *Server) proxyProtocolPolicy(opts proxyproto.ConnPolicyOptions) (proxypr
}
const (
defaultHealthAddr = "localhost:8080"
defaultDebugAddr = "localhost:8444"
// defaultDebugAddr is the localhost-bound fallback for the debug endpoint
// when DebugEndpointAddress is empty.
defaultDebugAddr = "localhost:8444"
// proxyProtoHeaderTimeout is the deadline for reading the PROXY protocol
// header after accepting a connection.
@@ -661,8 +855,10 @@ func (s *Server) gracefulShutdown() {
defer drainCancel()
s.Logger.Info("draining in-flight connections")
if err := s.https.Shutdown(drainCtx); err != nil {
s.Logger.Warnf("https server drain: %v", err)
if s.https != nil {
if err := s.https.Shutdown(drainCtx); err != nil {
s.Logger.Warnf("https server drain: %v", err)
}
}
// Step 4: Close hijacked connections (WebSocket) that Shutdown does not handle.
@@ -809,6 +1005,18 @@ func (s *Server) resolveDialFunc(accountID types.AccountID) (types.DialContextFu
return client.DialContext, nil
}
// initPrivateInbound wires per-account inbound listeners when --private
// is set. When the flag is off this is a no-op and the standalone proxy keeps
// its byte-for-byte previous behaviour.
func (s *Server) initPrivateInbound(handler http.Handler, tlsConfig *tls.Config) {
if !s.Private {
return
}
s.inbound = newInboundManager(s.Logger, handler, tlsConfig)
s.netbird.SetClientLifecycle(s.inbound.onClientReady, s.inbound.onClientStop)
s.Logger.Info("private inbound listeners enabled (per-account :80 + :443)")
}
// notifyError reports a resource error back to management so it can be
// surfaced to the user (e.g. port bind failure, dialer resolution error).
func (s *Server) notifyError(ctx context.Context, mapping *proto.ProxyMapping, err error) {
@@ -942,7 +1150,8 @@ func (s *Server) newManagementMappingWorker(ctx context.Context, client proto.Pr
}
// syncSupported tracks whether management supports SyncMappings.
// Starts true; set to false on first Unimplemented error.
// Starts true; set to false on the first Unimplemented error so
// subsequent retries skip straight to GetMappingUpdate.
syncSupported := true
initialSyncDone := false
@@ -992,10 +1201,15 @@ func (s *Server) newManagementMappingWorker(ctx context.Context, client proto.Pr
func (s *Server) proxyCapabilities() *proto.ProxyCapabilities {
supportsCrowdSec := s.crowdsecRegistry.Available()
privateCapability := s.Private
// Always true: this build enforces ProxyMapping.private via the auth middleware.
supportsPrivateService := true
return &proto.ProxyCapabilities{
SupportsCustomPorts: &s.SupportsCustomPorts,
RequireSubdomain: &s.RequireSubdomain,
SupportsCrowdsec: &supportsCrowdSec,
SupportsCustomPorts: &s.SupportsCustomPorts,
RequireSubdomain: &s.RequireSubdomain,
SupportsCrowdsec: &supportsCrowdSec,
Private: &privateCapability,
SupportsPrivateService: &supportsPrivateService,
}
}
@@ -1027,7 +1241,6 @@ func (s *Server) trySyncMappings(ctx context.Context, client proto.ProxyServiceC
return fmt.Errorf("create sync stream: %w", err)
}
// Send init message.
if err := stream.Send(&proto.SyncMappingsRequest{
Msg: &proto.SyncMappingsRequest_Init{
Init: &proto.SyncMappingsInit{
@@ -1058,6 +1271,10 @@ func isSyncUnimplemented(err error) bool {
return ok && st.Code() == codes.Unimplemented
}
// handleSyncMappingsStream consumes batches from a bidirectional SyncMappings
// stream, sending an ack after each batch is fully processed. Management waits
// for the ack before sending the next batch, providing application-level
// back-pressure.
func (s *Server) handleSyncMappingsStream(ctx context.Context, stream proto.ProxyService_SyncMappingsClient, initialSyncDone *bool, connectTime time.Time) error {
select {
case <-s.routerReady:
@@ -1095,39 +1312,10 @@ func (s *Server) handleSyncMappingsStream(ctx context.Context, stream proto.Prox
}
}
func (s *Server) handleMappingStream(ctx context.Context, mappingClient proto.ProxyService_GetMappingUpdateClient, initialSyncDone *bool, connectTime time.Time) error {
select {
case <-s.routerReady:
case <-ctx.Done():
return ctx.Err()
}
tracker := s.newSnapshotTracker(initialSyncDone, connectTime)
for {
select {
case <-ctx.Done():
return ctx.Err()
default:
msg, err := mappingClient.Recv()
switch {
case errors.Is(err, io.EOF):
return nil
case err != nil:
return fmt.Errorf("receive msg: %w", err)
}
batchStart := time.Now()
s.Logger.Debug("Received mapping update, starting processing")
s.processMappings(ctx, msg.GetMapping())
s.Logger.Debug("Processing mapping update completed")
tracker.recordBatch(ctx, s, msg.GetMapping(), msg.GetInitialSyncComplete(), batchStart)
}
}
}
// snapshotTracker accumulates service IDs during the initial snapshot and
// finalises sync state when the complete flag arrives.
// finalises sync state when the complete flag arrives. Used by both
// handleMappingStream and handleSyncMappingsStream so metric emission and
// reconciliation behave identically on either RPC.
type snapshotTracker struct {
done *bool
connectTime time.Time
@@ -1171,6 +1359,37 @@ func (t *snapshotTracker) recordBatch(ctx context.Context, s *Server, mappings [
s.Logger.Info("Initial mapping sync complete")
}
func (s *Server) handleMappingStream(ctx context.Context, mappingClient proto.ProxyService_GetMappingUpdateClient, initialSyncDone *bool, connectTime time.Time) error {
select {
case <-s.routerReady:
case <-ctx.Done():
return ctx.Err()
}
tracker := s.newSnapshotTracker(initialSyncDone, connectTime)
for {
select {
case <-ctx.Done():
return ctx.Err()
default:
msg, err := mappingClient.Recv()
switch {
case errors.Is(err, io.EOF):
return nil
case err != nil:
return fmt.Errorf("receive msg: %w", err)
}
batchStart := time.Now()
s.Logger.Debug("Received mapping update, starting processing")
s.processMappings(ctx, msg.GetMapping())
s.Logger.Debug("Processing mapping update completed")
tracker.recordBatch(ctx, s, msg.GetMapping(), msg.GetInitialSyncComplete(), batchStart)
}
}
}
// reconcileSnapshot removes local mappings that are absent from the snapshot.
// This ensures services deleted while the proxy was disconnected get cleaned up.
func (s *Server) reconcileSnapshot(ctx context.Context, snapshotIDs map[types.ServiceID]struct{}) {
@@ -1192,17 +1411,58 @@ func (s *Server) reconcileSnapshot(ctx context.Context, snapshotIDs map[types.Se
}
}
func (s *Server) processMappings(ctx context.Context, mappings []*proto.ProxyMapping) {
s.ensurePeers(ctx, mappings)
// mappingJSONMarshal dumps mappings on one line with zero-value fields visible for debug logs.
var mappingJSONMarshal = protojson.MarshalOptions{
Multiline: false,
EmitUnpopulated: true,
UseProtoNames: true,
}
// redactMappingForLog returns a deep copy of the mapping with sensitive fields
// (auth_token, header-auth hashed values, custom upstream headers) replaced so
// debug logs never carry credentials.
func redactMappingForLog(m *proto.ProxyMapping) *proto.ProxyMapping {
const placeholder = "[REDACTED]"
c := goproto.Clone(m).(*proto.ProxyMapping)
if c.GetAuthToken() != "" {
c.AuthToken = placeholder
}
if c.Auth != nil {
for _, h := range c.Auth.GetHeaderAuths() {
if h.GetHashedValue() != "" {
h.HashedValue = placeholder
}
}
}
for _, p := range c.GetPath() {
opts := p.GetOptions()
if opts == nil || len(opts.CustomHeaders) == 0 {
continue
}
redacted := make(map[string]string, len(opts.CustomHeaders))
for k := range opts.CustomHeaders {
redacted[k] = placeholder
}
opts.CustomHeaders = redacted
}
return c
}
func (s *Server) processMappings(ctx context.Context, mappings []*proto.ProxyMapping) {
debug := s.Logger != nil && s.Logger.IsLevelEnabled(log.DebugLevel)
for _, mapping := range mappings {
s.Logger.WithFields(log.Fields{
"type": mapping.GetType(),
"domain": mapping.GetDomain(),
"mode": mapping.GetMode(),
"port": mapping.GetListenPort(),
"id": mapping.GetId(),
}).Debug("Processing mapping update")
if debug {
raw, err := mappingJSONMarshal.Marshal(redactMappingForLog(mapping))
if err != nil {
raw = []byte(fmt.Sprintf("<marshal error: %v>", err))
}
s.Logger.WithFields(log.Fields{
"type": mapping.GetType(),
"domain": mapping.GetDomain(),
"id": mapping.GetId(),
"mapping": string(raw),
}).Debug("Processing mapping update")
}
switch mapping.GetType() {
case proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED:
if err := s.addMapping(ctx, mapping); err != nil {
@@ -1228,60 +1488,6 @@ func (s *Server) processMappings(ctx context.Context, mappings []*proto.ProxyMap
}
}
// ensurePeers pre-creates NetBird peers for all unique accounts referenced by
// CREATED mappings. Peers for different accounts are created concurrently,
// which avoids serializing N×100ms gRPC round-trips during large initial syncs.
func (s *Server) ensurePeers(ctx context.Context, mappings []*proto.ProxyMapping) {
// Collect one representative mapping per account that needs a new peer.
type peerReq struct {
accountID types.AccountID
svcKey roundtrip.ServiceKey
authToken string
svcID types.ServiceID
}
seen := make(map[types.AccountID]struct{})
var reqs []peerReq
for _, m := range mappings {
if m.GetType() != proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED {
continue
}
accountID := types.AccountID(m.GetAccountId())
if _, ok := seen[accountID]; ok {
continue
}
seen[accountID] = struct{}{}
if s.netbird.HasClient(accountID) {
continue
}
reqs = append(reqs, peerReq{
accountID: accountID,
svcKey: s.serviceKeyForMapping(m),
authToken: m.GetAuthToken(),
svcID: types.ServiceID(m.GetId()),
})
}
if len(reqs) <= 1 {
return
}
var wg sync.WaitGroup
wg.Add(len(reqs))
for _, r := range reqs {
go func() {
defer wg.Done()
if err := s.netbird.AddPeer(ctx, r.accountID, r.svcKey, r.authToken, r.svcID); err != nil {
s.Logger.WithFields(log.Fields{
"account_id": r.accountID,
"service_id": r.svcID,
"error": err,
}).Warn("failed to pre-create peer for account")
}
}()
}
wg.Wait()
}
// addMapping registers a service mapping and starts the appropriate relay or routes.
func (s *Server) addMapping(ctx context.Context, mapping *proto.ProxyMapping) error {
accountID := types.AccountID(mapping.GetAccountId())
@@ -1353,12 +1559,16 @@ func (s *Server) setupHTTPMapping(ctx context.Context, mapping *proto.ProxyMappi
if s.acme != nil {
wildcardHit = s.acme.AddDomain(d, accountID, svcID)
}
s.mainRouter.AddRoute(nbtcp.SNIHost(mapping.GetDomain()), nbtcp.Route{
httpRoute := nbtcp.Route{
Type: nbtcp.RouteHTTP,
AccountID: accountID,
ServiceID: svcID,
Domain: mapping.GetDomain(),
})
}
s.mainRouter.AddRoute(nbtcp.SNIHost(mapping.GetDomain()), httpRoute)
if s.inbound != nil {
s.inbound.AddRoute(accountID, nbtcp.SNIHost(mapping.GetDomain()), httpRoute)
}
if err := s.updateMapping(ctx, mapping); err != nil {
return fmt.Errorf("update mapping for domain %q: %w", d, err)
}
@@ -1718,7 +1928,7 @@ func (s *Server) updateMapping(ctx context.Context, mapping *proto.ProxyMapping)
s.warnIfGeoUnavailable(mapping.GetDomain(), mapping.GetAccessRestrictions())
maxSessionAge := time.Duration(mapping.GetAuth().GetMaxSessionAgeSeconds()) * time.Second
if err := s.auth.AddDomain(mapping.GetDomain(), schemes, mapping.GetAuth().GetSessionKey(), maxSessionAge, accountID, svcID, ipRestrictions); err != nil {
if err := s.auth.AddDomain(mapping.GetDomain(), schemes, mapping.GetAuth().GetSessionKey(), maxSessionAge, accountID, svcID, ipRestrictions, mapping.GetPrivate()); err != nil {
return fmt.Errorf("auth setup for domain %s: %w", mapping.GetDomain(), err)
}
m := s.protoToMapping(ctx, mapping)
@@ -1774,6 +1984,9 @@ func (s *Server) cleanupMappingRoutes(mapping *proto.ProxyMapping) {
}
// Remove SNI route from the main router (covers both HTTP and main-port TLS).
s.mainRouter.RemoveRoute(nbtcp.SNIHost(host), svcID)
if s.inbound != nil {
s.inbound.RemoveRoute(types.AccountID(mapping.GetAccountId()), nbtcp.SNIHost(host), svcID)
}
}
// Extract and delete tracked custom-port entries atomically.
@@ -1861,6 +2074,7 @@ func (s *Server) protoToMapping(ctx context.Context, mapping *proto.ProxyMapping
if d := opts.GetRequestTimeout(); d != nil {
pt.RequestTimeout = d.AsDuration()
}
pt.DirectUpstream = opts.GetDirectUpstream()
}
pt.RequestTimeout = s.clampDialTimeout(pt.RequestTimeout)
paths[pathMapping.GetPath()] = pt

View File

@@ -1,9 +1,17 @@
package proxy
import (
"context"
"errors"
"io"
"testing"
"time"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/shared/management/proto"
)
func TestDebugEndpointDisabledByDefault(t *testing.T) {
@@ -46,3 +54,151 @@ func TestDebugEndpointAddr(t *testing.T) {
})
}
}
// quietLifecycleLogger keeps lifecycle tests from spamming the test output.
func quietLifecycleLogger() *log.Logger {
l := log.New()
l.SetOutput(io.Discard)
l.SetLevel(log.PanicLevel)
return l
}
func TestStopBeforeStartIsNoOp(t *testing.T) {
srv := New(Config{Logger: quietLifecycleLogger()})
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
err := srv.Stop(ctx)
assert.NoError(t, err, "Stop on an unstarted server must succeed without error")
err = srv.Stop(ctx)
assert.NoError(t, err, "Stop must remain idempotent across repeated calls")
}
func TestStartFailsWithoutManagement(t *testing.T) {
srv := New(Config{
Logger: quietLifecycleLogger(),
ListenAddr: "127.0.0.1:0",
ManagementAddress: "://broken-url",
})
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
err := srv.Start(ctx)
require.Error(t, err, "Start must surface management dial failures")
assert.True(t, srv.started, "started flag is set before any dial attempt so a second Start fails fast")
err = srv.Start(ctx)
require.Error(t, err, "second Start must reject")
assert.Contains(t, err.Error(), "already started", "error must explain why the call was rejected")
}
func TestStopIsIdempotent(t *testing.T) {
srv := &Server{
Logger: quietLifecycleLogger(),
started: true,
runErrCh: make(chan struct{}),
runCancel: func() {},
}
srv.recordRunErr(errors.New("synthetic"))
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
err := srv.Stop(ctx)
require.Error(t, err, "Stop must surface the recorded background error")
assert.Contains(t, err.Error(), "synthetic", "error must round-trip recordRunErr's value")
err = srv.Stop(ctx)
require.Error(t, err, "second Stop must still report the same error")
assert.Contains(t, err.Error(), "synthetic", "idempotent Stop must return the cached error")
}
func TestRecordRunErrPreservesFirstFailure(t *testing.T) {
srv := &Server{
Logger: quietLifecycleLogger(),
runErrCh: make(chan struct{}),
}
srv.recordRunErr(errors.New("first"))
srv.recordRunErr(errors.New("second"))
require.Error(t, srv.runErr, "first failure must be retained")
assert.Contains(t, srv.runErr.Error(), "first", "second call must not overwrite the cached error")
select {
case <-srv.runErrCh:
default:
t.Fatal("recordRunErr must close runErrCh so waitAndStop unblocks")
}
}
func TestStopSkipsShutdownWhenNeverStarted(t *testing.T) {
srv := New(Config{Logger: quietLifecycleLogger()})
ctx, cancel := context.WithCancel(context.Background())
cancel()
err := srv.Stop(ctx)
assert.NoError(t, err, "Stop on an unstarted server should not block on the cancelled ctx")
}
func TestRedactMappingForLog_ScrubsSensitiveFields(t *testing.T) {
original := &proto.ProxyMapping{
Id: "svc-1",
Domain: "example.com",
AuthToken: "super-secret-token",
Auth: &proto.Authentication{
SessionKey: "pubkey-not-secret",
HeaderAuths: []*proto.HeaderAuth{
{Header: "Authorization", HashedValue: "argon2-hash-1"},
{Header: "X-Api-Key", HashedValue: "argon2-hash-2"},
},
},
Path: []*proto.PathMapping{
{
Path: "/api",
Target: "10.0.0.1:8080",
Options: &proto.PathTargetOptions{
CustomHeaders: map[string]string{
"Authorization": "Bearer upstream-token",
"X-Tenant": "acme",
},
},
},
},
}
redacted := redactMappingForLog(original)
assert.Equal(t, "super-secret-token", original.AuthToken, "original must not be mutated")
assert.Equal(t, "argon2-hash-1", original.Auth.HeaderAuths[0].HashedValue, "original header hash must not be mutated")
assert.Equal(t, "Bearer upstream-token", original.Path[0].Options.CustomHeaders["Authorization"], "original custom header must not be mutated")
assert.Equal(t, "[REDACTED]", redacted.AuthToken, "auth_token must be redacted")
require.Len(t, redacted.Auth.HeaderAuths, 2, "header auths must be preserved in count")
assert.Equal(t, "Authorization", redacted.Auth.HeaderAuths[0].Header, "header name must be preserved")
assert.Equal(t, "[REDACTED]", redacted.Auth.HeaderAuths[0].HashedValue, "hashed_value must be redacted")
assert.Equal(t, "[REDACTED]", redacted.Auth.HeaderAuths[1].HashedValue, "hashed_value must be redacted for every header auth")
assert.Equal(t, "pubkey-not-secret", redacted.Auth.SessionKey, "session_key (public) must be preserved")
headers := redacted.Path[0].Options.CustomHeaders
require.Len(t, headers, 2, "custom header keys must be preserved")
assert.Equal(t, "[REDACTED]", headers["Authorization"], "custom header values must be redacted")
assert.Equal(t, "[REDACTED]", headers["X-Tenant"], "every custom header value must be redacted")
assert.Equal(t, "svc-1", redacted.Id, "non-sensitive fields must round-trip")
assert.Equal(t, "example.com", redacted.Domain, "non-sensitive fields must round-trip")
}
func TestRedactMappingForLog_HandlesEmptyOrNilFields(t *testing.T) {
empty := &proto.ProxyMapping{Id: "svc-empty"}
redacted := redactMappingForLog(empty)
assert.Equal(t, "", redacted.AuthToken, "empty auth_token must remain empty (no placeholder)")
assert.Nil(t, redacted.Auth, "nil Auth must remain nil")
assert.Empty(t, redacted.Path, "empty Path must remain empty")
}