mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-16 15:26:40 +00:00
Get groups from the JWT tokens if it feature enabled for account
This commit is contained in:
@@ -34,6 +34,8 @@ const (
|
||||
PublicCategory = "public"
|
||||
PrivateCategory = "private"
|
||||
UnknownCategory = "unknown"
|
||||
GroupIssuedAPI = "api"
|
||||
GroupIssuedJWT = "jwt"
|
||||
CacheExpirationMax = 7 * 24 * 3600 * time.Second // 7 days
|
||||
CacheExpirationMin = 3 * 24 * 3600 * time.Second // 3 days
|
||||
DefaultPeerLoginExpiration = 24 * time.Hour
|
||||
@@ -139,6 +141,13 @@ type Settings struct {
|
||||
// PeerLoginExpiration is a setting that indicates when peer login expires.
|
||||
// Applies to all peers that have Peer.LoginExpirationEnabled set to true.
|
||||
PeerLoginExpiration time.Duration
|
||||
|
||||
// JWTGroupsEnabled allows extract groups from JWT claim, which name defined in the JWTGroupsClaimName
|
||||
// and add it to account groups.
|
||||
JWTGroupsEnabled bool
|
||||
|
||||
// JWTGroupsClaimName from which we extract groups name to add it to account groups
|
||||
JWTGroupsClaimName string
|
||||
}
|
||||
|
||||
// Copy copies the Settings struct
|
||||
@@ -146,6 +155,8 @@ func (s *Settings) Copy() *Settings {
|
||||
return &Settings{
|
||||
PeerLoginExpirationEnabled: s.PeerLoginExpirationEnabled,
|
||||
PeerLoginExpiration: s.PeerLoginExpiration,
|
||||
JWTGroupsEnabled: s.JWTGroupsEnabled,
|
||||
JWTGroupsClaimName: s.JWTGroupsClaimName,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -612,6 +623,25 @@ func (a *Account) GetPeer(peerID string) *Peer {
|
||||
return a.Peers[peerID]
|
||||
}
|
||||
|
||||
// AddJWTGroups to existed groups if they does not exists
|
||||
func (a *Account) AddJWTGroups(groups []string) error {
|
||||
existedGroups := make(map[string]*Group)
|
||||
for _, g := range a.Groups {
|
||||
existedGroups[g.Name] = g
|
||||
}
|
||||
|
||||
for _, name := range groups {
|
||||
if _, ok := existedGroups[name]; !ok {
|
||||
a.Groups[name] = &Group{
|
||||
ID: xid.New().String(),
|
||||
Name: name,
|
||||
Issued: GroupIssuedJWT,
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// BuildManager creates a new DefaultAccountManager with a provided Store
|
||||
func BuildManager(store Store, peersUpdateManager *PeersUpdateManager, idpManager idp.Manager,
|
||||
singleAccountModeDomain string, dnsDomain string, eventStore activity.Store,
|
||||
@@ -1241,6 +1271,24 @@ func (am *DefaultAccountManager) GetAccountFromToken(claims jwtclaims.Authorizat
|
||||
}
|
||||
}
|
||||
|
||||
if account.Settings.JWTGroupsEnabled {
|
||||
if account.Settings.JWTGroupsClaimName == "" {
|
||||
log.Errorf("JWT groups are enabled but no claim name is set")
|
||||
return account, user, nil
|
||||
}
|
||||
if claim, ok := claims.Raw[account.Settings.JWTGroupsClaimName]; ok {
|
||||
if groups, ok := claim.([]string); ok {
|
||||
if err := account.AddJWTGroups(groups); err != nil {
|
||||
log.Errorf("failed to add JWT groups: %v", err)
|
||||
}
|
||||
} else {
|
||||
log.Debugf("JWT claim %q is not a string array", account.Settings.JWTGroupsClaimName)
|
||||
}
|
||||
} else {
|
||||
log.Debugf("JWT claim %q not found", account.Settings.JWTGroupsClaimName)
|
||||
}
|
||||
}
|
||||
|
||||
return account, user, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
@@ -460,6 +461,62 @@ func TestDefaultAccountManager_GetAccountFromToken(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
|
||||
userId := "user-id"
|
||||
domain := "test.domain"
|
||||
|
||||
initAccount := newAccountWithId("", userId, domain)
|
||||
manager, err := createManager(t)
|
||||
require.NoError(t, err, "unable to create account manager")
|
||||
|
||||
accountID := initAccount.Id
|
||||
_, err = manager.GetAccountByUserOrAccountID(userId, accountID, domain)
|
||||
require.NoError(t, err, "create init user failed")
|
||||
|
||||
claims := jwtclaims.AuthorizationClaims{
|
||||
AccountId: accountID,
|
||||
Domain: domain,
|
||||
UserId: userId,
|
||||
DomainCategory: "test-category",
|
||||
Raw: jwt.MapClaims{"idp-groups": []string{"group1", "group2"}},
|
||||
}
|
||||
|
||||
t.Run("JWT groups disabled", func(t *testing.T) {
|
||||
account, _, err := manager.GetAccountFromToken(claims)
|
||||
require.NoError(t, err, "get account by token failed")
|
||||
require.Len(t, account.Groups, 1, "only ALL group should exists")
|
||||
})
|
||||
|
||||
t.Run("JWT groups enabled without claim name", func(t *testing.T) {
|
||||
initAccount.Settings.JWTGroupsEnabled = true
|
||||
manager.Store.SaveAccount(initAccount)
|
||||
|
||||
account, _, err := manager.GetAccountFromToken(claims)
|
||||
require.NoError(t, err, "get account by token failed")
|
||||
require.Len(t, account.Groups, 1, "if group claim is not set no group added from JWT")
|
||||
})
|
||||
|
||||
t.Run("JWT groups enabled", func(t *testing.T) {
|
||||
initAccount.Settings.JWTGroupsEnabled = true
|
||||
initAccount.Settings.JWTGroupsClaimName = "idp-groups"
|
||||
manager.Store.SaveAccount(initAccount)
|
||||
|
||||
account, _, err := manager.GetAccountFromToken(claims)
|
||||
require.NoError(t, err, "get account by token failed")
|
||||
require.Len(t, account.Groups, 3, "groups should be added to the account")
|
||||
|
||||
g1, ok := account.Groups["group1"]
|
||||
require.True(t, ok, "group1 should be added to the account")
|
||||
require.Equal(t, g1.Name, "group1", "group1 name should match")
|
||||
require.Equal(t, g1.Issued, GroupIssuedJWT, "group1 issued should match")
|
||||
|
||||
g2, ok := account.Groups["group2"]
|
||||
require.True(t, ok, "group2 should be added to the account")
|
||||
require.Equal(t, g2.Name, "group2", "group2 name should match")
|
||||
require.Equal(t, g2.Issued, GroupIssuedJWT, "group2 issued should match")
|
||||
})
|
||||
}
|
||||
|
||||
func TestAccountManager_GetAccountFromPAT(t *testing.T) {
|
||||
store := newStore(t)
|
||||
account := newAccountWithId("account_id", "testuser", "")
|
||||
|
||||
@@ -205,6 +205,10 @@ func restore(file string) (*FileStore, error) {
|
||||
group.Peers[i] = p.ID
|
||||
}
|
||||
}
|
||||
// TODO: delete this block after migration
|
||||
if group.Issued == "" {
|
||||
group.Issued = GroupIssuedAPI
|
||||
}
|
||||
}
|
||||
|
||||
// detect routes that have Peer.Key as a reference and replace it with ID.
|
||||
|
||||
@@ -262,6 +262,7 @@ func TestRestore(t *testing.T) {
|
||||
require.Len(t, store.TokenID2UserID, 1, "failed to restore a FileStore wrong TokenID2UserID mapping length")
|
||||
}
|
||||
|
||||
// TODO: outdated, delete this
|
||||
func TestRestorePolicies_Migration(t *testing.T) {
|
||||
storeDir := t.TempDir()
|
||||
|
||||
@@ -296,6 +297,30 @@ func TestRestorePolicies_Migration(t *testing.T) {
|
||||
"failed to restore a FileStore file - missing Account Policies Sources")
|
||||
}
|
||||
|
||||
func TestRestoreGroups_Migration(t *testing.T) {
|
||||
storeDir := t.TempDir()
|
||||
|
||||
err := util.CopyFileContents("testdata/store.json", filepath.Join(storeDir, "store.json"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
store, err := NewFileStore(storeDir, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
account := store.Accounts["bf1c8084-ba50-4ce7-9439-34653001fc3b"]
|
||||
require.Len(t, account.Groups, 1, "failed to restore a FileStore file - missing Account Groups")
|
||||
|
||||
var group *Group
|
||||
for _, g := range account.Groups {
|
||||
group = g
|
||||
}
|
||||
|
||||
require.Equal(t, group.Issued, GroupIssuedAPI, "default group should has API issued mark")
|
||||
}
|
||||
|
||||
func TestGetAccountByPrivateDomain(t *testing.T) {
|
||||
storeDir := t.TempDir()
|
||||
|
||||
|
||||
@@ -14,6 +14,9 @@ type Group struct {
|
||||
// Name visible in the UI
|
||||
Name string
|
||||
|
||||
// Issued of the group
|
||||
Issued string
|
||||
|
||||
// Peers list of the group
|
||||
Peers []string
|
||||
}
|
||||
@@ -45,9 +48,10 @@ func (g *Group) EventMeta() map[string]any {
|
||||
|
||||
func (g *Group) Copy() *Group {
|
||||
return &Group{
|
||||
ID: g.ID,
|
||||
Name: g.Name,
|
||||
Peers: g.Peers[:],
|
||||
ID: g.ID,
|
||||
Name: g.Name,
|
||||
Issued: g.Issued,
|
||||
Peers: g.Peers[:],
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,9 +1,15 @@
|
||||
package jwtclaims
|
||||
|
||||
import (
|
||||
"github.com/golang-jwt/jwt"
|
||||
)
|
||||
|
||||
// AuthorizationClaims stores authorization information from JWTs
|
||||
type AuthorizationClaims struct {
|
||||
UserId string
|
||||
AccountId string
|
||||
Domain string
|
||||
DomainCategory string
|
||||
|
||||
Raw jwt.MapClaims
|
||||
}
|
||||
|
||||
@@ -73,7 +73,9 @@ func NewClaimsExtractor(options ...ClaimsExtractorOption) *ClaimsExtractor {
|
||||
// FromToken extracts claims from the token (after auth)
|
||||
func (c *ClaimsExtractor) FromToken(token *jwt.Token) AuthorizationClaims {
|
||||
claims := token.Claims.(jwt.MapClaims)
|
||||
jwtClaims := AuthorizationClaims{}
|
||||
jwtClaims := AuthorizationClaims{
|
||||
Raw: claims,
|
||||
}
|
||||
userID, ok := claims[c.userIDClaim].(string)
|
||||
if !ok {
|
||||
return jwtClaims
|
||||
|
||||
@@ -48,6 +48,12 @@ func TestExtractClaimsFromRequestContext(t *testing.T) {
|
||||
Domain: "test.com",
|
||||
AccountId: "testAcc",
|
||||
DomainCategory: "public",
|
||||
Raw: jwt.MapClaims{
|
||||
"https://login/wt_account_domain": "test.com",
|
||||
"https://login/wt_account_domain_category": "public",
|
||||
"https://login/wt_account_id": "testAcc",
|
||||
"sub": "test",
|
||||
},
|
||||
},
|
||||
testingFunc: require.EqualValues,
|
||||
expectedMSG: "extracted claims should match input claims",
|
||||
@@ -59,6 +65,10 @@ func TestExtractClaimsFromRequestContext(t *testing.T) {
|
||||
inputAuthorizationClaims: AuthorizationClaims{
|
||||
UserId: "test",
|
||||
AccountId: "testAcc",
|
||||
Raw: jwt.MapClaims{
|
||||
"https://login/wt_account_id": "testAcc",
|
||||
"sub": "test",
|
||||
},
|
||||
},
|
||||
testingFunc: require.EqualValues,
|
||||
expectedMSG: "extracted claims should match input claims",
|
||||
@@ -70,6 +80,10 @@ func TestExtractClaimsFromRequestContext(t *testing.T) {
|
||||
inputAuthorizationClaims: AuthorizationClaims{
|
||||
UserId: "test",
|
||||
Domain: "test.com",
|
||||
Raw: jwt.MapClaims{
|
||||
"https://login/wt_account_domain": "test.com",
|
||||
"sub": "test",
|
||||
},
|
||||
},
|
||||
testingFunc: require.EqualValues,
|
||||
expectedMSG: "extracted claims should match input claims",
|
||||
@@ -82,6 +96,11 @@ func TestExtractClaimsFromRequestContext(t *testing.T) {
|
||||
UserId: "test",
|
||||
Domain: "test.com",
|
||||
AccountId: "testAcc",
|
||||
Raw: jwt.MapClaims{
|
||||
"https://login/wt_account_domain": "test.com",
|
||||
"https://login/wt_account_id": "testAcc",
|
||||
"sub": "test",
|
||||
},
|
||||
},
|
||||
testingFunc: require.EqualValues,
|
||||
expectedMSG: "extracted claims should match input claims",
|
||||
@@ -92,6 +111,9 @@ func TestExtractClaimsFromRequestContext(t *testing.T) {
|
||||
inputAudiance: "https://login/",
|
||||
inputAuthorizationClaims: AuthorizationClaims{
|
||||
UserId: "test",
|
||||
Raw: jwt.MapClaims{
|
||||
"sub": "test",
|
||||
},
|
||||
},
|
||||
testingFunc: require.EqualValues,
|
||||
expectedMSG: "extracted claims should match input claims",
|
||||
|
||||
Reference in New Issue
Block a user