diff --git a/management/server/account.go b/management/server/account.go index 25ee756ad..e48fd9ee6 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -34,6 +34,8 @@ const ( PublicCategory = "public" PrivateCategory = "private" UnknownCategory = "unknown" + GroupIssuedAPI = "api" + GroupIssuedJWT = "jwt" CacheExpirationMax = 7 * 24 * 3600 * time.Second // 7 days CacheExpirationMin = 3 * 24 * 3600 * time.Second // 3 days DefaultPeerLoginExpiration = 24 * time.Hour @@ -139,6 +141,13 @@ type Settings struct { // PeerLoginExpiration is a setting that indicates when peer login expires. // Applies to all peers that have Peer.LoginExpirationEnabled set to true. PeerLoginExpiration time.Duration + + // JWTGroupsEnabled allows extract groups from JWT claim, which name defined in the JWTGroupsClaimName + // and add it to account groups. + JWTGroupsEnabled bool + + // JWTGroupsClaimName from which we extract groups name to add it to account groups + JWTGroupsClaimName string } // Copy copies the Settings struct @@ -146,6 +155,8 @@ func (s *Settings) Copy() *Settings { return &Settings{ PeerLoginExpirationEnabled: s.PeerLoginExpirationEnabled, PeerLoginExpiration: s.PeerLoginExpiration, + JWTGroupsEnabled: s.JWTGroupsEnabled, + JWTGroupsClaimName: s.JWTGroupsClaimName, } } @@ -612,6 +623,25 @@ func (a *Account) GetPeer(peerID string) *Peer { return a.Peers[peerID] } +// AddJWTGroups to existed groups if they does not exists +func (a *Account) AddJWTGroups(groups []string) error { + existedGroups := make(map[string]*Group) + for _, g := range a.Groups { + existedGroups[g.Name] = g + } + + for _, name := range groups { + if _, ok := existedGroups[name]; !ok { + a.Groups[name] = &Group{ + ID: xid.New().String(), + Name: name, + Issued: GroupIssuedJWT, + } + } + } + return nil +} + // BuildManager creates a new DefaultAccountManager with a provided Store func BuildManager(store Store, peersUpdateManager *PeersUpdateManager, idpManager idp.Manager, singleAccountModeDomain string, dnsDomain string, eventStore activity.Store, @@ -1241,6 +1271,24 @@ func (am *DefaultAccountManager) GetAccountFromToken(claims jwtclaims.Authorizat } } + if account.Settings.JWTGroupsEnabled { + if account.Settings.JWTGroupsClaimName == "" { + log.Errorf("JWT groups are enabled but no claim name is set") + return account, user, nil + } + if claim, ok := claims.Raw[account.Settings.JWTGroupsClaimName]; ok { + if groups, ok := claim.([]string); ok { + if err := account.AddJWTGroups(groups); err != nil { + log.Errorf("failed to add JWT groups: %v", err) + } + } else { + log.Debugf("JWT claim %q is not a string array", account.Settings.JWTGroupsClaimName) + } + } else { + log.Debugf("JWT claim %q not found", account.Settings.JWTGroupsClaimName) + } + } + return account, user, nil } diff --git a/management/server/account_test.go b/management/server/account_test.go index 495e93892..4d2d6d5c4 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -10,6 +10,7 @@ import ( "testing" "time" + "github.com/golang-jwt/jwt" nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/route" @@ -460,6 +461,62 @@ func TestDefaultAccountManager_GetAccountFromToken(t *testing.T) { } } +func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) { + userId := "user-id" + domain := "test.domain" + + initAccount := newAccountWithId("", userId, domain) + manager, err := createManager(t) + require.NoError(t, err, "unable to create account manager") + + accountID := initAccount.Id + _, err = manager.GetAccountByUserOrAccountID(userId, accountID, domain) + require.NoError(t, err, "create init user failed") + + claims := jwtclaims.AuthorizationClaims{ + AccountId: accountID, + Domain: domain, + UserId: userId, + DomainCategory: "test-category", + Raw: jwt.MapClaims{"idp-groups": []string{"group1", "group2"}}, + } + + t.Run("JWT groups disabled", func(t *testing.T) { + account, _, err := manager.GetAccountFromToken(claims) + require.NoError(t, err, "get account by token failed") + require.Len(t, account.Groups, 1, "only ALL group should exists") + }) + + t.Run("JWT groups enabled without claim name", func(t *testing.T) { + initAccount.Settings.JWTGroupsEnabled = true + manager.Store.SaveAccount(initAccount) + + account, _, err := manager.GetAccountFromToken(claims) + require.NoError(t, err, "get account by token failed") + require.Len(t, account.Groups, 1, "if group claim is not set no group added from JWT") + }) + + t.Run("JWT groups enabled", func(t *testing.T) { + initAccount.Settings.JWTGroupsEnabled = true + initAccount.Settings.JWTGroupsClaimName = "idp-groups" + manager.Store.SaveAccount(initAccount) + + account, _, err := manager.GetAccountFromToken(claims) + require.NoError(t, err, "get account by token failed") + require.Len(t, account.Groups, 3, "groups should be added to the account") + + g1, ok := account.Groups["group1"] + require.True(t, ok, "group1 should be added to the account") + require.Equal(t, g1.Name, "group1", "group1 name should match") + require.Equal(t, g1.Issued, GroupIssuedJWT, "group1 issued should match") + + g2, ok := account.Groups["group2"] + require.True(t, ok, "group2 should be added to the account") + require.Equal(t, g2.Name, "group2", "group2 name should match") + require.Equal(t, g2.Issued, GroupIssuedJWT, "group2 issued should match") + }) +} + func TestAccountManager_GetAccountFromPAT(t *testing.T) { store := newStore(t) account := newAccountWithId("account_id", "testuser", "") diff --git a/management/server/file_store.go b/management/server/file_store.go index c39af154a..09a6063d7 100644 --- a/management/server/file_store.go +++ b/management/server/file_store.go @@ -205,6 +205,10 @@ func restore(file string) (*FileStore, error) { group.Peers[i] = p.ID } } + // TODO: delete this block after migration + if group.Issued == "" { + group.Issued = GroupIssuedAPI + } } // detect routes that have Peer.Key as a reference and replace it with ID. diff --git a/management/server/file_store_test.go b/management/server/file_store_test.go index ce035ea83..7d1b49812 100644 --- a/management/server/file_store_test.go +++ b/management/server/file_store_test.go @@ -262,6 +262,7 @@ func TestRestore(t *testing.T) { require.Len(t, store.TokenID2UserID, 1, "failed to restore a FileStore wrong TokenID2UserID mapping length") } +// TODO: outdated, delete this func TestRestorePolicies_Migration(t *testing.T) { storeDir := t.TempDir() @@ -296,6 +297,30 @@ func TestRestorePolicies_Migration(t *testing.T) { "failed to restore a FileStore file - missing Account Policies Sources") } +func TestRestoreGroups_Migration(t *testing.T) { + storeDir := t.TempDir() + + err := util.CopyFileContents("testdata/store.json", filepath.Join(storeDir, "store.json")) + if err != nil { + t.Fatal(err) + } + + store, err := NewFileStore(storeDir, nil) + if err != nil { + return + } + + account := store.Accounts["bf1c8084-ba50-4ce7-9439-34653001fc3b"] + require.Len(t, account.Groups, 1, "failed to restore a FileStore file - missing Account Groups") + + var group *Group + for _, g := range account.Groups { + group = g + } + + require.Equal(t, group.Issued, GroupIssuedAPI, "default group should has API issued mark") +} + func TestGetAccountByPrivateDomain(t *testing.T) { storeDir := t.TempDir() diff --git a/management/server/group.go b/management/server/group.go index 688d6dafa..dd1229c86 100644 --- a/management/server/group.go +++ b/management/server/group.go @@ -14,6 +14,9 @@ type Group struct { // Name visible in the UI Name string + // Issued of the group + Issued string + // Peers list of the group Peers []string } @@ -45,9 +48,10 @@ func (g *Group) EventMeta() map[string]any { func (g *Group) Copy() *Group { return &Group{ - ID: g.ID, - Name: g.Name, - Peers: g.Peers[:], + ID: g.ID, + Name: g.Name, + Issued: g.Issued, + Peers: g.Peers[:], } } diff --git a/management/server/jwtclaims/claims.go b/management/server/jwtclaims/claims.go index 2d7dc499a..946c0b8be 100644 --- a/management/server/jwtclaims/claims.go +++ b/management/server/jwtclaims/claims.go @@ -1,9 +1,15 @@ package jwtclaims +import ( + "github.com/golang-jwt/jwt" +) + // AuthorizationClaims stores authorization information from JWTs type AuthorizationClaims struct { UserId string AccountId string Domain string DomainCategory string + + Raw jwt.MapClaims } diff --git a/management/server/jwtclaims/extractor.go b/management/server/jwtclaims/extractor.go index 9aa00a004..466856d77 100644 --- a/management/server/jwtclaims/extractor.go +++ b/management/server/jwtclaims/extractor.go @@ -73,7 +73,9 @@ func NewClaimsExtractor(options ...ClaimsExtractorOption) *ClaimsExtractor { // FromToken extracts claims from the token (after auth) func (c *ClaimsExtractor) FromToken(token *jwt.Token) AuthorizationClaims { claims := token.Claims.(jwt.MapClaims) - jwtClaims := AuthorizationClaims{} + jwtClaims := AuthorizationClaims{ + Raw: claims, + } userID, ok := claims[c.userIDClaim].(string) if !ok { return jwtClaims diff --git a/management/server/jwtclaims/extractor_test.go b/management/server/jwtclaims/extractor_test.go index 53f8818b1..9bececac6 100644 --- a/management/server/jwtclaims/extractor_test.go +++ b/management/server/jwtclaims/extractor_test.go @@ -48,6 +48,12 @@ func TestExtractClaimsFromRequestContext(t *testing.T) { Domain: "test.com", AccountId: "testAcc", DomainCategory: "public", + Raw: jwt.MapClaims{ + "https://login/wt_account_domain": "test.com", + "https://login/wt_account_domain_category": "public", + "https://login/wt_account_id": "testAcc", + "sub": "test", + }, }, testingFunc: require.EqualValues, expectedMSG: "extracted claims should match input claims", @@ -59,6 +65,10 @@ func TestExtractClaimsFromRequestContext(t *testing.T) { inputAuthorizationClaims: AuthorizationClaims{ UserId: "test", AccountId: "testAcc", + Raw: jwt.MapClaims{ + "https://login/wt_account_id": "testAcc", + "sub": "test", + }, }, testingFunc: require.EqualValues, expectedMSG: "extracted claims should match input claims", @@ -70,6 +80,10 @@ func TestExtractClaimsFromRequestContext(t *testing.T) { inputAuthorizationClaims: AuthorizationClaims{ UserId: "test", Domain: "test.com", + Raw: jwt.MapClaims{ + "https://login/wt_account_domain": "test.com", + "sub": "test", + }, }, testingFunc: require.EqualValues, expectedMSG: "extracted claims should match input claims", @@ -82,6 +96,11 @@ func TestExtractClaimsFromRequestContext(t *testing.T) { UserId: "test", Domain: "test.com", AccountId: "testAcc", + Raw: jwt.MapClaims{ + "https://login/wt_account_domain": "test.com", + "https://login/wt_account_id": "testAcc", + "sub": "test", + }, }, testingFunc: require.EqualValues, expectedMSG: "extracted claims should match input claims", @@ -92,6 +111,9 @@ func TestExtractClaimsFromRequestContext(t *testing.T) { inputAudiance: "https://login/", inputAuthorizationClaims: AuthorizationClaims{ UserId: "test", + Raw: jwt.MapClaims{ + "sub": "test", + }, }, testingFunc: require.EqualValues, expectedMSG: "extracted claims should match input claims",