mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-16 07:16:38 +00:00
[management, infrastructure, idp] Simplified IdP Management - Embedded IdP (#5008)
Embed Dex as a built-in IdP to simplify self-hosting setup. Adds an embedded OIDC Identity Provider (Dex) with local user management and optional external IdP connectors (Google/GitHub/OIDC/SAML), plus device-auth flow for CLI login. Introduces instance onboarding/setup endpoints (including owner creation), field-level encryption for sensitive user data, a streamlined self-hosting provisioning script, and expanded APIs + test coverage for IdP management. more at https://github.com/netbirdio/netbird/pull/5008#issuecomment-3718987393
This commit is contained in:
@@ -78,16 +78,18 @@ func parseTime(timeString string) time.Time {
|
||||
return parsedTime
|
||||
}
|
||||
|
||||
func (c ClaimsExtractor) audienceClaim(claimName string) string {
|
||||
url, err := url.JoinPath(c.authAudience, claimName)
|
||||
func (c *ClaimsExtractor) audienceClaim(claimName string) string {
|
||||
audienceURL, err := url.JoinPath(c.authAudience, claimName)
|
||||
if err != nil {
|
||||
return c.authAudience + claimName // as it was previously
|
||||
}
|
||||
|
||||
return url
|
||||
return audienceURL
|
||||
}
|
||||
|
||||
// ToUserAuth extracts user authentication information from a JWT token
|
||||
// ToUserAuth extracts user authentication information from a JWT token.
|
||||
// The token should contain standard claims like email, name, preferred_username.
|
||||
// When using Dex, make sure to set getUserInfo: true to have these claims populated.
|
||||
func (c *ClaimsExtractor) ToUserAuth(token *jwt.Token) (auth.UserAuth, error) {
|
||||
claims := token.Claims.(jwt.MapClaims)
|
||||
userAuth := auth.UserAuth{}
|
||||
@@ -120,6 +122,21 @@ func (c *ClaimsExtractor) ToUserAuth(token *jwt.Token) (auth.UserAuth, error) {
|
||||
}
|
||||
}
|
||||
|
||||
// Extract email from standard "email" claim
|
||||
if email, ok := claims["email"].(string); ok {
|
||||
userAuth.Email = email
|
||||
}
|
||||
|
||||
// Extract name from standard "name" claim
|
||||
if name, ok := claims["name"].(string); ok {
|
||||
userAuth.Name = name
|
||||
}
|
||||
|
||||
// Extract name from standard "preferred_username" claim
|
||||
if preferredName, ok := claims["preferred_username"].(string); ok {
|
||||
userAuth.PreferredName = preferredName
|
||||
}
|
||||
|
||||
return userAuth, nil
|
||||
}
|
||||
|
||||
|
||||
322
shared/auth/jwt/extractor_test.go
Normal file
322
shared/auth/jwt/extractor_test.go
Normal file
@@ -0,0 +1,322 @@
|
||||
package jwt
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestClaimsExtractor_ToUserAuth_ExtractsEmailAndName(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
claims jwt.MapClaims
|
||||
userIDClaim string
|
||||
audience string
|
||||
expectedUserID string
|
||||
expectedEmail string
|
||||
expectedName string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "extracts email and name from standard claims",
|
||||
claims: jwt.MapClaims{
|
||||
"sub": "user-123",
|
||||
"email": "test@example.com",
|
||||
"name": "Test User",
|
||||
},
|
||||
userIDClaim: "sub",
|
||||
expectedUserID: "user-123",
|
||||
expectedEmail: "test@example.com",
|
||||
expectedName: "Test User",
|
||||
},
|
||||
{
|
||||
name: "extracts Dex encoded user ID",
|
||||
claims: jwt.MapClaims{
|
||||
"sub": "CiQ3YWFkOGMwNS0zMjg3LTQ3M2YtYjQyYS0zNjU1MDRiZjI1ZTcSBWxvY2Fs",
|
||||
"email": "dex-user@example.com",
|
||||
"name": "Dex User",
|
||||
},
|
||||
userIDClaim: "sub",
|
||||
expectedUserID: "CiQ3YWFkOGMwNS0zMjg3LTQ3M2YtYjQyYS0zNjU1MDRiZjI1ZTcSBWxvY2Fs",
|
||||
expectedEmail: "dex-user@example.com",
|
||||
expectedName: "Dex User",
|
||||
},
|
||||
{
|
||||
name: "handles missing email claim",
|
||||
claims: jwt.MapClaims{
|
||||
"sub": "user-456",
|
||||
"name": "User Without Email",
|
||||
},
|
||||
userIDClaim: "sub",
|
||||
expectedUserID: "user-456",
|
||||
expectedEmail: "",
|
||||
expectedName: "User Without Email",
|
||||
},
|
||||
{
|
||||
name: "handles missing name claim",
|
||||
claims: jwt.MapClaims{
|
||||
"sub": "user-789",
|
||||
"email": "noname@example.com",
|
||||
},
|
||||
userIDClaim: "sub",
|
||||
expectedUserID: "user-789",
|
||||
expectedEmail: "noname@example.com",
|
||||
expectedName: "",
|
||||
},
|
||||
{
|
||||
name: "handles missing both email and name",
|
||||
claims: jwt.MapClaims{
|
||||
"sub": "user-minimal",
|
||||
},
|
||||
userIDClaim: "sub",
|
||||
expectedUserID: "user-minimal",
|
||||
expectedEmail: "",
|
||||
expectedName: "",
|
||||
},
|
||||
{
|
||||
name: "extracts preferred_username",
|
||||
claims: jwt.MapClaims{
|
||||
"sub": "user-pref",
|
||||
"email": "pref@example.com",
|
||||
"name": "Preferred User",
|
||||
"preferred_username": "prefuser",
|
||||
},
|
||||
userIDClaim: "sub",
|
||||
expectedUserID: "user-pref",
|
||||
expectedEmail: "pref@example.com",
|
||||
expectedName: "Preferred User",
|
||||
},
|
||||
{
|
||||
name: "fails when user ID claim is empty",
|
||||
claims: jwt.MapClaims{
|
||||
"email": "test@example.com",
|
||||
"name": "Test User",
|
||||
},
|
||||
userIDClaim: "sub",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "uses custom user ID claim",
|
||||
claims: jwt.MapClaims{
|
||||
"user_id": "custom-user-id",
|
||||
"email": "custom@example.com",
|
||||
"name": "Custom User",
|
||||
},
|
||||
userIDClaim: "user_id",
|
||||
expectedUserID: "custom-user-id",
|
||||
expectedEmail: "custom@example.com",
|
||||
expectedName: "Custom User",
|
||||
},
|
||||
{
|
||||
name: "extracts account ID with audience prefix",
|
||||
claims: jwt.MapClaims{
|
||||
"sub": "user-with-account",
|
||||
"email": "account@example.com",
|
||||
"name": "Account User",
|
||||
"https://api.netbird.io/wt_account_id": "account-123",
|
||||
"https://api.netbird.io/wt_account_domain": "example.com",
|
||||
},
|
||||
userIDClaim: "sub",
|
||||
audience: "https://api.netbird.io",
|
||||
expectedUserID: "user-with-account",
|
||||
expectedEmail: "account@example.com",
|
||||
expectedName: "Account User",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Create extractor with options
|
||||
opts := []ClaimsExtractorOption{}
|
||||
if tt.userIDClaim != "" {
|
||||
opts = append(opts, WithUserIDClaim(tt.userIDClaim))
|
||||
}
|
||||
if tt.audience != "" {
|
||||
opts = append(opts, WithAudience(tt.audience))
|
||||
}
|
||||
extractor := NewClaimsExtractor(opts...)
|
||||
|
||||
// Create a mock token with the claims
|
||||
token := &jwt.Token{
|
||||
Claims: tt.claims,
|
||||
}
|
||||
|
||||
// Extract user auth
|
||||
userAuth, err := extractor.ToUserAuth(token)
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.expectedUserID, userAuth.UserId)
|
||||
assert.Equal(t, tt.expectedEmail, userAuth.Email)
|
||||
assert.Equal(t, tt.expectedName, userAuth.Name)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaimsExtractor_ToUserAuth_PreferredUsername(t *testing.T) {
|
||||
extractor := NewClaimsExtractor(WithUserIDClaim("sub"))
|
||||
|
||||
claims := jwt.MapClaims{
|
||||
"sub": "user-123",
|
||||
"email": "test@example.com",
|
||||
"name": "Test User",
|
||||
"preferred_username": "testuser",
|
||||
}
|
||||
|
||||
token := &jwt.Token{Claims: claims}
|
||||
|
||||
userAuth, err := extractor.ToUserAuth(token)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "user-123", userAuth.UserId)
|
||||
assert.Equal(t, "test@example.com", userAuth.Email)
|
||||
assert.Equal(t, "Test User", userAuth.Name)
|
||||
assert.Equal(t, "testuser", userAuth.PreferredName)
|
||||
}
|
||||
|
||||
func TestClaimsExtractor_ToUserAuth_LastLogin(t *testing.T) {
|
||||
extractor := NewClaimsExtractor(
|
||||
WithUserIDClaim("sub"),
|
||||
WithAudience("https://api.netbird.io"),
|
||||
)
|
||||
|
||||
expectedTime := time.Date(2025, 1, 15, 10, 30, 0, 0, time.UTC)
|
||||
|
||||
claims := jwt.MapClaims{
|
||||
"sub": "user-123",
|
||||
"email": "test@example.com",
|
||||
"https://api.netbird.io/nb_last_login": expectedTime.Format(time.RFC3339),
|
||||
}
|
||||
|
||||
token := &jwt.Token{Claims: claims}
|
||||
|
||||
userAuth, err := extractor.ToUserAuth(token)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, expectedTime, userAuth.LastLogin)
|
||||
}
|
||||
|
||||
func TestClaimsExtractor_ToUserAuth_Invited(t *testing.T) {
|
||||
extractor := NewClaimsExtractor(
|
||||
WithUserIDClaim("sub"),
|
||||
WithAudience("https://api.netbird.io"),
|
||||
)
|
||||
|
||||
claims := jwt.MapClaims{
|
||||
"sub": "user-123",
|
||||
"email": "invited@example.com",
|
||||
"https://api.netbird.io/nb_invited": true,
|
||||
}
|
||||
|
||||
token := &jwt.Token{Claims: claims}
|
||||
|
||||
userAuth, err := extractor.ToUserAuth(token)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.True(t, userAuth.Invited)
|
||||
}
|
||||
|
||||
func TestClaimsExtractor_ToGroups(t *testing.T) {
|
||||
extractor := NewClaimsExtractor(WithUserIDClaim("sub"))
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
claims jwt.MapClaims
|
||||
groupClaimName string
|
||||
expectedGroups []string
|
||||
}{
|
||||
{
|
||||
name: "extracts groups from claim",
|
||||
claims: jwt.MapClaims{
|
||||
"sub": "user-123",
|
||||
"groups": []interface{}{"admin", "users", "developers"},
|
||||
},
|
||||
groupClaimName: "groups",
|
||||
expectedGroups: []string{"admin", "users", "developers"},
|
||||
},
|
||||
{
|
||||
name: "returns empty slice when claim missing",
|
||||
claims: jwt.MapClaims{
|
||||
"sub": "user-123",
|
||||
},
|
||||
groupClaimName: "groups",
|
||||
expectedGroups: []string{},
|
||||
},
|
||||
{
|
||||
name: "handles custom claim name",
|
||||
claims: jwt.MapClaims{
|
||||
"sub": "user-123",
|
||||
"user_roles": []interface{}{"role1", "role2"},
|
||||
},
|
||||
groupClaimName: "user_roles",
|
||||
expectedGroups: []string{"role1", "role2"},
|
||||
},
|
||||
{
|
||||
name: "filters non-string values",
|
||||
claims: jwt.MapClaims{
|
||||
"sub": "user-123",
|
||||
"groups": []interface{}{"admin", 123, "users", true},
|
||||
},
|
||||
groupClaimName: "groups",
|
||||
expectedGroups: []string{"admin", "users"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
token := &jwt.Token{Claims: tt.claims}
|
||||
groups := extractor.ToGroups(token, tt.groupClaimName)
|
||||
assert.Equal(t, tt.expectedGroups, groups)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaimsExtractor_DefaultUserIDClaim(t *testing.T) {
|
||||
// When no user ID claim is specified, it should default to "sub"
|
||||
extractor := NewClaimsExtractor()
|
||||
|
||||
claims := jwt.MapClaims{
|
||||
"sub": "default-user-id",
|
||||
"email": "default@example.com",
|
||||
}
|
||||
|
||||
token := &jwt.Token{Claims: claims}
|
||||
|
||||
userAuth, err := extractor.ToUserAuth(token)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "default-user-id", userAuth.UserId)
|
||||
}
|
||||
|
||||
func TestClaimsExtractor_DexUserIDFormat(t *testing.T) {
|
||||
// Test that the extractor correctly handles Dex's encoded user ID format
|
||||
// Dex encodes user IDs as base64(protobuf{user_id, connector_id})
|
||||
extractor := NewClaimsExtractor(WithUserIDClaim("sub"))
|
||||
|
||||
// This is an actual Dex-encoded user ID
|
||||
dexEncodedID := "CiQ3YWFkOGMwNS0zMjg3LTQ3M2YtYjQyYS0zNjU1MDRiZjI1ZTcSBWxvY2Fs"
|
||||
|
||||
claims := jwt.MapClaims{
|
||||
"sub": dexEncodedID,
|
||||
"email": "dex@example.com",
|
||||
"name": "Dex User",
|
||||
}
|
||||
|
||||
token := &jwt.Token{Claims: claims}
|
||||
|
||||
userAuth, err := extractor.ToUserAuth(token)
|
||||
require.NoError(t, err)
|
||||
|
||||
// The extractor should pass through the encoded ID as-is
|
||||
// Decoding is done elsewhere (e.g., in the Dex provider)
|
||||
assert.Equal(t, dexEncodedID, userAuth.UserId)
|
||||
assert.Equal(t, "dex@example.com", userAuth.Email)
|
||||
assert.Equal(t, "Dex User", userAuth.Name)
|
||||
}
|
||||
@@ -60,6 +60,7 @@ type Validator struct {
|
||||
keysLocation string
|
||||
idpSignkeyRefreshEnabled bool
|
||||
keys *Jwks
|
||||
lastForcedRefresh time.Time
|
||||
}
|
||||
|
||||
var (
|
||||
@@ -84,26 +85,17 @@ func NewValidator(issuer string, audienceList []string, keysLocation string, idp
|
||||
}
|
||||
}
|
||||
|
||||
// forcedRefreshCooldown is the minimum time between forced key refreshes
|
||||
// to prevent abuse from invalid tokens with fake kid values
|
||||
const forcedRefreshCooldown = 30 * time.Second
|
||||
|
||||
func (v *Validator) getKeyFunc(ctx context.Context) jwt.Keyfunc {
|
||||
return func(token *jwt.Token) (interface{}, error) {
|
||||
// If keys are rotated, verify the keys prior to token validation
|
||||
if v.idpSignkeyRefreshEnabled {
|
||||
// If the keys are invalid, retrieve new ones
|
||||
// @todo propose a separate go routine to regularly check these to prevent blocking when actually
|
||||
// validating the token
|
||||
if !v.keys.stillValid() {
|
||||
v.lock.Lock()
|
||||
defer v.lock.Unlock()
|
||||
|
||||
refreshedKeys, err := getPemKeys(v.keysLocation)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("cannot get JSONWebKey: %v, falling back to old keys", err)
|
||||
refreshedKeys = v.keys
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Debugf("keys refreshed, new UTC expiration time: %s", refreshedKeys.expiresInTime.UTC())
|
||||
|
||||
v.keys = refreshedKeys
|
||||
v.refreshKeys(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -112,6 +104,18 @@ func (v *Validator) getKeyFunc(ctx context.Context) jwt.Keyfunc {
|
||||
return publicKey, nil
|
||||
}
|
||||
|
||||
// If key not found and refresh is enabled, try refreshing keys and retry once.
|
||||
// This handles the case where keys were rotated but cache hasn't expired yet.
|
||||
// Use a cooldown to prevent abuse from tokens with fake kid values.
|
||||
if errors.Is(err, errKeyNotFound) && v.idpSignkeyRefreshEnabled {
|
||||
if v.forceRefreshKeys(ctx) {
|
||||
publicKey, err = getPublicKey(token, v.keys)
|
||||
if err == nil {
|
||||
return publicKey, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
msg := fmt.Sprintf("getPublicKey error: %s", err)
|
||||
if errors.Is(err, errKeyNotFound) && !v.idpSignkeyRefreshEnabled {
|
||||
msg = fmt.Sprintf("getPublicKey error: %s. You can enable key refresh by setting HttpServerConfig.IdpSignKeyRefreshEnabled to true in your management.json file and restart the service", err)
|
||||
@@ -123,6 +127,46 @@ func (v *Validator) getKeyFunc(ctx context.Context) jwt.Keyfunc {
|
||||
}
|
||||
}
|
||||
|
||||
func (v *Validator) refreshKeys(ctx context.Context) {
|
||||
v.lock.Lock()
|
||||
defer v.lock.Unlock()
|
||||
|
||||
refreshedKeys, err := getPemKeys(v.keysLocation)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("cannot get JSONWebKey: %v, falling back to old keys", err)
|
||||
return
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Debugf("keys refreshed, new UTC expiration time: %s", refreshedKeys.expiresInTime.UTC())
|
||||
v.keys = refreshedKeys
|
||||
}
|
||||
|
||||
// forceRefreshKeys refreshes keys if the cooldown period has passed.
|
||||
// Returns true if keys were refreshed, false if cooldown prevented refresh.
|
||||
// The cooldown check is done inside the lock to prevent race conditions.
|
||||
func (v *Validator) forceRefreshKeys(ctx context.Context) bool {
|
||||
v.lock.Lock()
|
||||
defer v.lock.Unlock()
|
||||
|
||||
// Check cooldown inside lock to prevent multiple goroutines from refreshing
|
||||
if time.Since(v.lastForcedRefresh) <= forcedRefreshCooldown {
|
||||
return false
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Debugf("key not found in cache, forcing JWKS refresh")
|
||||
|
||||
refreshedKeys, err := getPemKeys(v.keysLocation)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("cannot get JSONWebKey: %v, falling back to old keys", err)
|
||||
return false
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Debugf("keys refreshed, new UTC expiration time: %s", refreshedKeys.expiresInTime.UTC())
|
||||
v.keys = refreshedKeys
|
||||
v.lastForcedRefresh = time.Now()
|
||||
return true
|
||||
}
|
||||
|
||||
// ValidateAndParse validates the token and returns the parsed token
|
||||
func (v *Validator) ValidateAndParse(ctx context.Context, token string) (*jwt.Token, error) {
|
||||
// If the token is empty...
|
||||
@@ -165,12 +209,12 @@ func (jwks *Jwks) stillValid() bool {
|
||||
func getPemKeys(keysLocation string) (*Jwks, error) {
|
||||
jwks := &Jwks{}
|
||||
|
||||
url, err := url.ParseRequestURI(keysLocation)
|
||||
requestURI, err := url.ParseRequestURI(keysLocation)
|
||||
if err != nil {
|
||||
return jwks, err
|
||||
}
|
||||
|
||||
resp, err := http.Get(url.String())
|
||||
resp, err := http.Get(requestURI.String())
|
||||
if err != nil {
|
||||
return jwks, err
|
||||
}
|
||||
|
||||
@@ -18,6 +18,15 @@ type UserAuth struct {
|
||||
|
||||
// The user id
|
||||
UserId string
|
||||
// The user's email address
|
||||
// (optional, may be empty if not in token, make sure to set getUserInfo: true in Dex to have this field)
|
||||
Email string
|
||||
// The user's name
|
||||
// (optional, may be empty if not in token, make sure to set getUserInfo: true in Dex to have this field)
|
||||
Name string
|
||||
// The user's preferred name
|
||||
// (optional, may be empty if not in token, make sure to set getUserInfo: true in Dex to have this field)
|
||||
PreferredName string
|
||||
// Last login time for this user
|
||||
LastLogin time.Time
|
||||
// The Groups the user belongs to on this account
|
||||
|
||||
Reference in New Issue
Block a user