add stateless proxy sessions

This commit is contained in:
Alisdair MacLeod
2026-02-04 16:52:35 +00:00
parent 476785b122
commit 694ae13418
16 changed files with 744 additions and 774 deletions

View File

@@ -2,73 +2,50 @@ package auth
import (
"context"
"crypto/rand"
"crypto/ed25519"
"encoding/base64"
"net"
"net/http"
"sync"
"time"
"github.com/netbirdio/netbird/proxy/auth"
"google.golang.org/grpc"
"github.com/netbirdio/netbird/proxy/web"
"github.com/netbirdio/netbird/shared/management/proto"
)
type Method string
var (
MethodPassword Method = "password"
MethodPIN Method = "pin"
MethodOIDC Method = "oidc"
MethodLink Method = "link"
)
func (m Method) String() string {
return string(m)
}
const (
sessionCookieName = "nb_session"
sessionExpiration = 24 * time.Hour
)
type session struct {
UserID string
Method Method
CreatedAt time.Time
}
type authenticator interface {
Authenticate(ctx context.Context, in *proto.AuthenticateRequest, opts ...grpc.CallOption) (*proto.AuthenticateResponse, error)
}
type Scheme interface {
Type() Method
Type() auth.Method
// Authenticate should check the passed request and determine whether
// it represents an authenticated user request. If it does not, then
// an empty string should indicate an unauthenticated request which
// will be rejected; optionally, it can also return any data that should
// be included in a UI template when prompting the user to authenticate.
// If the request is authenticated, then a user id should be returned.
Authenticate(*http.Request) (userid string, promptData string)
// If the request is authenticated, then a session token should be returned.
Authenticate(*http.Request) (token string, promptData string)
}
type DomainConfig struct {
Schemes []Scheme
SessionPublicKey ed25519.PublicKey
SessionExpiration time.Duration
}
type Middleware struct {
domainsMux sync.RWMutex
domains map[string][]Scheme
sessionsMux sync.RWMutex
sessions map[string]*session
domainsMux sync.RWMutex
domains map[string]DomainConfig
}
func NewMiddleware() *Middleware {
mw := &Middleware{
domains: make(map[string][]Scheme),
sessions: make(map[string]*session),
return &Middleware{
domains: make(map[string]DomainConfig),
}
// TODO: goroutine is leaked here.
go mw.cleanupSessions()
return mw
}
// Protect applies authentication middleware to the passed handler.
@@ -87,24 +64,20 @@ func (mw *Middleware) Protect(next http.Handler) http.Handler {
host = r.Host
}
mw.domainsMux.RLock()
schemes, exists := mw.domains[host]
config, exists := mw.domains[host]
mw.domainsMux.RUnlock()
// Domains that are not configured here or have no authentication schemes applied should simply pass through.
if !exists || len(schemes) == 0 {
if !exists || len(config.Schemes) == 0 {
next.ServeHTTP(w, r)
return
}
// Check for an existing session to avoid users having to authenticate for every request.
// TODO: This does not work if you are load balancing across multiple proxy servers.
if cookie, err := r.Cookie(sessionCookieName); err == nil {
mw.sessionsMux.RLock()
sess, ok := mw.sessions[cookie.Value]
mw.sessionsMux.RUnlock()
if ok {
ctx := withAuthMethod(r.Context(), sess.Method)
ctx = withAuthUser(ctx, sess.UserID)
// Check for an existing session cookie (contains JWT)
if cookie, err := r.Cookie(auth.SessionCookieName); err == nil {
if userID, method, err := auth.ValidateSessionJWT(cookie.Value, host, config.SessionPublicKey); err == nil {
ctx := withAuthMethod(r.Context(), auth.Method(method))
ctx = withAuthUser(ctx, userID)
next.ServeHTTP(w, r.WithContext(ctx))
return
}
@@ -112,28 +85,59 @@ func (mw *Middleware) Protect(next http.Handler) http.Handler {
// Try to authenticate with each scheme.
methods := make(map[string]string)
for _, s := range schemes {
userid, promptData := s.Authenticate(r)
if userid != "" {
mw.createSession(w, r, userid, s.Type())
// Clean the path and redirect to the naked URL.
// This is intended to prevent leaking potentially
// sensitive query parameters for authentication
// methods.
http.Redirect(w, r, r.URL.Path, http.StatusFound)
for _, scheme := range config.Schemes {
token, promptData := scheme.Authenticate(r)
if token != "" {
userid, _, err := auth.ValidateSessionJWT(token, host, config.SessionPublicKey)
if err != nil {
// TODO: log, this should never fail.
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
expiration := config.SessionExpiration
if expiration == 0 {
expiration = auth.DefaultSessionExpiry
}
http.SetCookie(w, &http.Cookie{
Name: auth.SessionCookieName,
Value: token,
HttpOnly: true,
Secure: true,
SameSite: http.SameSiteLaxMode,
MaxAge: int(expiration.Seconds()),
})
ctx := withAuthMethod(r.Context(), scheme.Type())
ctx = withAuthUser(ctx, userid)
next.ServeHTTP(w, r.WithContext(ctx))
return
}
methods[s.Type().String()] = promptData
methods[scheme.Type().String()] = promptData
}
web.ServeHTTP(w, r, map[string]any{"methods": methods})
})
}
func (mw *Middleware) AddDomain(domain string, schemes []Scheme) {
func (mw *Middleware) AddDomain(domain string, schemes []Scheme, publicKeyB64 string, expiration time.Duration) {
pubKeyBytes, err := base64.StdEncoding.DecodeString(publicKeyB64)
if err != nil {
// TODO: log
return
}
if len(pubKeyBytes) != ed25519.PublicKeySize {
// TODO: log
return
}
mw.domainsMux.Lock()
defer mw.domainsMux.Unlock()
mw.domains[domain] = schemes
mw.domains[domain] = DomainConfig{
Schemes: schemes,
SessionPublicKey: pubKeyBytes,
SessionExpiration: expiration,
}
}
func (mw *Middleware) RemoveDomain(domain string) {
@@ -141,39 +145,3 @@ func (mw *Middleware) RemoveDomain(domain string) {
defer mw.domainsMux.Unlock()
delete(mw.domains, domain)
}
func (mw *Middleware) createSession(w http.ResponseWriter, r *http.Request, userID string, method Method) {
// Generate a random sessionID
b := make([]byte, 32)
_, _ = rand.Read(b)
sessionID := base64.URLEncoding.EncodeToString(b)
mw.sessionsMux.Lock()
mw.sessions[sessionID] = &session{
UserID: userID,
Method: method,
CreatedAt: time.Now(),
}
mw.sessionsMux.Unlock()
http.SetCookie(w, &http.Cookie{
Name: sessionCookieName,
Value: sessionID,
HttpOnly: true, // This cookie is only for proxy access, so no scripts should touch it.
Secure: true, // The proxy only accepts TLS traffic regardless of the service proxied behind.
SameSite: http.SameSiteLaxMode, // TODO: might this actually be strict mode?
})
}
func (mw *Middleware) cleanupSessions() {
for range time.Tick(time.Minute) {
cutoff := time.Now().Add(-sessionExpiration)
mw.sessionsMux.Lock()
for id, sess := range mw.sessions {
if sess.CreatedAt.Before(cutoff) {
delete(mw.sessions, id)
}
}
mw.sessionsMux.Unlock()
}
}