mirror of
https://github.com/netbirdio/netbird.git
synced 2026-05-31 21:19:55 +00:00
[management, client, proxy] add expose NetBird-only services over tunnel peers (#6226)
Adds a new "private" service mode for the reverse proxy: services reachable exclusively over the embedded WireGuard tunnel, gated by per-peer group membership instead of operator auth schemes. Wire contract - ProxyMapping.private (field 13): the proxy MUST call ValidateTunnelPeer and fail closed; operator schemes are bypassed. - ProxyCapabilities.private (4) + supports_private_service (5): capability gate. Management never streams private mappings to proxies that don't claim the capability; the broadcast path applies the same filter via filterMappingsForProxy. - ValidateTunnelPeer RPC: resolves an inbound tunnel IP to a peer, checks the peer's groups against service.AccessGroups, and mints a session JWT on success. checkPeerGroupAccess fails closed when a private service has empty AccessGroups. - ValidateSession/ValidateTunnelPeer responses now carry peer_group_ids + peer_group_names so the proxy can authorise policy-aware middlewares without an extra management round-trip. - ProxyInboundListener + SendStatusUpdate.inbound_listener: per-account inbound listener state surfaced to dashboards. - PathTargetOptions.direct_upstream (11): bypass the embedded NetBird client and dial the target via the proxy host's network stack for upstreams reachable without WireGuard. Data model - Service.Private (bool) + Service.AccessGroups ([]string, JSON- serialised). Validate() rejects bearer auth on private services. Copy() deep-copies AccessGroups. pgx getServices loads the columns. - DomainConfig.Private threaded into the proxy auth middleware. Request handler routes private services through forwardWithTunnelPeer and returns 403 on validation failure. - Account-level SynthesizePrivateServiceZones (synthetic DNS) and injectPrivateServicePolicies (synthetic ACL) gate on len(svc.AccessGroups) > 0. Proxy - /netbird proxy --private (embedded mode) flag; Config.Private in proxy/lifecycle.go. - Per-account inbound listener (proxy/inbound.go) binding HTTP/HTTPS on the embedded NetBird client's WireGuard tunnel netstack. - proxy/internal/auth/tunnel_cache: ValidateTunnelPeer response cache with single-flight de-duplication and per-account eviction. - Local peerstore short-circuit: when the inbound IP isn't in the account roster, deny fast without an RPC. - proxy/server.go reports SupportsPrivateService=true and redacts the full ProxyMapping JSON from info logs (auth_token + header-auth hashed values now only at debug level). Identity forwarding - ValidateSessionJWT returns user_id, email, method, groups, group_names. sessionkey.Claims carries Email + Groups + GroupNames so the proxy can stamp identity onto upstream requests without an extra management round-trip on every cookie-bearing request. - CapturedData carries userEmail / userGroups / userGroupNames; the proxy stamps X-NetBird-User and X-NetBird-Groups on r.Out from the authenticated identity (strips client-supplied values first to prevent spoofing). - AccessLog.UserGroups: access-log enrichment captures the user's group memberships at write time so the dashboard can render group context without reverse-resolving stale memberships. OpenAPI/dashboard surface - ReverseProxyService gains private + access_groups; ReverseProxyCluster gains private + supports_private. ReverseProxyTarget target_type enum gains "cluster". ServiceTargetOptions gains direct_upstream. ProxyAccessLog gains user_groups.
This commit is contained in:
47
proxy/internal/auth/identity.go
Normal file
47
proxy/internal/auth/identity.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/netip"
|
||||
)
|
||||
|
||||
// PeerIdentity describes the locally-known facts about a peer reachable on
|
||||
// the proxy's per-account WireGuard listener. Phase 3 fills PubKey, TunnelIP
|
||||
// and FQDN from the embedded client's peerstore. UserID, Email and Groups
|
||||
// stay zero in V1 — full identity still travels through ValidateTunnelPeer.
|
||||
// Phase V2 will populate them once RemotePeerConfig carries user identity.
|
||||
type PeerIdentity struct {
|
||||
PubKey string
|
||||
TunnelIP netip.Addr
|
||||
FQDN string
|
||||
|
||||
// V2 fields (zero in V1).
|
||||
UserID string
|
||||
Email string
|
||||
Groups []string
|
||||
}
|
||||
|
||||
// TunnelLookupFunc resolves a tunnel IP to a peer identity using locally
|
||||
// available peerstore data. ok=false means the IP is not in the calling
|
||||
// account's roster.
|
||||
type TunnelLookupFunc func(ip netip.Addr) (PeerIdentity, bool)
|
||||
|
||||
type tunnelLookupContextKey struct{}
|
||||
|
||||
// WithTunnelLookup attaches a per-account peerstore lookup function to
|
||||
// the request context. The auth middleware calls this lookup before
|
||||
// hitting management's ValidateTunnelPeer to short-circuit unknown IPs
|
||||
// and to skip the RPC for already-cached identities.
|
||||
func WithTunnelLookup(ctx context.Context, lookup TunnelLookupFunc) context.Context {
|
||||
if lookup == nil {
|
||||
return ctx
|
||||
}
|
||||
return context.WithValue(ctx, tunnelLookupContextKey{}, lookup)
|
||||
}
|
||||
|
||||
// TunnelLookupFromContext returns the peerstore lookup attached to ctx,
|
||||
// or nil when the request did not arrive on a per-account listener.
|
||||
func TunnelLookupFromContext(ctx context.Context) TunnelLookupFunc {
|
||||
v, _ := ctx.Value(tunnelLookupContextKey{}).(TunnelLookupFunc)
|
||||
return v
|
||||
}
|
||||
@@ -36,6 +36,7 @@ type authenticator interface {
|
||||
// SessionValidator validates session tokens and checks user access permissions.
|
||||
type SessionValidator interface {
|
||||
ValidateSession(ctx context.Context, in *proto.ValidateSessionRequest, opts ...grpc.CallOption) (*proto.ValidateSessionResponse, error)
|
||||
ValidateTunnelPeer(ctx context.Context, in *proto.ValidateTunnelPeerRequest, opts ...grpc.CallOption) (*proto.ValidateTunnelPeerResponse, error)
|
||||
}
|
||||
|
||||
// Scheme defines an authentication mechanism for a domain.
|
||||
@@ -56,12 +57,21 @@ type DomainConfig struct {
|
||||
AccountID types.AccountID
|
||||
ServiceID types.ServiceID
|
||||
IPRestrictions *restrict.Filter
|
||||
// Private routes the domain through ValidateTunnelPeer; failure → 403.
|
||||
Private bool
|
||||
}
|
||||
|
||||
type validationResult struct {
|
||||
UserID string
|
||||
UserEmail string
|
||||
Valid bool
|
||||
DeniedReason string
|
||||
Groups []string
|
||||
// GroupNames carries the human-readable display names for Groups,
|
||||
// ordered identically (positional pairing). May be shorter than
|
||||
// Groups for tokens minted before names were embedded; the consumer
|
||||
// falls back to ids for missing positions.
|
||||
GroupNames []string
|
||||
}
|
||||
|
||||
// Middleware applies per-domain authentication and IP restriction checks.
|
||||
@@ -71,6 +81,7 @@ type Middleware struct {
|
||||
logger *log.Logger
|
||||
sessionValidator SessionValidator
|
||||
geo restrict.GeoResolver
|
||||
tunnelCache *tunnelValidationCache
|
||||
}
|
||||
|
||||
// NewMiddleware creates a new authentication middleware. The sessionValidator is
|
||||
@@ -84,6 +95,7 @@ func NewMiddleware(logger *log.Logger, sessionValidator SessionValidator, geo re
|
||||
logger: logger,
|
||||
sessionValidator: sessionValidator,
|
||||
geo: geo,
|
||||
tunnelCache: newTunnelValidationCache(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -111,6 +123,15 @@ func (mw *Middleware) Protect(next http.Handler) http.Handler {
|
||||
return
|
||||
}
|
||||
|
||||
// Private services bypass operator schemes and gate on tunnel peer.
|
||||
if config.Private {
|
||||
if mw.forwardWithTunnelPeer(w, r, host, config, next) {
|
||||
return
|
||||
}
|
||||
http.Error(w, "Forbidden", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
// Domains with no authentication schemes pass through after IP checks.
|
||||
if len(config.Schemes) == 0 {
|
||||
next.ServeHTTP(w, r)
|
||||
@@ -129,10 +150,54 @@ func (mw *Middleware) Protect(next http.Handler) http.Handler {
|
||||
return
|
||||
}
|
||||
|
||||
if mw.forwardWithTunnelPeer(w, r, host, config, next) {
|
||||
return
|
||||
}
|
||||
|
||||
if mw.blockOIDCOnPlainHTTP(w, r, config) {
|
||||
return
|
||||
}
|
||||
|
||||
mw.authenticateWithSchemes(w, r, host, config)
|
||||
})
|
||||
}
|
||||
|
||||
// requestIsPlainHTTP reports whether the request arrived without TLS.
|
||||
// Used to gate cookie-on-plain warnings and the OIDC plain-HTTP block.
|
||||
func requestIsPlainHTTP(r *http.Request) bool {
|
||||
return r.TLS == nil
|
||||
}
|
||||
|
||||
// hasOIDCScheme reports whether any of the configured schemes requires
|
||||
// TLS to round-trip safely with an external IdP.
|
||||
func hasOIDCScheme(schemes []Scheme) bool {
|
||||
for _, s := range schemes {
|
||||
if s.Type() == auth.MethodOIDC {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// blockOIDCOnPlainHTTP fails fast when an OIDC-configured domain is hit
|
||||
// over plain HTTP. Most IdPs reject http:// redirect URIs, so surfacing
|
||||
// the misconfiguration here yields a clearer error than the IdP's
|
||||
// "invalid redirect_uri" round-trip.
|
||||
func (mw *Middleware) blockOIDCOnPlainHTTP(w http.ResponseWriter, r *http.Request, config DomainConfig) bool {
|
||||
if !requestIsPlainHTTP(r) {
|
||||
return false
|
||||
}
|
||||
if !hasOIDCScheme(config.Schemes) {
|
||||
return false
|
||||
}
|
||||
mw.logger.WithFields(log.Fields{
|
||||
"host": r.Host,
|
||||
"remote": r.RemoteAddr,
|
||||
}).Warn("OIDC scheme reached on plain HTTP path; rejecting with 400 — use port 443")
|
||||
http.Error(w, "OIDC requires TLS — use port 443", http.StatusBadRequest)
|
||||
return true
|
||||
}
|
||||
|
||||
func (mw *Middleware) getDomainConfig(host string) (DomainConfig, bool) {
|
||||
mw.domainsMux.RLock()
|
||||
defer mw.domainsMux.RUnlock()
|
||||
@@ -162,7 +227,17 @@ func (mw *Middleware) checkIPRestrictions(w http.ResponseWriter, r *http.Request
|
||||
return false
|
||||
}
|
||||
|
||||
verdict := config.IPRestrictions.Check(clientIP, mw.geo)
|
||||
var verdict restrict.Verdict
|
||||
if types.IsOverlayOrigin(r.Context()) {
|
||||
// Geo/CrowdSec checks don't apply over the WireGuard overlay:
|
||||
// the source address is always inside the NetBird CGNAT range,
|
||||
// which is never in a GeoIP database or a CrowdSec decision
|
||||
// list. Enforcing them here would either no-op (best case) or
|
||||
// fail-closed when the geo database is missing.
|
||||
verdict = config.IPRestrictions.CheckCIDR(clientIP)
|
||||
} else {
|
||||
verdict = config.IPRestrictions.Check(clientIP, mw.geo)
|
||||
}
|
||||
if verdict == restrict.Allow {
|
||||
return true
|
||||
}
|
||||
@@ -246,18 +321,111 @@ func (mw *Middleware) forwardWithSessionCookie(w http.ResponseWriter, r *http.Re
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
userID, method, err := auth.ValidateSessionJWT(cookie.Value, host, config.SessionPublicKey)
|
||||
userID, email, method, groups, groupNames, err := auth.ValidateSessionJWT(cookie.Value, host, config.SessionPublicKey)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
|
||||
cd.SetUserID(userID)
|
||||
cd.SetUserEmail(email)
|
||||
cd.SetUserGroups(groups)
|
||||
cd.SetUserGroupNames(groupNames)
|
||||
cd.SetAuthMethod(method)
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
return true
|
||||
}
|
||||
|
||||
// forwardWithTunnelPeer is the OIDC fast-path for requests originating on the
|
||||
// netbird mesh. When the source IP belongs to a private/CGNAT range the proxy
|
||||
// asks management to resolve it to a peer/user and to gate by the service's
|
||||
// distribution_groups. On success the proxy installs the freshly minted JWT
|
||||
// as a session cookie, sets UserID + Method=oidc on the captured data, and
|
||||
// forwards directly — operators see the same access-log shape as if the user
|
||||
// had completed an OIDC redirect. Any failure (private-range mismatch,
|
||||
// management unreachable, peer unknown, user not in group) returns false so
|
||||
// the caller falls back to the existing OIDC scheme dispatch.
|
||||
//
|
||||
// Phase 3 adds a local-first short-circuit: when the request arrived on a
|
||||
// per-account inbound listener the context carries a peerstore lookup
|
||||
// (TunnelLookupFromContext). If the lookup says the IP isn't in the account's
|
||||
// roster the proxy denies fast without calling management. If the lookup
|
||||
// confirms a known peer the RPC still runs for the user-identity tail
|
||||
// (UserID + group access), but its result is cached for tunnelCacheTTL so
|
||||
// repeat requests skip management entirely.
|
||||
func (mw *Middleware) forwardWithTunnelPeer(w http.ResponseWriter, r *http.Request, host string, config DomainConfig, next http.Handler) bool {
|
||||
if mw.sessionValidator == nil {
|
||||
return false
|
||||
}
|
||||
clientIP := mw.resolveClientIP(r)
|
||||
if !clientIP.IsValid() {
|
||||
return false
|
||||
}
|
||||
if !isTunnelSourceIP(clientIP) {
|
||||
return false
|
||||
}
|
||||
|
||||
if lookup := TunnelLookupFromContext(r.Context()); lookup != nil {
|
||||
if _, ok := lookup(clientIP); !ok {
|
||||
mw.logger.WithFields(log.Fields{
|
||||
"host": host,
|
||||
"remote": clientIP,
|
||||
}).Debug("local peerstore: tunnel IP not in account roster; denying without RPC")
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
resp, _, err := mw.tunnelCache.fetch(r.Context(), tunnelCacheKey{
|
||||
accountID: config.AccountID,
|
||||
tunnelIP: clientIP,
|
||||
domain: host,
|
||||
}, mw.validateTunnelPeer)
|
||||
if err != nil {
|
||||
mw.logger.WithError(err).Debug("ValidateTunnelPeer failed; falling back to OIDC")
|
||||
return false
|
||||
}
|
||||
if !resp.GetValid() || resp.GetSessionToken() == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
setSessionCookie(w, resp.GetSessionToken(), config.SessionExpiration)
|
||||
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
|
||||
cd.SetOrigin(proxy.OriginAuth)
|
||||
cd.SetUserID(resp.GetUserId())
|
||||
cd.SetUserEmail(resp.GetUserEmail())
|
||||
cd.SetUserGroups(resp.GetPeerGroupIds())
|
||||
cd.SetUserGroupNames(resp.GetPeerGroupNames())
|
||||
cd.SetAuthMethod(auth.MethodOIDC.String())
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
return true
|
||||
}
|
||||
|
||||
// validateTunnelPeer adapts the SessionValidator interface to the cache's
|
||||
// validateTunnelPeerFn signature.
|
||||
func (mw *Middleware) validateTunnelPeer(ctx context.Context, req *proto.ValidateTunnelPeerRequest) (*proto.ValidateTunnelPeerResponse, error) {
|
||||
return mw.sessionValidator.ValidateTunnelPeer(ctx, req)
|
||||
}
|
||||
|
||||
// cgnatPrefix covers RFC 6598 100.64.0.0/10, the CGNAT block NetBird
|
||||
// allocates tunnel addresses from by default. IsPrivate() doesn't include
|
||||
// it, so we check it explicitly.
|
||||
var cgnatPrefix = netip.MustParsePrefix("100.64.0.0/10")
|
||||
|
||||
// isTunnelSourceIP reports whether ip falls within an address range typical
|
||||
// of NetBird tunnels: RFC1918 private space, IPv6 ULA, or CGNAT 100.64/10
|
||||
// (NetBird's default range). Loopback and link-local are excluded — the
|
||||
// fast-path is meant for peer-to-peer mesh traffic, not localhost.
|
||||
func isTunnelSourceIP(ip netip.Addr) bool {
|
||||
if !ip.IsValid() || ip.IsLoopback() || ip.IsLinkLocalUnicast() {
|
||||
return false
|
||||
}
|
||||
if ip.IsPrivate() {
|
||||
return true
|
||||
}
|
||||
return cgnatPrefix.Contains(ip)
|
||||
}
|
||||
|
||||
// forwardWithHeaderAuth checks for a Header auth scheme. If the header validates,
|
||||
// the request is forwarded directly (no redirect), which is important for API clients.
|
||||
func (mw *Middleware) forwardWithHeaderAuth(w http.ResponseWriter, r *http.Request, host string, config DomainConfig, next http.Handler) bool {
|
||||
@@ -286,7 +454,7 @@ func (mw *Middleware) tryHeaderScheme(w http.ResponseWriter, r *http.Request, ho
|
||||
|
||||
result, err := mw.validateSessionToken(r.Context(), host, token, config.SessionPublicKey, auth.MethodHeader)
|
||||
if err != nil {
|
||||
setHeaderCapturedData(r.Context(), "")
|
||||
setHeaderCapturedData(r.Context(), "", "", nil, nil)
|
||||
status := http.StatusBadRequest
|
||||
msg := "invalid session token"
|
||||
if errors.Is(err, errValidationUnavailable) {
|
||||
@@ -298,7 +466,7 @@ func (mw *Middleware) tryHeaderScheme(w http.ResponseWriter, r *http.Request, ho
|
||||
}
|
||||
|
||||
if !result.Valid {
|
||||
setHeaderCapturedData(r.Context(), result.UserID)
|
||||
setHeaderCapturedData(r.Context(), result.UserID, result.UserEmail, result.Groups, result.GroupNames)
|
||||
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||||
return true
|
||||
}
|
||||
@@ -306,6 +474,9 @@ func (mw *Middleware) tryHeaderScheme(w http.ResponseWriter, r *http.Request, ho
|
||||
setSessionCookie(w, token, config.SessionExpiration)
|
||||
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
|
||||
cd.SetUserID(result.UserID)
|
||||
cd.SetUserEmail(result.UserEmail)
|
||||
cd.SetUserGroups(result.Groups)
|
||||
cd.SetUserGroupNames(result.GroupNames)
|
||||
cd.SetAuthMethod(auth.MethodHeader.String())
|
||||
}
|
||||
|
||||
@@ -315,7 +486,7 @@ func (mw *Middleware) tryHeaderScheme(w http.ResponseWriter, r *http.Request, ho
|
||||
|
||||
func (mw *Middleware) handleHeaderAuthError(w http.ResponseWriter, r *http.Request, err error) bool {
|
||||
if errors.Is(err, ErrHeaderAuthFailed) {
|
||||
setHeaderCapturedData(r.Context(), "")
|
||||
setHeaderCapturedData(r.Context(), "", "", nil, nil)
|
||||
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||||
return true
|
||||
}
|
||||
@@ -327,7 +498,7 @@ func (mw *Middleware) handleHeaderAuthError(w http.ResponseWriter, r *http.Reque
|
||||
return true
|
||||
}
|
||||
|
||||
func setHeaderCapturedData(ctx context.Context, userID string) {
|
||||
func setHeaderCapturedData(ctx context.Context, userID, userEmail string, groups, groupNames []string) {
|
||||
cd := proxy.CapturedDataFromContext(ctx)
|
||||
if cd == nil {
|
||||
return
|
||||
@@ -335,6 +506,9 @@ func setHeaderCapturedData(ctx context.Context, userID string) {
|
||||
cd.SetOrigin(proxy.OriginAuth)
|
||||
cd.SetAuthMethod(auth.MethodHeader.String())
|
||||
cd.SetUserID(userID)
|
||||
cd.SetUserEmail(userEmail)
|
||||
cd.SetUserGroups(groups)
|
||||
cd.SetUserGroupNames(groupNames)
|
||||
}
|
||||
|
||||
// authenticateWithSchemes tries each configured auth scheme in order.
|
||||
@@ -405,6 +579,9 @@ func (mw *Middleware) handleAuthenticatedToken(w http.ResponseWriter, r *http.Re
|
||||
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
|
||||
cd.SetOrigin(proxy.OriginAuth)
|
||||
cd.SetUserID(result.UserID)
|
||||
cd.SetUserEmail(result.UserEmail)
|
||||
cd.SetUserGroups(result.Groups)
|
||||
cd.SetUserGroupNames(result.GroupNames)
|
||||
cd.SetAuthMethod(scheme.Type().String())
|
||||
requestID = cd.GetRequestID()
|
||||
}
|
||||
@@ -419,6 +596,9 @@ func (mw *Middleware) handleAuthenticatedToken(w http.ResponseWriter, r *http.Re
|
||||
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
|
||||
cd.SetOrigin(proxy.OriginAuth)
|
||||
cd.SetUserID(result.UserID)
|
||||
cd.SetUserEmail(result.UserEmail)
|
||||
cd.SetUserGroups(result.Groups)
|
||||
cd.SetUserGroupNames(result.GroupNames)
|
||||
cd.SetAuthMethod(scheme.Type().String())
|
||||
}
|
||||
redirectURL := stripSessionTokenParam(r.URL)
|
||||
@@ -454,12 +634,9 @@ func wasCredentialSubmitted(r *http.Request, method auth.Method) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// AddDomain registers authentication schemes for the given domain.
|
||||
// If schemes are provided, a valid session public key is required to sign/verify
|
||||
// session JWTs. Returns an error if the key is missing or invalid.
|
||||
// Callers must not serve the domain if this returns an error, to avoid
|
||||
// exposing an unauthenticated service.
|
||||
func (mw *Middleware) AddDomain(domain string, schemes []Scheme, publicKeyB64 string, expiration time.Duration, accountID types.AccountID, serviceID types.ServiceID, ipRestrictions *restrict.Filter) error {
|
||||
// AddDomain registers authentication schemes for the given domain. With schemes a valid session public key is required.
|
||||
// private=true forces ValidateTunnelPeer enforcement (403 on failure) regardless of the schemes list.
|
||||
func (mw *Middleware) AddDomain(domain string, schemes []Scheme, publicKeyB64 string, expiration time.Duration, accountID types.AccountID, serviceID types.ServiceID, ipRestrictions *restrict.Filter, private bool) error {
|
||||
if len(schemes) == 0 {
|
||||
mw.domainsMux.Lock()
|
||||
defer mw.domainsMux.Unlock()
|
||||
@@ -467,6 +644,7 @@ func (mw *Middleware) AddDomain(domain string, schemes []Scheme, publicKeyB64 st
|
||||
AccountID: accountID,
|
||||
ServiceID: serviceID,
|
||||
IPRestrictions: ipRestrictions,
|
||||
Private: private,
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -488,6 +666,7 @@ func (mw *Middleware) AddDomain(domain string, schemes []Scheme, publicKeyB64 st
|
||||
AccountID: accountID,
|
||||
ServiceID: serviceID,
|
||||
IPRestrictions: ipRestrictions,
|
||||
Private: private,
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -518,18 +697,25 @@ func (mw *Middleware) validateSessionToken(ctx context.Context, host, token stri
|
||||
}).Debug("Session validation denied")
|
||||
return &validationResult{
|
||||
UserID: resp.UserId,
|
||||
UserEmail: resp.GetUserEmail(),
|
||||
Valid: false,
|
||||
DeniedReason: resp.DeniedReason,
|
||||
}, nil
|
||||
}
|
||||
return &validationResult{UserID: resp.UserId, Valid: true}, nil
|
||||
return &validationResult{
|
||||
UserID: resp.UserId,
|
||||
UserEmail: resp.GetUserEmail(),
|
||||
Valid: true,
|
||||
Groups: resp.GetPeerGroupIds(),
|
||||
GroupNames: resp.GetPeerGroupNames(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
userID, _, err := auth.ValidateSessionJWT(token, host, publicKey)
|
||||
userID, email, _, groups, groupNames, err := auth.ValidateSessionJWT(token, host, publicKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &validationResult{UserID: userID, Valid: true}, nil
|
||||
return &validationResult{UserID: userID, UserEmail: email, Valid: true, Groups: groups, GroupNames: groupNames}, nil
|
||||
}
|
||||
|
||||
// stripSessionTokenParam returns the request URI with the session_token query
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"crypto/rand"
|
||||
"crypto/tls"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"net/http"
|
||||
@@ -23,6 +24,7 @@ import (
|
||||
"github.com/netbirdio/netbird/proxy/auth"
|
||||
"github.com/netbirdio/netbird/proxy/internal/proxy"
|
||||
"github.com/netbirdio/netbird/proxy/internal/restrict"
|
||||
"github.com/netbirdio/netbird/proxy/internal/types"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
@@ -62,7 +64,7 @@ func TestAddDomain_ValidKey(t *testing.T) {
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
err := mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil)
|
||||
err := mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
mw.domainsMux.RLock()
|
||||
@@ -79,7 +81,7 @@ func TestAddDomain_EmptyKey(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil, nil)
|
||||
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
err := mw.AddDomain("example.com", []Scheme{scheme}, "", time.Hour, "", "", nil)
|
||||
err := mw.AddDomain("example.com", []Scheme{scheme}, "", time.Hour, "", "", nil, false)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "invalid session public key size")
|
||||
|
||||
@@ -93,7 +95,7 @@ func TestAddDomain_InvalidBase64(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil, nil)
|
||||
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
err := mw.AddDomain("example.com", []Scheme{scheme}, "not-valid-base64!!!", time.Hour, "", "", nil)
|
||||
err := mw.AddDomain("example.com", []Scheme{scheme}, "not-valid-base64!!!", time.Hour, "", "", nil, false)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "decode session public key")
|
||||
|
||||
@@ -108,7 +110,7 @@ func TestAddDomain_WrongKeySize(t *testing.T) {
|
||||
|
||||
shortKey := base64.StdEncoding.EncodeToString([]byte("tooshort"))
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
err := mw.AddDomain("example.com", []Scheme{scheme}, shortKey, time.Hour, "", "", nil)
|
||||
err := mw.AddDomain("example.com", []Scheme{scheme}, shortKey, time.Hour, "", "", nil, false)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "invalid session public key size")
|
||||
|
||||
@@ -121,7 +123,7 @@ func TestAddDomain_WrongKeySize(t *testing.T) {
|
||||
func TestAddDomain_NoSchemes_NoKeyRequired(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil, nil)
|
||||
|
||||
err := mw.AddDomain("example.com", nil, "", time.Hour, "", "", nil)
|
||||
err := mw.AddDomain("example.com", nil, "", time.Hour, "", "", nil, false)
|
||||
require.NoError(t, err, "domains with no auth schemes should not require a key")
|
||||
|
||||
mw.domainsMux.RLock()
|
||||
@@ -137,8 +139,8 @@ func TestAddDomain_OverwritesPreviousConfig(t *testing.T) {
|
||||
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp1.PublicKey, time.Hour, "", "", nil))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp2.PublicKey, 2*time.Hour, "", "", nil))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp1.PublicKey, time.Hour, "", "", nil, false))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp2.PublicKey, 2*time.Hour, "", "", nil, false))
|
||||
|
||||
mw.domainsMux.RLock()
|
||||
config := mw.domains["example.com"]
|
||||
@@ -154,7 +156,7 @@ func TestRemoveDomain(t *testing.T) {
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
|
||||
|
||||
mw.RemoveDomain("example.com")
|
||||
|
||||
@@ -178,7 +180,7 @@ func TestProtect_UnknownDomainPassesThrough(t *testing.T) {
|
||||
|
||||
func TestProtect_DomainWithNoSchemesPassesThrough(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil, nil)
|
||||
require.NoError(t, mw.AddDomain("example.com", nil, "", time.Hour, "", "", nil))
|
||||
require.NoError(t, mw.AddDomain("example.com", nil, "", time.Hour, "", "", nil, false))
|
||||
|
||||
handler := mw.Protect(newPassthroughHandler())
|
||||
|
||||
@@ -195,7 +197,7 @@ func TestProtect_UnauthenticatedRequestIsBlocked(t *testing.T) {
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
|
||||
|
||||
var backendCalled bool
|
||||
backend := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
@@ -216,7 +218,7 @@ func TestProtect_HostWithPortIsMatched(t *testing.T) {
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
|
||||
|
||||
var backendCalled bool
|
||||
backend := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
@@ -237,9 +239,9 @@ func TestProtect_ValidSessionCookiePassesThrough(t *testing.T) {
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
|
||||
|
||||
token, err := sessionkey.SignToken(kp.PrivateKey, "test-user", "example.com", auth.MethodPIN, time.Hour)
|
||||
token, err := sessionkey.SignToken(kp.PrivateKey, "test-user", "", "example.com", auth.MethodPIN, nil, nil, time.Hour)
|
||||
require.NoError(t, err)
|
||||
|
||||
capturedData := proxy.NewCapturedData("")
|
||||
@@ -262,15 +264,48 @@ func TestProtect_ValidSessionCookiePassesThrough(t *testing.T) {
|
||||
assert.Equal(t, "authenticated", rec.Body.String())
|
||||
}
|
||||
|
||||
// TestProtect_SessionCookieGroupsPropagate verifies the cookie path lifts the
|
||||
// JWT's groups claim into CapturedData so policy-aware middlewares can
|
||||
// authorise without an extra management round-trip.
|
||||
func TestProtect_SessionCookieGroupsPropagate(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil, nil)
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
|
||||
|
||||
groups := []string{"engineering", "sre"}
|
||||
token, err := sessionkey.SignToken(kp.PrivateKey, "test-user", "", "example.com", auth.MethodPIN, groups, nil, time.Hour)
|
||||
require.NoError(t, err)
|
||||
|
||||
capturedData := proxy.NewCapturedData("")
|
||||
handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
cd := proxy.CapturedDataFromContext(r.Context())
|
||||
require.NotNil(t, cd, "captured data must be present in request context")
|
||||
assert.Equal(t, "test-user", cd.GetUserID())
|
||||
assert.Equal(t, groups, cd.GetUserGroups(), "JWT groups claim must propagate to CapturedData")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
|
||||
req = req.WithContext(proxy.WithCapturedData(req.Context(), capturedData))
|
||||
req.AddCookie(&http.Cookie{Name: auth.SessionCookieName, Value: token})
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code, "request with valid groups-bearing cookie must succeed")
|
||||
assert.Equal(t, groups, capturedData.GetUserGroups(), "CapturedData groups must be retained after handler completes")
|
||||
}
|
||||
|
||||
func TestProtect_ExpiredSessionCookieIsRejected(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil, nil)
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
|
||||
|
||||
// Sign a token that expired 1 second ago.
|
||||
token, err := sessionkey.SignToken(kp.PrivateKey, "test-user", "example.com", auth.MethodPIN, -time.Second)
|
||||
token, err := sessionkey.SignToken(kp.PrivateKey, "test-user", "", "example.com", auth.MethodPIN, nil, nil, -time.Second)
|
||||
require.NoError(t, err)
|
||||
|
||||
var backendCalled bool
|
||||
@@ -293,10 +328,10 @@ func TestProtect_WrongDomainCookieIsRejected(t *testing.T) {
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
|
||||
|
||||
// Token signed for a different domain audience.
|
||||
token, err := sessionkey.SignToken(kp.PrivateKey, "test-user", "other.com", auth.MethodPIN, time.Hour)
|
||||
token, err := sessionkey.SignToken(kp.PrivateKey, "test-user", "", "other.com", auth.MethodPIN, nil, nil, time.Hour)
|
||||
require.NoError(t, err)
|
||||
|
||||
var backendCalled bool
|
||||
@@ -320,10 +355,10 @@ func TestProtect_WrongKeyCookieIsRejected(t *testing.T) {
|
||||
kp2 := generateTestKeyPair(t)
|
||||
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp1.PublicKey, time.Hour, "", "", nil))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp1.PublicKey, time.Hour, "", "", nil, false))
|
||||
|
||||
// Token signed with a different private key.
|
||||
token, err := sessionkey.SignToken(kp2.PrivateKey, "test-user", "example.com", auth.MethodPIN, time.Hour)
|
||||
token, err := sessionkey.SignToken(kp2.PrivateKey, "test-user", "", "example.com", auth.MethodPIN, nil, nil, time.Hour)
|
||||
require.NoError(t, err)
|
||||
|
||||
var backendCalled bool
|
||||
@@ -345,7 +380,7 @@ func TestProtect_SchemeAuthRedirectsWithCookie(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil, nil)
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
token, err := sessionkey.SignToken(kp.PrivateKey, "pin-user", "example.com", auth.MethodPIN, time.Hour)
|
||||
token, err := sessionkey.SignToken(kp.PrivateKey, "pin-user", "", "example.com", auth.MethodPIN, nil, nil, time.Hour)
|
||||
require.NoError(t, err)
|
||||
|
||||
scheme := &stubScheme{
|
||||
@@ -357,7 +392,7 @@ func TestProtect_SchemeAuthRedirectsWithCookie(t *testing.T) {
|
||||
return "", "pin", nil
|
||||
},
|
||||
}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
|
||||
|
||||
var backendCalled bool
|
||||
backend := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
@@ -410,7 +445,7 @@ func TestProtect_FailedAuthDoesNotSetCookie(t *testing.T) {
|
||||
return "", "pin", nil
|
||||
},
|
||||
}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
|
||||
|
||||
handler := mw.Protect(newPassthroughHandler())
|
||||
|
||||
@@ -427,7 +462,7 @@ func TestProtect_MultipleSchemes(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil, nil)
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
token, err := sessionkey.SignToken(kp.PrivateKey, "password-user", "example.com", auth.MethodPassword, time.Hour)
|
||||
token, err := sessionkey.SignToken(kp.PrivateKey, "password-user", "", "example.com", auth.MethodPassword, nil, nil, time.Hour)
|
||||
require.NoError(t, err)
|
||||
|
||||
// First scheme (PIN) always fails, second scheme (password) succeeds.
|
||||
@@ -446,7 +481,7 @@ func TestProtect_MultipleSchemes(t *testing.T) {
|
||||
return "", "password", nil
|
||||
},
|
||||
}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{pinScheme, passwordScheme}, kp.PublicKey, time.Hour, "", "", nil))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{pinScheme, passwordScheme}, kp.PublicKey, time.Hour, "", "", nil, false))
|
||||
|
||||
var backendCalled bool
|
||||
backend := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
@@ -476,7 +511,7 @@ func TestProtect_InvalidTokenFromSchemeReturns400(t *testing.T) {
|
||||
return "invalid-jwt-token", "", nil
|
||||
},
|
||||
}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
|
||||
|
||||
handler := mw.Protect(newPassthroughHandler())
|
||||
|
||||
@@ -500,7 +535,7 @@ func TestAddDomain_RandomBytes32NotEd25519(t *testing.T) {
|
||||
key := base64.StdEncoding.EncodeToString(randomBytes)
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
|
||||
err = mw.AddDomain("example.com", []Scheme{scheme}, key, time.Hour, "", "", nil)
|
||||
err = mw.AddDomain("example.com", []Scheme{scheme}, key, time.Hour, "", "", nil, false)
|
||||
require.NoError(t, err, "any 32-byte key should be accepted at registration time")
|
||||
}
|
||||
|
||||
@@ -509,10 +544,10 @@ func TestAddDomain_InvalidKeyDoesNotCorruptExistingConfig(t *testing.T) {
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
|
||||
|
||||
// Attempt to overwrite with an invalid key.
|
||||
err := mw.AddDomain("example.com", []Scheme{scheme}, "bad", time.Hour, "", "", nil)
|
||||
err := mw.AddDomain("example.com", []Scheme{scheme}, "bad", time.Hour, "", "", nil, false)
|
||||
require.Error(t, err)
|
||||
|
||||
// The original valid config should still be intact.
|
||||
@@ -536,7 +571,7 @@ func TestProtect_FailedPinAuthCapturesAuthMethod(t *testing.T) {
|
||||
return "", "pin", nil
|
||||
},
|
||||
}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
|
||||
|
||||
capturedData := proxy.NewCapturedData("")
|
||||
handler := mw.Protect(newPassthroughHandler())
|
||||
@@ -563,7 +598,7 @@ func TestProtect_FailedPasswordAuthCapturesAuthMethod(t *testing.T) {
|
||||
return "", "password", nil
|
||||
},
|
||||
}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
|
||||
|
||||
capturedData := proxy.NewCapturedData("")
|
||||
handler := mw.Protect(newPassthroughHandler())
|
||||
@@ -590,7 +625,7 @@ func TestProtect_NoCredentialsDoesNotCaptureAuthMethod(t *testing.T) {
|
||||
return "", "pin", nil
|
||||
},
|
||||
}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
|
||||
|
||||
capturedData := proxy.NewCapturedData("")
|
||||
handler := mw.Protect(newPassthroughHandler())
|
||||
@@ -678,7 +713,7 @@ func TestCheckIPRestrictions_UnparseableAddress(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil, nil)
|
||||
|
||||
err := mw.AddDomain("example.com", nil, "", 0, "acc1", "svc1",
|
||||
restrict.ParseFilter(restrict.FilterConfig{AllowedCIDRs: []string{"10.0.0.0/8"}}))
|
||||
restrict.ParseFilter(restrict.FilterConfig{AllowedCIDRs: []string{"10.0.0.0/8"}}), false)
|
||||
require.NoError(t, err)
|
||||
|
||||
handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -714,7 +749,7 @@ func TestCheckIPRestrictions_UsesCapturedDataClientIP(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil, nil)
|
||||
|
||||
err := mw.AddDomain("example.com", nil, "", 0, "acc1", "svc1",
|
||||
restrict.ParseFilter(restrict.FilterConfig{AllowedCIDRs: []string{"203.0.113.0/24"}}))
|
||||
restrict.ParseFilter(restrict.FilterConfig{AllowedCIDRs: []string{"203.0.113.0/24"}}), false)
|
||||
require.NoError(t, err)
|
||||
|
||||
handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -755,7 +790,7 @@ func TestCheckIPRestrictions_NilGeoWithCountryRules(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil, nil)
|
||||
|
||||
err := mw.AddDomain("example.com", nil, "", 0, "acc1", "svc1",
|
||||
restrict.ParseFilter(restrict.FilterConfig{AllowedCountries: []string{"US"}}))
|
||||
restrict.ParseFilter(restrict.FilterConfig{AllowedCountries: []string{"US"}}), false)
|
||||
require.NoError(t, err)
|
||||
|
||||
handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -770,6 +805,69 @@ func TestCheckIPRestrictions_NilGeoWithCountryRules(t *testing.T) {
|
||||
assert.Equal(t, http.StatusForbidden, rr.Code, "country restrictions with nil geo must deny")
|
||||
}
|
||||
|
||||
// TestCheckIPRestrictions_OverlayOriginSkipsCountryRules covers the
|
||||
// inbound (WG) listener path: requests stamped with WithOverlayOrigin
|
||||
// must skip country lookups, even when no geo database is configured.
|
||||
// Without this short-circuit the inbound flow would fail-closed for
|
||||
// every overlay request whenever country rules are configured.
|
||||
func TestCheckIPRestrictions_OverlayOriginSkipsCountryRules(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil, nil)
|
||||
|
||||
err := mw.AddDomain("example.com", nil, "", 0, "acc1", "svc1",
|
||||
restrict.ParseFilter(restrict.FilterConfig{
|
||||
AllowedCIDRs: []string{"100.64.0.0/10"},
|
||||
AllowedCountries: []string{"US"},
|
||||
}), false)
|
||||
require.NoError(t, err)
|
||||
|
||||
handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
|
||||
req.RemoteAddr = "100.64.5.6:5000"
|
||||
req.Host = "example.com"
|
||||
req = req.WithContext(types.WithOverlayOrigin(req.Context()))
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, req)
|
||||
assert.Equal(t, http.StatusOK, rr.Code,
|
||||
"overlay-origin requests must not be denied by country rules they would fail without geo data")
|
||||
|
||||
// Sanity check: the same filter without the overlay flag denies (no geo,
|
||||
// country allowlist active → DenyGeoUnavailable).
|
||||
req2 := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
|
||||
req2.RemoteAddr = "100.64.5.6:5000"
|
||||
req2.Host = "example.com"
|
||||
rr2 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr2, req2)
|
||||
assert.Equal(t, http.StatusForbidden, rr2.Code,
|
||||
"WAN-origin requests must still hit the full Check path and be denied without geo data")
|
||||
}
|
||||
|
||||
// TestCheckIPRestrictions_OverlayOriginRespectsCIDR confirms CIDR
|
||||
// rules still apply on the overlay path so operators retain a way to
|
||||
// scope private services to specific peer subnets.
|
||||
func TestCheckIPRestrictions_OverlayOriginRespectsCIDR(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil, nil)
|
||||
|
||||
err := mw.AddDomain("example.com", nil, "", 0, "acc1", "svc1",
|
||||
restrict.ParseFilter(restrict.FilterConfig{AllowedCIDRs: []string{"100.64.0.0/16"}}), false)
|
||||
require.NoError(t, err)
|
||||
|
||||
handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
|
||||
req.RemoteAddr = "100.65.5.6:5000" // outside 100.64.0.0/16
|
||||
req.Host = "example.com"
|
||||
req = req.WithContext(types.WithOverlayOrigin(req.Context()))
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, req)
|
||||
assert.Equal(t, http.StatusForbidden, rr.Code,
|
||||
"CIDR rules must still apply on the overlay path")
|
||||
}
|
||||
|
||||
func TestProtect_OIDCOnlyRedirectsDirectly(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil, nil)
|
||||
kp := generateTestKeyPair(t)
|
||||
@@ -781,11 +879,12 @@ func TestProtect_OIDCOnlyRedirectsDirectly(t *testing.T) {
|
||||
return "", oidcURL, nil
|
||||
},
|
||||
}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
|
||||
|
||||
handler := mw.Protect(newPassthroughHandler())
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
|
||||
req := httptest.NewRequest(http.MethodGet, "https://example.com/", nil)
|
||||
req.TLS = &tls.ConnectionState{}
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
@@ -809,11 +908,12 @@ func TestProtect_OIDCWithOtherMethodShowsLoginPage(t *testing.T) {
|
||||
return "", "pin", nil
|
||||
},
|
||||
}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{oidcScheme, pinScheme}, kp.PublicKey, time.Hour, "", "", nil))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{oidcScheme, pinScheme}, kp.PublicKey, time.Hour, "", "", nil, false))
|
||||
|
||||
handler := mw.Protect(newPassthroughHandler())
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
|
||||
req := httptest.NewRequest(http.MethodGet, "https://example.com/", nil)
|
||||
req.TLS = &tls.ConnectionState{}
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
@@ -834,7 +934,7 @@ func (m *mockAuthenticator) Authenticate(ctx context.Context, in *proto.Authenti
|
||||
// returns a signed session token when the expected header value is provided.
|
||||
func newHeaderSchemeWithToken(t *testing.T, kp *sessionkey.KeyPair, headerName, expectedValue string) Header {
|
||||
t.Helper()
|
||||
token, err := sessionkey.SignToken(kp.PrivateKey, "header-user", "example.com", auth.MethodHeader, time.Hour)
|
||||
token, err := sessionkey.SignToken(kp.PrivateKey, "header-user", "", "example.com", auth.MethodHeader, nil, nil, time.Hour)
|
||||
require.NoError(t, err)
|
||||
|
||||
mock := &mockAuthenticator{fn: func(_ context.Context, req *proto.AuthenticateRequest) (*proto.AuthenticateResponse, error) {
|
||||
@@ -852,7 +952,7 @@ func TestProtect_HeaderAuth_ForwardsOnSuccess(t *testing.T) {
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
hdr := newHeaderSchemeWithToken(t, kp, "X-API-Key", "secret-key")
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr}, kp.PublicKey, time.Hour, "acc1", "svc1", nil))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr}, kp.PublicKey, time.Hour, "acc1", "svc1", nil, false))
|
||||
|
||||
var backendCalled bool
|
||||
capturedData := proxy.NewCapturedData("")
|
||||
@@ -895,7 +995,7 @@ func TestProtect_HeaderAuth_MissingHeaderFallsThrough(t *testing.T) {
|
||||
hdr := newHeaderSchemeWithToken(t, kp, "X-API-Key", "secret-key")
|
||||
// Also add a PIN scheme so we can verify fallthrough behavior.
|
||||
pinScheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr, pinScheme}, kp.PublicKey, time.Hour, "acc1", "svc1", nil))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr, pinScheme}, kp.PublicKey, time.Hour, "acc1", "svc1", nil, false))
|
||||
|
||||
handler := mw.Protect(newPassthroughHandler())
|
||||
|
||||
@@ -915,7 +1015,7 @@ func TestProtect_HeaderAuth_WrongValueReturns401(t *testing.T) {
|
||||
return &proto.AuthenticateResponse{Success: false}, nil
|
||||
}}
|
||||
hdr := NewHeader(mock, "svc1", "acc1", "X-API-Key")
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr}, kp.PublicKey, time.Hour, "acc1", "svc1", nil))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr}, kp.PublicKey, time.Hour, "acc1", "svc1", nil, false))
|
||||
|
||||
capturedData := proxy.NewCapturedData("")
|
||||
handler := mw.Protect(newPassthroughHandler())
|
||||
@@ -938,7 +1038,7 @@ func TestProtect_HeaderAuth_InfraErrorReturns502(t *testing.T) {
|
||||
return nil, errors.New("gRPC unavailable")
|
||||
}}
|
||||
hdr := NewHeader(mock, "svc1", "acc1", "X-API-Key")
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr}, kp.PublicKey, time.Hour, "acc1", "svc1", nil))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr}, kp.PublicKey, time.Hour, "acc1", "svc1", nil, false))
|
||||
|
||||
handler := mw.Protect(newPassthroughHandler())
|
||||
|
||||
@@ -955,7 +1055,7 @@ func TestProtect_HeaderAuth_SubsequentRequestUsesSessionCookie(t *testing.T) {
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
hdr := newHeaderSchemeWithToken(t, kp, "X-API-Key", "secret-key")
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr}, kp.PublicKey, time.Hour, "acc1", "svc1", nil))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr}, kp.PublicKey, time.Hour, "acc1", "svc1", nil, false))
|
||||
|
||||
handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
@@ -1006,7 +1106,7 @@ func TestProtect_HeaderAuth_MultipleValuesSameHeader(t *testing.T) {
|
||||
mock := &mockAuthenticator{fn: func(_ context.Context, req *proto.AuthenticateRequest) (*proto.AuthenticateResponse, error) {
|
||||
ha := req.GetHeaderAuth()
|
||||
if ha != nil && accepted[ha.GetHeaderValue()] {
|
||||
token, err := sessionkey.SignToken(kp.PrivateKey, "header-user", "example.com", auth.MethodHeader, time.Hour)
|
||||
token, err := sessionkey.SignToken(kp.PrivateKey, "header-user", "", "example.com", auth.MethodHeader, nil, nil, time.Hour)
|
||||
require.NoError(t, err)
|
||||
return &proto.AuthenticateResponse{Success: true, SessionToken: token}, nil
|
||||
}
|
||||
@@ -1015,7 +1115,7 @@ func TestProtect_HeaderAuth_MultipleValuesSameHeader(t *testing.T) {
|
||||
|
||||
// Single Header scheme (as if one entry existed), but the mock checks both values.
|
||||
hdr := NewHeader(mock, "svc1", "acc1", "Authorization")
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr}, kp.PublicKey, time.Hour, "acc1", "svc1", nil))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr}, kp.PublicKey, time.Hour, "acc1", "svc1", nil, false))
|
||||
|
||||
var backendCalled bool
|
||||
handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
@@ -1059,3 +1159,71 @@ func TestProtect_HeaderAuth_MultipleValuesSameHeader(t *testing.T) {
|
||||
assert.False(t, backendCalled, "unknown token should be rejected")
|
||||
})
|
||||
}
|
||||
|
||||
// TestProtect_OIDCOnPlainHTTP_BlockedWith400 verifies that when an OIDC
|
||||
// scheme is configured and the request arrived without TLS, the middleware
|
||||
// short-circuits with a 400 instead of dispatching to the IdP redirect.
|
||||
func TestProtect_OIDCOnPlainHTTP_BlockedWith400(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil, nil)
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
scheme := &stubScheme{
|
||||
method: auth.MethodOIDC,
|
||||
authFn: func(_ *http.Request) (string, string, error) {
|
||||
return "", "https://idp.example.com/authorize", nil
|
||||
},
|
||||
}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
|
||||
|
||||
handler := mw.Protect(newPassthroughHandler())
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusBadRequest, rec.Code, "OIDC over plain HTTP should be rejected")
|
||||
assert.Contains(t, rec.Body.String(), "OIDC requires TLS", "response body should explain the rejection")
|
||||
}
|
||||
|
||||
// TestProtect_OIDCOverTLS_NotBlocked confirms the same configuration works
|
||||
// over TLS — the block only fires on plain HTTP.
|
||||
func TestProtect_OIDCOverTLS_NotBlocked(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil, nil)
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
scheme := &stubScheme{
|
||||
method: auth.MethodOIDC,
|
||||
authFn: func(_ *http.Request) (string, string, error) {
|
||||
return "", "https://idp.example.com/authorize", nil
|
||||
},
|
||||
}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
|
||||
|
||||
handler := mw.Protect(newPassthroughHandler())
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "https://example.com/", nil)
|
||||
req.TLS = &tls.ConnectionState{}
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusFound, rec.Code, "OIDC over TLS should redirect to IdP")
|
||||
}
|
||||
|
||||
// TestProtect_NonOIDCSchemes_PlainHTTP_NotBlocked confirms that the OIDC
|
||||
// block only fires when an OIDC scheme is configured. PIN-only domains
|
||||
// pass through normally on plain HTTP.
|
||||
func TestProtect_NonOIDCSchemes_PlainHTTP_NotBlocked(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil, nil)
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
|
||||
|
||||
handler := mw.Protect(newPassthroughHandler())
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusUnauthorized, rec.Code, "PIN-only domain should serve the login page on plain HTTP")
|
||||
}
|
||||
|
||||
171
proxy/internal/auth/tunnel_cache.go
Normal file
171
proxy/internal/auth/tunnel_cache.go
Normal file
@@ -0,0 +1,171 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/sync/singleflight"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/internal/types"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
// tunnelCacheTTL caps how long a positive ValidateTunnelPeer result is
|
||||
// reused before re-fetching from management. 5 minutes balances freshness
|
||||
// against management load on busy mesh networks.
|
||||
const tunnelCacheTTL = 300 * time.Second
|
||||
|
||||
// tunnelCachePerAccount caps the number of cached identities per account.
|
||||
// Bounded eviction avoids memory growth in pathological cases (huge peer
|
||||
// roster, brief request bursts) while staying generous for normal use.
|
||||
const tunnelCachePerAccount = 1024
|
||||
|
||||
// tunnelCacheKey identifies a cached entry by tunnel IP and originating
|
||||
// account. Domain is part of the value, not the key, because the
|
||||
// management response is per (account, IP) — domain only gates whether a
|
||||
// re-fetch is needed if the operator is accessing a different service.
|
||||
type tunnelCacheKey struct {
|
||||
accountID types.AccountID
|
||||
tunnelIP netip.Addr
|
||||
domain string
|
||||
}
|
||||
|
||||
// tunnelCacheEntry stores a positive validation response with the time it
|
||||
// was minted. Entries past tunnelCacheTTL are treated as misses.
|
||||
type tunnelCacheEntry struct {
|
||||
resp *proto.ValidateTunnelPeerResponse
|
||||
cachedAt time.Time
|
||||
}
|
||||
|
||||
// tunnelValidationCache memoizes ValidateTunnelPeer responses keyed by
|
||||
// (accountID, tunnelIP, domain). Only successful, valid responses are
|
||||
// cached — denials skip the cache so policy changes apply immediately.
|
||||
// Single-flight de-duplicates concurrent fetches for the same key so a
|
||||
// burst of cold requests collapses into a single RPC.
|
||||
type tunnelValidationCache struct {
|
||||
mu sync.Mutex
|
||||
entries map[types.AccountID]*accountBucket
|
||||
flight singleflight.Group
|
||||
ttl time.Duration
|
||||
maxSize int
|
||||
now func() time.Time
|
||||
}
|
||||
|
||||
// accountBucket holds the cached entries for a single account, with a
|
||||
// FIFO eviction queue used when the bucket exceeds maxSize.
|
||||
type accountBucket struct {
|
||||
items map[tunnelCacheKey]tunnelCacheEntry
|
||||
order []tunnelCacheKey
|
||||
}
|
||||
|
||||
// newTunnelValidationCache constructs a cache with default TTL and bounds.
|
||||
func newTunnelValidationCache() *tunnelValidationCache {
|
||||
return &tunnelValidationCache{
|
||||
entries: make(map[types.AccountID]*accountBucket),
|
||||
ttl: tunnelCacheTTL,
|
||||
maxSize: tunnelCachePerAccount,
|
||||
now: time.Now,
|
||||
}
|
||||
}
|
||||
|
||||
// get returns a cached response for the key, or nil when missing or
|
||||
// expired. Expired entries are evicted lazily on read.
|
||||
func (c *tunnelValidationCache) get(key tunnelCacheKey) *proto.ValidateTunnelPeerResponse {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
bucket, ok := c.entries[key.accountID]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
entry, ok := bucket.items[key]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
if c.now().Sub(entry.cachedAt) > c.ttl {
|
||||
delete(bucket.items, key)
|
||||
bucket.order = removeKey(bucket.order, key)
|
||||
return nil
|
||||
}
|
||||
return entry.resp
|
||||
}
|
||||
|
||||
// put records a positive response under the key. Evicts the oldest entry
|
||||
// in the account's bucket when the bound is exceeded.
|
||||
func (c *tunnelValidationCache) put(key tunnelCacheKey, resp *proto.ValidateTunnelPeerResponse) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
bucket, ok := c.entries[key.accountID]
|
||||
if !ok {
|
||||
bucket = &accountBucket{items: make(map[tunnelCacheKey]tunnelCacheEntry)}
|
||||
c.entries[key.accountID] = bucket
|
||||
}
|
||||
if _, exists := bucket.items[key]; !exists {
|
||||
bucket.order = append(bucket.order, key)
|
||||
}
|
||||
bucket.items[key] = tunnelCacheEntry{resp: resp, cachedAt: c.now()}
|
||||
|
||||
for len(bucket.order) > c.maxSize {
|
||||
oldest := bucket.order[0]
|
||||
bucket.order = bucket.order[1:]
|
||||
delete(bucket.items, oldest)
|
||||
}
|
||||
}
|
||||
|
||||
// removeKey drops the first occurrence of needle from order. The cache
|
||||
// uses small slices so a linear scan is cheaper than a map+slice combo.
|
||||
func removeKey(order []tunnelCacheKey, needle tunnelCacheKey) []tunnelCacheKey {
|
||||
for i, k := range order {
|
||||
if k == needle {
|
||||
return append(order[:i], order[i+1:]...)
|
||||
}
|
||||
}
|
||||
return order
|
||||
}
|
||||
|
||||
// flightKey turns a cache key into a single-flight string. AccountID and
|
||||
// IP isolation by themselves are insufficient because different domains
|
||||
// for the same peer/account may have different group access.
|
||||
func flightKey(key tunnelCacheKey) string {
|
||||
return string(key.accountID) + "|" + key.tunnelIP.String() + "|" + key.domain
|
||||
}
|
||||
|
||||
// validateTunnelPeerFn is the RPC entry point the cache wraps. It matches
|
||||
// the SessionValidator.ValidateTunnelPeer signature without exposing the
|
||||
// gRPC option variadic, since callers don't need it on the cache hot path.
|
||||
type validateTunnelPeerFn func(ctx context.Context, req *proto.ValidateTunnelPeerRequest) (*proto.ValidateTunnelPeerResponse, error)
|
||||
|
||||
// fetch returns a cached response when present, otherwise calls validate
|
||||
// under single-flight and caches the result. Denied responses pass
|
||||
// through but are not cached so policy changes apply immediately.
|
||||
func (c *tunnelValidationCache) fetch(ctx context.Context, key tunnelCacheKey, validate validateTunnelPeerFn) (*proto.ValidateTunnelPeerResponse, bool, error) {
|
||||
if resp := c.get(key); resp != nil {
|
||||
return resp, true, nil
|
||||
}
|
||||
|
||||
flight := flightKey(key)
|
||||
res, err, _ := c.flight.Do(flight, func() (any, error) {
|
||||
if cached := c.get(key); cached != nil {
|
||||
return cached, nil
|
||||
}
|
||||
resp, err := validate(ctx, &proto.ValidateTunnelPeerRequest{
|
||||
TunnelIp: key.tunnelIP.String(),
|
||||
Domain: key.domain,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.GetValid() && resp.GetSessionToken() != "" {
|
||||
c.put(key, resp)
|
||||
}
|
||||
return resp, nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
resp, _ := res.(*proto.ValidateTunnelPeerResponse)
|
||||
return resp, false, nil
|
||||
}
|
||||
171
proxy/internal/auth/tunnel_cache_test.go
Normal file
171
proxy/internal/auth/tunnel_cache_test.go
Normal file
@@ -0,0 +1,171 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/internal/types"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
func newTestKey(account types.AccountID, ip string, domain string) tunnelCacheKey {
|
||||
return tunnelCacheKey{
|
||||
accountID: account,
|
||||
tunnelIP: netip.MustParseAddr(ip),
|
||||
domain: domain,
|
||||
}
|
||||
}
|
||||
|
||||
func TestTunnelCache_HitSkipsRPC(t *testing.T) {
|
||||
cache := newTunnelValidationCache()
|
||||
key := newTestKey("acct-1", "100.64.0.10", "svc.example")
|
||||
|
||||
var calls int32
|
||||
validate := func(_ context.Context, req *proto.ValidateTunnelPeerRequest) (*proto.ValidateTunnelPeerResponse, error) {
|
||||
atomic.AddInt32(&calls, 1)
|
||||
return &proto.ValidateTunnelPeerResponse{Valid: true, SessionToken: "tok", UserId: "user-1"}, nil
|
||||
}
|
||||
|
||||
resp, fromCache, err := cache.fetch(context.Background(), key, validate)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp, "first fetch returns RPC response")
|
||||
assert.False(t, fromCache, "first fetch must not be cached")
|
||||
|
||||
resp2, fromCache2, err := cache.fetch(context.Background(), key, validate)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp2, "second fetch returns cached response")
|
||||
assert.True(t, fromCache2, "second fetch must be served from cache")
|
||||
assert.Equal(t, "user-1", resp2.GetUserId(), "cached response should preserve user identity")
|
||||
assert.Equal(t, int32(1), atomic.LoadInt32(&calls), "validate should run exactly once with one cache hit")
|
||||
}
|
||||
|
||||
func TestTunnelCache_ExpiredEntryRefetches(t *testing.T) {
|
||||
cache := newTunnelValidationCache()
|
||||
clock := time.Now()
|
||||
cache.now = func() time.Time { return clock }
|
||||
|
||||
key := newTestKey("acct-1", "100.64.0.10", "svc.example")
|
||||
var calls int32
|
||||
validate := func(_ context.Context, _ *proto.ValidateTunnelPeerRequest) (*proto.ValidateTunnelPeerResponse, error) {
|
||||
atomic.AddInt32(&calls, 1)
|
||||
return &proto.ValidateTunnelPeerResponse{Valid: true, SessionToken: "tok"}, nil
|
||||
}
|
||||
|
||||
_, _, err := cache.fetch(context.Background(), key, validate)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int32(1), atomic.LoadInt32(&calls), "first fetch issues one RPC")
|
||||
|
||||
clock = clock.Add(tunnelCacheTTL + time.Second)
|
||||
|
||||
_, fromCache, err := cache.fetch(context.Background(), key, validate)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, fromCache, "expired entry must miss the cache")
|
||||
assert.Equal(t, int32(2), atomic.LoadInt32(&calls), "expired entry forces a re-fetch")
|
||||
}
|
||||
|
||||
func TestTunnelCache_DeniedResponseNotCached(t *testing.T) {
|
||||
cache := newTunnelValidationCache()
|
||||
key := newTestKey("acct-1", "100.64.0.10", "svc.example")
|
||||
|
||||
var calls int32
|
||||
validate := func(_ context.Context, _ *proto.ValidateTunnelPeerRequest) (*proto.ValidateTunnelPeerResponse, error) {
|
||||
atomic.AddInt32(&calls, 1)
|
||||
return &proto.ValidateTunnelPeerResponse{Valid: false, DeniedReason: "not_in_group"}, nil
|
||||
}
|
||||
|
||||
for i := 0; i < 3; i++ {
|
||||
_, _, err := cache.fetch(context.Background(), key, validate)
|
||||
require.NoError(t, err, "fetch must not error on denied response")
|
||||
}
|
||||
assert.Equal(t, int32(3), atomic.LoadInt32(&calls), "denied responses bypass the cache so policy changes apply immediately")
|
||||
}
|
||||
|
||||
func TestTunnelCache_ConcurrentColdHitsCoalesce(t *testing.T) {
|
||||
cache := newTunnelValidationCache()
|
||||
key := newTestKey("acct-1", "100.64.0.10", "svc.example")
|
||||
|
||||
gate := make(chan struct{})
|
||||
var calls int32
|
||||
validate := func(_ context.Context, _ *proto.ValidateTunnelPeerRequest) (*proto.ValidateTunnelPeerResponse, error) {
|
||||
atomic.AddInt32(&calls, 1)
|
||||
<-gate
|
||||
return &proto.ValidateTunnelPeerResponse{Valid: true, SessionToken: "tok"}, nil
|
||||
}
|
||||
|
||||
const workers = 16
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(workers)
|
||||
results := make([]bool, workers)
|
||||
for i := 0; i < workers; i++ {
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
resp, _, err := cache.fetch(context.Background(), key, validate)
|
||||
results[idx] = err == nil && resp.GetValid()
|
||||
}(i)
|
||||
}
|
||||
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
close(gate)
|
||||
wg.Wait()
|
||||
|
||||
for i, ok := range results {
|
||||
assert.Truef(t, ok, "worker %d should observe a successful response", i)
|
||||
}
|
||||
assert.Equal(t, int32(1), atomic.LoadInt32(&calls), "single-flight must collapse concurrent cold fetches into one RPC")
|
||||
}
|
||||
|
||||
func TestTunnelCache_PerAccountIsolation(t *testing.T) {
|
||||
cache := newTunnelValidationCache()
|
||||
keyA := newTestKey("acct-a", "100.64.0.10", "svc.example")
|
||||
keyB := newTestKey("acct-b", "100.64.0.10", "svc.example")
|
||||
|
||||
var callsA, callsB int32
|
||||
validateA := func(_ context.Context, _ *proto.ValidateTunnelPeerRequest) (*proto.ValidateTunnelPeerResponse, error) {
|
||||
atomic.AddInt32(&callsA, 1)
|
||||
return &proto.ValidateTunnelPeerResponse{Valid: true, SessionToken: "tok-a", UserId: "user-a"}, nil
|
||||
}
|
||||
validateB := func(_ context.Context, _ *proto.ValidateTunnelPeerRequest) (*proto.ValidateTunnelPeerResponse, error) {
|
||||
atomic.AddInt32(&callsB, 1)
|
||||
return &proto.ValidateTunnelPeerResponse{Valid: true, SessionToken: "tok-b", UserId: "user-b"}, nil
|
||||
}
|
||||
|
||||
respA, _, err := cache.fetch(context.Background(), keyA, validateA)
|
||||
require.NoError(t, err)
|
||||
respB, _, err := cache.fetch(context.Background(), keyB, validateB)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "user-a", respA.GetUserId(), "account A response should belong to user-a")
|
||||
assert.Equal(t, "user-b", respB.GetUserId(), "account B response must not be served from account A's cache")
|
||||
assert.Equal(t, int32(1), atomic.LoadInt32(&callsA), "validateA called exactly once")
|
||||
assert.Equal(t, int32(1), atomic.LoadInt32(&callsB), "validateB called exactly once")
|
||||
}
|
||||
|
||||
func TestTunnelCache_BoundedSizeEvictsOldest(t *testing.T) {
|
||||
cache := newTunnelValidationCache()
|
||||
cache.maxSize = 2
|
||||
|
||||
validate := func(_ context.Context, req *proto.ValidateTunnelPeerRequest) (*proto.ValidateTunnelPeerResponse, error) {
|
||||
return &proto.ValidateTunnelPeerResponse{Valid: true, SessionToken: "tok-" + req.GetTunnelIp()}, nil
|
||||
}
|
||||
|
||||
keys := []tunnelCacheKey{
|
||||
newTestKey("acct-1", "100.64.0.10", "svc"),
|
||||
newTestKey("acct-1", "100.64.0.11", "svc"),
|
||||
newTestKey("acct-1", "100.64.0.12", "svc"),
|
||||
}
|
||||
for _, k := range keys {
|
||||
_, _, err := cache.fetch(context.Background(), k, validate)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
assert.Nil(t, cache.get(keys[0]), "oldest key should be evicted past maxSize")
|
||||
assert.NotNil(t, cache.get(keys[1]), "second-newest must remain cached")
|
||||
assert.NotNil(t, cache.get(keys[2]), "newest must remain cached")
|
||||
}
|
||||
325
proxy/internal/auth/tunnel_lookup_test.go
Normal file
325
proxy/internal/auth/tunnel_lookup_test.go
Normal file
@@ -0,0 +1,325 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/netip"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/grpc"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/internal/proxy"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
// stubSessionValidator records ValidateTunnelPeer calls and returns the
|
||||
// pre-canned response. Counts let tests assert RPC traffic.
|
||||
type stubSessionValidator struct {
|
||||
respFn func(req *proto.ValidateTunnelPeerRequest) *proto.ValidateTunnelPeerResponse
|
||||
respErr error
|
||||
tunnelCalls atomic.Int32
|
||||
}
|
||||
|
||||
func (s *stubSessionValidator) ValidateSession(_ context.Context, _ *proto.ValidateSessionRequest, _ ...grpc.CallOption) (*proto.ValidateSessionResponse, error) {
|
||||
return &proto.ValidateSessionResponse{Valid: false}, nil
|
||||
}
|
||||
|
||||
func (s *stubSessionValidator) ValidateTunnelPeer(_ context.Context, in *proto.ValidateTunnelPeerRequest, _ ...grpc.CallOption) (*proto.ValidateTunnelPeerResponse, error) {
|
||||
s.tunnelCalls.Add(1)
|
||||
if s.respErr != nil {
|
||||
return nil, s.respErr
|
||||
}
|
||||
if s.respFn != nil {
|
||||
return s.respFn(in), nil
|
||||
}
|
||||
return &proto.ValidateTunnelPeerResponse{Valid: false}, nil
|
||||
}
|
||||
|
||||
func newTunnelMiddleware(t *testing.T, validator SessionValidator) *Middleware {
|
||||
t.Helper()
|
||||
mw := NewMiddleware(log.New(), validator, nil)
|
||||
require.NoError(t, mw.AddDomain("svc.example", nil, "", 0, "acct-1", "svc-1", nil, false))
|
||||
return mw
|
||||
}
|
||||
|
||||
func newTunnelRequest(remoteAddr string) (*httptest.ResponseRecorder, *http.Request) {
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequest(http.MethodGet, "https://svc.example/", nil)
|
||||
r.Host = "svc.example"
|
||||
r.RemoteAddr = remoteAddr
|
||||
return w, r
|
||||
}
|
||||
|
||||
// TestForwardWithTunnelPeer_LocalLookupUnknownIPDeniesFast verifies the
|
||||
// short-circuit: a tunnel IP not in the account's roster never reaches
|
||||
// management's ValidateTunnelPeer.
|
||||
func TestForwardWithTunnelPeer_LocalLookupUnknownIPDeniesFast(t *testing.T) {
|
||||
validator := &stubSessionValidator{}
|
||||
mw := newTunnelMiddleware(t, validator)
|
||||
|
||||
lookup := TunnelLookupFunc(func(_ netip.Addr) (PeerIdentity, bool) {
|
||||
return PeerIdentity{}, false
|
||||
})
|
||||
|
||||
w, r := newTunnelRequest("100.64.0.99:55555")
|
||||
r = r.WithContext(WithTunnelLookup(r.Context(), lookup))
|
||||
|
||||
called := false
|
||||
next := http.HandlerFunc(func(http.ResponseWriter, *http.Request) { called = true })
|
||||
|
||||
config, _ := mw.getDomainConfig("svc.example")
|
||||
handled := mw.forwardWithTunnelPeer(w, r, "svc.example", config, next)
|
||||
|
||||
assert.False(t, handled, "unknown peer must fall through, not forward")
|
||||
assert.False(t, called, "next handler must not run for unknown peer")
|
||||
assert.Equal(t, int32(0), validator.tunnelCalls.Load(), "ValidateTunnelPeer must be skipped on local-lookup miss")
|
||||
}
|
||||
|
||||
// TestForwardWithTunnelPeer_GroupsPropagateToCapturedData verifies the proxy
|
||||
// surfaces the calling peer's group memberships from ValidateTunnelPeerResponse
|
||||
// onto CapturedData so policy-aware middlewares can authorise without an
|
||||
// extra management round-trip.
|
||||
func TestForwardWithTunnelPeer_GroupsPropagateToCapturedData(t *testing.T) {
|
||||
groups := []string{"engineering", "sre"}
|
||||
validator := &stubSessionValidator{
|
||||
respFn: func(_ *proto.ValidateTunnelPeerRequest) *proto.ValidateTunnelPeerResponse {
|
||||
return &proto.ValidateTunnelPeerResponse{
|
||||
Valid: true,
|
||||
SessionToken: "tok",
|
||||
UserId: "user-1",
|
||||
PeerGroupIds: groups,
|
||||
}
|
||||
},
|
||||
}
|
||||
mw := newTunnelMiddleware(t, validator)
|
||||
|
||||
w, r := newTunnelRequest("100.64.0.10:55555")
|
||||
cd := proxy.NewCapturedData("")
|
||||
r = r.WithContext(proxy.WithCapturedData(r.Context(), cd))
|
||||
|
||||
called := false
|
||||
next := http.HandlerFunc(func(http.ResponseWriter, *http.Request) { called = true })
|
||||
|
||||
config, _ := mw.getDomainConfig("svc.example")
|
||||
handled := mw.forwardWithTunnelPeer(w, r, "svc.example", config, next)
|
||||
|
||||
require.True(t, handled, "valid tunnel-peer response must forward")
|
||||
require.True(t, called, "next handler must run")
|
||||
assert.Equal(t, "user-1", cd.GetUserID(), "user id must propagate from tunnel-peer response")
|
||||
assert.Equal(t, groups, cd.GetUserGroups(), "peer group IDs must propagate from tunnel-peer response")
|
||||
}
|
||||
|
||||
// TestForwardWithTunnelPeer_LocalLookupKnownPeerStillRPCs verifies that a
|
||||
// known tunnel IP still triggers ValidateTunnelPeer for the user-identity
|
||||
// tail (UserID + group access). Phase 3 only short-circuits the deny path.
|
||||
func TestForwardWithTunnelPeer_LocalLookupKnownPeerStillRPCs(t *testing.T) {
|
||||
validator := &stubSessionValidator{
|
||||
respFn: func(_ *proto.ValidateTunnelPeerRequest) *proto.ValidateTunnelPeerResponse {
|
||||
return &proto.ValidateTunnelPeerResponse{Valid: true, SessionToken: "tok", UserId: "user-1"}
|
||||
},
|
||||
}
|
||||
mw := newTunnelMiddleware(t, validator)
|
||||
|
||||
knownIP := netip.MustParseAddr("100.64.0.10")
|
||||
lookup := TunnelLookupFunc(func(ip netip.Addr) (PeerIdentity, bool) {
|
||||
if ip == knownIP {
|
||||
return PeerIdentity{PubKey: "pk", TunnelIP: ip, FQDN: "peer.netbird.cloud"}, true
|
||||
}
|
||||
return PeerIdentity{}, false
|
||||
})
|
||||
|
||||
w, r := newTunnelRequest(knownIP.String() + ":55555")
|
||||
r = r.WithContext(WithTunnelLookup(r.Context(), lookup))
|
||||
|
||||
called := false
|
||||
next := http.HandlerFunc(func(http.ResponseWriter, *http.Request) { called = true })
|
||||
|
||||
config, _ := mw.getDomainConfig("svc.example")
|
||||
handled := mw.forwardWithTunnelPeer(w, r, "svc.example", config, next)
|
||||
|
||||
assert.True(t, handled, "known peer with valid RPC response must forward")
|
||||
assert.True(t, called, "next handler must run on success")
|
||||
assert.Equal(t, int32(1), validator.tunnelCalls.Load(), "RPC must run for the user-identity tail when local lookup confirms the peer")
|
||||
}
|
||||
|
||||
// TestForwardWithTunnelPeer_NoLookupKeepsLegacyPath ensures the existing
|
||||
// behaviour stays intact on the host-level listener (no lookup attached).
|
||||
func TestForwardWithTunnelPeer_NoLookupKeepsLegacyPath(t *testing.T) {
|
||||
validator := &stubSessionValidator{
|
||||
respFn: func(_ *proto.ValidateTunnelPeerRequest) *proto.ValidateTunnelPeerResponse {
|
||||
return &proto.ValidateTunnelPeerResponse{Valid: true, SessionToken: "tok", UserId: "user-1"}
|
||||
},
|
||||
}
|
||||
mw := newTunnelMiddleware(t, validator)
|
||||
|
||||
w, r := newTunnelRequest("100.64.0.10:55555")
|
||||
called := false
|
||||
next := http.HandlerFunc(func(http.ResponseWriter, *http.Request) { called = true })
|
||||
|
||||
config, _ := mw.getDomainConfig("svc.example")
|
||||
handled := mw.forwardWithTunnelPeer(w, r, "svc.example", config, next)
|
||||
|
||||
assert.True(t, handled, "host-level path forwards on positive RPC result")
|
||||
assert.True(t, called, "next handler runs on host-level success")
|
||||
assert.Equal(t, int32(1), validator.tunnelCalls.Load(), "host-level path always RPCs (Phase 3 unchanged)")
|
||||
}
|
||||
|
||||
// TestForwardWithTunnelPeer_RPCErrorFallsThrough validates that an RPC
|
||||
// failure still falls through to the next scheme (no false positive).
|
||||
func TestForwardWithTunnelPeer_RPCErrorFallsThrough(t *testing.T) {
|
||||
validator := &stubSessionValidator{respErr: errors.New("management down")}
|
||||
mw := newTunnelMiddleware(t, validator)
|
||||
|
||||
knownIP := netip.MustParseAddr("100.64.0.10")
|
||||
lookup := TunnelLookupFunc(func(ip netip.Addr) (PeerIdentity, bool) {
|
||||
return PeerIdentity{TunnelIP: ip}, true
|
||||
})
|
||||
|
||||
w, r := newTunnelRequest(knownIP.String() + ":55555")
|
||||
r = r.WithContext(WithTunnelLookup(r.Context(), lookup))
|
||||
|
||||
config, _ := mw.getDomainConfig("svc.example")
|
||||
handled := mw.forwardWithTunnelPeer(w, r, "svc.example", config, http.HandlerFunc(func(http.ResponseWriter, *http.Request) {}))
|
||||
|
||||
assert.False(t, handled, "RPC error must let the caller try other schemes")
|
||||
assert.Equal(t, int32(1), validator.tunnelCalls.Load(), "RPC was attempted exactly once")
|
||||
}
|
||||
|
||||
// TestForwardWithTunnelPeer_CacheReusesPositiveResponse confirms the
|
||||
// (account, IP, domain) cache prevents repeated RPCs for the same peer.
|
||||
func TestForwardWithTunnelPeer_CacheReusesPositiveResponse(t *testing.T) {
|
||||
validator := &stubSessionValidator{
|
||||
respFn: func(_ *proto.ValidateTunnelPeerRequest) *proto.ValidateTunnelPeerResponse {
|
||||
return &proto.ValidateTunnelPeerResponse{Valid: true, SessionToken: "tok", UserId: "user-1"}
|
||||
},
|
||||
}
|
||||
mw := newTunnelMiddleware(t, validator)
|
||||
|
||||
for i := 0; i < 4; i++ {
|
||||
w, r := newTunnelRequest("100.64.0.10:55555")
|
||||
next := http.HandlerFunc(func(http.ResponseWriter, *http.Request) {})
|
||||
config, _ := mw.getDomainConfig("svc.example")
|
||||
handled := mw.forwardWithTunnelPeer(w, r, "svc.example", config, next)
|
||||
require.True(t, handled, "iteration %d should forward", i)
|
||||
}
|
||||
|
||||
assert.Equal(t, int32(1), validator.tunnelCalls.Load(), "subsequent forwards must hit the cache, not management")
|
||||
}
|
||||
|
||||
// TestForwardWithTunnelPeer_RoutesAccountIDIntoCacheKey ensures cache keys
|
||||
// honour account scoping — same tunnel IP on different accounts must not
|
||||
// collide.
|
||||
func TestForwardWithTunnelPeer_RoutesAccountIDIntoCacheKey(t *testing.T) {
|
||||
validator := &stubSessionValidator{
|
||||
respFn: func(req *proto.ValidateTunnelPeerRequest) *proto.ValidateTunnelPeerResponse {
|
||||
return &proto.ValidateTunnelPeerResponse{Valid: true, SessionToken: "tok", UserId: "user"}
|
||||
},
|
||||
}
|
||||
mw := NewMiddleware(log.New(), validator, nil)
|
||||
|
||||
require.NoError(t, mw.AddDomain("svc-a.example", nil, "", 0, "acct-a", "svc-a", nil, false))
|
||||
require.NoError(t, mw.AddDomain("svc-b.example", nil, "", 0, "acct-b", "svc-b", nil, false))
|
||||
|
||||
for _, host := range []string{"svc-a.example", "svc-b.example"} {
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequest(http.MethodGet, "https://"+host+"/", nil)
|
||||
r.Host = host
|
||||
r.RemoteAddr = "100.64.0.10:55555"
|
||||
config, _ := mw.getDomainConfig(host)
|
||||
handled := mw.forwardWithTunnelPeer(w, r, host, config, http.HandlerFunc(func(http.ResponseWriter, *http.Request) {}))
|
||||
require.True(t, handled, "host %s should forward", host)
|
||||
}
|
||||
|
||||
assert.Equal(t, int32(2), validator.tunnelCalls.Load(), "cache must not collide across accounts even when tunnel IPs match")
|
||||
}
|
||||
|
||||
// TestForwardWithTunnelPeer_LocalLookupShortCircuitDoesNotPopulateCache
|
||||
// guarantees that the deny-fast path leaves the cache untouched, so a
|
||||
// subsequent request from the same IP after the peerstore catches up
|
||||
// goes through the normal RPC flow.
|
||||
func TestForwardWithTunnelPeer_LocalLookupShortCircuitDoesNotPopulateCache(t *testing.T) {
|
||||
validator := &stubSessionValidator{
|
||||
respFn: func(_ *proto.ValidateTunnelPeerRequest) *proto.ValidateTunnelPeerResponse {
|
||||
return &proto.ValidateTunnelPeerResponse{Valid: true, SessionToken: "tok"}
|
||||
},
|
||||
}
|
||||
mw := newTunnelMiddleware(t, validator)
|
||||
|
||||
knownIP := netip.MustParseAddr("100.64.0.10")
|
||||
known := false
|
||||
lookup := TunnelLookupFunc(func(ip netip.Addr) (PeerIdentity, bool) {
|
||||
if known && ip == knownIP {
|
||||
return PeerIdentity{TunnelIP: ip}, true
|
||||
}
|
||||
return PeerIdentity{}, false
|
||||
})
|
||||
|
||||
doRequest := func() bool {
|
||||
w, r := newTunnelRequest(knownIP.String() + ":55555")
|
||||
r = r.WithContext(WithTunnelLookup(r.Context(), lookup))
|
||||
config, _ := mw.getDomainConfig("svc.example")
|
||||
return mw.forwardWithTunnelPeer(w, r, "svc.example", config, http.HandlerFunc(func(http.ResponseWriter, *http.Request) {}))
|
||||
}
|
||||
|
||||
require.False(t, doRequest(), "first request must short-circuit")
|
||||
require.Equal(t, int32(0), validator.tunnelCalls.Load(), "short-circuit must not populate the cache")
|
||||
|
||||
known = true
|
||||
require.True(t, doRequest(), "second request with peer in roster must forward via RPC")
|
||||
assert.Equal(t, int32(1), validator.tunnelCalls.Load(), "RPC runs once after peerstore catches up")
|
||||
}
|
||||
|
||||
func TestPrivateService_FailsClosedOnTunnelPeerFailure(t *testing.T) {
|
||||
mw := NewMiddleware(log.New(), nil, nil)
|
||||
require.NoError(t, mw.AddDomain("private.svc", nil, "", 0, "acct-1", "svc-1", nil, true))
|
||||
|
||||
called := false
|
||||
handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
called = true
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "https://private.svc/", nil)
|
||||
req.Host = "private.svc"
|
||||
req.RemoteAddr = "100.64.0.10:55555"
|
||||
w := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusForbidden, w.Code)
|
||||
assert.False(t, called)
|
||||
}
|
||||
|
||||
func TestPrivateService_ForwardsOnTunnelPeerSuccess(t *testing.T) {
|
||||
validator := &stubSessionValidator{
|
||||
respFn: func(_ *proto.ValidateTunnelPeerRequest) *proto.ValidateTunnelPeerResponse {
|
||||
return &proto.ValidateTunnelPeerResponse{
|
||||
Valid: true,
|
||||
SessionToken: "tok",
|
||||
UserId: "user-1",
|
||||
}
|
||||
},
|
||||
}
|
||||
mw := NewMiddleware(log.New(), validator, nil)
|
||||
require.NoError(t, mw.AddDomain("private.svc", nil, "", 0, "acct-1", "svc-1", nil, true))
|
||||
|
||||
called := false
|
||||
handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
called = true
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "https://private.svc/", nil)
|
||||
req.Host = "private.svc"
|
||||
req.RemoteAddr = "100.64.0.10:55555"
|
||||
w := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.True(t, called)
|
||||
}
|
||||
@@ -11,7 +11,6 @@ import (
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
)
|
||||
|
||||
// StatusFilters contains filter options for status queries.
|
||||
@@ -160,6 +159,49 @@ func (c *Client) printClients(data map[string]any) {
|
||||
for _, item := range clients {
|
||||
c.printClientRow(item)
|
||||
}
|
||||
|
||||
c.printInboundListeners(clients)
|
||||
}
|
||||
|
||||
func (c *Client) printInboundListeners(clients []any) {
|
||||
type row struct {
|
||||
accountID string
|
||||
tunnelIP string
|
||||
httpsPort int
|
||||
httpPort int
|
||||
}
|
||||
var rows []row
|
||||
for _, item := range clients {
|
||||
client, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
inbound, ok := client["inbound_listener"].(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
tunnelIP, _ := inbound["tunnel_ip"].(string)
|
||||
httpsPort, _ := inbound["https_port"].(float64)
|
||||
httpPort, _ := inbound["http_port"].(float64)
|
||||
accountID, _ := client["account_id"].(string)
|
||||
rows = append(rows, row{
|
||||
accountID: accountID,
|
||||
tunnelIP: tunnelIP,
|
||||
httpsPort: int(httpsPort),
|
||||
httpPort: int(httpPort),
|
||||
})
|
||||
}
|
||||
if len(rows) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
_, _ = fmt.Fprintln(c.out)
|
||||
_, _ = fmt.Fprintln(c.out, "Inbound listeners (per-account):")
|
||||
_, _ = fmt.Fprintf(c.out, " %-38s %-20s %-7s %s\n", "ACCOUNT ID", "TUNNEL IP", "HTTPS", "HTTP")
|
||||
_, _ = fmt.Fprintln(c.out, " "+strings.Repeat("-", 78))
|
||||
for _, r := range rows {
|
||||
_, _ = fmt.Fprintf(c.out, " %-38s %-20s %-7d %d\n", r.accountID, r.tunnelIP, r.httpsPort, r.httpPort)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) printClientRow(item any) {
|
||||
@@ -219,7 +261,14 @@ func (c *Client) ClientStatus(ctx context.Context, accountID string, filters Sta
|
||||
}
|
||||
|
||||
func (c *Client) printClientStatus(data map[string]any) {
|
||||
_, _ = fmt.Fprintf(c.out, "Account: %v\n\n", data["account_id"])
|
||||
_, _ = fmt.Fprintf(c.out, "Account: %v\n", data["account_id"])
|
||||
if inbound, ok := data["inbound_listener"].(map[string]any); ok {
|
||||
tunnelIP, _ := inbound["tunnel_ip"].(string)
|
||||
httpsPort, _ := inbound["https_port"].(float64)
|
||||
httpPort, _ := inbound["http_port"].(float64)
|
||||
_, _ = fmt.Fprintf(c.out, "Inbound listener: %s (https=%d, http=%d)\n", tunnelIP, int(httpsPort), int(httpPort))
|
||||
}
|
||||
_, _ = fmt.Fprintln(c.out)
|
||||
if status, ok := data["status"].(string); ok {
|
||||
_, _ = fmt.Fprint(c.out, status)
|
||||
}
|
||||
|
||||
@@ -61,6 +61,23 @@ type clientProvider interface {
|
||||
ListClientsForDebug() map[types.AccountID]roundtrip.ClientDebugInfo
|
||||
}
|
||||
|
||||
// InboundListenerInfo describes a per-account inbound listener as
|
||||
// surfaced through the debug HTTP handler. Mirrors the proto sub-message
|
||||
// emitted with SendStatusUpdate so dashboards and CLI tooling see the
|
||||
// same shape.
|
||||
type InboundListenerInfo struct {
|
||||
TunnelIP string `json:"tunnel_ip"`
|
||||
HTTPSPort uint16 `json:"https_port"`
|
||||
HTTPPort uint16 `json:"http_port"`
|
||||
}
|
||||
|
||||
// InboundProvider exposes per-account inbound listener state. Optional;
|
||||
// when nil the debug endpoint omits the inbound section entirely so the
|
||||
// existing JSON shape stays additive.
|
||||
type InboundProvider interface {
|
||||
InboundListeners() map[types.AccountID]InboundListenerInfo
|
||||
}
|
||||
|
||||
// healthChecker provides health probe state.
|
||||
type healthChecker interface {
|
||||
ReadinessProbe() bool
|
||||
@@ -80,6 +97,7 @@ type Handler struct {
|
||||
provider clientProvider
|
||||
health healthChecker
|
||||
certStatus certStatus
|
||||
inbound InboundProvider
|
||||
logger *log.Logger
|
||||
startTime time.Time
|
||||
templates *template.Template
|
||||
@@ -108,6 +126,13 @@ func (h *Handler) SetCertStatus(cs certStatus) {
|
||||
h.certStatus = cs
|
||||
}
|
||||
|
||||
// SetInboundProvider wires per-account inbound listener observability.
|
||||
// Pass nil (or skip the call) to keep the inbound section out of debug
|
||||
// responses on proxies that don't run --private-inbound.
|
||||
func (h *Handler) SetInboundProvider(p InboundProvider) {
|
||||
h.inbound = p
|
||||
}
|
||||
|
||||
func (h *Handler) loadTemplates() error {
|
||||
tmpl, err := template.ParseFS(templateFS, "templates/*.html")
|
||||
if err != nil {
|
||||
@@ -323,23 +348,35 @@ func (h *Handler) handleListClients(w http.ResponseWriter, _ *http.Request, want
|
||||
sortedIDs := sortedAccountIDs(clients)
|
||||
|
||||
if wantJSON {
|
||||
var inboundAll map[types.AccountID]InboundListenerInfo
|
||||
if h.inbound != nil {
|
||||
inboundAll = h.inbound.InboundListeners()
|
||||
}
|
||||
clientsJSON := make([]map[string]interface{}, 0, len(clients))
|
||||
for _, id := range sortedIDs {
|
||||
info := clients[id]
|
||||
clientsJSON = append(clientsJSON, map[string]interface{}{
|
||||
row := map[string]interface{}{
|
||||
"account_id": info.AccountID,
|
||||
"service_count": info.ServiceCount,
|
||||
"service_keys": info.ServiceKeys,
|
||||
"has_client": info.HasClient,
|
||||
"created_at": info.CreatedAt,
|
||||
"age": time.Since(info.CreatedAt).Round(time.Second).String(),
|
||||
})
|
||||
}
|
||||
if inb, ok := inboundAll[id]; ok {
|
||||
row["inbound_listener"] = inb
|
||||
}
|
||||
clientsJSON = append(clientsJSON, row)
|
||||
}
|
||||
h.writeJSON(w, map[string]interface{}{
|
||||
resp := map[string]interface{}{
|
||||
"uptime": time.Since(h.startTime).Round(time.Second).String(),
|
||||
"client_count": len(clients),
|
||||
"clients": clientsJSON,
|
||||
})
|
||||
}
|
||||
if len(inboundAll) > 0 {
|
||||
resp["inbound_listener_count"] = len(inboundAll)
|
||||
}
|
||||
h.writeJSON(w, resp)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -421,10 +458,14 @@ func (h *Handler) handleClientStatus(w http.ResponseWriter, r *http.Request, acc
|
||||
})
|
||||
|
||||
if wantJSON {
|
||||
h.writeJSON(w, map[string]interface{}{
|
||||
resp := map[string]interface{}{
|
||||
"account_id": accountID,
|
||||
"status": overview.FullDetailSummary(),
|
||||
})
|
||||
}
|
||||
if info, ok := h.inboundInfoFor(accountID); ok {
|
||||
resp["inbound_listener"] = info
|
||||
}
|
||||
h.writeJSON(w, resp)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -437,6 +478,18 @@ func (h *Handler) handleClientStatus(w http.ResponseWriter, r *http.Request, acc
|
||||
h.renderTemplate(w, "clientDetail", data)
|
||||
}
|
||||
|
||||
// inboundInfoFor returns the inbound listener info for an account, or
|
||||
// ok=false when no inbound provider is wired or the account has no live
|
||||
// listener.
|
||||
func (h *Handler) inboundInfoFor(accountID types.AccountID) (InboundListenerInfo, bool) {
|
||||
if h.inbound == nil {
|
||||
return InboundListenerInfo{}, false
|
||||
}
|
||||
all := h.inbound.InboundListeners()
|
||||
info, ok := all[accountID]
|
||||
return info, ok
|
||||
}
|
||||
|
||||
func (h *Handler) handleClientSyncResponse(w http.ResponseWriter, _ *http.Request, accountID types.AccountID, wantJSON bool) {
|
||||
client, ok := h.provider.GetClient(accountID)
|
||||
if !ok {
|
||||
|
||||
@@ -52,8 +52,15 @@ type CapturedData struct {
|
||||
origin ResponseOrigin
|
||||
clientIP netip.Addr
|
||||
userID string
|
||||
authMethod string
|
||||
metadata map[string]string
|
||||
userEmail string
|
||||
userGroups []string
|
||||
// userGroupNames pairs positionally with userGroups; populated from
|
||||
// the JWT's group_names claim or from ValidateSession/Tunnel
|
||||
// responses. Slice may be shorter than userGroups for tokens minted
|
||||
// before names were resolvable.
|
||||
userGroupNames []string
|
||||
authMethod string
|
||||
metadata map[string]string
|
||||
}
|
||||
|
||||
// NewCapturedData creates a CapturedData with the given request ID.
|
||||
@@ -138,6 +145,81 @@ func (c *CapturedData) GetUserID() string {
|
||||
return c.userID
|
||||
}
|
||||
|
||||
// SetUserEmail records the authenticated user's email address. Used by
|
||||
// policy-aware middlewares to stamp identity onto upstream requests
|
||||
// (e.g. x-litellm-end-user-id) without a management round-trip.
|
||||
func (c *CapturedData) SetUserEmail(email string) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.userEmail = email
|
||||
}
|
||||
|
||||
// GetUserEmail returns the authenticated user's email address. Returns
|
||||
// the empty string when the auth path didn't carry an email (e.g.
|
||||
// non-OIDC schemes or legacy JWTs minted before the email claim).
|
||||
func (c *CapturedData) GetUserEmail() string {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
return c.userEmail
|
||||
}
|
||||
|
||||
// SetUserGroups records the authenticated user's group memberships so
|
||||
// downstream policy-aware middlewares can authorise the request without
|
||||
// an additional management round-trip. The auth middleware populates this
|
||||
// from ValidateSessionResponse / ValidateTunnelPeerResponse and from the
|
||||
// session JWT's groups claim on cookie-bearing requests.
|
||||
func (c *CapturedData) SetUserGroups(groups []string) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if len(groups) == 0 {
|
||||
c.userGroups = nil
|
||||
return
|
||||
}
|
||||
c.userGroups = append(c.userGroups[:0], groups...)
|
||||
}
|
||||
|
||||
// GetUserGroups returns a copy of the authenticated user's group
|
||||
// memberships.
|
||||
func (c *CapturedData) GetUserGroups() []string {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
if len(c.userGroups) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]string, len(c.userGroups))
|
||||
copy(out, c.userGroups)
|
||||
return out
|
||||
}
|
||||
|
||||
// SetUserGroupNames records the human-readable display names for the
|
||||
// user's groups, ordered identically to UserGroups (positional
|
||||
// pairing). Stamped onto upstream requests as X-NetBird-Groups so
|
||||
// downstream services can read names rather than opaque ids.
|
||||
func (c *CapturedData) SetUserGroupNames(names []string) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if len(names) == 0 {
|
||||
c.userGroupNames = nil
|
||||
return
|
||||
}
|
||||
c.userGroupNames = append(c.userGroupNames[:0], names...)
|
||||
}
|
||||
|
||||
// GetUserGroupNames returns a copy of the authenticated user's group
|
||||
// display names. Position i pairs with UserGroups[i]. May be shorter
|
||||
// than UserGroups for tokens minted before names were resolvable; the
|
||||
// consumer should fall back to ids for missing positions.
|
||||
func (c *CapturedData) GetUserGroupNames() []string {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
if len(c.userGroupNames) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]string, len(c.userGroupNames))
|
||||
copy(out, c.userGroupNames)
|
||||
return out
|
||||
}
|
||||
|
||||
// SetAuthMethod sets the authentication method used.
|
||||
func (c *CapturedData) SetAuthMethod(method string) {
|
||||
c.mu.Lock()
|
||||
|
||||
@@ -86,6 +86,9 @@ func (p *ReverseProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
if pt.RequestTimeout > 0 {
|
||||
ctx = types.WithDialTimeout(ctx, pt.RequestTimeout)
|
||||
}
|
||||
if pt.DirectUpstream {
|
||||
ctx = roundtrip.WithDirectUpstream(ctx)
|
||||
}
|
||||
|
||||
rewriteMatchedPath := result.matchedPath
|
||||
if pt.PathRewrite == PathRewritePreserve {
|
||||
@@ -142,6 +145,8 @@ func (p *ReverseProxy) rewriteFunc(target *url.URL, matchedPath string, passHost
|
||||
r.Out.Header.Set(k, v)
|
||||
}
|
||||
|
||||
stampNetBirdIdentity(r)
|
||||
|
||||
clientIP := extractHostIP(r.In.RemoteAddr)
|
||||
|
||||
if isTrustedAddr(clientIP, p.trustedProxies) {
|
||||
@@ -426,3 +431,70 @@ func opErrorContains(err error, substr string) bool {
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
const (
|
||||
// headerNetBirdUser carries the authenticated user's display identity
|
||||
// (email when the peer is attached to a user, else peer name) onto
|
||||
// upstream requests. Stripped from inbound requests before stamping
|
||||
// so a client can't spoof identity by setting the header themselves.
|
||||
headerNetBirdUser = "X-NetBird-User"
|
||||
// headerNetBirdGroups carries the user's group display names as a
|
||||
// comma-separated list. Falls back to group IDs at positions where a
|
||||
// name wasn't available at session-mint time. Labels containing a
|
||||
// comma or any non-printable byte are dropped at stamp time so the
|
||||
// list is unambiguously splittable by consumers.
|
||||
headerNetBirdGroups = "X-NetBird-Groups"
|
||||
)
|
||||
|
||||
// isHeaderValueSafe reports whether v is a valid RFC 7230 field-value:
|
||||
// VCHAR (0x21-0x7E), SP (0x20), or HTAB (0x09). Empty values are
|
||||
// rejected; the caller decides whether to omit the header entirely.
|
||||
func isHeaderValueSafe(v string) bool {
|
||||
if v == "" {
|
||||
return false
|
||||
}
|
||||
for i := 0; i < len(v); i++ {
|
||||
c := v[i]
|
||||
if c == '\t' || (c >= 0x20 && c <= 0x7E) {
|
||||
continue
|
||||
}
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// stampNetBirdIdentity injects authenticated identity onto outbound
|
||||
// requests as X-NetBird-User and X-NetBird-Groups. Always strips any
|
||||
// client-sent values first (anti-spoof). Skips when the request didn't
|
||||
// carry CapturedData (early-path errors, internal endpoints).
|
||||
func stampNetBirdIdentity(r *httputil.ProxyRequest) {
|
||||
r.Out.Header.Del(headerNetBirdUser)
|
||||
r.Out.Header.Del(headerNetBirdGroups)
|
||||
|
||||
cd := CapturedDataFromContext(r.In.Context())
|
||||
if cd == nil {
|
||||
return
|
||||
}
|
||||
if email := cd.GetUserEmail(); isHeaderValueSafe(email) {
|
||||
r.Out.Header.Set(headerNetBirdUser, email)
|
||||
}
|
||||
groupIDs := cd.GetUserGroups()
|
||||
if len(groupIDs) == 0 {
|
||||
return
|
||||
}
|
||||
groupNames := cd.GetUserGroupNames()
|
||||
labels := make([]string, 0, len(groupIDs))
|
||||
for i, id := range groupIDs {
|
||||
label := id
|
||||
if i < len(groupNames) && groupNames[i] != "" {
|
||||
label = groupNames[i]
|
||||
}
|
||||
if !isHeaderValueSafe(label) || strings.ContainsRune(label, ',') {
|
||||
continue
|
||||
}
|
||||
labels = append(labels, label)
|
||||
}
|
||||
if len(labels) > 0 {
|
||||
r.Out.Header.Set(headerNetBirdGroups, strings.Join(labels, ","))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1067,3 +1067,245 @@ func TestClassifyProxyError(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStampNetBirdIdentity_NoCapturedData_StripsOnly(t *testing.T) {
|
||||
target, _ := url.Parse("http://backend.internal:8080")
|
||||
p := &ReverseProxy{forwardedProto: "auto"}
|
||||
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
|
||||
|
||||
pr := newProxyRequest(t, "http://example.com/", "203.0.113.50:9999")
|
||||
pr.In.Header.Set(headerNetBirdUser, "spoofed@evil.io")
|
||||
pr.In.Header.Set(headerNetBirdGroups, "admin")
|
||||
pr.Out.Header = pr.In.Header.Clone()
|
||||
|
||||
rewrite(pr)
|
||||
|
||||
assert.Empty(t, pr.Out.Header.Get(headerNetBirdUser),
|
||||
"client-supplied X-NetBird-User must be stripped when no captured identity is present")
|
||||
assert.Empty(t, pr.Out.Header.Get(headerNetBirdGroups),
|
||||
"client-supplied X-NetBird-Groups must be stripped when no captured identity is present")
|
||||
}
|
||||
|
||||
func TestStampNetBirdIdentity_StampsFromCapturedData(t *testing.T) {
|
||||
target, _ := url.Parse("http://backend.internal:8080")
|
||||
p := &ReverseProxy{forwardedProto: "auto"}
|
||||
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
|
||||
|
||||
pr := newProxyRequest(t, "http://example.com/", "203.0.113.50:9999")
|
||||
pr.In.Header.Set(headerNetBirdUser, "spoofed@evil.io")
|
||||
pr.Out.Header = pr.In.Header.Clone()
|
||||
|
||||
cd := NewCapturedData("req-1")
|
||||
cd.SetUserEmail("alice@netbird.io")
|
||||
cd.SetUserGroups([]string{"grp-eng", "grp-ops"})
|
||||
cd.SetUserGroupNames([]string{"engineering", "operations"})
|
||||
|
||||
pr.In = pr.In.WithContext(WithCapturedData(pr.In.Context(), cd))
|
||||
|
||||
rewrite(pr)
|
||||
|
||||
assert.Equal(t, "alice@netbird.io", pr.Out.Header.Get(headerNetBirdUser),
|
||||
"captured email must overwrite any spoofed value")
|
||||
assert.Equal(t, "engineering,operations", pr.Out.Header.Get(headerNetBirdGroups),
|
||||
"group display names must be CSV-joined in positional order")
|
||||
}
|
||||
|
||||
// TestStampNetBirdIdentity_GroupsOnlyWhenEmailEmpty covers the
|
||||
// tunnel-peer-without-user case (machine agents, unattached proxy peers).
|
||||
// The proxy must still stamp the peer's groups so downstream services can
|
||||
// authorise, but X-NetBird-User stays unset — only its inbound stripping
|
||||
// must happen.
|
||||
func TestStampNetBirdIdentity_GroupsOnlyWhenEmailEmpty(t *testing.T) {
|
||||
target, _ := url.Parse("http://backend.internal:8080")
|
||||
p := &ReverseProxy{forwardedProto: "auto"}
|
||||
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
|
||||
|
||||
pr := newProxyRequest(t, "http://example.com/", "203.0.113.50:9999")
|
||||
pr.In.Header.Set(headerNetBirdUser, "spoofed@evil.io")
|
||||
pr.Out.Header = pr.In.Header.Clone()
|
||||
|
||||
cd := NewCapturedData("req-1")
|
||||
cd.SetUserGroups([]string{"grp-machines"})
|
||||
cd.SetUserGroupNames([]string{"machines"})
|
||||
|
||||
pr.In = pr.In.WithContext(WithCapturedData(pr.In.Context(), cd))
|
||||
|
||||
rewrite(pr)
|
||||
|
||||
assert.Empty(t, pr.Out.Header.Get(headerNetBirdUser),
|
||||
"X-NetBird-User must remain unset when CapturedData carries no email")
|
||||
assert.Equal(t, "machines", pr.Out.Header.Get(headerNetBirdGroups),
|
||||
"groups must still be stamped for peers without a user identity")
|
||||
}
|
||||
|
||||
// TestStampNetBirdIdentity_EmailOnlyWhenGroupsEmpty covers the symmetric
|
||||
// case: identity-resolved user without resolved group memberships.
|
||||
func TestStampNetBirdIdentity_EmailOnlyWhenGroupsEmpty(t *testing.T) {
|
||||
target, _ := url.Parse("http://backend.internal:8080")
|
||||
p := &ReverseProxy{forwardedProto: "auto"}
|
||||
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
|
||||
|
||||
pr := newProxyRequest(t, "http://example.com/", "203.0.113.50:9999")
|
||||
pr.In.Header.Set(headerNetBirdGroups, "spoofed-admin")
|
||||
pr.Out.Header = pr.In.Header.Clone()
|
||||
|
||||
cd := NewCapturedData("req-1")
|
||||
cd.SetUserEmail("carol@netbird.io")
|
||||
|
||||
pr.In = pr.In.WithContext(WithCapturedData(pr.In.Context(), cd))
|
||||
|
||||
rewrite(pr)
|
||||
|
||||
assert.Equal(t, "carol@netbird.io", pr.Out.Header.Get(headerNetBirdUser),
|
||||
"email must be stamped even when no groups are captured")
|
||||
assert.Empty(t, pr.Out.Header.Get(headerNetBirdGroups),
|
||||
"X-NetBird-Groups must remain unset when CapturedData carries no groups")
|
||||
}
|
||||
|
||||
func TestStampNetBirdIdentity_FallsBackToGroupIDsWhenNameMissing(t *testing.T) {
|
||||
target, _ := url.Parse("http://backend.internal:8080")
|
||||
p := &ReverseProxy{forwardedProto: "auto"}
|
||||
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
|
||||
|
||||
pr := newProxyRequest(t, "http://example.com/", "203.0.113.50:9999")
|
||||
|
||||
cd := NewCapturedData("req-1")
|
||||
cd.SetUserEmail("bob@netbird.io")
|
||||
cd.SetUserGroups([]string{"grp-a", "grp-b", "grp-c"})
|
||||
// "grp-b" gets an explicit empty-string display name (not just a
|
||||
// shorter slice). Both gap shapes must fall back to the id.
|
||||
cd.SetUserGroupNames([]string{"alpha", "", ""})
|
||||
|
||||
pr.In = pr.In.WithContext(WithCapturedData(pr.In.Context(), cd))
|
||||
|
||||
rewrite(pr)
|
||||
|
||||
assert.Equal(t, "alpha,grp-b,grp-c", pr.Out.Header.Get(headerNetBirdGroups),
|
||||
"empty-string and out-of-range name slots must both fall back to the group id")
|
||||
}
|
||||
|
||||
// TestStampNetBirdIdentity_DropsLabelsWithComma covers the
|
||||
// comma-separator constraint: a group display name that itself contains
|
||||
// a comma is dropped from the header (rather than corrupting the list),
|
||||
// and the remaining labels are stamped.
|
||||
func TestStampNetBirdIdentity_DropsLabelsWithComma(t *testing.T) {
|
||||
target, _ := url.Parse("http://backend.internal:8080")
|
||||
p := &ReverseProxy{forwardedProto: "auto"}
|
||||
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
|
||||
|
||||
pr := newProxyRequest(t, "http://example.com/", "203.0.113.50:9999")
|
||||
|
||||
cd := NewCapturedData("req-1")
|
||||
cd.SetUserEmail("alice@netbird.io")
|
||||
cd.SetUserGroups([]string{"grp-a", "grp-b", "grp-c"})
|
||||
cd.SetUserGroupNames([]string{"engineering", "EU, EMEA", "operations"})
|
||||
|
||||
pr.In = pr.In.WithContext(WithCapturedData(pr.In.Context(), cd))
|
||||
|
||||
rewrite(pr)
|
||||
|
||||
assert.Equal(t, "engineering,operations", pr.Out.Header.Get(headerNetBirdGroups),
|
||||
"group label with embedded comma must be dropped, remaining labels stamped")
|
||||
}
|
||||
|
||||
// TestStampNetBirdIdentity_RejectsControlCharsInEmail covers the
|
||||
// header-injection defence: an email value containing CR/LF/control
|
||||
// chars is omitted entirely (not partially stamped) so the upstream
|
||||
// request stays well-formed and no header injection is possible.
|
||||
func TestStampNetBirdIdentity_RejectsControlCharsInEmail(t *testing.T) {
|
||||
target, _ := url.Parse("http://backend.internal:8080")
|
||||
p := &ReverseProxy{forwardedProto: "auto"}
|
||||
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
|
||||
|
||||
pr := newProxyRequest(t, "http://example.com/", "203.0.113.50:9999")
|
||||
pr.In.Header.Set(headerNetBirdUser, "spoofed@evil.io")
|
||||
pr.Out.Header = pr.In.Header.Clone()
|
||||
|
||||
cd := NewCapturedData("req-1")
|
||||
cd.SetUserEmail("alice@netbird.io\r\nX-Admin: yes")
|
||||
cd.SetUserGroups([]string{"grp-a"})
|
||||
cd.SetUserGroupNames([]string{"engineering"})
|
||||
|
||||
pr.In = pr.In.WithContext(WithCapturedData(pr.In.Context(), cd))
|
||||
|
||||
rewrite(pr)
|
||||
|
||||
assert.Empty(t, pr.Out.Header.Get(headerNetBirdUser),
|
||||
"email with CR/LF must be dropped, not partially stamped")
|
||||
assert.Equal(t, "engineering", pr.Out.Header.Get(headerNetBirdGroups),
|
||||
"groups remain stampable even when email is invalid")
|
||||
}
|
||||
|
||||
// TestStampNetBirdIdentity_RejectsControlCharsInGroup covers the
|
||||
// per-label defence: a group name with a control char is silently
|
||||
// dropped, the rest are stamped.
|
||||
func TestStampNetBirdIdentity_RejectsControlCharsInGroup(t *testing.T) {
|
||||
target, _ := url.Parse("http://backend.internal:8080")
|
||||
p := &ReverseProxy{forwardedProto: "auto"}
|
||||
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
|
||||
|
||||
pr := newProxyRequest(t, "http://example.com/", "203.0.113.50:9999")
|
||||
|
||||
cd := NewCapturedData("req-1")
|
||||
cd.SetUserEmail("alice@netbird.io")
|
||||
cd.SetUserGroups([]string{"grp-a", "grp-b"})
|
||||
cd.SetUserGroupNames([]string{"engineering\r\nsneaky", "operations"})
|
||||
|
||||
pr.In = pr.In.WithContext(WithCapturedData(pr.In.Context(), cd))
|
||||
|
||||
rewrite(pr)
|
||||
|
||||
assert.Equal(t, "operations", pr.Out.Header.Get(headerNetBirdGroups),
|
||||
"group label with control char must be dropped, valid ones kept")
|
||||
}
|
||||
|
||||
// TestStampNetBirdIdentity_OmitsGroupsHeaderWhenAllInvalid covers the
|
||||
// edge case where every group label is rejected: the header must not be
|
||||
// set at all (rather than set to an empty string).
|
||||
func TestStampNetBirdIdentity_OmitsGroupsHeaderWhenAllInvalid(t *testing.T) {
|
||||
target, _ := url.Parse("http://backend.internal:8080")
|
||||
p := &ReverseProxy{forwardedProto: "auto"}
|
||||
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
|
||||
|
||||
pr := newProxyRequest(t, "http://example.com/", "203.0.113.50:9999")
|
||||
pr.In.Header.Set(headerNetBirdGroups, "spoofed-admin")
|
||||
pr.Out.Header = pr.In.Header.Clone()
|
||||
|
||||
cd := NewCapturedData("req-1")
|
||||
cd.SetUserEmail("alice@netbird.io")
|
||||
cd.SetUserGroups([]string{"grp-a", "grp-b"})
|
||||
cd.SetUserGroupNames([]string{"with,comma", "with\nbreak"})
|
||||
|
||||
pr.In = pr.In.WithContext(WithCapturedData(pr.In.Context(), cd))
|
||||
|
||||
rewrite(pr)
|
||||
|
||||
_, present := pr.Out.Header[http.CanonicalHeaderKey(headerNetBirdGroups)]
|
||||
assert.False(t, present,
|
||||
"X-NetBird-Groups must not be set when every group label is rejected")
|
||||
}
|
||||
|
||||
// TestStampNetBirdIdentity_CapturedDataPresentButEmpty covers requests
|
||||
// that carry CapturedData with no identity fields populated (e.g. the
|
||||
// auth middleware ran but the request didn't authenticate). Both
|
||||
// headers must be cleared and neither stamped.
|
||||
func TestStampNetBirdIdentity_CapturedDataPresentButEmpty(t *testing.T) {
|
||||
target, _ := url.Parse("http://backend.internal:8080")
|
||||
p := &ReverseProxy{forwardedProto: "auto"}
|
||||
rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil)
|
||||
|
||||
pr := newProxyRequest(t, "http://example.com/", "203.0.113.50:9999")
|
||||
pr.In.Header.Set(headerNetBirdUser, "spoofed@evil.io")
|
||||
pr.In.Header.Set(headerNetBirdGroups, "spoofed-admin")
|
||||
pr.Out.Header = pr.In.Header.Clone()
|
||||
|
||||
cd := NewCapturedData("req-1")
|
||||
pr.In = pr.In.WithContext(WithCapturedData(pr.In.Context(), cd))
|
||||
|
||||
rewrite(pr)
|
||||
|
||||
assert.Empty(t, pr.Out.Header.Get(headerNetBirdUser),
|
||||
"X-NetBird-User must be stripped when CapturedData has no email")
|
||||
assert.Empty(t, pr.Out.Header.Get(headerNetBirdGroups),
|
||||
"X-NetBird-Groups must be stripped when CapturedData has no groups")
|
||||
}
|
||||
|
||||
@@ -28,6 +28,10 @@ type PathTarget struct {
|
||||
RequestTimeout time.Duration
|
||||
PathRewrite PathRewriteMode
|
||||
CustomHeaders map[string]string
|
||||
// DirectUpstream selects the stdlib HTTP transport (host network stack)
|
||||
// over the embedded NetBird WireGuard client when forwarding requests
|
||||
// to this target. Default false → embedded client (existing behaviour).
|
||||
DirectUpstream bool
|
||||
}
|
||||
|
||||
// Mapping describes how a domain is routed by the HTTP reverse proxy.
|
||||
|
||||
@@ -191,6 +191,18 @@ func (f *Filter) IsObserveOnly(v Verdict) bool {
|
||||
return v.IsCrowdSec() && f.CrowdSecMode == CrowdSecObserve
|
||||
}
|
||||
|
||||
// CheckCIDR runs only the CIDR allow/block evaluation. Use this when
|
||||
// country and CrowdSec checks don't apply — e.g. requests arriving
|
||||
// from the WireGuard overlay, whose source addresses live in the
|
||||
// CGNAT range and have no meaningful geolocation or IP-reputation
|
||||
// data.
|
||||
func (f *Filter) CheckCIDR(addr netip.Addr) Verdict {
|
||||
if f == nil {
|
||||
return Allow
|
||||
}
|
||||
return f.checkCIDR(addr.Unmap())
|
||||
}
|
||||
|
||||
// Check evaluates whether addr is permitted. CIDR rules are evaluated
|
||||
// first because they are O(n) prefix comparisons. Country rules run
|
||||
// only when CIDR checks pass and require a geo lookup. CrowdSec checks
|
||||
|
||||
@@ -514,6 +514,34 @@ func TestFilter_CrowdSec_Observe_NilChecker(t *testing.T) {
|
||||
assert.Equal(t, Allow, f.Check(netip.MustParseAddr("1.2.3.4"), nil))
|
||||
}
|
||||
|
||||
func TestFilter_CheckCIDR_AllowsWithoutCountryOrCrowdSec(t *testing.T) {
|
||||
cs := &mockCrowdSec{ready: true, decisions: map[string]*CrowdSecDecision{
|
||||
"100.64.5.6": {Type: DecisionBan},
|
||||
}}
|
||||
f := ParseFilter(FilterConfig{
|
||||
AllowedCIDRs: []string{"100.64.0.0/10"},
|
||||
AllowedCountries: []string{"US"},
|
||||
CrowdSec: cs,
|
||||
CrowdSecMode: CrowdSecEnforce,
|
||||
})
|
||||
|
||||
// CheckCIDR skips country + CrowdSec evaluation: an address inside
|
||||
// the allowed CIDR passes even when it would be denied by CrowdSec
|
||||
// or by the country allowlist (CGNAT addresses have no geo data).
|
||||
assert.Equal(t, Allow, f.CheckCIDR(netip.MustParseAddr("100.64.5.6")),
|
||||
"CheckCIDR must not run CrowdSec lookups on overlay traffic")
|
||||
|
||||
// CIDR denials still fire.
|
||||
assert.Equal(t, DenyCIDR, f.CheckCIDR(netip.MustParseAddr("198.51.100.1")),
|
||||
"CheckCIDR must still reject addresses outside the allow list")
|
||||
}
|
||||
|
||||
func TestFilter_CheckCIDR_NilFilter(t *testing.T) {
|
||||
var f *Filter
|
||||
assert.Equal(t, Allow, f.CheckCIDR(netip.MustParseAddr("100.64.5.6")),
|
||||
"CheckCIDR on a nil filter must allow")
|
||||
}
|
||||
|
||||
func TestFilter_HasRestrictions_CrowdSec(t *testing.T) {
|
||||
cs := &mockCrowdSec{ready: true}
|
||||
f := ParseFilter(FilterConfig{CrowdSec: cs, CrowdSecMode: CrowdSecEnforce})
|
||||
|
||||
112
proxy/internal/roundtrip/multi.go
Normal file
112
proxy/internal/roundtrip/multi.go
Normal file
@@ -0,0 +1,112 @@
|
||||
package roundtrip
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// MultiTransport dispatches each request to either the embedded NetBird
|
||||
// http.RoundTripper or a stdlib http.Transport based on a per-request
|
||||
// context flag set by the reverse-proxy rewrite step. When the flag is
|
||||
// absent (the default for every existing target), requests follow the
|
||||
// embedded NetBird path — current behaviour, preserved.
|
||||
//
|
||||
// The stdlib branch is used when a target was configured with
|
||||
// direct_upstream=true. It dials via the host's network stack, which is
|
||||
// what private (`netbird proxy`) deployments and centralised proxies
|
||||
// fronting host-reachable upstreams (public APIs, LAN services,
|
||||
// localhost sidecars) want.
|
||||
//
|
||||
// An embedded roundtripper is required. To run direct-only (no WG
|
||||
// branch at all), construct the MultiTransport via NewDirectOnly.
|
||||
type MultiTransport struct {
|
||||
embedded http.RoundTripper
|
||||
direct *http.Transport
|
||||
insecure *http.Transport
|
||||
}
|
||||
|
||||
// errNoEmbeddedTransport is returned when a request reaches the
|
||||
// embedded branch on a MultiTransport that wasn't given one. Surfaces
|
||||
// the misconfiguration to the caller instead of silently routing to
|
||||
// the direct branch, which would bypass the WG tunnel.
|
||||
var errNoEmbeddedTransport = errors.New("multitransport: embedded roundtripper not configured")
|
||||
|
||||
// NewMultiTransport wires both branches. embedded is the existing NetBird
|
||||
// roundtripper and must not be nil — pass to NewDirectOnly for a
|
||||
// MultiTransport that only ever uses the direct branch. The direct
|
||||
// branches honour the same NB_PROXY_* tuning env vars as the embedded
|
||||
// transport (see loadTransportConfig) plus a dial-timeout wrapper that
|
||||
// respects types.WithDialTimeout.
|
||||
func NewMultiTransport(embedded http.RoundTripper, logger *log.Logger) *MultiTransport {
|
||||
if logger == nil {
|
||||
logger = log.StandardLogger()
|
||||
}
|
||||
cfg := loadTransportConfig(logger)
|
||||
dialer := &net.Dialer{
|
||||
Timeout: 30 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
}
|
||||
direct := &http.Transport{
|
||||
DialContext: dialWithTimeout(dialer.DialContext),
|
||||
ForceAttemptHTTP2: true,
|
||||
MaxIdleConns: cfg.maxIdleConns,
|
||||
MaxIdleConnsPerHost: cfg.maxIdleConnsPerHost,
|
||||
MaxConnsPerHost: cfg.maxConnsPerHost,
|
||||
IdleConnTimeout: cfg.idleConnTimeout,
|
||||
TLSHandshakeTimeout: cfg.tlsHandshakeTimeout,
|
||||
ExpectContinueTimeout: cfg.expectContinueTimeout,
|
||||
ResponseHeaderTimeout: cfg.responseHeaderTimeout,
|
||||
WriteBufferSize: cfg.writeBufferSize,
|
||||
ReadBufferSize: cfg.readBufferSize,
|
||||
DisableCompression: cfg.disableCompression,
|
||||
}
|
||||
insecure := direct.Clone()
|
||||
insecure.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} //nolint:gosec // matches the embedded NetBird transport's per-target opt-in
|
||||
|
||||
return &MultiTransport{
|
||||
embedded: embedded,
|
||||
direct: direct,
|
||||
insecure: insecure,
|
||||
}
|
||||
}
|
||||
|
||||
// NewDirectOnly returns a MultiTransport with no embedded branch.
|
||||
// Every request goes through the direct branch regardless of the
|
||||
// per-request flag, so the embedded path can never be reached
|
||||
// silently — wiring code that needs WG must use NewMultiTransport.
|
||||
func NewDirectOnly(logger *log.Logger) *MultiTransport {
|
||||
return NewMultiTransport(noEmbeddedRoundTripper{}, logger)
|
||||
}
|
||||
|
||||
// noEmbeddedRoundTripper is the sentinel embedded transport for
|
||||
// direct-only MultiTransports. RoundTrip is never called in practice
|
||||
// because the direct branch matches every request, but if anything
|
||||
// ever did reach this path it would fail loudly instead of falling
|
||||
// back to direct.
|
||||
type noEmbeddedRoundTripper struct{}
|
||||
|
||||
func (noEmbeddedRoundTripper) RoundTrip(*http.Request) (*http.Response, error) {
|
||||
return nil, errNoEmbeddedTransport
|
||||
}
|
||||
|
||||
// RoundTrip dispatches by reading the direct-upstream flag from the request
|
||||
// context. When set, the request is forwarded via the stdlib transport,
|
||||
// honouring the existing per-request skip-TLS-verify flag. Otherwise it
|
||||
// goes through the embedded NetBird roundtripper.
|
||||
func (m *MultiTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
if DirectUpstreamFromContext(req.Context()) {
|
||||
if skipTLSVerifyFromContext(req.Context()) {
|
||||
return m.insecure.RoundTrip(req)
|
||||
}
|
||||
return m.direct.RoundTrip(req)
|
||||
}
|
||||
if m.embedded == nil {
|
||||
return nil, errNoEmbeddedTransport
|
||||
}
|
||||
return m.embedded.RoundTrip(req)
|
||||
}
|
||||
134
proxy/internal/roundtrip/multi_test.go
Normal file
134
proxy/internal/roundtrip/multi_test.go
Normal file
@@ -0,0 +1,134 @@
|
||||
package roundtrip
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// stubRoundTripper records whether RoundTrip was called and returns a
|
||||
// canned response so tests can assert the dispatch decision without
|
||||
// running a real network.
|
||||
type stubRoundTripper struct {
|
||||
called bool
|
||||
body string
|
||||
}
|
||||
|
||||
func (s *stubRoundTripper) RoundTrip(_ *http.Request) (*http.Response, error) {
|
||||
s.called = true
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader(s.body)),
|
||||
Header: http.Header{},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func TestMultiTransport_DispatchesByContextFlag(t *testing.T) {
|
||||
embedded := &stubRoundTripper{body: "embedded"}
|
||||
mt := NewMultiTransport(embedded, nil)
|
||||
|
||||
t.Run("default routes to embedded", func(t *testing.T) {
|
||||
embedded.called = false
|
||||
req := httptest.NewRequest(http.MethodGet, "http://example.invalid", nil)
|
||||
resp, err := mt.RoundTrip(req)
|
||||
require.NoError(t, err, "embedded path must not error on stubbed transport")
|
||||
require.NotNil(t, resp)
|
||||
_ = resp.Body.Close()
|
||||
assert.True(t, embedded.called, "request without WithDirectUpstream must hit the embedded transport")
|
||||
})
|
||||
|
||||
t.Run("WithDirectUpstream skips embedded", func(t *testing.T) {
|
||||
embedded.called = false
|
||||
// Hit a server we control to verify the stdlib transport is used.
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
_, _ = io.WriteString(w, "direct")
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
req, err := http.NewRequestWithContext(WithDirectUpstream(context.Background()), http.MethodGet, srv.URL, nil)
|
||||
require.NoError(t, err)
|
||||
resp, err := mt.RoundTrip(req)
|
||||
require.NoError(t, err, "direct path must dial via stdlib transport")
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
_ = resp.Body.Close()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "direct", string(body), "stdlib transport must reach the test server")
|
||||
assert.False(t, embedded.called, "WithDirectUpstream must bypass the embedded transport")
|
||||
})
|
||||
}
|
||||
|
||||
// TestMultiTransport_AppliesEnvOverridesToDirect verifies that the
|
||||
// NB_PROXY_* env vars consumed by loadTransportConfig flow into the
|
||||
// direct branches (previously they only applied to the embedded
|
||||
// roundtripper, so direct-upstream traffic ignored operator tuning).
|
||||
func TestMultiTransport_AppliesEnvOverridesToDirect(t *testing.T) {
|
||||
t.Setenv(EnvMaxIdleConns, "42")
|
||||
t.Setenv(EnvIdleConnTimeout, "11s")
|
||||
t.Setenv(EnvTLSHandshakeTimeout, "7s")
|
||||
|
||||
mt := NewMultiTransport(&stubRoundTripper{body: "embedded"}, nil)
|
||||
|
||||
assert.Equal(t, 42, mt.direct.MaxIdleConns,
|
||||
"NB_PROXY_MAX_IDLE_CONNS must propagate to the direct transport")
|
||||
assert.Equal(t, 11*time.Second, mt.direct.IdleConnTimeout,
|
||||
"NB_PROXY_IDLE_CONN_TIMEOUT must propagate to the direct transport")
|
||||
assert.Equal(t, 7*time.Second, mt.direct.TLSHandshakeTimeout,
|
||||
"NB_PROXY_TLS_HANDSHAKE_TIMEOUT must propagate to the direct transport")
|
||||
assert.Equal(t, 42, mt.insecure.MaxIdleConns,
|
||||
"env tuning must also apply to the insecure-skip-verify direct transport")
|
||||
}
|
||||
|
||||
// TestMultiTransport_NilEmbeddedErrorsWhenWGPathRequested guards
|
||||
// against the previous silent fallback: a MultiTransport constructed
|
||||
// without an embedded transport must reject requests that don't
|
||||
// explicitly opt into the direct branch, rather than routing them
|
||||
// over the host stack and bypassing WireGuard.
|
||||
func TestMultiTransport_NilEmbeddedErrorsWhenWGPathRequested(t *testing.T) {
|
||||
mt := NewMultiTransport(nil, nil)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://example.invalid", nil)
|
||||
resp, err := mt.RoundTrip(req)
|
||||
if resp != nil {
|
||||
_ = resp.Body.Close()
|
||||
}
|
||||
require.Error(t, err, "nil embedded must surface as an explicit error, not a silent direct dispatch")
|
||||
assert.Nil(t, resp)
|
||||
assert.ErrorIs(t, err, errNoEmbeddedTransport,
|
||||
"the error must be the sentinel so callers can distinguish misconfiguration from network failures")
|
||||
}
|
||||
|
||||
// TestMultiTransport_DirectOnlyServesDirectBranch verifies NewDirectOnly
|
||||
// constructs a MultiTransport whose direct branch handles requests with
|
||||
// the direct-upstream flag set, and surfaces the explicit sentinel
|
||||
// when the embedded path is reached.
|
||||
func TestMultiTransport_DirectOnlyServesDirectBranch(t *testing.T) {
|
||||
mt := NewDirectOnly(nil)
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
_, _ = io.WriteString(w, "ok")
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
req, err := http.NewRequestWithContext(WithDirectUpstream(context.Background()), http.MethodGet, srv.URL, nil)
|
||||
require.NoError(t, err)
|
||||
resp, err := mt.RoundTrip(req)
|
||||
require.NoError(t, err, "direct-only must serve requests that opt into the direct branch")
|
||||
_ = resp.Body.Close()
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
wgReq := httptest.NewRequest(http.MethodGet, "http://example.invalid", nil)
|
||||
resp, err = mt.RoundTrip(wgReq)
|
||||
if resp != nil {
|
||||
_ = resp.Body.Close()
|
||||
}
|
||||
require.Error(t, err, "direct-only must refuse requests that didn't opt into the direct branch")
|
||||
assert.Nil(t, resp)
|
||||
assert.ErrorIs(t, err, errNoEmbeddedTransport)
|
||||
}
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -76,11 +77,11 @@ type clientEntry struct {
|
||||
services map[ServiceKey]serviceInfo
|
||||
createdAt time.Time
|
||||
started bool
|
||||
// ready is closed once the client has been fully initialized.
|
||||
// Callers that find a pending entry wait on this channel before
|
||||
// accessing the client. A nil initErr means success.
|
||||
ready chan struct{}
|
||||
initErr error
|
||||
// inbound is opaque per-account state owned by the NetBird parent's
|
||||
// ReadyHandler. The roundtrip package never inspects this value; it
|
||||
// only stores it so RemovePeer / StopAll can hand it back to the
|
||||
// matching StopHandler. Nil when no inbound integration is active.
|
||||
inbound any
|
||||
// Per-backend in-flight limiting keyed by target host:port.
|
||||
// TODO: clean up stale entries when backend targets change.
|
||||
inflightMu sync.Mutex
|
||||
@@ -88,6 +89,19 @@ type clientEntry struct {
|
||||
maxInflight int
|
||||
}
|
||||
|
||||
// IdentityForIP resolves a tunnel IP to the peer identity locally known by
|
||||
// this account's embedded client. Returns (pubKey, fqdn) on success.
|
||||
// ok=false means the IP is not in the account's roster — callers can use
|
||||
// that as a fast deny without round-tripping management. The returned
|
||||
// strings carry only what the embedded peerstore exposes; user identity
|
||||
// (UserID / Email / Groups) still flows through ValidateTunnelPeer.
|
||||
func (e *clientEntry) IdentityForIP(ip netip.Addr) (pubKey, fqdn string, ok bool) {
|
||||
if e == nil || e.client == nil || !ip.IsValid() {
|
||||
return "", "", false
|
||||
}
|
||||
return e.client.IdentityForIP(ip)
|
||||
}
|
||||
|
||||
// acquireInflight attempts to acquire an in-flight slot for the given backend.
|
||||
// It returns a release function that must always be called, and true on success.
|
||||
func (e *clientEntry) acquireInflight(backend backendKey) (release func(), ok bool) {
|
||||
@@ -117,6 +131,12 @@ type ClientConfig struct {
|
||||
MgmtAddr string
|
||||
WGPort uint16
|
||||
PreSharedKey string
|
||||
// BlockInbound mirrors embed.Options.BlockInbound. Set to true on the
|
||||
// standalone proxy where the embedded client never accepts inbound;
|
||||
// set to false on the private/embedded proxy so the engine creates
|
||||
// the ACL manager and applies management's per-policy firewall rules
|
||||
// (which is what gates per-account inbound listeners on the netstack).
|
||||
BlockInbound bool
|
||||
}
|
||||
|
||||
type statusNotifier interface {
|
||||
@@ -142,6 +162,14 @@ type NetBird struct {
|
||||
clients map[types.AccountID]*clientEntry
|
||||
initLogOnce sync.Once
|
||||
statusNotifier statusNotifier
|
||||
// readyHandler runs after the embedded client for an account reports
|
||||
// Ready. The opaque return value is stored on clientEntry and handed
|
||||
// back to stopHandler when the entry is torn down. Nil disables the
|
||||
// hook entirely (default for the standalone proxy).
|
||||
readyHandler func(ctx context.Context, accountID types.AccountID, client *embed.Client) any
|
||||
// stopHandler runs when an account's last service is removed (or the
|
||||
// transport is shutting down). Receives whatever readyHandler returned.
|
||||
stopHandler func(accountID types.AccountID, state any)
|
||||
|
||||
// OnAddPeer, when set, is called after AddPeer completes for a new account
|
||||
// (i.e. when a new client was actually created, not when an existing one
|
||||
@@ -167,9 +195,6 @@ type skipTLSVerifyContextKey struct{}
|
||||
// AddPeer registers a service for an account. If the account doesn't have a client yet,
|
||||
// one is created by authenticating with the management server using the provided token.
|
||||
// Multiple services can share the same client.
|
||||
//
|
||||
// Client creation (WG keygen, gRPC, embed.New) runs without holding clientsMux
|
||||
// so that concurrent AddPeer calls for different accounts execute in parallel.
|
||||
func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, key ServiceKey, authToken string, serviceID types.ServiceID) error {
|
||||
si := serviceInfo{serviceID: serviceID}
|
||||
|
||||
@@ -177,23 +202,10 @@ func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, key Se
|
||||
|
||||
entry, exists := n.clients[accountID]
|
||||
if exists {
|
||||
ready := entry.ready
|
||||
entry.services[key] = si
|
||||
started := entry.started
|
||||
n.clientsMux.Unlock()
|
||||
|
||||
// If the entry is still being initialized by another goroutine, wait.
|
||||
if ready != nil {
|
||||
select {
|
||||
case <-ready:
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
if entry.initErr != nil {
|
||||
return fmt.Errorf("peer initialization failed: %w", entry.initErr)
|
||||
}
|
||||
}
|
||||
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
"service_key": key,
|
||||
@@ -210,43 +222,19 @@ func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, key Se
|
||||
return nil
|
||||
}
|
||||
|
||||
// Insert a placeholder so other goroutines calling AddPeer for the same
|
||||
// account will wait on the ready channel instead of starting a second
|
||||
// client creation.
|
||||
entry = &clientEntry{
|
||||
services: map[ServiceKey]serviceInfo{key: si},
|
||||
ready: make(chan struct{}),
|
||||
}
|
||||
n.clients[accountID] = entry
|
||||
n.clientsMux.Unlock()
|
||||
|
||||
createStart := time.Now()
|
||||
created, err := n.createClientEntry(ctx, accountID, key, authToken, si)
|
||||
entry, err := n.createClientEntry(ctx, accountID, key, authToken, si)
|
||||
if n.OnAddPeer != nil {
|
||||
n.OnAddPeer(time.Since(createStart), err)
|
||||
}
|
||||
if err != nil {
|
||||
entry.initErr = err
|
||||
close(entry.ready)
|
||||
|
||||
n.clientsMux.Lock()
|
||||
delete(n.clients, accountID)
|
||||
n.clientsMux.Unlock()
|
||||
return err
|
||||
}
|
||||
|
||||
// Transfer any services that were registered by concurrent AddPeer calls
|
||||
// while we were creating the client.
|
||||
n.clientsMux.Lock()
|
||||
for k, v := range entry.services {
|
||||
created.services[k] = v
|
||||
}
|
||||
created.ready = nil
|
||||
n.clients[accountID] = created
|
||||
n.clients[accountID] = entry
|
||||
n.clientsMux.Unlock()
|
||||
|
||||
close(entry.ready)
|
||||
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
"service_key": key,
|
||||
@@ -254,13 +242,13 @@ func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, key Se
|
||||
|
||||
// Attempt to start the client in the background; if this fails we will
|
||||
// retry on the first request via RoundTrip.
|
||||
go n.runClientStartup(ctx, accountID, created.client)
|
||||
go n.runClientStartup(ctx, accountID, entry.client)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// createClientEntry generates a WireGuard keypair, authenticates with management,
|
||||
// and creates an embedded NetBird client.
|
||||
// and creates an embedded NetBird client. Must be called with clientsMux held.
|
||||
func (n *NetBird) createClientEntry(ctx context.Context, accountID types.AccountID, key ServiceKey, authToken string, si serviceInfo) (*clientEntry, error) {
|
||||
serviceID := si.serviceID
|
||||
n.logger.WithFields(log.Fields{
|
||||
@@ -318,9 +306,15 @@ func (n *NetBird) createClientEntry(ctx context.Context, accountID types.Account
|
||||
ManagementURL: n.clientCfg.MgmtAddr,
|
||||
PrivateKey: privateKey.String(),
|
||||
LogLevel: log.WarnLevel.String(),
|
||||
BlockInbound: true,
|
||||
WireguardPort: &wgPort,
|
||||
PreSharedKey: n.clientCfg.PreSharedKey,
|
||||
BlockInbound: n.clientCfg.BlockInbound,
|
||||
// The embedded proxy peer must never be a stepping stone into
|
||||
// the proxy host's LAN: it only exists to reach NetBird mesh
|
||||
// targets or, when direct_upstream is set, the host network
|
||||
// stack via the MultiTransport's direct branch (which bypasses
|
||||
// the engine routing entirely).
|
||||
BlockLANAccess: true,
|
||||
WireguardPort: &wgPort,
|
||||
PreSharedKey: n.clientCfg.PreSharedKey,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create netbird client: %w", err)
|
||||
@@ -385,8 +379,25 @@ func (n *NetBird) runClientStartup(ctx context.Context, accountID types.AccountI
|
||||
toNotify = append(toNotify, serviceNotification{key: key, serviceID: info.serviceID})
|
||||
}
|
||||
}
|
||||
readyHandler := n.readyHandler
|
||||
n.clientsMux.Unlock()
|
||||
|
||||
if readyHandler != nil {
|
||||
state := readyHandler(ctx, accountID, client)
|
||||
n.clientsMux.Lock()
|
||||
if e, ok := n.clients[accountID]; ok {
|
||||
e.inbound = state
|
||||
} else if state != nil && n.stopHandler != nil {
|
||||
// Account was removed while readyHandler ran; tear down the
|
||||
// resources it just brought up.
|
||||
stop := n.stopHandler
|
||||
n.clientsMux.Unlock()
|
||||
stop(accountID, state)
|
||||
n.clientsMux.Lock()
|
||||
}
|
||||
n.clientsMux.Unlock()
|
||||
}
|
||||
|
||||
if n.statusNotifier == nil {
|
||||
return
|
||||
}
|
||||
@@ -432,11 +443,15 @@ func (n *NetBird) RemovePeer(ctx context.Context, accountID types.AccountID, key
|
||||
stopClient := len(entry.services) == 0
|
||||
var client *embed.Client
|
||||
var transport, insecureTransport *http.Transport
|
||||
var inbound any
|
||||
var stopHandler func(types.AccountID, any)
|
||||
if stopClient {
|
||||
n.logger.WithField("account_id", accountID).Info("stopping client, no more services")
|
||||
client = entry.client
|
||||
transport = entry.transport
|
||||
insecureTransport = entry.insecureTransport
|
||||
inbound = entry.inbound
|
||||
stopHandler = n.stopHandler
|
||||
delete(n.clients, accountID)
|
||||
} else {
|
||||
n.logger.WithFields(log.Fields{
|
||||
@@ -450,6 +465,9 @@ func (n *NetBird) RemovePeer(ctx context.Context, accountID types.AccountID, key
|
||||
n.notifyDisconnect(ctx, accountID, key, si.serviceID)
|
||||
|
||||
if stopClient {
|
||||
if inbound != nil && stopHandler != nil {
|
||||
stopHandler(accountID, inbound)
|
||||
}
|
||||
transport.CloseIdleConnections()
|
||||
insecureTransport.CloseIdleConnections()
|
||||
if err := client.Stop(ctx); err != nil {
|
||||
@@ -536,8 +554,12 @@ func (n *NetBird) StopAll(ctx context.Context) error {
|
||||
n.clientsMux.Lock()
|
||||
defer n.clientsMux.Unlock()
|
||||
|
||||
stopHandler := n.stopHandler
|
||||
var merr *multierror.Error
|
||||
for accountID, entry := range n.clients {
|
||||
if entry.inbound != nil && stopHandler != nil {
|
||||
stopHandler(accountID, entry.inbound)
|
||||
}
|
||||
entry.transport.CloseIdleConnections()
|
||||
entry.insecureTransport.CloseIdleConnections()
|
||||
if err := entry.client.Stop(ctx); err != nil {
|
||||
@@ -590,6 +612,19 @@ func (n *NetBird) GetClient(accountID types.AccountID) (*embed.Client, bool) {
|
||||
return entry.client, true
|
||||
}
|
||||
|
||||
// IdentityForIP resolves a tunnel IP to a peer identity local to the given
|
||||
// account. Delegates to clientEntry.IdentityForIP. Returns ok=false when
|
||||
// the account has no client or the IP is not in its peerstore.
|
||||
func (n *NetBird) IdentityForIP(accountID types.AccountID, ip netip.Addr) (pubKey, fqdn string, ok bool) {
|
||||
n.clientsMux.RLock()
|
||||
entry, exists := n.clients[accountID]
|
||||
n.clientsMux.RUnlock()
|
||||
if !exists {
|
||||
return "", "", false
|
||||
}
|
||||
return entry.IdentityForIP(ip)
|
||||
}
|
||||
|
||||
// ListClientsForDebug returns information about all clients for debug purposes.
|
||||
func (n *NetBird) ListClientsForDebug() map[types.AccountID]ClientDebugInfo {
|
||||
n.clientsMux.RLock()
|
||||
@@ -645,6 +680,18 @@ func NewNetBird(proxyID, proxyAddr string, clientCfg ClientConfig, logger *log.L
|
||||
}
|
||||
}
|
||||
|
||||
// SetClientLifecycle registers callbacks that run when an embedded
|
||||
// client becomes ready and when its entry is torn down. The opaque value
|
||||
// returned by ready is stored on the entry and handed back to stop on
|
||||
// cleanup. Must be called before AddPeer. A nil pair leaves the
|
||||
// outbound-only behaviour intact.
|
||||
func (n *NetBird) SetClientLifecycle(ready func(ctx context.Context, accountID types.AccountID, client *embed.Client) any, stop func(accountID types.AccountID, state any)) {
|
||||
n.clientsMux.Lock()
|
||||
defer n.clientsMux.Unlock()
|
||||
n.readyHandler = ready
|
||||
n.stopHandler = stop
|
||||
}
|
||||
|
||||
// dialWithTimeout wraps a DialContext function so that any dial timeout
|
||||
// stored in the context (via types.WithDialTimeout) is applied only to
|
||||
// the connection establishment phase, not the full request lifetime.
|
||||
@@ -687,3 +734,22 @@ func skipTLSVerifyFromContext(ctx context.Context) bool {
|
||||
v, _ := ctx.Value(skipTLSVerifyContextKey{}).(bool)
|
||||
return v
|
||||
}
|
||||
|
||||
// directUpstreamContextKey signals that the request should bypass the embedded
|
||||
// NetBird WireGuard client and dial via the host's network stack instead.
|
||||
// Set by the reverse-proxy rewrite step when the matched target carries
|
||||
// PathTarget.DirectUpstream; consumed by MultiTransport.
|
||||
type directUpstreamContextKey struct{}
|
||||
|
||||
// WithDirectUpstream marks the context so MultiTransport routes the request
|
||||
// through its stdlib transport instead of the embedded NetBird roundtripper.
|
||||
func WithDirectUpstream(ctx context.Context) context.Context {
|
||||
return context.WithValue(ctx, directUpstreamContextKey{}, true)
|
||||
}
|
||||
|
||||
// DirectUpstreamFromContext reports whether the context has been marked to
|
||||
// bypass the embedded NetBird client.
|
||||
func DirectUpstreamFromContext(ctx context.Context) bool {
|
||||
v, _ := ctx.Value(directUpstreamContextKey{}).(bool)
|
||||
return v
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package roundtrip
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
@@ -305,6 +306,36 @@ func TestNetBird_AddPeer_ExistingStartedClient_NotifiesStatus(t *testing.T) {
|
||||
assert.True(t, calls[0].connected)
|
||||
}
|
||||
|
||||
// TestNetBird_IdentityForIP_UnknownAccountReturnsFalse confirms that the
|
||||
// public lookup short-circuits when no client has been registered for
|
||||
// the queried account. The auth middleware uses ok=false as a fast deny.
|
||||
func TestNetBird_IdentityForIP_UnknownAccountReturnsFalse(t *testing.T) {
|
||||
nb := mockNetBird()
|
||||
_, _, ok := nb.IdentityForIP("acct-missing", netip.MustParseAddr("100.64.0.10"))
|
||||
assert.False(t, ok, "unknown account must yield ok=false")
|
||||
}
|
||||
|
||||
// TestClientEntry_IdentityForIP_NilClientGuard ensures the receiver
|
||||
// methods stay safe when called on partially-initialized state, which
|
||||
// can happen briefly during AddPeer setup or test fixtures.
|
||||
func TestClientEntry_IdentityForIP_NilClientGuard(t *testing.T) {
|
||||
var e *clientEntry
|
||||
_, _, ok := e.IdentityForIP(netip.MustParseAddr("100.64.0.10"))
|
||||
assert.False(t, ok, "nil clientEntry must yield ok=false")
|
||||
|
||||
e = &clientEntry{}
|
||||
_, _, ok = e.IdentityForIP(netip.MustParseAddr("100.64.0.10"))
|
||||
assert.False(t, ok, "clientEntry with nil embed.Client must yield ok=false")
|
||||
}
|
||||
|
||||
// TestClientEntry_IdentityForIP_InvalidIPReturnsFalse covers the input
|
||||
// guard so callers don't have to repeat the check.
|
||||
func TestClientEntry_IdentityForIP_InvalidIPReturnsFalse(t *testing.T) {
|
||||
e := &clientEntry{}
|
||||
_, _, ok := e.IdentityForIP(netip.Addr{})
|
||||
assert.False(t, ok, "invalid IP must yield ok=false")
|
||||
}
|
||||
|
||||
func TestNetBird_RemovePeer_NotifiesDisconnection(t *testing.T) {
|
||||
notifier := &mockStatusNotifier{}
|
||||
nb := NewNetBird("test-proxy", "invalid.test", ClientConfig{
|
||||
|
||||
@@ -36,7 +36,7 @@ func BenchmarkPeekClientHello_TLS(b *testing.B) {
|
||||
for b.Loop() {
|
||||
r := bytes.NewReader(hello)
|
||||
conn := &readerConn{Reader: r}
|
||||
sni, wrapped, err := PeekClientHello(conn)
|
||||
sni, wrapped, _, err := PeekClientHello(conn)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
@@ -59,7 +59,7 @@ func BenchmarkPeekClientHello_NonTLS(b *testing.B) {
|
||||
for b.Loop() {
|
||||
r := bytes.NewReader(httpReq)
|
||||
conn := &readerConn{Reader: r}
|
||||
_, wrapped, err := PeekClientHello(conn)
|
||||
_, wrapped, _, err := PeekClientHello(conn)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -100,28 +100,50 @@ type Router struct {
|
||||
// httpCh is immutable after construction: set only in NewRouter, nil in NewPortRouter.
|
||||
httpCh chan net.Conn
|
||||
httpListener *chanListener
|
||||
mu sync.RWMutex
|
||||
routes map[SNIHost][]Route
|
||||
fallback *Route
|
||||
draining bool
|
||||
dialResolve DialResolver
|
||||
activeConns sync.WaitGroup
|
||||
activeRelays sync.WaitGroup
|
||||
relaySem chan struct{}
|
||||
drainDone chan struct{}
|
||||
observer RelayObserver
|
||||
accessLog l4Logger
|
||||
geo restrict.GeoResolver
|
||||
// httpPlainCh feeds non-TLS HTTP connections to a parallel http.Server.
|
||||
// Set only when NewRouter is called with WithPlainHTTP option (used by
|
||||
// per-account inbound listeners that accept both :80 and :443 traffic).
|
||||
// Nil for the host SNI router and for port routers.
|
||||
httpPlainCh chan net.Conn
|
||||
httpPlainListener *chanListener
|
||||
mu sync.RWMutex
|
||||
routes map[SNIHost][]Route
|
||||
fallback *Route
|
||||
draining bool
|
||||
dialResolve DialResolver
|
||||
activeConns sync.WaitGroup
|
||||
activeRelays sync.WaitGroup
|
||||
relaySem chan struct{}
|
||||
drainDone chan struct{}
|
||||
observer RelayObserver
|
||||
accessLog l4Logger
|
||||
geo restrict.GeoResolver
|
||||
// svcCtxs tracks a context per service ID. All relay goroutines for a
|
||||
// service derive from its context; canceling it kills them immediately.
|
||||
svcCtxs map[types.ServiceID]context.Context
|
||||
svcCancels map[types.ServiceID]context.CancelFunc
|
||||
}
|
||||
|
||||
// RouterOption customises Router construction.
|
||||
type RouterOption func(*Router)
|
||||
|
||||
// WithPlainHTTP enables a parallel plain-HTTP channel on the router. When
|
||||
// set, connections whose first byte is not a TLS handshake are forwarded
|
||||
// to the plain channel returned by HTTPListenerPlain instead of the TLS
|
||||
// channel. Used by per-account inbound listeners that share both :80 and
|
||||
// :443 traffic on the same router.
|
||||
func WithPlainHTTP(addr net.Addr) RouterOption {
|
||||
return func(r *Router) {
|
||||
ch := make(chan net.Conn, httpChannelBuffer)
|
||||
r.httpPlainCh = ch
|
||||
r.httpPlainListener = newChanListener(ch, addr)
|
||||
}
|
||||
}
|
||||
|
||||
// NewRouter creates a new SNI-based connection router.
|
||||
func NewRouter(logger *log.Logger, dialResolve DialResolver, addr net.Addr) *Router {
|
||||
func NewRouter(logger *log.Logger, dialResolve DialResolver, addr net.Addr, opts ...RouterOption) *Router {
|
||||
httpCh := make(chan net.Conn, httpChannelBuffer)
|
||||
return &Router{
|
||||
r := &Router{
|
||||
logger: logger,
|
||||
httpCh: httpCh,
|
||||
httpListener: newChanListener(httpCh, addr),
|
||||
@@ -131,6 +153,10 @@ func NewRouter(logger *log.Logger, dialResolve DialResolver, addr net.Addr) *Rou
|
||||
svcCtxs: make(map[types.ServiceID]context.Context),
|
||||
svcCancels: make(map[types.ServiceID]context.CancelFunc),
|
||||
}
|
||||
for _, opt := range opts {
|
||||
opt(r)
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
// NewPortRouter creates a Router for a dedicated port without an HTTP
|
||||
@@ -153,6 +179,16 @@ func (r *Router) HTTPListener() net.Listener {
|
||||
return r.httpListener
|
||||
}
|
||||
|
||||
// HTTPListenerPlain returns a net.Listener yielding non-TLS connections
|
||||
// for use with a parallel plain http.Server. Returns nil when the router
|
||||
// was not constructed with WithPlainHTTP.
|
||||
func (r *Router) HTTPListenerPlain() net.Listener {
|
||||
if r.httpPlainListener == nil {
|
||||
return nil
|
||||
}
|
||||
return r.httpPlainListener
|
||||
}
|
||||
|
||||
// AddRoute registers an SNI route. Multiple routes for the same host are
|
||||
// stored and resolved by priority at lookup time (HTTP > TCP).
|
||||
// Empty host is ignored to prevent conflicts with ECH/ESNI fallback.
|
||||
@@ -254,6 +290,9 @@ func (r *Router) Serve(ctx context.Context, ln net.Listener) error {
|
||||
if r.httpListener != nil {
|
||||
r.httpListener.Close()
|
||||
}
|
||||
if r.httpPlainListener != nil {
|
||||
r.httpPlainListener.Close()
|
||||
}
|
||||
case <-done:
|
||||
}
|
||||
}()
|
||||
@@ -270,6 +309,7 @@ func (r *Router) Serve(ctx context.Context, ln net.Listener) error {
|
||||
r.logger.Debugf("SNI router accept: %v", err)
|
||||
continue
|
||||
}
|
||||
r.logger.Debugf("SNI router accepted conn from %s on %s", conn.RemoteAddr(), conn.LocalAddr())
|
||||
r.activeConns.Add(1)
|
||||
go func() {
|
||||
defer r.activeConns.Done()
|
||||
@@ -278,13 +318,24 @@ func (r *Router) Serve(ctx context.Context, ln net.Listener) error {
|
||||
}
|
||||
}
|
||||
|
||||
// HandleConn lets external accept loops feed a connection through the
|
||||
// router's peek-and-dispatch logic. Use this when the same router serves
|
||||
// a secondary listener (for example, a per-account inbound :80 socket
|
||||
// alongside its :443 socket).
|
||||
func (r *Router) HandleConn(ctx context.Context, conn net.Conn) {
|
||||
r.activeConns.Add(1)
|
||||
defer r.activeConns.Done()
|
||||
r.handleConn(ctx, conn)
|
||||
}
|
||||
|
||||
// handleConn peeks at the TLS ClientHello and routes the connection.
|
||||
func (r *Router) handleConn(ctx context.Context, conn net.Conn) {
|
||||
// Fast path: when no SNI routes and no HTTP channel exist (pure TCP
|
||||
// fallback port), skip the TLS peek entirely to avoid read errors on
|
||||
// non-TLS connections and reduce latency.
|
||||
if r.isFallbackOnly() {
|
||||
r.handleUnmatched(ctx, conn)
|
||||
r.logger.Debugf("SNI router fallback-only mode for conn from %s; skipping ClientHello peek", conn.RemoteAddr())
|
||||
r.handleUnmatched(ctx, conn, false)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -294,11 +345,11 @@ func (r *Router) handleConn(ctx context.Context, conn net.Conn) {
|
||||
return
|
||||
}
|
||||
|
||||
sni, wrapped, err := PeekClientHello(conn)
|
||||
sni, wrapped, isTLS, err := PeekClientHello(conn)
|
||||
if err != nil {
|
||||
r.logger.Debugf("SNI peek: %v", err)
|
||||
r.logger.Debugf("SNI peek failed for conn from %s: %v", conn.RemoteAddr(), err)
|
||||
if wrapped != nil {
|
||||
r.handleUnmatched(ctx, wrapped)
|
||||
r.handleUnmatched(ctx, wrapped, isTLS)
|
||||
} else {
|
||||
_ = conn.Close()
|
||||
}
|
||||
@@ -313,13 +364,20 @@ func (r *Router) handleConn(ctx context.Context, conn net.Conn) {
|
||||
|
||||
host := SNIHost(strings.ToLower(sni))
|
||||
route, ok := r.lookupRoute(host)
|
||||
r.logger.WithFields(log.Fields{
|
||||
"remote": wrapped.RemoteAddr().String(),
|
||||
"sni": string(host),
|
||||
"match": ok,
|
||||
"tls": isTLS,
|
||||
}).Debug("SNI route lookup")
|
||||
if !ok {
|
||||
r.handleUnmatched(ctx, wrapped)
|
||||
r.handleUnmatched(ctx, wrapped, isTLS)
|
||||
return
|
||||
}
|
||||
|
||||
if route.Type == RouteHTTP {
|
||||
r.sendToHTTP(wrapped)
|
||||
r.logger.Debugf("SNI %q routed to HTTP handler (service_id=%s)", host, route.ServiceID)
|
||||
r.sendToHTTP(wrapped, isTLS)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -344,15 +402,17 @@ func (r *Router) isFallbackOnly() bool {
|
||||
}
|
||||
|
||||
// handleUnmatched routes a connection that didn't match any SNI route.
|
||||
// This includes ECH/ESNI connections where the cleartext SNI is empty.
|
||||
// This includes ECH/ESNI connections where the cleartext SNI is empty,
|
||||
// and plain (non-TLS) HTTP connections when isTLS is false.
|
||||
// It tries the fallback relay first, then the HTTP channel, and closes
|
||||
// the connection if neither is available.
|
||||
func (r *Router) handleUnmatched(ctx context.Context, conn net.Conn) {
|
||||
func (r *Router) handleUnmatched(ctx context.Context, conn net.Conn, isTLS bool) {
|
||||
r.mu.RLock()
|
||||
fb := r.fallback
|
||||
r.mu.RUnlock()
|
||||
|
||||
if fb != nil {
|
||||
r.logger.Debugf("unmatched conn from %s relayed to TCP fallback (service_id=%s, target=%s)", conn.RemoteAddr(), fb.ServiceID, fb.Target)
|
||||
if err := r.relayTCP(ctx, conn, SNIHost("fallback"), *fb); err != nil {
|
||||
if !errors.Is(err, errAccessRestricted) {
|
||||
r.logger.WithFields(log.Fields{
|
||||
@@ -364,7 +424,8 @@ func (r *Router) handleUnmatched(ctx context.Context, conn net.Conn) {
|
||||
}
|
||||
return
|
||||
}
|
||||
r.sendToHTTP(conn)
|
||||
r.logger.Debugf("unmatched conn from %s sent to HTTP channel (no TCP fallback configured)", conn.RemoteAddr())
|
||||
r.sendToHTTP(conn, isTLS)
|
||||
}
|
||||
|
||||
// lookupRoute returns the highest-priority route for the given SNI host.
|
||||
@@ -386,10 +447,20 @@ func (r *Router) lookupRoute(host SNIHost) (Route, bool) {
|
||||
}
|
||||
|
||||
// sendToHTTP feeds the connection to the HTTP handler via the channel.
|
||||
// If no HTTP channel is configured (port router), the router is
|
||||
// draining, or the channel is full, the connection is closed.
|
||||
func (r *Router) sendToHTTP(conn net.Conn) {
|
||||
if r.httpCh == nil {
|
||||
// When isTLS is false and a plain channel is configured the connection
|
||||
// is forwarded to the plain channel; otherwise it lands on the TLS
|
||||
// channel. If no usable channel exists, the router is draining, or the
|
||||
// channel is full, the connection is closed.
|
||||
func (r *Router) sendToHTTP(conn net.Conn, isTLS bool) {
|
||||
ch := r.httpCh
|
||||
chanName := "HTTP"
|
||||
if !isTLS && r.httpPlainCh != nil {
|
||||
ch = r.httpPlainCh
|
||||
chanName = "HTTP-plain"
|
||||
}
|
||||
|
||||
if ch == nil {
|
||||
r.logger.Debugf("%s channel nil; dropping conn from %s", chanName, conn.RemoteAddr())
|
||||
_ = conn.Close()
|
||||
return
|
||||
}
|
||||
@@ -399,14 +470,15 @@ func (r *Router) sendToHTTP(conn net.Conn) {
|
||||
r.mu.RUnlock()
|
||||
|
||||
if draining {
|
||||
r.logger.Debugf("router draining; dropping conn from %s", conn.RemoteAddr())
|
||||
_ = conn.Close()
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case r.httpCh <- conn:
|
||||
case ch <- conn:
|
||||
default:
|
||||
r.logger.Warnf("HTTP channel full, dropping connection from %s", conn.RemoteAddr())
|
||||
r.logger.Warnf("%s channel full, dropping connection from %s", chanName, conn.RemoteAddr())
|
||||
_ = conn.Close()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1739,3 +1739,97 @@ func TestCheckRestrictions_IPv4MappedIPv6(t *testing.T) {
|
||||
connOutside := &fakeConn{remote: fakeAddr("[::ffff:192.168.1.1]:5678")}
|
||||
assert.NotEqual(t, restrict.Allow, router.checkRestrictions(connOutside, route), "::ffff:192.168.1.1 not in v4 CIDR")
|
||||
}
|
||||
|
||||
// TestRouter_PlainHTTP_RoutesToPlainChannel verifies that a plain (non-TLS)
|
||||
// connection lands on the plain HTTP channel when the router was built
|
||||
// with WithPlainHTTP, leaving the TLS channel untouched.
|
||||
func TestRouter_PlainHTTP_RoutesToPlainChannel(t *testing.T) {
|
||||
logger := log.StandardLogger()
|
||||
addr := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 443}
|
||||
|
||||
router := NewRouter(logger, nil, addr, WithPlainHTTP(addr))
|
||||
router.AddRoute("example.com", Route{Type: RouteHTTP})
|
||||
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err, "test listener bind must succeed")
|
||||
defer ln.Close()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
go func() {
|
||||
_ = router.Serve(ctx, ln)
|
||||
}()
|
||||
|
||||
// Plain HTTP request (no TLS handshake byte).
|
||||
go func() {
|
||||
conn, err := net.DialTimeout("tcp", ln.Addr().String(), 2*time.Second)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_, _ = conn.Write([]byte("GET / HTTP/1.1\r\nHost: example.com\r\n\r\n"))
|
||||
}()
|
||||
|
||||
plainListener := router.HTTPListenerPlain()
|
||||
require.NotNil(t, plainListener, "plain listener must be exposed when WithPlainHTTP is set")
|
||||
|
||||
acceptDone := make(chan net.Conn, 1)
|
||||
go func() {
|
||||
conn, err := plainListener.Accept()
|
||||
if err == nil {
|
||||
acceptDone <- conn
|
||||
}
|
||||
}()
|
||||
|
||||
select {
|
||||
case conn := <-acceptDone:
|
||||
require.NotNil(t, conn)
|
||||
_ = conn.Close()
|
||||
case <-router.HTTPListener().(*chanListener).ch:
|
||||
t.Fatal("plain HTTP request leaked into TLS channel")
|
||||
case <-time.After(3 * time.Second):
|
||||
t.Fatal("plain HTTP connection never reached plain channel")
|
||||
}
|
||||
}
|
||||
|
||||
// TestRouter_TLS_StaysOnTLSChannel_WhenPlainEnabled verifies that the
|
||||
// presence of a plain channel does not divert TLS traffic — TLS still
|
||||
// goes to the TLS channel as before.
|
||||
func TestRouter_TLS_StaysOnTLSChannel_WhenPlainEnabled(t *testing.T) {
|
||||
logger := log.StandardLogger()
|
||||
addr := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 443}
|
||||
|
||||
router := NewRouter(logger, nil, addr, WithPlainHTTP(addr))
|
||||
router.AddRoute("example.com", Route{Type: RouteHTTP})
|
||||
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err, "test listener bind must succeed")
|
||||
defer ln.Close()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
go func() { _ = router.Serve(ctx, ln) }()
|
||||
|
||||
// Send a TLS ClientHello.
|
||||
go func() {
|
||||
conn, err := net.DialTimeout("tcp", ln.Addr().String(), 2*time.Second)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
tlsConn := tls.Client(conn, &tls.Config{
|
||||
ServerName: "example.com",
|
||||
InsecureSkipVerify: true, //nolint:gosec
|
||||
})
|
||||
_ = tlsConn.Handshake()
|
||||
_ = tlsConn.Close()
|
||||
}()
|
||||
|
||||
select {
|
||||
case conn := <-router.httpCh:
|
||||
require.NotNil(t, conn, "TLS conn should land on the TLS channel")
|
||||
_ = conn.Close()
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("TLS conn never reached the TLS channel")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -30,26 +30,30 @@ const (
|
||||
// bytes transparently. If the data is not a valid TLS ClientHello or
|
||||
// contains no SNI extension, sni is empty and err is nil.
|
||||
//
|
||||
// isTLS reports whether the first byte indicated a TLS handshake record.
|
||||
// Callers can use this to distinguish plain (non-TLS) traffic from a TLS
|
||||
// stream that simply lacked an SNI extension or used ECH.
|
||||
//
|
||||
// ECH/ESNI: When the client uses Encrypted Client Hello (TLS 1.3), the
|
||||
// real server name is encrypted inside the encrypted_client_hello
|
||||
// extension. This parser only reads the cleartext server_name extension
|
||||
// (type 0x0000), so ECH connections return sni="" and are routed through
|
||||
// the fallback path (or HTTP channel), which is the correct behavior
|
||||
// for a transparent proxy that does not terminate TLS.
|
||||
func PeekClientHello(conn net.Conn) (sni string, wrapped net.Conn, err error) {
|
||||
func PeekClientHello(conn net.Conn) (sni string, wrapped net.Conn, isTLS bool, err error) {
|
||||
// Read the 5-byte TLS record header into a small stack-friendly buffer.
|
||||
var header [tlsRecordHeaderLen]byte
|
||||
if _, err := io.ReadFull(conn, header[:]); err != nil {
|
||||
return "", nil, fmt.Errorf("read TLS record header: %w", err)
|
||||
return "", nil, false, fmt.Errorf("read TLS record header: %w", err)
|
||||
}
|
||||
|
||||
if header[0] != contentTypeHandshake {
|
||||
return "", newPeekedConn(conn, header[:]), nil
|
||||
return "", newPeekedConn(conn, header[:]), false, nil
|
||||
}
|
||||
|
||||
recordLen := int(binary.BigEndian.Uint16(header[3:5]))
|
||||
if recordLen == 0 || recordLen > maxClientHelloLen {
|
||||
return "", newPeekedConn(conn, header[:]), nil
|
||||
return "", newPeekedConn(conn, header[:]), true, nil
|
||||
}
|
||||
|
||||
// Single allocation for header + payload. The peekedConn takes
|
||||
@@ -59,11 +63,11 @@ func PeekClientHello(conn net.Conn) (sni string, wrapped net.Conn, err error) {
|
||||
|
||||
n, err := io.ReadFull(conn, buf[tlsRecordHeaderLen:])
|
||||
if err != nil {
|
||||
return "", newPeekedConn(conn, buf[:tlsRecordHeaderLen+n]), fmt.Errorf("read TLS handshake payload: %w", err)
|
||||
return "", newPeekedConn(conn, buf[:tlsRecordHeaderLen+n]), true, fmt.Errorf("read TLS handshake payload: %w", err)
|
||||
}
|
||||
|
||||
sni = extractSNI(buf[tlsRecordHeaderLen:])
|
||||
return sni, newPeekedConn(conn, buf), nil
|
||||
return sni, newPeekedConn(conn, buf), true, nil
|
||||
}
|
||||
|
||||
// extractSNI parses a TLS handshake payload to find the SNI extension.
|
||||
|
||||
@@ -29,10 +29,11 @@ func TestPeekClientHello_ValidSNI(t *testing.T) {
|
||||
_ = tlsConn.Handshake()
|
||||
}()
|
||||
|
||||
sni, wrapped, err := PeekClientHello(serverConn)
|
||||
sni, wrapped, isTLS, err := PeekClientHello(serverConn)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, expectedSNI, sni, "should extract SNI from ClientHello")
|
||||
assert.NotNil(t, wrapped, "wrapped connection should not be nil")
|
||||
assert.True(t, isTLS, "TLS ClientHello should be flagged as TLS")
|
||||
|
||||
// Verify the wrapped connection replays the peeked bytes.
|
||||
// Read the first 5 bytes (TLS record header) to confirm replay.
|
||||
@@ -83,10 +84,11 @@ func TestPeekClientHello_MultipleSNIs(t *testing.T) {
|
||||
_ = tlsConn.Handshake()
|
||||
}()
|
||||
|
||||
sni, wrapped, err := PeekClientHello(serverConn)
|
||||
sni, wrapped, isTLS, err := PeekClientHello(serverConn)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.expectedSNI, sni)
|
||||
assert.NotNil(t, wrapped)
|
||||
assert.True(t, isTLS, "TLS handshake should be flagged as TLS")
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -102,10 +104,11 @@ func TestPeekClientHello_NonTLSData(t *testing.T) {
|
||||
_, _ = clientConn.Write(httpData)
|
||||
}()
|
||||
|
||||
sni, wrapped, err := PeekClientHello(serverConn)
|
||||
sni, wrapped, isTLS, err := PeekClientHello(serverConn)
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, sni, "should return empty SNI for non-TLS data")
|
||||
assert.NotNil(t, wrapped)
|
||||
assert.False(t, isTLS, "plain HTTP data should not be flagged as TLS")
|
||||
|
||||
// Verify the wrapped connection still provides the original data.
|
||||
buf := make([]byte, len(httpData))
|
||||
@@ -124,7 +127,7 @@ func TestPeekClientHello_TruncatedHeader(t *testing.T) {
|
||||
clientConn.Close()
|
||||
}()
|
||||
|
||||
_, _, err := PeekClientHello(serverConn)
|
||||
_, _, _, err := PeekClientHello(serverConn)
|
||||
assert.Error(t, err, "should error on truncated header")
|
||||
}
|
||||
|
||||
@@ -140,7 +143,7 @@ func TestPeekClientHello_TruncatedPayload(t *testing.T) {
|
||||
clientConn.Close()
|
||||
}()
|
||||
|
||||
_, _, err := PeekClientHello(serverConn)
|
||||
_, _, _, err := PeekClientHello(serverConn)
|
||||
assert.Error(t, err, "should error on truncated payload")
|
||||
}
|
||||
|
||||
@@ -154,10 +157,11 @@ func TestPeekClientHello_ZeroLengthRecord(t *testing.T) {
|
||||
_, _ = clientConn.Write([]byte{0x16, 0x03, 0x01, 0x00, 0x00})
|
||||
}()
|
||||
|
||||
sni, wrapped, err := PeekClientHello(serverConn)
|
||||
sni, wrapped, isTLS, err := PeekClientHello(serverConn)
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, sni)
|
||||
assert.NotNil(t, wrapped)
|
||||
assert.True(t, isTLS, "zero-length record should still be a TLS handshake byte")
|
||||
}
|
||||
|
||||
func TestExtractSNI_InvalidPayload(t *testing.T) {
|
||||
|
||||
@@ -54,3 +54,23 @@ func DialTimeoutFromContext(ctx context.Context) (time.Duration, bool) {
|
||||
d, ok := ctx.Value(dialTimeoutKey{}).(time.Duration)
|
||||
return d, ok && d > 0
|
||||
}
|
||||
|
||||
// overlayOriginKey is the context key set by per-account inbound
|
||||
// listeners to mark a request as originating from the WireGuard
|
||||
// overlay rather than the public-facing host listener.
|
||||
type overlayOriginKey struct{}
|
||||
|
||||
// WithOverlayOrigin marks the context as originating from the
|
||||
// embedded NetBird overlay (tunnel-side inbound listener).
|
||||
func WithOverlayOrigin(ctx context.Context) context.Context {
|
||||
return context.WithValue(ctx, overlayOriginKey{}, true)
|
||||
}
|
||||
|
||||
// IsOverlayOrigin reports whether the request reached the proxy via
|
||||
// the overlay listener. Middlewares that only make sense for WAN
|
||||
// traffic (geolocation, CrowdSec IP reputation) should short-circuit
|
||||
// when this is true.
|
||||
func IsOverlayOrigin(ctx context.Context) bool {
|
||||
v, _ := ctx.Value(overlayOriginKey{}).(bool)
|
||||
return v
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user