diff --git a/management/server/account.go b/management/server/account.go index 4c13c8535..85fdf45c5 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -164,6 +164,9 @@ type Settings struct { // JWTGroupsClaimName from which we extract groups name to add it to account groups JWTGroupsClaimName string + // JWTAllowGroups list of groups to which users are allowed access + JWTAllowGroups []string `gorm:"serializer:json"` + // Extra is a dictionary of Account settings Extra *account.ExtraSettings `gorm:"embedded;embeddedPrefix:extra_"` } @@ -176,6 +179,7 @@ func (s *Settings) Copy() *Settings { JWTGroupsEnabled: s.JWTGroupsEnabled, JWTGroupsClaimName: s.JWTGroupsClaimName, GroupsPropagationEnabled: s.GroupsPropagationEnabled, + JWTAllowGroups: s.JWTAllowGroups, } if s.Extra != nil { settings.Extra = s.Extra.Copy() diff --git a/management/server/http/accounts_handler.go b/management/server/http/accounts_handler.go index c2751abd4..bab00219b 100644 --- a/management/server/http/accounts_handler.go +++ b/management/server/http/accounts_handler.go @@ -91,6 +91,9 @@ func (h *AccountsHandler) UpdateAccount(w http.ResponseWriter, r *http.Request) if req.Settings.JwtGroupsClaimName != nil { settings.JWTGroupsClaimName = *req.Settings.JwtGroupsClaimName } + if req.Settings.JwtAllowGroups != nil { + settings.JWTAllowGroups = *req.Settings.JwtAllowGroups + } updatedAccount, err := h.accountManager.UpdateAccountSettings(accountID, user.Id, settings) if err != nil { @@ -128,12 +131,18 @@ func (h *AccountsHandler) DeleteAccount(w http.ResponseWriter, r *http.Request) } func toAccountResponse(account *server.Account) *api.Account { + jwtAllowGroups := account.Settings.JWTAllowGroups + if jwtAllowGroups == nil { + jwtAllowGroups = []string{} + } + settings := api.AccountSettings{ PeerLoginExpiration: int(account.Settings.PeerLoginExpiration.Seconds()), PeerLoginExpirationEnabled: account.Settings.PeerLoginExpirationEnabled, GroupsPropagationEnabled: &account.Settings.GroupsPropagationEnabled, JwtGroupsEnabled: &account.Settings.JWTGroupsEnabled, JwtGroupsClaimName: &account.Settings.JWTGroupsClaimName, + JwtAllowGroups: &jwtAllowGroups, } if account.Settings.Extra != nil { diff --git a/management/server/http/accounts_handler_test.go b/management/server/http/accounts_handler_test.go index 08c98c830..fd2c4bfcd 100644 --- a/management/server/http/accounts_handler_test.go +++ b/management/server/http/accounts_handler_test.go @@ -95,6 +95,7 @@ func TestAccounts_AccountsHandler(t *testing.T) { GroupsPropagationEnabled: br(false), JwtGroupsClaimName: sr(""), JwtGroupsEnabled: br(false), + JwtAllowGroups: &[]string{}, }, expectedArray: true, expectedID: accountID, @@ -112,6 +113,7 @@ func TestAccounts_AccountsHandler(t *testing.T) { GroupsPropagationEnabled: br(false), JwtGroupsClaimName: sr(""), JwtGroupsEnabled: br(false), + JwtAllowGroups: &[]string{}, }, expectedArray: false, expectedID: accountID, @@ -121,7 +123,7 @@ func TestAccounts_AccountsHandler(t *testing.T) { expectedBody: true, requestType: http.MethodPut, requestPath: "/api/accounts/" + accountID, - requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 15552000,\"peer_login_expiration_enabled\": false,\"jwt_groups_enabled\":true,\"jwt_groups_claim_name\":\"roles\"}}"), + requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 15552000,\"peer_login_expiration_enabled\": false,\"jwt_groups_enabled\":true,\"jwt_groups_claim_name\":\"roles\",\"jwt_allow_groups\":[\"test\"]}}"), expectedStatus: http.StatusOK, expectedSettings: api.AccountSettings{ PeerLoginExpiration: 15552000, @@ -129,6 +131,7 @@ func TestAccounts_AccountsHandler(t *testing.T) { GroupsPropagationEnabled: br(false), JwtGroupsClaimName: sr("roles"), JwtGroupsEnabled: br(true), + JwtAllowGroups: &[]string{"test"}, }, expectedArray: false, expectedID: accountID, @@ -146,6 +149,7 @@ func TestAccounts_AccountsHandler(t *testing.T) { GroupsPropagationEnabled: br(true), JwtGroupsClaimName: sr("groups"), JwtGroupsEnabled: br(true), + JwtAllowGroups: &[]string{}, }, expectedArray: false, expectedID: accountID, diff --git a/management/server/http/api/openapi.yml b/management/server/http/api/openapi.yml index 3a125bdd7..1a049a0cf 100644 --- a/management/server/http/api/openapi.yml +++ b/management/server/http/api/openapi.yml @@ -66,6 +66,12 @@ components: description: Name of the claim from which we extract groups names to add it to account groups. type: string example: "roles" + jwt_allow_groups: + description: List of groups to which users are allowed access + type: array + items: + type: string + example: Administrators extra: $ref: '#/components/schemas/AccountExtraSettings' required: diff --git a/management/server/http/api/types.gen.go b/management/server/http/api/types.gen.go index 820cf5c48..329c66884 100644 --- a/management/server/http/api/types.gen.go +++ b/management/server/http/api/types.gen.go @@ -160,6 +160,9 @@ type AccountSettings struct { // GroupsPropagationEnabled Allows propagate the new user auto groups to peers that belongs to the user GroupsPropagationEnabled *bool `json:"groups_propagation_enabled,omitempty"` + // JwtAllowGroups List of groups to which users are allowed access + JwtAllowGroups *[]string `json:"jwt_allow_groups,omitempty"` + // JwtGroupsClaimName Name of the claim from which we extract groups names to add it to account groups. JwtGroupsClaimName *string `json:"jwt_groups_claim_name,omitempty"` diff --git a/management/server/http/handler.go b/management/server/http/handler.go index 8c77d27dc..fba7cc311 100644 --- a/management/server/http/handler.go +++ b/management/server/http/handler.go @@ -34,12 +34,20 @@ type emptyObject struct { // APIHandler creates the Management service HTTP API handler registering all the available endpoints. func APIHandler(accountManager s.AccountManager, jwtValidator jwtclaims.JWTValidator, appMetrics telemetry.AppMetrics, authCfg AuthCfg) (http.Handler, error) { + claimsExtractor := jwtclaims.NewClaimsExtractor( + jwtclaims.WithAudience(authCfg.Audience), + jwtclaims.WithUserIDClaim(authCfg.UserIDClaim), + ) + authMiddleware := middleware.NewAuthMiddleware( accountManager.GetAccountFromPAT, jwtValidator.ValidateAndParse, accountManager.MarkPATUsed, + accountManager.GetAccountFromToken, + claimsExtractor, authCfg.Audience, - authCfg.UserIDClaim) + authCfg.UserIDClaim, + ) corsMiddleware := cors.AllowAll() @@ -60,11 +68,6 @@ func APIHandler(accountManager s.AccountManager, jwtValidator jwtclaims.JWTValid AuthCfg: authCfg, } - claimsExtractor := jwtclaims.NewClaimsExtractor( - jwtclaims.WithAudience(authCfg.Audience), - jwtclaims.WithUserIDClaim(authCfg.UserIDClaim), - ) - integrations.RegisterHandlers(api.Router, accountManager, claimsExtractor) api.addAccountsEndpoint() api.addPeersEndpoint() diff --git a/management/server/http/middleware/auth_middleware.go b/management/server/http/middleware/auth_middleware.go index 99482bfb7..10bad2fdb 100644 --- a/management/server/http/middleware/auth_middleware.go +++ b/management/server/http/middleware/auth_middleware.go @@ -26,11 +26,16 @@ type ValidateAndParseTokenFunc func(token string) (*jwt.Token, error) // MarkPATUsedFunc function type MarkPATUsedFunc func(token string) error +// GetAccountFromTokenFunc function +type GetAccountFromTokenFunc func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) + // AuthMiddleware middleware to verify personal access tokens (PAT) and JWT tokens type AuthMiddleware struct { getAccountFromPAT GetAccountFromPATFunc validateAndParseToken ValidateAndParseTokenFunc markPATUsed MarkPATUsedFunc + getAccountFromToken GetAccountFromTokenFunc + claimsExtractor *jwtclaims.ClaimsExtractor audience string userIDClaim string } @@ -40,14 +45,19 @@ const ( ) // NewAuthMiddleware instance constructor -func NewAuthMiddleware(getAccountFromPAT GetAccountFromPATFunc, validateAndParseToken ValidateAndParseTokenFunc, markPATUsed MarkPATUsedFunc, audience string, userIdClaim string) *AuthMiddleware { +func NewAuthMiddleware(getAccountFromPAT GetAccountFromPATFunc, validateAndParseToken ValidateAndParseTokenFunc, + markPATUsed MarkPATUsedFunc, getAccountFromToken GetAccountFromTokenFunc, claimsExtractor *jwtclaims.ClaimsExtractor, + audience string, userIdClaim string) *AuthMiddleware { if userIdClaim == "" { userIdClaim = jwtclaims.UserIDClaim } + return &AuthMiddleware{ getAccountFromPAT: getAccountFromPAT, validateAndParseToken: validateAndParseToken, markPATUsed: markPATUsed, + getAccountFromToken: getAccountFromToken, + claimsExtractor: claimsExtractor, audience: audience, userIDClaim: userIdClaim, } @@ -107,6 +117,10 @@ func (m *AuthMiddleware) checkJWTFromRequest(w http.ResponseWriter, r *http.Requ return nil } + if err := m.verifyUserAccess(validatedToken); err != nil { + return err + } + // If we get here, everything worked and we can set the // user property in context. newRequest := r.WithContext(context.WithValue(r.Context(), userProperty, validatedToken)) //nolint @@ -115,6 +129,41 @@ func (m *AuthMiddleware) checkJWTFromRequest(w http.ResponseWriter, r *http.Requ return nil } +// verifyUserAccess checks if a user, based on a validated JWT token, +// is allowed access, particularly in cases where the admin enabled JWT +// group propagation and designated certain groups with access permissions. +func (m *AuthMiddleware) verifyUserAccess(validatedToken *jwt.Token) error { + authClaims := m.claimsExtractor.FromToken(validatedToken) + account, _, err := m.getAccountFromToken(authClaims) + if err != nil { + return fmt.Errorf("failed to get the account from token: %w", err) + } + + // Ensures JWT group synchronization to the management is enabled before, + // filtering access based on the allowed groups. + if account.Settings != nil && account.Settings.JWTGroupsEnabled { + if allowedGroups := account.Settings.JWTAllowGroups; len(allowedGroups) > 0 { + userJWTGroups := make([]string, 0) + + if claim, ok := authClaims.Raw[account.Settings.JWTGroupsClaimName]; ok { + if claimGroups, ok := claim.([]interface{}); ok { + for _, g := range claimGroups { + if group, ok := g.(string); ok { + userJWTGroups = append(userJWTGroups, group) + } + } + } + } + + if !userHasAllowedGroup(allowedGroups, userJWTGroups) { + return fmt.Errorf("user does not belong to any of the allowed JWT groups") + } + } + } + + return nil +} + // CheckPATFromRequest checks if the PAT is valid func (m *AuthMiddleware) checkPATFromRequest(w http.ResponseWriter, r *http.Request, auth []string) error { token, err := getTokenFromPATRequest(auth) @@ -168,3 +217,15 @@ func getTokenFromPATRequest(authHeaderParts []string) (string, error) { return authHeaderParts[1], nil } + +// userHasAllowedGroup checks if a user belongs to any of the allowed groups. +func userHasAllowedGroup(allowedGroups []string, userGroups []string) bool { + for _, userGroup := range userGroups { + for _, allowedGroup := range allowedGroups { + if userGroup == allowedGroup { + return true + } + } + } + return false +} diff --git a/management/server/http/middleware/auth_middleware_test.go b/management/server/http/middleware/auth_middleware_test.go index 55e5de260..531f4d886 100644 --- a/management/server/http/middleware/auth_middleware_test.go +++ b/management/server/http/middleware/auth_middleware_test.go @@ -10,6 +10,7 @@ import ( "github.com/golang-jwt/jwt" "github.com/netbirdio/netbird/management/server" + "github.com/netbirdio/netbird/management/server/jwtclaims" ) const ( @@ -54,7 +55,13 @@ func mockGetAccountFromPAT(token string) (*server.Account, *server.User, *server func mockValidateAndParseToken(token string) (*jwt.Token, error) { if token == JWT { - return &jwt.Token{}, nil + return &jwt.Token{ + Claims: jwt.MapClaims{ + userIDClaim: userID, + audience + jwtclaims.AccountIDSuffix: accountID, + }, + Valid: true, + }, nil } return nil, fmt.Errorf("JWT invalid") } @@ -66,6 +73,19 @@ func mockMarkPATUsed(token string) error { return fmt.Errorf("Should never get reached") } +func mockGetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { + if testAccount.Id != claims.AccountId { + return nil, nil, fmt.Errorf("account with id %s does not exist", claims.AccountId) + } + + user, ok := testAccount.Users[claims.UserId] + if !ok { + return nil, nil, fmt.Errorf("user with id %s does not exist", claims.UserId) + } + + return testAccount, user, nil +} + func TestAuthMiddleware_Handler(t *testing.T) { tt := []struct { name string @@ -108,7 +128,20 @@ func TestAuthMiddleware_Handler(t *testing.T) { // do nothing }) - authMiddleware := NewAuthMiddleware(mockGetAccountFromPAT, mockValidateAndParseToken, mockMarkPATUsed, audience, userIDClaim) + claimsExtractor := jwtclaims.NewClaimsExtractor( + jwtclaims.WithAudience(audience), + jwtclaims.WithUserIDClaim(userIDClaim), + ) + + authMiddleware := NewAuthMiddleware( + mockGetAccountFromPAT, + mockValidateAndParseToken, + mockMarkPATUsed, + mockGetAccountFromToken, + claimsExtractor, + audience, + userIDClaim, + ) handlerToTest := authMiddleware.Handler(nextHandler)