Jwtclaims package (#242)

* Move JWTClaims logic to its own package

* Add extractor tests
This commit is contained in:
Maycon Santos
2022-02-23 20:02:02 +01:00
committed by GitHub
parent 5f5cbf7e20
commit b29948b910
7 changed files with 164 additions and 53 deletions

View File

@@ -0,0 +1,8 @@
package jwtclaims
// AuthorizationClaims stores authorization information from JWTs
type AuthorizationClaims struct {
UserId string
AccountId string
Domain string
}

View File

@@ -0,0 +1,51 @@
package jwtclaims
import (
"github.com/golang-jwt/jwt"
"net/http"
)
const (
TokenUserProperty = "user"
AccountIDSuffix = "wt_account_id"
DomainIDSuffix = "wt_account_domain"
UserIDClaim = "sub"
)
// Extract function type
type ExtractClaims func(r *http.Request, authAudiance string) AuthorizationClaims
// ClaimsExtractor struct that holds the extract function
type ClaimsExtractor struct {
ExtractClaimsFromRequestContext ExtractClaims
}
// NewClaimsExtractor returns an extractor, and if provided with a function with ExtractClaims signature,
// then it will use that logic. Uses ExtractClaimsFromRequestContext by default
func NewClaimsExtractor(e ExtractClaims) *ClaimsExtractor {
var extractFunc ExtractClaims
if extractFunc = e; extractFunc == nil {
extractFunc = ExtractClaimsFromRequestContext
}
return &ClaimsExtractor{
ExtractClaimsFromRequestContext: extractFunc,
}
}
// ExtractClaimsFromRequestContext extracts claims from the request context previously filled by the JWT token (after auth)
func ExtractClaimsFromRequestContext(r *http.Request, authAudiance string) AuthorizationClaims {
token := r.Context().Value(TokenUserProperty).(*jwt.Token)
claims := token.Claims.(jwt.MapClaims)
jwtClaims := AuthorizationClaims{}
jwtClaims.UserId = claims[UserIDClaim].(string)
accountIdClaim, ok := claims[authAudiance+AccountIDSuffix]
if ok {
jwtClaims.AccountId = accountIdClaim.(string)
}
domainClaim, ok := claims[authAudiance+DomainIDSuffix]
if ok {
jwtClaims.Domain = domainClaim.(string)
}
return jwtClaims
}

View File

@@ -0,0 +1,94 @@
package jwtclaims
import (
"context"
"github.com/golang-jwt/jwt"
"github.com/stretchr/testify/require"
"net/http"
"testing"
)
func newTestRequestWithJWT(t *testing.T, claims AuthorizationClaims, audiance string) *http.Request {
claimMaps := jwt.MapClaims{}
if claims.UserId != "" {
claimMaps[UserIDClaim] = claims.UserId
}
if claims.AccountId != "" {
claimMaps[audiance+AccountIDSuffix] = claims.AccountId
}
if claims.Domain != "" {
claimMaps[audiance+DomainIDSuffix] = claims.Domain
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claimMaps)
r, err := http.NewRequest(http.MethodGet, "http://localhost", nil)
require.NoError(t, err, "creating testing request failed")
testRequest := r.WithContext(context.WithValue(r.Context(), TokenUserProperty, token)) //nolint
return testRequest
}
func TestExtractClaimsFromRequestContext(t *testing.T) {
type test struct {
name string
inputAuthorizationClaims AuthorizationClaims
inputAudiance string
testingFunc require.ComparisonAssertionFunc
expectedMSG string
}
testCase1 := test{
name: "All Claim Fields",
inputAudiance: "https://login/",
inputAuthorizationClaims: AuthorizationClaims{
UserId: "test",
Domain: "test.com",
AccountId: "testAcc",
},
testingFunc: require.EqualValues,
expectedMSG: "extracted claims should match input claims",
}
testCase2 := test{
name: "Domain Is Empty",
inputAudiance: "https://login/",
inputAuthorizationClaims: AuthorizationClaims{
UserId: "test",
AccountId: "testAcc",
},
testingFunc: require.EqualValues,
expectedMSG: "extracted claims should match input claims",
}
testCase3 := test{
name: "Account ID Is Empty",
inputAudiance: "https://login/",
inputAuthorizationClaims: AuthorizationClaims{
UserId: "test",
Domain: "test.com",
},
testingFunc: require.EqualValues,
expectedMSG: "extracted claims should match input claims",
}
testCase4 := test{
name: "Only User ID Is set",
inputAudiance: "https://login/",
inputAuthorizationClaims: AuthorizationClaims{
UserId: "test",
},
testingFunc: require.EqualValues,
expectedMSG: "extracted claims should match input claims",
}
for _, testCase := range []test{testCase1, testCase2, testCase3, testCase4} {
t.Run(testCase.name, func(t *testing.T) {
request := newTestRequestWithJWT(t, testCase.inputAuthorizationClaims, testCase.inputAudiance)
extractedClaims := ExtractClaimsFromRequestContext(request, testCase.inputAudiance)
testCase.testingFunc(t, testCase.inputAuthorizationClaims, extractedClaims, testCase.expectedMSG)
})
}
}