mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-18 08:16:39 +00:00
add stateless proxy sessions
This commit is contained in:
@@ -2,6 +2,8 @@ package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/auth"
|
||||
)
|
||||
|
||||
type requestContextKey string
|
||||
@@ -11,13 +13,13 @@ const (
|
||||
authUserKey requestContextKey = "authUser"
|
||||
)
|
||||
|
||||
func withAuthMethod(ctx context.Context, method Method) context.Context {
|
||||
func withAuthMethod(ctx context.Context, method auth.Method) context.Context {
|
||||
return context.WithValue(ctx, authMethodKey, method)
|
||||
}
|
||||
|
||||
func MethodFromContext(ctx context.Context) Method {
|
||||
func MethodFromContext(ctx context.Context) auth.Method {
|
||||
v := ctx.Value(authMethodKey)
|
||||
method, ok := v.(Method)
|
||||
method, ok := v.(auth.Method)
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -1,52 +0,0 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
const linkFormId = "email"
|
||||
|
||||
type Link struct {
|
||||
id, accountId string
|
||||
client authenticator
|
||||
}
|
||||
|
||||
func NewLink(client authenticator, id, accountId string) Link {
|
||||
return Link{
|
||||
id: id,
|
||||
accountId: accountId,
|
||||
client: client,
|
||||
}
|
||||
}
|
||||
|
||||
func (Link) Type() Method {
|
||||
return MethodLink
|
||||
}
|
||||
|
||||
func (l Link) Authenticate(r *http.Request) (string, string) {
|
||||
email := r.FormValue(linkFormId)
|
||||
|
||||
res, err := l.client.Authenticate(r.Context(), &proto.AuthenticateRequest{
|
||||
Id: l.id,
|
||||
AccountId: l.accountId,
|
||||
Request: &proto.AuthenticateRequest_Link{
|
||||
Link: &proto.LinkRequest{
|
||||
Email: email,
|
||||
Redirect: "", // TODO: calculate this.
|
||||
},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
// TODO: log error here
|
||||
return "", linkFormId
|
||||
}
|
||||
|
||||
if res.GetSuccess() {
|
||||
// Use the email address as the user identifier.
|
||||
return email, ""
|
||||
}
|
||||
|
||||
return "", linkFormId
|
||||
}
|
||||
@@ -2,73 +2,50 @@ package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/ed25519"
|
||||
"encoding/base64"
|
||||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/auth"
|
||||
"google.golang.org/grpc"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/web"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
type Method string
|
||||
|
||||
var (
|
||||
MethodPassword Method = "password"
|
||||
MethodPIN Method = "pin"
|
||||
MethodOIDC Method = "oidc"
|
||||
MethodLink Method = "link"
|
||||
)
|
||||
|
||||
func (m Method) String() string {
|
||||
return string(m)
|
||||
}
|
||||
|
||||
const (
|
||||
sessionCookieName = "nb_session"
|
||||
sessionExpiration = 24 * time.Hour
|
||||
)
|
||||
|
||||
type session struct {
|
||||
UserID string
|
||||
Method Method
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
type authenticator interface {
|
||||
Authenticate(ctx context.Context, in *proto.AuthenticateRequest, opts ...grpc.CallOption) (*proto.AuthenticateResponse, error)
|
||||
}
|
||||
|
||||
type Scheme interface {
|
||||
Type() Method
|
||||
Type() auth.Method
|
||||
// Authenticate should check the passed request and determine whether
|
||||
// it represents an authenticated user request. If it does not, then
|
||||
// an empty string should indicate an unauthenticated request which
|
||||
// will be rejected; optionally, it can also return any data that should
|
||||
// be included in a UI template when prompting the user to authenticate.
|
||||
// If the request is authenticated, then a user id should be returned.
|
||||
Authenticate(*http.Request) (userid string, promptData string)
|
||||
// If the request is authenticated, then a session token should be returned.
|
||||
Authenticate(*http.Request) (token string, promptData string)
|
||||
}
|
||||
|
||||
type DomainConfig struct {
|
||||
Schemes []Scheme
|
||||
SessionPublicKey ed25519.PublicKey
|
||||
SessionExpiration time.Duration
|
||||
}
|
||||
|
||||
type Middleware struct {
|
||||
domainsMux sync.RWMutex
|
||||
domains map[string][]Scheme
|
||||
sessionsMux sync.RWMutex
|
||||
sessions map[string]*session
|
||||
domainsMux sync.RWMutex
|
||||
domains map[string]DomainConfig
|
||||
}
|
||||
|
||||
func NewMiddleware() *Middleware {
|
||||
mw := &Middleware{
|
||||
domains: make(map[string][]Scheme),
|
||||
sessions: make(map[string]*session),
|
||||
return &Middleware{
|
||||
domains: make(map[string]DomainConfig),
|
||||
}
|
||||
// TODO: goroutine is leaked here.
|
||||
go mw.cleanupSessions()
|
||||
return mw
|
||||
}
|
||||
|
||||
// Protect applies authentication middleware to the passed handler.
|
||||
@@ -87,24 +64,20 @@ func (mw *Middleware) Protect(next http.Handler) http.Handler {
|
||||
host = r.Host
|
||||
}
|
||||
mw.domainsMux.RLock()
|
||||
schemes, exists := mw.domains[host]
|
||||
config, exists := mw.domains[host]
|
||||
mw.domainsMux.RUnlock()
|
||||
|
||||
// Domains that are not configured here or have no authentication schemes applied should simply pass through.
|
||||
if !exists || len(schemes) == 0 {
|
||||
if !exists || len(config.Schemes) == 0 {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
// Check for an existing session to avoid users having to authenticate for every request.
|
||||
// TODO: This does not work if you are load balancing across multiple proxy servers.
|
||||
if cookie, err := r.Cookie(sessionCookieName); err == nil {
|
||||
mw.sessionsMux.RLock()
|
||||
sess, ok := mw.sessions[cookie.Value]
|
||||
mw.sessionsMux.RUnlock()
|
||||
if ok {
|
||||
ctx := withAuthMethod(r.Context(), sess.Method)
|
||||
ctx = withAuthUser(ctx, sess.UserID)
|
||||
// Check for an existing session cookie (contains JWT)
|
||||
if cookie, err := r.Cookie(auth.SessionCookieName); err == nil {
|
||||
if userID, method, err := auth.ValidateSessionJWT(cookie.Value, host, config.SessionPublicKey); err == nil {
|
||||
ctx := withAuthMethod(r.Context(), auth.Method(method))
|
||||
ctx = withAuthUser(ctx, userID)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
return
|
||||
}
|
||||
@@ -112,28 +85,59 @@ func (mw *Middleware) Protect(next http.Handler) http.Handler {
|
||||
|
||||
// Try to authenticate with each scheme.
|
||||
methods := make(map[string]string)
|
||||
for _, s := range schemes {
|
||||
userid, promptData := s.Authenticate(r)
|
||||
if userid != "" {
|
||||
mw.createSession(w, r, userid, s.Type())
|
||||
// Clean the path and redirect to the naked URL.
|
||||
// This is intended to prevent leaking potentially
|
||||
// sensitive query parameters for authentication
|
||||
// methods.
|
||||
http.Redirect(w, r, r.URL.Path, http.StatusFound)
|
||||
for _, scheme := range config.Schemes {
|
||||
token, promptData := scheme.Authenticate(r)
|
||||
if token != "" {
|
||||
userid, _, err := auth.ValidateSessionJWT(token, host, config.SessionPublicKey)
|
||||
if err != nil {
|
||||
// TODO: log, this should never fail.
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
expiration := config.SessionExpiration
|
||||
if expiration == 0 {
|
||||
expiration = auth.DefaultSessionExpiry
|
||||
}
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: auth.SessionCookieName,
|
||||
Value: token,
|
||||
HttpOnly: true,
|
||||
Secure: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
MaxAge: int(expiration.Seconds()),
|
||||
})
|
||||
|
||||
ctx := withAuthMethod(r.Context(), scheme.Type())
|
||||
ctx = withAuthUser(ctx, userid)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
return
|
||||
}
|
||||
methods[s.Type().String()] = promptData
|
||||
methods[scheme.Type().String()] = promptData
|
||||
}
|
||||
|
||||
web.ServeHTTP(w, r, map[string]any{"methods": methods})
|
||||
})
|
||||
}
|
||||
|
||||
func (mw *Middleware) AddDomain(domain string, schemes []Scheme) {
|
||||
func (mw *Middleware) AddDomain(domain string, schemes []Scheme, publicKeyB64 string, expiration time.Duration) {
|
||||
pubKeyBytes, err := base64.StdEncoding.DecodeString(publicKeyB64)
|
||||
if err != nil {
|
||||
// TODO: log
|
||||
return
|
||||
}
|
||||
if len(pubKeyBytes) != ed25519.PublicKeySize {
|
||||
// TODO: log
|
||||
return
|
||||
}
|
||||
|
||||
mw.domainsMux.Lock()
|
||||
defer mw.domainsMux.Unlock()
|
||||
mw.domains[domain] = schemes
|
||||
mw.domains[domain] = DomainConfig{
|
||||
Schemes: schemes,
|
||||
SessionPublicKey: pubKeyBytes,
|
||||
SessionExpiration: expiration,
|
||||
}
|
||||
}
|
||||
|
||||
func (mw *Middleware) RemoveDomain(domain string) {
|
||||
@@ -141,39 +145,3 @@ func (mw *Middleware) RemoveDomain(domain string) {
|
||||
defer mw.domainsMux.Unlock()
|
||||
delete(mw.domains, domain)
|
||||
}
|
||||
|
||||
func (mw *Middleware) createSession(w http.ResponseWriter, r *http.Request, userID string, method Method) {
|
||||
// Generate a random sessionID
|
||||
b := make([]byte, 32)
|
||||
_, _ = rand.Read(b)
|
||||
sessionID := base64.URLEncoding.EncodeToString(b)
|
||||
|
||||
mw.sessionsMux.Lock()
|
||||
mw.sessions[sessionID] = &session{
|
||||
UserID: userID,
|
||||
Method: method,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
mw.sessionsMux.Unlock()
|
||||
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: sessionCookieName,
|
||||
Value: sessionID,
|
||||
HttpOnly: true, // This cookie is only for proxy access, so no scripts should touch it.
|
||||
Secure: true, // The proxy only accepts TLS traffic regardless of the service proxied behind.
|
||||
SameSite: http.SameSiteLaxMode, // TODO: might this actually be strict mode?
|
||||
})
|
||||
}
|
||||
|
||||
func (mw *Middleware) cleanupSessions() {
|
||||
for range time.Tick(time.Minute) {
|
||||
cutoff := time.Now().Add(-sessionExpiration)
|
||||
mw.sessionsMux.Lock()
|
||||
for id, sess := range mw.sessions {
|
||||
if sess.CreatedAt.Before(cutoff) {
|
||||
delete(mw.sessions, id)
|
||||
}
|
||||
}
|
||||
mw.sessionsMux.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,66 +4,41 @@ import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
gojwt "github.com/golang-jwt/jwt/v5"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
"google.golang.org/grpc"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/auth/jwt"
|
||||
"github.com/netbirdio/netbird/proxy/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
type urlGenerator interface {
|
||||
GetOIDCURL(context.Context, *proto.GetOIDCURLRequest, ...grpc.CallOption) (*proto.GetOIDCURLResponse, error)
|
||||
}
|
||||
|
||||
// OIDCConfig holds configuration for OIDC JWT verification
|
||||
type OIDCConfig struct {
|
||||
Issuer string
|
||||
Audiences []string
|
||||
KeysLocation string
|
||||
MaxTokenAgeSeconds int64
|
||||
}
|
||||
|
||||
// oidcState stores CSRF state with expiration
|
||||
type oidcState struct {
|
||||
OriginalURL string
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
// OIDC implements the Scheme interface for JWT/OIDC authentication
|
||||
type OIDC struct {
|
||||
id, accountId string
|
||||
validator *jwt.Validator
|
||||
maxTokenAgeSeconds int64
|
||||
client urlGenerator
|
||||
id, accountId string
|
||||
client urlGenerator
|
||||
}
|
||||
|
||||
// NewOIDC creates a new OIDC authentication scheme
|
||||
func NewOIDC(client urlGenerator, id, accountId string, cfg OIDCConfig) *OIDC {
|
||||
return &OIDC{
|
||||
func NewOIDC(client urlGenerator, id, accountId string) OIDC {
|
||||
return OIDC{
|
||||
id: id,
|
||||
accountId: accountId,
|
||||
validator: jwt.NewValidator(
|
||||
cfg.Issuer,
|
||||
cfg.Audiences,
|
||||
cfg.KeysLocation,
|
||||
true,
|
||||
),
|
||||
maxTokenAgeSeconds: cfg.MaxTokenAgeSeconds,
|
||||
client: client,
|
||||
client: client,
|
||||
}
|
||||
}
|
||||
|
||||
func (*OIDC) Type() Method {
|
||||
return MethodOIDC
|
||||
func (OIDC) Type() auth.Method {
|
||||
return auth.MethodOIDC
|
||||
}
|
||||
|
||||
func (o *OIDC) Authenticate(r *http.Request) (string, string) {
|
||||
if token := r.URL.Query().Get("access_token"); token != "" {
|
||||
if userID := o.validateToken(r.Context(), token); userID != "" {
|
||||
return userID, ""
|
||||
}
|
||||
func (o OIDC) Authenticate(r *http.Request) (string, string) {
|
||||
// Check for the session_token query param (from OIDC redirects).
|
||||
// The management server passes the token in the URL because it cannot set
|
||||
// cookies for the proxy's domain (cookies are domain-scoped per RFC 6265).
|
||||
if token := r.URL.Query().Get("session_token"); token != "" {
|
||||
return token, ""
|
||||
}
|
||||
|
||||
redirectURL := &url.URL{
|
||||
@@ -84,55 +59,3 @@ func (o *OIDC) Authenticate(r *http.Request) (string, string) {
|
||||
|
||||
return "", res.GetUrl()
|
||||
}
|
||||
|
||||
// validateToken validates a JWT ID token and returns the user ID (subject)
|
||||
// Returns empty string if token is invalid.
|
||||
func (o *OIDC) validateToken(ctx context.Context, token string) string {
|
||||
if o.validator == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
idToken, err := o.validator.ValidateAndParse(ctx, token)
|
||||
if err != nil {
|
||||
// TODO: log or return?
|
||||
return ""
|
||||
}
|
||||
|
||||
iat, err := idToken.Claims.GetIssuedAt()
|
||||
if err != nil {
|
||||
// TODO: log or return?
|
||||
return ""
|
||||
}
|
||||
|
||||
// If max token age is 0 skip this check.
|
||||
if o.maxTokenAgeSeconds > 0 && time.Since(iat.Time).Seconds() > float64(o.maxTokenAgeSeconds) {
|
||||
// TODO: log or return?
|
||||
return ""
|
||||
}
|
||||
|
||||
return extractUserID(idToken)
|
||||
}
|
||||
|
||||
func extractUserID(token *gojwt.Token) string {
|
||||
if token == nil {
|
||||
return "unknown"
|
||||
}
|
||||
claims, ok := token.Claims.(gojwt.MapClaims)
|
||||
if !ok {
|
||||
return "unknown"
|
||||
}
|
||||
return getUserIDFromClaims(claims)
|
||||
}
|
||||
|
||||
func getUserIDFromClaims(claims gojwt.MapClaims) string {
|
||||
if sub, ok := claims["sub"].(string); ok && sub != "" {
|
||||
return sub
|
||||
}
|
||||
if userID, ok := claims["user_id"].(string); ok && userID != "" {
|
||||
return userID
|
||||
}
|
||||
if email, ok := claims["email"].(string); ok && email != "" {
|
||||
return email
|
||||
}
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
@@ -3,13 +3,11 @@ package auth
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
const (
|
||||
passwordUserId = "password-user"
|
||||
passwordFormId = "password"
|
||||
)
|
||||
const passwordFormId = "password"
|
||||
|
||||
type Password struct {
|
||||
id, accountId string
|
||||
@@ -24,8 +22,8 @@ func NewPassword(client authenticator, id, accountId string) Password {
|
||||
}
|
||||
}
|
||||
|
||||
func (Password) Type() Method {
|
||||
return MethodPassword
|
||||
func (Password) Type() auth.Method {
|
||||
return auth.MethodPassword
|
||||
}
|
||||
|
||||
// Authenticate attempts to authenticate the request using a form
|
||||
@@ -37,7 +35,7 @@ func (p Password) Authenticate(r *http.Request) (string, string) {
|
||||
password := r.FormValue(passwordFormId)
|
||||
|
||||
if password == "" {
|
||||
// This cannot be authenticated so not worth wasting time sending the request.
|
||||
// This cannot be authenticated, so not worth wasting time sending the request.
|
||||
return "", passwordFormId
|
||||
}
|
||||
|
||||
@@ -56,7 +54,7 @@ func (p Password) Authenticate(r *http.Request) (string, string) {
|
||||
}
|
||||
|
||||
if res.GetSuccess() {
|
||||
return passwordUserId, ""
|
||||
return res.GetSessionToken(), ""
|
||||
}
|
||||
|
||||
return "", passwordFormId
|
||||
|
||||
@@ -3,13 +3,11 @@ package auth
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
const (
|
||||
pinUserId = "pin-user"
|
||||
pinFormId = "pin"
|
||||
)
|
||||
const pinFormId = "pin"
|
||||
|
||||
type Pin struct {
|
||||
id, accountId string
|
||||
@@ -24,8 +22,8 @@ func NewPin(client authenticator, id, accountId string) Pin {
|
||||
}
|
||||
}
|
||||
|
||||
func (Pin) Type() Method {
|
||||
return MethodPIN
|
||||
func (Pin) Type() auth.Method {
|
||||
return auth.MethodPIN
|
||||
}
|
||||
|
||||
// Authenticate attempts to authenticate the request using a form
|
||||
@@ -37,7 +35,7 @@ func (p Pin) Authenticate(r *http.Request) (string, string) {
|
||||
pin := r.FormValue(pinFormId)
|
||||
|
||||
if pin == "" {
|
||||
// This cannot be authenticated so not worth wasting time sending the request.
|
||||
// This cannot be authenticated, so not worth wasting time sending the request.
|
||||
return "", pinFormId
|
||||
}
|
||||
|
||||
@@ -56,7 +54,7 @@ func (p Pin) Authenticate(r *http.Request) (string, string) {
|
||||
}
|
||||
|
||||
if res.GetSuccess() {
|
||||
return pinUserId, ""
|
||||
return res.GetSessionToken(), ""
|
||||
}
|
||||
|
||||
return "", pinFormId
|
||||
|
||||
Reference in New Issue
Block a user