mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-18 16:26:38 +00:00
Jwtclaims package (#242)
* Move JWTClaims logic to its own package * Add extractor tests
This commit is contained in:
8
management/server/jwtclaims/claims.go
Normal file
8
management/server/jwtclaims/claims.go
Normal file
@@ -0,0 +1,8 @@
|
||||
package jwtclaims
|
||||
|
||||
// AuthorizationClaims stores authorization information from JWTs
|
||||
type AuthorizationClaims struct {
|
||||
UserId string
|
||||
AccountId string
|
||||
Domain string
|
||||
}
|
||||
51
management/server/jwtclaims/extractor.go
Normal file
51
management/server/jwtclaims/extractor.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package jwtclaims
|
||||
|
||||
import (
|
||||
"github.com/golang-jwt/jwt"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
const (
|
||||
TokenUserProperty = "user"
|
||||
AccountIDSuffix = "wt_account_id"
|
||||
DomainIDSuffix = "wt_account_domain"
|
||||
UserIDClaim = "sub"
|
||||
)
|
||||
|
||||
// Extract function type
|
||||
type ExtractClaims func(r *http.Request, authAudiance string) AuthorizationClaims
|
||||
|
||||
// ClaimsExtractor struct that holds the extract function
|
||||
type ClaimsExtractor struct {
|
||||
ExtractClaimsFromRequestContext ExtractClaims
|
||||
}
|
||||
|
||||
// NewClaimsExtractor returns an extractor, and if provided with a function with ExtractClaims signature,
|
||||
// then it will use that logic. Uses ExtractClaimsFromRequestContext by default
|
||||
func NewClaimsExtractor(e ExtractClaims) *ClaimsExtractor {
|
||||
var extractFunc ExtractClaims
|
||||
if extractFunc = e; extractFunc == nil {
|
||||
extractFunc = ExtractClaimsFromRequestContext
|
||||
}
|
||||
|
||||
return &ClaimsExtractor{
|
||||
ExtractClaimsFromRequestContext: extractFunc,
|
||||
}
|
||||
}
|
||||
|
||||
// ExtractClaimsFromRequestContext extracts claims from the request context previously filled by the JWT token (after auth)
|
||||
func ExtractClaimsFromRequestContext(r *http.Request, authAudiance string) AuthorizationClaims {
|
||||
token := r.Context().Value(TokenUserProperty).(*jwt.Token)
|
||||
claims := token.Claims.(jwt.MapClaims)
|
||||
jwtClaims := AuthorizationClaims{}
|
||||
jwtClaims.UserId = claims[UserIDClaim].(string)
|
||||
accountIdClaim, ok := claims[authAudiance+AccountIDSuffix]
|
||||
if ok {
|
||||
jwtClaims.AccountId = accountIdClaim.(string)
|
||||
}
|
||||
domainClaim, ok := claims[authAudiance+DomainIDSuffix]
|
||||
if ok {
|
||||
jwtClaims.Domain = domainClaim.(string)
|
||||
}
|
||||
return jwtClaims
|
||||
}
|
||||
94
management/server/jwtclaims/extractor_test.go
Normal file
94
management/server/jwtclaims/extractor_test.go
Normal file
@@ -0,0 +1,94 @@
|
||||
package jwtclaims
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/golang-jwt/jwt"
|
||||
"github.com/stretchr/testify/require"
|
||||
"net/http"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func newTestRequestWithJWT(t *testing.T, claims AuthorizationClaims, audiance string) *http.Request {
|
||||
claimMaps := jwt.MapClaims{}
|
||||
if claims.UserId != "" {
|
||||
claimMaps[UserIDClaim] = claims.UserId
|
||||
}
|
||||
if claims.AccountId != "" {
|
||||
claimMaps[audiance+AccountIDSuffix] = claims.AccountId
|
||||
}
|
||||
if claims.Domain != "" {
|
||||
claimMaps[audiance+DomainIDSuffix] = claims.Domain
|
||||
}
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claimMaps)
|
||||
r, err := http.NewRequest(http.MethodGet, "http://localhost", nil)
|
||||
require.NoError(t, err, "creating testing request failed")
|
||||
testRequest := r.WithContext(context.WithValue(r.Context(), TokenUserProperty, token)) //nolint
|
||||
|
||||
return testRequest
|
||||
}
|
||||
|
||||
func TestExtractClaimsFromRequestContext(t *testing.T) {
|
||||
|
||||
type test struct {
|
||||
name string
|
||||
inputAuthorizationClaims AuthorizationClaims
|
||||
inputAudiance string
|
||||
testingFunc require.ComparisonAssertionFunc
|
||||
expectedMSG string
|
||||
}
|
||||
|
||||
testCase1 := test{
|
||||
name: "All Claim Fields",
|
||||
inputAudiance: "https://login/",
|
||||
inputAuthorizationClaims: AuthorizationClaims{
|
||||
UserId: "test",
|
||||
Domain: "test.com",
|
||||
AccountId: "testAcc",
|
||||
},
|
||||
testingFunc: require.EqualValues,
|
||||
expectedMSG: "extracted claims should match input claims",
|
||||
}
|
||||
|
||||
testCase2 := test{
|
||||
name: "Domain Is Empty",
|
||||
inputAudiance: "https://login/",
|
||||
inputAuthorizationClaims: AuthorizationClaims{
|
||||
UserId: "test",
|
||||
AccountId: "testAcc",
|
||||
},
|
||||
testingFunc: require.EqualValues,
|
||||
expectedMSG: "extracted claims should match input claims",
|
||||
}
|
||||
|
||||
testCase3 := test{
|
||||
name: "Account ID Is Empty",
|
||||
inputAudiance: "https://login/",
|
||||
inputAuthorizationClaims: AuthorizationClaims{
|
||||
UserId: "test",
|
||||
Domain: "test.com",
|
||||
},
|
||||
testingFunc: require.EqualValues,
|
||||
expectedMSG: "extracted claims should match input claims",
|
||||
}
|
||||
|
||||
testCase4 := test{
|
||||
name: "Only User ID Is set",
|
||||
inputAudiance: "https://login/",
|
||||
inputAuthorizationClaims: AuthorizationClaims{
|
||||
UserId: "test",
|
||||
},
|
||||
testingFunc: require.EqualValues,
|
||||
expectedMSG: "extracted claims should match input claims",
|
||||
}
|
||||
|
||||
for _, testCase := range []test{testCase1, testCase2, testCase3, testCase4} {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
|
||||
request := newTestRequestWithJWT(t, testCase.inputAuthorizationClaims, testCase.inputAudiance)
|
||||
|
||||
extractedClaims := ExtractClaimsFromRequestContext(request, testCase.inputAudiance)
|
||||
|
||||
testCase.testingFunc(t, testCase.inputAuthorizationClaims, extractedClaims, testCase.expectedMSG)
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user