mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-18 08:16:39 +00:00
add stateless proxy sessions
This commit is contained in:
@@ -9,6 +9,7 @@ import (
|
||||
"github.com/rs/xid"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey"
|
||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
@@ -75,6 +76,14 @@ func (m *managerImpl) CreateReverseProxy(ctx context.Context, accountID, userID
|
||||
|
||||
reverseProxy.Auth = authConfig
|
||||
|
||||
// Generate session JWT signing keys
|
||||
keyPair, err := sessionkey.GenerateKeyPair()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("generate session keys: %w", err)
|
||||
}
|
||||
reverseProxy.SessionPrivateKey = keyPair.PrivateKey
|
||||
reverseProxy.SessionPublicKey = keyPair.PublicKey
|
||||
|
||||
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
// Check for duplicate domain
|
||||
existingReverseProxy, err := transaction.GetReverseProxyByDomain(ctx, accountID, reverseProxy.Domain)
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/util/crypt"
|
||||
"github.com/rs/xid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
@@ -58,15 +59,10 @@ type BearerAuthConfig struct {
|
||||
DistributionGroups []string `json:"distribution_groups,omitempty" gorm:"serializer:json"`
|
||||
}
|
||||
|
||||
type LinkAuthConfig struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
}
|
||||
|
||||
type AuthConfig struct {
|
||||
PasswordAuth *PasswordAuthConfig `json:"password_auth,omitempty" gorm:"serializer:json"`
|
||||
PinAuth *PINAuthConfig `json:"pin_auth,omitempty" gorm:"serializer:json"`
|
||||
BearerAuth *BearerAuthConfig `json:"bearer_auth,omitempty" gorm:"serializer:json"`
|
||||
LinkAuth *LinkAuthConfig `json:"link_auth,omitempty" gorm:"serializer:json"`
|
||||
}
|
||||
|
||||
type OIDCValidationConfig struct {
|
||||
@@ -83,14 +79,16 @@ type ReverseProxyMeta struct {
|
||||
}
|
||||
|
||||
type ReverseProxy struct {
|
||||
ID string `gorm:"primaryKey"`
|
||||
AccountID string `gorm:"index"`
|
||||
Name string
|
||||
Domain string `gorm:"index"`
|
||||
Targets []Target `gorm:"serializer:json"`
|
||||
Enabled bool
|
||||
Auth AuthConfig `gorm:"serializer:json"`
|
||||
Meta ReverseProxyMeta `gorm:"embedded;embeddedPrefix:meta_"`
|
||||
ID string `gorm:"primaryKey"`
|
||||
AccountID string `gorm:"index"`
|
||||
Name string
|
||||
Domain string `gorm:"index"`
|
||||
Targets []Target `gorm:"serializer:json"`
|
||||
Enabled bool
|
||||
Auth AuthConfig `gorm:"serializer:json"`
|
||||
Meta ReverseProxyMeta `gorm:"embedded;embeddedPrefix:meta_"`
|
||||
SessionPrivateKey string `gorm:"column:session_private_key"`
|
||||
SessionPublicKey string `gorm:"column:session_public_key"`
|
||||
}
|
||||
|
||||
func NewReverseProxy(accountID, name, domain string, targets []Target, enabled bool) *ReverseProxy {
|
||||
@@ -132,12 +130,6 @@ func (r *ReverseProxy) ToAPIResponse() *api.ReverseProxy {
|
||||
}
|
||||
}
|
||||
|
||||
if r.Auth.LinkAuth != nil {
|
||||
authConfig.LinkAuth = &api.LinkAuthConfig{
|
||||
Enabled: r.Auth.LinkAuth.Enabled,
|
||||
}
|
||||
}
|
||||
|
||||
// Convert internal targets to API targets
|
||||
apiTargets := make([]api.ReverseProxyTarget, 0, len(r.Targets))
|
||||
for _, target := range r.Targets {
|
||||
@@ -199,7 +191,10 @@ func (r *ReverseProxy) ToProtoMapping(operation Operation, setupKey string, oidc
|
||||
})
|
||||
}
|
||||
|
||||
auth := &proto.Authentication{}
|
||||
auth := &proto.Authentication{
|
||||
SessionKey: r.SessionPublicKey,
|
||||
MaxSessionAgeSeconds: int64((time.Hour * 24).Seconds()),
|
||||
}
|
||||
|
||||
if r.Auth.PasswordAuth != nil && r.Auth.PasswordAuth.Enabled {
|
||||
auth.Password = true
|
||||
@@ -210,16 +205,7 @@ func (r *ReverseProxy) ToProtoMapping(operation Operation, setupKey string, oidc
|
||||
}
|
||||
|
||||
if r.Auth.BearerAuth != nil && r.Auth.BearerAuth.Enabled {
|
||||
auth.Oidc = &proto.OIDC{
|
||||
Issuer: oidcConfig.Issuer,
|
||||
Audiences: oidcConfig.Audiences,
|
||||
KeysLocation: oidcConfig.KeysLocation,
|
||||
MaxTokenAge: oidcConfig.MaxTokenAgeSeconds,
|
||||
}
|
||||
}
|
||||
|
||||
if r.Auth.LinkAuth != nil && r.Auth.LinkAuth.Enabled {
|
||||
auth.Link = true
|
||||
auth.Oidc = true
|
||||
}
|
||||
|
||||
return &proto.ProxyMapping{
|
||||
@@ -291,13 +277,6 @@ func (r *ReverseProxy) FromAPIRequest(req *api.ReverseProxyRequest, accountID st
|
||||
}
|
||||
r.Auth.BearerAuth = bearerAuth
|
||||
}
|
||||
|
||||
if req.Auth.LinkAuth != nil {
|
||||
r.Auth.LinkAuth = &LinkAuthConfig{
|
||||
Enabled: req.Auth.LinkAuth.Enabled,
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func (r *ReverseProxy) Validate() error {
|
||||
@@ -322,3 +301,53 @@ func (r *ReverseProxy) Validate() error {
|
||||
func (r *ReverseProxy) EventMeta() map[string]any {
|
||||
return map[string]any{"name": r.Name, "domain": r.Domain}
|
||||
}
|
||||
|
||||
func (r *ReverseProxy) Copy() *ReverseProxy {
|
||||
targets := make([]Target, len(r.Targets))
|
||||
copy(targets, r.Targets)
|
||||
|
||||
return &ReverseProxy{
|
||||
ID: r.ID,
|
||||
AccountID: r.AccountID,
|
||||
Name: r.Name,
|
||||
Domain: r.Domain,
|
||||
Targets: targets,
|
||||
Enabled: r.Enabled,
|
||||
Auth: r.Auth,
|
||||
Meta: r.Meta,
|
||||
SessionPrivateKey: r.SessionPrivateKey,
|
||||
SessionPublicKey: r.SessionPublicKey,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *ReverseProxy) EncryptSensitiveData(enc *crypt.FieldEncrypt) error {
|
||||
if enc == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if r.SessionPrivateKey != "" {
|
||||
var err error
|
||||
r.SessionPrivateKey, err = enc.Encrypt(r.SessionPrivateKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *ReverseProxy) DecryptSensitiveData(enc *crypt.FieldEncrypt) error {
|
||||
if enc == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if r.SessionPrivateKey != "" {
|
||||
var err error
|
||||
r.SessionPrivateKey, err = enc.Decrypt(r.SessionPrivateKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -0,0 +1,69 @@
|
||||
package sessionkey
|
||||
|
||||
import (
|
||||
"crypto/ed25519"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/auth"
|
||||
)
|
||||
|
||||
type KeyPair struct {
|
||||
PrivateKey string
|
||||
PublicKey string
|
||||
}
|
||||
|
||||
type Claims struct {
|
||||
jwt.RegisteredClaims
|
||||
Method auth.Method `json:"method"`
|
||||
}
|
||||
|
||||
func GenerateKeyPair() (*KeyPair, error) {
|
||||
pub, priv, err := ed25519.GenerateKey(rand.Reader)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("generate ed25519 key: %w", err)
|
||||
}
|
||||
|
||||
return &KeyPair{
|
||||
PrivateKey: base64.StdEncoding.EncodeToString(priv),
|
||||
PublicKey: base64.StdEncoding.EncodeToString(pub),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func SignToken(privKeyB64, userID, domain string, method auth.Method, expiration time.Duration) (string, error) {
|
||||
privKeyBytes, err := base64.StdEncoding.DecodeString(privKeyB64)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("decode private key: %w", err)
|
||||
}
|
||||
|
||||
if len(privKeyBytes) != ed25519.PrivateKeySize {
|
||||
return "", fmt.Errorf("invalid private key size: got %d, want %d", len(privKeyBytes), ed25519.PrivateKeySize)
|
||||
}
|
||||
|
||||
privKey := ed25519.PrivateKey(privKeyBytes)
|
||||
|
||||
now := time.Now()
|
||||
claims := Claims{
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
Issuer: auth.SessionJWTIssuer,
|
||||
Subject: userID,
|
||||
Audience: jwt.ClaimStrings{domain},
|
||||
ExpiresAt: jwt.NewNumericDate(now.Add(expiration)),
|
||||
IssuedAt: jwt.NewNumericDate(now),
|
||||
NotBefore: jwt.NewNumericDate(now),
|
||||
},
|
||||
Method: method,
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodEdDSA, claims)
|
||||
signedToken, err := token.SignedString(privKey)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("sign token: %w", err)
|
||||
}
|
||||
|
||||
return signedToken, nil
|
||||
}
|
||||
Reference in New Issue
Block a user