Feature: add custom id claim (#667)

This feature allows using the custom claim in the JWT token as a user ID.

Refactor claims extractor with options support

Add is_current to the user API response
This commit is contained in:
Givi Khojanashvili
2023-02-04 00:47:20 +04:00
committed by GitHub
parent 494e56d1be
commit 3ec8274b8e
32 changed files with 474 additions and 305 deletions

View File

@@ -1,8 +1,9 @@
package jwtclaims
import (
"github.com/golang-jwt/jwt"
"net/http"
"github.com/golang-jwt/jwt"
)
const (
@@ -14,51 +15,85 @@ const (
)
// Extract function type
type ExtractClaims func(r *http.Request, authAudiance string) AuthorizationClaims
type ExtractClaims func(r *http.Request) AuthorizationClaims
// ClaimsExtractor struct that holds the extract function
type ClaimsExtractor struct {
ExtractClaimsFromRequestContext ExtractClaims
authAudience string
userIDClaim string
FromRequestContext ExtractClaims
}
// ClaimsExtractorOption is a function that configures the ClaimsExtractor
type ClaimsExtractorOption func(*ClaimsExtractor)
// WithAudience sets the audience for the extractor
func WithAudience(audience string) ClaimsExtractorOption {
return func(c *ClaimsExtractor) {
c.authAudience = audience
}
}
// WithUserIDClaim sets the user id claim for the extractor
func WithUserIDClaim(userIDClaim string) ClaimsExtractorOption {
return func(c *ClaimsExtractor) {
c.userIDClaim = userIDClaim
}
}
// WithFromRequestContext sets the function that extracts claims from the request context
func WithFromRequestContext(ec ExtractClaims) ClaimsExtractorOption {
return func(c *ClaimsExtractor) {
c.FromRequestContext = ec
}
}
// NewClaimsExtractor returns an extractor, and if provided with a function with ExtractClaims signature,
// then it will use that logic. Uses ExtractClaimsFromRequestContext by default
func NewClaimsExtractor(e ExtractClaims) *ClaimsExtractor {
var extractFunc ExtractClaims
if extractFunc = e; extractFunc == nil {
extractFunc = ExtractClaimsFromRequestContext
func NewClaimsExtractor(options ...ClaimsExtractorOption) *ClaimsExtractor {
ce := &ClaimsExtractor{}
for _, option := range options {
option(ce)
}
return &ClaimsExtractor{
ExtractClaimsFromRequestContext: extractFunc,
if ce.FromRequestContext == nil {
ce.FromRequestContext = ce.fromRequestContext
}
if ce.userIDClaim == "" {
ce.userIDClaim = UserIDClaim
}
return ce
}
// ExtractClaimsFromRequestContext extracts claims from the request context previously filled by the JWT token (after auth)
func ExtractClaimsFromRequestContext(r *http.Request, authAudience string) AuthorizationClaims {
if r.Context().Value(TokenUserProperty) == nil {
return AuthorizationClaims{}
}
token := r.Context().Value(TokenUserProperty).(*jwt.Token)
return ExtractClaimsWithToken(token, authAudience)
}
// ExtractClaimsWithToken extracts claims from the token (after auth)
func ExtractClaimsWithToken(token *jwt.Token, authAudience string) AuthorizationClaims {
// FromToken extracts claims from the token (after auth)
func (c *ClaimsExtractor) FromToken(token *jwt.Token) AuthorizationClaims {
claims := token.Claims.(jwt.MapClaims)
jwtClaims := AuthorizationClaims{}
jwtClaims.UserId = claims[UserIDClaim].(string)
accountIdClaim, ok := claims[authAudience+AccountIDSuffix]
if ok {
jwtClaims.AccountId = accountIdClaim.(string)
userID, ok := claims[c.userIDClaim].(string)
if !ok {
return jwtClaims
}
domainClaim, ok := claims[authAudience+DomainIDSuffix]
jwtClaims.UserId = userID
accountIDClaim, ok := claims[c.authAudience+AccountIDSuffix]
if ok {
jwtClaims.AccountId = accountIDClaim.(string)
}
domainClaim, ok := claims[c.authAudience+DomainIDSuffix]
if ok {
jwtClaims.Domain = domainClaim.(string)
}
domainCategoryClaim, ok := claims[authAudience+DomainCategorySuffix]
domainCategoryClaim, ok := claims[c.authAudience+DomainCategorySuffix]
if ok {
jwtClaims.DomainCategory = domainCategoryClaim.(string)
}
return jwtClaims
}
// fromRequestContext extracts claims from the request context previously filled by the JWT token (after auth)
func (c *ClaimsExtractor) fromRequestContext(r *http.Request) AuthorizationClaims {
if r.Context().Value(TokenUserProperty) == nil {
return AuthorizationClaims{}
}
token := r.Context().Value(TokenUserProperty).(*jwt.Token)
return c.FromToken(token)
}

View File

@@ -2,10 +2,11 @@ package jwtclaims
import (
"context"
"github.com/golang-jwt/jwt"
"github.com/stretchr/testify/require"
"net/http"
"testing"
"github.com/golang-jwt/jwt"
"github.com/stretchr/testify/require"
)
func newTestRequestWithJWT(t *testing.T, claims AuthorizationClaims, audiance string) *http.Request {
@@ -31,7 +32,6 @@ func newTestRequestWithJWT(t *testing.T, claims AuthorizationClaims, audiance st
}
func TestExtractClaimsFromRequestContext(t *testing.T) {
type test struct {
name string
inputAuthorizationClaims AuthorizationClaims
@@ -99,12 +99,84 @@ func TestExtractClaimsFromRequestContext(t *testing.T) {
for _, testCase := range []test{testCase1, testCase2, testCase3, testCase4, testCase5} {
t.Run(testCase.name, func(t *testing.T) {
request := newTestRequestWithJWT(t, testCase.inputAuthorizationClaims, testCase.inputAudiance)
extractedClaims := ExtractClaimsFromRequestContext(request, testCase.inputAudiance)
extractor := NewClaimsExtractor(WithAudience(testCase.inputAudiance))
extractedClaims := extractor.FromRequestContext(request)
testCase.testingFunc(t, testCase.inputAuthorizationClaims, extractedClaims, testCase.expectedMSG)
})
}
}
func TestExtractClaimsSetOptions(t *testing.T) {
type test struct {
name string
extractor *ClaimsExtractor
check func(t *testing.T, c test)
}
testCase1 := test{
name: "No custom options",
extractor: NewClaimsExtractor(),
check: func(t *testing.T, c test) {
if c.extractor.authAudience != "" {
t.Error("audience should be empty")
return
}
if c.extractor.userIDClaim != UserIDClaim {
t.Errorf("user id claim should be default, expected %s, got %s", UserIDClaim, c.extractor.userIDClaim)
return
}
if c.extractor.FromRequestContext == nil {
t.Error("from request context should not be nil")
return
}
},
}
testCase2 := test{
name: "Custom audience",
extractor: NewClaimsExtractor(WithAudience("https://login/")),
check: func(t *testing.T, c test) {
if c.extractor.authAudience != "https://login/" {
t.Errorf("audience expected %s, got %s", "https://login/", c.extractor.authAudience)
return
}
},
}
testCase3 := test{
name: "Custom user id claim",
extractor: NewClaimsExtractor(WithUserIDClaim("customUserId")),
check: func(t *testing.T, c test) {
if c.extractor.userIDClaim != "customUserId" {
t.Errorf("user id claim expected %s, got %s", "customUserId", c.extractor.userIDClaim)
return
}
},
}
testCase4 := test{
name: "Custom extractor from request context",
extractor: NewClaimsExtractor(
WithFromRequestContext(func(r *http.Request) AuthorizationClaims {
return AuthorizationClaims{
UserId: "testCustomRequest",
}
})),
check: func(t *testing.T, c test) {
claims := c.extractor.FromRequestContext(&http.Request{})
if claims.UserId != "testCustomRequest" {
t.Errorf("user id claim expected %s, got %s", "testCustomRequest", claims.UserId)
return
}
},
}
for _, testCase := range []test{testCase1, testCase2, testCase3, testCase4} {
t.Run(testCase.name, func(t *testing.T) {
testCase.check(t, testCase)
})
}
}