mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-18 16:26:38 +00:00
add stateless proxy sessions
This commit is contained in:
@@ -9,6 +9,7 @@ import (
|
|||||||
"github.com/rs/xid"
|
"github.com/rs/xid"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey"
|
||||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||||
"github.com/netbirdio/netbird/management/server/account"
|
"github.com/netbirdio/netbird/management/server/account"
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
@@ -75,6 +76,14 @@ func (m *managerImpl) CreateReverseProxy(ctx context.Context, accountID, userID
|
|||||||
|
|
||||||
reverseProxy.Auth = authConfig
|
reverseProxy.Auth = authConfig
|
||||||
|
|
||||||
|
// Generate session JWT signing keys
|
||||||
|
keyPair, err := sessionkey.GenerateKeyPair()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("generate session keys: %w", err)
|
||||||
|
}
|
||||||
|
reverseProxy.SessionPrivateKey = keyPair.PrivateKey
|
||||||
|
reverseProxy.SessionPublicKey = keyPair.PublicKey
|
||||||
|
|
||||||
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||||
// Check for duplicate domain
|
// Check for duplicate domain
|
||||||
existingReverseProxy, err := transaction.GetReverseProxyByDomain(ctx, accountID, reverseProxy.Domain)
|
existingReverseProxy, err := transaction.GetReverseProxyByDomain(ctx, accountID, reverseProxy.Domain)
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
"strconv"
|
"strconv"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/util/crypt"
|
||||||
"github.com/rs/xid"
|
"github.com/rs/xid"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
@@ -58,15 +59,10 @@ type BearerAuthConfig struct {
|
|||||||
DistributionGroups []string `json:"distribution_groups,omitempty" gorm:"serializer:json"`
|
DistributionGroups []string `json:"distribution_groups,omitempty" gorm:"serializer:json"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type LinkAuthConfig struct {
|
|
||||||
Enabled bool `json:"enabled"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type AuthConfig struct {
|
type AuthConfig struct {
|
||||||
PasswordAuth *PasswordAuthConfig `json:"password_auth,omitempty" gorm:"serializer:json"`
|
PasswordAuth *PasswordAuthConfig `json:"password_auth,omitempty" gorm:"serializer:json"`
|
||||||
PinAuth *PINAuthConfig `json:"pin_auth,omitempty" gorm:"serializer:json"`
|
PinAuth *PINAuthConfig `json:"pin_auth,omitempty" gorm:"serializer:json"`
|
||||||
BearerAuth *BearerAuthConfig `json:"bearer_auth,omitempty" gorm:"serializer:json"`
|
BearerAuth *BearerAuthConfig `json:"bearer_auth,omitempty" gorm:"serializer:json"`
|
||||||
LinkAuth *LinkAuthConfig `json:"link_auth,omitempty" gorm:"serializer:json"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type OIDCValidationConfig struct {
|
type OIDCValidationConfig struct {
|
||||||
@@ -83,14 +79,16 @@ type ReverseProxyMeta struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type ReverseProxy struct {
|
type ReverseProxy struct {
|
||||||
ID string `gorm:"primaryKey"`
|
ID string `gorm:"primaryKey"`
|
||||||
AccountID string `gorm:"index"`
|
AccountID string `gorm:"index"`
|
||||||
Name string
|
Name string
|
||||||
Domain string `gorm:"index"`
|
Domain string `gorm:"index"`
|
||||||
Targets []Target `gorm:"serializer:json"`
|
Targets []Target `gorm:"serializer:json"`
|
||||||
Enabled bool
|
Enabled bool
|
||||||
Auth AuthConfig `gorm:"serializer:json"`
|
Auth AuthConfig `gorm:"serializer:json"`
|
||||||
Meta ReverseProxyMeta `gorm:"embedded;embeddedPrefix:meta_"`
|
Meta ReverseProxyMeta `gorm:"embedded;embeddedPrefix:meta_"`
|
||||||
|
SessionPrivateKey string `gorm:"column:session_private_key"`
|
||||||
|
SessionPublicKey string `gorm:"column:session_public_key"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewReverseProxy(accountID, name, domain string, targets []Target, enabled bool) *ReverseProxy {
|
func NewReverseProxy(accountID, name, domain string, targets []Target, enabled bool) *ReverseProxy {
|
||||||
@@ -132,12 +130,6 @@ func (r *ReverseProxy) ToAPIResponse() *api.ReverseProxy {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if r.Auth.LinkAuth != nil {
|
|
||||||
authConfig.LinkAuth = &api.LinkAuthConfig{
|
|
||||||
Enabled: r.Auth.LinkAuth.Enabled,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Convert internal targets to API targets
|
// Convert internal targets to API targets
|
||||||
apiTargets := make([]api.ReverseProxyTarget, 0, len(r.Targets))
|
apiTargets := make([]api.ReverseProxyTarget, 0, len(r.Targets))
|
||||||
for _, target := range r.Targets {
|
for _, target := range r.Targets {
|
||||||
@@ -199,7 +191,10 @@ func (r *ReverseProxy) ToProtoMapping(operation Operation, setupKey string, oidc
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
auth := &proto.Authentication{}
|
auth := &proto.Authentication{
|
||||||
|
SessionKey: r.SessionPublicKey,
|
||||||
|
MaxSessionAgeSeconds: int64((time.Hour * 24).Seconds()),
|
||||||
|
}
|
||||||
|
|
||||||
if r.Auth.PasswordAuth != nil && r.Auth.PasswordAuth.Enabled {
|
if r.Auth.PasswordAuth != nil && r.Auth.PasswordAuth.Enabled {
|
||||||
auth.Password = true
|
auth.Password = true
|
||||||
@@ -210,16 +205,7 @@ func (r *ReverseProxy) ToProtoMapping(operation Operation, setupKey string, oidc
|
|||||||
}
|
}
|
||||||
|
|
||||||
if r.Auth.BearerAuth != nil && r.Auth.BearerAuth.Enabled {
|
if r.Auth.BearerAuth != nil && r.Auth.BearerAuth.Enabled {
|
||||||
auth.Oidc = &proto.OIDC{
|
auth.Oidc = true
|
||||||
Issuer: oidcConfig.Issuer,
|
|
||||||
Audiences: oidcConfig.Audiences,
|
|
||||||
KeysLocation: oidcConfig.KeysLocation,
|
|
||||||
MaxTokenAge: oidcConfig.MaxTokenAgeSeconds,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if r.Auth.LinkAuth != nil && r.Auth.LinkAuth.Enabled {
|
|
||||||
auth.Link = true
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return &proto.ProxyMapping{
|
return &proto.ProxyMapping{
|
||||||
@@ -291,13 +277,6 @@ func (r *ReverseProxy) FromAPIRequest(req *api.ReverseProxyRequest, accountID st
|
|||||||
}
|
}
|
||||||
r.Auth.BearerAuth = bearerAuth
|
r.Auth.BearerAuth = bearerAuth
|
||||||
}
|
}
|
||||||
|
|
||||||
if req.Auth.LinkAuth != nil {
|
|
||||||
r.Auth.LinkAuth = &LinkAuthConfig{
|
|
||||||
Enabled: req.Auth.LinkAuth.Enabled,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *ReverseProxy) Validate() error {
|
func (r *ReverseProxy) Validate() error {
|
||||||
@@ -322,3 +301,53 @@ func (r *ReverseProxy) Validate() error {
|
|||||||
func (r *ReverseProxy) EventMeta() map[string]any {
|
func (r *ReverseProxy) EventMeta() map[string]any {
|
||||||
return map[string]any{"name": r.Name, "domain": r.Domain}
|
return map[string]any{"name": r.Name, "domain": r.Domain}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *ReverseProxy) Copy() *ReverseProxy {
|
||||||
|
targets := make([]Target, len(r.Targets))
|
||||||
|
copy(targets, r.Targets)
|
||||||
|
|
||||||
|
return &ReverseProxy{
|
||||||
|
ID: r.ID,
|
||||||
|
AccountID: r.AccountID,
|
||||||
|
Name: r.Name,
|
||||||
|
Domain: r.Domain,
|
||||||
|
Targets: targets,
|
||||||
|
Enabled: r.Enabled,
|
||||||
|
Auth: r.Auth,
|
||||||
|
Meta: r.Meta,
|
||||||
|
SessionPrivateKey: r.SessionPrivateKey,
|
||||||
|
SessionPublicKey: r.SessionPublicKey,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *ReverseProxy) EncryptSensitiveData(enc *crypt.FieldEncrypt) error {
|
||||||
|
if enc == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.SessionPrivateKey != "" {
|
||||||
|
var err error
|
||||||
|
r.SessionPrivateKey, err = enc.Encrypt(r.SessionPrivateKey)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *ReverseProxy) DecryptSensitiveData(enc *crypt.FieldEncrypt) error {
|
||||||
|
if enc == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.SessionPrivateKey != "" {
|
||||||
|
var err error
|
||||||
|
r.SessionPrivateKey, err = enc.Decrypt(r.SessionPrivateKey)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -0,0 +1,69 @@
|
|||||||
|
package sessionkey
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/ed25519"
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/base64"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/golang-jwt/jwt/v5"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/proxy/auth"
|
||||||
|
)
|
||||||
|
|
||||||
|
type KeyPair struct {
|
||||||
|
PrivateKey string
|
||||||
|
PublicKey string
|
||||||
|
}
|
||||||
|
|
||||||
|
type Claims struct {
|
||||||
|
jwt.RegisteredClaims
|
||||||
|
Method auth.Method `json:"method"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func GenerateKeyPair() (*KeyPair, error) {
|
||||||
|
pub, priv, err := ed25519.GenerateKey(rand.Reader)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("generate ed25519 key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &KeyPair{
|
||||||
|
PrivateKey: base64.StdEncoding.EncodeToString(priv),
|
||||||
|
PublicKey: base64.StdEncoding.EncodeToString(pub),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func SignToken(privKeyB64, userID, domain string, method auth.Method, expiration time.Duration) (string, error) {
|
||||||
|
privKeyBytes, err := base64.StdEncoding.DecodeString(privKeyB64)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("decode private key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(privKeyBytes) != ed25519.PrivateKeySize {
|
||||||
|
return "", fmt.Errorf("invalid private key size: got %d, want %d", len(privKeyBytes), ed25519.PrivateKeySize)
|
||||||
|
}
|
||||||
|
|
||||||
|
privKey := ed25519.PrivateKey(privKeyBytes)
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
claims := Claims{
|
||||||
|
RegisteredClaims: jwt.RegisteredClaims{
|
||||||
|
Issuer: auth.SessionJWTIssuer,
|
||||||
|
Subject: userID,
|
||||||
|
Audience: jwt.ClaimStrings{domain},
|
||||||
|
ExpiresAt: jwt.NewNumericDate(now.Add(expiration)),
|
||||||
|
IssuedAt: jwt.NewNumericDate(now),
|
||||||
|
NotBefore: jwt.NewNumericDate(now),
|
||||||
|
},
|
||||||
|
Method: method,
|
||||||
|
}
|
||||||
|
|
||||||
|
token := jwt.NewWithClaims(jwt.SigningMethodEdDSA, claims)
|
||||||
|
signedToken, err := token.SignedString(privKey)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("sign token: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return signedToken, nil
|
||||||
|
}
|
||||||
@@ -23,9 +23,11 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey"
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
"github.com/netbirdio/netbird/management/server/store"
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
"github.com/netbirdio/netbird/management/server/types"
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
|
proxyauth "github.com/netbirdio/netbird/proxy/auth"
|
||||||
"github.com/netbirdio/netbird/shared/management/proto"
|
"github.com/netbirdio/netbird/shared/management/proto"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -317,7 +319,10 @@ func (s *ProxyServiceServer) Authenticate(ctx context.Context, req *proto.Authen
|
|||||||
// TODO: log the error
|
// TODO: log the error
|
||||||
return nil, status.Errorf(codes.FailedPrecondition, "failed to get reverse proxy from store: %v", err)
|
return nil, status.Errorf(codes.FailedPrecondition, "failed to get reverse proxy from store: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var authenticated bool
|
var authenticated bool
|
||||||
|
var userId string
|
||||||
|
var method proxyauth.Method
|
||||||
switch v := req.GetRequest().(type) {
|
switch v := req.GetRequest().(type) {
|
||||||
case *proto.AuthenticateRequest_Pin:
|
case *proto.AuthenticateRequest_Pin:
|
||||||
auth := proxy.Auth.PinAuth
|
auth := proxy.Auth.PinAuth
|
||||||
@@ -327,6 +332,8 @@ func (s *ProxyServiceServer) Authenticate(ctx context.Context, req *proto.Authen
|
|||||||
break
|
break
|
||||||
}
|
}
|
||||||
authenticated = subtle.ConstantTimeCompare([]byte(auth.Pin), []byte(v.Pin.GetPin())) == 1
|
authenticated = subtle.ConstantTimeCompare([]byte(auth.Pin), []byte(v.Pin.GetPin())) == 1
|
||||||
|
userId = "pin-user"
|
||||||
|
method = proxyauth.MethodPIN
|
||||||
case *proto.AuthenticateRequest_Password:
|
case *proto.AuthenticateRequest_Password:
|
||||||
auth := proxy.Auth.PasswordAuth
|
auth := proxy.Auth.PasswordAuth
|
||||||
if auth == nil || !auth.Enabled {
|
if auth == nil || !auth.Enabled {
|
||||||
@@ -335,9 +342,28 @@ func (s *ProxyServiceServer) Authenticate(ctx context.Context, req *proto.Authen
|
|||||||
break
|
break
|
||||||
}
|
}
|
||||||
authenticated = subtle.ConstantTimeCompare([]byte(auth.Password), []byte(v.Password.GetPassword())) == 1
|
authenticated = subtle.ConstantTimeCompare([]byte(auth.Password), []byte(v.Password.GetPassword())) == 1
|
||||||
|
userId = "password-user"
|
||||||
|
method = proxyauth.MethodPassword
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var token string
|
||||||
|
if authenticated && proxy.SessionPrivateKey != "" {
|
||||||
|
token, err = sessionkey.SignToken(
|
||||||
|
proxy.SessionPrivateKey,
|
||||||
|
userId,
|
||||||
|
proxy.Domain,
|
||||||
|
method,
|
||||||
|
proxyauth.DefaultSessionExpiry,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
log.WithError(err).Error("Failed to sign session token")
|
||||||
|
authenticated = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return &proto.AuthenticateResponse{
|
return &proto.AuthenticateResponse{
|
||||||
Success: authenticated,
|
Success: authenticated,
|
||||||
|
SessionToken: token,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -516,3 +542,35 @@ func (s *ProxyServiceServer) ValidateState(state string) (verifier, redirectURL
|
|||||||
|
|
||||||
return verifier, redirectURL, nil
|
return verifier, redirectURL, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GenerateSessionToken creates a signed session JWT for the given domain and user.
|
||||||
|
func (s *ProxyServiceServer) GenerateSessionToken(ctx context.Context, domain, userID string, method proxyauth.Method) (string, error) {
|
||||||
|
// Find the proxy by domain to get its signing key
|
||||||
|
proxies, err := s.reverseProxyStore.GetReverseProxies(ctx, store.LockingStrengthNone)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("get reverse proxies: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var proxy *reverseproxy.ReverseProxy
|
||||||
|
for _, p := range proxies {
|
||||||
|
if p.Domain == domain {
|
||||||
|
proxy = p
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if proxy == nil {
|
||||||
|
return "", fmt.Errorf("reverse proxy not found for domain: %s", domain)
|
||||||
|
}
|
||||||
|
|
||||||
|
if proxy.SessionPrivateKey == "" {
|
||||||
|
return "", fmt.Errorf("no session key configured for domain: %s", domain)
|
||||||
|
}
|
||||||
|
|
||||||
|
return sessionkey.SignToken(
|
||||||
|
proxy.SessionPrivateKey,
|
||||||
|
userID,
|
||||||
|
domain,
|
||||||
|
method,
|
||||||
|
proxyauth.DefaultSessionExpiry,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package proxy
|
package proxy
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
|
||||||
@@ -10,6 +11,7 @@ import (
|
|||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
|
|
||||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||||
|
"github.com/netbirdio/netbird/proxy/auth"
|
||||||
)
|
)
|
||||||
|
|
||||||
type AuthCallbackHandler struct {
|
type AuthCallbackHandler struct {
|
||||||
@@ -65,16 +67,77 @@ func (h *AuthCallbackHandler) handleCallback(w http.ResponseWriter, r *http.Requ
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
redirectQuery := redirectURL.Query()
|
// Extract user ID from the OIDC token
|
||||||
redirectQuery.Set("access_token", token.AccessToken)
|
userID := extractUserIDFromToken(r.Context(), provider, oidcConfig, token)
|
||||||
if token.RefreshToken != "" {
|
if userID == "" {
|
||||||
redirectQuery.Set("refresh_token", token.RefreshToken)
|
log.Error("Failed to extract user ID from OIDC token")
|
||||||
|
http.Error(w, "Failed to validate token", http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate session JWT instead of passing OIDC access_token
|
||||||
|
sessionToken, err := h.proxyService.GenerateSessionToken(r.Context(), redirectURL.Hostname(), userID, auth.MethodOIDC)
|
||||||
|
if err != nil {
|
||||||
|
log.WithError(err).Error("Failed to create session token")
|
||||||
|
http.Error(w, "Failed to create session", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
redirectURL.RawQuery = redirectQuery.Encode()
|
|
||||||
|
|
||||||
// Redirect must be HTTPS, regardless of what was originally intended (which should always be HTTPS but better to double-check here).
|
// Redirect must be HTTPS, regardless of what was originally intended (which should always be HTTPS but better to double-check here).
|
||||||
redirectURL.Scheme = "https"
|
redirectURL.Scheme = "https"
|
||||||
|
|
||||||
log.WithField("redirect", redirectURL.String()).Debug("OAuth callback: redirecting user with token")
|
// Pass the session token in the URL query parameter. The proxy middleware will
|
||||||
|
// extract it, validate it, set its own cookie, and redirect to remove the token from the URL.
|
||||||
|
// We cannot set the cookie here because cookies are domain-scoped (RFC 6265) and the
|
||||||
|
// management server cannot set cookies for the proxy's domain.
|
||||||
|
query := redirectURL.Query()
|
||||||
|
query.Set("session_token", sessionToken)
|
||||||
|
redirectURL.RawQuery = query.Encode()
|
||||||
|
|
||||||
|
log.WithField("redirect", redirectURL.Host).Debug("OAuth callback: redirecting user with session token")
|
||||||
http.Redirect(w, r, redirectURL.String(), http.StatusFound)
|
http.Redirect(w, r, redirectURL.String(), http.StatusFound)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// extractUserIDFromToken extracts the user ID from an OIDC token.
|
||||||
|
func extractUserIDFromToken(ctx context.Context, provider *oidc.Provider, config nbgrpc.ProxyOIDCConfig, token *oauth2.Token) string {
|
||||||
|
// Try to get ID token from the oauth2 token extras
|
||||||
|
rawIDToken, ok := token.Extra("id_token").(string)
|
||||||
|
if !ok {
|
||||||
|
log.Warn("No id_token in OIDC response")
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
verifier := provider.Verifier(&oidc.Config{
|
||||||
|
ClientID: config.ClientID,
|
||||||
|
})
|
||||||
|
|
||||||
|
idToken, err := verifier.Verify(ctx, rawIDToken)
|
||||||
|
if err != nil {
|
||||||
|
log.WithError(err).Warn("Failed to verify ID token")
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract claims
|
||||||
|
var claims struct {
|
||||||
|
Subject string `json:"sub"`
|
||||||
|
Email string `json:"email"`
|
||||||
|
UserID string `json:"user_id"`
|
||||||
|
}
|
||||||
|
if err := idToken.Claims(&claims); err != nil {
|
||||||
|
log.WithError(err).Warn("Failed to extract claims from ID token")
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prefer subject, fall back to user_id or email
|
||||||
|
if claims.Subject != "" {
|
||||||
|
return claims.Subject
|
||||||
|
}
|
||||||
|
if claims.UserID != "" {
|
||||||
|
return claims.UserID
|
||||||
|
}
|
||||||
|
if claims.Email != "" {
|
||||||
|
return claims.Email
|
||||||
|
}
|
||||||
|
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|||||||
@@ -4609,7 +4609,11 @@ func (s *SqlStore) GetPeerIDByKey(ctx context.Context, lockStrength LockingStren
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *SqlStore) CreateReverseProxy(ctx context.Context, proxy *reverseproxy.ReverseProxy) error {
|
func (s *SqlStore) CreateReverseProxy(ctx context.Context, proxy *reverseproxy.ReverseProxy) error {
|
||||||
result := s.db.Create(proxy)
|
proxyCopy := proxy.Copy()
|
||||||
|
if err := proxyCopy.EncryptSensitiveData(s.fieldEncrypt); err != nil {
|
||||||
|
return fmt.Errorf("encrypt reverse proxy data: %w", err)
|
||||||
|
}
|
||||||
|
result := s.db.Create(proxyCopy)
|
||||||
if result.Error != nil {
|
if result.Error != nil {
|
||||||
log.WithContext(ctx).Errorf("failed to create reverse proxy to store: %v", result.Error)
|
log.WithContext(ctx).Errorf("failed to create reverse proxy to store: %v", result.Error)
|
||||||
return status.Errorf(status.Internal, "failed to create reverse proxy to store")
|
return status.Errorf(status.Internal, "failed to create reverse proxy to store")
|
||||||
@@ -4619,7 +4623,11 @@ func (s *SqlStore) CreateReverseProxy(ctx context.Context, proxy *reverseproxy.R
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *SqlStore) UpdateReverseProxy(ctx context.Context, proxy *reverseproxy.ReverseProxy) error {
|
func (s *SqlStore) UpdateReverseProxy(ctx context.Context, proxy *reverseproxy.ReverseProxy) error {
|
||||||
result := s.db.Select("*").Save(proxy)
|
proxyCopy := proxy.Copy()
|
||||||
|
if err := proxyCopy.EncryptSensitiveData(s.fieldEncrypt); err != nil {
|
||||||
|
return fmt.Errorf("encrypt reverse proxy data: %w", err)
|
||||||
|
}
|
||||||
|
result := s.db.Select("*").Save(proxyCopy)
|
||||||
if result.Error != nil {
|
if result.Error != nil {
|
||||||
log.WithContext(ctx).Errorf("failed to update reverse proxy to store: %v", result.Error)
|
log.WithContext(ctx).Errorf("failed to update reverse proxy to store: %v", result.Error)
|
||||||
return status.Errorf(status.Internal, "failed to update reverse proxy to store")
|
return status.Errorf(status.Internal, "failed to update reverse proxy to store")
|
||||||
@@ -4659,6 +4667,10 @@ func (s *SqlStore) GetReverseProxyByID(ctx context.Context, lockStrength Locking
|
|||||||
return nil, status.Errorf(status.Internal, "failed to get reverse proxy from store")
|
return nil, status.Errorf(status.Internal, "failed to get reverse proxy from store")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := proxy.DecryptSensitiveData(s.fieldEncrypt); err != nil {
|
||||||
|
return nil, fmt.Errorf("decrypt reverse proxy data: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
return proxy, nil
|
return proxy, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -4674,6 +4686,10 @@ func (s *SqlStore) GetReverseProxyByDomain(ctx context.Context, accountID, domai
|
|||||||
return nil, status.Errorf(status.Internal, "failed to get reverse proxy by domain from store")
|
return nil, status.Errorf(status.Internal, "failed to get reverse proxy by domain from store")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := proxy.DecryptSensitiveData(s.fieldEncrypt); err != nil {
|
||||||
|
return nil, fmt.Errorf("decrypt reverse proxy data: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
return proxy, nil
|
return proxy, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -4690,6 +4706,12 @@ func (s *SqlStore) GetReverseProxies(ctx context.Context, lockStrength LockingSt
|
|||||||
return nil, status.Errorf(status.Internal, "failed to get reverse proxy from store")
|
return nil, status.Errorf(status.Internal, "failed to get reverse proxy from store")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for _, proxy := range proxyList {
|
||||||
|
if err := proxy.DecryptSensitiveData(s.fieldEncrypt); err != nil {
|
||||||
|
return nil, fmt.Errorf("decrypt reverse proxy data: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return proxyList, nil
|
return proxyList, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -4706,6 +4728,12 @@ func (s *SqlStore) GetAccountReverseProxies(ctx context.Context, lockStrength Lo
|
|||||||
return nil, status.Errorf(status.Internal, "failed to get reverse proxy from store")
|
return nil, status.Errorf(status.Internal, "failed to get reverse proxy from store")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for _, proxy := range proxyList {
|
||||||
|
if err := proxy.DecryptSensitiveData(s.fieldEncrypt); err != nil {
|
||||||
|
return nil, fmt.Errorf("decrypt reverse proxy data: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return proxyList, nil
|
return proxyList, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
60
proxy/auth/auth.go
Normal file
60
proxy/auth/auth.go
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
// Package auth contains exported proxy auth values.
|
||||||
|
// These are used to ensure coherent usage across management and proxy implementations.
|
||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/ed25519"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/golang-jwt/jwt/v5"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Method string
|
||||||
|
|
||||||
|
var (
|
||||||
|
MethodPassword Method = "password"
|
||||||
|
MethodPIN Method = "pin"
|
||||||
|
MethodOIDC Method = "oidc"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (m Method) String() string {
|
||||||
|
return string(m)
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
SessionCookieName = "nb_session"
|
||||||
|
DefaultSessionExpiry = 24 * time.Hour
|
||||||
|
SessionJWTIssuer = "netbird-management"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ValidateSessionJWT validates a session JWT and returns the user ID and method.
|
||||||
|
func ValidateSessionJWT(tokenString, domain string, publicKey ed25519.PublicKey) (userID, method string, err error) {
|
||||||
|
if publicKey == nil {
|
||||||
|
return "", "", fmt.Errorf("no public key configured for domain")
|
||||||
|
}
|
||||||
|
|
||||||
|
token, err := jwt.Parse(tokenString, func(t *jwt.Token) (interface{}, error) {
|
||||||
|
if _, ok := t.Method.(*jwt.SigningMethodEd25519); !ok {
|
||||||
|
return nil, fmt.Errorf("unexpected signing method: %v", t.Header["alg"])
|
||||||
|
}
|
||||||
|
return publicKey, nil
|
||||||
|
}, jwt.WithAudience(domain), jwt.WithIssuer(SessionJWTIssuer))
|
||||||
|
if err != nil {
|
||||||
|
return "", "", fmt.Errorf("parse token: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
claims, ok := token.Claims.(jwt.MapClaims)
|
||||||
|
if !ok || !token.Valid {
|
||||||
|
return "", "", fmt.Errorf("invalid token claims")
|
||||||
|
}
|
||||||
|
|
||||||
|
sub, _ := claims.GetSubject()
|
||||||
|
if sub == "" {
|
||||||
|
return "", "", fmt.Errorf("missing subject claim")
|
||||||
|
}
|
||||||
|
|
||||||
|
methodClaim, _ := claims["method"].(string)
|
||||||
|
|
||||||
|
return sub, methodClaim, nil
|
||||||
|
}
|
||||||
@@ -2,6 +2,8 @@ package auth
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/proxy/auth"
|
||||||
)
|
)
|
||||||
|
|
||||||
type requestContextKey string
|
type requestContextKey string
|
||||||
@@ -11,13 +13,13 @@ const (
|
|||||||
authUserKey requestContextKey = "authUser"
|
authUserKey requestContextKey = "authUser"
|
||||||
)
|
)
|
||||||
|
|
||||||
func withAuthMethod(ctx context.Context, method Method) context.Context {
|
func withAuthMethod(ctx context.Context, method auth.Method) context.Context {
|
||||||
return context.WithValue(ctx, authMethodKey, method)
|
return context.WithValue(ctx, authMethodKey, method)
|
||||||
}
|
}
|
||||||
|
|
||||||
func MethodFromContext(ctx context.Context) Method {
|
func MethodFromContext(ctx context.Context) auth.Method {
|
||||||
v := ctx.Value(authMethodKey)
|
v := ctx.Value(authMethodKey)
|
||||||
method, ok := v.(Method)
|
method, ok := v.(auth.Method)
|
||||||
if !ok {
|
if !ok {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,52 +0,0 @@
|
|||||||
package auth
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/http"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/shared/management/proto"
|
|
||||||
)
|
|
||||||
|
|
||||||
const linkFormId = "email"
|
|
||||||
|
|
||||||
type Link struct {
|
|
||||||
id, accountId string
|
|
||||||
client authenticator
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewLink(client authenticator, id, accountId string) Link {
|
|
||||||
return Link{
|
|
||||||
id: id,
|
|
||||||
accountId: accountId,
|
|
||||||
client: client,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (Link) Type() Method {
|
|
||||||
return MethodLink
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l Link) Authenticate(r *http.Request) (string, string) {
|
|
||||||
email := r.FormValue(linkFormId)
|
|
||||||
|
|
||||||
res, err := l.client.Authenticate(r.Context(), &proto.AuthenticateRequest{
|
|
||||||
Id: l.id,
|
|
||||||
AccountId: l.accountId,
|
|
||||||
Request: &proto.AuthenticateRequest_Link{
|
|
||||||
Link: &proto.LinkRequest{
|
|
||||||
Email: email,
|
|
||||||
Redirect: "", // TODO: calculate this.
|
|
||||||
},
|
|
||||||
},
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
// TODO: log error here
|
|
||||||
return "", linkFormId
|
|
||||||
}
|
|
||||||
|
|
||||||
if res.GetSuccess() {
|
|
||||||
// Use the email address as the user identifier.
|
|
||||||
return email, ""
|
|
||||||
}
|
|
||||||
|
|
||||||
return "", linkFormId
|
|
||||||
}
|
|
||||||
@@ -2,73 +2,50 @@ package auth
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/rand"
|
"crypto/ed25519"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/proxy/auth"
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/proxy/web"
|
"github.com/netbirdio/netbird/proxy/web"
|
||||||
"github.com/netbirdio/netbird/shared/management/proto"
|
"github.com/netbirdio/netbird/shared/management/proto"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Method string
|
|
||||||
|
|
||||||
var (
|
|
||||||
MethodPassword Method = "password"
|
|
||||||
MethodPIN Method = "pin"
|
|
||||||
MethodOIDC Method = "oidc"
|
|
||||||
MethodLink Method = "link"
|
|
||||||
)
|
|
||||||
|
|
||||||
func (m Method) String() string {
|
|
||||||
return string(m)
|
|
||||||
}
|
|
||||||
|
|
||||||
const (
|
|
||||||
sessionCookieName = "nb_session"
|
|
||||||
sessionExpiration = 24 * time.Hour
|
|
||||||
)
|
|
||||||
|
|
||||||
type session struct {
|
|
||||||
UserID string
|
|
||||||
Method Method
|
|
||||||
CreatedAt time.Time
|
|
||||||
}
|
|
||||||
|
|
||||||
type authenticator interface {
|
type authenticator interface {
|
||||||
Authenticate(ctx context.Context, in *proto.AuthenticateRequest, opts ...grpc.CallOption) (*proto.AuthenticateResponse, error)
|
Authenticate(ctx context.Context, in *proto.AuthenticateRequest, opts ...grpc.CallOption) (*proto.AuthenticateResponse, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type Scheme interface {
|
type Scheme interface {
|
||||||
Type() Method
|
Type() auth.Method
|
||||||
// Authenticate should check the passed request and determine whether
|
// Authenticate should check the passed request and determine whether
|
||||||
// it represents an authenticated user request. If it does not, then
|
// it represents an authenticated user request. If it does not, then
|
||||||
// an empty string should indicate an unauthenticated request which
|
// an empty string should indicate an unauthenticated request which
|
||||||
// will be rejected; optionally, it can also return any data that should
|
// will be rejected; optionally, it can also return any data that should
|
||||||
// be included in a UI template when prompting the user to authenticate.
|
// be included in a UI template when prompting the user to authenticate.
|
||||||
// If the request is authenticated, then a user id should be returned.
|
// If the request is authenticated, then a session token should be returned.
|
||||||
Authenticate(*http.Request) (userid string, promptData string)
|
Authenticate(*http.Request) (token string, promptData string)
|
||||||
|
}
|
||||||
|
|
||||||
|
type DomainConfig struct {
|
||||||
|
Schemes []Scheme
|
||||||
|
SessionPublicKey ed25519.PublicKey
|
||||||
|
SessionExpiration time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
type Middleware struct {
|
type Middleware struct {
|
||||||
domainsMux sync.RWMutex
|
domainsMux sync.RWMutex
|
||||||
domains map[string][]Scheme
|
domains map[string]DomainConfig
|
||||||
sessionsMux sync.RWMutex
|
|
||||||
sessions map[string]*session
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewMiddleware() *Middleware {
|
func NewMiddleware() *Middleware {
|
||||||
mw := &Middleware{
|
return &Middleware{
|
||||||
domains: make(map[string][]Scheme),
|
domains: make(map[string]DomainConfig),
|
||||||
sessions: make(map[string]*session),
|
|
||||||
}
|
}
|
||||||
// TODO: goroutine is leaked here.
|
|
||||||
go mw.cleanupSessions()
|
|
||||||
return mw
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Protect applies authentication middleware to the passed handler.
|
// Protect applies authentication middleware to the passed handler.
|
||||||
@@ -87,24 +64,20 @@ func (mw *Middleware) Protect(next http.Handler) http.Handler {
|
|||||||
host = r.Host
|
host = r.Host
|
||||||
}
|
}
|
||||||
mw.domainsMux.RLock()
|
mw.domainsMux.RLock()
|
||||||
schemes, exists := mw.domains[host]
|
config, exists := mw.domains[host]
|
||||||
mw.domainsMux.RUnlock()
|
mw.domainsMux.RUnlock()
|
||||||
|
|
||||||
// Domains that are not configured here or have no authentication schemes applied should simply pass through.
|
// Domains that are not configured here or have no authentication schemes applied should simply pass through.
|
||||||
if !exists || len(schemes) == 0 {
|
if !exists || len(config.Schemes) == 0 {
|
||||||
next.ServeHTTP(w, r)
|
next.ServeHTTP(w, r)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check for an existing session to avoid users having to authenticate for every request.
|
// Check for an existing session cookie (contains JWT)
|
||||||
// TODO: This does not work if you are load balancing across multiple proxy servers.
|
if cookie, err := r.Cookie(auth.SessionCookieName); err == nil {
|
||||||
if cookie, err := r.Cookie(sessionCookieName); err == nil {
|
if userID, method, err := auth.ValidateSessionJWT(cookie.Value, host, config.SessionPublicKey); err == nil {
|
||||||
mw.sessionsMux.RLock()
|
ctx := withAuthMethod(r.Context(), auth.Method(method))
|
||||||
sess, ok := mw.sessions[cookie.Value]
|
ctx = withAuthUser(ctx, userID)
|
||||||
mw.sessionsMux.RUnlock()
|
|
||||||
if ok {
|
|
||||||
ctx := withAuthMethod(r.Context(), sess.Method)
|
|
||||||
ctx = withAuthUser(ctx, sess.UserID)
|
|
||||||
next.ServeHTTP(w, r.WithContext(ctx))
|
next.ServeHTTP(w, r.WithContext(ctx))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -112,28 +85,59 @@ func (mw *Middleware) Protect(next http.Handler) http.Handler {
|
|||||||
|
|
||||||
// Try to authenticate with each scheme.
|
// Try to authenticate with each scheme.
|
||||||
methods := make(map[string]string)
|
methods := make(map[string]string)
|
||||||
for _, s := range schemes {
|
for _, scheme := range config.Schemes {
|
||||||
userid, promptData := s.Authenticate(r)
|
token, promptData := scheme.Authenticate(r)
|
||||||
if userid != "" {
|
if token != "" {
|
||||||
mw.createSession(w, r, userid, s.Type())
|
userid, _, err := auth.ValidateSessionJWT(token, host, config.SessionPublicKey)
|
||||||
// Clean the path and redirect to the naked URL.
|
if err != nil {
|
||||||
// This is intended to prevent leaking potentially
|
// TODO: log, this should never fail.
|
||||||
// sensitive query parameters for authentication
|
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||||
// methods.
|
return
|
||||||
http.Redirect(w, r, r.URL.Path, http.StatusFound)
|
}
|
||||||
|
|
||||||
|
expiration := config.SessionExpiration
|
||||||
|
if expiration == 0 {
|
||||||
|
expiration = auth.DefaultSessionExpiry
|
||||||
|
}
|
||||||
|
http.SetCookie(w, &http.Cookie{
|
||||||
|
Name: auth.SessionCookieName,
|
||||||
|
Value: token,
|
||||||
|
HttpOnly: true,
|
||||||
|
Secure: true,
|
||||||
|
SameSite: http.SameSiteLaxMode,
|
||||||
|
MaxAge: int(expiration.Seconds()),
|
||||||
|
})
|
||||||
|
|
||||||
|
ctx := withAuthMethod(r.Context(), scheme.Type())
|
||||||
|
ctx = withAuthUser(ctx, userid)
|
||||||
|
next.ServeHTTP(w, r.WithContext(ctx))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
methods[s.Type().String()] = promptData
|
methods[scheme.Type().String()] = promptData
|
||||||
}
|
}
|
||||||
|
|
||||||
web.ServeHTTP(w, r, map[string]any{"methods": methods})
|
web.ServeHTTP(w, r, map[string]any{"methods": methods})
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (mw *Middleware) AddDomain(domain string, schemes []Scheme) {
|
func (mw *Middleware) AddDomain(domain string, schemes []Scheme, publicKeyB64 string, expiration time.Duration) {
|
||||||
|
pubKeyBytes, err := base64.StdEncoding.DecodeString(publicKeyB64)
|
||||||
|
if err != nil {
|
||||||
|
// TODO: log
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if len(pubKeyBytes) != ed25519.PublicKeySize {
|
||||||
|
// TODO: log
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
mw.domainsMux.Lock()
|
mw.domainsMux.Lock()
|
||||||
defer mw.domainsMux.Unlock()
|
defer mw.domainsMux.Unlock()
|
||||||
mw.domains[domain] = schemes
|
mw.domains[domain] = DomainConfig{
|
||||||
|
Schemes: schemes,
|
||||||
|
SessionPublicKey: pubKeyBytes,
|
||||||
|
SessionExpiration: expiration,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (mw *Middleware) RemoveDomain(domain string) {
|
func (mw *Middleware) RemoveDomain(domain string) {
|
||||||
@@ -141,39 +145,3 @@ func (mw *Middleware) RemoveDomain(domain string) {
|
|||||||
defer mw.domainsMux.Unlock()
|
defer mw.domainsMux.Unlock()
|
||||||
delete(mw.domains, domain)
|
delete(mw.domains, domain)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (mw *Middleware) createSession(w http.ResponseWriter, r *http.Request, userID string, method Method) {
|
|
||||||
// Generate a random sessionID
|
|
||||||
b := make([]byte, 32)
|
|
||||||
_, _ = rand.Read(b)
|
|
||||||
sessionID := base64.URLEncoding.EncodeToString(b)
|
|
||||||
|
|
||||||
mw.sessionsMux.Lock()
|
|
||||||
mw.sessions[sessionID] = &session{
|
|
||||||
UserID: userID,
|
|
||||||
Method: method,
|
|
||||||
CreatedAt: time.Now(),
|
|
||||||
}
|
|
||||||
mw.sessionsMux.Unlock()
|
|
||||||
|
|
||||||
http.SetCookie(w, &http.Cookie{
|
|
||||||
Name: sessionCookieName,
|
|
||||||
Value: sessionID,
|
|
||||||
HttpOnly: true, // This cookie is only for proxy access, so no scripts should touch it.
|
|
||||||
Secure: true, // The proxy only accepts TLS traffic regardless of the service proxied behind.
|
|
||||||
SameSite: http.SameSiteLaxMode, // TODO: might this actually be strict mode?
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (mw *Middleware) cleanupSessions() {
|
|
||||||
for range time.Tick(time.Minute) {
|
|
||||||
cutoff := time.Now().Add(-sessionExpiration)
|
|
||||||
mw.sessionsMux.Lock()
|
|
||||||
for id, sess := range mw.sessions {
|
|
||||||
if sess.CreatedAt.Before(cutoff) {
|
|
||||||
delete(mw.sessions, id)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
mw.sessionsMux.Unlock()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -4,66 +4,41 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"time"
|
|
||||||
|
|
||||||
gojwt "github.com/golang-jwt/jwt/v5"
|
|
||||||
"github.com/netbirdio/netbird/shared/management/proto"
|
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/shared/auth/jwt"
|
"github.com/netbirdio/netbird/proxy/auth"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/proto"
|
||||||
)
|
)
|
||||||
|
|
||||||
type urlGenerator interface {
|
type urlGenerator interface {
|
||||||
GetOIDCURL(context.Context, *proto.GetOIDCURLRequest, ...grpc.CallOption) (*proto.GetOIDCURLResponse, error)
|
GetOIDCURL(context.Context, *proto.GetOIDCURLRequest, ...grpc.CallOption) (*proto.GetOIDCURLResponse, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// OIDCConfig holds configuration for OIDC JWT verification
|
|
||||||
type OIDCConfig struct {
|
|
||||||
Issuer string
|
|
||||||
Audiences []string
|
|
||||||
KeysLocation string
|
|
||||||
MaxTokenAgeSeconds int64
|
|
||||||
}
|
|
||||||
|
|
||||||
// oidcState stores CSRF state with expiration
|
|
||||||
type oidcState struct {
|
|
||||||
OriginalURL string
|
|
||||||
CreatedAt time.Time
|
|
||||||
}
|
|
||||||
|
|
||||||
// OIDC implements the Scheme interface for JWT/OIDC authentication
|
|
||||||
type OIDC struct {
|
type OIDC struct {
|
||||||
id, accountId string
|
id, accountId string
|
||||||
validator *jwt.Validator
|
client urlGenerator
|
||||||
maxTokenAgeSeconds int64
|
|
||||||
client urlGenerator
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewOIDC creates a new OIDC authentication scheme
|
// NewOIDC creates a new OIDC authentication scheme
|
||||||
func NewOIDC(client urlGenerator, id, accountId string, cfg OIDCConfig) *OIDC {
|
func NewOIDC(client urlGenerator, id, accountId string) OIDC {
|
||||||
return &OIDC{
|
return OIDC{
|
||||||
id: id,
|
id: id,
|
||||||
accountId: accountId,
|
accountId: accountId,
|
||||||
validator: jwt.NewValidator(
|
client: client,
|
||||||
cfg.Issuer,
|
|
||||||
cfg.Audiences,
|
|
||||||
cfg.KeysLocation,
|
|
||||||
true,
|
|
||||||
),
|
|
||||||
maxTokenAgeSeconds: cfg.MaxTokenAgeSeconds,
|
|
||||||
client: client,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (*OIDC) Type() Method {
|
func (OIDC) Type() auth.Method {
|
||||||
return MethodOIDC
|
return auth.MethodOIDC
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *OIDC) Authenticate(r *http.Request) (string, string) {
|
func (o OIDC) Authenticate(r *http.Request) (string, string) {
|
||||||
if token := r.URL.Query().Get("access_token"); token != "" {
|
// Check for the session_token query param (from OIDC redirects).
|
||||||
if userID := o.validateToken(r.Context(), token); userID != "" {
|
// The management server passes the token in the URL because it cannot set
|
||||||
return userID, ""
|
// cookies for the proxy's domain (cookies are domain-scoped per RFC 6265).
|
||||||
}
|
if token := r.URL.Query().Get("session_token"); token != "" {
|
||||||
|
return token, ""
|
||||||
}
|
}
|
||||||
|
|
||||||
redirectURL := &url.URL{
|
redirectURL := &url.URL{
|
||||||
@@ -84,55 +59,3 @@ func (o *OIDC) Authenticate(r *http.Request) (string, string) {
|
|||||||
|
|
||||||
return "", res.GetUrl()
|
return "", res.GetUrl()
|
||||||
}
|
}
|
||||||
|
|
||||||
// validateToken validates a JWT ID token and returns the user ID (subject)
|
|
||||||
// Returns empty string if token is invalid.
|
|
||||||
func (o *OIDC) validateToken(ctx context.Context, token string) string {
|
|
||||||
if o.validator == nil {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
idToken, err := o.validator.ValidateAndParse(ctx, token)
|
|
||||||
if err != nil {
|
|
||||||
// TODO: log or return?
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
iat, err := idToken.Claims.GetIssuedAt()
|
|
||||||
if err != nil {
|
|
||||||
// TODO: log or return?
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
// If max token age is 0 skip this check.
|
|
||||||
if o.maxTokenAgeSeconds > 0 && time.Since(iat.Time).Seconds() > float64(o.maxTokenAgeSeconds) {
|
|
||||||
// TODO: log or return?
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
return extractUserID(idToken)
|
|
||||||
}
|
|
||||||
|
|
||||||
func extractUserID(token *gojwt.Token) string {
|
|
||||||
if token == nil {
|
|
||||||
return "unknown"
|
|
||||||
}
|
|
||||||
claims, ok := token.Claims.(gojwt.MapClaims)
|
|
||||||
if !ok {
|
|
||||||
return "unknown"
|
|
||||||
}
|
|
||||||
return getUserIDFromClaims(claims)
|
|
||||||
}
|
|
||||||
|
|
||||||
func getUserIDFromClaims(claims gojwt.MapClaims) string {
|
|
||||||
if sub, ok := claims["sub"].(string); ok && sub != "" {
|
|
||||||
return sub
|
|
||||||
}
|
|
||||||
if userID, ok := claims["user_id"].(string); ok && userID != "" {
|
|
||||||
return userID
|
|
||||||
}
|
|
||||||
if email, ok := claims["email"].(string); ok && email != "" {
|
|
||||||
return email
|
|
||||||
}
|
|
||||||
return "unknown"
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -3,13 +3,11 @@ package auth
|
|||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/proxy/auth"
|
||||||
"github.com/netbirdio/netbird/shared/management/proto"
|
"github.com/netbirdio/netbird/shared/management/proto"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const passwordFormId = "password"
|
||||||
passwordUserId = "password-user"
|
|
||||||
passwordFormId = "password"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Password struct {
|
type Password struct {
|
||||||
id, accountId string
|
id, accountId string
|
||||||
@@ -24,8 +22,8 @@ func NewPassword(client authenticator, id, accountId string) Password {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (Password) Type() Method {
|
func (Password) Type() auth.Method {
|
||||||
return MethodPassword
|
return auth.MethodPassword
|
||||||
}
|
}
|
||||||
|
|
||||||
// Authenticate attempts to authenticate the request using a form
|
// Authenticate attempts to authenticate the request using a form
|
||||||
@@ -37,7 +35,7 @@ func (p Password) Authenticate(r *http.Request) (string, string) {
|
|||||||
password := r.FormValue(passwordFormId)
|
password := r.FormValue(passwordFormId)
|
||||||
|
|
||||||
if password == "" {
|
if password == "" {
|
||||||
// This cannot be authenticated so not worth wasting time sending the request.
|
// This cannot be authenticated, so not worth wasting time sending the request.
|
||||||
return "", passwordFormId
|
return "", passwordFormId
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -56,7 +54,7 @@ func (p Password) Authenticate(r *http.Request) (string, string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if res.GetSuccess() {
|
if res.GetSuccess() {
|
||||||
return passwordUserId, ""
|
return res.GetSessionToken(), ""
|
||||||
}
|
}
|
||||||
|
|
||||||
return "", passwordFormId
|
return "", passwordFormId
|
||||||
|
|||||||
@@ -3,13 +3,11 @@ package auth
|
|||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/proxy/auth"
|
||||||
"github.com/netbirdio/netbird/shared/management/proto"
|
"github.com/netbirdio/netbird/shared/management/proto"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const pinFormId = "pin"
|
||||||
pinUserId = "pin-user"
|
|
||||||
pinFormId = "pin"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Pin struct {
|
type Pin struct {
|
||||||
id, accountId string
|
id, accountId string
|
||||||
@@ -24,8 +22,8 @@ func NewPin(client authenticator, id, accountId string) Pin {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (Pin) Type() Method {
|
func (Pin) Type() auth.Method {
|
||||||
return MethodPIN
|
return auth.MethodPIN
|
||||||
}
|
}
|
||||||
|
|
||||||
// Authenticate attempts to authenticate the request using a form
|
// Authenticate attempts to authenticate the request using a form
|
||||||
@@ -37,7 +35,7 @@ func (p Pin) Authenticate(r *http.Request) (string, string) {
|
|||||||
pin := r.FormValue(pinFormId)
|
pin := r.FormValue(pinFormId)
|
||||||
|
|
||||||
if pin == "" {
|
if pin == "" {
|
||||||
// This cannot be authenticated so not worth wasting time sending the request.
|
// This cannot be authenticated, so not worth wasting time sending the request.
|
||||||
return "", pinFormId
|
return "", pinFormId
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -56,7 +54,7 @@ func (p Pin) Authenticate(r *http.Request) (string, string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if res.GetSuccess() {
|
if res.GetSuccess() {
|
||||||
return pinUserId, ""
|
return res.GetSessionToken(), ""
|
||||||
}
|
}
|
||||||
|
|
||||||
return "", pinFormId
|
return "", pinFormId
|
||||||
|
|||||||
@@ -386,7 +386,7 @@ func (s *Server) handleMappingStream(ctx context.Context, mappingClient proto.Pr
|
|||||||
}
|
}
|
||||||
s.Logger.Debug("Processing mapping update completed")
|
s.Logger.Debug("Processing mapping update completed")
|
||||||
|
|
||||||
// After the first mapping sync, mark initial sync complete.
|
// After the first mapping sync, mark the initial sync complete.
|
||||||
// Client health is checked directly in the startup probe.
|
// Client health is checked directly in the startup probe.
|
||||||
if !*initialSyncDone && s.healthChecker != nil {
|
if !*initialSyncDone && s.healthChecker != nil {
|
||||||
s.healthChecker.SetInitialSyncComplete()
|
s.healthChecker.SetInitialSyncComplete()
|
||||||
@@ -429,19 +429,12 @@ func (s *Server) updateMapping(ctx context.Context, mapping *proto.ProxyMapping)
|
|||||||
if mapping.GetAuth().GetPin() {
|
if mapping.GetAuth().GetPin() {
|
||||||
schemes = append(schemes, auth.NewPin(s.mgmtClient, mapping.GetId(), mapping.GetAccountId()))
|
schemes = append(schemes, auth.NewPin(s.mgmtClient, mapping.GetId(), mapping.GetAccountId()))
|
||||||
}
|
}
|
||||||
if mapping.GetAuth().GetOidc() != nil {
|
if mapping.GetAuth().GetOidc() {
|
||||||
oidc := mapping.GetAuth().GetOidc()
|
schemes = append(schemes, auth.NewOIDC(s.mgmtClient, mapping.GetId(), mapping.GetAccountId()))
|
||||||
schemes = append(schemes, auth.NewOIDC(s.mgmtClient, mapping.GetId(), mapping.GetAccountId(), auth.OIDCConfig{
|
|
||||||
Issuer: oidc.GetIssuer(),
|
|
||||||
Audiences: oidc.GetAudiences(),
|
|
||||||
KeysLocation: oidc.GetKeysLocation(),
|
|
||||||
MaxTokenAgeSeconds: oidc.GetMaxTokenAge(),
|
|
||||||
}))
|
|
||||||
}
|
}
|
||||||
if mapping.GetAuth().GetLink() {
|
|
||||||
schemes = append(schemes, auth.NewLink(s.mgmtClient, mapping.GetId(), mapping.GetAccountId()))
|
maxSessionAge := time.Duration(mapping.GetAuth().GetMaxSessionAgeSeconds()) * time.Second
|
||||||
}
|
s.auth.AddDomain(mapping.GetDomain(), schemes, mapping.GetAuth().GetSessionKey(), maxSessionAge)
|
||||||
s.auth.AddDomain(mapping.GetDomain(), schemes)
|
|
||||||
s.proxy.AddMapping(s.protoToMapping(mapping))
|
s.proxy.AddMapping(s.protoToMapping(mapping))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -47,17 +47,11 @@ message PathMapping {
|
|||||||
}
|
}
|
||||||
|
|
||||||
message Authentication {
|
message Authentication {
|
||||||
bool password = 1;
|
string session_key = 1;
|
||||||
bool pin = 2;
|
int64 max_session_age_seconds = 2;
|
||||||
optional OIDC oidc = 3;
|
bool password = 3;
|
||||||
bool link = 4;
|
bool pin = 4;
|
||||||
}
|
bool oidc = 5;
|
||||||
|
|
||||||
message OIDC {
|
|
||||||
string issuer = 1;
|
|
||||||
repeated string audiences = 2;
|
|
||||||
string keys_location = 3;
|
|
||||||
int64 max_token_age = 4;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
message ProxyMapping {
|
message ProxyMapping {
|
||||||
@@ -100,7 +94,6 @@ message AuthenticateRequest {
|
|||||||
oneof request {
|
oneof request {
|
||||||
PasswordRequest password = 3;
|
PasswordRequest password = 3;
|
||||||
PinRequest pin = 4;
|
PinRequest pin = 4;
|
||||||
LinkRequest link = 5;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -112,13 +105,9 @@ message PinRequest {
|
|||||||
string pin = 1;
|
string pin = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
message LinkRequest {
|
|
||||||
string email = 1;
|
|
||||||
string redirect = 2;
|
|
||||||
}
|
|
||||||
|
|
||||||
message AuthenticateResponse {
|
message AuthenticateResponse {
|
||||||
bool success = 1;
|
bool success = 1;
|
||||||
|
string session_token = 2;
|
||||||
}
|
}
|
||||||
|
|
||||||
enum ProxyStatus {
|
enum ProxyStatus {
|
||||||
|
|||||||
Reference in New Issue
Block a user