mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-18 08:16:39 +00:00
[proxy, management] Add header auth, access restrictions, and session idle timeout (#5587)
This commit is contained in:
@@ -4,9 +4,12 @@ import (
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"html"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -16,11 +19,16 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/auth"
|
||||
"github.com/netbirdio/netbird/proxy/internal/proxy"
|
||||
"github.com/netbirdio/netbird/proxy/internal/restrict"
|
||||
"github.com/netbirdio/netbird/proxy/internal/types"
|
||||
"github.com/netbirdio/netbird/proxy/web"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
// errValidationUnavailable indicates that session validation failed due to
|
||||
// an infrastructure error (e.g. gRPC unavailable), not an invalid token.
|
||||
var errValidationUnavailable = errors.New("session validation unavailable")
|
||||
|
||||
type authenticator interface {
|
||||
Authenticate(ctx context.Context, in *proto.AuthenticateRequest, opts ...grpc.CallOption) (*proto.AuthenticateResponse, error)
|
||||
}
|
||||
@@ -40,12 +48,14 @@ type Scheme interface {
|
||||
Authenticate(*http.Request) (token string, promptData string, err error)
|
||||
}
|
||||
|
||||
// DomainConfig holds the authentication and restriction settings for a protected domain.
|
||||
type DomainConfig struct {
|
||||
Schemes []Scheme
|
||||
SessionPublicKey ed25519.PublicKey
|
||||
SessionExpiration time.Duration
|
||||
AccountID types.AccountID
|
||||
ServiceID types.ServiceID
|
||||
IPRestrictions *restrict.Filter
|
||||
}
|
||||
|
||||
type validationResult struct {
|
||||
@@ -54,17 +64,18 @@ type validationResult struct {
|
||||
DeniedReason string
|
||||
}
|
||||
|
||||
// Middleware applies per-domain authentication and IP restriction checks.
|
||||
type Middleware struct {
|
||||
domainsMux sync.RWMutex
|
||||
domains map[string]DomainConfig
|
||||
logger *log.Logger
|
||||
sessionValidator SessionValidator
|
||||
geo restrict.GeoResolver
|
||||
}
|
||||
|
||||
// NewMiddleware creates a new authentication middleware.
|
||||
// The sessionValidator is optional; if nil, OIDC session tokens will be validated
|
||||
// locally without group access checks.
|
||||
func NewMiddleware(logger *log.Logger, sessionValidator SessionValidator) *Middleware {
|
||||
// NewMiddleware creates a new authentication middleware. The sessionValidator is
|
||||
// optional; if nil, OIDC session tokens are validated locally without group access checks.
|
||||
func NewMiddleware(logger *log.Logger, sessionValidator SessionValidator, geo restrict.GeoResolver) *Middleware {
|
||||
if logger == nil {
|
||||
logger = log.StandardLogger()
|
||||
}
|
||||
@@ -72,18 +83,12 @@ func NewMiddleware(logger *log.Logger, sessionValidator SessionValidator) *Middl
|
||||
domains: make(map[string]DomainConfig),
|
||||
logger: logger,
|
||||
sessionValidator: sessionValidator,
|
||||
geo: geo,
|
||||
}
|
||||
}
|
||||
|
||||
// Protect applies authentication middleware to the passed handler.
|
||||
// For each incoming request it will be checked against the middleware's
|
||||
// internal list of protected domains.
|
||||
// If the Host domain in the inbound request is not present, then it will
|
||||
// simply be passed through.
|
||||
// However, if the Host domain is present, then the specified authentication
|
||||
// schemes for that domain will be applied to the request.
|
||||
// In the event that no authentication schemes are defined for the domain,
|
||||
// then the request will also be simply passed through.
|
||||
// Protect wraps next with per-domain authentication and IP restriction checks.
|
||||
// Requests whose Host is not registered pass through unchanged.
|
||||
func (mw *Middleware) Protect(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
host, _, err := net.SplitHostPort(r.Host)
|
||||
@@ -94,8 +99,7 @@ func (mw *Middleware) Protect(next http.Handler) http.Handler {
|
||||
config, exists := mw.getDomainConfig(host)
|
||||
mw.logger.Debugf("checking authentication for host: %s, exists: %t", host, exists)
|
||||
|
||||
// Domains that are not configured here or have no authentication schemes applied should simply pass through.
|
||||
if !exists || len(config.Schemes) == 0 {
|
||||
if !exists {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
@@ -103,6 +107,16 @@ func (mw *Middleware) Protect(next http.Handler) http.Handler {
|
||||
// Set account and service IDs in captured data for access logging.
|
||||
setCapturedIDs(r, config)
|
||||
|
||||
if !mw.checkIPRestrictions(w, r, config) {
|
||||
return
|
||||
}
|
||||
|
||||
// Domains with no authentication schemes pass through after IP checks.
|
||||
if len(config.Schemes) == 0 {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
if mw.handleOAuthCallbackError(w, r) {
|
||||
return
|
||||
}
|
||||
@@ -111,6 +125,10 @@ func (mw *Middleware) Protect(next http.Handler) http.Handler {
|
||||
return
|
||||
}
|
||||
|
||||
if mw.forwardWithHeaderAuth(w, r, host, config, next) {
|
||||
return
|
||||
}
|
||||
|
||||
mw.authenticateWithSchemes(w, r, host, config)
|
||||
})
|
||||
}
|
||||
@@ -124,11 +142,65 @@ func (mw *Middleware) getDomainConfig(host string) (DomainConfig, bool) {
|
||||
|
||||
func setCapturedIDs(r *http.Request, config DomainConfig) {
|
||||
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
|
||||
cd.SetAccountId(config.AccountID)
|
||||
cd.SetServiceId(config.ServiceID)
|
||||
cd.SetAccountID(config.AccountID)
|
||||
cd.SetServiceID(config.ServiceID)
|
||||
}
|
||||
}
|
||||
|
||||
// checkIPRestrictions validates the client IP against the domain's IP restrictions.
|
||||
// Uses the resolved client IP from CapturedData (which accounts for trusted proxies)
|
||||
// rather than r.RemoteAddr directly.
|
||||
func (mw *Middleware) checkIPRestrictions(w http.ResponseWriter, r *http.Request, config DomainConfig) bool {
|
||||
if config.IPRestrictions == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
clientIP := mw.resolveClientIP(r)
|
||||
if !clientIP.IsValid() {
|
||||
mw.logger.Debugf("IP restriction: cannot resolve client address for %q, denying", r.RemoteAddr)
|
||||
http.Error(w, "Forbidden", http.StatusForbidden)
|
||||
return false
|
||||
}
|
||||
|
||||
verdict := config.IPRestrictions.Check(clientIP, mw.geo)
|
||||
if verdict == restrict.Allow {
|
||||
return true
|
||||
}
|
||||
|
||||
reason := verdict.String()
|
||||
mw.blockIPRestriction(r, reason)
|
||||
http.Error(w, "Forbidden", http.StatusForbidden)
|
||||
return false
|
||||
}
|
||||
|
||||
// resolveClientIP extracts the real client IP from CapturedData, falling back to r.RemoteAddr.
|
||||
func (mw *Middleware) resolveClientIP(r *http.Request) netip.Addr {
|
||||
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
|
||||
if ip := cd.GetClientIP(); ip.IsValid() {
|
||||
return ip
|
||||
}
|
||||
}
|
||||
|
||||
clientIPStr, _, _ := net.SplitHostPort(r.RemoteAddr)
|
||||
if clientIPStr == "" {
|
||||
clientIPStr = r.RemoteAddr
|
||||
}
|
||||
addr, err := netip.ParseAddr(clientIPStr)
|
||||
if err != nil {
|
||||
return netip.Addr{}
|
||||
}
|
||||
return addr.Unmap()
|
||||
}
|
||||
|
||||
// blockIPRestriction sets captured data fields for an IP-restriction block event.
|
||||
func (mw *Middleware) blockIPRestriction(r *http.Request, reason string) {
|
||||
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
|
||||
cd.SetOrigin(proxy.OriginAuth)
|
||||
cd.SetAuthMethod(reason)
|
||||
}
|
||||
mw.logger.Debugf("IP restriction: %s for %s", reason, r.RemoteAddr)
|
||||
}
|
||||
|
||||
// handleOAuthCallbackError checks for error query parameters from an OAuth
|
||||
// callback and renders the access denied page if present.
|
||||
func (mw *Middleware) handleOAuthCallbackError(w http.ResponseWriter, r *http.Request) bool {
|
||||
@@ -146,6 +218,8 @@ func (mw *Middleware) handleOAuthCallbackError(w http.ResponseWriter, r *http.Re
|
||||
errDesc := r.URL.Query().Get("error_description")
|
||||
if errDesc == "" {
|
||||
errDesc = "An error occurred during authentication"
|
||||
} else {
|
||||
errDesc = html.EscapeString(errDesc)
|
||||
}
|
||||
web.ServeAccessDeniedPage(w, r, http.StatusForbidden, "Access Denied", errDesc, requestID)
|
||||
return true
|
||||
@@ -170,6 +244,85 @@ func (mw *Middleware) forwardWithSessionCookie(w http.ResponseWriter, r *http.Re
|
||||
return true
|
||||
}
|
||||
|
||||
// forwardWithHeaderAuth checks for a Header auth scheme. If the header validates,
|
||||
// the request is forwarded directly (no redirect), which is important for API clients.
|
||||
func (mw *Middleware) forwardWithHeaderAuth(w http.ResponseWriter, r *http.Request, host string, config DomainConfig, next http.Handler) bool {
|
||||
for _, scheme := range config.Schemes {
|
||||
hdr, ok := scheme.(Header)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
handled := mw.tryHeaderScheme(w, r, host, config, hdr, next)
|
||||
if handled {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (mw *Middleware) tryHeaderScheme(w http.ResponseWriter, r *http.Request, host string, config DomainConfig, hdr Header, next http.Handler) bool {
|
||||
token, _, err := hdr.Authenticate(r)
|
||||
if err != nil {
|
||||
return mw.handleHeaderAuthError(w, r, err)
|
||||
}
|
||||
if token == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
result, err := mw.validateSessionToken(r.Context(), host, token, config.SessionPublicKey, auth.MethodHeader)
|
||||
if err != nil {
|
||||
setHeaderCapturedData(r.Context(), "")
|
||||
status := http.StatusBadRequest
|
||||
msg := "invalid session token"
|
||||
if errors.Is(err, errValidationUnavailable) {
|
||||
status = http.StatusBadGateway
|
||||
msg = "authentication service unavailable"
|
||||
}
|
||||
http.Error(w, msg, status)
|
||||
return true
|
||||
}
|
||||
|
||||
if !result.Valid {
|
||||
setHeaderCapturedData(r.Context(), result.UserID)
|
||||
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||||
return true
|
||||
}
|
||||
|
||||
setSessionCookie(w, token, config.SessionExpiration)
|
||||
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
|
||||
cd.SetUserID(result.UserID)
|
||||
cd.SetAuthMethod(auth.MethodHeader.String())
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
return true
|
||||
}
|
||||
|
||||
func (mw *Middleware) handleHeaderAuthError(w http.ResponseWriter, r *http.Request, err error) bool {
|
||||
if errors.Is(err, ErrHeaderAuthFailed) {
|
||||
setHeaderCapturedData(r.Context(), "")
|
||||
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||||
return true
|
||||
}
|
||||
mw.logger.WithField("scheme", "header").Warnf("header auth infrastructure error: %v", err)
|
||||
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
|
||||
cd.SetOrigin(proxy.OriginAuth)
|
||||
}
|
||||
http.Error(w, "authentication service unavailable", http.StatusBadGateway)
|
||||
return true
|
||||
}
|
||||
|
||||
func setHeaderCapturedData(ctx context.Context, userID string) {
|
||||
cd := proxy.CapturedDataFromContext(ctx)
|
||||
if cd == nil {
|
||||
return
|
||||
}
|
||||
cd.SetOrigin(proxy.OriginAuth)
|
||||
cd.SetAuthMethod(auth.MethodHeader.String())
|
||||
cd.SetUserID(userID)
|
||||
}
|
||||
|
||||
// authenticateWithSchemes tries each configured auth scheme in order.
|
||||
// On success it sets a session cookie and redirects; on failure it renders the login page.
|
||||
func (mw *Middleware) authenticateWithSchemes(w http.ResponseWriter, r *http.Request, host string, config DomainConfig) {
|
||||
@@ -217,7 +370,13 @@ func (mw *Middleware) handleAuthenticatedToken(w http.ResponseWriter, r *http.Re
|
||||
cd.SetOrigin(proxy.OriginAuth)
|
||||
cd.SetAuthMethod(scheme.Type().String())
|
||||
}
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
status := http.StatusBadRequest
|
||||
msg := "invalid session token"
|
||||
if errors.Is(err, errValidationUnavailable) {
|
||||
status = http.StatusBadGateway
|
||||
msg = "authentication service unavailable"
|
||||
}
|
||||
http.Error(w, msg, status)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -233,7 +392,21 @@ func (mw *Middleware) handleAuthenticatedToken(w http.ResponseWriter, r *http.Re
|
||||
return
|
||||
}
|
||||
|
||||
expiration := config.SessionExpiration
|
||||
setSessionCookie(w, token, config.SessionExpiration)
|
||||
|
||||
// Redirect instead of forwarding the auth POST to the backend.
|
||||
// The browser will follow with a GET carrying the new session cookie.
|
||||
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
|
||||
cd.SetOrigin(proxy.OriginAuth)
|
||||
cd.SetUserID(result.UserID)
|
||||
cd.SetAuthMethod(scheme.Type().String())
|
||||
}
|
||||
redirectURL := stripSessionTokenParam(r.URL)
|
||||
http.Redirect(w, r, redirectURL, http.StatusSeeOther)
|
||||
}
|
||||
|
||||
// setSessionCookie writes a session cookie with secure defaults.
|
||||
func setSessionCookie(w http.ResponseWriter, token string, expiration time.Duration) {
|
||||
if expiration == 0 {
|
||||
expiration = auth.DefaultSessionExpiry
|
||||
}
|
||||
@@ -245,16 +418,6 @@ func (mw *Middleware) handleAuthenticatedToken(w http.ResponseWriter, r *http.Re
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
MaxAge: int(expiration.Seconds()),
|
||||
})
|
||||
|
||||
// Redirect instead of forwarding the auth POST to the backend.
|
||||
// The browser will follow with a GET carrying the new session cookie.
|
||||
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
|
||||
cd.SetOrigin(proxy.OriginAuth)
|
||||
cd.SetUserID(result.UserID)
|
||||
cd.SetAuthMethod(scheme.Type().String())
|
||||
}
|
||||
redirectURL := stripSessionTokenParam(r.URL)
|
||||
http.Redirect(w, r, redirectURL, http.StatusSeeOther)
|
||||
}
|
||||
|
||||
// wasCredentialSubmitted checks if credentials were submitted for the given auth method.
|
||||
@@ -275,13 +438,14 @@ func wasCredentialSubmitted(r *http.Request, method auth.Method) bool {
|
||||
// session JWTs. Returns an error if the key is missing or invalid.
|
||||
// Callers must not serve the domain if this returns an error, to avoid
|
||||
// exposing an unauthenticated service.
|
||||
func (mw *Middleware) AddDomain(domain string, schemes []Scheme, publicKeyB64 string, expiration time.Duration, accountID types.AccountID, serviceID types.ServiceID) error {
|
||||
func (mw *Middleware) AddDomain(domain string, schemes []Scheme, publicKeyB64 string, expiration time.Duration, accountID types.AccountID, serviceID types.ServiceID, ipRestrictions *restrict.Filter) error {
|
||||
if len(schemes) == 0 {
|
||||
mw.domainsMux.Lock()
|
||||
defer mw.domainsMux.Unlock()
|
||||
mw.domains[domain] = DomainConfig{
|
||||
AccountID: accountID,
|
||||
ServiceID: serviceID,
|
||||
AccountID: accountID,
|
||||
ServiceID: serviceID,
|
||||
IPRestrictions: ipRestrictions,
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -302,30 +466,28 @@ func (mw *Middleware) AddDomain(domain string, schemes []Scheme, publicKeyB64 st
|
||||
SessionExpiration: expiration,
|
||||
AccountID: accountID,
|
||||
ServiceID: serviceID,
|
||||
IPRestrictions: ipRestrictions,
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveDomain unregisters authentication for the given domain.
|
||||
func (mw *Middleware) RemoveDomain(domain string) {
|
||||
mw.domainsMux.Lock()
|
||||
defer mw.domainsMux.Unlock()
|
||||
delete(mw.domains, domain)
|
||||
}
|
||||
|
||||
// validateSessionToken validates a session token, optionally checking group access via gRPC.
|
||||
// For OIDC tokens with a configured validator, it calls ValidateSession to check group access.
|
||||
// For other auth methods (PIN, password), it validates the JWT locally.
|
||||
// Returns a validationResult with user ID and validity status, or error for invalid tokens.
|
||||
// validateSessionToken validates a session token. OIDC tokens with a configured
|
||||
// validator go through gRPC for group access checks; other methods validate locally.
|
||||
func (mw *Middleware) validateSessionToken(ctx context.Context, host, token string, publicKey ed25519.PublicKey, method auth.Method) (*validationResult, error) {
|
||||
// For OIDC with a session validator, call the gRPC service to check group access
|
||||
if method == auth.MethodOIDC && mw.sessionValidator != nil {
|
||||
resp, err := mw.sessionValidator.ValidateSession(ctx, &proto.ValidateSessionRequest{
|
||||
Domain: host,
|
||||
SessionToken: token,
|
||||
})
|
||||
if err != nil {
|
||||
mw.logger.WithError(err).Error("ValidateSession gRPC call failed")
|
||||
return nil, fmt.Errorf("session validation failed")
|
||||
return nil, fmt.Errorf("%w: %w", errValidationUnavailable, err)
|
||||
}
|
||||
if !resp.Valid {
|
||||
mw.logger.WithFields(log.Fields{
|
||||
@@ -342,7 +504,6 @@ func (mw *Middleware) validateSessionToken(ctx context.Context, host, token stri
|
||||
return &validationResult{UserID: resp.UserId, Valid: true}, nil
|
||||
}
|
||||
|
||||
// For non-OIDC methods or when no validator is configured, validate JWT locally
|
||||
userID, _, err := auth.ValidateSessionJWT(token, host, publicKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
Reference in New Issue
Block a user