add stateless proxy sessions

This commit is contained in:
Alisdair MacLeod
2026-02-04 16:52:35 +00:00
parent 476785b122
commit 694ae13418
16 changed files with 744 additions and 774 deletions

View File

@@ -9,6 +9,7 @@ import (
"github.com/rs/xid"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity"
@@ -75,6 +76,14 @@ func (m *managerImpl) CreateReverseProxy(ctx context.Context, accountID, userID
reverseProxy.Auth = authConfig
// Generate session JWT signing keys
keyPair, err := sessionkey.GenerateKeyPair()
if err != nil {
return nil, fmt.Errorf("generate session keys: %w", err)
}
reverseProxy.SessionPrivateKey = keyPair.PrivateKey
reverseProxy.SessionPublicKey = keyPair.PublicKey
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
// Check for duplicate domain
existingReverseProxy, err := transaction.GetReverseProxyByDomain(ctx, accountID, reverseProxy.Domain)

View File

@@ -7,6 +7,7 @@ import (
"strconv"
"time"
"github.com/netbirdio/netbird/util/crypt"
"github.com/rs/xid"
log "github.com/sirupsen/logrus"
@@ -58,15 +59,10 @@ type BearerAuthConfig struct {
DistributionGroups []string `json:"distribution_groups,omitempty" gorm:"serializer:json"`
}
type LinkAuthConfig struct {
Enabled bool `json:"enabled"`
}
type AuthConfig struct {
PasswordAuth *PasswordAuthConfig `json:"password_auth,omitempty" gorm:"serializer:json"`
PinAuth *PINAuthConfig `json:"pin_auth,omitempty" gorm:"serializer:json"`
BearerAuth *BearerAuthConfig `json:"bearer_auth,omitempty" gorm:"serializer:json"`
LinkAuth *LinkAuthConfig `json:"link_auth,omitempty" gorm:"serializer:json"`
}
type OIDCValidationConfig struct {
@@ -83,14 +79,16 @@ type ReverseProxyMeta struct {
}
type ReverseProxy struct {
ID string `gorm:"primaryKey"`
AccountID string `gorm:"index"`
Name string
Domain string `gorm:"index"`
Targets []Target `gorm:"serializer:json"`
Enabled bool
Auth AuthConfig `gorm:"serializer:json"`
Meta ReverseProxyMeta `gorm:"embedded;embeddedPrefix:meta_"`
ID string `gorm:"primaryKey"`
AccountID string `gorm:"index"`
Name string
Domain string `gorm:"index"`
Targets []Target `gorm:"serializer:json"`
Enabled bool
Auth AuthConfig `gorm:"serializer:json"`
Meta ReverseProxyMeta `gorm:"embedded;embeddedPrefix:meta_"`
SessionPrivateKey string `gorm:"column:session_private_key"`
SessionPublicKey string `gorm:"column:session_public_key"`
}
func NewReverseProxy(accountID, name, domain string, targets []Target, enabled bool) *ReverseProxy {
@@ -132,12 +130,6 @@ func (r *ReverseProxy) ToAPIResponse() *api.ReverseProxy {
}
}
if r.Auth.LinkAuth != nil {
authConfig.LinkAuth = &api.LinkAuthConfig{
Enabled: r.Auth.LinkAuth.Enabled,
}
}
// Convert internal targets to API targets
apiTargets := make([]api.ReverseProxyTarget, 0, len(r.Targets))
for _, target := range r.Targets {
@@ -199,7 +191,10 @@ func (r *ReverseProxy) ToProtoMapping(operation Operation, setupKey string, oidc
})
}
auth := &proto.Authentication{}
auth := &proto.Authentication{
SessionKey: r.SessionPublicKey,
MaxSessionAgeSeconds: int64((time.Hour * 24).Seconds()),
}
if r.Auth.PasswordAuth != nil && r.Auth.PasswordAuth.Enabled {
auth.Password = true
@@ -210,16 +205,7 @@ func (r *ReverseProxy) ToProtoMapping(operation Operation, setupKey string, oidc
}
if r.Auth.BearerAuth != nil && r.Auth.BearerAuth.Enabled {
auth.Oidc = &proto.OIDC{
Issuer: oidcConfig.Issuer,
Audiences: oidcConfig.Audiences,
KeysLocation: oidcConfig.KeysLocation,
MaxTokenAge: oidcConfig.MaxTokenAgeSeconds,
}
}
if r.Auth.LinkAuth != nil && r.Auth.LinkAuth.Enabled {
auth.Link = true
auth.Oidc = true
}
return &proto.ProxyMapping{
@@ -291,13 +277,6 @@ func (r *ReverseProxy) FromAPIRequest(req *api.ReverseProxyRequest, accountID st
}
r.Auth.BearerAuth = bearerAuth
}
if req.Auth.LinkAuth != nil {
r.Auth.LinkAuth = &LinkAuthConfig{
Enabled: req.Auth.LinkAuth.Enabled,
}
}
}
func (r *ReverseProxy) Validate() error {
@@ -322,3 +301,53 @@ func (r *ReverseProxy) Validate() error {
func (r *ReverseProxy) EventMeta() map[string]any {
return map[string]any{"name": r.Name, "domain": r.Domain}
}
func (r *ReverseProxy) Copy() *ReverseProxy {
targets := make([]Target, len(r.Targets))
copy(targets, r.Targets)
return &ReverseProxy{
ID: r.ID,
AccountID: r.AccountID,
Name: r.Name,
Domain: r.Domain,
Targets: targets,
Enabled: r.Enabled,
Auth: r.Auth,
Meta: r.Meta,
SessionPrivateKey: r.SessionPrivateKey,
SessionPublicKey: r.SessionPublicKey,
}
}
func (r *ReverseProxy) EncryptSensitiveData(enc *crypt.FieldEncrypt) error {
if enc == nil {
return nil
}
if r.SessionPrivateKey != "" {
var err error
r.SessionPrivateKey, err = enc.Encrypt(r.SessionPrivateKey)
if err != nil {
return err
}
}
return nil
}
func (r *ReverseProxy) DecryptSensitiveData(enc *crypt.FieldEncrypt) error {
if enc == nil {
return nil
}
if r.SessionPrivateKey != "" {
var err error
r.SessionPrivateKey, err = enc.Decrypt(r.SessionPrivateKey)
if err != nil {
return err
}
}
return nil
}

View File

@@ -0,0 +1,69 @@
package sessionkey
import (
"crypto/ed25519"
"crypto/rand"
"encoding/base64"
"fmt"
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/netbirdio/netbird/proxy/auth"
)
type KeyPair struct {
PrivateKey string
PublicKey string
}
type Claims struct {
jwt.RegisteredClaims
Method auth.Method `json:"method"`
}
func GenerateKeyPair() (*KeyPair, error) {
pub, priv, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
return nil, fmt.Errorf("generate ed25519 key: %w", err)
}
return &KeyPair{
PrivateKey: base64.StdEncoding.EncodeToString(priv),
PublicKey: base64.StdEncoding.EncodeToString(pub),
}, nil
}
func SignToken(privKeyB64, userID, domain string, method auth.Method, expiration time.Duration) (string, error) {
privKeyBytes, err := base64.StdEncoding.DecodeString(privKeyB64)
if err != nil {
return "", fmt.Errorf("decode private key: %w", err)
}
if len(privKeyBytes) != ed25519.PrivateKeySize {
return "", fmt.Errorf("invalid private key size: got %d, want %d", len(privKeyBytes), ed25519.PrivateKeySize)
}
privKey := ed25519.PrivateKey(privKeyBytes)
now := time.Now()
claims := Claims{
RegisteredClaims: jwt.RegisteredClaims{
Issuer: auth.SessionJWTIssuer,
Subject: userID,
Audience: jwt.ClaimStrings{domain},
ExpiresAt: jwt.NewNumericDate(now.Add(expiration)),
IssuedAt: jwt.NewNumericDate(now),
NotBefore: jwt.NewNumericDate(now),
},
Method: method,
}
token := jwt.NewWithClaims(jwt.SigningMethodEdDSA, claims)
signedToken, err := token.SignedString(privKey)
if err != nil {
return "", fmt.Errorf("sign token: %w", err)
}
return signedToken, nil
}

View File

@@ -23,9 +23,11 @@ import (
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
proxyauth "github.com/netbirdio/netbird/proxy/auth"
"github.com/netbirdio/netbird/shared/management/proto"
)
@@ -317,7 +319,10 @@ func (s *ProxyServiceServer) Authenticate(ctx context.Context, req *proto.Authen
// TODO: log the error
return nil, status.Errorf(codes.FailedPrecondition, "failed to get reverse proxy from store: %v", err)
}
var authenticated bool
var userId string
var method proxyauth.Method
switch v := req.GetRequest().(type) {
case *proto.AuthenticateRequest_Pin:
auth := proxy.Auth.PinAuth
@@ -327,6 +332,8 @@ func (s *ProxyServiceServer) Authenticate(ctx context.Context, req *proto.Authen
break
}
authenticated = subtle.ConstantTimeCompare([]byte(auth.Pin), []byte(v.Pin.GetPin())) == 1
userId = "pin-user"
method = proxyauth.MethodPIN
case *proto.AuthenticateRequest_Password:
auth := proxy.Auth.PasswordAuth
if auth == nil || !auth.Enabled {
@@ -335,9 +342,28 @@ func (s *ProxyServiceServer) Authenticate(ctx context.Context, req *proto.Authen
break
}
authenticated = subtle.ConstantTimeCompare([]byte(auth.Password), []byte(v.Password.GetPassword())) == 1
userId = "password-user"
method = proxyauth.MethodPassword
}
var token string
if authenticated && proxy.SessionPrivateKey != "" {
token, err = sessionkey.SignToken(
proxy.SessionPrivateKey,
userId,
proxy.Domain,
method,
proxyauth.DefaultSessionExpiry,
)
if err != nil {
log.WithError(err).Error("Failed to sign session token")
authenticated = false
}
}
return &proto.AuthenticateResponse{
Success: authenticated,
Success: authenticated,
SessionToken: token,
}, nil
}
@@ -516,3 +542,35 @@ func (s *ProxyServiceServer) ValidateState(state string) (verifier, redirectURL
return verifier, redirectURL, nil
}
// GenerateSessionToken creates a signed session JWT for the given domain and user.
func (s *ProxyServiceServer) GenerateSessionToken(ctx context.Context, domain, userID string, method proxyauth.Method) (string, error) {
// Find the proxy by domain to get its signing key
proxies, err := s.reverseProxyStore.GetReverseProxies(ctx, store.LockingStrengthNone)
if err != nil {
return "", fmt.Errorf("get reverse proxies: %w", err)
}
var proxy *reverseproxy.ReverseProxy
for _, p := range proxies {
if p.Domain == domain {
proxy = p
break
}
}
if proxy == nil {
return "", fmt.Errorf("reverse proxy not found for domain: %s", domain)
}
if proxy.SessionPrivateKey == "" {
return "", fmt.Errorf("no session key configured for domain: %s", domain)
}
return sessionkey.SignToken(
proxy.SessionPrivateKey,
userID,
domain,
method,
proxyauth.DefaultSessionExpiry,
)
}

View File

@@ -1,6 +1,7 @@
package proxy
import (
"context"
"net/http"
"net/url"
@@ -10,6 +11,7 @@ import (
"golang.org/x/oauth2"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/proxy/auth"
)
type AuthCallbackHandler struct {
@@ -65,16 +67,77 @@ func (h *AuthCallbackHandler) handleCallback(w http.ResponseWriter, r *http.Requ
return
}
redirectQuery := redirectURL.Query()
redirectQuery.Set("access_token", token.AccessToken)
if token.RefreshToken != "" {
redirectQuery.Set("refresh_token", token.RefreshToken)
// Extract user ID from the OIDC token
userID := extractUserIDFromToken(r.Context(), provider, oidcConfig, token)
if userID == "" {
log.Error("Failed to extract user ID from OIDC token")
http.Error(w, "Failed to validate token", http.StatusUnauthorized)
return
}
// Generate session JWT instead of passing OIDC access_token
sessionToken, err := h.proxyService.GenerateSessionToken(r.Context(), redirectURL.Hostname(), userID, auth.MethodOIDC)
if err != nil {
log.WithError(err).Error("Failed to create session token")
http.Error(w, "Failed to create session", http.StatusInternalServerError)
return
}
redirectURL.RawQuery = redirectQuery.Encode()
// Redirect must be HTTPS, regardless of what was originally intended (which should always be HTTPS but better to double-check here).
redirectURL.Scheme = "https"
log.WithField("redirect", redirectURL.String()).Debug("OAuth callback: redirecting user with token")
// Pass the session token in the URL query parameter. The proxy middleware will
// extract it, validate it, set its own cookie, and redirect to remove the token from the URL.
// We cannot set the cookie here because cookies are domain-scoped (RFC 6265) and the
// management server cannot set cookies for the proxy's domain.
query := redirectURL.Query()
query.Set("session_token", sessionToken)
redirectURL.RawQuery = query.Encode()
log.WithField("redirect", redirectURL.Host).Debug("OAuth callback: redirecting user with session token")
http.Redirect(w, r, redirectURL.String(), http.StatusFound)
}
// extractUserIDFromToken extracts the user ID from an OIDC token.
func extractUserIDFromToken(ctx context.Context, provider *oidc.Provider, config nbgrpc.ProxyOIDCConfig, token *oauth2.Token) string {
// Try to get ID token from the oauth2 token extras
rawIDToken, ok := token.Extra("id_token").(string)
if !ok {
log.Warn("No id_token in OIDC response")
return ""
}
verifier := provider.Verifier(&oidc.Config{
ClientID: config.ClientID,
})
idToken, err := verifier.Verify(ctx, rawIDToken)
if err != nil {
log.WithError(err).Warn("Failed to verify ID token")
return ""
}
// Extract claims
var claims struct {
Subject string `json:"sub"`
Email string `json:"email"`
UserID string `json:"user_id"`
}
if err := idToken.Claims(&claims); err != nil {
log.WithError(err).Warn("Failed to extract claims from ID token")
return ""
}
// Prefer subject, fall back to user_id or email
if claims.Subject != "" {
return claims.Subject
}
if claims.UserID != "" {
return claims.UserID
}
if claims.Email != "" {
return claims.Email
}
return ""
}

View File

@@ -4609,7 +4609,11 @@ func (s *SqlStore) GetPeerIDByKey(ctx context.Context, lockStrength LockingStren
}
func (s *SqlStore) CreateReverseProxy(ctx context.Context, proxy *reverseproxy.ReverseProxy) error {
result := s.db.Create(proxy)
proxyCopy := proxy.Copy()
if err := proxyCopy.EncryptSensitiveData(s.fieldEncrypt); err != nil {
return fmt.Errorf("encrypt reverse proxy data: %w", err)
}
result := s.db.Create(proxyCopy)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to create reverse proxy to store: %v", result.Error)
return status.Errorf(status.Internal, "failed to create reverse proxy to store")
@@ -4619,7 +4623,11 @@ func (s *SqlStore) CreateReverseProxy(ctx context.Context, proxy *reverseproxy.R
}
func (s *SqlStore) UpdateReverseProxy(ctx context.Context, proxy *reverseproxy.ReverseProxy) error {
result := s.db.Select("*").Save(proxy)
proxyCopy := proxy.Copy()
if err := proxyCopy.EncryptSensitiveData(s.fieldEncrypt); err != nil {
return fmt.Errorf("encrypt reverse proxy data: %w", err)
}
result := s.db.Select("*").Save(proxyCopy)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to update reverse proxy to store: %v", result.Error)
return status.Errorf(status.Internal, "failed to update reverse proxy to store")
@@ -4659,6 +4667,10 @@ func (s *SqlStore) GetReverseProxyByID(ctx context.Context, lockStrength Locking
return nil, status.Errorf(status.Internal, "failed to get reverse proxy from store")
}
if err := proxy.DecryptSensitiveData(s.fieldEncrypt); err != nil {
return nil, fmt.Errorf("decrypt reverse proxy data: %w", err)
}
return proxy, nil
}
@@ -4674,6 +4686,10 @@ func (s *SqlStore) GetReverseProxyByDomain(ctx context.Context, accountID, domai
return nil, status.Errorf(status.Internal, "failed to get reverse proxy by domain from store")
}
if err := proxy.DecryptSensitiveData(s.fieldEncrypt); err != nil {
return nil, fmt.Errorf("decrypt reverse proxy data: %w", err)
}
return proxy, nil
}
@@ -4690,6 +4706,12 @@ func (s *SqlStore) GetReverseProxies(ctx context.Context, lockStrength LockingSt
return nil, status.Errorf(status.Internal, "failed to get reverse proxy from store")
}
for _, proxy := range proxyList {
if err := proxy.DecryptSensitiveData(s.fieldEncrypt); err != nil {
return nil, fmt.Errorf("decrypt reverse proxy data: %w", err)
}
}
return proxyList, nil
}
@@ -4706,6 +4728,12 @@ func (s *SqlStore) GetAccountReverseProxies(ctx context.Context, lockStrength Lo
return nil, status.Errorf(status.Internal, "failed to get reverse proxy from store")
}
for _, proxy := range proxyList {
if err := proxy.DecryptSensitiveData(s.fieldEncrypt); err != nil {
return nil, fmt.Errorf("decrypt reverse proxy data: %w", err)
}
}
return proxyList, nil
}