Enable JWT group-based user authorization (#1368)

* Extend management API to support list of allowed JWT groups (#1366)

* Add JWTAllowGroups settings to account management

* Return an empty group list if jwt allow groups is not set

* Add JwtAllowGroups to account settings in handler test

* Add JWT group-based user authorization (#1373)

* Add JWTAllowGroups settings to account management

* Return an empty group list if jwt allow groups is not set

* Add JwtAllowGroups to account settings in handler test

* Implement user access validation authentication based on JWT groups

* Remove the slices package import due to compatibility issues with the gitHub workflow(s) Go version

* Refactor auth middleware and test for extracted claim handling

* Optimize JWT group check in auth middleware to cover nil and empty allowed groups
This commit is contained in:
Bethuel Mmbaga
2023-12-11 18:59:15 +03:00
committed by GitHub
parent 5ecafef5d2
commit d275d411aa
8 changed files with 133 additions and 10 deletions

View File

@@ -26,11 +26,16 @@ type ValidateAndParseTokenFunc func(token string) (*jwt.Token, error)
// MarkPATUsedFunc function
type MarkPATUsedFunc func(token string) error
// GetAccountFromTokenFunc function
type GetAccountFromTokenFunc func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error)
// AuthMiddleware middleware to verify personal access tokens (PAT) and JWT tokens
type AuthMiddleware struct {
getAccountFromPAT GetAccountFromPATFunc
validateAndParseToken ValidateAndParseTokenFunc
markPATUsed MarkPATUsedFunc
getAccountFromToken GetAccountFromTokenFunc
claimsExtractor *jwtclaims.ClaimsExtractor
audience string
userIDClaim string
}
@@ -40,14 +45,19 @@ const (
)
// NewAuthMiddleware instance constructor
func NewAuthMiddleware(getAccountFromPAT GetAccountFromPATFunc, validateAndParseToken ValidateAndParseTokenFunc, markPATUsed MarkPATUsedFunc, audience string, userIdClaim string) *AuthMiddleware {
func NewAuthMiddleware(getAccountFromPAT GetAccountFromPATFunc, validateAndParseToken ValidateAndParseTokenFunc,
markPATUsed MarkPATUsedFunc, getAccountFromToken GetAccountFromTokenFunc, claimsExtractor *jwtclaims.ClaimsExtractor,
audience string, userIdClaim string) *AuthMiddleware {
if userIdClaim == "" {
userIdClaim = jwtclaims.UserIDClaim
}
return &AuthMiddleware{
getAccountFromPAT: getAccountFromPAT,
validateAndParseToken: validateAndParseToken,
markPATUsed: markPATUsed,
getAccountFromToken: getAccountFromToken,
claimsExtractor: claimsExtractor,
audience: audience,
userIDClaim: userIdClaim,
}
@@ -107,6 +117,10 @@ func (m *AuthMiddleware) checkJWTFromRequest(w http.ResponseWriter, r *http.Requ
return nil
}
if err := m.verifyUserAccess(validatedToken); err != nil {
return err
}
// If we get here, everything worked and we can set the
// user property in context.
newRequest := r.WithContext(context.WithValue(r.Context(), userProperty, validatedToken)) //nolint
@@ -115,6 +129,41 @@ func (m *AuthMiddleware) checkJWTFromRequest(w http.ResponseWriter, r *http.Requ
return nil
}
// verifyUserAccess checks if a user, based on a validated JWT token,
// is allowed access, particularly in cases where the admin enabled JWT
// group propagation and designated certain groups with access permissions.
func (m *AuthMiddleware) verifyUserAccess(validatedToken *jwt.Token) error {
authClaims := m.claimsExtractor.FromToken(validatedToken)
account, _, err := m.getAccountFromToken(authClaims)
if err != nil {
return fmt.Errorf("failed to get the account from token: %w", err)
}
// Ensures JWT group synchronization to the management is enabled before,
// filtering access based on the allowed groups.
if account.Settings != nil && account.Settings.JWTGroupsEnabled {
if allowedGroups := account.Settings.JWTAllowGroups; len(allowedGroups) > 0 {
userJWTGroups := make([]string, 0)
if claim, ok := authClaims.Raw[account.Settings.JWTGroupsClaimName]; ok {
if claimGroups, ok := claim.([]interface{}); ok {
for _, g := range claimGroups {
if group, ok := g.(string); ok {
userJWTGroups = append(userJWTGroups, group)
}
}
}
}
if !userHasAllowedGroup(allowedGroups, userJWTGroups) {
return fmt.Errorf("user does not belong to any of the allowed JWT groups")
}
}
}
return nil
}
// CheckPATFromRequest checks if the PAT is valid
func (m *AuthMiddleware) checkPATFromRequest(w http.ResponseWriter, r *http.Request, auth []string) error {
token, err := getTokenFromPATRequest(auth)
@@ -168,3 +217,15 @@ func getTokenFromPATRequest(authHeaderParts []string) (string, error) {
return authHeaderParts[1], nil
}
// userHasAllowedGroup checks if a user belongs to any of the allowed groups.
func userHasAllowedGroup(allowedGroups []string, userGroups []string) bool {
for _, userGroup := range userGroups {
for _, allowedGroup := range allowedGroups {
if userGroup == allowedGroup {
return true
}
}
}
return false
}

View File

@@ -10,6 +10,7 @@ import (
"github.com/golang-jwt/jwt"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/jwtclaims"
)
const (
@@ -54,7 +55,13 @@ func mockGetAccountFromPAT(token string) (*server.Account, *server.User, *server
func mockValidateAndParseToken(token string) (*jwt.Token, error) {
if token == JWT {
return &jwt.Token{}, nil
return &jwt.Token{
Claims: jwt.MapClaims{
userIDClaim: userID,
audience + jwtclaims.AccountIDSuffix: accountID,
},
Valid: true,
}, nil
}
return nil, fmt.Errorf("JWT invalid")
}
@@ -66,6 +73,19 @@ func mockMarkPATUsed(token string) error {
return fmt.Errorf("Should never get reached")
}
func mockGetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
if testAccount.Id != claims.AccountId {
return nil, nil, fmt.Errorf("account with id %s does not exist", claims.AccountId)
}
user, ok := testAccount.Users[claims.UserId]
if !ok {
return nil, nil, fmt.Errorf("user with id %s does not exist", claims.UserId)
}
return testAccount, user, nil
}
func TestAuthMiddleware_Handler(t *testing.T) {
tt := []struct {
name string
@@ -108,7 +128,20 @@ func TestAuthMiddleware_Handler(t *testing.T) {
// do nothing
})
authMiddleware := NewAuthMiddleware(mockGetAccountFromPAT, mockValidateAndParseToken, mockMarkPATUsed, audience, userIDClaim)
claimsExtractor := jwtclaims.NewClaimsExtractor(
jwtclaims.WithAudience(audience),
jwtclaims.WithUserIDClaim(userIDClaim),
)
authMiddleware := NewAuthMiddleware(
mockGetAccountFromPAT,
mockValidateAndParseToken,
mockMarkPATUsed,
mockGetAccountFromToken,
claimsExtractor,
audience,
userIDClaim,
)
handlerToTest := authMiddleware.Handler(nextHandler)