diff --git a/management/server/http/handler.go b/management/server/http/handler.go index c99b9b51f..6e9b029c7 100644 --- a/management/server/http/handler.go +++ b/management/server/http/handler.go @@ -36,7 +36,8 @@ func APIHandler(accountManager s.AccountManager, jwtValidator jwtclaims.JWTValid accountManager.GetAccountFromPAT, jwtValidator.ValidateAndParse, accountManager.MarkPATUsed, - authCfg.Audience) + authCfg.Audience, + authCfg.UserIDClaim) corsMiddleware := cors.AllowAll() diff --git a/management/server/http/middleware/auth_middleware.go b/management/server/http/middleware/auth_middleware.go index 898ad0875..710723124 100644 --- a/management/server/http/middleware/auth_middleware.go +++ b/management/server/http/middleware/auth_middleware.go @@ -32,6 +32,7 @@ type AuthMiddleware struct { validateAndParseToken ValidateAndParseTokenFunc markPATUsed MarkPATUsedFunc audience string + userIDClaim string } const ( @@ -39,12 +40,16 @@ const ( ) // NewAuthMiddleware instance constructor -func NewAuthMiddleware(getAccountFromPAT GetAccountFromPATFunc, validateAndParseToken ValidateAndParseTokenFunc, markPATUsed MarkPATUsedFunc, audience string) *AuthMiddleware { +func NewAuthMiddleware(getAccountFromPAT GetAccountFromPATFunc, validateAndParseToken ValidateAndParseTokenFunc, markPATUsed MarkPATUsedFunc, audience string, userIdClaim string) *AuthMiddleware { + if userIdClaim == "" { + userIdClaim = jwtclaims.UserIDClaim + } return &AuthMiddleware{ getAccountFromPAT: getAccountFromPAT, validateAndParseToken: validateAndParseToken, markPATUsed: markPATUsed, audience: audience, + userIDClaim: userIdClaim, } } @@ -127,7 +132,7 @@ func (m *AuthMiddleware) CheckPATFromRequest(w http.ResponseWriter, r *http.Requ } claimMaps := jwt.MapClaims{} - claimMaps[jwtclaims.UserIDClaim] = user.Id + claimMaps[m.userIDClaim] = user.Id claimMaps[m.audience+jwtclaims.AccountIDSuffix] = account.Id claimMaps[m.audience+jwtclaims.DomainIDSuffix] = account.Domain claimMaps[m.audience+jwtclaims.DomainCategorySuffix] = account.DomainCategory diff --git a/management/server/http/middleware/auth_middleware_test.go b/management/server/http/middleware/auth_middleware_test.go index b041b12d5..8c8c941b0 100644 --- a/management/server/http/middleware/auth_middleware_test.go +++ b/management/server/http/middleware/auth_middleware_test.go @@ -13,14 +13,15 @@ import ( ) const ( - audience = "audience" - accountID = "accountID" - domain = "domain" - userID = "userID" - tokenID = "tokenID" - PAT = "PAT" - JWT = "JWT" - wrongToken = "wrongToken" + audience = "audience" + userIDClaim = "userIDClaim" + accountID = "accountID" + domain = "domain" + userID = "userID" + tokenID = "tokenID" + PAT = "PAT" + JWT = "JWT" + wrongToken = "wrongToken" ) var testAccount = &server.Account{ @@ -102,7 +103,7 @@ func TestAuthMiddleware_Handler(t *testing.T) { // do nothing }) - authMiddleware := NewAuthMiddleware(mockGetAccountFromPAT, mockValidateAndParseToken, mockMarkPATUsed, audience) + authMiddleware := NewAuthMiddleware(mockGetAccountFromPAT, mockValidateAndParseToken, mockMarkPATUsed, audience, userIDClaim) handlerToTest := authMiddleware.Handler(nextHandler)