Get groups from the JWT tokens if it feature enabled for account

This commit is contained in:
Givi Khojanashvili
2023-06-19 10:03:34 +04:00
parent 042f124702
commit e41072e2fc
8 changed files with 172 additions and 4 deletions

View File

@@ -34,6 +34,8 @@ const (
PublicCategory = "public"
PrivateCategory = "private"
UnknownCategory = "unknown"
GroupIssuedAPI = "api"
GroupIssuedJWT = "jwt"
CacheExpirationMax = 7 * 24 * 3600 * time.Second // 7 days
CacheExpirationMin = 3 * 24 * 3600 * time.Second // 3 days
DefaultPeerLoginExpiration = 24 * time.Hour
@@ -139,6 +141,13 @@ type Settings struct {
// PeerLoginExpiration is a setting that indicates when peer login expires.
// Applies to all peers that have Peer.LoginExpirationEnabled set to true.
PeerLoginExpiration time.Duration
// JWTGroupsEnabled allows extract groups from JWT claim, which name defined in the JWTGroupsClaimName
// and add it to account groups.
JWTGroupsEnabled bool
// JWTGroupsClaimName from which we extract groups name to add it to account groups
JWTGroupsClaimName string
}
// Copy copies the Settings struct
@@ -146,6 +155,8 @@ func (s *Settings) Copy() *Settings {
return &Settings{
PeerLoginExpirationEnabled: s.PeerLoginExpirationEnabled,
PeerLoginExpiration: s.PeerLoginExpiration,
JWTGroupsEnabled: s.JWTGroupsEnabled,
JWTGroupsClaimName: s.JWTGroupsClaimName,
}
}
@@ -612,6 +623,25 @@ func (a *Account) GetPeer(peerID string) *Peer {
return a.Peers[peerID]
}
// AddJWTGroups to existed groups if they does not exists
func (a *Account) AddJWTGroups(groups []string) error {
existedGroups := make(map[string]*Group)
for _, g := range a.Groups {
existedGroups[g.Name] = g
}
for _, name := range groups {
if _, ok := existedGroups[name]; !ok {
a.Groups[name] = &Group{
ID: xid.New().String(),
Name: name,
Issued: GroupIssuedJWT,
}
}
}
return nil
}
// BuildManager creates a new DefaultAccountManager with a provided Store
func BuildManager(store Store, peersUpdateManager *PeersUpdateManager, idpManager idp.Manager,
singleAccountModeDomain string, dnsDomain string, eventStore activity.Store,
@@ -1241,6 +1271,24 @@ func (am *DefaultAccountManager) GetAccountFromToken(claims jwtclaims.Authorizat
}
}
if account.Settings.JWTGroupsEnabled {
if account.Settings.JWTGroupsClaimName == "" {
log.Errorf("JWT groups are enabled but no claim name is set")
return account, user, nil
}
if claim, ok := claims.Raw[account.Settings.JWTGroupsClaimName]; ok {
if groups, ok := claim.([]string); ok {
if err := account.AddJWTGroups(groups); err != nil {
log.Errorf("failed to add JWT groups: %v", err)
}
} else {
log.Debugf("JWT claim %q is not a string array", account.Settings.JWTGroupsClaimName)
}
} else {
log.Debugf("JWT claim %q not found", account.Settings.JWTGroupsClaimName)
}
}
return account, user, nil
}

View File

@@ -10,6 +10,7 @@ import (
"testing"
"time"
"github.com/golang-jwt/jwt"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/route"
@@ -460,6 +461,62 @@ func TestDefaultAccountManager_GetAccountFromToken(t *testing.T) {
}
}
func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
userId := "user-id"
domain := "test.domain"
initAccount := newAccountWithId("", userId, domain)
manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
accountID := initAccount.Id
_, err = manager.GetAccountByUserOrAccountID(userId, accountID, domain)
require.NoError(t, err, "create init user failed")
claims := jwtclaims.AuthorizationClaims{
AccountId: accountID,
Domain: domain,
UserId: userId,
DomainCategory: "test-category",
Raw: jwt.MapClaims{"idp-groups": []string{"group1", "group2"}},
}
t.Run("JWT groups disabled", func(t *testing.T) {
account, _, err := manager.GetAccountFromToken(claims)
require.NoError(t, err, "get account by token failed")
require.Len(t, account.Groups, 1, "only ALL group should exists")
})
t.Run("JWT groups enabled without claim name", func(t *testing.T) {
initAccount.Settings.JWTGroupsEnabled = true
manager.Store.SaveAccount(initAccount)
account, _, err := manager.GetAccountFromToken(claims)
require.NoError(t, err, "get account by token failed")
require.Len(t, account.Groups, 1, "if group claim is not set no group added from JWT")
})
t.Run("JWT groups enabled", func(t *testing.T) {
initAccount.Settings.JWTGroupsEnabled = true
initAccount.Settings.JWTGroupsClaimName = "idp-groups"
manager.Store.SaveAccount(initAccount)
account, _, err := manager.GetAccountFromToken(claims)
require.NoError(t, err, "get account by token failed")
require.Len(t, account.Groups, 3, "groups should be added to the account")
g1, ok := account.Groups["group1"]
require.True(t, ok, "group1 should be added to the account")
require.Equal(t, g1.Name, "group1", "group1 name should match")
require.Equal(t, g1.Issued, GroupIssuedJWT, "group1 issued should match")
g2, ok := account.Groups["group2"]
require.True(t, ok, "group2 should be added to the account")
require.Equal(t, g2.Name, "group2", "group2 name should match")
require.Equal(t, g2.Issued, GroupIssuedJWT, "group2 issued should match")
})
}
func TestAccountManager_GetAccountFromPAT(t *testing.T) {
store := newStore(t)
account := newAccountWithId("account_id", "testuser", "")

View File

@@ -205,6 +205,10 @@ func restore(file string) (*FileStore, error) {
group.Peers[i] = p.ID
}
}
// TODO: delete this block after migration
if group.Issued == "" {
group.Issued = GroupIssuedAPI
}
}
// detect routes that have Peer.Key as a reference and replace it with ID.

View File

@@ -262,6 +262,7 @@ func TestRestore(t *testing.T) {
require.Len(t, store.TokenID2UserID, 1, "failed to restore a FileStore wrong TokenID2UserID mapping length")
}
// TODO: outdated, delete this
func TestRestorePolicies_Migration(t *testing.T) {
storeDir := t.TempDir()
@@ -296,6 +297,30 @@ func TestRestorePolicies_Migration(t *testing.T) {
"failed to restore a FileStore file - missing Account Policies Sources")
}
func TestRestoreGroups_Migration(t *testing.T) {
storeDir := t.TempDir()
err := util.CopyFileContents("testdata/store.json", filepath.Join(storeDir, "store.json"))
if err != nil {
t.Fatal(err)
}
store, err := NewFileStore(storeDir, nil)
if err != nil {
return
}
account := store.Accounts["bf1c8084-ba50-4ce7-9439-34653001fc3b"]
require.Len(t, account.Groups, 1, "failed to restore a FileStore file - missing Account Groups")
var group *Group
for _, g := range account.Groups {
group = g
}
require.Equal(t, group.Issued, GroupIssuedAPI, "default group should has API issued mark")
}
func TestGetAccountByPrivateDomain(t *testing.T) {
storeDir := t.TempDir()

View File

@@ -14,6 +14,9 @@ type Group struct {
// Name visible in the UI
Name string
// Issued of the group
Issued string
// Peers list of the group
Peers []string
}
@@ -45,9 +48,10 @@ func (g *Group) EventMeta() map[string]any {
func (g *Group) Copy() *Group {
return &Group{
ID: g.ID,
Name: g.Name,
Peers: g.Peers[:],
ID: g.ID,
Name: g.Name,
Issued: g.Issued,
Peers: g.Peers[:],
}
}

View File

@@ -1,9 +1,15 @@
package jwtclaims
import (
"github.com/golang-jwt/jwt"
)
// AuthorizationClaims stores authorization information from JWTs
type AuthorizationClaims struct {
UserId string
AccountId string
Domain string
DomainCategory string
Raw jwt.MapClaims
}

View File

@@ -73,7 +73,9 @@ func NewClaimsExtractor(options ...ClaimsExtractorOption) *ClaimsExtractor {
// FromToken extracts claims from the token (after auth)
func (c *ClaimsExtractor) FromToken(token *jwt.Token) AuthorizationClaims {
claims := token.Claims.(jwt.MapClaims)
jwtClaims := AuthorizationClaims{}
jwtClaims := AuthorizationClaims{
Raw: claims,
}
userID, ok := claims[c.userIDClaim].(string)
if !ok {
return jwtClaims

View File

@@ -48,6 +48,12 @@ func TestExtractClaimsFromRequestContext(t *testing.T) {
Domain: "test.com",
AccountId: "testAcc",
DomainCategory: "public",
Raw: jwt.MapClaims{
"https://login/wt_account_domain": "test.com",
"https://login/wt_account_domain_category": "public",
"https://login/wt_account_id": "testAcc",
"sub": "test",
},
},
testingFunc: require.EqualValues,
expectedMSG: "extracted claims should match input claims",
@@ -59,6 +65,10 @@ func TestExtractClaimsFromRequestContext(t *testing.T) {
inputAuthorizationClaims: AuthorizationClaims{
UserId: "test",
AccountId: "testAcc",
Raw: jwt.MapClaims{
"https://login/wt_account_id": "testAcc",
"sub": "test",
},
},
testingFunc: require.EqualValues,
expectedMSG: "extracted claims should match input claims",
@@ -70,6 +80,10 @@ func TestExtractClaimsFromRequestContext(t *testing.T) {
inputAuthorizationClaims: AuthorizationClaims{
UserId: "test",
Domain: "test.com",
Raw: jwt.MapClaims{
"https://login/wt_account_domain": "test.com",
"sub": "test",
},
},
testingFunc: require.EqualValues,
expectedMSG: "extracted claims should match input claims",
@@ -82,6 +96,11 @@ func TestExtractClaimsFromRequestContext(t *testing.T) {
UserId: "test",
Domain: "test.com",
AccountId: "testAcc",
Raw: jwt.MapClaims{
"https://login/wt_account_domain": "test.com",
"https://login/wt_account_id": "testAcc",
"sub": "test",
},
},
testingFunc: require.EqualValues,
expectedMSG: "extracted claims should match input claims",
@@ -92,6 +111,9 @@ func TestExtractClaimsFromRequestContext(t *testing.T) {
inputAudiance: "https://login/",
inputAuthorizationClaims: AuthorizationClaims{
UserId: "test",
Raw: jwt.MapClaims{
"sub": "test",
},
},
testingFunc: require.EqualValues,
expectedMSG: "extracted claims should match input claims",