[proxy, management] Add header auth, access restrictions, and session idle timeout (#5587)

This commit is contained in:
Viktor Liu
2026-03-16 22:22:00 +08:00
committed by GitHub
parent 3e6baea405
commit 387e374e4b
34 changed files with 3509 additions and 1380 deletions

View File

@@ -4,9 +4,12 @@ import (
"context"
"crypto/ed25519"
"encoding/base64"
"errors"
"fmt"
"html"
"net"
"net/http"
"net/netip"
"net/url"
"sync"
"time"
@@ -16,11 +19,16 @@ import (
"github.com/netbirdio/netbird/proxy/auth"
"github.com/netbirdio/netbird/proxy/internal/proxy"
"github.com/netbirdio/netbird/proxy/internal/restrict"
"github.com/netbirdio/netbird/proxy/internal/types"
"github.com/netbirdio/netbird/proxy/web"
"github.com/netbirdio/netbird/shared/management/proto"
)
// errValidationUnavailable indicates that session validation failed due to
// an infrastructure error (e.g. gRPC unavailable), not an invalid token.
var errValidationUnavailable = errors.New("session validation unavailable")
type authenticator interface {
Authenticate(ctx context.Context, in *proto.AuthenticateRequest, opts ...grpc.CallOption) (*proto.AuthenticateResponse, error)
}
@@ -40,12 +48,14 @@ type Scheme interface {
Authenticate(*http.Request) (token string, promptData string, err error)
}
// DomainConfig holds the authentication and restriction settings for a protected domain.
type DomainConfig struct {
Schemes []Scheme
SessionPublicKey ed25519.PublicKey
SessionExpiration time.Duration
AccountID types.AccountID
ServiceID types.ServiceID
IPRestrictions *restrict.Filter
}
type validationResult struct {
@@ -54,17 +64,18 @@ type validationResult struct {
DeniedReason string
}
// Middleware applies per-domain authentication and IP restriction checks.
type Middleware struct {
domainsMux sync.RWMutex
domains map[string]DomainConfig
logger *log.Logger
sessionValidator SessionValidator
geo restrict.GeoResolver
}
// NewMiddleware creates a new authentication middleware.
// The sessionValidator is optional; if nil, OIDC session tokens will be validated
// locally without group access checks.
func NewMiddleware(logger *log.Logger, sessionValidator SessionValidator) *Middleware {
// NewMiddleware creates a new authentication middleware. The sessionValidator is
// optional; if nil, OIDC session tokens are validated locally without group access checks.
func NewMiddleware(logger *log.Logger, sessionValidator SessionValidator, geo restrict.GeoResolver) *Middleware {
if logger == nil {
logger = log.StandardLogger()
}
@@ -72,18 +83,12 @@ func NewMiddleware(logger *log.Logger, sessionValidator SessionValidator) *Middl
domains: make(map[string]DomainConfig),
logger: logger,
sessionValidator: sessionValidator,
geo: geo,
}
}
// Protect applies authentication middleware to the passed handler.
// For each incoming request it will be checked against the middleware's
// internal list of protected domains.
// If the Host domain in the inbound request is not present, then it will
// simply be passed through.
// However, if the Host domain is present, then the specified authentication
// schemes for that domain will be applied to the request.
// In the event that no authentication schemes are defined for the domain,
// then the request will also be simply passed through.
// Protect wraps next with per-domain authentication and IP restriction checks.
// Requests whose Host is not registered pass through unchanged.
func (mw *Middleware) Protect(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
host, _, err := net.SplitHostPort(r.Host)
@@ -94,8 +99,7 @@ func (mw *Middleware) Protect(next http.Handler) http.Handler {
config, exists := mw.getDomainConfig(host)
mw.logger.Debugf("checking authentication for host: %s, exists: %t", host, exists)
// Domains that are not configured here or have no authentication schemes applied should simply pass through.
if !exists || len(config.Schemes) == 0 {
if !exists {
next.ServeHTTP(w, r)
return
}
@@ -103,6 +107,16 @@ func (mw *Middleware) Protect(next http.Handler) http.Handler {
// Set account and service IDs in captured data for access logging.
setCapturedIDs(r, config)
if !mw.checkIPRestrictions(w, r, config) {
return
}
// Domains with no authentication schemes pass through after IP checks.
if len(config.Schemes) == 0 {
next.ServeHTTP(w, r)
return
}
if mw.handleOAuthCallbackError(w, r) {
return
}
@@ -111,6 +125,10 @@ func (mw *Middleware) Protect(next http.Handler) http.Handler {
return
}
if mw.forwardWithHeaderAuth(w, r, host, config, next) {
return
}
mw.authenticateWithSchemes(w, r, host, config)
})
}
@@ -124,11 +142,65 @@ func (mw *Middleware) getDomainConfig(host string) (DomainConfig, bool) {
func setCapturedIDs(r *http.Request, config DomainConfig) {
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
cd.SetAccountId(config.AccountID)
cd.SetServiceId(config.ServiceID)
cd.SetAccountID(config.AccountID)
cd.SetServiceID(config.ServiceID)
}
}
// checkIPRestrictions validates the client IP against the domain's IP restrictions.
// Uses the resolved client IP from CapturedData (which accounts for trusted proxies)
// rather than r.RemoteAddr directly.
func (mw *Middleware) checkIPRestrictions(w http.ResponseWriter, r *http.Request, config DomainConfig) bool {
if config.IPRestrictions == nil {
return true
}
clientIP := mw.resolveClientIP(r)
if !clientIP.IsValid() {
mw.logger.Debugf("IP restriction: cannot resolve client address for %q, denying", r.RemoteAddr)
http.Error(w, "Forbidden", http.StatusForbidden)
return false
}
verdict := config.IPRestrictions.Check(clientIP, mw.geo)
if verdict == restrict.Allow {
return true
}
reason := verdict.String()
mw.blockIPRestriction(r, reason)
http.Error(w, "Forbidden", http.StatusForbidden)
return false
}
// resolveClientIP extracts the real client IP from CapturedData, falling back to r.RemoteAddr.
func (mw *Middleware) resolveClientIP(r *http.Request) netip.Addr {
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
if ip := cd.GetClientIP(); ip.IsValid() {
return ip
}
}
clientIPStr, _, _ := net.SplitHostPort(r.RemoteAddr)
if clientIPStr == "" {
clientIPStr = r.RemoteAddr
}
addr, err := netip.ParseAddr(clientIPStr)
if err != nil {
return netip.Addr{}
}
return addr.Unmap()
}
// blockIPRestriction sets captured data fields for an IP-restriction block event.
func (mw *Middleware) blockIPRestriction(r *http.Request, reason string) {
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
cd.SetOrigin(proxy.OriginAuth)
cd.SetAuthMethod(reason)
}
mw.logger.Debugf("IP restriction: %s for %s", reason, r.RemoteAddr)
}
// handleOAuthCallbackError checks for error query parameters from an OAuth
// callback and renders the access denied page if present.
func (mw *Middleware) handleOAuthCallbackError(w http.ResponseWriter, r *http.Request) bool {
@@ -146,6 +218,8 @@ func (mw *Middleware) handleOAuthCallbackError(w http.ResponseWriter, r *http.Re
errDesc := r.URL.Query().Get("error_description")
if errDesc == "" {
errDesc = "An error occurred during authentication"
} else {
errDesc = html.EscapeString(errDesc)
}
web.ServeAccessDeniedPage(w, r, http.StatusForbidden, "Access Denied", errDesc, requestID)
return true
@@ -170,6 +244,85 @@ func (mw *Middleware) forwardWithSessionCookie(w http.ResponseWriter, r *http.Re
return true
}
// forwardWithHeaderAuth checks for a Header auth scheme. If the header validates,
// the request is forwarded directly (no redirect), which is important for API clients.
func (mw *Middleware) forwardWithHeaderAuth(w http.ResponseWriter, r *http.Request, host string, config DomainConfig, next http.Handler) bool {
for _, scheme := range config.Schemes {
hdr, ok := scheme.(Header)
if !ok {
continue
}
handled := mw.tryHeaderScheme(w, r, host, config, hdr, next)
if handled {
return true
}
}
return false
}
func (mw *Middleware) tryHeaderScheme(w http.ResponseWriter, r *http.Request, host string, config DomainConfig, hdr Header, next http.Handler) bool {
token, _, err := hdr.Authenticate(r)
if err != nil {
return mw.handleHeaderAuthError(w, r, err)
}
if token == "" {
return false
}
result, err := mw.validateSessionToken(r.Context(), host, token, config.SessionPublicKey, auth.MethodHeader)
if err != nil {
setHeaderCapturedData(r.Context(), "")
status := http.StatusBadRequest
msg := "invalid session token"
if errors.Is(err, errValidationUnavailable) {
status = http.StatusBadGateway
msg = "authentication service unavailable"
}
http.Error(w, msg, status)
return true
}
if !result.Valid {
setHeaderCapturedData(r.Context(), result.UserID)
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return true
}
setSessionCookie(w, token, config.SessionExpiration)
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
cd.SetUserID(result.UserID)
cd.SetAuthMethod(auth.MethodHeader.String())
}
next.ServeHTTP(w, r)
return true
}
func (mw *Middleware) handleHeaderAuthError(w http.ResponseWriter, r *http.Request, err error) bool {
if errors.Is(err, ErrHeaderAuthFailed) {
setHeaderCapturedData(r.Context(), "")
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return true
}
mw.logger.WithField("scheme", "header").Warnf("header auth infrastructure error: %v", err)
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
cd.SetOrigin(proxy.OriginAuth)
}
http.Error(w, "authentication service unavailable", http.StatusBadGateway)
return true
}
func setHeaderCapturedData(ctx context.Context, userID string) {
cd := proxy.CapturedDataFromContext(ctx)
if cd == nil {
return
}
cd.SetOrigin(proxy.OriginAuth)
cd.SetAuthMethod(auth.MethodHeader.String())
cd.SetUserID(userID)
}
// authenticateWithSchemes tries each configured auth scheme in order.
// On success it sets a session cookie and redirects; on failure it renders the login page.
func (mw *Middleware) authenticateWithSchemes(w http.ResponseWriter, r *http.Request, host string, config DomainConfig) {
@@ -217,7 +370,13 @@ func (mw *Middleware) handleAuthenticatedToken(w http.ResponseWriter, r *http.Re
cd.SetOrigin(proxy.OriginAuth)
cd.SetAuthMethod(scheme.Type().String())
}
http.Error(w, err.Error(), http.StatusBadRequest)
status := http.StatusBadRequest
msg := "invalid session token"
if errors.Is(err, errValidationUnavailable) {
status = http.StatusBadGateway
msg = "authentication service unavailable"
}
http.Error(w, msg, status)
return
}
@@ -233,7 +392,21 @@ func (mw *Middleware) handleAuthenticatedToken(w http.ResponseWriter, r *http.Re
return
}
expiration := config.SessionExpiration
setSessionCookie(w, token, config.SessionExpiration)
// Redirect instead of forwarding the auth POST to the backend.
// The browser will follow with a GET carrying the new session cookie.
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
cd.SetOrigin(proxy.OriginAuth)
cd.SetUserID(result.UserID)
cd.SetAuthMethod(scheme.Type().String())
}
redirectURL := stripSessionTokenParam(r.URL)
http.Redirect(w, r, redirectURL, http.StatusSeeOther)
}
// setSessionCookie writes a session cookie with secure defaults.
func setSessionCookie(w http.ResponseWriter, token string, expiration time.Duration) {
if expiration == 0 {
expiration = auth.DefaultSessionExpiry
}
@@ -245,16 +418,6 @@ func (mw *Middleware) handleAuthenticatedToken(w http.ResponseWriter, r *http.Re
SameSite: http.SameSiteLaxMode,
MaxAge: int(expiration.Seconds()),
})
// Redirect instead of forwarding the auth POST to the backend.
// The browser will follow with a GET carrying the new session cookie.
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
cd.SetOrigin(proxy.OriginAuth)
cd.SetUserID(result.UserID)
cd.SetAuthMethod(scheme.Type().String())
}
redirectURL := stripSessionTokenParam(r.URL)
http.Redirect(w, r, redirectURL, http.StatusSeeOther)
}
// wasCredentialSubmitted checks if credentials were submitted for the given auth method.
@@ -275,13 +438,14 @@ func wasCredentialSubmitted(r *http.Request, method auth.Method) bool {
// session JWTs. Returns an error if the key is missing or invalid.
// Callers must not serve the domain if this returns an error, to avoid
// exposing an unauthenticated service.
func (mw *Middleware) AddDomain(domain string, schemes []Scheme, publicKeyB64 string, expiration time.Duration, accountID types.AccountID, serviceID types.ServiceID) error {
func (mw *Middleware) AddDomain(domain string, schemes []Scheme, publicKeyB64 string, expiration time.Duration, accountID types.AccountID, serviceID types.ServiceID, ipRestrictions *restrict.Filter) error {
if len(schemes) == 0 {
mw.domainsMux.Lock()
defer mw.domainsMux.Unlock()
mw.domains[domain] = DomainConfig{
AccountID: accountID,
ServiceID: serviceID,
AccountID: accountID,
ServiceID: serviceID,
IPRestrictions: ipRestrictions,
}
return nil
}
@@ -302,30 +466,28 @@ func (mw *Middleware) AddDomain(domain string, schemes []Scheme, publicKeyB64 st
SessionExpiration: expiration,
AccountID: accountID,
ServiceID: serviceID,
IPRestrictions: ipRestrictions,
}
return nil
}
// RemoveDomain unregisters authentication for the given domain.
func (mw *Middleware) RemoveDomain(domain string) {
mw.domainsMux.Lock()
defer mw.domainsMux.Unlock()
delete(mw.domains, domain)
}
// validateSessionToken validates a session token, optionally checking group access via gRPC.
// For OIDC tokens with a configured validator, it calls ValidateSession to check group access.
// For other auth methods (PIN, password), it validates the JWT locally.
// Returns a validationResult with user ID and validity status, or error for invalid tokens.
// validateSessionToken validates a session token. OIDC tokens with a configured
// validator go through gRPC for group access checks; other methods validate locally.
func (mw *Middleware) validateSessionToken(ctx context.Context, host, token string, publicKey ed25519.PublicKey, method auth.Method) (*validationResult, error) {
// For OIDC with a session validator, call the gRPC service to check group access
if method == auth.MethodOIDC && mw.sessionValidator != nil {
resp, err := mw.sessionValidator.ValidateSession(ctx, &proto.ValidateSessionRequest{
Domain: host,
SessionToken: token,
})
if err != nil {
mw.logger.WithError(err).Error("ValidateSession gRPC call failed")
return nil, fmt.Errorf("session validation failed")
return nil, fmt.Errorf("%w: %w", errValidationUnavailable, err)
}
if !resp.Valid {
mw.logger.WithFields(log.Fields{
@@ -342,7 +504,6 @@ func (mw *Middleware) validateSessionToken(ctx context.Context, host, token stri
return &validationResult{UserID: resp.UserId, Valid: true}, nil
}
// For non-OIDC methods or when no validator is configured, validate JWT locally
userID, _, err := auth.ValidateSessionJWT(token, host, publicKey)
if err != nil {
return nil, err