feat(private-service): expose NetBird-only services over tunnel peers

Adds a new "private" service mode for the reverse proxy: services
reachable exclusively over the embedded WireGuard tunnel, gated by
per-peer group membership instead of operator auth schemes.

Wire contract
- ProxyMapping.private (field 13): the proxy MUST call
  ValidateTunnelPeer and fail closed; operator schemes are bypassed.
- ProxyCapabilities.private (4) + supports_private_service (5):
  capability gate. Management never streams private mappings to
  proxies that don't claim the capability; the broadcast path applies
  the same filter via filterMappingsForProxy.
- ValidateTunnelPeer RPC: resolves an inbound tunnel IP to a peer,
  checks the peer's groups against service.AccessGroups, and mints
  a session JWT on success. checkPeerGroupAccess fails closed when
  a private service has empty AccessGroups.
- ValidateSession/ValidateTunnelPeer responses now carry
  peer_group_ids + peer_group_names so the proxy can authorise
  policy-aware middlewares without an extra management round-trip.
- ProxyInboundListener + SendStatusUpdate.inbound_listener: per-account
  inbound listener state surfaced to dashboards.
- PathTargetOptions.direct_upstream (11): bypass the embedded NetBird
  client and dial the target via the proxy host's network stack for
  upstreams reachable without WireGuard.

Data model
- Service.Private (bool) + Service.AccessGroups ([]string, JSON-
  serialised). Validate() rejects bearer auth on private services.
  Copy() deep-copies AccessGroups. pgx getServices loads the columns.
- DomainConfig.Private threaded into the proxy auth middleware.
  Request handler routes private services through forwardWithTunnelPeer
  and returns 403 on validation failure.
- Account-level SynthesizePrivateServiceZones (synthetic DNS) and
  injectPrivateServicePolicies (synthetic ACL) gate on
  len(svc.AccessGroups) > 0.

Proxy
- /netbird proxy --private (embedded mode) flag; Config.Private in
  proxy/lifecycle.go.
- Per-account inbound listener (proxy/inbound.go) binding HTTP/HTTPS
  on the embedded NetBird client's WireGuard tunnel netstack.
- proxy/internal/auth/tunnel_cache: ValidateTunnelPeer response cache
  with single-flight de-duplication and per-account eviction.
- Local peerstore short-circuit: when the inbound IP isn't in the
  account roster, deny fast without an RPC.
- proxy/server.go reports SupportsPrivateService=true and redacts the
  full ProxyMapping JSON from info logs (auth_token + header-auth
  hashed values now only at debug level).

Identity forwarding
- ValidateSessionJWT returns user_id, email, method, groups,
  group_names. sessionkey.Claims carries Email + Groups + GroupNames
  so the proxy can stamp identity onto upstream requests without an
  extra management round-trip on every cookie-bearing request.
- CapturedData carries userEmail / userGroups / userGroupNames; the
  proxy stamps X-NetBird-User and X-NetBird-Groups on r.Out from the
  authenticated identity (strips client-supplied values first to
  prevent spoofing).
- AccessLog.UserGroups: access-log enrichment captures the user's
  group memberships at write time so the dashboard can render group
  context without reverse-resolving stale memberships.

OpenAPI/dashboard surface
- ReverseProxyService gains private + access_groups; ReverseProxyCluster
  gains private + supports_private. ReverseProxyTarget target_type
  enum gains "cluster". ServiceTargetOptions gains direct_upstream.
  ProxyAccessLog gains user_groups.
This commit is contained in:
mlsmaycon
2026-05-20 21:39:22 +02:00
parent 37052fd5bc
commit 167ee08e14
72 changed files with 6584 additions and 2586 deletions

View File

@@ -45,10 +45,14 @@ func ResolveProto(forwardedProto string, conn *tls.ConnectionState) string {
}
}
// ValidateSessionJWT validates a session JWT and returns the user ID and method.
func ValidateSessionJWT(tokenString, domain string, publicKey ed25519.PublicKey) (userID, method string, err error) {
// ValidateSessionJWT validates a session JWT and returns the user ID, the
// user's email (when carried), the authentication method, any embedded
// group memberships, and the parallel group display names. email,
// groups, and groupNames may be empty for tokens minted before those
// claims were introduced. groupNames pairs positionally with groups.
func ValidateSessionJWT(tokenString, domain string, publicKey ed25519.PublicKey) (userID, email, method string, groups, groupNames []string, err error) {
if publicKey == nil {
return "", "", fmt.Errorf("no public key configured for domain")
return "", "", "", nil, nil, fmt.Errorf("no public key configured for domain")
}
token, err := jwt.Parse(tokenString, func(t *jwt.Token) (interface{}, error) {
@@ -58,20 +62,46 @@ func ValidateSessionJWT(tokenString, domain string, publicKey ed25519.PublicKey)
return publicKey, nil
}, jwt.WithAudience(domain), jwt.WithIssuer(SessionJWTIssuer))
if err != nil {
return "", "", fmt.Errorf("parse token: %w", err)
return "", "", "", nil, nil, fmt.Errorf("parse token: %w", err)
}
claims, ok := token.Claims.(jwt.MapClaims)
if !ok || !token.Valid {
return "", "", fmt.Errorf("invalid token claims")
return "", "", "", nil, nil, fmt.Errorf("invalid token claims")
}
sub, _ := claims.GetSubject()
if sub == "" {
return "", "", fmt.Errorf("missing subject claim")
return "", "", "", nil, nil, fmt.Errorf("missing subject claim")
}
methodClaim, _ := claims["method"].(string)
emailClaim, _ := claims["email"].(string)
groups = extractGroupsClaim(claims["groups"])
groupNames = extractGroupsClaim(claims["group_names"])
return sub, methodClaim, nil
return sub, emailClaim, methodClaim, groups, groupNames, nil
}
// extractGroupsClaim decodes the "groups" claim into a string slice. The JWT
// library decodes JSON arrays as []interface{}, so we coerce element-wise
// and skip non-string entries silently.
func extractGroupsClaim(claim interface{}) []string {
raw, ok := claim.([]interface{})
if !ok {
return nil
}
if len(raw) == 0 {
return nil
}
groups := make([]string, 0, len(raw))
for _, v := range raw {
if s, ok := v.(string); ok && s != "" {
groups = append(groups, s)
}
}
if len(groups) == 0 {
return nil
}
return groups
}

View File

@@ -63,6 +63,7 @@ var (
preSharedKey string
supportsCustomPorts bool
requireSubdomain bool
private bool
geoDataDir string
crowdsecAPIURL string
crowdsecAPIKey string
@@ -105,6 +106,12 @@ func init() {
rootCmd.Flags().StringVar(&preSharedKey, "preshared-key", envStringOrDefault("NB_PROXY_PRESHARED_KEY", ""), "Define a pre-shared key for the tunnel between proxy and peers")
rootCmd.Flags().BoolVar(&supportsCustomPorts, "supports-custom-ports", envBoolOrDefault("NB_PROXY_SUPPORTS_CUSTOM_PORTS", true), "Whether the proxy can bind arbitrary ports for UDP/TCP passthrough")
rootCmd.Flags().BoolVar(&requireSubdomain, "require-subdomain", envBoolOrDefault("NB_PROXY_REQUIRE_SUBDOMAIN", false), "Require a subdomain label in front of the cluster domain")
// --private is internal: set by the embedded `netbird proxy` subcommand
// via NB_PROXY_PRIVATE so management can distinguish per-peer / private
// clusters from centralised ones. Hidden so the standalone CLI doesn't
// surface it as an operator-facing toggle.
rootCmd.Flags().BoolVar(&private, "private", envBoolOrDefault("NB_PROXY_PRIVATE", false), "Mark this proxy as embedded/private (internal flag)")
_ = rootCmd.Flags().MarkHidden("private")
rootCmd.Flags().DurationVar(&maxDialTimeout, "max-dial-timeout", envDurationOrDefault("NB_PROXY_MAX_DIAL_TIMEOUT", 0), "Cap per-service backend dial timeout (0 = no cap)")
rootCmd.Flags().DurationVar(&maxSessionIdleTimeout, "max-session-idle-timeout", envDurationOrDefault("NB_PROXY_MAX_SESSION_IDLE_TIMEOUT", 0), "Cap per-service session idle timeout (0 = no cap)")
rootCmd.Flags().StringVar(&geoDataDir, "geo-data-dir", envStringOrDefault("NB_PROXY_GEO_DATA_DIR", "/var/lib/netbird/geolocation"), "Directory for the GeoLite2 MMDB file (auto-downloaded if missing)")
@@ -119,6 +126,16 @@ func Execute() {
}
}
// NewRootCmd returns the proxy server cobra command for embedding under
// another binary (for example as the "proxy" subcommand of the netbird
// client). The returned command shares its flag set, RunE, and any
// previously registered subcommands (e.g. "debug") with the standalone
// binary, so flag and env-var contracts stay identical across both
// invocations.
func NewRootCmd() *cobra.Command {
return rootCmd
}
// SetVersionInfo sets version information for the CLI.
func SetVersionInfo(version, commit, buildDate, goVersion string) {
Version = version
@@ -161,7 +178,8 @@ func runServer(cmd *cobra.Command, args []string) error {
return fmt.Errorf("invalid --trusted-proxies: %w", err)
}
srv := proxy.Server{
srv := proxy.New(proxy.Config{
ListenAddr: addr,
Logger: logger,
Version: Version,
ManagementAddress: mgmtAddr,
@@ -178,7 +196,7 @@ func runServer(cmd *cobra.Command, args []string) error {
ACMEChallengeType: acmeChallengeType,
DebugEndpointEnabled: debugEndpoint,
DebugEndpointAddress: debugEndpointAddr,
HealthAddress: healthAddr,
HealthAddr: healthAddr,
ForwardedProto: forwardedProto,
TrustedProxies: parsedTrustedProxies,
CertLockMethod: nbacme.CertLockMethod(certLockMethod),
@@ -188,12 +206,13 @@ func runServer(cmd *cobra.Command, args []string) error {
PreSharedKey: preSharedKey,
SupportsCustomPorts: supportsCustomPorts,
RequireSubdomain: requireSubdomain,
Private: private,
MaxDialTimeout: maxDialTimeout,
MaxSessionIdleTimeout: maxSessionIdleTimeout,
GeoDataDir: geoDataDir,
CrowdSecAPIURL: crowdsecAPIURL,
CrowdSecAPIKey: crowdsecAPIKey,
}
})
ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGTERM, syscall.SIGINT)
defer stop()

537
proxy/inbound.go Normal file
View File

@@ -0,0 +1,537 @@
package proxy
import (
"context"
"crypto/tls"
"errors"
"fmt"
stdlog "log"
"net"
"net/http"
"net/netip"
"strconv"
"sync"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/embed"
"github.com/netbirdio/netbird/proxy/internal/auth"
"github.com/netbirdio/netbird/proxy/internal/debug"
nbtcp "github.com/netbirdio/netbird/proxy/internal/tcp"
"github.com/netbirdio/netbird/proxy/internal/types"
)
// httpInboundReadHeaderTimeout matches the host-listener read header timeout
// so per-account http.Servers don't leak idle connections.
const httpInboundReadHeaderTimeout = 30 * time.Second
// httpInboundIdleTimeout caps idle keep-alives on per-account inbound HTTP
// servers; matches the host listener.
const httpInboundIdleTimeout = 90 * time.Second
// inboundShutdownTimeout caps how long a per-account http.Server gets to
// drain in-flight requests during teardown.
const inboundShutdownTimeout = 5 * time.Second
// privateInboundPortHTTPS is the WG-side TLS port. Each account's
// embedded netstack binds independently, so a fixed port is fine.
const privateInboundPortHTTPS = 443
// privateInboundPortHTTP is the WG-side plain-HTTP port.
const privateInboundPortHTTP = 80
// inboundManager wires per-account inbound listeners into the proxy
// pipeline when --private-inbound is enabled. When disabled the manager
// is nil and every method on *Server that touches it short-circuits.
type inboundManager struct {
logger *log.Logger
handler http.Handler
tlsConfig *tls.Config
// muxLock guards entries and pendingRoutes.
muxLock sync.Mutex
entries map[types.AccountID]*inboundEntry
pendingRoutes map[types.AccountID][]pendingInboundRoute
}
// inboundEntry owns the listeners, router and HTTP servers for a single
// account's embedded netstack.
type inboundEntry struct {
router *nbtcp.Router
tlsListener net.Listener
plainListener net.Listener
httpsServer *http.Server
httpServer *http.Server
cancel context.CancelFunc
wg sync.WaitGroup
}
// pendingInboundRoute holds a route that arrived before the account's
// listener finished starting.
type pendingInboundRoute struct {
host nbtcp.SNIHost
route nbtcp.Route
}
// newInboundManager constructs a manager bound to the proxy's HTTP
// handler chain and TLS config.
func newInboundManager(logger *log.Logger, handler http.Handler, tlsConfig *tls.Config) *inboundManager {
return &inboundManager{
logger: logger,
handler: handler,
tlsConfig: tlsConfig,
entries: make(map[types.AccountID]*inboundEntry),
pendingRoutes: make(map[types.AccountID][]pendingInboundRoute),
}
}
// onClientReady is registered with NetBird.SetClientLifecycle so the
// listener pair comes up exactly when the embedded client reports ready.
// The returned value is opaque to the roundtrip package; it is handed
// back verbatim to onClientStop on teardown.
func (m *inboundManager) onClientReady(ctx context.Context, accountID types.AccountID, client *embed.Client) any {
if m == nil {
return nil
}
entry, err := m.bringUp(ctx, accountID, client)
if err != nil {
m.logger.WithField("account_id", accountID).WithError(err).Warn("failed to start per-account inbound listener; continuing without inbound")
return nil
}
m.flushPending(accountID, entry)
m.logger.WithFields(log.Fields{
"account_id": accountID,
"https": entry.tlsListener.Addr().String(),
"http": entry.plainListener.Addr().String(),
}).Info("per-account inbound listeners up")
return entry
}
// onClientStop tears down a per-account listener bundle. State is the
// opaque value previously returned by onClientReady.
func (m *inboundManager) onClientStop(accountID types.AccountID, state any) {
if m == nil {
return
}
entry, ok := state.(*inboundEntry)
if !ok || entry == nil {
return
}
m.tearDown(accountID, entry)
}
// bringUp opens both listeners on the account's netstack, builds the
// router, and starts the parallel HTTP servers.
func (m *inboundManager) bringUp(ctx context.Context, accountID types.AccountID, client *embed.Client) (*inboundEntry, error) {
tlsListener, err := client.ListenTCP(fmt.Sprintf(":%d", privateInboundPortHTTPS))
if err != nil {
return nil, fmt.Errorf("listen tls on netstack: %w", err)
}
plainListener, err := client.ListenTCP(fmt.Sprintf(":%d", privateInboundPortHTTP))
if err != nil {
_ = tlsListener.Close()
return nil, fmt.Errorf("listen plain on netstack: %w", err)
}
router := nbtcp.NewRouter(m.logger, accountDialResolver(accountID, client), tlsListener.Addr(), nbtcp.WithPlainHTTP(plainListener.Addr()))
scopedHandler := withTunnelLookup(m.handler, accountTunnelLookup(client))
httpsServer := &http.Server{
Handler: scopedHandler,
TLSConfig: m.tlsConfig,
ReadHeaderTimeout: httpInboundReadHeaderTimeout,
IdleTimeout: httpInboundIdleTimeout,
ErrorLog: newInboundErrorLog(m.logger, "https", accountID),
}
httpServer := &http.Server{
Handler: scopedHandler,
ReadHeaderTimeout: httpInboundReadHeaderTimeout,
IdleTimeout: httpInboundIdleTimeout,
ErrorLog: newInboundErrorLog(m.logger, "http", accountID),
}
runCtx, cancel := context.WithCancel(ctx)
entry := &inboundEntry{
router: router,
tlsListener: tlsListener,
plainListener: plainListener,
httpsServer: httpsServer,
httpServer: httpServer,
cancel: cancel,
}
entry.wg.Add(1)
go func() {
defer entry.wg.Done()
if err := router.Serve(runCtx, tlsListener); err != nil {
m.logger.WithField("account_id", accountID).Debugf("per-account router stopped: %v", err)
}
}()
entry.wg.Add(1)
go func() {
defer entry.wg.Done()
if err := httpsServer.ServeTLS(router.HTTPListener(), "", ""); err != nil && !errors.Is(err, http.ErrServerClosed) {
m.logger.WithField("account_id", accountID).Debugf("per-account https server stopped: %v", err)
}
}()
entry.wg.Add(1)
go func() {
defer entry.wg.Done()
if err := httpServer.Serve(router.HTTPListenerPlain()); err != nil && !errors.Is(err, http.ErrServerClosed) {
m.logger.WithField("account_id", accountID).Debugf("per-account http server stopped: %v", err)
}
}()
entry.wg.Add(1)
go func() {
defer entry.wg.Done()
feedRouterFromListener(runCtx, plainListener, router, m.logger, accountID)
}()
m.muxLock.Lock()
m.entries[accountID] = entry
m.muxLock.Unlock()
return entry, nil
}
// tearDown shuts every goroutine down and closes the netstack listeners.
func (m *inboundManager) tearDown(accountID types.AccountID, entry *inboundEntry) {
m.muxLock.Lock()
if m.entries[accountID] == entry {
delete(m.entries, accountID)
delete(m.pendingRoutes, accountID)
}
m.muxLock.Unlock()
entry.cancel()
shutdownCtx, cancel := context.WithTimeout(context.Background(), inboundShutdownTimeout)
defer cancel()
if err := entry.httpsServer.Shutdown(shutdownCtx); err != nil {
m.logger.Debugf("per-account https shutdown: %v", err)
}
if err := entry.httpServer.Shutdown(shutdownCtx); err != nil {
m.logger.Debugf("per-account http shutdown: %v", err)
}
if err := entry.tlsListener.Close(); err != nil {
m.logger.Debugf("close per-account tls listener: %v", err)
}
if err := entry.plainListener.Close(); err != nil {
m.logger.Debugf("close per-account plain listener: %v", err)
}
entry.wg.Wait()
}
// AddRoute records an SNI/host route on the account's per-account router.
// Routes registered before the listener is up are queued and replayed
// once startup completes.
func (m *inboundManager) AddRoute(accountID types.AccountID, host nbtcp.SNIHost, route nbtcp.Route) {
if m == nil {
return
}
m.muxLock.Lock()
entry, ok := m.entries[accountID]
if !ok {
m.queuePendingLocked(accountID, host, route)
m.muxLock.Unlock()
return
}
router := entry.router
m.muxLock.Unlock()
router.AddRoute(host, route)
}
// RemoveRoute drops a previously registered route. Safe to call when the
// listener is not yet up; queued copies are pruned in that case.
func (m *inboundManager) RemoveRoute(accountID types.AccountID, host nbtcp.SNIHost, svcID types.ServiceID) {
if m == nil {
return
}
m.muxLock.Lock()
m.dropPendingLocked(accountID, host, svcID)
entry, ok := m.entries[accountID]
if !ok {
m.muxLock.Unlock()
return
}
router := entry.router
m.muxLock.Unlock()
router.RemoveRoute(host, svcID)
}
// queuePendingLocked stores or upserts a pending route. Caller holds muxLock.
func (m *inboundManager) queuePendingLocked(accountID types.AccountID, host nbtcp.SNIHost, route nbtcp.Route) {
queued := m.pendingRoutes[accountID]
for i, pr := range queued {
if pr.host == host && pr.route.ServiceID == route.ServiceID {
queued[i] = pendingInboundRoute{host: host, route: route}
m.pendingRoutes[accountID] = queued
return
}
}
m.pendingRoutes[accountID] = append(queued, pendingInboundRoute{host: host, route: route})
}
// dropPendingLocked removes any queued route matching host/svcID.
// Caller holds muxLock.
func (m *inboundManager) dropPendingLocked(accountID types.AccountID, host nbtcp.SNIHost, svcID types.ServiceID) {
queued, ok := m.pendingRoutes[accountID]
if !ok {
return
}
filtered := queued[:0]
for _, pr := range queued {
if pr.host == host && pr.route.ServiceID == svcID {
continue
}
filtered = append(filtered, pr)
}
if len(filtered) == 0 {
delete(m.pendingRoutes, accountID)
return
}
m.pendingRoutes[accountID] = filtered
}
// flushPending applies all queued routes to a freshly-up router.
func (m *inboundManager) flushPending(accountID types.AccountID, entry *inboundEntry) {
m.muxLock.Lock()
queued := m.pendingRoutes[accountID]
delete(m.pendingRoutes, accountID)
m.muxLock.Unlock()
for _, pr := range queued {
entry.router.AddRoute(pr.host, pr.route)
}
}
// HasInbound reports whether the manager has a live listener for the account.
// Used by tests.
func (m *inboundManager) HasInbound(accountID types.AccountID) bool {
if m == nil {
return false
}
m.muxLock.Lock()
defer m.muxLock.Unlock()
_, ok := m.entries[accountID]
return ok
}
// PendingRouteCount reports the number of queued routes for the account.
// Used by tests.
func (m *inboundManager) PendingRouteCount(accountID types.AccountID) int {
if m == nil {
return 0
}
m.muxLock.Lock()
defer m.muxLock.Unlock()
return len(m.pendingRoutes[accountID])
}
// InboundListenerInfo describes the bound addresses of a single
// per-account inbound listener. Both addresses live on the embedded
// netstack of the account's WireGuard client and share the same tunnel IP.
type InboundListenerInfo struct {
TunnelIP string
HTTPSPort uint16
HTTPPort uint16
}
// ListenerInfo returns the inbound listener addresses for the given
// account, or ok=false when the account has no live listener. Used by
// the status-update RPC and the debug HTTP handler to surface inbound
// reachability to operators.
func (m *inboundManager) ListenerInfo(accountID types.AccountID) (InboundListenerInfo, bool) {
if m == nil {
return InboundListenerInfo{}, false
}
m.muxLock.Lock()
defer m.muxLock.Unlock()
entry, ok := m.entries[accountID]
if !ok || entry == nil {
return InboundListenerInfo{}, false
}
return listenerInfoFromEntry(entry), true
}
// Snapshot returns the inbound listener state for every account that has
// a live listener at call time. Empty when --private-inbound is off or
// no accounts have come up yet.
func (m *inboundManager) Snapshot() map[types.AccountID]InboundListenerInfo {
if m == nil {
return nil
}
m.muxLock.Lock()
defer m.muxLock.Unlock()
if len(m.entries) == 0 {
return nil
}
out := make(map[types.AccountID]InboundListenerInfo, len(m.entries))
for id, entry := range m.entries {
if entry == nil {
continue
}
out[id] = listenerInfoFromEntry(entry)
}
return out
}
// listenerInfoFromEntry extracts the tunnel IP and ports from a live
// per-account entry. Both listeners are bound on the same netstack so
// their host components match; we still pull the TLS host as the
// authoritative source.
func listenerInfoFromEntry(entry *inboundEntry) InboundListenerInfo {
info := InboundListenerInfo{HTTPSPort: privateInboundPortHTTPS, HTTPPort: privateInboundPortHTTP}
if entry.tlsListener != nil {
host, port := splitHostPort(entry.tlsListener.Addr())
info.TunnelIP = host
if port != 0 {
info.HTTPSPort = port
}
}
if entry.plainListener != nil {
host, port := splitHostPort(entry.plainListener.Addr())
if info.TunnelIP == "" {
info.TunnelIP = host
}
if port != 0 {
info.HTTPPort = port
}
}
return info
}
// splitHostPort extracts host and port from a net.Addr, returning the
// zero values when the address is missing or malformed.
func splitHostPort(addr net.Addr) (string, uint16) {
if addr == nil {
return "", 0
}
host, portStr, err := net.SplitHostPort(addr.String())
if err != nil {
return "", 0
}
if portStr == "" {
return host, 0
}
port, err := strconv.ParseUint(portStr, 10, 16)
if err != nil {
return host, 0
}
return host, uint16(port)
}
// feedRouterFromListener accepts on the plain-HTTP netstack listener and
// hands every connection to the account's router. The router peeks the
// first byte and dispatches to the plain-HTTP channel for non-TLS
// streams or the TLS channel for ClientHellos that arrive on :80.
func feedRouterFromListener(ctx context.Context, ln net.Listener, router *nbtcp.Router, logger *log.Logger, accountID types.AccountID) {
go func() {
<-ctx.Done()
_ = ln.Close()
}()
for {
conn, err := ln.Accept()
if err != nil {
if ctx.Err() != nil || errors.Is(err, net.ErrClosed) {
return
}
logger.WithField("account_id", accountID).Debugf("plain inbound accept: %v", err)
continue
}
router.HandleConn(ctx, conn)
}
}
// accountDialResolver returns a DialResolver bound to a single account's
// embedded client. The router only ever serves traffic for that account
// so the supplied accountID is ignored at dial time.
func accountDialResolver(_ types.AccountID, client *embed.Client) nbtcp.DialResolver {
return func(_ types.AccountID) (types.DialContextFunc, error) {
return client.DialContext, nil
}
}
// accountTunnelLookup returns a TunnelLookupFunc backed by the embedded
// client's peerstore for a single account. Phase 3 uses the result to
// short-circuit ValidateTunnelPeer when the source IP is not in the
// account's roster and to seed the cached identity for known peers.
func accountTunnelLookup(client *embed.Client) auth.TunnelLookupFunc {
if client == nil {
return nil
}
return func(ip netip.Addr) (auth.PeerIdentity, bool) {
pubKey, fqdn, ok := client.IdentityForIP(ip)
if !ok {
return auth.PeerIdentity{}, false
}
return auth.PeerIdentity{
PubKey: pubKey,
TunnelIP: ip,
FQDN: fqdn,
}, true
}
}
// withTunnelLookup returns an http.Handler that attaches the per-account
// peerstore lookup to every request's context before delegating to next.
// Calling on the host-level listener is a no-op because that path never
// installs this wrapper, so the existing behaviour stays byte-for-byte
// identical when --private-inbound is off or the request didn't arrive
// on a per-account listener.
func withTunnelLookup(next http.Handler, lookup auth.TunnelLookupFunc) http.Handler {
if lookup == nil {
return next
}
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := auth.WithTunnelLookup(r.Context(), lookup)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
// inboundDebugAdapter adapts *inboundManager to the debug.InboundProvider
// interface so the debug HTTP handler can render per-account inbound
// listener state without importing the proxy package.
type inboundDebugAdapter struct {
mgr *inboundManager
}
// InboundListeners returns a snapshot of the live per-account inbound
// listeners formatted for the debug surface.
func (a inboundDebugAdapter) InboundListeners() map[types.AccountID]debug.InboundListenerInfo {
if a.mgr == nil {
return nil
}
snap := a.mgr.Snapshot()
if len(snap) == 0 {
return nil
}
out := make(map[types.AccountID]debug.InboundListenerInfo, len(snap))
for id, info := range snap {
out[id] = debug.InboundListenerInfo{
TunnelIP: info.TunnelIP,
HTTPSPort: info.HTTPSPort,
HTTPPort: info.HTTPPort,
}
}
return out
}
// newInboundErrorLog routes a per-account http.Server's stdlib error
// stream through logrus at warn level.
func newInboundErrorLog(logger *log.Logger, scheme string, accountID types.AccountID) *stdlog.Logger {
return stdlog.New(logger.WithFields(log.Fields{
"inbound-http": scheme,
"account_id": accountID,
}).WriterLevel(log.WarnLevel), "", 0)
}

502
proxy/inbound_test.go Normal file
View File

@@ -0,0 +1,502 @@
package proxy
import (
"bufio"
"context"
"crypto/tls"
"net"
"net/http"
"net/http/httptest"
"net/netip"
"sync"
"sync/atomic"
"testing"
"time"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"github.com/netbirdio/netbird/proxy/internal/auth"
"github.com/netbirdio/netbird/proxy/internal/roundtrip"
nbtcp "github.com/netbirdio/netbird/proxy/internal/tcp"
"github.com/netbirdio/netbird/proxy/internal/types"
"github.com/netbirdio/netbird/shared/management/proto"
)
// bufioReader wraps the connection in a buffered reader so http.ReadResponse
// can parse the response line + headers off the wire.
func bufioReader(conn net.Conn) *bufio.Reader {
return bufio.NewReader(conn)
}
// quietLogger returns a logger that emits nothing — keeps test output tidy.
func quietLogger() *log.Logger {
logger := log.New()
logger.SetLevel(log.PanicLevel)
return logger
}
func TestInboundManager_RouteScopedToAccount(t *testing.T) {
mgr := newInboundManager(quietLogger(), http.NotFoundHandler(), nil)
accountA := types.AccountID("acct-a")
accountB := types.AccountID("acct-b")
mgr.AddRoute(accountA, "shared.example", nbtcp.Route{Type: nbtcp.RouteHTTP, AccountID: accountA, ServiceID: "svc-a", Domain: "shared.example"})
mgr.AddRoute(accountB, "other.example", nbtcp.Route{Type: nbtcp.RouteHTTP, AccountID: accountB, ServiceID: "svc-b", Domain: "other.example"})
require.Equal(t, 1, mgr.PendingRouteCount(accountA), "account A should have one queued route")
require.Equal(t, 1, mgr.PendingRouteCount(accountB), "account B should have one queued route")
mgr.RemoveRoute(accountA, "shared.example", "svc-a")
mgr.RemoveRoute(accountB, "other.example", "svc-b")
assert.Equal(t, 0, mgr.PendingRouteCount(accountA), "queue should drain on remove")
assert.Equal(t, 0, mgr.PendingRouteCount(accountB), "queue should drain on remove")
}
func TestInboundManager_PendingThenFlush(t *testing.T) {
mgr := newInboundManager(quietLogger(), http.NotFoundHandler(), nil)
accountID := types.AccountID("acct-1")
host := nbtcp.SNIHost("example.test")
route := nbtcp.Route{Type: nbtcp.RouteHTTP, AccountID: accountID, ServiceID: "svc-1", Domain: "example.test"}
mgr.AddRoute(accountID, host, route)
require.Equal(t, 1, mgr.PendingRouteCount(accountID), "pending count before listener is up")
// Simulate listener up by registering a fake entry, then flushing.
router := nbtcp.NewRouter(quietLogger(), nil, &fakeAddr{addr: "127.0.0.1:0"})
entry := &inboundEntry{router: router}
mgr.muxLock.Lock()
mgr.entries[accountID] = entry
mgr.muxLock.Unlock()
mgr.flushPending(accountID, entry)
assert.Equal(t, 0, mgr.PendingRouteCount(accountID), "queue should be empty after flush")
}
// fakeAddr is a stub net.Addr for tests that don't actually bind sockets.
type fakeAddr struct {
addr string
}
func (a *fakeAddr) Network() string { return "tcp" }
func (a *fakeAddr) String() string { return a.addr }
// fakeMgmtClient implements roundtrip.managementClient for tests.
type fakeMgmtClient struct{}
func (fakeMgmtClient) CreateProxyPeer(_ context.Context, _ *proto.CreateProxyPeerRequest, _ ...grpc.CallOption) (*proto.CreateProxyPeerResponse, error) {
return &proto.CreateProxyPeerResponse{Success: true}, nil
}
// TestServer_PrivateInbound_NotEnabled_NoManager confirms that with
// --private off the inbound manager is nil and the standalone proxy
// keeps its zero-overhead default path.
func TestServer_PrivateInbound_NotEnabled_NoManager(t *testing.T) {
s := &Server{Logger: quietLogger(), Private: false}
s.initPrivateInbound(http.NotFoundHandler(), nil)
assert.Nil(t, s.inbound, "manager should remain nil when --private is off")
}
// TestServer_PrivateInbound_Enabled_WiresLifecycle confirms that
// --private alone wires the manager into the NetBird transport, so
// AddPeer / RemovePeer drive the lifecycle.
func TestServer_PrivateInbound_Enabled_WiresLifecycle(t *testing.T) {
s := &Server{Logger: quietLogger(), Private: true}
// Construct a NetBird transport. We can't actually start the embedded
// client here (that needs a real management server), but we can
// confirm that the lifecycle callbacks are registered.
s.netbird = roundtrip.NewNetBird("test", "test", roundtrip.ClientConfig{
MgmtAddr: "http://invalid.test",
}, quietLogger(), nil, fakeMgmtClient{})
s.initPrivateInbound(http.NotFoundHandler(), &tls.Config{}) //nolint:gosec
require.NotNil(t, s.inbound, "manager should be set when --private is on")
assert.NotNil(t, s.inbound.handler, "handler should be set on manager")
assert.NotNil(t, s.inbound.tlsConfig, "tls config should be set on manager")
}
// TestInboundManager_AddRouteAfterReady_RegistersDirectly verifies that
// when the listener is already up, AddRoute writes straight to the
// router without queueing.
func TestInboundManager_AddRouteAfterReady_RegistersDirectly(t *testing.T) {
mgr := newInboundManager(quietLogger(), http.NotFoundHandler(), nil)
accountID := types.AccountID("acct-1")
router := nbtcp.NewRouter(quietLogger(), nil, &fakeAddr{addr: "127.0.0.1:0"})
mgr.muxLock.Lock()
mgr.entries[accountID] = &inboundEntry{router: router}
mgr.muxLock.Unlock()
host := nbtcp.SNIHost("ready.example")
mgr.AddRoute(accountID, host, nbtcp.Route{Type: nbtcp.RouteHTTP, AccountID: accountID, ServiceID: "svc-ready", Domain: string(host)})
assert.Equal(t, 0, mgr.PendingRouteCount(accountID), "no pending entries when listener is up")
}
// TestPrivateCapability_DerivedFromPrivateOnly tests that the capability
// bit reported upstream tracks --private exclusively. The previous
// --private-inbound flag has been folded into --private.
func TestPrivateCapability_DerivedFromPrivateOnly(t *testing.T) {
tests := []struct {
name string
private bool
expected bool
}{
{"off", false, false},
{"on", true, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := &Server{Private: tt.private}
assert.Equal(t, tt.expected, s.Private, "private capability bit should match --private")
})
}
}
// TestInboundManager_RouteScopedToAccountB_DoesNotMatchA verifies that a
// service registered for account B is invisible to a router serving
// account A. We exercise the path through real per-account routers.
func TestInboundManager_RouteScopedToAccountB_DoesNotMatchA(t *testing.T) {
mgr := newInboundManager(quietLogger(), http.NotFoundHandler(), nil)
accountA := types.AccountID("acct-a")
accountB := types.AccountID("acct-b")
routerA := nbtcp.NewRouter(quietLogger(), nil, &fakeAddr{addr: "127.0.0.1:0"})
routerB := nbtcp.NewRouter(quietLogger(), nil, &fakeAddr{addr: "127.0.0.1:0"})
mgr.muxLock.Lock()
mgr.entries[accountA] = &inboundEntry{router: routerA}
mgr.entries[accountB] = &inboundEntry{router: routerB}
mgr.muxLock.Unlock()
host := nbtcp.SNIHost("shared.example")
mgr.AddRoute(accountB, host, nbtcp.Route{Type: nbtcp.RouteHTTP, AccountID: accountB, ServiceID: "svc-b", Domain: string(host)})
// Account A's router should have no routes; account B's should have one.
// We check via IsEmpty — true means no routes and no fallback.
assert.True(t, routerA.IsEmpty(), "account A router must not see account B's mappings")
assert.False(t, routerB.IsEmpty(), "account B router should hold its own mapping")
}
// TestInboundEntry_ShutdownIdempotent ensures that tearDown can run twice
// without panicking — callers may invoke it from RemovePeer + StopAll.
func TestInboundEntry_ShutdownIdempotent(t *testing.T) {
t.Skip("teardown requires real netstack listeners; covered by integration tests")
}
// TestRouter_PlainHTTP_ForwardedProtoIsHTTP exercises the full per-account
// router pipeline against a loopback listener (proxy of a netstack
// listener for test purposes): a plain HTTP request lands on the plain
// http.Server and the inner handler observes a nil r.TLS, which is what
// auth.ResolveProto translates to "http" in the real pipeline.
func TestRouter_PlainHTTP_ForwardedProtoIsHTTP(t *testing.T) {
logger := quietLogger()
var captured atomic.Value
captured.Store("")
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.TLS == nil {
captured.Store("http")
} else {
captured.Store("https")
}
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ok"))
})
hostListener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err, "loopback listener bind must succeed")
defer hostListener.Close()
router := nbtcp.NewRouter(logger, nil, hostListener.Addr(), nbtcp.WithPlainHTTP(hostListener.Addr()))
httpServer := &http.Server{Handler: handler, ReadHeaderTimeout: time.Second}
defer func() { _ = httpServer.Close() }()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go func() { _ = httpServer.Serve(router.HTTPListenerPlain()) }()
go func() { _ = router.Serve(ctx, hostListener) }()
conn, err := net.DialTimeout("tcp", hostListener.Addr().String(), 2*time.Second)
require.NoError(t, err, "plain HTTP dial must succeed")
defer conn.Close()
_, err = conn.Write([]byte("GET / HTTP/1.1\r\nHost: example.com\r\nConnection: close\r\n\r\n"))
require.NoError(t, err, "write must succeed")
resp, err := http.ReadResponse(bufioReader(conn), nil)
require.NoError(t, err, "must read response")
defer resp.Body.Close()
assert.Equal(t, http.StatusOK, resp.StatusCode)
assert.Equal(t, "http", captured.Load(), "ForwardedProto must be http on plain path")
}
// TestWithTunnelLookup_AttachesLookupToContext verifies that requests
// flowing through the per-account handler wrapper carry the peerstore
// lookup function. Phase 3's local-first deny path depends on this.
func TestWithTunnelLookup_AttachesLookupToContext(t *testing.T) {
expected := auth.PeerIdentity{TunnelIP: netip.MustParseAddr("100.64.0.10"), FQDN: "peer.netbird"}
lookup := auth.TunnelLookupFunc(func(_ netip.Addr) (auth.PeerIdentity, bool) {
return expected, true
})
var observed auth.TunnelLookupFunc
inner := http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
observed = auth.TunnelLookupFromContext(r.Context())
})
handler := withTunnelLookup(inner, lookup)
r := httptest.NewRequest(http.MethodGet, "https://svc.example/", nil)
handler.ServeHTTP(httptest.NewRecorder(), r)
require.NotNil(t, observed, "wrapper must inject the lookup into the request context")
got, ok := observed(netip.MustParseAddr("100.64.0.10"))
assert.True(t, ok, "lookup must round-trip through context")
assert.Equal(t, expected.FQDN, got.FQDN, "lookup must return the same identity it was constructed with")
}
// TestWithTunnelLookup_NilLookupIsNoop confirms the wrapper is a pure
// pass-through when no lookup is provided. Required for the host-level
// listener path to keep its byte-for-byte previous behaviour.
func TestWithTunnelLookup_NilLookupIsNoop(t *testing.T) {
var called bool
inner := http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
called = true
assert.Nil(t, auth.TunnelLookupFromContext(r.Context()), "host-level path must not see a lookup function")
})
handler := withTunnelLookup(inner, nil)
r := httptest.NewRequest(http.MethodGet, "https://svc.example/", nil)
handler.ServeHTTP(httptest.NewRecorder(), r)
assert.True(t, called, "wrapper without lookup must still invoke next")
}
// fakeListener satisfies net.Listener for snapshot tests without binding
// a real socket on the netstack.
type fakeListener struct {
addr net.Addr
}
func (f *fakeListener) Accept() (net.Conn, error) { return nil, net.ErrClosed }
func (f *fakeListener) Close() error { return nil }
func (f *fakeListener) Addr() net.Addr { return f.addr }
// TestInboundManager_ListenerInfo confirms ListenerInfo and Snapshot
// surface the bound tunnel-IP and ports for live entries.
func TestInboundManager_ListenerInfo(t *testing.T) {
mgr := newInboundManager(quietLogger(), http.NotFoundHandler(), nil)
accountID := types.AccountID("acct-info")
tlsAddr := &net.TCPAddr{IP: net.ParseIP("100.64.0.5"), Port: privateInboundPortHTTPS}
plainAddr := &net.TCPAddr{IP: net.ParseIP("100.64.0.5"), Port: privateInboundPortHTTP}
mgr.muxLock.Lock()
mgr.entries[accountID] = &inboundEntry{
tlsListener: &fakeListener{addr: tlsAddr},
plainListener: &fakeListener{addr: plainAddr},
}
mgr.muxLock.Unlock()
info, ok := mgr.ListenerInfo(accountID)
require.True(t, ok, "ListenerInfo must report ok for live entry")
assert.Equal(t, "100.64.0.5", info.TunnelIP, "tunnel IP must come from listener address")
assert.Equal(t, uint16(privateInboundPortHTTPS), info.HTTPSPort, "TLS port must match bound port")
assert.Equal(t, uint16(privateInboundPortHTTP), info.HTTPPort, "HTTP port must match bound port")
snap := mgr.Snapshot()
require.Len(t, snap, 1, "snapshot must contain exactly one entry")
assert.Equal(t, info, snap[accountID], "snapshot entry must equal direct lookup")
_, ok = mgr.ListenerInfo(types.AccountID("missing"))
assert.False(t, ok, "ListenerInfo must report ok=false for unknown accounts")
}
// TestInboundManager_NilManagerSafe ensures the observability accessors
// are safe to call when --private-inbound is off (nil manager).
func TestInboundManager_NilManagerSafe(t *testing.T) {
var mgr *inboundManager
_, ok := mgr.ListenerInfo("anything")
assert.False(t, ok, "nil manager must return ok=false")
assert.Nil(t, mgr.Snapshot(), "nil manager must return nil snapshot")
}
// TestInboundManager_ConcurrentAddRemove pounds AddRoute / RemoveRoute
// from multiple goroutines to expose any locking gaps.
func TestInboundManager_ConcurrentAddRemove(t *testing.T) {
mgr := newInboundManager(quietLogger(), http.NotFoundHandler(), nil)
accountID := types.AccountID("acct-1")
const workers = 32
const iterations = 50
var wg sync.WaitGroup
wg.Add(workers)
for i := 0; i < workers; i++ {
go func(idx int) {
defer wg.Done()
host := nbtcp.SNIHost("example.test")
svc := types.ServiceID("svc")
route := nbtcp.Route{Type: nbtcp.RouteHTTP, AccountID: accountID, ServiceID: svc, Domain: "example.test"}
for j := 0; j < iterations; j++ {
mgr.AddRoute(accountID, host, route)
mgr.RemoveRoute(accountID, host, svc)
}
}(i)
}
done := make(chan struct{})
go func() {
wg.Wait()
close(done)
}()
select {
case <-done:
case <-time.After(10 * time.Second):
t.Fatal("concurrent add/remove timed out")
}
}
// TestFeedRouterFromListener_DeliversConnectionToHandler validates the
// per-account inbound chain end-to-end with a loopback listener
// substituted for the embedded netstack: a TCP connection arriving at
// the plain listener flows through feedRouterFromListener, the router's
// peek-and-dispatch, the wrapped HTTP server, and reaches the user
// handler. If the embedded netstack is delivering connections at all,
// this is the path they take. Failures localise to wiring bugs in the
// proxy, not the netstack.
func TestFeedRouterFromListener_DeliversConnectionToHandler(t *testing.T) {
logger := quietLogger()
hits := make(chan string, 1)
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
hits <- r.Host
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("served"))
})
plainLn, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err, "plain loopback bind must succeed")
t.Cleanup(func() { _ = plainLn.Close() })
router := nbtcp.NewRouter(logger, nil, &fakeAddr{addr: "127.0.0.1:0"}, nbtcp.WithPlainHTTP(plainLn.Addr()))
httpServer := &http.Server{Handler: handler, ReadHeaderTimeout: time.Second}
t.Cleanup(func() { _ = httpServer.Close() })
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)
go func() { _ = httpServer.Serve(router.HTTPListenerPlain()) }()
go feedRouterFromListener(ctx, plainLn, router, logger, types.AccountID("acct-1"))
conn, err := net.DialTimeout("tcp", plainLn.Addr().String(), 2*time.Second)
require.NoError(t, err, "must connect to the plain listener")
t.Cleanup(func() { _ = conn.Close() })
_, err = conn.Write([]byte("GET / HTTP/1.1\r\nHost: app.example\r\nConnection: close\r\n\r\n"))
require.NoError(t, err, "request write must succeed")
resp, err := http.ReadResponse(bufioReader(conn), nil)
require.NoError(t, err, "must read response from server")
t.Cleanup(func() { _ = resp.Body.Close() })
assert.Equal(t, http.StatusOK, resp.StatusCode, "handler must be reached")
select {
case host := <-hits:
assert.Equal(t, "app.example", host, "handler must observe the request Host")
case <-time.After(2 * time.Second):
t.Fatal("handler was not invoked — connection did not flow through router → http server")
}
}
// TestFeedRouterFromListener_DispatchesTLSToTLSChannel verifies that a
// TLS ClientHello arriving on the plain listener is detected by the
// router peek and re-dispatched to the TLS channel — the cross-channel
// fallback the inbound stack relies on for HTTPS-on-:80 testing.
func TestFeedRouterFromListener_DispatchesTLSToTLSChannel(t *testing.T) {
logger := quietLogger()
hits := make(chan string, 1)
tlsHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
hits <- r.Host
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("served-tls"))
})
plainLn, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err, "plain loopback bind must succeed")
t.Cleanup(func() { _ = plainLn.Close() })
tlsLn, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err, "tls loopback bind must succeed")
t.Cleanup(func() { _ = tlsLn.Close() })
router := nbtcp.NewRouter(logger, nil, tlsLn.Addr(), nbtcp.WithPlainHTTP(plainLn.Addr()))
tlsConfig := selfSignedTLSConfig(t)
httpsServer := &http.Server{
Handler: tlsHandler,
TLSConfig: tlsConfig,
ReadHeaderTimeout: time.Second,
}
t.Cleanup(func() { _ = httpsServer.Close() })
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)
go func() { _ = httpsServer.ServeTLS(router.HTTPListener(), "", "") }()
go feedRouterFromListener(ctx, plainLn, router, logger, types.AccountID("acct-tls"))
tlsConn, err := tls.Dial("tcp", plainLn.Addr().String(), &tls.Config{InsecureSkipVerify: true}) //nolint:gosec
require.NoError(t, err, "TLS dial against the plain listener must succeed (cross-channel)")
t.Cleanup(func() { _ = tlsConn.Close() })
req, err := http.NewRequest(http.MethodGet, "https://app.example/", nil)
require.NoError(t, err)
require.NoError(t, req.Write(tlsConn), "TLS request write must succeed")
resp, err := http.ReadResponse(bufioReader(tlsConn), req)
require.NoError(t, err, "must read TLS response")
t.Cleanup(func() { _ = resp.Body.Close() })
assert.Equal(t, http.StatusOK, resp.StatusCode, "TLS handler must be reached")
select {
case host := <-hits:
assert.Equal(t, "app.example", host, "TLS handler must observe the request Host")
case <-time.After(2 * time.Second):
t.Fatal("TLS handler was not invoked — peek/dispatch path is broken")
}
}
func selfSignedTLSConfig(t *testing.T) *tls.Config {
t.Helper()
cert, err := tls.X509KeyPair(testCertPEM, testKeyPEM)
require.NoError(t, err, "load static self-signed cert")
return &tls.Config{Certificates: []tls.Certificate{cert}, MinVersion: tls.VersionTLS12} //nolint:gosec
}
// testCertPEM / testKeyPEM are a minimal RSA self-signed cert for
// 127.0.0.1 — only used by tests that need a working TLS handshake.
var testCertPEM = []byte(`-----BEGIN CERTIFICATE-----
MIIBhTCCASugAwIBAgIQIRi6zePL6mKjOipn+dNuaTAKBggqhkjOPQQDAjASMRAw
DgYDVQQKEwdBY21lIENvMB4XDTE3MTAyMDE5NDMwNloXDTE4MTAyMDE5NDMwNlow
EjEQMA4GA1UEChMHQWNtZSBDbzBZMBMGByqGSM49AgEGCCqGSM49AwEHA0IABD0d
7VNhbWvZLWPuj/RtHFjvtJBEwOkhbN/BnnE8rnZR8+sbwnc/KhCk3FhnpHZnQz7B
5aETbbIgmuvewdjvSBSjYzBhMA4GA1UdDwEB/wQEAwICpDATBgNVHSUEDDAKBggr
BgEFBQcDATAPBgNVHRMBAf8EBTADAQH/MCkGA1UdEQQiMCCCDmxvY2FsaG9zdDo1
NDUzgg4xMjcuMC4wLjE6NTQ1MzAKBggqhkjOPQQDAgNIADBFAiEA2zpJEPQyz6/l
Wf86aX6PepsntZv2GYlA5UpabfT2EZICICpJ5h/iI+i341gBmLiAFQOyTDT+/wQc
6MF9+Yw1Yy0t
-----END CERTIFICATE-----`)
var testKeyPEM = []byte(`-----BEGIN EC PRIVATE KEY-----
MHcCAQEEIIrYSSNQFaA2Hwf1duRSxKtLYX5CB04fSeQ6tF1aY/PuoAoGCCqGSM49
AwEHoUQDQgAEPR3tU2Fta9ktY+6P9G0cWO+0kETA6SFs38GecTyudlHz6xvCdz8q
EKTcWGekdmdDPsHloRNtsiCa697B2O9IFA==
-----END EC PRIVATE KEY-----`)

View File

@@ -0,0 +1,47 @@
package auth
import (
"context"
"net/netip"
)
// PeerIdentity describes the locally-known facts about a peer reachable on
// the proxy's per-account WireGuard listener. Phase 3 fills PubKey, TunnelIP
// and FQDN from the embedded client's peerstore. UserID, Email and Groups
// stay zero in V1 — full identity still travels through ValidateTunnelPeer.
// Phase V2 will populate them once RemotePeerConfig carries user identity.
type PeerIdentity struct {
PubKey string
TunnelIP netip.Addr
FQDN string
// V2 fields (zero in V1).
UserID string
Email string
Groups []string
}
// TunnelLookupFunc resolves a tunnel IP to a peer identity using locally
// available peerstore data. ok=false means the IP is not in the calling
// account's roster.
type TunnelLookupFunc func(ip netip.Addr) (PeerIdentity, bool)
type tunnelLookupContextKey struct{}
// WithTunnelLookup attaches a per-account peerstore lookup function to
// the request context. The auth middleware calls this lookup before
// hitting management's ValidateTunnelPeer to short-circuit unknown IPs
// and to skip the RPC for already-cached identities.
func WithTunnelLookup(ctx context.Context, lookup TunnelLookupFunc) context.Context {
if lookup == nil {
return ctx
}
return context.WithValue(ctx, tunnelLookupContextKey{}, lookup)
}
// TunnelLookupFromContext returns the peerstore lookup attached to ctx,
// or nil when the request did not arrive on a per-account listener.
func TunnelLookupFromContext(ctx context.Context) TunnelLookupFunc {
v, _ := ctx.Value(tunnelLookupContextKey{}).(TunnelLookupFunc)
return v
}

View File

@@ -36,6 +36,7 @@ type authenticator interface {
// SessionValidator validates session tokens and checks user access permissions.
type SessionValidator interface {
ValidateSession(ctx context.Context, in *proto.ValidateSessionRequest, opts ...grpc.CallOption) (*proto.ValidateSessionResponse, error)
ValidateTunnelPeer(ctx context.Context, in *proto.ValidateTunnelPeerRequest, opts ...grpc.CallOption) (*proto.ValidateTunnelPeerResponse, error)
}
// Scheme defines an authentication mechanism for a domain.
@@ -56,12 +57,21 @@ type DomainConfig struct {
AccountID types.AccountID
ServiceID types.ServiceID
IPRestrictions *restrict.Filter
// Private routes the domain through ValidateTunnelPeer; failure → 403.
Private bool
}
type validationResult struct {
UserID string
UserEmail string
Valid bool
DeniedReason string
Groups []string
// GroupNames carries the human-readable display names for Groups,
// ordered identically (positional pairing). May be shorter than
// Groups for tokens minted before names were embedded; the consumer
// falls back to ids for missing positions.
GroupNames []string
}
// Middleware applies per-domain authentication and IP restriction checks.
@@ -71,6 +81,7 @@ type Middleware struct {
logger *log.Logger
sessionValidator SessionValidator
geo restrict.GeoResolver
tunnelCache *tunnelValidationCache
}
// NewMiddleware creates a new authentication middleware. The sessionValidator is
@@ -84,6 +95,7 @@ func NewMiddleware(logger *log.Logger, sessionValidator SessionValidator, geo re
logger: logger,
sessionValidator: sessionValidator,
geo: geo,
tunnelCache: newTunnelValidationCache(),
}
}
@@ -111,6 +123,15 @@ func (mw *Middleware) Protect(next http.Handler) http.Handler {
return
}
// Private services bypass operator schemes and gate on tunnel peer.
if config.Private {
if mw.forwardWithTunnelPeer(w, r, host, config, next) {
return
}
http.Error(w, "Forbidden", http.StatusForbidden)
return
}
// Domains with no authentication schemes pass through after IP checks.
if len(config.Schemes) == 0 {
next.ServeHTTP(w, r)
@@ -129,10 +150,54 @@ func (mw *Middleware) Protect(next http.Handler) http.Handler {
return
}
if mw.forwardWithTunnelPeer(w, r, host, config, next) {
return
}
if mw.blockOIDCOnPlainHTTP(w, r, config) {
return
}
mw.authenticateWithSchemes(w, r, host, config)
})
}
// requestIsPlainHTTP reports whether the request arrived without TLS.
// Used to gate cookie-on-plain warnings and the OIDC plain-HTTP block.
func requestIsPlainHTTP(r *http.Request) bool {
return r.TLS == nil
}
// hasOIDCScheme reports whether any of the configured schemes requires
// TLS to round-trip safely with an external IdP.
func hasOIDCScheme(schemes []Scheme) bool {
for _, s := range schemes {
if s.Type() == auth.MethodOIDC {
return true
}
}
return false
}
// blockOIDCOnPlainHTTP fails fast when an OIDC-configured domain is hit
// over plain HTTP. Most IdPs reject http:// redirect URIs, so surfacing
// the misconfiguration here yields a clearer error than the IdP's
// "invalid redirect_uri" round-trip.
func (mw *Middleware) blockOIDCOnPlainHTTP(w http.ResponseWriter, r *http.Request, config DomainConfig) bool {
if !requestIsPlainHTTP(r) {
return false
}
if !hasOIDCScheme(config.Schemes) {
return false
}
mw.logger.WithFields(log.Fields{
"host": r.Host,
"remote": r.RemoteAddr,
}).Warn("OIDC scheme reached on plain HTTP path; rejecting with 400 — use port 443")
http.Error(w, "OIDC requires TLS — use port 443", http.StatusBadRequest)
return true
}
func (mw *Middleware) getDomainConfig(host string) (DomainConfig, bool) {
mw.domainsMux.RLock()
defer mw.domainsMux.RUnlock()
@@ -246,18 +311,117 @@ func (mw *Middleware) forwardWithSessionCookie(w http.ResponseWriter, r *http.Re
if err != nil {
return false
}
userID, method, err := auth.ValidateSessionJWT(cookie.Value, host, config.SessionPublicKey)
if requestIsPlainHTTP(r) {
mw.logger.WithFields(log.Fields{
"host": host,
"remote": r.RemoteAddr,
}).Warn("session cookie on plain HTTP path; cookie auth requires TLS — use port 443")
}
userID, email, method, groups, groupNames, err := auth.ValidateSessionJWT(cookie.Value, host, config.SessionPublicKey)
if err != nil {
return false
}
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
cd.SetUserID(userID)
cd.SetUserEmail(email)
cd.SetUserGroups(groups)
cd.SetUserGroupNames(groupNames)
cd.SetAuthMethod(method)
}
next.ServeHTTP(w, r)
return true
}
// forwardWithTunnelPeer is the OIDC fast-path for requests originating on the
// netbird mesh. When the source IP belongs to a private/CGNAT range the proxy
// asks management to resolve it to a peer/user and to gate by the service's
// distribution_groups. On success the proxy installs the freshly minted JWT
// as a session cookie, sets UserID + Method=oidc on the captured data, and
// forwards directly — operators see the same access-log shape as if the user
// had completed an OIDC redirect. Any failure (private-range mismatch,
// management unreachable, peer unknown, user not in group) returns false so
// the caller falls back to the existing OIDC scheme dispatch.
//
// Phase 3 adds a local-first short-circuit: when the request arrived on a
// per-account inbound listener the context carries a peerstore lookup
// (TunnelLookupFromContext). If the lookup says the IP isn't in the account's
// roster the proxy denies fast without calling management. If the lookup
// confirms a known peer the RPC still runs for the user-identity tail
// (UserID + group access), but its result is cached for tunnelCacheTTL so
// repeat requests skip management entirely.
func (mw *Middleware) forwardWithTunnelPeer(w http.ResponseWriter, r *http.Request, host string, config DomainConfig, next http.Handler) bool {
if mw.sessionValidator == nil {
return false
}
clientIP := mw.resolveClientIP(r)
if !clientIP.IsValid() {
return false
}
if !isTunnelSourceIP(clientIP) {
return false
}
if lookup := TunnelLookupFromContext(r.Context()); lookup != nil {
if _, ok := lookup(clientIP); !ok {
mw.logger.WithFields(log.Fields{
"host": host,
"remote": clientIP,
}).Debug("local peerstore: tunnel IP not in account roster; denying without RPC")
return false
}
}
resp, _, err := mw.tunnelCache.fetch(r.Context(), tunnelCacheKey{
accountID: config.AccountID,
tunnelIP: clientIP,
domain: host,
}, mw.validateTunnelPeer)
if err != nil {
mw.logger.WithError(err).Debug("ValidateTunnelPeer failed; falling back to OIDC")
return false
}
if !resp.GetValid() || resp.GetSessionToken() == "" {
return false
}
setSessionCookie(w, resp.GetSessionToken(), config.SessionExpiration)
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
cd.SetOrigin(proxy.OriginAuth)
cd.SetUserID(resp.GetUserId())
cd.SetUserEmail(resp.GetUserEmail())
cd.SetUserGroups(resp.GetPeerGroupIds())
cd.SetUserGroupNames(resp.GetPeerGroupNames())
cd.SetAuthMethod(auth.MethodOIDC.String())
}
next.ServeHTTP(w, r)
return true
}
// validateTunnelPeer adapts the SessionValidator interface to the cache's
// validateTunnelPeerFn signature.
func (mw *Middleware) validateTunnelPeer(ctx context.Context, req *proto.ValidateTunnelPeerRequest) (*proto.ValidateTunnelPeerResponse, error) {
return mw.sessionValidator.ValidateTunnelPeer(ctx, req)
}
// cgnatPrefix covers RFC 6598 100.64.0.0/10, the CGNAT block NetBird
// allocates tunnel addresses from by default. IsPrivate() doesn't include
// it, so we check it explicitly.
var cgnatPrefix = netip.MustParsePrefix("100.64.0.0/10")
// isTunnelSourceIP reports whether ip falls within an address range typical
// of NetBird tunnels: RFC1918 private space, IPv6 ULA, or CGNAT 100.64/10
// (NetBird's default range). Loopback and link-local are excluded — the
// fast-path is meant for peer-to-peer mesh traffic, not localhost.
func isTunnelSourceIP(ip netip.Addr) bool {
if !ip.IsValid() || ip.IsLoopback() || ip.IsLinkLocalUnicast() {
return false
}
if ip.IsPrivate() {
return true
}
return cgnatPrefix.Contains(ip)
}
// forwardWithHeaderAuth checks for a Header auth scheme. If the header validates,
// the request is forwarded directly (no redirect), which is important for API clients.
func (mw *Middleware) forwardWithHeaderAuth(w http.ResponseWriter, r *http.Request, host string, config DomainConfig, next http.Handler) bool {
@@ -286,7 +450,7 @@ func (mw *Middleware) tryHeaderScheme(w http.ResponseWriter, r *http.Request, ho
result, err := mw.validateSessionToken(r.Context(), host, token, config.SessionPublicKey, auth.MethodHeader)
if err != nil {
setHeaderCapturedData(r.Context(), "")
setHeaderCapturedData(r.Context(), "", "", nil, nil)
status := http.StatusBadRequest
msg := "invalid session token"
if errors.Is(err, errValidationUnavailable) {
@@ -298,7 +462,7 @@ func (mw *Middleware) tryHeaderScheme(w http.ResponseWriter, r *http.Request, ho
}
if !result.Valid {
setHeaderCapturedData(r.Context(), result.UserID)
setHeaderCapturedData(r.Context(), result.UserID, result.UserEmail, result.Groups, result.GroupNames)
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return true
}
@@ -306,6 +470,9 @@ func (mw *Middleware) tryHeaderScheme(w http.ResponseWriter, r *http.Request, ho
setSessionCookie(w, token, config.SessionExpiration)
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
cd.SetUserID(result.UserID)
cd.SetUserEmail(result.UserEmail)
cd.SetUserGroups(result.Groups)
cd.SetUserGroupNames(result.GroupNames)
cd.SetAuthMethod(auth.MethodHeader.String())
}
@@ -315,7 +482,7 @@ func (mw *Middleware) tryHeaderScheme(w http.ResponseWriter, r *http.Request, ho
func (mw *Middleware) handleHeaderAuthError(w http.ResponseWriter, r *http.Request, err error) bool {
if errors.Is(err, ErrHeaderAuthFailed) {
setHeaderCapturedData(r.Context(), "")
setHeaderCapturedData(r.Context(), "", "", nil, nil)
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return true
}
@@ -327,7 +494,7 @@ func (mw *Middleware) handleHeaderAuthError(w http.ResponseWriter, r *http.Reque
return true
}
func setHeaderCapturedData(ctx context.Context, userID string) {
func setHeaderCapturedData(ctx context.Context, userID, userEmail string, groups, groupNames []string) {
cd := proxy.CapturedDataFromContext(ctx)
if cd == nil {
return
@@ -335,6 +502,9 @@ func setHeaderCapturedData(ctx context.Context, userID string) {
cd.SetOrigin(proxy.OriginAuth)
cd.SetAuthMethod(auth.MethodHeader.String())
cd.SetUserID(userID)
cd.SetUserEmail(userEmail)
cd.SetUserGroups(groups)
cd.SetUserGroupNames(groupNames)
}
// authenticateWithSchemes tries each configured auth scheme in order.
@@ -405,6 +575,9 @@ func (mw *Middleware) handleAuthenticatedToken(w http.ResponseWriter, r *http.Re
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
cd.SetOrigin(proxy.OriginAuth)
cd.SetUserID(result.UserID)
cd.SetUserEmail(result.UserEmail)
cd.SetUserGroups(result.Groups)
cd.SetUserGroupNames(result.GroupNames)
cd.SetAuthMethod(scheme.Type().String())
requestID = cd.GetRequestID()
}
@@ -419,6 +592,9 @@ func (mw *Middleware) handleAuthenticatedToken(w http.ResponseWriter, r *http.Re
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
cd.SetOrigin(proxy.OriginAuth)
cd.SetUserID(result.UserID)
cd.SetUserEmail(result.UserEmail)
cd.SetUserGroups(result.Groups)
cd.SetUserGroupNames(result.GroupNames)
cd.SetAuthMethod(scheme.Type().String())
}
redirectURL := stripSessionTokenParam(r.URL)
@@ -454,12 +630,9 @@ func wasCredentialSubmitted(r *http.Request, method auth.Method) bool {
return false
}
// AddDomain registers authentication schemes for the given domain.
// If schemes are provided, a valid session public key is required to sign/verify
// session JWTs. Returns an error if the key is missing or invalid.
// Callers must not serve the domain if this returns an error, to avoid
// exposing an unauthenticated service.
func (mw *Middleware) AddDomain(domain string, schemes []Scheme, publicKeyB64 string, expiration time.Duration, accountID types.AccountID, serviceID types.ServiceID, ipRestrictions *restrict.Filter) error {
// AddDomain registers authentication schemes for the given domain. With schemes a valid session public key is required.
// private=true forces ValidateTunnelPeer enforcement (403 on failure) regardless of the schemes list.
func (mw *Middleware) AddDomain(domain string, schemes []Scheme, publicKeyB64 string, expiration time.Duration, accountID types.AccountID, serviceID types.ServiceID, ipRestrictions *restrict.Filter, private bool) error {
if len(schemes) == 0 {
mw.domainsMux.Lock()
defer mw.domainsMux.Unlock()
@@ -467,6 +640,7 @@ func (mw *Middleware) AddDomain(domain string, schemes []Scheme, publicKeyB64 st
AccountID: accountID,
ServiceID: serviceID,
IPRestrictions: ipRestrictions,
Private: private,
}
return nil
}
@@ -488,6 +662,7 @@ func (mw *Middleware) AddDomain(domain string, schemes []Scheme, publicKeyB64 st
AccountID: accountID,
ServiceID: serviceID,
IPRestrictions: ipRestrictions,
Private: private,
}
return nil
}
@@ -518,18 +693,27 @@ func (mw *Middleware) validateSessionToken(ctx context.Context, host, token stri
}).Debug("Session validation denied")
return &validationResult{
UserID: resp.UserId,
UserEmail: resp.GetUserEmail(),
Valid: false,
DeniedReason: resp.DeniedReason,
Groups: resp.GetPeerGroupIds(),
GroupNames: resp.GetPeerGroupNames(),
}, nil
}
return &validationResult{UserID: resp.UserId, Valid: true}, nil
return &validationResult{
UserID: resp.UserId,
UserEmail: resp.GetUserEmail(),
Valid: true,
Groups: resp.GetPeerGroupIds(),
GroupNames: resp.GetPeerGroupNames(),
}, nil
}
userID, _, err := auth.ValidateSessionJWT(token, host, publicKey)
userID, email, _, groups, groupNames, err := auth.ValidateSessionJWT(token, host, publicKey)
if err != nil {
return nil, err
}
return &validationResult{UserID: userID, Valid: true}, nil
return &validationResult{UserID: userID, UserEmail: email, Valid: true, Groups: groups, GroupNames: groupNames}, nil
}
// stripSessionTokenParam returns the request URI with the session_token query

View File

@@ -4,13 +4,16 @@ import (
"context"
"crypto/ed25519"
"crypto/rand"
"crypto/tls"
"encoding/base64"
"errors"
"io"
"net/http"
"net/http/httptest"
"net/netip"
"net/url"
"strings"
"sync"
"testing"
"time"
@@ -62,7 +65,7 @@ func TestAddDomain_ValidKey(t *testing.T) {
kp := generateTestKeyPair(t)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
err := mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil)
err := mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false)
require.NoError(t, err)
mw.domainsMux.RLock()
@@ -79,7 +82,7 @@ func TestAddDomain_EmptyKey(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil, nil)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
err := mw.AddDomain("example.com", []Scheme{scheme}, "", time.Hour, "", "", nil)
err := mw.AddDomain("example.com", []Scheme{scheme}, "", time.Hour, "", "", nil, false)
require.Error(t, err)
assert.Contains(t, err.Error(), "invalid session public key size")
@@ -93,7 +96,7 @@ func TestAddDomain_InvalidBase64(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil, nil)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
err := mw.AddDomain("example.com", []Scheme{scheme}, "not-valid-base64!!!", time.Hour, "", "", nil)
err := mw.AddDomain("example.com", []Scheme{scheme}, "not-valid-base64!!!", time.Hour, "", "", nil, false)
require.Error(t, err)
assert.Contains(t, err.Error(), "decode session public key")
@@ -108,7 +111,7 @@ func TestAddDomain_WrongKeySize(t *testing.T) {
shortKey := base64.StdEncoding.EncodeToString([]byte("tooshort"))
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
err := mw.AddDomain("example.com", []Scheme{scheme}, shortKey, time.Hour, "", "", nil)
err := mw.AddDomain("example.com", []Scheme{scheme}, shortKey, time.Hour, "", "", nil, false)
require.Error(t, err)
assert.Contains(t, err.Error(), "invalid session public key size")
@@ -121,7 +124,7 @@ func TestAddDomain_WrongKeySize(t *testing.T) {
func TestAddDomain_NoSchemes_NoKeyRequired(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil, nil)
err := mw.AddDomain("example.com", nil, "", time.Hour, "", "", nil)
err := mw.AddDomain("example.com", nil, "", time.Hour, "", "", nil, false)
require.NoError(t, err, "domains with no auth schemes should not require a key")
mw.domainsMux.RLock()
@@ -137,8 +140,8 @@ func TestAddDomain_OverwritesPreviousConfig(t *testing.T) {
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp1.PublicKey, time.Hour, "", "", nil))
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp2.PublicKey, 2*time.Hour, "", "", nil))
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp1.PublicKey, time.Hour, "", "", nil, false))
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp2.PublicKey, 2*time.Hour, "", "", nil, false))
mw.domainsMux.RLock()
config := mw.domains["example.com"]
@@ -154,7 +157,7 @@ func TestRemoveDomain(t *testing.T) {
kp := generateTestKeyPair(t)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
mw.RemoveDomain("example.com")
@@ -178,7 +181,7 @@ func TestProtect_UnknownDomainPassesThrough(t *testing.T) {
func TestProtect_DomainWithNoSchemesPassesThrough(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil, nil)
require.NoError(t, mw.AddDomain("example.com", nil, "", time.Hour, "", "", nil))
require.NoError(t, mw.AddDomain("example.com", nil, "", time.Hour, "", "", nil, false))
handler := mw.Protect(newPassthroughHandler())
@@ -195,7 +198,7 @@ func TestProtect_UnauthenticatedRequestIsBlocked(t *testing.T) {
kp := generateTestKeyPair(t)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
var backendCalled bool
backend := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
@@ -216,7 +219,7 @@ func TestProtect_HostWithPortIsMatched(t *testing.T) {
kp := generateTestKeyPair(t)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
var backendCalled bool
backend := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
@@ -237,9 +240,9 @@ func TestProtect_ValidSessionCookiePassesThrough(t *testing.T) {
kp := generateTestKeyPair(t)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
token, err := sessionkey.SignToken(kp.PrivateKey, "test-user", "example.com", auth.MethodPIN, time.Hour)
token, err := sessionkey.SignToken(kp.PrivateKey, "test-user", "", "example.com", auth.MethodPIN, nil, nil, time.Hour)
require.NoError(t, err)
capturedData := proxy.NewCapturedData("")
@@ -262,15 +265,48 @@ func TestProtect_ValidSessionCookiePassesThrough(t *testing.T) {
assert.Equal(t, "authenticated", rec.Body.String())
}
// TestProtect_SessionCookieGroupsPropagate verifies the cookie path lifts the
// JWT's groups claim into CapturedData so policy-aware middlewares can
// authorise without an extra management round-trip.
func TestProtect_SessionCookieGroupsPropagate(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil, nil)
kp := generateTestKeyPair(t)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
groups := []string{"engineering", "sre"}
token, err := sessionkey.SignToken(kp.PrivateKey, "test-user", "", "example.com", auth.MethodPIN, groups, nil, time.Hour)
require.NoError(t, err)
capturedData := proxy.NewCapturedData("")
handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
cd := proxy.CapturedDataFromContext(r.Context())
require.NotNil(t, cd, "captured data must be present in request context")
assert.Equal(t, "test-user", cd.GetUserID())
assert.Equal(t, groups, cd.GetUserGroups(), "JWT groups claim must propagate to CapturedData")
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
req = req.WithContext(proxy.WithCapturedData(req.Context(), capturedData))
req.AddCookie(&http.Cookie{Name: auth.SessionCookieName, Value: token})
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code, "request with valid groups-bearing cookie must succeed")
assert.Equal(t, groups, capturedData.GetUserGroups(), "CapturedData groups must be retained after handler completes")
}
func TestProtect_ExpiredSessionCookieIsRejected(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil, nil)
kp := generateTestKeyPair(t)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
// Sign a token that expired 1 second ago.
token, err := sessionkey.SignToken(kp.PrivateKey, "test-user", "example.com", auth.MethodPIN, -time.Second)
token, err := sessionkey.SignToken(kp.PrivateKey, "test-user", "", "example.com", auth.MethodPIN, nil, nil, -time.Second)
require.NoError(t, err)
var backendCalled bool
@@ -293,10 +329,10 @@ func TestProtect_WrongDomainCookieIsRejected(t *testing.T) {
kp := generateTestKeyPair(t)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
// Token signed for a different domain audience.
token, err := sessionkey.SignToken(kp.PrivateKey, "test-user", "other.com", auth.MethodPIN, time.Hour)
token, err := sessionkey.SignToken(kp.PrivateKey, "test-user", "", "other.com", auth.MethodPIN, nil, nil, time.Hour)
require.NoError(t, err)
var backendCalled bool
@@ -320,10 +356,10 @@ func TestProtect_WrongKeyCookieIsRejected(t *testing.T) {
kp2 := generateTestKeyPair(t)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp1.PublicKey, time.Hour, "", "", nil))
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp1.PublicKey, time.Hour, "", "", nil, false))
// Token signed with a different private key.
token, err := sessionkey.SignToken(kp2.PrivateKey, "test-user", "example.com", auth.MethodPIN, time.Hour)
token, err := sessionkey.SignToken(kp2.PrivateKey, "test-user", "", "example.com", auth.MethodPIN, nil, nil, time.Hour)
require.NoError(t, err)
var backendCalled bool
@@ -345,7 +381,7 @@ func TestProtect_SchemeAuthRedirectsWithCookie(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil, nil)
kp := generateTestKeyPair(t)
token, err := sessionkey.SignToken(kp.PrivateKey, "pin-user", "example.com", auth.MethodPIN, time.Hour)
token, err := sessionkey.SignToken(kp.PrivateKey, "pin-user", "", "example.com", auth.MethodPIN, nil, nil, time.Hour)
require.NoError(t, err)
scheme := &stubScheme{
@@ -357,7 +393,7 @@ func TestProtect_SchemeAuthRedirectsWithCookie(t *testing.T) {
return "", "pin", nil
},
}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
var backendCalled bool
backend := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
@@ -410,7 +446,7 @@ func TestProtect_FailedAuthDoesNotSetCookie(t *testing.T) {
return "", "pin", nil
},
}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
handler := mw.Protect(newPassthroughHandler())
@@ -427,7 +463,7 @@ func TestProtect_MultipleSchemes(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil, nil)
kp := generateTestKeyPair(t)
token, err := sessionkey.SignToken(kp.PrivateKey, "password-user", "example.com", auth.MethodPassword, time.Hour)
token, err := sessionkey.SignToken(kp.PrivateKey, "password-user", "", "example.com", auth.MethodPassword, nil, nil, time.Hour)
require.NoError(t, err)
// First scheme (PIN) always fails, second scheme (password) succeeds.
@@ -446,7 +482,7 @@ func TestProtect_MultipleSchemes(t *testing.T) {
return "", "password", nil
},
}
require.NoError(t, mw.AddDomain("example.com", []Scheme{pinScheme, passwordScheme}, kp.PublicKey, time.Hour, "", "", nil))
require.NoError(t, mw.AddDomain("example.com", []Scheme{pinScheme, passwordScheme}, kp.PublicKey, time.Hour, "", "", nil, false))
var backendCalled bool
backend := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
@@ -476,7 +512,7 @@ func TestProtect_InvalidTokenFromSchemeReturns400(t *testing.T) {
return "invalid-jwt-token", "", nil
},
}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
handler := mw.Protect(newPassthroughHandler())
@@ -500,7 +536,7 @@ func TestAddDomain_RandomBytes32NotEd25519(t *testing.T) {
key := base64.StdEncoding.EncodeToString(randomBytes)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
err = mw.AddDomain("example.com", []Scheme{scheme}, key, time.Hour, "", "", nil)
err = mw.AddDomain("example.com", []Scheme{scheme}, key, time.Hour, "", "", nil, false)
require.NoError(t, err, "any 32-byte key should be accepted at registration time")
}
@@ -509,10 +545,10 @@ func TestAddDomain_InvalidKeyDoesNotCorruptExistingConfig(t *testing.T) {
kp := generateTestKeyPair(t)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
// Attempt to overwrite with an invalid key.
err := mw.AddDomain("example.com", []Scheme{scheme}, "bad", time.Hour, "", "", nil)
err := mw.AddDomain("example.com", []Scheme{scheme}, "bad", time.Hour, "", "", nil, false)
require.Error(t, err)
// The original valid config should still be intact.
@@ -536,7 +572,7 @@ func TestProtect_FailedPinAuthCapturesAuthMethod(t *testing.T) {
return "", "pin", nil
},
}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
capturedData := proxy.NewCapturedData("")
handler := mw.Protect(newPassthroughHandler())
@@ -563,7 +599,7 @@ func TestProtect_FailedPasswordAuthCapturesAuthMethod(t *testing.T) {
return "", "password", nil
},
}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
capturedData := proxy.NewCapturedData("")
handler := mw.Protect(newPassthroughHandler())
@@ -590,7 +626,7 @@ func TestProtect_NoCredentialsDoesNotCaptureAuthMethod(t *testing.T) {
return "", "pin", nil
},
}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
capturedData := proxy.NewCapturedData("")
handler := mw.Protect(newPassthroughHandler())
@@ -678,7 +714,7 @@ func TestCheckIPRestrictions_UnparseableAddress(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil, nil)
err := mw.AddDomain("example.com", nil, "", 0, "acc1", "svc1",
restrict.ParseFilter(restrict.FilterConfig{AllowedCIDRs: []string{"10.0.0.0/8"}}))
restrict.ParseFilter(restrict.FilterConfig{AllowedCIDRs: []string{"10.0.0.0/8"}}), false)
require.NoError(t, err)
handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -714,7 +750,7 @@ func TestCheckIPRestrictions_UsesCapturedDataClientIP(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil, nil)
err := mw.AddDomain("example.com", nil, "", 0, "acc1", "svc1",
restrict.ParseFilter(restrict.FilterConfig{AllowedCIDRs: []string{"203.0.113.0/24"}}))
restrict.ParseFilter(restrict.FilterConfig{AllowedCIDRs: []string{"203.0.113.0/24"}}), false)
require.NoError(t, err)
handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -755,7 +791,7 @@ func TestCheckIPRestrictions_NilGeoWithCountryRules(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil, nil)
err := mw.AddDomain("example.com", nil, "", 0, "acc1", "svc1",
restrict.ParseFilter(restrict.FilterConfig{AllowedCountries: []string{"US"}}))
restrict.ParseFilter(restrict.FilterConfig{AllowedCountries: []string{"US"}}), false)
require.NoError(t, err)
handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -781,11 +817,12 @@ func TestProtect_OIDCOnlyRedirectsDirectly(t *testing.T) {
return "", oidcURL, nil
},
}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
handler := mw.Protect(newPassthroughHandler())
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
req := httptest.NewRequest(http.MethodGet, "https://example.com/", nil)
req.TLS = &tls.ConnectionState{}
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
@@ -809,11 +846,12 @@ func TestProtect_OIDCWithOtherMethodShowsLoginPage(t *testing.T) {
return "", "pin", nil
},
}
require.NoError(t, mw.AddDomain("example.com", []Scheme{oidcScheme, pinScheme}, kp.PublicKey, time.Hour, "", "", nil))
require.NoError(t, mw.AddDomain("example.com", []Scheme{oidcScheme, pinScheme}, kp.PublicKey, time.Hour, "", "", nil, false))
handler := mw.Protect(newPassthroughHandler())
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
req := httptest.NewRequest(http.MethodGet, "https://example.com/", nil)
req.TLS = &tls.ConnectionState{}
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
@@ -834,7 +872,7 @@ func (m *mockAuthenticator) Authenticate(ctx context.Context, in *proto.Authenti
// returns a signed session token when the expected header value is provided.
func newHeaderSchemeWithToken(t *testing.T, kp *sessionkey.KeyPair, headerName, expectedValue string) Header {
t.Helper()
token, err := sessionkey.SignToken(kp.PrivateKey, "header-user", "example.com", auth.MethodHeader, time.Hour)
token, err := sessionkey.SignToken(kp.PrivateKey, "header-user", "", "example.com", auth.MethodHeader, nil, nil, time.Hour)
require.NoError(t, err)
mock := &mockAuthenticator{fn: func(_ context.Context, req *proto.AuthenticateRequest) (*proto.AuthenticateResponse, error) {
@@ -852,7 +890,7 @@ func TestProtect_HeaderAuth_ForwardsOnSuccess(t *testing.T) {
kp := generateTestKeyPair(t)
hdr := newHeaderSchemeWithToken(t, kp, "X-API-Key", "secret-key")
require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr}, kp.PublicKey, time.Hour, "acc1", "svc1", nil))
require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr}, kp.PublicKey, time.Hour, "acc1", "svc1", nil, false))
var backendCalled bool
capturedData := proxy.NewCapturedData("")
@@ -895,7 +933,7 @@ func TestProtect_HeaderAuth_MissingHeaderFallsThrough(t *testing.T) {
hdr := newHeaderSchemeWithToken(t, kp, "X-API-Key", "secret-key")
// Also add a PIN scheme so we can verify fallthrough behavior.
pinScheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr, pinScheme}, kp.PublicKey, time.Hour, "acc1", "svc1", nil))
require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr, pinScheme}, kp.PublicKey, time.Hour, "acc1", "svc1", nil, false))
handler := mw.Protect(newPassthroughHandler())
@@ -915,7 +953,7 @@ func TestProtect_HeaderAuth_WrongValueReturns401(t *testing.T) {
return &proto.AuthenticateResponse{Success: false}, nil
}}
hdr := NewHeader(mock, "svc1", "acc1", "X-API-Key")
require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr}, kp.PublicKey, time.Hour, "acc1", "svc1", nil))
require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr}, kp.PublicKey, time.Hour, "acc1", "svc1", nil, false))
capturedData := proxy.NewCapturedData("")
handler := mw.Protect(newPassthroughHandler())
@@ -938,7 +976,7 @@ func TestProtect_HeaderAuth_InfraErrorReturns502(t *testing.T) {
return nil, errors.New("gRPC unavailable")
}}
hdr := NewHeader(mock, "svc1", "acc1", "X-API-Key")
require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr}, kp.PublicKey, time.Hour, "acc1", "svc1", nil))
require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr}, kp.PublicKey, time.Hour, "acc1", "svc1", nil, false))
handler := mw.Protect(newPassthroughHandler())
@@ -955,7 +993,7 @@ func TestProtect_HeaderAuth_SubsequentRequestUsesSessionCookie(t *testing.T) {
kp := generateTestKeyPair(t)
hdr := newHeaderSchemeWithToken(t, kp, "X-API-Key", "secret-key")
require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr}, kp.PublicKey, time.Hour, "acc1", "svc1", nil))
require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr}, kp.PublicKey, time.Hour, "acc1", "svc1", nil, false))
handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
@@ -1006,7 +1044,7 @@ func TestProtect_HeaderAuth_MultipleValuesSameHeader(t *testing.T) {
mock := &mockAuthenticator{fn: func(_ context.Context, req *proto.AuthenticateRequest) (*proto.AuthenticateResponse, error) {
ha := req.GetHeaderAuth()
if ha != nil && accepted[ha.GetHeaderValue()] {
token, err := sessionkey.SignToken(kp.PrivateKey, "header-user", "example.com", auth.MethodHeader, time.Hour)
token, err := sessionkey.SignToken(kp.PrivateKey, "header-user", "", "example.com", auth.MethodHeader, nil, nil, time.Hour)
require.NoError(t, err)
return &proto.AuthenticateResponse{Success: true, SessionToken: token}, nil
}
@@ -1015,7 +1053,7 @@ func TestProtect_HeaderAuth_MultipleValuesSameHeader(t *testing.T) {
// Single Header scheme (as if one entry existed), but the mock checks both values.
hdr := NewHeader(mock, "svc1", "acc1", "Authorization")
require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr}, kp.PublicKey, time.Hour, "acc1", "svc1", nil))
require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr}, kp.PublicKey, time.Hour, "acc1", "svc1", nil, false))
var backendCalled bool
handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
@@ -1059,3 +1097,173 @@ func TestProtect_HeaderAuth_MultipleValuesSameHeader(t *testing.T) {
assert.False(t, backendCalled, "unknown token should be rejected")
})
}
// TestProtect_OIDCOnPlainHTTP_BlockedWith400 verifies that when an OIDC
// scheme is configured and the request arrived without TLS, the middleware
// short-circuits with a 400 instead of dispatching to the IdP redirect.
func TestProtect_OIDCOnPlainHTTP_BlockedWith400(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil, nil)
kp := generateTestKeyPair(t)
scheme := &stubScheme{
method: auth.MethodOIDC,
authFn: func(_ *http.Request) (string, string, error) {
return "", "https://idp.example.com/authorize", nil
},
}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
handler := mw.Protect(newPassthroughHandler())
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusBadRequest, rec.Code, "OIDC over plain HTTP should be rejected")
assert.Contains(t, rec.Body.String(), "OIDC requires TLS", "response body should explain the rejection")
}
// TestProtect_OIDCOverTLS_NotBlocked confirms the same configuration works
// over TLS — the block only fires on plain HTTP.
func TestProtect_OIDCOverTLS_NotBlocked(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil, nil)
kp := generateTestKeyPair(t)
scheme := &stubScheme{
method: auth.MethodOIDC,
authFn: func(_ *http.Request) (string, string, error) {
return "", "https://idp.example.com/authorize", nil
},
}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
handler := mw.Protect(newPassthroughHandler())
req := httptest.NewRequest(http.MethodGet, "https://example.com/", nil)
req.TLS = &tls.ConnectionState{}
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusFound, rec.Code, "OIDC over TLS should redirect to IdP")
}
// TestProtect_NonOIDCSchemes_PlainHTTP_NotBlocked confirms that the OIDC
// block only fires when an OIDC scheme is configured. PIN-only domains
// pass through normally on plain HTTP.
func TestProtect_NonOIDCSchemes_PlainHTTP_NotBlocked(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil, nil)
kp := generateTestKeyPair(t)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
handler := mw.Protect(newPassthroughHandler())
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusUnauthorized, rec.Code, "PIN-only domain should serve the login page on plain HTTP")
}
// TestProtect_SessionCookieOnPlainHTTP_LogsWarn verifies that a request
// carrying a valid session cookie over plain HTTP is still forwarded but
// emits a WARN-level log line for the operator.
func TestProtect_SessionCookieOnPlainHTTP_LogsWarn(t *testing.T) {
logger, hook := newTestLogger()
mw := NewMiddleware(logger, nil, nil)
kp := generateTestKeyPair(t)
pinScheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
require.NoError(t, mw.AddDomain("example.com", []Scheme{pinScheme}, kp.PublicKey, time.Hour, "", "", nil, false))
token, err := sessionkey.SignToken(kp.PrivateKey, "user-1", "", "example.com", auth.MethodPassword, nil, nil, time.Hour)
require.NoError(t, err)
var backendCalled bool
handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
backendCalled = true
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
req.AddCookie(&http.Cookie{Name: auth.SessionCookieName, Value: token})
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.True(t, backendCalled, "backend should still be reached — we don't drop the cookie")
assert.Equal(t, http.StatusOK, rec.Code)
var found bool
for _, entry := range hook.entries() {
if entry.Level == log.WarnLevel && strings.Contains(entry.Message, "session cookie on plain HTTP path") {
found = true
break
}
}
assert.True(t, found, "expected WARN log for session cookie on plain HTTP")
}
// TestProtect_SessionCookieOverTLS_NoWarn confirms the WARN only fires
// on plain HTTP — TLS requests with a session cookie behave as before.
func TestProtect_SessionCookieOverTLS_NoWarn(t *testing.T) {
logger, hook := newTestLogger()
mw := NewMiddleware(logger, nil, nil)
kp := generateTestKeyPair(t)
pinScheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
require.NoError(t, mw.AddDomain("example.com", []Scheme{pinScheme}, kp.PublicKey, time.Hour, "", "", nil, false))
token, err := sessionkey.SignToken(kp.PrivateKey, "user-1", "", "example.com", auth.MethodPassword, nil, nil, time.Hour)
require.NoError(t, err)
handler := mw.Protect(newPassthroughHandler())
req := httptest.NewRequest(http.MethodGet, "https://example.com/", nil)
req.TLS = &tls.ConnectionState{}
req.AddCookie(&http.Cookie{Name: auth.SessionCookieName, Value: token})
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
for _, entry := range hook.entries() {
assert.NotContains(t, entry.Message, "session cookie on plain HTTP path", "no plain-HTTP cookie warn expected over TLS")
}
}
// captureHook is a minimal logrus hook that records emitted entries for
// inspection in tests. It avoids pulling in the full sirupsen test
// helpers package (which the rest of the codebase doesn't use).
type captureHook struct {
mu sync.Mutex
records []log.Entry
}
func (h *captureHook) Levels() []log.Level { return log.AllLevels }
func (h *captureHook) Fire(entry *log.Entry) error {
h.mu.Lock()
defer h.mu.Unlock()
h.records = append(h.records, *entry)
return nil
}
func (h *captureHook) entries() []log.Entry {
h.mu.Lock()
defer h.mu.Unlock()
return append([]log.Entry{}, h.records...)
}
// newTestLogger builds an isolated logrus logger with a capture hook so
// tests can assert on emitted records without contending on the global
// logger.
func newTestLogger() (*log.Logger, *captureHook) {
hook := &captureHook{}
logger := log.New()
logger.SetOutput(io.Discard)
logger.AddHook(hook)
logger.SetLevel(log.DebugLevel)
return logger, hook
}

View File

@@ -0,0 +1,171 @@
package auth
import (
"context"
"net/netip"
"sync"
"time"
"golang.org/x/sync/singleflight"
"github.com/netbirdio/netbird/proxy/internal/types"
"github.com/netbirdio/netbird/shared/management/proto"
)
// tunnelCacheTTL caps how long a positive ValidateTunnelPeer result is
// reused before re-fetching from management. 5 minutes balances freshness
// against management load on busy mesh networks.
const tunnelCacheTTL = 30 * time.Second
// tunnelCachePerAccount caps the number of cached identities per account.
// Bounded eviction avoids memory growth in pathological cases (huge peer
// roster, brief request bursts) while staying generous for normal use.
const tunnelCachePerAccount = 1024
// tunnelCacheKey identifies a cached entry by tunnel IP and originating
// account. Domain is part of the value, not the key, because the
// management response is per (account, IP) — domain only gates whether a
// re-fetch is needed if the operator is accessing a different service.
type tunnelCacheKey struct {
accountID types.AccountID
tunnelIP netip.Addr
domain string
}
// tunnelCacheEntry stores a positive validation response with the time it
// was minted. Entries past tunnelCacheTTL are treated as misses.
type tunnelCacheEntry struct {
resp *proto.ValidateTunnelPeerResponse
cachedAt time.Time
}
// tunnelValidationCache memoizes ValidateTunnelPeer responses keyed by
// (accountID, tunnelIP, domain). Only successful, valid responses are
// cached — denials skip the cache so policy changes apply immediately.
// Single-flight de-duplicates concurrent fetches for the same key so a
// burst of cold requests collapses into a single RPC.
type tunnelValidationCache struct {
mu sync.Mutex
entries map[types.AccountID]*accountBucket
flight singleflight.Group
ttl time.Duration
maxSize int
now func() time.Time
}
// accountBucket holds the cached entries for a single account, with a
// FIFO eviction queue used when the bucket exceeds maxSize.
type accountBucket struct {
items map[tunnelCacheKey]tunnelCacheEntry
order []tunnelCacheKey
}
// newTunnelValidationCache constructs a cache with default TTL and bounds.
func newTunnelValidationCache() *tunnelValidationCache {
return &tunnelValidationCache{
entries: make(map[types.AccountID]*accountBucket),
ttl: tunnelCacheTTL,
maxSize: tunnelCachePerAccount,
now: time.Now,
}
}
// get returns a cached response for the key, or nil when missing or
// expired. Expired entries are evicted lazily on read.
func (c *tunnelValidationCache) get(key tunnelCacheKey) *proto.ValidateTunnelPeerResponse {
c.mu.Lock()
defer c.mu.Unlock()
bucket, ok := c.entries[key.accountID]
if !ok {
return nil
}
entry, ok := bucket.items[key]
if !ok {
return nil
}
if c.now().Sub(entry.cachedAt) > c.ttl {
delete(bucket.items, key)
bucket.order = removeKey(bucket.order, key)
return nil
}
return entry.resp
}
// put records a positive response under the key. Evicts the oldest entry
// in the account's bucket when the bound is exceeded.
func (c *tunnelValidationCache) put(key tunnelCacheKey, resp *proto.ValidateTunnelPeerResponse) {
c.mu.Lock()
defer c.mu.Unlock()
bucket, ok := c.entries[key.accountID]
if !ok {
bucket = &accountBucket{items: make(map[tunnelCacheKey]tunnelCacheEntry)}
c.entries[key.accountID] = bucket
}
if _, exists := bucket.items[key]; !exists {
bucket.order = append(bucket.order, key)
}
bucket.items[key] = tunnelCacheEntry{resp: resp, cachedAt: c.now()}
for len(bucket.order) > c.maxSize {
oldest := bucket.order[0]
bucket.order = bucket.order[1:]
delete(bucket.items, oldest)
}
}
// removeKey drops the first occurrence of needle from order. The cache
// uses small slices so a linear scan is cheaper than a map+slice combo.
func removeKey(order []tunnelCacheKey, needle tunnelCacheKey) []tunnelCacheKey {
for i, k := range order {
if k == needle {
return append(order[:i], order[i+1:]...)
}
}
return order
}
// flightKey turns a cache key into a single-flight string. AccountID and
// IP isolation by themselves are insufficient because different domains
// for the same peer/account may have different group access.
func flightKey(key tunnelCacheKey) string {
return string(key.accountID) + "|" + key.tunnelIP.String() + "|" + key.domain
}
// validateTunnelPeerFn is the RPC entry point the cache wraps. It matches
// the SessionValidator.ValidateTunnelPeer signature without exposing the
// gRPC option variadic, since callers don't need it on the cache hot path.
type validateTunnelPeerFn func(ctx context.Context, req *proto.ValidateTunnelPeerRequest) (*proto.ValidateTunnelPeerResponse, error)
// fetch returns a cached response when present, otherwise calls validate
// under single-flight and caches the result. Denied responses pass
// through but are not cached so policy changes apply immediately.
func (c *tunnelValidationCache) fetch(ctx context.Context, key tunnelCacheKey, validate validateTunnelPeerFn) (*proto.ValidateTunnelPeerResponse, bool, error) {
if resp := c.get(key); resp != nil {
return resp, true, nil
}
flight := flightKey(key)
res, err, _ := c.flight.Do(flight, func() (any, error) {
if cached := c.get(key); cached != nil {
return cached, nil
}
resp, err := validate(ctx, &proto.ValidateTunnelPeerRequest{
TunnelIp: key.tunnelIP.String(),
Domain: key.domain,
})
if err != nil {
return nil, err
}
if resp.GetValid() && resp.GetSessionToken() != "" {
c.put(key, resp)
}
return resp, nil
})
if err != nil {
return nil, false, err
}
resp, _ := res.(*proto.ValidateTunnelPeerResponse)
return resp, false, nil
}

View File

@@ -0,0 +1,171 @@
package auth
import (
"context"
"net/netip"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/proxy/internal/types"
"github.com/netbirdio/netbird/shared/management/proto"
)
func newTestKey(account types.AccountID, ip string, domain string) tunnelCacheKey {
return tunnelCacheKey{
accountID: account,
tunnelIP: netip.MustParseAddr(ip),
domain: domain,
}
}
func TestTunnelCache_HitSkipsRPC(t *testing.T) {
cache := newTunnelValidationCache()
key := newTestKey("acct-1", "100.64.0.10", "svc.example")
var calls int32
validate := func(_ context.Context, req *proto.ValidateTunnelPeerRequest) (*proto.ValidateTunnelPeerResponse, error) {
atomic.AddInt32(&calls, 1)
return &proto.ValidateTunnelPeerResponse{Valid: true, SessionToken: "tok", UserId: "user-1"}, nil
}
resp, fromCache, err := cache.fetch(context.Background(), key, validate)
require.NoError(t, err)
require.NotNil(t, resp, "first fetch returns RPC response")
assert.False(t, fromCache, "first fetch must not be cached")
resp2, fromCache2, err := cache.fetch(context.Background(), key, validate)
require.NoError(t, err)
require.NotNil(t, resp2, "second fetch returns cached response")
assert.True(t, fromCache2, "second fetch must be served from cache")
assert.Equal(t, "user-1", resp2.GetUserId(), "cached response should preserve user identity")
assert.Equal(t, int32(1), atomic.LoadInt32(&calls), "validate should run exactly once with one cache hit")
}
func TestTunnelCache_ExpiredEntryRefetches(t *testing.T) {
cache := newTunnelValidationCache()
clock := time.Now()
cache.now = func() time.Time { return clock }
key := newTestKey("acct-1", "100.64.0.10", "svc.example")
var calls int32
validate := func(_ context.Context, _ *proto.ValidateTunnelPeerRequest) (*proto.ValidateTunnelPeerResponse, error) {
atomic.AddInt32(&calls, 1)
return &proto.ValidateTunnelPeerResponse{Valid: true, SessionToken: "tok"}, nil
}
_, _, err := cache.fetch(context.Background(), key, validate)
require.NoError(t, err)
assert.Equal(t, int32(1), atomic.LoadInt32(&calls), "first fetch issues one RPC")
clock = clock.Add(tunnelCacheTTL + time.Second)
_, fromCache, err := cache.fetch(context.Background(), key, validate)
require.NoError(t, err)
assert.False(t, fromCache, "expired entry must miss the cache")
assert.Equal(t, int32(2), atomic.LoadInt32(&calls), "expired entry forces a re-fetch")
}
func TestTunnelCache_DeniedResponseNotCached(t *testing.T) {
cache := newTunnelValidationCache()
key := newTestKey("acct-1", "100.64.0.10", "svc.example")
var calls int32
validate := func(_ context.Context, _ *proto.ValidateTunnelPeerRequest) (*proto.ValidateTunnelPeerResponse, error) {
atomic.AddInt32(&calls, 1)
return &proto.ValidateTunnelPeerResponse{Valid: false, DeniedReason: "not_in_group"}, nil
}
for i := 0; i < 3; i++ {
_, _, err := cache.fetch(context.Background(), key, validate)
require.NoError(t, err, "fetch must not error on denied response")
}
assert.Equal(t, int32(3), atomic.LoadInt32(&calls), "denied responses bypass the cache so policy changes apply immediately")
}
func TestTunnelCache_ConcurrentColdHitsCoalesce(t *testing.T) {
cache := newTunnelValidationCache()
key := newTestKey("acct-1", "100.64.0.10", "svc.example")
gate := make(chan struct{})
var calls int32
validate := func(_ context.Context, _ *proto.ValidateTunnelPeerRequest) (*proto.ValidateTunnelPeerResponse, error) {
atomic.AddInt32(&calls, 1)
<-gate
return &proto.ValidateTunnelPeerResponse{Valid: true, SessionToken: "tok"}, nil
}
const workers = 16
var wg sync.WaitGroup
wg.Add(workers)
results := make([]bool, workers)
for i := 0; i < workers; i++ {
go func(idx int) {
defer wg.Done()
resp, _, err := cache.fetch(context.Background(), key, validate)
results[idx] = err == nil && resp.GetValid()
}(i)
}
time.Sleep(20 * time.Millisecond)
close(gate)
wg.Wait()
for i, ok := range results {
assert.Truef(t, ok, "worker %d should observe a successful response", i)
}
assert.Equal(t, int32(1), atomic.LoadInt32(&calls), "single-flight must collapse concurrent cold fetches into one RPC")
}
func TestTunnelCache_PerAccountIsolation(t *testing.T) {
cache := newTunnelValidationCache()
keyA := newTestKey("acct-a", "100.64.0.10", "svc.example")
keyB := newTestKey("acct-b", "100.64.0.10", "svc.example")
var callsA, callsB int32
validateA := func(_ context.Context, _ *proto.ValidateTunnelPeerRequest) (*proto.ValidateTunnelPeerResponse, error) {
atomic.AddInt32(&callsA, 1)
return &proto.ValidateTunnelPeerResponse{Valid: true, SessionToken: "tok-a", UserId: "user-a"}, nil
}
validateB := func(_ context.Context, _ *proto.ValidateTunnelPeerRequest) (*proto.ValidateTunnelPeerResponse, error) {
atomic.AddInt32(&callsB, 1)
return &proto.ValidateTunnelPeerResponse{Valid: true, SessionToken: "tok-b", UserId: "user-b"}, nil
}
respA, _, err := cache.fetch(context.Background(), keyA, validateA)
require.NoError(t, err)
respB, _, err := cache.fetch(context.Background(), keyB, validateB)
require.NoError(t, err)
assert.Equal(t, "user-a", respA.GetUserId(), "account A response should belong to user-a")
assert.Equal(t, "user-b", respB.GetUserId(), "account B response must not be served from account A's cache")
assert.Equal(t, int32(1), atomic.LoadInt32(&callsA), "validateA called exactly once")
assert.Equal(t, int32(1), atomic.LoadInt32(&callsB), "validateB called exactly once")
}
func TestTunnelCache_BoundedSizeEvictsOldest(t *testing.T) {
cache := newTunnelValidationCache()
cache.maxSize = 2
validate := func(_ context.Context, req *proto.ValidateTunnelPeerRequest) (*proto.ValidateTunnelPeerResponse, error) {
return &proto.ValidateTunnelPeerResponse{Valid: true, SessionToken: "tok-" + req.GetTunnelIp()}, nil
}
keys := []tunnelCacheKey{
newTestKey("acct-1", "100.64.0.10", "svc"),
newTestKey("acct-1", "100.64.0.11", "svc"),
newTestKey("acct-1", "100.64.0.12", "svc"),
}
for _, k := range keys {
_, _, err := cache.fetch(context.Background(), k, validate)
require.NoError(t, err)
}
assert.Nil(t, cache.get(keys[0]), "oldest key should be evicted past maxSize")
assert.NotNil(t, cache.get(keys[1]), "second-newest must remain cached")
assert.NotNil(t, cache.get(keys[2]), "newest must remain cached")
}

View File

@@ -0,0 +1,325 @@
package auth
import (
"context"
"errors"
"net/http"
"net/http/httptest"
"net/netip"
"sync/atomic"
"testing"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"github.com/netbirdio/netbird/proxy/internal/proxy"
"github.com/netbirdio/netbird/shared/management/proto"
)
// stubSessionValidator records ValidateTunnelPeer calls and returns the
// pre-canned response. Counts let tests assert RPC traffic.
type stubSessionValidator struct {
respFn func(req *proto.ValidateTunnelPeerRequest) *proto.ValidateTunnelPeerResponse
respErr error
tunnelCalls atomic.Int32
}
func (s *stubSessionValidator) ValidateSession(_ context.Context, _ *proto.ValidateSessionRequest, _ ...grpc.CallOption) (*proto.ValidateSessionResponse, error) {
return &proto.ValidateSessionResponse{Valid: false}, nil
}
func (s *stubSessionValidator) ValidateTunnelPeer(_ context.Context, in *proto.ValidateTunnelPeerRequest, _ ...grpc.CallOption) (*proto.ValidateTunnelPeerResponse, error) {
s.tunnelCalls.Add(1)
if s.respErr != nil {
return nil, s.respErr
}
if s.respFn != nil {
return s.respFn(in), nil
}
return &proto.ValidateTunnelPeerResponse{Valid: false}, nil
}
func newTunnelMiddleware(t *testing.T, validator SessionValidator) *Middleware {
t.Helper()
mw := NewMiddleware(log.New(), validator, nil)
require.NoError(t, mw.AddDomain("svc.example", nil, "", 0, "acct-1", "svc-1", nil, false))
return mw
}
func newTunnelRequest(remoteAddr string) (*httptest.ResponseRecorder, *http.Request) {
w := httptest.NewRecorder()
r := httptest.NewRequest(http.MethodGet, "https://svc.example/", nil)
r.Host = "svc.example"
r.RemoteAddr = remoteAddr
return w, r
}
// TestForwardWithTunnelPeer_LocalLookupUnknownIPDeniesFast verifies the
// short-circuit: a tunnel IP not in the account's roster never reaches
// management's ValidateTunnelPeer.
func TestForwardWithTunnelPeer_LocalLookupUnknownIPDeniesFast(t *testing.T) {
validator := &stubSessionValidator{}
mw := newTunnelMiddleware(t, validator)
lookup := TunnelLookupFunc(func(_ netip.Addr) (PeerIdentity, bool) {
return PeerIdentity{}, false
})
w, r := newTunnelRequest("100.64.0.99:55555")
r = r.WithContext(WithTunnelLookup(r.Context(), lookup))
called := false
next := http.HandlerFunc(func(http.ResponseWriter, *http.Request) { called = true })
config, _ := mw.getDomainConfig("svc.example")
handled := mw.forwardWithTunnelPeer(w, r, "svc.example", config, next)
assert.False(t, handled, "unknown peer must fall through, not forward")
assert.False(t, called, "next handler must not run for unknown peer")
assert.Equal(t, int32(0), validator.tunnelCalls.Load(), "ValidateTunnelPeer must be skipped on local-lookup miss")
}
// TestForwardWithTunnelPeer_GroupsPropagateToCapturedData verifies the proxy
// surfaces the calling peer's group memberships from ValidateTunnelPeerResponse
// onto CapturedData so policy-aware middlewares can authorise without an
// extra management round-trip.
func TestForwardWithTunnelPeer_GroupsPropagateToCapturedData(t *testing.T) {
groups := []string{"engineering", "sre"}
validator := &stubSessionValidator{
respFn: func(_ *proto.ValidateTunnelPeerRequest) *proto.ValidateTunnelPeerResponse {
return &proto.ValidateTunnelPeerResponse{
Valid: true,
SessionToken: "tok",
UserId: "user-1",
PeerGroupIds: groups,
}
},
}
mw := newTunnelMiddleware(t, validator)
w, r := newTunnelRequest("100.64.0.10:55555")
cd := proxy.NewCapturedData("")
r = r.WithContext(proxy.WithCapturedData(r.Context(), cd))
called := false
next := http.HandlerFunc(func(http.ResponseWriter, *http.Request) { called = true })
config, _ := mw.getDomainConfig("svc.example")
handled := mw.forwardWithTunnelPeer(w, r, "svc.example", config, next)
require.True(t, handled, "valid tunnel-peer response must forward")
require.True(t, called, "next handler must run")
assert.Equal(t, "user-1", cd.GetUserID(), "user id must propagate from tunnel-peer response")
assert.Equal(t, groups, cd.GetUserGroups(), "peer group IDs must propagate from tunnel-peer response")
}
// TestForwardWithTunnelPeer_LocalLookupKnownPeerStillRPCs verifies that a
// known tunnel IP still triggers ValidateTunnelPeer for the user-identity
// tail (UserID + group access). Phase 3 only short-circuits the deny path.
func TestForwardWithTunnelPeer_LocalLookupKnownPeerStillRPCs(t *testing.T) {
validator := &stubSessionValidator{
respFn: func(_ *proto.ValidateTunnelPeerRequest) *proto.ValidateTunnelPeerResponse {
return &proto.ValidateTunnelPeerResponse{Valid: true, SessionToken: "tok", UserId: "user-1"}
},
}
mw := newTunnelMiddleware(t, validator)
knownIP := netip.MustParseAddr("100.64.0.10")
lookup := TunnelLookupFunc(func(ip netip.Addr) (PeerIdentity, bool) {
if ip == knownIP {
return PeerIdentity{PubKey: "pk", TunnelIP: ip, FQDN: "peer.netbird.cloud"}, true
}
return PeerIdentity{}, false
})
w, r := newTunnelRequest(knownIP.String() + ":55555")
r = r.WithContext(WithTunnelLookup(r.Context(), lookup))
called := false
next := http.HandlerFunc(func(http.ResponseWriter, *http.Request) { called = true })
config, _ := mw.getDomainConfig("svc.example")
handled := mw.forwardWithTunnelPeer(w, r, "svc.example", config, next)
assert.True(t, handled, "known peer with valid RPC response must forward")
assert.True(t, called, "next handler must run on success")
assert.Equal(t, int32(1), validator.tunnelCalls.Load(), "RPC must run for the user-identity tail when local lookup confirms the peer")
}
// TestForwardWithTunnelPeer_NoLookupKeepsLegacyPath ensures the existing
// behaviour stays intact on the host-level listener (no lookup attached).
func TestForwardWithTunnelPeer_NoLookupKeepsLegacyPath(t *testing.T) {
validator := &stubSessionValidator{
respFn: func(_ *proto.ValidateTunnelPeerRequest) *proto.ValidateTunnelPeerResponse {
return &proto.ValidateTunnelPeerResponse{Valid: true, SessionToken: "tok", UserId: "user-1"}
},
}
mw := newTunnelMiddleware(t, validator)
w, r := newTunnelRequest("100.64.0.10:55555")
called := false
next := http.HandlerFunc(func(http.ResponseWriter, *http.Request) { called = true })
config, _ := mw.getDomainConfig("svc.example")
handled := mw.forwardWithTunnelPeer(w, r, "svc.example", config, next)
assert.True(t, handled, "host-level path forwards on positive RPC result")
assert.True(t, called, "next handler runs on host-level success")
assert.Equal(t, int32(1), validator.tunnelCalls.Load(), "host-level path always RPCs (Phase 3 unchanged)")
}
// TestForwardWithTunnelPeer_RPCErrorFallsThrough validates that an RPC
// failure still falls through to the next scheme (no false positive).
func TestForwardWithTunnelPeer_RPCErrorFallsThrough(t *testing.T) {
validator := &stubSessionValidator{respErr: errors.New("management down")}
mw := newTunnelMiddleware(t, validator)
knownIP := netip.MustParseAddr("100.64.0.10")
lookup := TunnelLookupFunc(func(ip netip.Addr) (PeerIdentity, bool) {
return PeerIdentity{TunnelIP: ip}, true
})
w, r := newTunnelRequest(knownIP.String() + ":55555")
r = r.WithContext(WithTunnelLookup(r.Context(), lookup))
config, _ := mw.getDomainConfig("svc.example")
handled := mw.forwardWithTunnelPeer(w, r, "svc.example", config, http.HandlerFunc(func(http.ResponseWriter, *http.Request) {}))
assert.False(t, handled, "RPC error must let the caller try other schemes")
assert.Equal(t, int32(1), validator.tunnelCalls.Load(), "RPC was attempted exactly once")
}
// TestForwardWithTunnelPeer_CacheReusesPositiveResponse confirms the
// (account, IP, domain) cache prevents repeated RPCs for the same peer.
func TestForwardWithTunnelPeer_CacheReusesPositiveResponse(t *testing.T) {
validator := &stubSessionValidator{
respFn: func(_ *proto.ValidateTunnelPeerRequest) *proto.ValidateTunnelPeerResponse {
return &proto.ValidateTunnelPeerResponse{Valid: true, SessionToken: "tok", UserId: "user-1"}
},
}
mw := newTunnelMiddleware(t, validator)
for i := 0; i < 4; i++ {
w, r := newTunnelRequest("100.64.0.10:55555")
next := http.HandlerFunc(func(http.ResponseWriter, *http.Request) {})
config, _ := mw.getDomainConfig("svc.example")
handled := mw.forwardWithTunnelPeer(w, r, "svc.example", config, next)
require.True(t, handled, "iteration %d should forward", i)
}
assert.Equal(t, int32(1), validator.tunnelCalls.Load(), "subsequent forwards must hit the cache, not management")
}
// TestForwardWithTunnelPeer_RoutesAccountIDIntoCacheKey ensures cache keys
// honour account scoping — same tunnel IP on different accounts must not
// collide.
func TestForwardWithTunnelPeer_RoutesAccountIDIntoCacheKey(t *testing.T) {
validator := &stubSessionValidator{
respFn: func(req *proto.ValidateTunnelPeerRequest) *proto.ValidateTunnelPeerResponse {
return &proto.ValidateTunnelPeerResponse{Valid: true, SessionToken: "tok", UserId: "user"}
},
}
mw := NewMiddleware(log.New(), validator, nil)
require.NoError(t, mw.AddDomain("svc-a.example", nil, "", 0, "acct-a", "svc-a", nil, false))
require.NoError(t, mw.AddDomain("svc-b.example", nil, "", 0, "acct-b", "svc-b", nil, false))
for _, host := range []string{"svc-a.example", "svc-b.example"} {
w := httptest.NewRecorder()
r := httptest.NewRequest(http.MethodGet, "https://"+host+"/", nil)
r.Host = host
r.RemoteAddr = "100.64.0.10:55555"
config, _ := mw.getDomainConfig(host)
handled := mw.forwardWithTunnelPeer(w, r, host, config, http.HandlerFunc(func(http.ResponseWriter, *http.Request) {}))
require.True(t, handled, "host %s should forward", host)
}
assert.Equal(t, int32(2), validator.tunnelCalls.Load(), "cache must not collide across accounts even when tunnel IPs match")
}
// TestForwardWithTunnelPeer_LocalLookupShortCircuitDoesNotPopulateCache
// guarantees that the deny-fast path leaves the cache untouched, so a
// subsequent request from the same IP after the peerstore catches up
// goes through the normal RPC flow.
func TestForwardWithTunnelPeer_LocalLookupShortCircuitDoesNotPopulateCache(t *testing.T) {
validator := &stubSessionValidator{
respFn: func(_ *proto.ValidateTunnelPeerRequest) *proto.ValidateTunnelPeerResponse {
return &proto.ValidateTunnelPeerResponse{Valid: true, SessionToken: "tok"}
},
}
mw := newTunnelMiddleware(t, validator)
knownIP := netip.MustParseAddr("100.64.0.10")
known := false
lookup := TunnelLookupFunc(func(ip netip.Addr) (PeerIdentity, bool) {
if known && ip == knownIP {
return PeerIdentity{TunnelIP: ip}, true
}
return PeerIdentity{}, false
})
doRequest := func() bool {
w, r := newTunnelRequest(knownIP.String() + ":55555")
r = r.WithContext(WithTunnelLookup(r.Context(), lookup))
config, _ := mw.getDomainConfig("svc.example")
return mw.forwardWithTunnelPeer(w, r, "svc.example", config, http.HandlerFunc(func(http.ResponseWriter, *http.Request) {}))
}
require.False(t, doRequest(), "first request must short-circuit")
require.Equal(t, int32(0), validator.tunnelCalls.Load(), "short-circuit must not populate the cache")
known = true
require.True(t, doRequest(), "second request with peer in roster must forward via RPC")
assert.Equal(t, int32(1), validator.tunnelCalls.Load(), "RPC runs once after peerstore catches up")
}
func TestPrivateService_FailsClosedOnTunnelPeerFailure(t *testing.T) {
mw := NewMiddleware(log.New(), nil, nil)
require.NoError(t, mw.AddDomain("private.svc", nil, "", 0, "acct-1", "svc-1", nil, true))
called := false
handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
called = true
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodGet, "https://private.svc/", nil)
req.Host = "private.svc"
req.RemoteAddr = "100.64.0.10:55555"
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
assert.Equal(t, http.StatusForbidden, w.Code)
assert.False(t, called)
}
func TestPrivateService_ForwardsOnTunnelPeerSuccess(t *testing.T) {
validator := &stubSessionValidator{
respFn: func(_ *proto.ValidateTunnelPeerRequest) *proto.ValidateTunnelPeerResponse {
return &proto.ValidateTunnelPeerResponse{
Valid: true,
SessionToken: "tok",
UserId: "user-1",
}
},
}
mw := NewMiddleware(log.New(), validator, nil)
require.NoError(t, mw.AddDomain("private.svc", nil, "", 0, "acct-1", "svc-1", nil, true))
called := false
handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
called = true
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodGet, "https://private.svc/", nil)
req.Host = "private.svc"
req.RemoteAddr = "100.64.0.10:55555"
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
assert.True(t, called)
}

View File

@@ -11,7 +11,6 @@ import (
"net/url"
"strings"
"time"
)
// StatusFilters contains filter options for status queries.
@@ -160,6 +159,49 @@ func (c *Client) printClients(data map[string]any) {
for _, item := range clients {
c.printClientRow(item)
}
c.printInboundListeners(clients)
}
func (c *Client) printInboundListeners(clients []any) {
type row struct {
accountID string
tunnelIP string
httpsPort int
httpPort int
}
var rows []row
for _, item := range clients {
client, ok := item.(map[string]any)
if !ok {
continue
}
inbound, ok := client["inbound_listener"].(map[string]any)
if !ok {
continue
}
tunnelIP, _ := inbound["tunnel_ip"].(string)
httpsPort, _ := inbound["https_port"].(float64)
httpPort, _ := inbound["http_port"].(float64)
accountID, _ := client["account_id"].(string)
rows = append(rows, row{
accountID: accountID,
tunnelIP: tunnelIP,
httpsPort: int(httpsPort),
httpPort: int(httpPort),
})
}
if len(rows) == 0 {
return
}
_, _ = fmt.Fprintln(c.out)
_, _ = fmt.Fprintln(c.out, "Inbound listeners (per-account):")
_, _ = fmt.Fprintf(c.out, " %-38s %-20s %-7s %s\n", "ACCOUNT ID", "TUNNEL IP", "HTTPS", "HTTP")
_, _ = fmt.Fprintln(c.out, " "+strings.Repeat("-", 78))
for _, r := range rows {
_, _ = fmt.Fprintf(c.out, " %-38s %-20s %-7d %d\n", r.accountID, r.tunnelIP, r.httpsPort, r.httpPort)
}
}
func (c *Client) printClientRow(item any) {
@@ -219,7 +261,14 @@ func (c *Client) ClientStatus(ctx context.Context, accountID string, filters Sta
}
func (c *Client) printClientStatus(data map[string]any) {
_, _ = fmt.Fprintf(c.out, "Account: %v\n\n", data["account_id"])
_, _ = fmt.Fprintf(c.out, "Account: %v\n", data["account_id"])
if inbound, ok := data["inbound_listener"].(map[string]any); ok {
tunnelIP, _ := inbound["tunnel_ip"].(string)
httpsPort, _ := inbound["https_port"].(float64)
httpPort, _ := inbound["http_port"].(float64)
_, _ = fmt.Fprintf(c.out, "Inbound listener: %s (https=%d, http=%d)\n", tunnelIP, int(httpsPort), int(httpPort))
}
_, _ = fmt.Fprintln(c.out)
if status, ok := data["status"].(string); ok {
_, _ = fmt.Fprint(c.out, status)
}

View File

@@ -61,6 +61,23 @@ type clientProvider interface {
ListClientsForDebug() map[types.AccountID]roundtrip.ClientDebugInfo
}
// InboundListenerInfo describes a per-account inbound listener as
// surfaced through the debug HTTP handler. Mirrors the proto sub-message
// emitted with SendStatusUpdate so dashboards and CLI tooling see the
// same shape.
type InboundListenerInfo struct {
TunnelIP string `json:"tunnel_ip"`
HTTPSPort uint16 `json:"https_port"`
HTTPPort uint16 `json:"http_port"`
}
// InboundProvider exposes per-account inbound listener state. Optional;
// when nil the debug endpoint omits the inbound section entirely so the
// existing JSON shape stays additive.
type InboundProvider interface {
InboundListeners() map[types.AccountID]InboundListenerInfo
}
// healthChecker provides health probe state.
type healthChecker interface {
ReadinessProbe() bool
@@ -80,6 +97,7 @@ type Handler struct {
provider clientProvider
health healthChecker
certStatus certStatus
inbound InboundProvider
logger *log.Logger
startTime time.Time
templates *template.Template
@@ -108,6 +126,13 @@ func (h *Handler) SetCertStatus(cs certStatus) {
h.certStatus = cs
}
// SetInboundProvider wires per-account inbound listener observability.
// Pass nil (or skip the call) to keep the inbound section out of debug
// responses on proxies that don't run --private-inbound.
func (h *Handler) SetInboundProvider(p InboundProvider) {
h.inbound = p
}
func (h *Handler) loadTemplates() error {
tmpl, err := template.ParseFS(templateFS, "templates/*.html")
if err != nil {
@@ -323,23 +348,35 @@ func (h *Handler) handleListClients(w http.ResponseWriter, _ *http.Request, want
sortedIDs := sortedAccountIDs(clients)
if wantJSON {
var inboundAll map[types.AccountID]InboundListenerInfo
if h.inbound != nil {
inboundAll = h.inbound.InboundListeners()
}
clientsJSON := make([]map[string]interface{}, 0, len(clients))
for _, id := range sortedIDs {
info := clients[id]
clientsJSON = append(clientsJSON, map[string]interface{}{
row := map[string]interface{}{
"account_id": info.AccountID,
"service_count": info.ServiceCount,
"service_keys": info.ServiceKeys,
"has_client": info.HasClient,
"created_at": info.CreatedAt,
"age": time.Since(info.CreatedAt).Round(time.Second).String(),
})
}
if inb, ok := inboundAll[id]; ok {
row["inbound_listener"] = inb
}
clientsJSON = append(clientsJSON, row)
}
h.writeJSON(w, map[string]interface{}{
resp := map[string]interface{}{
"uptime": time.Since(h.startTime).Round(time.Second).String(),
"client_count": len(clients),
"clients": clientsJSON,
})
}
if len(inboundAll) > 0 {
resp["inbound_listener_count"] = len(inboundAll)
}
h.writeJSON(w, resp)
return
}
@@ -421,10 +458,14 @@ func (h *Handler) handleClientStatus(w http.ResponseWriter, r *http.Request, acc
})
if wantJSON {
h.writeJSON(w, map[string]interface{}{
resp := map[string]interface{}{
"account_id": accountID,
"status": overview.FullDetailSummary(),
})
}
if info, ok := h.inboundInfoFor(accountID); ok {
resp["inbound_listener"] = info
}
h.writeJSON(w, resp)
return
}
@@ -437,6 +478,18 @@ func (h *Handler) handleClientStatus(w http.ResponseWriter, r *http.Request, acc
h.renderTemplate(w, "clientDetail", data)
}
// inboundInfoFor returns the inbound listener info for an account, or
// ok=false when no inbound provider is wired or the account has no live
// listener.
func (h *Handler) inboundInfoFor(accountID types.AccountID) (InboundListenerInfo, bool) {
if h.inbound == nil {
return InboundListenerInfo{}, false
}
all := h.inbound.InboundListeners()
info, ok := all[accountID]
return info, ok
}
func (h *Handler) handleClientSyncResponse(w http.ResponseWriter, _ *http.Request, accountID types.AccountID, wantJSON bool) {
client, ok := h.provider.GetClient(accountID)
if !ok {

View File

@@ -52,8 +52,15 @@ type CapturedData struct {
origin ResponseOrigin
clientIP netip.Addr
userID string
authMethod string
metadata map[string]string
userEmail string
userGroups []string
// userGroupNames pairs positionally with userGroups; populated from
// the JWT's group_names claim or from ValidateSession/Tunnel
// responses. Slice may be shorter than userGroups for tokens minted
// before names were resolvable.
userGroupNames []string
authMethod string
metadata map[string]string
}
// NewCapturedData creates a CapturedData with the given request ID.
@@ -138,6 +145,81 @@ func (c *CapturedData) GetUserID() string {
return c.userID
}
// SetUserEmail records the authenticated user's email address. Used by
// policy-aware middlewares to stamp identity onto upstream requests
// (e.g. x-litellm-end-user-id) without a management round-trip.
func (c *CapturedData) SetUserEmail(email string) {
c.mu.Lock()
defer c.mu.Unlock()
c.userEmail = email
}
// GetUserEmail returns the authenticated user's email address. Returns
// the empty string when the auth path didn't carry an email (e.g.
// non-OIDC schemes or legacy JWTs minted before the email claim).
func (c *CapturedData) GetUserEmail() string {
c.mu.RLock()
defer c.mu.RUnlock()
return c.userEmail
}
// SetUserGroups records the authenticated user's group memberships so
// downstream policy-aware middlewares can authorise the request without
// an additional management round-trip. The auth middleware populates this
// from ValidateSessionResponse / ValidateTunnelPeerResponse and from the
// session JWT's groups claim on cookie-bearing requests.
func (c *CapturedData) SetUserGroups(groups []string) {
c.mu.Lock()
defer c.mu.Unlock()
if len(groups) == 0 {
c.userGroups = nil
return
}
c.userGroups = append(c.userGroups[:0], groups...)
}
// GetUserGroups returns a copy of the authenticated user's group
// memberships.
func (c *CapturedData) GetUserGroups() []string {
c.mu.RLock()
defer c.mu.RUnlock()
if len(c.userGroups) == 0 {
return nil
}
out := make([]string, len(c.userGroups))
copy(out, c.userGroups)
return out
}
// SetUserGroupNames records the human-readable display names for the
// user's groups, ordered identically to UserGroups (positional
// pairing). Stamped onto upstream requests as X-NetBird-Groups so
// downstream services can read names rather than opaque ids.
func (c *CapturedData) SetUserGroupNames(names []string) {
c.mu.Lock()
defer c.mu.Unlock()
if len(names) == 0 {
c.userGroupNames = nil
return
}
c.userGroupNames = append(c.userGroupNames[:0], names...)
}
// GetUserGroupNames returns a copy of the authenticated user's group
// display names. Position i pairs with UserGroups[i]. May be shorter
// than UserGroups for tokens minted before names were resolvable; the
// consumer should fall back to ids for missing positions.
func (c *CapturedData) GetUserGroupNames() []string {
c.mu.RLock()
defer c.mu.RUnlock()
if len(c.userGroupNames) == 0 {
return nil
}
out := make([]string, len(c.userGroupNames))
copy(out, c.userGroupNames)
return out
}
// SetAuthMethod sets the authentication method used.
func (c *CapturedData) SetAuthMethod(method string) {
c.mu.Lock()

View File

@@ -86,6 +86,9 @@ func (p *ReverseProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if pt.RequestTimeout > 0 {
ctx = types.WithDialTimeout(ctx, pt.RequestTimeout)
}
if pt.DirectUpstream {
ctx = roundtrip.WithDirectUpstream(ctx)
}
rewriteMatchedPath := result.matchedPath
if pt.PathRewrite == PathRewritePreserve {
@@ -142,6 +145,8 @@ func (p *ReverseProxy) rewriteFunc(target *url.URL, matchedPath string, passHost
r.Out.Header.Set(k, v)
}
stampNetBirdIdentity(r)
clientIP := extractHostIP(r.In.RemoteAddr)
if isTrustedAddr(clientIP, p.trustedProxies) {
@@ -426,3 +431,46 @@ func opErrorContains(err error, substr string) bool {
}
return false
}
const (
// headerNetBirdUser carries the authenticated user's display identity
// (email when the peer is attached to a user, else peer name) onto
// upstream requests. Stripped from inbound requests before stamping
// so a client can't spoof identity by setting the header themselves.
headerNetBirdUser = "X-NetBird-User"
// headerNetBirdGroups carries the user's group display names as a
// comma-separated list. Falls back to group IDs at positions where a
// name wasn't available at session-mint time.
headerNetBirdGroups = "X-NetBird-Groups"
)
// stampNetBirdIdentity injects authenticated identity onto outbound
// requests as X-NetBird-User and X-NetBird-Groups. Always strips any
// client-sent values first (anti-spoof). Skips when the request didn't
// carry CapturedData (early-path errors, internal endpoints).
func stampNetBirdIdentity(r *httputil.ProxyRequest) {
r.Out.Header.Del(headerNetBirdUser)
r.Out.Header.Del(headerNetBirdGroups)
cd := CapturedDataFromContext(r.In.Context())
if cd == nil {
return
}
if email := cd.GetUserEmail(); email != "" {
r.Out.Header.Set(headerNetBirdUser, email)
}
groupIDs := cd.GetUserGroups()
if len(groupIDs) == 0 {
return
}
groupNames := cd.GetUserGroupNames()
labels := make([]string, len(groupIDs))
for i, id := range groupIDs {
if i < len(groupNames) && groupNames[i] != "" {
labels[i] = groupNames[i]
continue
}
labels[i] = id
}
r.Out.Header.Set(headerNetBirdGroups, strings.Join(labels, ","))
}

View File

@@ -1067,3 +1067,67 @@ func TestClassifyProxyError(t *testing.T) {
})
}
}
func TestStampNetBirdIdentity_NoCapturedData_StripsOnly(t *testing.T) {
target, _ := url.Parse("http://backend.internal:8080")
p := &ReverseProxy{forwardedProto: "auto"}
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
pr := newProxyRequest(t, "http://example.com/", "203.0.113.50:9999")
pr.In.Header.Set("X-NetBird-User", "spoofed@evil.io")
pr.In.Header.Set("X-NetBird-Groups", "admin")
pr.Out.Header = pr.In.Header.Clone()
rewrite(pr)
assert.Empty(t, pr.Out.Header.Get("X-NetBird-User"),
"client-supplied X-NetBird-User must be stripped when no captured identity is present")
assert.Empty(t, pr.Out.Header.Get("X-NetBird-Groups"),
"client-supplied X-NetBird-Groups must be stripped when no captured identity is present")
}
func TestStampNetBirdIdentity_StampsFromCapturedData(t *testing.T) {
target, _ := url.Parse("http://backend.internal:8080")
p := &ReverseProxy{forwardedProto: "auto"}
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
pr := newProxyRequest(t, "http://example.com/", "203.0.113.50:9999")
pr.In.Header.Set("X-NetBird-User", "spoofed@evil.io")
pr.Out.Header = pr.In.Header.Clone()
cd := NewCapturedData("req-1")
cd.SetUserEmail("alice@netbird.io")
cd.SetUserGroups([]string{"grp-eng", "grp-ops"})
cd.SetUserGroupNames([]string{"engineering", "operations"})
withCD := pr.In.WithContext(WithCapturedData(pr.In.Context(), cd))
pr.In = withCD
rewrite(pr)
assert.Equal(t, "alice@netbird.io", pr.Out.Header.Get("X-NetBird-User"),
"captured email must overwrite any spoofed value")
assert.Equal(t, "engineering,operations", pr.Out.Header.Get("X-NetBird-Groups"),
"group display names must be CSV-joined in positional order")
}
func TestStampNetBirdIdentity_FallsBackToGroupIDsWhenNameMissing(t *testing.T) {
target, _ := url.Parse("http://backend.internal:8080")
p := &ReverseProxy{forwardedProto: "auto"}
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
pr := newProxyRequest(t, "http://example.com/", "203.0.113.50:9999")
cd := NewCapturedData("req-1")
cd.SetUserEmail("bob@netbird.io")
cd.SetUserGroups([]string{"grp-a", "grp-b"})
cd.SetUserGroupNames([]string{"alpha"})
withCD := pr.In.WithContext(WithCapturedData(pr.In.Context(), cd))
pr.In = withCD
rewrite(pr)
assert.Equal(t, "alpha,grp-b", pr.Out.Header.Get("X-NetBird-Groups"),
"positions without a name must fall back to the group id")
}

View File

@@ -28,6 +28,10 @@ type PathTarget struct {
RequestTimeout time.Duration
PathRewrite PathRewriteMode
CustomHeaders map[string]string
// DirectUpstream selects the stdlib HTTP transport (host network stack)
// over the embedded NetBird WireGuard client when forwarding requests
// to this target. Default false → embedded client (existing behaviour).
DirectUpstream bool
}
// Mapping describes how a domain is routed by the HTTP reverse proxy.

View File

@@ -0,0 +1,163 @@
package roundtrip
import (
"bytes"
"crypto/tls"
"io"
"net"
"net/http"
"sort"
"strings"
"time"
log "github.com/sirupsen/logrus"
)
// upstreamLogBodyMax caps the request body bytes copied into the
// debug log line so a giant prompt or streamed payload doesn't fill the
// log. The body itself is always restored to the request unchanged.
const upstreamLogBodyMax = 4096
// MultiTransport dispatches each request to either the embedded NetBird
// http.RoundTripper or a stdlib http.Transport based on a per-request
// context flag set by the reverse-proxy rewrite step. When the flag is
// absent (the default for every existing target), requests follow the
// embedded NetBird path — current behaviour, preserved.
//
// The stdlib branch is used when a target was configured with
// `direct_upstream=true`. It dials via the host's network stack, which
// is what private (`netbird proxy`) deployments and centralised proxies
// fronting host-reachable upstreams (public APIs, LAN services,
// localhost sidecars) want.
type MultiTransport struct {
embedded http.RoundTripper
direct *http.Transport
insecure *http.Transport
logger *log.Logger
}
// NewMultiTransport wires both branches. embedded is the existing NetBird
// roundtripper; the direct branches are constructed here with sensible
// defaults that mirror Go's stdlib defaults plus a dial-timeout wrapper
// honouring the per-request value attached via types.WithDialTimeout.
// Pass embedded=nil to disable the WG branch entirely (every request
// will route direct, regardless of the context flag). logger may be
// nil; when nil the transport falls back to the logrus default
// instance.
func NewMultiTransport(embedded http.RoundTripper, logger *log.Logger) *MultiTransport {
dialer := &net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}
direct := &http.Transport{
DialContext: dialWithTimeout(dialer.DialContext),
ForceAttemptHTTP2: true,
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
}
insecure := direct.Clone()
insecure.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} //nolint:gosec // matches the embedded NetBird transport's per-target opt-in
if logger == nil {
logger = log.StandardLogger()
}
return &MultiTransport{
embedded: embedded,
direct: direct,
insecure: insecure,
logger: logger,
}
}
// RoundTrip dispatches by reading the direct-upstream flag from the request
// context. When set, the request is forwarded via the stdlib transport,
// honouring the existing per-request skip-TLS-verify flag. Otherwise it
// goes through the embedded NetBird roundtripper.
func (m *MultiTransport) RoundTrip(req *http.Request) (*http.Response, error) {
m.logUpstreamRequest(req)
if DirectUpstreamFromContext(req.Context()) || m.embedded == nil {
if skipTLSVerifyFromContext(req.Context()) {
return m.insecure.RoundTrip(req)
}
return m.direct.RoundTrip(req)
}
return m.embedded.RoundTrip(req)
}
// logUpstreamRequest emits the outbound request method, URL, headers,
// and a (capped) body snippet at info level for debugging. The body is
// read, copied into a snippet, and restored on the request so the
// actual upstream call sees it unchanged.
func (m *MultiTransport) logUpstreamRequest(req *http.Request) {
if req == nil {
return
}
body := snapshotRequestBody(req)
m.logger.Debugf("upstream request: method=%s url=%s host=%s body_length=%d headers=%s body=%s",
req.Method, req.URL.String(), req.Host, req.ContentLength, formatHeaders(req.Header), body)
}
// formatHeaders renders the headers as a deterministic single-line
// string. Multi-valued headers are joined with commas. Sensitive
// header values (the upstream Authorization NetBird just stamped, plus
// any cookie jar that survived) are redacted so logs don't leak the
// provider API key.
func formatHeaders(h http.Header) string {
if len(h) == 0 {
return "{}"
}
keys := make([]string, 0, len(h))
for k := range h {
keys = append(keys, k)
}
sort.Strings(keys)
var sb strings.Builder
sb.WriteByte('{')
for i, k := range keys {
if i > 0 {
sb.WriteByte(' ')
}
sb.WriteString(k)
sb.WriteByte('=')
sb.WriteString(redactHeaderValue(k, h.Values(k)))
}
sb.WriteByte('}')
return sb.String()
}
// redactHeaderValue replaces sensitive credentials with a placeholder.
// All other header values are joined with commas verbatim.
func redactHeaderValue(name string, values []string) string {
switch strings.ToLower(name) {
case "authorization", "proxy-authorization", "x-api-key", "api-key", "cookie":
return "[redacted]"
}
return strings.Join(values, ",")
}
// snapshotRequestBody returns a printable snippet of the request body
// (capped to upstreamLogBodyMax) and restores the body so downstream
// transports can still read it. Returns the empty string when there's
// no body or it can't be read.
func snapshotRequestBody(req *http.Request) string {
if req.Body == nil || req.Body == http.NoBody {
return ""
}
raw, err := io.ReadAll(req.Body)
if err != nil {
return ""
}
req.Body = io.NopCloser(bytes.NewReader(raw))
// Restore GetBody so transports performing redirects or retries
// still get a fresh reader.
req.GetBody = func() (io.ReadCloser, error) {
return io.NopCloser(bytes.NewReader(raw)), nil
}
if len(raw) > upstreamLogBodyMax {
return string(raw[:upstreamLogBodyMax]) + "...[truncated]"
}
return string(raw)
}

View File

@@ -0,0 +1,79 @@
package roundtrip
import (
"context"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// stubRoundTripper records whether RoundTrip was called and returns a
// canned response so tests can assert the dispatch decision without
// running a real network.
type stubRoundTripper struct {
called bool
body string
}
func (s *stubRoundTripper) RoundTrip(_ *http.Request) (*http.Response, error) {
s.called = true
return &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader(s.body)),
Header: http.Header{},
}, nil
}
func TestMultiTransport_DispatchesByContextFlag(t *testing.T) {
embedded := &stubRoundTripper{body: "embedded"}
mt := NewMultiTransport(embedded, nil)
t.Run("default routes to embedded", func(t *testing.T) {
embedded.called = false
req := httptest.NewRequest(http.MethodGet, "http://example.invalid", nil)
resp, err := mt.RoundTrip(req)
require.NoError(t, err, "embedded path must not error on stubbed transport")
require.NotNil(t, resp)
_ = resp.Body.Close()
assert.True(t, embedded.called, "request without WithDirectUpstream must hit the embedded transport")
})
t.Run("WithDirectUpstream skips embedded", func(t *testing.T) {
embedded.called = false
// Hit a server we control to verify the stdlib transport is used.
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
_, _ = io.WriteString(w, "direct")
}))
defer srv.Close()
req, err := http.NewRequestWithContext(WithDirectUpstream(context.Background()), http.MethodGet, srv.URL, nil)
require.NoError(t, err)
resp, err := mt.RoundTrip(req)
require.NoError(t, err, "direct path must dial via stdlib transport")
body, err := io.ReadAll(resp.Body)
_ = resp.Body.Close()
require.NoError(t, err)
assert.Equal(t, "direct", string(body), "stdlib transport must reach the test server")
assert.False(t, embedded.called, "WithDirectUpstream must bypass the embedded transport")
})
}
func TestMultiTransport_NilEmbeddedAlwaysDirects(t *testing.T) {
mt := NewMultiTransport(nil, nil)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
_, _ = io.WriteString(w, "ok")
}))
defer srv.Close()
req, err := http.NewRequest(http.MethodGet, srv.URL, nil)
require.NoError(t, err)
resp, err := mt.RoundTrip(req)
require.NoError(t, err, "nil embedded must fall through to direct without panic")
_ = resp.Body.Close()
assert.Equal(t, http.StatusOK, resp.StatusCode)
}

View File

@@ -7,6 +7,7 @@ import (
"fmt"
"net"
"net/http"
"net/netip"
"sync"
"time"
@@ -76,11 +77,11 @@ type clientEntry struct {
services map[ServiceKey]serviceInfo
createdAt time.Time
started bool
// ready is closed once the client has been fully initialized.
// Callers that find a pending entry wait on this channel before
// accessing the client. A nil initErr means success.
ready chan struct{}
initErr error
// inbound is opaque per-account state owned by the NetBird parent's
// ReadyHandler. The roundtrip package never inspects this value; it
// only stores it so RemovePeer / StopAll can hand it back to the
// matching StopHandler. Nil when no inbound integration is active.
inbound any
// Per-backend in-flight limiting keyed by target host:port.
// TODO: clean up stale entries when backend targets change.
inflightMu sync.Mutex
@@ -88,6 +89,19 @@ type clientEntry struct {
maxInflight int
}
// IdentityForIP resolves a tunnel IP to the peer identity locally known by
// this account's embedded client. Returns (pubKey, fqdn) on success.
// ok=false means the IP is not in the account's roster — callers can use
// that as a fast deny without round-tripping management. The returned
// strings carry only what the embedded peerstore exposes; user identity
// (UserID / Email / Groups) still flows through ValidateTunnelPeer.
func (e *clientEntry) IdentityForIP(ip netip.Addr) (pubKey, fqdn string, ok bool) {
if e == nil || e.client == nil || !ip.IsValid() {
return "", "", false
}
return e.client.IdentityForIP(ip)
}
// acquireInflight attempts to acquire an in-flight slot for the given backend.
// It returns a release function that must always be called, and true on success.
func (e *clientEntry) acquireInflight(backend backendKey) (release func(), ok bool) {
@@ -117,6 +131,12 @@ type ClientConfig struct {
MgmtAddr string
WGPort uint16
PreSharedKey string
// BlockInbound mirrors embed.Options.BlockInbound. Set to true on the
// standalone proxy where the embedded client never accepts inbound;
// set to false on the private/embedded proxy so the engine creates
// the ACL manager and applies management's per-policy firewall rules
// (which is what gates per-account inbound listeners on the netstack).
BlockInbound bool
}
type statusNotifier interface {
@@ -142,11 +162,14 @@ type NetBird struct {
clients map[types.AccountID]*clientEntry
initLogOnce sync.Once
statusNotifier statusNotifier
// OnAddPeer, when set, is called after AddPeer completes for a new account
// (i.e. when a new client was actually created, not when an existing one
// was reused). The duration covers keygen + gRPC CreateProxyPeer + embed.New.
OnAddPeer func(d time.Duration, err error)
// readyHandler runs after the embedded client for an account reports
// Ready. The opaque return value is stored on clientEntry and handed
// back to stopHandler when the entry is torn down. Nil disables the
// hook entirely (default for the standalone proxy).
readyHandler func(ctx context.Context, accountID types.AccountID, client *embed.Client) any
// stopHandler runs when an account's last service is removed (or the
// transport is shutting down). Receives whatever readyHandler returned.
stopHandler func(accountID types.AccountID, state any)
}
// ClientDebugInfo contains debug information about a client.
@@ -167,9 +190,6 @@ type skipTLSVerifyContextKey struct{}
// AddPeer registers a service for an account. If the account doesn't have a client yet,
// one is created by authenticating with the management server using the provided token.
// Multiple services can share the same client.
//
// Client creation (WG keygen, gRPC, embed.New) runs without holding clientsMux
// so that concurrent AddPeer calls for different accounts execute in parallel.
func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, key ServiceKey, authToken string, serviceID types.ServiceID) error {
si := serviceInfo{serviceID: serviceID}
@@ -177,23 +197,10 @@ func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, key Se
entry, exists := n.clients[accountID]
if exists {
ready := entry.ready
entry.services[key] = si
started := entry.started
n.clientsMux.Unlock()
// If the entry is still being initialized by another goroutine, wait.
if ready != nil {
select {
case <-ready:
case <-ctx.Done():
return ctx.Err()
}
if entry.initErr != nil {
return fmt.Errorf("peer initialization failed: %w", entry.initErr)
}
}
n.logger.WithFields(log.Fields{
"account_id": accountID,
"service_key": key,
@@ -210,43 +217,15 @@ func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, key Se
return nil
}
// Insert a placeholder so other goroutines calling AddPeer for the same
// account will wait on the ready channel instead of starting a second
// client creation.
entry = &clientEntry{
services: map[ServiceKey]serviceInfo{key: si},
ready: make(chan struct{}),
}
n.clients[accountID] = entry
n.clientsMux.Unlock()
createStart := time.Now()
created, err := n.createClientEntry(ctx, accountID, key, authToken, si)
if n.OnAddPeer != nil {
n.OnAddPeer(time.Since(createStart), err)
}
entry, err := n.createClientEntry(ctx, accountID, key, authToken, si)
if err != nil {
entry.initErr = err
close(entry.ready)
n.clientsMux.Lock()
delete(n.clients, accountID)
n.clientsMux.Unlock()
return err
}
// Transfer any services that were registered by concurrent AddPeer calls
// while we were creating the client.
n.clientsMux.Lock()
for k, v := range entry.services {
created.services[k] = v
}
created.ready = nil
n.clients[accountID] = created
n.clients[accountID] = entry
n.clientsMux.Unlock()
close(entry.ready)
n.logger.WithFields(log.Fields{
"account_id": accountID,
"service_key": key,
@@ -254,13 +233,13 @@ func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, key Se
// Attempt to start the client in the background; if this fails we will
// retry on the first request via RoundTrip.
go n.runClientStartup(ctx, accountID, created.client)
go n.runClientStartup(ctx, accountID, entry.client)
return nil
}
// createClientEntry generates a WireGuard keypair, authenticates with management,
// and creates an embedded NetBird client.
// and creates an embedded NetBird client. Must be called with clientsMux held.
func (n *NetBird) createClientEntry(ctx context.Context, accountID types.AccountID, key ServiceKey, authToken string, si serviceInfo) (*clientEntry, error) {
serviceID := si.serviceID
n.logger.WithFields(log.Fields{
@@ -318,7 +297,7 @@ func (n *NetBird) createClientEntry(ctx context.Context, accountID types.Account
ManagementURL: n.clientCfg.MgmtAddr,
PrivateKey: privateKey.String(),
LogLevel: log.WarnLevel.String(),
BlockInbound: true,
BlockInbound: n.clientCfg.BlockInbound,
WireguardPort: &wgPort,
PreSharedKey: n.clientCfg.PreSharedKey,
})
@@ -385,8 +364,25 @@ func (n *NetBird) runClientStartup(ctx context.Context, accountID types.AccountI
toNotify = append(toNotify, serviceNotification{key: key, serviceID: info.serviceID})
}
}
readyHandler := n.readyHandler
n.clientsMux.Unlock()
if readyHandler != nil {
state := readyHandler(ctx, accountID, client)
n.clientsMux.Lock()
if e, ok := n.clients[accountID]; ok {
e.inbound = state
} else if state != nil && n.stopHandler != nil {
// Account was removed while readyHandler ran; tear down the
// resources it just brought up.
stop := n.stopHandler
n.clientsMux.Unlock()
stop(accountID, state)
n.clientsMux.Lock()
}
n.clientsMux.Unlock()
}
if n.statusNotifier == nil {
return
}
@@ -432,11 +428,15 @@ func (n *NetBird) RemovePeer(ctx context.Context, accountID types.AccountID, key
stopClient := len(entry.services) == 0
var client *embed.Client
var transport, insecureTransport *http.Transport
var inbound any
var stopHandler func(types.AccountID, any)
if stopClient {
n.logger.WithField("account_id", accountID).Info("stopping client, no more services")
client = entry.client
transport = entry.transport
insecureTransport = entry.insecureTransport
inbound = entry.inbound
stopHandler = n.stopHandler
delete(n.clients, accountID)
} else {
n.logger.WithFields(log.Fields{
@@ -450,6 +450,9 @@ func (n *NetBird) RemovePeer(ctx context.Context, accountID types.AccountID, key
n.notifyDisconnect(ctx, accountID, key, si.serviceID)
if stopClient {
if inbound != nil && stopHandler != nil {
stopHandler(accountID, inbound)
}
transport.CloseIdleConnections()
insecureTransport.CloseIdleConnections()
if err := client.Stop(ctx); err != nil {
@@ -536,8 +539,12 @@ func (n *NetBird) StopAll(ctx context.Context) error {
n.clientsMux.Lock()
defer n.clientsMux.Unlock()
stopHandler := n.stopHandler
var merr *multierror.Error
for accountID, entry := range n.clients {
if entry.inbound != nil && stopHandler != nil {
stopHandler(accountID, entry.inbound)
}
entry.transport.CloseIdleConnections()
entry.insecureTransport.CloseIdleConnections()
if err := entry.client.Stop(ctx); err != nil {
@@ -590,6 +597,19 @@ func (n *NetBird) GetClient(accountID types.AccountID) (*embed.Client, bool) {
return entry.client, true
}
// IdentityForIP resolves a tunnel IP to a peer identity local to the given
// account. Delegates to clientEntry.IdentityForIP. Returns ok=false when
// the account has no client or the IP is not in its peerstore.
func (n *NetBird) IdentityForIP(accountID types.AccountID, ip netip.Addr) (pubKey, fqdn string, ok bool) {
n.clientsMux.RLock()
entry, exists := n.clients[accountID]
n.clientsMux.RUnlock()
if !exists {
return "", "", false
}
return entry.IdentityForIP(ip)
}
// ListClientsForDebug returns information about all clients for debug purposes.
func (n *NetBird) ListClientsForDebug() map[types.AccountID]ClientDebugInfo {
n.clientsMux.RLock()
@@ -645,6 +665,18 @@ func NewNetBird(proxyID, proxyAddr string, clientCfg ClientConfig, logger *log.L
}
}
// SetClientLifecycle registers callbacks that run when an embedded
// client becomes ready and when its entry is torn down. The opaque value
// returned by ready is stored on the entry and handed back to stop on
// cleanup. Must be called before AddPeer. A nil pair leaves the
// outbound-only behaviour intact.
func (n *NetBird) SetClientLifecycle(ready func(ctx context.Context, accountID types.AccountID, client *embed.Client) any, stop func(accountID types.AccountID, state any)) {
n.clientsMux.Lock()
defer n.clientsMux.Unlock()
n.readyHandler = ready
n.stopHandler = stop
}
// dialWithTimeout wraps a DialContext function so that any dial timeout
// stored in the context (via types.WithDialTimeout) is applied only to
// the connection establishment phase, not the full request lifetime.
@@ -687,3 +719,22 @@ func skipTLSVerifyFromContext(ctx context.Context) bool {
v, _ := ctx.Value(skipTLSVerifyContextKey{}).(bool)
return v
}
// directUpstreamContextKey signals that the request should bypass the embedded
// NetBird WireGuard client and dial via the host's network stack instead.
// Set by the reverse-proxy rewrite step when the matched target carries
// PathTarget.DirectUpstream; consumed by MultiTransport.
type directUpstreamContextKey struct{}
// WithDirectUpstream marks the context so MultiTransport routes the request
// through its stdlib transport instead of the embedded NetBird roundtripper.
func WithDirectUpstream(ctx context.Context) context.Context {
return context.WithValue(ctx, directUpstreamContextKey{}, true)
}
// DirectUpstreamFromContext reports whether the context has been marked to
// bypass the embedded NetBird client.
func DirectUpstreamFromContext(ctx context.Context) bool {
v, _ := ctx.Value(directUpstreamContextKey{}).(bool)
return v
}

View File

@@ -3,6 +3,7 @@ package roundtrip
import (
"context"
"net/http"
"net/netip"
"sync"
"testing"
@@ -305,6 +306,36 @@ func TestNetBird_AddPeer_ExistingStartedClient_NotifiesStatus(t *testing.T) {
assert.True(t, calls[0].connected)
}
// TestNetBird_IdentityForIP_UnknownAccountReturnsFalse confirms that the
// public lookup short-circuits when no client has been registered for
// the queried account. The auth middleware uses ok=false as a fast deny.
func TestNetBird_IdentityForIP_UnknownAccountReturnsFalse(t *testing.T) {
nb := mockNetBird()
_, _, ok := nb.IdentityForIP("acct-missing", netip.MustParseAddr("100.64.0.10"))
assert.False(t, ok, "unknown account must yield ok=false")
}
// TestClientEntry_IdentityForIP_NilClientGuard ensures the receiver
// methods stay safe when called on partially-initialized state, which
// can happen briefly during AddPeer setup or test fixtures.
func TestClientEntry_IdentityForIP_NilClientGuard(t *testing.T) {
var e *clientEntry
_, _, ok := e.IdentityForIP(netip.MustParseAddr("100.64.0.10"))
assert.False(t, ok, "nil clientEntry must yield ok=false")
e = &clientEntry{}
_, _, ok = e.IdentityForIP(netip.MustParseAddr("100.64.0.10"))
assert.False(t, ok, "clientEntry with nil embed.Client must yield ok=false")
}
// TestClientEntry_IdentityForIP_InvalidIPReturnsFalse covers the input
// guard so callers don't have to repeat the check.
func TestClientEntry_IdentityForIP_InvalidIPReturnsFalse(t *testing.T) {
e := &clientEntry{}
_, _, ok := e.IdentityForIP(netip.Addr{})
assert.False(t, ok, "invalid IP must yield ok=false")
}
func TestNetBird_RemovePeer_NotifiesDisconnection(t *testing.T) {
notifier := &mockStatusNotifier{}
nb := NewNetBird("test-proxy", "invalid.test", ClientConfig{

View File

@@ -36,7 +36,7 @@ func BenchmarkPeekClientHello_TLS(b *testing.B) {
for b.Loop() {
r := bytes.NewReader(hello)
conn := &readerConn{Reader: r}
sni, wrapped, err := PeekClientHello(conn)
sni, wrapped, _, err := PeekClientHello(conn)
if err != nil {
b.Fatal(err)
}
@@ -59,7 +59,7 @@ func BenchmarkPeekClientHello_NonTLS(b *testing.B) {
for b.Loop() {
r := bytes.NewReader(httpReq)
conn := &readerConn{Reader: r}
_, wrapped, err := PeekClientHello(conn)
_, wrapped, _, err := PeekClientHello(conn)
if err != nil {
b.Fatal(err)
}

View File

@@ -100,28 +100,50 @@ type Router struct {
// httpCh is immutable after construction: set only in NewRouter, nil in NewPortRouter.
httpCh chan net.Conn
httpListener *chanListener
mu sync.RWMutex
routes map[SNIHost][]Route
fallback *Route
draining bool
dialResolve DialResolver
activeConns sync.WaitGroup
activeRelays sync.WaitGroup
relaySem chan struct{}
drainDone chan struct{}
observer RelayObserver
accessLog l4Logger
geo restrict.GeoResolver
// httpPlainCh feeds non-TLS HTTP connections to a parallel http.Server.
// Set only when NewRouter is called with WithPlainHTTP option (used by
// per-account inbound listeners that accept both :80 and :443 traffic).
// Nil for the host SNI router and for port routers.
httpPlainCh chan net.Conn
httpPlainListener *chanListener
mu sync.RWMutex
routes map[SNIHost][]Route
fallback *Route
draining bool
dialResolve DialResolver
activeConns sync.WaitGroup
activeRelays sync.WaitGroup
relaySem chan struct{}
drainDone chan struct{}
observer RelayObserver
accessLog l4Logger
geo restrict.GeoResolver
// svcCtxs tracks a context per service ID. All relay goroutines for a
// service derive from its context; canceling it kills them immediately.
svcCtxs map[types.ServiceID]context.Context
svcCancels map[types.ServiceID]context.CancelFunc
}
// RouterOption customises Router construction.
type RouterOption func(*Router)
// WithPlainHTTP enables a parallel plain-HTTP channel on the router. When
// set, connections whose first byte is not a TLS handshake are forwarded
// to the plain channel returned by HTTPListenerPlain instead of the TLS
// channel. Used by per-account inbound listeners that share both :80 and
// :443 traffic on the same router.
func WithPlainHTTP(addr net.Addr) RouterOption {
return func(r *Router) {
ch := make(chan net.Conn, httpChannelBuffer)
r.httpPlainCh = ch
r.httpPlainListener = newChanListener(ch, addr)
}
}
// NewRouter creates a new SNI-based connection router.
func NewRouter(logger *log.Logger, dialResolve DialResolver, addr net.Addr) *Router {
func NewRouter(logger *log.Logger, dialResolve DialResolver, addr net.Addr, opts ...RouterOption) *Router {
httpCh := make(chan net.Conn, httpChannelBuffer)
return &Router{
r := &Router{
logger: logger,
httpCh: httpCh,
httpListener: newChanListener(httpCh, addr),
@@ -131,6 +153,10 @@ func NewRouter(logger *log.Logger, dialResolve DialResolver, addr net.Addr) *Rou
svcCtxs: make(map[types.ServiceID]context.Context),
svcCancels: make(map[types.ServiceID]context.CancelFunc),
}
for _, opt := range opts {
opt(r)
}
return r
}
// NewPortRouter creates a Router for a dedicated port without an HTTP
@@ -153,6 +179,16 @@ func (r *Router) HTTPListener() net.Listener {
return r.httpListener
}
// HTTPListenerPlain returns a net.Listener yielding non-TLS connections
// for use with a parallel plain http.Server. Returns nil when the router
// was not constructed with WithPlainHTTP.
func (r *Router) HTTPListenerPlain() net.Listener {
if r.httpPlainListener == nil {
return nil
}
return r.httpPlainListener
}
// AddRoute registers an SNI route. Multiple routes for the same host are
// stored and resolved by priority at lookup time (HTTP > TCP).
// Empty host is ignored to prevent conflicts with ECH/ESNI fallback.
@@ -254,6 +290,9 @@ func (r *Router) Serve(ctx context.Context, ln net.Listener) error {
if r.httpListener != nil {
r.httpListener.Close()
}
if r.httpPlainListener != nil {
r.httpPlainListener.Close()
}
case <-done:
}
}()
@@ -270,6 +309,7 @@ func (r *Router) Serve(ctx context.Context, ln net.Listener) error {
r.logger.Debugf("SNI router accept: %v", err)
continue
}
r.logger.Debugf("SNI router accepted conn from %s on %s", conn.RemoteAddr(), conn.LocalAddr())
r.activeConns.Add(1)
go func() {
defer r.activeConns.Done()
@@ -278,13 +318,24 @@ func (r *Router) Serve(ctx context.Context, ln net.Listener) error {
}
}
// HandleConn lets external accept loops feed a connection through the
// router's peek-and-dispatch logic. Use this when the same router serves
// a secondary listener (for example, a per-account inbound :80 socket
// alongside its :443 socket).
func (r *Router) HandleConn(ctx context.Context, conn net.Conn) {
r.activeConns.Add(1)
defer r.activeConns.Done()
r.handleConn(ctx, conn)
}
// handleConn peeks at the TLS ClientHello and routes the connection.
func (r *Router) handleConn(ctx context.Context, conn net.Conn) {
// Fast path: when no SNI routes and no HTTP channel exist (pure TCP
// fallback port), skip the TLS peek entirely to avoid read errors on
// non-TLS connections and reduce latency.
if r.isFallbackOnly() {
r.handleUnmatched(ctx, conn)
r.logger.Debugf("SNI router fallback-only mode for conn from %s; skipping ClientHello peek", conn.RemoteAddr())
r.handleUnmatched(ctx, conn, false)
return
}
@@ -294,11 +345,11 @@ func (r *Router) handleConn(ctx context.Context, conn net.Conn) {
return
}
sni, wrapped, err := PeekClientHello(conn)
sni, wrapped, isTLS, err := PeekClientHello(conn)
if err != nil {
r.logger.Debugf("SNI peek: %v", err)
r.logger.Debugf("SNI peek failed for conn from %s: %v", conn.RemoteAddr(), err)
if wrapped != nil {
r.handleUnmatched(ctx, wrapped)
r.handleUnmatched(ctx, wrapped, isTLS)
} else {
_ = conn.Close()
}
@@ -313,13 +364,20 @@ func (r *Router) handleConn(ctx context.Context, conn net.Conn) {
host := SNIHost(strings.ToLower(sni))
route, ok := r.lookupRoute(host)
r.logger.WithFields(log.Fields{
"remote": wrapped.RemoteAddr().String(),
"sni": string(host),
"match": ok,
"tls": isTLS,
}).Debug("SNI route lookup")
if !ok {
r.handleUnmatched(ctx, wrapped)
r.handleUnmatched(ctx, wrapped, isTLS)
return
}
if route.Type == RouteHTTP {
r.sendToHTTP(wrapped)
r.logger.Debugf("SNI %q routed to HTTP handler (service_id=%s)", host, route.ServiceID)
r.sendToHTTP(wrapped, isTLS)
return
}
@@ -344,15 +402,17 @@ func (r *Router) isFallbackOnly() bool {
}
// handleUnmatched routes a connection that didn't match any SNI route.
// This includes ECH/ESNI connections where the cleartext SNI is empty.
// This includes ECH/ESNI connections where the cleartext SNI is empty,
// and plain (non-TLS) HTTP connections when isTLS is false.
// It tries the fallback relay first, then the HTTP channel, and closes
// the connection if neither is available.
func (r *Router) handleUnmatched(ctx context.Context, conn net.Conn) {
func (r *Router) handleUnmatched(ctx context.Context, conn net.Conn, isTLS bool) {
r.mu.RLock()
fb := r.fallback
r.mu.RUnlock()
if fb != nil {
r.logger.Debugf("unmatched conn from %s relayed to TCP fallback (service_id=%s, target=%s)", conn.RemoteAddr(), fb.ServiceID, fb.Target)
if err := r.relayTCP(ctx, conn, SNIHost("fallback"), *fb); err != nil {
if !errors.Is(err, errAccessRestricted) {
r.logger.WithFields(log.Fields{
@@ -364,7 +424,8 @@ func (r *Router) handleUnmatched(ctx context.Context, conn net.Conn) {
}
return
}
r.sendToHTTP(conn)
r.logger.Debugf("unmatched conn from %s sent to HTTP channel (no TCP fallback configured)", conn.RemoteAddr())
r.sendToHTTP(conn, isTLS)
}
// lookupRoute returns the highest-priority route for the given SNI host.
@@ -386,10 +447,20 @@ func (r *Router) lookupRoute(host SNIHost) (Route, bool) {
}
// sendToHTTP feeds the connection to the HTTP handler via the channel.
// If no HTTP channel is configured (port router), the router is
// draining, or the channel is full, the connection is closed.
func (r *Router) sendToHTTP(conn net.Conn) {
if r.httpCh == nil {
// When isTLS is false and a plain channel is configured the connection
// is forwarded to the plain channel; otherwise it lands on the TLS
// channel. If no usable channel exists, the router is draining, or the
// channel is full, the connection is closed.
func (r *Router) sendToHTTP(conn net.Conn, isTLS bool) {
ch := r.httpCh
chanName := "HTTP"
if !isTLS && r.httpPlainCh != nil {
ch = r.httpPlainCh
chanName = "HTTP-plain"
}
if ch == nil {
r.logger.Debugf("%s channel nil; dropping conn from %s", chanName, conn.RemoteAddr())
_ = conn.Close()
return
}
@@ -399,14 +470,15 @@ func (r *Router) sendToHTTP(conn net.Conn) {
r.mu.RUnlock()
if draining {
r.logger.Debugf("router draining; dropping conn from %s", conn.RemoteAddr())
_ = conn.Close()
return
}
select {
case r.httpCh <- conn:
case ch <- conn:
default:
r.logger.Warnf("HTTP channel full, dropping connection from %s", conn.RemoteAddr())
r.logger.Warnf("%s channel full, dropping connection from %s", chanName, conn.RemoteAddr())
_ = conn.Close()
}
}

View File

@@ -1739,3 +1739,97 @@ func TestCheckRestrictions_IPv4MappedIPv6(t *testing.T) {
connOutside := &fakeConn{remote: fakeAddr("[::ffff:192.168.1.1]:5678")}
assert.NotEqual(t, restrict.Allow, router.checkRestrictions(connOutside, route), "::ffff:192.168.1.1 not in v4 CIDR")
}
// TestRouter_PlainHTTP_RoutesToPlainChannel verifies that a plain (non-TLS)
// connection lands on the plain HTTP channel when the router was built
// with WithPlainHTTP, leaving the TLS channel untouched.
func TestRouter_PlainHTTP_RoutesToPlainChannel(t *testing.T) {
logger := log.StandardLogger()
addr := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 443}
router := NewRouter(logger, nil, addr, WithPlainHTTP(addr))
router.AddRoute("example.com", Route{Type: RouteHTTP})
ln, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err, "test listener bind must succeed")
defer ln.Close()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go func() {
_ = router.Serve(ctx, ln)
}()
// Plain HTTP request (no TLS handshake byte).
go func() {
conn, err := net.DialTimeout("tcp", ln.Addr().String(), 2*time.Second)
if err != nil {
return
}
_, _ = conn.Write([]byte("GET / HTTP/1.1\r\nHost: example.com\r\n\r\n"))
}()
plainListener := router.HTTPListenerPlain()
require.NotNil(t, plainListener, "plain listener must be exposed when WithPlainHTTP is set")
acceptDone := make(chan net.Conn, 1)
go func() {
conn, err := plainListener.Accept()
if err == nil {
acceptDone <- conn
}
}()
select {
case conn := <-acceptDone:
require.NotNil(t, conn)
_ = conn.Close()
case <-router.HTTPListener().(*chanListener).ch:
t.Fatal("plain HTTP request leaked into TLS channel")
case <-time.After(3 * time.Second):
t.Fatal("plain HTTP connection never reached plain channel")
}
}
// TestRouter_TLS_StaysOnTLSChannel_WhenPlainEnabled verifies that the
// presence of a plain channel does not divert TLS traffic — TLS still
// goes to the TLS channel as before.
func TestRouter_TLS_StaysOnTLSChannel_WhenPlainEnabled(t *testing.T) {
logger := log.StandardLogger()
addr := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 443}
router := NewRouter(logger, nil, addr, WithPlainHTTP(addr))
router.AddRoute("example.com", Route{Type: RouteHTTP})
ln, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err, "test listener bind must succeed")
defer ln.Close()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go func() { _ = router.Serve(ctx, ln) }()
// Send a TLS ClientHello.
go func() {
conn, err := net.DialTimeout("tcp", ln.Addr().String(), 2*time.Second)
if err != nil {
return
}
tlsConn := tls.Client(conn, &tls.Config{
ServerName: "example.com",
InsecureSkipVerify: true, //nolint:gosec
})
_ = tlsConn.Handshake()
_ = tlsConn.Close()
}()
select {
case conn := <-router.httpCh:
require.NotNil(t, conn, "TLS conn should land on the TLS channel")
_ = conn.Close()
case <-time.After(5 * time.Second):
t.Fatal("TLS conn never reached the TLS channel")
}
}

View File

@@ -30,26 +30,30 @@ const (
// bytes transparently. If the data is not a valid TLS ClientHello or
// contains no SNI extension, sni is empty and err is nil.
//
// isTLS reports whether the first byte indicated a TLS handshake record.
// Callers can use this to distinguish plain (non-TLS) traffic from a TLS
// stream that simply lacked an SNI extension or used ECH.
//
// ECH/ESNI: When the client uses Encrypted Client Hello (TLS 1.3), the
// real server name is encrypted inside the encrypted_client_hello
// extension. This parser only reads the cleartext server_name extension
// (type 0x0000), so ECH connections return sni="" and are routed through
// the fallback path (or HTTP channel), which is the correct behavior
// for a transparent proxy that does not terminate TLS.
func PeekClientHello(conn net.Conn) (sni string, wrapped net.Conn, err error) {
func PeekClientHello(conn net.Conn) (sni string, wrapped net.Conn, isTLS bool, err error) {
// Read the 5-byte TLS record header into a small stack-friendly buffer.
var header [tlsRecordHeaderLen]byte
if _, err := io.ReadFull(conn, header[:]); err != nil {
return "", nil, fmt.Errorf("read TLS record header: %w", err)
return "", nil, false, fmt.Errorf("read TLS record header: %w", err)
}
if header[0] != contentTypeHandshake {
return "", newPeekedConn(conn, header[:]), nil
return "", newPeekedConn(conn, header[:]), false, nil
}
recordLen := int(binary.BigEndian.Uint16(header[3:5]))
if recordLen == 0 || recordLen > maxClientHelloLen {
return "", newPeekedConn(conn, header[:]), nil
return "", newPeekedConn(conn, header[:]), true, nil
}
// Single allocation for header + payload. The peekedConn takes
@@ -59,11 +63,11 @@ func PeekClientHello(conn net.Conn) (sni string, wrapped net.Conn, err error) {
n, err := io.ReadFull(conn, buf[tlsRecordHeaderLen:])
if err != nil {
return "", newPeekedConn(conn, buf[:tlsRecordHeaderLen+n]), fmt.Errorf("read TLS handshake payload: %w", err)
return "", newPeekedConn(conn, buf[:tlsRecordHeaderLen+n]), true, fmt.Errorf("read TLS handshake payload: %w", err)
}
sni = extractSNI(buf[tlsRecordHeaderLen:])
return sni, newPeekedConn(conn, buf), nil
return sni, newPeekedConn(conn, buf), true, nil
}
// extractSNI parses a TLS handshake payload to find the SNI extension.

View File

@@ -29,10 +29,11 @@ func TestPeekClientHello_ValidSNI(t *testing.T) {
_ = tlsConn.Handshake()
}()
sni, wrapped, err := PeekClientHello(serverConn)
sni, wrapped, isTLS, err := PeekClientHello(serverConn)
require.NoError(t, err)
assert.Equal(t, expectedSNI, sni, "should extract SNI from ClientHello")
assert.NotNil(t, wrapped, "wrapped connection should not be nil")
assert.True(t, isTLS, "TLS ClientHello should be flagged as TLS")
// Verify the wrapped connection replays the peeked bytes.
// Read the first 5 bytes (TLS record header) to confirm replay.
@@ -83,10 +84,11 @@ func TestPeekClientHello_MultipleSNIs(t *testing.T) {
_ = tlsConn.Handshake()
}()
sni, wrapped, err := PeekClientHello(serverConn)
sni, wrapped, isTLS, err := PeekClientHello(serverConn)
require.NoError(t, err)
assert.Equal(t, tt.expectedSNI, sni)
assert.NotNil(t, wrapped)
assert.True(t, isTLS, "TLS handshake should be flagged as TLS")
})
}
}
@@ -102,10 +104,11 @@ func TestPeekClientHello_NonTLSData(t *testing.T) {
_, _ = clientConn.Write(httpData)
}()
sni, wrapped, err := PeekClientHello(serverConn)
sni, wrapped, isTLS, err := PeekClientHello(serverConn)
require.NoError(t, err)
assert.Empty(t, sni, "should return empty SNI for non-TLS data")
assert.NotNil(t, wrapped)
assert.False(t, isTLS, "plain HTTP data should not be flagged as TLS")
// Verify the wrapped connection still provides the original data.
buf := make([]byte, len(httpData))
@@ -124,7 +127,7 @@ func TestPeekClientHello_TruncatedHeader(t *testing.T) {
clientConn.Close()
}()
_, _, err := PeekClientHello(serverConn)
_, _, _, err := PeekClientHello(serverConn)
assert.Error(t, err, "should error on truncated header")
}
@@ -140,7 +143,7 @@ func TestPeekClientHello_TruncatedPayload(t *testing.T) {
clientConn.Close()
}()
_, _, err := PeekClientHello(serverConn)
_, _, _, err := PeekClientHello(serverConn)
assert.Error(t, err, "should error on truncated payload")
}
@@ -154,10 +157,11 @@ func TestPeekClientHello_ZeroLengthRecord(t *testing.T) {
_, _ = clientConn.Write([]byte{0x16, 0x03, 0x01, 0x00, 0x00})
}()
sni, wrapped, err := PeekClientHello(serverConn)
sni, wrapped, isTLS, err := PeekClientHello(serverConn)
require.NoError(t, err)
assert.Empty(t, sni)
assert.NotNil(t, wrapped)
assert.True(t, isTLS, "zero-length record should still be a TLS handshake byte")
}
func TestExtractSNI_InvalidPayload(t *testing.T) {

160
proxy/lifecycle.go Normal file
View File

@@ -0,0 +1,160 @@
package proxy
import (
"net/netip"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/proxy/internal/acme"
)
// Config bundles every knob the proxy reads at construction time. It mirrors
// the public fields on Server so library callers don't have to learn the
// internal struct layout. Zero values mean "feature off" or "fall back to the
// internal default" depending on the field — see the per-field doc.
//
// The standalone binary continues to populate Server fields directly, so
// adding fields here must not change the zero-value behaviour of Server.
type Config struct {
// ListenAddr is the TCP address the main listener binds. Required.
ListenAddr string
// ID identifies this proxy instance to management. Empty value lets
// New generate a timestamped default.
ID string
// Logger is the logrus logger used everywhere. Empty value falls back
// to log.StandardLogger().
Logger *log.Logger
// Version is the build version string reported to management. Empty
// becomes "dev".
Version string
// ProxyURL is the public address operators use to reach this proxy.
ProxyURL string
// ManagementAddress is the gRPC URL of the management server.
ManagementAddress string
// ProxyToken authenticates this proxy with the management server.
ProxyToken string
// CertificateDirectory is the directory holding TLS certificate
// material (static or ACME-provisioned).
CertificateDirectory string
// CertificateFile is the certificate filename within
// CertificateDirectory.
CertificateFile string
// CertificateKeyFile is the private key filename within
// CertificateDirectory.
CertificateKeyFile string
// GenerateACMECertificates toggles ACME certificate provisioning.
GenerateACMECertificates bool
// ACMEChallengeAddress is the listen address for HTTP-01 challenges.
ACMEChallengeAddress string
// ACMEDirectory is the ACME directory URL (Let's Encrypt by default).
ACMEDirectory string
// ACMEEABKID is the External Account Binding Key ID for CAs that
// require EAB (e.g. ZeroSSL).
ACMEEABKID string
// ACMEEABHMACKey is the External Account Binding HMAC key for CAs
// that require EAB.
ACMEEABHMACKey string
// ACMEChallengeType is the ACME challenge type ("tls-alpn-01" or
// "http-01"). Empty defaults to "tls-alpn-01".
ACMEChallengeType string
// CertLockMethod controls how ACME certificate locks are coordinated
// across replicas.
CertLockMethod acme.CertLockMethod
// WildcardCertDir is an optional directory containing static wildcard
// certificates that override ACME for matching domains.
WildcardCertDir string
// DebugEndpointEnabled toggles the debug HTTP endpoint.
DebugEndpointEnabled bool
// DebugEndpointAddress is the bind address for the debug endpoint.
DebugEndpointAddress string
// HealthAddr is the bind address for the health probe and metrics
// surface. Empty disables the health probe entirely (library callers
// can attach their own).
HealthAddr string
// ForwardedProto overrides the X-Forwarded-Proto value sent to
// backends. Valid values: "auto", "http", "https".
ForwardedProto string
// TrustedProxies is a list of IP prefixes for trusted upstream
// proxies that may set forwarding headers.
TrustedProxies []netip.Prefix
// WireguardPort is the UDP port for the embedded NetBird tunnel.
// Zero asks the OS for a random port.
WireguardPort uint16
// ProxyProtocol enables PROXY protocol (v1/v2) on TCP listeners.
ProxyProtocol bool
// PreSharedKey is the WireGuard pre-shared key used between the
// proxy's embedded clients and peers.
PreSharedKey string
// SupportsCustomPorts indicates whether the proxy can bind arbitrary
// ports for TCP/UDP/TLS services.
SupportsCustomPorts bool
// RequireSubdomain forces accounts to use a subdomain in front of
// the proxy's cluster domain.
RequireSubdomain bool
// Private flags this proxy as embedded in a netbird client and
// serving exclusively over the WireGuard tunnel. Also enables
// per-account inbound listeners on each embedded client's netstack.
Private bool
// MaxDialTimeout caps the per-service backend dial timeout.
MaxDialTimeout time.Duration
// MaxSessionIdleTimeout caps the per-service session idle timeout.
MaxSessionIdleTimeout time.Duration
// GeoDataDir is the directory containing GeoLite2 MMDB files.
GeoDataDir string
// CrowdSecAPIURL is the CrowdSec LAPI URL. Empty disables CrowdSec.
CrowdSecAPIURL string
// CrowdSecAPIKey is the CrowdSec bouncer API key. Empty disables
// CrowdSec.
CrowdSecAPIKey string
}
// New builds a Server from cfg without performing any I/O. No goroutines
// are spawned, no network connections are dialed, and no listeners are
// bound — call Start to bring the proxy up. Returning a fully-formed
// Server keeps the standalone code path (which still constructs Server
// directly) byte-for-byte equivalent.
func New(cfg Config) *Server {
return &Server{
ListenAddr: cfg.ListenAddr,
ID: cfg.ID,
Logger: cfg.Logger,
Version: cfg.Version,
ProxyURL: cfg.ProxyURL,
ManagementAddress: cfg.ManagementAddress,
ProxyToken: cfg.ProxyToken,
CertificateDirectory: cfg.CertificateDirectory,
CertificateFile: cfg.CertificateFile,
CertificateKeyFile: cfg.CertificateKeyFile,
GenerateACMECertificates: cfg.GenerateACMECertificates,
ACMEChallengeAddress: cfg.ACMEChallengeAddress,
ACMEDirectory: cfg.ACMEDirectory,
ACMEEABKID: cfg.ACMEEABKID,
ACMEEABHMACKey: cfg.ACMEEABHMACKey,
ACMEChallengeType: cfg.ACMEChallengeType,
CertLockMethod: cfg.CertLockMethod,
WildcardCertDir: cfg.WildcardCertDir,
DebugEndpointEnabled: cfg.DebugEndpointEnabled,
DebugEndpointAddress: cfg.DebugEndpointAddress,
HealthAddress: cfg.HealthAddr,
ForwardedProto: cfg.ForwardedProto,
TrustedProxies: cfg.TrustedProxies,
WireguardPort: cfg.WireguardPort,
ProxyProtocol: cfg.ProxyProtocol,
PreSharedKey: cfg.PreSharedKey,
SupportsCustomPorts: cfg.SupportsCustomPorts,
RequireSubdomain: cfg.RequireSubdomain,
Private: cfg.Private,
MaxDialTimeout: cfg.MaxDialTimeout,
MaxSessionIdleTimeout: cfg.MaxSessionIdleTimeout,
GeoDataDir: cfg.GeoDataDir,
CrowdSecAPIURL: cfg.CrowdSecAPIURL,
CrowdSecAPIKey: cfg.CrowdSecAPIKey,
}
}

134
proxy/lifecycle_test.go Normal file
View File

@@ -0,0 +1,134 @@
package proxy
import (
"context"
"errors"
"io"
"testing"
"time"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// quietLifecycleLogger keeps lifecycle tests from spamming the test output.
func quietLifecycleLogger() *log.Logger {
l := log.New()
l.SetOutput(io.Discard)
l.SetLevel(log.PanicLevel)
return l
}
func TestNewIsPureConstructor(t *testing.T) {
cfg := Config{
ListenAddr: ":0",
ID: "test-id",
Logger: quietLifecycleLogger(),
Version: "test",
ManagementAddress: "https://example.invalid",
HealthAddr: "",
ForwardedProto: "auto",
}
srv := New(cfg)
require.NotNil(t, srv, "New must return a non-nil Server")
assert.Equal(t, ":0", srv.ListenAddr, "ListenAddr should round-trip")
assert.Equal(t, "test-id", srv.ID, "ID should round-trip")
assert.Equal(t, "test", srv.Version, "Version should round-trip")
assert.Equal(t, "https://example.invalid", srv.ManagementAddress, "ManagementAddress should round-trip")
assert.Equal(t, "auto", srv.ForwardedProto, "ForwardedProto should round-trip")
// Pure constructor: no goroutines, no listener bind, no management dial.
assert.False(t, srv.started, "Server must be marked unstarted before Start")
assert.Nil(t, srv.mgmtClient, "mgmt client must not be created in New")
assert.Nil(t, srv.netbird, "netbird client must not be created in New")
assert.Nil(t, srv.https, "https server must not be created in New")
assert.Nil(t, srv.healthServer, "health server must not be created in New")
assert.Nil(t, srv.runCancel, "runCancel must be nil before Start")
assert.Nil(t, srv.runErrCh, "runErrCh must be nil before Start")
}
func TestStopBeforeStartIsNoOp(t *testing.T) {
srv := New(Config{Logger: quietLifecycleLogger()})
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
err := srv.Stop(ctx)
assert.NoError(t, err, "Stop on an unstarted server must succeed without error")
err = srv.Stop(ctx)
assert.NoError(t, err, "Stop must remain idempotent across repeated calls")
}
func TestStartFailsWithoutManagement(t *testing.T) {
srv := New(Config{
Logger: quietLifecycleLogger(),
ListenAddr: "127.0.0.1:0",
ManagementAddress: "://broken-url",
})
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
err := srv.Start(ctx)
require.Error(t, err, "Start must surface management dial failures")
assert.True(t, srv.started, "started flag is set before any dial attempt so a second Start fails fast")
err = srv.Start(ctx)
require.Error(t, err, "second Start must reject")
assert.Contains(t, err.Error(), "already started", "error must explain why the call was rejected")
}
func TestStopIsIdempotent(t *testing.T) {
srv := &Server{
Logger: quietLifecycleLogger(),
started: true,
runErrCh: make(chan struct{}),
runCancel: func() {},
}
srv.recordRunErr(errors.New("synthetic"))
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
err := srv.Stop(ctx)
require.Error(t, err, "Stop must surface the recorded background error")
assert.Contains(t, err.Error(), "synthetic", "error must round-trip recordRunErr's value")
err = srv.Stop(ctx)
require.Error(t, err, "second Stop must still report the same error")
assert.Contains(t, err.Error(), "synthetic", "idempotent Stop must return the cached error")
}
func TestRecordRunErrPreservesFirstFailure(t *testing.T) {
srv := &Server{
Logger: quietLifecycleLogger(),
runErrCh: make(chan struct{}),
}
srv.recordRunErr(errors.New("first"))
srv.recordRunErr(errors.New("second"))
require.Error(t, srv.runErr, "first failure must be retained")
assert.Contains(t, srv.runErr.Error(), "first", "second call must not overwrite the cached error")
select {
case <-srv.runErrCh:
default:
t.Fatal("recordRunErr must close runErrCh so waitAndStop unblocks")
}
}
func TestStopSkipsShutdownWhenNeverStarted(t *testing.T) {
srv := New(Config{Logger: quietLifecycleLogger()})
ctx, cancel := context.WithCancel(context.Background())
cancel()
err := srv.Stop(ctx)
assert.NoError(t, err, "Stop on an unstarted server should not block on the cancelled ctx")
}

View File

@@ -239,6 +239,10 @@ func (m *testProxyManager) ClusterSupportsCrowdSec(_ context.Context, _ string)
return nil
}
func (m *testProxyManager) ClusterSupportsPrivate(_ context.Context, _ string) *bool {
return nil
}
func (m *testProxyManager) CleanupStale(_ context.Context, _ time.Duration) error {
return nil
}
@@ -565,6 +569,7 @@ func TestIntegration_ProxyConnection_ReconnectDoesNotDuplicateState(t *testing.T
proxytypes.AccountID(mapping.GetAccountId()),
proxytypes.ServiceID(mapping.GetId()),
nil,
mapping.GetPrivate(),
)
require.NoError(t, err)

View File

@@ -32,11 +32,10 @@ import (
"go.opentelemetry.io/otel/sdk/metric"
"golang.org/x/exp/maps"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/keepalive"
grpcstatus "google.golang.org/grpc/status"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/netbirdio/netbird/proxy/internal/accesslog"
@@ -114,9 +113,28 @@ type Server struct {
// The mapping worker waits on this before processing updates.
routerReady chan struct{}
// inbound, when non-nil, manages per-account inbound listeners. Set by
// initPrivateInbound only when Private is true so the standalone
// proxy keeps its zero-overhead default path.
inbound *inboundManager
// Lifecycle state — populated by Start, consumed by Stop. The fields
// stay zero on a fresh Server until Start runs so direct struct
// construction (`&Server{...}`) used by tests still works.
runCancel context.CancelFunc
mgmtConn *grpc.ClientConn
runErr error
runErrCh chan struct{}
startMu sync.Mutex
started bool
stopOnce sync.Once
// Mostly used for debugging on management.
startTime time.Time
// ListenAddr is the address the main TCP listener binds. Populated by
// New from Config or by ListenAndServe from its addr argument.
ListenAddr string
ID string
Logger *log.Logger
Version string
@@ -177,6 +195,14 @@ type Server struct {
// in front of this proxy's cluster domain. When true, accounts cannot
// create services on the bare cluster domain.
RequireSubdomain bool
// Private flags this proxy as embedded in a netbird client and serving
// exclusively over the WireGuard tunnel (i.e. `netbird proxy`). Reported
// upstream as a capability so dashboards can distinguish per-peer
// clusters from centralised ones, and turns on per-account inbound
// listeners on each embedded client's netstack: every account that
// registers a service exposes :80 + :443 inside its own WG tunnel,
// scoped to that account's services only.
Private bool
// MaxDialTimeout caps the per-service backend dial timeout.
// When the API sends a timeout, it is clamped to this value.
// When the API sends no timeout, this value is used as the default.
@@ -222,12 +248,16 @@ func (s *Server) NotifyStatus(ctx context.Context, accountID types.AccountID, se
status = proto.ProxyStatus_PROXY_STATUS_ACTIVE
}
_, err := s.mgmtClient.SendStatusUpdate(ctx, &proto.SendStatusUpdateRequest{
req := &proto.SendStatusUpdateRequest{
ServiceId: string(serviceID),
AccountId: string(accountID),
Status: status,
CertificateIssued: false,
})
}
if connected {
req.InboundListener = s.inboundListenerProto(accountID)
}
_, err := s.mgmtClient.SendStatusUpdate(ctx, req)
return err
}
@@ -238,56 +268,68 @@ func (s *Server) NotifyCertificateIssued(ctx context.Context, accountID types.Ac
AccountId: string(accountID),
Status: proto.ProxyStatus_PROXY_STATUS_ACTIVE,
CertificateIssued: true,
InboundListener: s.inboundListenerProto(accountID),
})
return err
}
func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
s.initDefaults()
s.routerReady = make(chan struct{})
s.udpRelays = make(map[types.ServiceID]*udprelay.Relay)
s.portRouters = make(map[uint16]*portRouter)
s.svcPorts = make(map[types.ServiceID][]uint16)
s.lastMappings = make(map[types.ServiceID]*proto.ProxyMapping)
exporter, err := prometheus.New()
if err != nil {
return fmt.Errorf("create prometheus exporter: %w", err)
// inboundListenerProto resolves the per-account inbound listener state for
// the SendStatusUpdate payload. Returns nil when --private-inbound is off
// or the account has no live listener so management treats the field as
// absent.
func (s *Server) inboundListenerProto(accountID types.AccountID) *proto.ProxyInboundListener {
if s.inbound == nil {
return nil
}
provider := metric.NewMeterProvider(metric.WithReader(exporter))
pkg := reflect.TypeOf(Server{}).PkgPath()
meter := provider.Meter(pkg)
s.meter, err = proxymetrics.New(ctx, meter)
if err != nil {
return fmt.Errorf("create metrics: %w", err)
info, ok := s.inbound.ListenerInfo(accountID)
if !ok || info.TunnelIP == "" {
return nil
}
return &proto.ProxyInboundListener{
TunnelIp: info.TunnelIP,
HttpsPort: uint32(info.HTTPSPort),
HttpPort: uint32(info.HTTPPort),
}
}
mgmtConn, err := s.dialManagement()
if err != nil {
// ListenAndServe is the standalone entrypoint. It binds the listener, runs
// the proxy until ctx is cancelled or a background goroutine fails, then
// drains and stops. Library callers should prefer New + Start + Stop and
// own their own shutdown signalling.
func (s *Server) ListenAndServe(ctx context.Context, addr string) error {
s.ListenAddr = addr
if err := s.Start(ctx); err != nil {
return err
}
defer func() {
if err := mgmtConn.Close(); err != nil {
s.Logger.Debugf("management connection close: %v", err)
}
}()
s.mgmtClient = proto.NewProxyServiceClient(mgmtConn)
return s.waitAndStop(ctx)
}
// Start brings the proxy up: dials management, configures TLS, binds the
// main listener, and spawns the SNI router and HTTPS server goroutines. It
// returns once the listener is bound; background errors are surfaced
// through Stop's return value. Start is not safe to call twice.
func (s *Server) Start(ctx context.Context) error {
s.startMu.Lock()
if s.started {
s.startMu.Unlock()
return errors.New("proxy already started")
}
s.started = true
s.startMu.Unlock()
s.initLifecycleState()
if err := s.initMetrics(ctx); err != nil {
return err
}
if err := s.initManagementClient(); err != nil {
return err
}
runCtx, runCancel := context.WithCancel(ctx)
defer runCancel()
s.runCancel = runCancel
// Initialize the netbird client, this is required to build peer connections
// to proxy over.
s.netbird = roundtrip.NewNetBird(s.ID, s.ProxyURL, roundtrip.ClientConfig{
MgmtAddr: s.ManagementAddress,
WGPort: s.WireguardPort,
PreSharedKey: s.PreSharedKey,
}, s.Logger, s, s.mgmtClient)
s.netbird.OnAddPeer = s.meter.RecordAddPeerDuration
// Create health checker before the mapping worker so it can track
// management connectivity from the first stream connection.
s.initNetBirdClient()
s.healthChecker = health.NewChecker(s.Logger, s.netbird)
s.crowdsecRegistry = crowdsec.NewRegistry(s.CrowdSecAPIURL, s.CrowdSecAPIKey, log.NewEntry(s.Logger))
@@ -300,34 +342,25 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
return err
}
// Configure the reverse proxy using NetBird's HTTP Client Transport for proxying.
s.proxy = proxy.NewReverseProxy(s.meter.RoundTripper(s.netbird), s.ForwardedProto, s.TrustedProxies, s.Logger)
s.initReverseProxy()
geoLookup, err := geolocation.NewLookup(s.Logger, s.GeoDataDir)
if err != nil {
return fmt.Errorf("initialize geolocation: %w", err)
}
s.geoRaw = geoLookup
if geoLookup != nil {
s.geo = geoLookup
if err := s.initGeoLookup(); err != nil {
return err
}
var startupOK bool
startupOK := false
defer func() {
if startupOK {
return
}
if s.geoRaw != nil {
if err := s.geoRaw.Close(); err != nil {
s.Logger.Debugf("close geolocation on startup failure: %v", err)
if closeErr := s.geoRaw.Close(); closeErr != nil {
s.Logger.Debugf("close geolocation on startup failure: %v", closeErr)
}
}
}()
// Configure the authentication middleware with session validator for OIDC group checks.
s.auth = auth.NewMiddleware(s.Logger, s.mgmtClient, s.geo)
// Configure Access logs to management server.
s.accessLog = accesslog.NewLogger(s.mgmtClient, s.Logger, s.TrustedProxies)
s.startDebugEndpoint()
@@ -336,35 +369,21 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
return err
}
// Build the handler chain from inside out.
handler := http.Handler(s.proxy)
handler = s.auth.Protect(handler)
handler = web.AssetHandler(handler)
handler = s.accessLog.Middleware(handler)
handler = s.meter.Middleware(handler)
handler = s.hijackTracker.Middleware(handler)
handler := s.buildHandlerChain()
s.initPrivateInbound(handler, tlsConfig)
// Start a raw TCP listener; the SNI router peeks at ClientHello
// and routes to either the HTTP handler or a TCP relay.
lc := net.ListenConfig{}
ln, err := lc.Listen(ctx, "tcp", addr)
ln, err := s.bindMainListener(ctx)
if err != nil {
return fmt.Errorf("listen on %s: %w", addr, err)
return err
}
if s.ProxyProtocol {
ln = s.wrapProxyProtocol(ln)
}
s.mainPort = uint16(ln.Addr().(*net.TCPAddr).Port) //nolint:gosec // port from OS is always valid
// Set up the SNI router for TCP/HTTP multiplexing on the main port.
s.mainRouter = nbtcp.NewRouter(s.Logger, s.resolveDialFunc, ln.Addr())
s.mainRouter.SetObserver(s.meter)
s.mainRouter.SetAccessLogger(s.accessLog)
close(s.routerReady)
// The HTTP server uses the chanListener fed by the SNI router.
s.https = &http.Server{
Addr: addr,
Addr: s.ListenAddr,
Handler: handler,
TLSConfig: tlsConfig,
ReadHeaderTimeout: httpReadHeaderTimeout,
@@ -374,35 +393,200 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
startupOK = true
httpsErr := make(chan error, 1)
go func() {
s.Logger.Debug("starting HTTPS server on SNI router HTTP channel")
httpsErr <- s.https.ServeTLS(s.mainRouter.HTTPListener(), "", "")
if serveErr := s.https.ServeTLS(s.mainRouter.HTTPListener(), "", ""); serveErr != nil && !errors.Is(serveErr, http.ErrServerClosed) {
s.recordRunErr(fmt.Errorf("https server: %w", serveErr))
}
}()
routerErr := make(chan error, 1)
go func() {
s.Logger.Debugf("starting SNI router on %s", addr)
routerErr <- s.mainRouter.Serve(runCtx, ln)
s.Logger.Debugf("starting SNI router on %s", s.ListenAddr)
if serveErr := s.mainRouter.Serve(runCtx, ln); serveErr != nil {
s.recordRunErr(fmt.Errorf("SNI router: %w", serveErr))
}
}()
return nil
}
// Stop drains in-flight connections, shuts down all background services,
// and releases resources. Idempotent; calling it before Start is a no-op.
// Returns the first fatal error reported by a background goroutine, if
// any. The provided ctx bounds the total wait time; once it is cancelled
// Stop returns even if drain is still in flight.
func (s *Server) Stop(ctx context.Context) error {
s.stopOnce.Do(func() {
s.startMu.Lock()
started := s.started
s.startMu.Unlock()
if !started {
return
}
done := make(chan struct{})
go func() {
defer close(done)
s.gracefulShutdown()
if s.runCancel != nil {
s.runCancel()
}
if s.mgmtConn != nil {
if err := s.mgmtConn.Close(); err != nil {
s.Logger.Debugf("management connection close: %v", err)
}
}
}()
select {
case <-done:
case <-ctx.Done():
s.Logger.Warnf("proxy stop deadline exceeded: %v", ctx.Err())
}
})
s.startMu.Lock()
defer s.startMu.Unlock()
return s.runErr
}
// waitAndStop blocks until ctx is cancelled or a background goroutine
// reports a fatal error, then drains and stops. Used by ListenAndServe.
func (s *Server) waitAndStop(ctx context.Context) error {
select {
case err := <-httpsErr:
s.shutdownServices()
if !errors.Is(err, http.ErrServerClosed) {
return fmt.Errorf("https server: %w", err)
}
return nil
case err := <-routerErr:
s.shutdownServices()
if err != nil {
return fmt.Errorf("SNI router: %w", err)
}
return nil
case <-ctx.Done():
s.gracefulShutdown()
return nil
case <-s.runErrCh:
}
stopCtx, cancel := context.WithTimeout(context.Background(), shutdownDrainTimeout+shutdownServiceTimeout)
defer cancel()
return s.Stop(stopCtx)
}
// recordRunErr stores the first fatal background error and signals
// waitAndStop. Subsequent errors are logged at debug level so the first
// cause is preserved.
func (s *Server) recordRunErr(err error) {
s.startMu.Lock()
defer s.startMu.Unlock()
if s.runErr != nil {
s.Logger.Debugf("background error after first failure: %v", err)
return
}
s.runErr = err
if s.runErrCh != nil {
close(s.runErrCh)
}
}
// initLifecycleState seeds the maps and channels Start needs to wire up
// background goroutines. Called once at the top of Start.
func (s *Server) initLifecycleState() {
s.initDefaults()
s.routerReady = make(chan struct{})
s.runErrCh = make(chan struct{})
s.udpRelays = make(map[types.ServiceID]*udprelay.Relay)
s.portRouters = make(map[uint16]*portRouter)
s.svcPorts = make(map[types.ServiceID][]uint16)
s.lastMappings = make(map[types.ServiceID]*proto.ProxyMapping)
}
// initMetrics builds the prometheus exporter and meter bundle.
func (s *Server) initMetrics(ctx context.Context) error {
exporter, err := prometheus.New()
if err != nil {
return fmt.Errorf("create prometheus exporter: %w", err)
}
provider := metric.NewMeterProvider(metric.WithReader(exporter))
pkg := reflect.TypeOf(Server{}).PkgPath()
meter := provider.Meter(pkg)
s.meter, err = proxymetrics.New(ctx, meter)
if err != nil {
return fmt.Errorf("create metrics: %w", err)
}
return nil
}
// initManagementClient dials management and stashes the connection so
// Stop can close it deterministically.
func (s *Server) initManagementClient() error {
conn, err := s.dialManagement()
if err != nil {
return err
}
s.mgmtConn = conn
s.mgmtClient = proto.NewProxyServiceClient(conn)
return nil
}
// initNetBirdClient builds the multi-tenant embedded NetBird client used
// for outbound RoundTripping and (when --private-inbound is on) per-account
// inbound listeners.
func (s *Server) initNetBirdClient() {
s.netbird = roundtrip.NewNetBird(s.ID, s.ProxyURL, roundtrip.ClientConfig{
MgmtAddr: s.ManagementAddress,
WGPort: s.WireguardPort,
PreSharedKey: s.PreSharedKey,
// On --private the embedded client serves per-account inbound
// listeners and must apply management's ACL: keep BlockInbound off
// so the engine creates the ACL manager. On the standalone proxy
// the embedded client never accepts inbound, so block.
BlockInbound: !s.Private,
}, s.Logger, s, s.mgmtClient)
}
// initReverseProxy builds the meter-instrumented reverse proxy. MultiTransport
// routes targets opted into direct_upstream through the host's network stack
// (stdlib transport); everything else falls through to the embedded NetBird
// client. The split is needed so direct_upstream targets resolve DNS via the
// proxy host's resolver instead of the tunnel's DNS.
func (s *Server) initReverseProxy() {
upstreamRT := roundtrip.NewMultiTransport(s.netbird, s.Logger)
s.proxy = proxy.NewReverseProxy(s.meter.RoundTripper(upstreamRT), s.ForwardedProto, s.TrustedProxies, s.Logger)
}
// initGeoLookup configures the GeoLite2 lookup used for country-based
// access restrictions and access-log enrichment.
func (s *Server) initGeoLookup() error {
geoLookup, err := geolocation.NewLookup(s.Logger, s.GeoDataDir)
if err != nil {
return fmt.Errorf("initialize geolocation: %w", err)
}
s.geoRaw = geoLookup
if geoLookup != nil {
s.geo = geoLookup
}
return nil
}
// buildHandlerChain wires the request middlewares from inside out.
func (s *Server) buildHandlerChain() http.Handler {
handler := http.Handler(s.proxy)
handler = s.auth.Protect(handler)
handler = web.AssetHandler(handler)
handler = s.accessLog.Middleware(handler)
handler = s.meter.Middleware(handler)
return s.hijackTracker.Middleware(handler)
}
// bindMainListener binds the main TCP listener and wraps it with PROXY
// protocol when configured.
func (s *Server) bindMainListener(ctx context.Context) (net.Listener, error) {
lc := net.ListenConfig{}
ln, err := lc.Listen(ctx, "tcp", s.ListenAddr)
if err != nil {
return nil, fmt.Errorf("listen on %s: %w", s.ListenAddr, err)
}
if s.ProxyProtocol {
ln = s.wrapProxyProtocol(ln)
}
s.mainPort = uint16(ln.Addr().(*net.TCPAddr).Port) //nolint:gosec // port from OS is always valid
s.Logger.WithFields(log.Fields{
"requested_addr": s.ListenAddr,
"bound_addr": ln.Addr().String(),
"private": s.Private,
"proxy_protocol": s.ProxyProtocol,
}).Info("proxy main listener bound")
return ln, nil
}
// initDefaults sets fallback values for optional Server fields.
@@ -434,6 +618,9 @@ func (s *Server) startDebugEndpoint() {
if s.acme != nil {
debugHandler.SetCertStatus(s.acme)
}
if s.inbound != nil {
debugHandler.SetInboundProvider(inboundDebugAdapter{mgr: s.inbound})
}
s.debug = &http.Server{
Addr: debugAddr,
Handler: debugHandler,
@@ -447,16 +634,18 @@ func (s *Server) startDebugEndpoint() {
}()
}
// startHealthServer launches the health probe and metrics server.
// startHealthServer launches the health probe and metrics server. Empty
// HealthAddress disables the probe entirely (intended for library callers
// that want to manage their own health surface).
func (s *Server) startHealthServer() error {
healthAddr := s.HealthAddress
if healthAddr == "" {
healthAddr = defaultHealthAddr
if s.HealthAddress == "" {
s.Logger.Debug("health probe disabled (empty HealthAddress)")
return nil
}
s.healthServer = health.NewServer(healthAddr, s.healthChecker, s.Logger, promhttp.HandlerFor(prometheus2.DefaultGatherer, promhttp.HandlerOpts{EnableOpenMetrics: true}))
healthListener, err := net.Listen("tcp", healthAddr)
s.healthServer = health.NewServer(s.HealthAddress, s.healthChecker, s.Logger, promhttp.HandlerFor(prometheus2.DefaultGatherer, promhttp.HandlerOpts{EnableOpenMetrics: true}))
healthListener, err := net.Listen("tcp", s.HealthAddress)
if err != nil {
return fmt.Errorf("health probe server listen on %s: %w", healthAddr, err)
return fmt.Errorf("health probe server listen on %s: %w", s.HealthAddress, err)
}
go func() {
if err := s.healthServer.Serve(healthListener); err != nil && !errors.Is(err, http.ErrServerClosed) {
@@ -507,8 +696,9 @@ func (s *Server) proxyProtocolPolicy(opts proxyproto.ConnPolicyOptions) (proxypr
}
const (
defaultHealthAddr = "localhost:8080"
defaultDebugAddr = "localhost:8444"
// defaultDebugAddr is the localhost-bound fallback for the debug endpoint
// when DebugEndpointAddress is empty.
defaultDebugAddr = "localhost:8444"
// proxyProtoHeaderTimeout is the deadline for reading the PROXY protocol
// header after accepting a connection.
@@ -661,8 +851,10 @@ func (s *Server) gracefulShutdown() {
defer drainCancel()
s.Logger.Info("draining in-flight connections")
if err := s.https.Shutdown(drainCtx); err != nil {
s.Logger.Warnf("https server drain: %v", err)
if s.https != nil {
if err := s.https.Shutdown(drainCtx); err != nil {
s.Logger.Warnf("https server drain: %v", err)
}
}
// Step 4: Close hijacked connections (WebSocket) that Shutdown does not handle.
@@ -809,6 +1001,18 @@ func (s *Server) resolveDialFunc(accountID types.AccountID) (types.DialContextFu
return client.DialContext, nil
}
// initPrivateInbound wires per-account inbound listeners when --private
// is set. When the flag is off this is a no-op and the standalone proxy keeps
// its byte-for-byte previous behaviour.
func (s *Server) initPrivateInbound(handler http.Handler, tlsConfig *tls.Config) {
if !s.Private {
return
}
s.inbound = newInboundManager(s.Logger, handler, tlsConfig)
s.netbird.SetClientLifecycle(s.inbound.onClientReady, s.inbound.onClientStop)
s.Logger.Info("private inbound listeners enabled (per-account :80 + :443)")
}
// notifyError reports a resource error back to management so it can be
// surfaced to the user (e.g. port bind failure, dialer resolution error).
func (s *Server) notifyError(ctx context.Context, mapping *proto.ProxyMapping, err error) {
@@ -941,9 +1145,6 @@ func (s *Server) newManagementMappingWorker(ctx context.Context, client proto.Pr
Clock: backoff.SystemClock,
}
// syncSupported tracks whether management supports SyncMappings.
// Starts true; set to false on first Unimplemented error.
syncSupported := true
initialSyncDone := false
operation := func() error {
@@ -955,25 +1156,41 @@ func (s *Server) newManagementMappingWorker(ctx context.Context, client proto.Pr
s.healthChecker.SetManagementConnected(false)
}
var streamErr error
if syncSupported {
streamErr = s.trySyncMappings(ctx, client, &initialSyncDone)
if isSyncUnimplemented(streamErr) {
syncSupported = false
s.Logger.Info("management does not support SyncMappings, falling back to GetMappingUpdate")
streamErr = s.tryGetMappingUpdate(ctx, client, &initialSyncDone)
}
} else {
streamErr = s.tryGetMappingUpdate(ctx, client, &initialSyncDone)
supportsCrowdSec := s.crowdsecRegistry.Available()
privateCapability := s.Private
// Always true: this build enforces ProxyMapping.private via the auth middleware.
supportsPrivateService := true
mappingClient, err := client.GetMappingUpdate(ctx, &proto.GetMappingUpdateRequest{
ProxyId: s.ID,
Version: s.Version,
StartedAt: timestamppb.New(s.startTime),
Address: s.ProxyURL,
Capabilities: &proto.ProxyCapabilities{
SupportsCustomPorts: &s.SupportsCustomPorts,
RequireSubdomain: &s.RequireSubdomain,
SupportsCrowdsec: &supportsCrowdSec,
Private: &privateCapability,
SupportsPrivateService: &supportsPrivateService,
},
})
if err != nil {
return fmt.Errorf("create mapping stream: %w", err)
}
if s.healthChecker != nil {
s.healthChecker.SetManagementConnected(true)
}
s.Logger.Debug("management mapping stream established")
// Stream established — reset backoff so the next failure retries quickly.
bo.Reset()
streamErr := s.handleMappingStream(ctx, mappingClient, &initialSyncDone, time.Now())
if s.healthChecker != nil {
s.healthChecker.SetManagementConnected(false)
}
// Stream established — reset backoff so the next failure retries quickly.
bo.Reset()
if streamErr == nil {
return fmt.Errorf("stream closed by server")
}
@@ -990,187 +1207,56 @@ func (s *Server) newManagementMappingWorker(ctx context.Context, client proto.Pr
}
}
func (s *Server) proxyCapabilities() *proto.ProxyCapabilities {
supportsCrowdSec := s.crowdsecRegistry.Available()
return &proto.ProxyCapabilities{
SupportsCustomPorts: &s.SupportsCustomPorts,
RequireSubdomain: &s.RequireSubdomain,
SupportsCrowdsec: &supportsCrowdSec,
}
}
func (s *Server) tryGetMappingUpdate(ctx context.Context, client proto.ProxyServiceClient, initialSyncDone *bool) error {
connectTime := time.Now()
mappingClient, err := client.GetMappingUpdate(ctx, &proto.GetMappingUpdateRequest{
ProxyId: s.ID,
Version: s.Version,
StartedAt: timestamppb.New(s.startTime),
Address: s.ProxyURL,
Capabilities: s.proxyCapabilities(),
})
if err != nil {
return fmt.Errorf("create mapping stream: %w", err)
}
if s.healthChecker != nil {
s.healthChecker.SetManagementConnected(true)
}
s.Logger.Debug("management mapping stream established (GetMappingUpdate)")
return s.handleMappingStream(ctx, mappingClient, initialSyncDone, connectTime)
}
func (s *Server) trySyncMappings(ctx context.Context, client proto.ProxyServiceClient, initialSyncDone *bool) error {
connectTime := time.Now()
stream, err := client.SyncMappings(ctx)
if err != nil {
return fmt.Errorf("create sync stream: %w", err)
}
// Send init message.
if err := stream.Send(&proto.SyncMappingsRequest{
Msg: &proto.SyncMappingsRequest_Init{
Init: &proto.SyncMappingsInit{
ProxyId: s.ID,
Version: s.Version,
StartedAt: timestamppb.New(s.startTime),
Address: s.ProxyURL,
Capabilities: s.proxyCapabilities(),
},
},
}); err != nil {
return fmt.Errorf("send sync init: %w", err)
}
if s.healthChecker != nil {
s.healthChecker.SetManagementConnected(true)
}
s.Logger.Debug("management mapping stream established (SyncMappings)")
return s.handleSyncMappingsStream(ctx, stream, initialSyncDone, connectTime)
}
func isSyncUnimplemented(err error) bool {
if err == nil {
return false
}
st, ok := grpcstatus.FromError(err)
return ok && st.Code() == codes.Unimplemented
}
func (s *Server) handleSyncMappingsStream(ctx context.Context, stream proto.ProxyService_SyncMappingsClient, initialSyncDone *bool, connectTime time.Time) error {
func (s *Server) handleMappingStream(ctx context.Context, mappingClient proto.ProxyService_GetMappingUpdateClient, initialSyncDone *bool, _ time.Time) error {
select {
case <-s.routerReady:
case <-ctx.Done():
return ctx.Err()
}
tracker := s.newSnapshotTracker(initialSyncDone, connectTime)
for {
select {
case <-ctx.Done():
return ctx.Err()
default:
msg, err := stream.Recv()
switch {
case errors.Is(err, io.EOF):
return nil
case err != nil:
return fmt.Errorf("receive msg: %w", err)
}
batchStart := time.Now()
s.Logger.Debug("Received mapping update, starting processing")
s.processMappings(ctx, msg.GetMapping())
s.Logger.Debug("Processing mapping update completed")
tracker.recordBatch(ctx, s, msg.GetMapping(), msg.GetInitialSyncComplete(), batchStart)
if err := stream.Send(&proto.SyncMappingsRequest{
Msg: &proto.SyncMappingsRequest_Ack{Ack: &proto.SyncMappingsAck{}},
}); err != nil {
return fmt.Errorf("send ack: %w", err)
}
}
}
}
func (s *Server) handleMappingStream(ctx context.Context, mappingClient proto.ProxyService_GetMappingUpdateClient, initialSyncDone *bool, connectTime time.Time) error {
select {
case <-s.routerReady:
case <-ctx.Done():
return ctx.Err()
var snapshotIDs map[types.ServiceID]struct{}
if !*initialSyncDone {
snapshotIDs = make(map[types.ServiceID]struct{})
}
tracker := s.newSnapshotTracker(initialSyncDone, connectTime)
for {
// Check for context completion to gracefully shutdown.
select {
case <-ctx.Done():
// Shutting down.
return ctx.Err()
default:
msg, err := mappingClient.Recv()
switch {
case errors.Is(err, io.EOF):
// Mapping connection gracefully terminated by server.
return nil
case err != nil:
// Something has gone horribly wrong, return and hope the parent retries the connection.
return fmt.Errorf("receive msg: %w", err)
}
batchStart := time.Now()
s.Logger.Debug("Received mapping update, starting processing")
s.processMappings(ctx, msg.GetMapping())
s.Logger.Debug("Processing mapping update completed")
tracker.recordBatch(ctx, s, msg.GetMapping(), msg.GetInitialSyncComplete(), batchStart)
if !*initialSyncDone {
for _, m := range msg.GetMapping() {
snapshotIDs[types.ServiceID(m.GetId())] = struct{}{}
}
if msg.GetInitialSyncComplete() {
s.reconcileSnapshot(ctx, snapshotIDs)
snapshotIDs = nil
if s.healthChecker != nil {
s.healthChecker.SetInitialSyncComplete()
}
*initialSyncDone = true
s.Logger.Info("Initial mapping sync complete")
}
}
}
}
}
// snapshotTracker accumulates service IDs during the initial snapshot and
// finalises sync state when the complete flag arrives.
type snapshotTracker struct {
done *bool
connectTime time.Time
snapshotIDs map[types.ServiceID]struct{}
}
func (s *Server) newSnapshotTracker(done *bool, connectTime time.Time) *snapshotTracker {
var ids map[types.ServiceID]struct{}
if !*done {
ids = make(map[types.ServiceID]struct{})
}
return &snapshotTracker{done: done, connectTime: connectTime, snapshotIDs: ids}
}
func (t *snapshotTracker) recordBatch(ctx context.Context, s *Server, mappings []*proto.ProxyMapping, syncComplete bool, batchStart time.Time) {
if *t.done {
return
}
if s.meter != nil {
s.meter.RecordSnapshotBatchDuration(time.Since(batchStart))
}
for _, m := range mappings {
t.snapshotIDs[types.ServiceID(m.GetId())] = struct{}{}
}
if !syncComplete {
return
}
s.reconcileSnapshot(ctx, t.snapshotIDs)
t.snapshotIDs = nil
if s.healthChecker != nil {
s.healthChecker.SetInitialSyncComplete()
}
*t.done = true
if s.meter != nil {
s.meter.RecordSnapshotSyncDuration(time.Since(t.connectTime))
}
s.Logger.Info("Initial mapping sync complete")
}
// reconcileSnapshot removes local mappings that are absent from the snapshot.
// This ensures services deleted while the proxy was disconnected get cleaned up.
func (s *Server) reconcileSnapshot(ctx context.Context, snapshotIDs map[types.ServiceID]struct{}) {
@@ -1192,17 +1278,29 @@ func (s *Server) reconcileSnapshot(ctx context.Context, snapshotIDs map[types.Se
}
}
func (s *Server) processMappings(ctx context.Context, mappings []*proto.ProxyMapping) {
s.ensurePeers(ctx, mappings)
// mappingJSONMarshal dumps mappings on one line with zero-value fields visible for debug logs.
var mappingJSONMarshal = protojson.MarshalOptions{
Multiline: false,
EmitUnpopulated: true,
UseProtoNames: true,
}
func (s *Server) processMappings(ctx context.Context, mappings []*proto.ProxyMapping) {
// The full proto dump carries auth_token and header-auth values; gate on debug.
debug := s.Logger != nil && s.Logger.IsLevelEnabled(log.DebugLevel)
for _, mapping := range mappings {
s.Logger.WithFields(log.Fields{
"type": mapping.GetType(),
"domain": mapping.GetDomain(),
"mode": mapping.GetMode(),
"port": mapping.GetListenPort(),
"id": mapping.GetId(),
}).Debug("Processing mapping update")
if debug {
raw, err := mappingJSONMarshal.Marshal(mapping)
if err != nil {
raw = []byte(fmt.Sprintf("<marshal error: %v>", err))
}
s.Logger.WithFields(log.Fields{
"type": mapping.GetType(),
"domain": mapping.GetDomain(),
"id": mapping.GetId(),
"mapping": string(raw),
}).Debug("Processing mapping update")
}
switch mapping.GetType() {
case proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED:
if err := s.addMapping(ctx, mapping); err != nil {
@@ -1228,60 +1326,6 @@ func (s *Server) processMappings(ctx context.Context, mappings []*proto.ProxyMap
}
}
// ensurePeers pre-creates NetBird peers for all unique accounts referenced by
// CREATED mappings. Peers for different accounts are created concurrently,
// which avoids serializing N×100ms gRPC round-trips during large initial syncs.
func (s *Server) ensurePeers(ctx context.Context, mappings []*proto.ProxyMapping) {
// Collect one representative mapping per account that needs a new peer.
type peerReq struct {
accountID types.AccountID
svcKey roundtrip.ServiceKey
authToken string
svcID types.ServiceID
}
seen := make(map[types.AccountID]struct{})
var reqs []peerReq
for _, m := range mappings {
if m.GetType() != proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED {
continue
}
accountID := types.AccountID(m.GetAccountId())
if _, ok := seen[accountID]; ok {
continue
}
seen[accountID] = struct{}{}
if s.netbird.HasClient(accountID) {
continue
}
reqs = append(reqs, peerReq{
accountID: accountID,
svcKey: s.serviceKeyForMapping(m),
authToken: m.GetAuthToken(),
svcID: types.ServiceID(m.GetId()),
})
}
if len(reqs) <= 1 {
return
}
var wg sync.WaitGroup
wg.Add(len(reqs))
for _, r := range reqs {
go func() {
defer wg.Done()
if err := s.netbird.AddPeer(ctx, r.accountID, r.svcKey, r.authToken, r.svcID); err != nil {
s.Logger.WithFields(log.Fields{
"account_id": r.accountID,
"service_id": r.svcID,
"error": err,
}).Warn("failed to pre-create peer for account")
}
}()
}
wg.Wait()
}
// addMapping registers a service mapping and starts the appropriate relay or routes.
func (s *Server) addMapping(ctx context.Context, mapping *proto.ProxyMapping) error {
accountID := types.AccountID(mapping.GetAccountId())
@@ -1353,12 +1397,16 @@ func (s *Server) setupHTTPMapping(ctx context.Context, mapping *proto.ProxyMappi
if s.acme != nil {
wildcardHit = s.acme.AddDomain(d, accountID, svcID)
}
s.mainRouter.AddRoute(nbtcp.SNIHost(mapping.GetDomain()), nbtcp.Route{
httpRoute := nbtcp.Route{
Type: nbtcp.RouteHTTP,
AccountID: accountID,
ServiceID: svcID,
Domain: mapping.GetDomain(),
})
}
s.mainRouter.AddRoute(nbtcp.SNIHost(mapping.GetDomain()), httpRoute)
if s.inbound != nil {
s.inbound.AddRoute(accountID, nbtcp.SNIHost(mapping.GetDomain()), httpRoute)
}
if err := s.updateMapping(ctx, mapping); err != nil {
return fmt.Errorf("update mapping for domain %q: %w", d, err)
}
@@ -1718,7 +1766,7 @@ func (s *Server) updateMapping(ctx context.Context, mapping *proto.ProxyMapping)
s.warnIfGeoUnavailable(mapping.GetDomain(), mapping.GetAccessRestrictions())
maxSessionAge := time.Duration(mapping.GetAuth().GetMaxSessionAgeSeconds()) * time.Second
if err := s.auth.AddDomain(mapping.GetDomain(), schemes, mapping.GetAuth().GetSessionKey(), maxSessionAge, accountID, svcID, ipRestrictions); err != nil {
if err := s.auth.AddDomain(mapping.GetDomain(), schemes, mapping.GetAuth().GetSessionKey(), maxSessionAge, accountID, svcID, ipRestrictions, mapping.GetPrivate()); err != nil {
return fmt.Errorf("auth setup for domain %s: %w", mapping.GetDomain(), err)
}
m := s.protoToMapping(ctx, mapping)
@@ -1774,6 +1822,9 @@ func (s *Server) cleanupMappingRoutes(mapping *proto.ProxyMapping) {
}
// Remove SNI route from the main router (covers both HTTP and main-port TLS).
s.mainRouter.RemoveRoute(nbtcp.SNIHost(host), svcID)
if s.inbound != nil {
s.inbound.RemoveRoute(types.AccountID(mapping.GetAccountId()), nbtcp.SNIHost(host), svcID)
}
}
// Extract and delete tracked custom-port entries atomically.
@@ -1861,6 +1912,7 @@ func (s *Server) protoToMapping(ctx context.Context, mapping *proto.ProxyMapping
if d := opts.GetRequestTimeout(); d != nil {
pt.RequestTimeout = d.AsDuration()
}
pt.DirectUpstream = opts.GetDirectUpstream()
}
pt.RequestTimeout = s.clampDialTimeout(pt.RequestTimeout)
paths[pathMapping.GetPath()] = pt

View File

@@ -1,525 +0,0 @@
package proxy
import (
"context"
"errors"
"fmt"
"net"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials/insecure"
grpcstatus "google.golang.org/grpc/status"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/shared/management/proto"
)
func TestIntegration_SyncMappings_HappyPath(t *testing.T) {
setup := setupIntegrationTest(t)
defer setup.cleanup()
conn, err := grpc.NewClient(setup.grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
require.NoError(t, err)
defer conn.Close()
client := proto.NewProxyServiceClient(conn)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
stream, err := client.SyncMappings(ctx)
require.NoError(t, err)
// Send init.
err = stream.Send(&proto.SyncMappingsRequest{
Msg: &proto.SyncMappingsRequest_Init{
Init: &proto.SyncMappingsInit{
ProxyId: "sync-proxy-1",
Version: "test-v1",
Address: "test.proxy.io",
},
},
})
require.NoError(t, err)
mappingsByID := make(map[string]*proto.ProxyMapping)
for {
msg, err := stream.Recv()
require.NoError(t, err)
for _, m := range msg.GetMapping() {
mappingsByID[m.GetId()] = m
}
// Ack every batch.
err = stream.Send(&proto.SyncMappingsRequest{
Msg: &proto.SyncMappingsRequest_Ack{Ack: &proto.SyncMappingsAck{}},
})
require.NoError(t, err)
if msg.GetInitialSyncComplete() {
break
}
}
assert.Len(t, mappingsByID, 2, "Should receive 2 mappings")
rp1 := mappingsByID["rp-1"]
require.NotNil(t, rp1)
assert.Equal(t, "app1.test.proxy.io", rp1.GetDomain())
assert.Equal(t, "test-account-1", rp1.GetAccountId())
assert.Equal(t, proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED, rp1.GetType())
assert.NotEmpty(t, rp1.GetAuthToken(), "Should have auth token")
rp2 := mappingsByID["rp-2"]
require.NotNil(t, rp2)
assert.Equal(t, "app2.test.proxy.io", rp2.GetDomain())
}
func TestIntegration_SyncMappings_BackPressure(t *testing.T) {
setup := setupIntegrationTest(t)
defer setup.cleanup()
// Add enough services to guarantee multiple batches (default batch size 500).
addServicesToStore(t, setup, 600, "test.proxy.io")
conn, err := grpc.NewClient(setup.grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
require.NoError(t, err)
defer conn.Close()
client := proto.NewProxyServiceClient(conn)
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
stream, err := client.SyncMappings(ctx)
require.NoError(t, err)
err = stream.Send(&proto.SyncMappingsRequest{
Msg: &proto.SyncMappingsRequest_Init{
Init: &proto.SyncMappingsInit{
ProxyId: "sync-proxy-backpressure",
Version: "test-v1",
Address: "test.proxy.io",
},
},
})
require.NoError(t, err)
// Strategy: receive batch 1, then hold for a significant delay before
// acking. If back-pressure works, batch 2 cannot arrive until after
// the ack is sent — so its receive timestamp must be >= the ack
// timestamp. If management were fire-and-forget, all batches would
// already be buffered in the gRPC transport and batch 2 would arrive
// well before the ack time.
const ackDelay = 300 * time.Millisecond
type batchEvent struct {
recvAt time.Time
ackAt time.Time
count int
}
var batches []batchEvent
var totalMappings int
for {
msg, err := stream.Recv()
require.NoError(t, err)
recvAt := time.Now()
totalMappings += len(msg.GetMapping())
// Delay the ack on non-final batches to create a measurable gap.
if !msg.GetInitialSyncComplete() {
time.Sleep(ackDelay)
}
ackAt := time.Now()
batches = append(batches, batchEvent{
recvAt: recvAt,
ackAt: ackAt,
count: len(msg.GetMapping()),
})
err = stream.Send(&proto.SyncMappingsRequest{
Msg: &proto.SyncMappingsRequest_Ack{Ack: &proto.SyncMappingsAck{}},
})
require.NoError(t, err)
if msg.GetInitialSyncComplete() {
break
}
}
// 2 original + 600 added = 602 services total.
assert.Equal(t, 602, totalMappings, "should receive all 602 mappings")
require.GreaterOrEqual(t, len(batches), 2, "need at least 2 batches to verify back-pressure")
// For every batch after the first, its receive time must be after the
// previous batch's ack time. This proves management waited for the ack
// before sending the next batch.
for i := 1; i < len(batches); i++ {
prevAckAt := batches[i-1].ackAt
thisRecvAt := batches[i].recvAt
assert.True(t, !thisRecvAt.Before(prevAckAt),
"batch %d received at %v, but batch %d was acked at %v — "+
"management sent the next batch before receiving the ack",
i, thisRecvAt, i-1, prevAckAt)
}
}
func TestIntegration_SyncMappings_IncrementalUpdate(t *testing.T) {
setup := setupIntegrationTest(t)
defer setup.cleanup()
conn, err := grpc.NewClient(setup.grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
require.NoError(t, err)
defer conn.Close()
client := proto.NewProxyServiceClient(conn)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
stream, err := client.SyncMappings(ctx)
require.NoError(t, err)
err = stream.Send(&proto.SyncMappingsRequest{
Msg: &proto.SyncMappingsRequest_Init{
Init: &proto.SyncMappingsInit{
ProxyId: "sync-proxy-incremental",
Version: "test-v1",
Address: "test.proxy.io",
},
},
})
require.NoError(t, err)
// Drain initial snapshot.
for {
msg, err := stream.Recv()
require.NoError(t, err)
err = stream.Send(&proto.SyncMappingsRequest{
Msg: &proto.SyncMappingsRequest_Ack{Ack: &proto.SyncMappingsAck{}},
})
require.NoError(t, err)
if msg.GetInitialSyncComplete() {
break
}
}
// Now send an incremental update via the management server.
setup.proxyService.SendServiceUpdate(&proto.GetMappingUpdateResponse{
Mapping: []*proto.ProxyMapping{{
Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED,
Id: "rp-1",
AccountId: "test-account-1",
Domain: "app1.test.proxy.io",
}},
})
// Receive the incremental update on the sync stream.
msg, err := stream.Recv()
require.NoError(t, err)
require.NotEmpty(t, msg.GetMapping())
assert.Equal(t, "rp-1", msg.GetMapping()[0].GetId())
assert.Equal(t, proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED, msg.GetMapping()[0].GetType())
}
func TestIntegration_SyncMappings_MixedProxyVersions(t *testing.T) {
setup := setupIntegrationTest(t)
defer setup.cleanup()
conn, err := grpc.NewClient(setup.grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
require.NoError(t, err)
defer conn.Close()
client := proto.NewProxyServiceClient(conn)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
// Old proxy uses GetMappingUpdate.
legacyStream, err := client.GetMappingUpdate(ctx, &proto.GetMappingUpdateRequest{
ProxyId: "legacy-proxy",
Version: "old-v1",
Address: "test.proxy.io",
})
require.NoError(t, err)
var legacyMappings []*proto.ProxyMapping
for {
msg, err := legacyStream.Recv()
require.NoError(t, err)
legacyMappings = append(legacyMappings, msg.GetMapping()...)
if msg.GetInitialSyncComplete() {
break
}
}
// New proxy uses SyncMappings.
syncStream, err := client.SyncMappings(ctx)
require.NoError(t, err)
err = syncStream.Send(&proto.SyncMappingsRequest{
Msg: &proto.SyncMappingsRequest_Init{
Init: &proto.SyncMappingsInit{
ProxyId: "new-proxy",
Version: "new-v2",
Address: "test.proxy.io",
},
},
})
require.NoError(t, err)
var syncMappings []*proto.ProxyMapping
for {
msg, err := syncStream.Recv()
require.NoError(t, err)
syncMappings = append(syncMappings, msg.GetMapping()...)
err = syncStream.Send(&proto.SyncMappingsRequest{
Msg: &proto.SyncMappingsRequest_Ack{Ack: &proto.SyncMappingsAck{}},
})
require.NoError(t, err)
if msg.GetInitialSyncComplete() {
break
}
}
// Both should receive the same set of mappings.
assert.Equal(t, len(legacyMappings), len(syncMappings),
"legacy and sync proxies should receive the same number of mappings")
legacyIDs := make(map[string]bool)
for _, m := range legacyMappings {
legacyIDs[m.GetId()] = true
}
for _, m := range syncMappings {
assert.True(t, legacyIDs[m.GetId()],
"mapping %s should be present in both streams", m.GetId())
}
// Both proxies should be connected.
proxies := setup.proxyService.GetConnectedProxies()
assert.Contains(t, proxies, "legacy-proxy")
assert.Contains(t, proxies, "new-proxy")
// Both should receive incremental updates.
setup.proxyService.SendServiceUpdate(&proto.GetMappingUpdateResponse{
Mapping: []*proto.ProxyMapping{{
Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED,
Id: "rp-1",
AccountId: "test-account-1",
Domain: "app1.test.proxy.io",
}},
})
// Legacy proxy receives via GetMappingUpdateResponse.
legacyMsg, err := legacyStream.Recv()
require.NoError(t, err)
assert.Equal(t, "rp-1", legacyMsg.GetMapping()[0].GetId())
// Sync proxy receives via SyncMappingsResponse.
syncMsg, err := syncStream.Recv()
require.NoError(t, err)
assert.Equal(t, "rp-1", syncMsg.GetMapping()[0].GetId())
}
func TestIntegration_SyncMappings_Reconnect(t *testing.T) {
setup := setupIntegrationTest(t)
defer setup.cleanup()
conn, err := grpc.NewClient(setup.grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
require.NoError(t, err)
defer conn.Close()
client := proto.NewProxyServiceClient(conn)
proxyID := "sync-proxy-reconnect"
receiveMappings := func() []*proto.ProxyMapping {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
stream, err := client.SyncMappings(ctx)
require.NoError(t, err)
err = stream.Send(&proto.SyncMappingsRequest{
Msg: &proto.SyncMappingsRequest_Init{
Init: &proto.SyncMappingsInit{
ProxyId: proxyID,
Version: "test-v1",
Address: "test.proxy.io",
},
},
})
require.NoError(t, err)
var mappings []*proto.ProxyMapping
for {
msg, err := stream.Recv()
require.NoError(t, err)
mappings = append(mappings, msg.GetMapping()...)
err = stream.Send(&proto.SyncMappingsRequest{
Msg: &proto.SyncMappingsRequest_Ack{Ack: &proto.SyncMappingsAck{}},
})
require.NoError(t, err)
if msg.GetInitialSyncComplete() {
break
}
}
return mappings
}
first := receiveMappings()
time.Sleep(100 * time.Millisecond)
second := receiveMappings()
assert.Equal(t, len(first), len(second),
"should receive same mappings on reconnect")
firstIDs := make(map[string]bool)
for _, m := range first {
firstIDs[m.GetId()] = true
}
for _, m := range second {
assert.True(t, firstIDs[m.GetId()],
"mapping %s should be present in both connections", m.GetId())
}
}
// --- Fallback tests: old management returns Unimplemented ---
// unimplementedProxyServer embeds UnimplementedProxyServiceServer so
// SyncMappings returns codes.Unimplemented while GetMappingUpdate works.
type unimplementedSyncServer struct {
proto.UnimplementedProxyServiceServer
getMappingCalls atomic.Int32
}
func (s *unimplementedSyncServer) GetMappingUpdate(_ *proto.GetMappingUpdateRequest, stream proto.ProxyService_GetMappingUpdateServer) error {
s.getMappingCalls.Add(1)
return stream.Send(&proto.GetMappingUpdateResponse{
Mapping: []*proto.ProxyMapping{{Id: "svc-1", AccountId: "acct-1", Domain: "example.com"}},
InitialSyncComplete: true,
})
}
func TestIntegration_FallbackToGetMappingUpdate(t *testing.T) {
// Start a gRPC server that does NOT implement SyncMappings.
lis, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
srv := &unimplementedSyncServer{}
grpcServer := grpc.NewServer()
proto.RegisterProxyServiceServer(grpcServer, srv)
go func() { _ = grpcServer.Serve(lis) }()
defer grpcServer.GracefulStop()
conn, err := grpc.NewClient(lis.Addr().String(), grpc.WithTransportCredentials(insecure.NewCredentials()))
require.NoError(t, err)
defer conn.Close()
client := proto.NewProxyServiceClient(conn)
// Try SyncMappings — should get Unimplemented.
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
stream, err := client.SyncMappings(ctx)
require.NoError(t, err)
err = stream.Send(&proto.SyncMappingsRequest{
Msg: &proto.SyncMappingsRequest_Init{
Init: &proto.SyncMappingsInit{
ProxyId: "test-proxy",
Address: "test.example.com",
},
},
})
require.NoError(t, err)
_, err = stream.Recv()
require.Error(t, err)
st, ok := grpcstatus.FromError(err)
require.True(t, ok)
assert.Equal(t, codes.Unimplemented, st.Code(),
"unimplemented SyncMappings should return Unimplemented code")
// isSyncUnimplemented should detect this.
assert.True(t, isSyncUnimplemented(err))
// The actual fallback: GetMappingUpdate should work.
legacyStream, err := client.GetMappingUpdate(ctx, &proto.GetMappingUpdateRequest{
ProxyId: "test-proxy",
Address: "test.example.com",
})
require.NoError(t, err)
msg, err := legacyStream.Recv()
require.NoError(t, err)
assert.True(t, msg.GetInitialSyncComplete())
assert.Len(t, msg.GetMapping(), 1)
assert.Equal(t, int32(1), srv.getMappingCalls.Load())
}
func TestIsSyncUnimplemented(t *testing.T) {
tests := []struct {
name string
err error
want bool
}{
{"nil error", nil, false},
{"non-grpc error", errors.New("random"), false},
{"grpc internal", grpcstatus.Error(codes.Internal, "fail"), false},
{"grpc unavailable", grpcstatus.Error(codes.Unavailable, "fail"), false},
{"grpc unimplemented", grpcstatus.Error(codes.Unimplemented, "method not found"), true},
{
"wrapped unimplemented",
fmt.Errorf("create sync stream: %w", grpcstatus.Error(codes.Unimplemented, "nope")),
// grpc/status.FromError unwraps in recent versions of grpc-go.
true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert.Equal(t, tt.want, isSyncUnimplemented(tt.err))
})
}
}
// addServicesToStore adds n extra services to the test store for the given cluster.
func addServicesToStore(t *testing.T, setup *integrationTestSetup, n int, cluster string) {
t.Helper()
ctx := context.Background()
for i := 0; i < n; i++ {
svc := &service.Service{
ID: fmt.Sprintf("extra-svc-%d", i),
AccountID: "test-account-1",
Name: fmt.Sprintf("Extra Service %d", i),
Domain: fmt.Sprintf("extra-%d.test.proxy.io", i),
ProxyCluster: cluster,
Enabled: true,
Targets: []*service.Target{{
Path: strPtr("/"),
Host: fmt.Sprintf("10.0.1.%d", i%256),
Port: 8080,
Protocol: "http",
TargetId: fmt.Sprintf("peer-extra-%d", i),
TargetType: "peer",
Enabled: true,
}},
}
require.NoError(t, setup.store.CreateService(ctx, svc))
}
}