mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-18 16:26:38 +00:00
Feature: add custom id claim (#667)
This feature allows using the custom claim in the JWT token as a user ID. Refactor claims extractor with options support Add is_current to the user API response
This commit is contained in:
committed by
GitHub
parent
494e56d1be
commit
3ec8274b8e
@@ -1,8 +1,9 @@
|
||||
package jwtclaims
|
||||
|
||||
import (
|
||||
"github.com/golang-jwt/jwt"
|
||||
"net/http"
|
||||
|
||||
"github.com/golang-jwt/jwt"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -14,51 +15,85 @@ const (
|
||||
)
|
||||
|
||||
// Extract function type
|
||||
type ExtractClaims func(r *http.Request, authAudiance string) AuthorizationClaims
|
||||
type ExtractClaims func(r *http.Request) AuthorizationClaims
|
||||
|
||||
// ClaimsExtractor struct that holds the extract function
|
||||
type ClaimsExtractor struct {
|
||||
ExtractClaimsFromRequestContext ExtractClaims
|
||||
authAudience string
|
||||
userIDClaim string
|
||||
|
||||
FromRequestContext ExtractClaims
|
||||
}
|
||||
|
||||
// ClaimsExtractorOption is a function that configures the ClaimsExtractor
|
||||
type ClaimsExtractorOption func(*ClaimsExtractor)
|
||||
|
||||
// WithAudience sets the audience for the extractor
|
||||
func WithAudience(audience string) ClaimsExtractorOption {
|
||||
return func(c *ClaimsExtractor) {
|
||||
c.authAudience = audience
|
||||
}
|
||||
}
|
||||
|
||||
// WithUserIDClaim sets the user id claim for the extractor
|
||||
func WithUserIDClaim(userIDClaim string) ClaimsExtractorOption {
|
||||
return func(c *ClaimsExtractor) {
|
||||
c.userIDClaim = userIDClaim
|
||||
}
|
||||
}
|
||||
|
||||
// WithFromRequestContext sets the function that extracts claims from the request context
|
||||
func WithFromRequestContext(ec ExtractClaims) ClaimsExtractorOption {
|
||||
return func(c *ClaimsExtractor) {
|
||||
c.FromRequestContext = ec
|
||||
}
|
||||
}
|
||||
|
||||
// NewClaimsExtractor returns an extractor, and if provided with a function with ExtractClaims signature,
|
||||
// then it will use that logic. Uses ExtractClaimsFromRequestContext by default
|
||||
func NewClaimsExtractor(e ExtractClaims) *ClaimsExtractor {
|
||||
var extractFunc ExtractClaims
|
||||
if extractFunc = e; extractFunc == nil {
|
||||
extractFunc = ExtractClaimsFromRequestContext
|
||||
func NewClaimsExtractor(options ...ClaimsExtractorOption) *ClaimsExtractor {
|
||||
ce := &ClaimsExtractor{}
|
||||
for _, option := range options {
|
||||
option(ce)
|
||||
}
|
||||
|
||||
return &ClaimsExtractor{
|
||||
ExtractClaimsFromRequestContext: extractFunc,
|
||||
if ce.FromRequestContext == nil {
|
||||
ce.FromRequestContext = ce.fromRequestContext
|
||||
}
|
||||
if ce.userIDClaim == "" {
|
||||
ce.userIDClaim = UserIDClaim
|
||||
}
|
||||
return ce
|
||||
}
|
||||
|
||||
// ExtractClaimsFromRequestContext extracts claims from the request context previously filled by the JWT token (after auth)
|
||||
func ExtractClaimsFromRequestContext(r *http.Request, authAudience string) AuthorizationClaims {
|
||||
if r.Context().Value(TokenUserProperty) == nil {
|
||||
return AuthorizationClaims{}
|
||||
}
|
||||
token := r.Context().Value(TokenUserProperty).(*jwt.Token)
|
||||
return ExtractClaimsWithToken(token, authAudience)
|
||||
}
|
||||
|
||||
// ExtractClaimsWithToken extracts claims from the token (after auth)
|
||||
func ExtractClaimsWithToken(token *jwt.Token, authAudience string) AuthorizationClaims {
|
||||
// FromToken extracts claims from the token (after auth)
|
||||
func (c *ClaimsExtractor) FromToken(token *jwt.Token) AuthorizationClaims {
|
||||
claims := token.Claims.(jwt.MapClaims)
|
||||
jwtClaims := AuthorizationClaims{}
|
||||
jwtClaims.UserId = claims[UserIDClaim].(string)
|
||||
accountIdClaim, ok := claims[authAudience+AccountIDSuffix]
|
||||
if ok {
|
||||
jwtClaims.AccountId = accountIdClaim.(string)
|
||||
userID, ok := claims[c.userIDClaim].(string)
|
||||
if !ok {
|
||||
return jwtClaims
|
||||
}
|
||||
domainClaim, ok := claims[authAudience+DomainIDSuffix]
|
||||
jwtClaims.UserId = userID
|
||||
accountIDClaim, ok := claims[c.authAudience+AccountIDSuffix]
|
||||
if ok {
|
||||
jwtClaims.AccountId = accountIDClaim.(string)
|
||||
}
|
||||
domainClaim, ok := claims[c.authAudience+DomainIDSuffix]
|
||||
if ok {
|
||||
jwtClaims.Domain = domainClaim.(string)
|
||||
}
|
||||
domainCategoryClaim, ok := claims[authAudience+DomainCategorySuffix]
|
||||
domainCategoryClaim, ok := claims[c.authAudience+DomainCategorySuffix]
|
||||
if ok {
|
||||
jwtClaims.DomainCategory = domainCategoryClaim.(string)
|
||||
}
|
||||
return jwtClaims
|
||||
}
|
||||
|
||||
// fromRequestContext extracts claims from the request context previously filled by the JWT token (after auth)
|
||||
func (c *ClaimsExtractor) fromRequestContext(r *http.Request) AuthorizationClaims {
|
||||
if r.Context().Value(TokenUserProperty) == nil {
|
||||
return AuthorizationClaims{}
|
||||
}
|
||||
token := r.Context().Value(TokenUserProperty).(*jwt.Token)
|
||||
return c.FromToken(token)
|
||||
}
|
||||
|
||||
@@ -2,10 +2,11 @@ package jwtclaims
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/golang-jwt/jwt"
|
||||
"github.com/stretchr/testify/require"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/golang-jwt/jwt"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func newTestRequestWithJWT(t *testing.T, claims AuthorizationClaims, audiance string) *http.Request {
|
||||
@@ -31,7 +32,6 @@ func newTestRequestWithJWT(t *testing.T, claims AuthorizationClaims, audiance st
|
||||
}
|
||||
|
||||
func TestExtractClaimsFromRequestContext(t *testing.T) {
|
||||
|
||||
type test struct {
|
||||
name string
|
||||
inputAuthorizationClaims AuthorizationClaims
|
||||
@@ -99,12 +99,84 @@ func TestExtractClaimsFromRequestContext(t *testing.T) {
|
||||
|
||||
for _, testCase := range []test{testCase1, testCase2, testCase3, testCase4, testCase5} {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
|
||||
request := newTestRequestWithJWT(t, testCase.inputAuthorizationClaims, testCase.inputAudiance)
|
||||
|
||||
extractedClaims := ExtractClaimsFromRequestContext(request, testCase.inputAudiance)
|
||||
extractor := NewClaimsExtractor(WithAudience(testCase.inputAudiance))
|
||||
extractedClaims := extractor.FromRequestContext(request)
|
||||
|
||||
testCase.testingFunc(t, testCase.inputAuthorizationClaims, extractedClaims, testCase.expectedMSG)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractClaimsSetOptions(t *testing.T) {
|
||||
type test struct {
|
||||
name string
|
||||
extractor *ClaimsExtractor
|
||||
check func(t *testing.T, c test)
|
||||
}
|
||||
|
||||
testCase1 := test{
|
||||
name: "No custom options",
|
||||
extractor: NewClaimsExtractor(),
|
||||
check: func(t *testing.T, c test) {
|
||||
if c.extractor.authAudience != "" {
|
||||
t.Error("audience should be empty")
|
||||
return
|
||||
}
|
||||
if c.extractor.userIDClaim != UserIDClaim {
|
||||
t.Errorf("user id claim should be default, expected %s, got %s", UserIDClaim, c.extractor.userIDClaim)
|
||||
return
|
||||
}
|
||||
if c.extractor.FromRequestContext == nil {
|
||||
t.Error("from request context should not be nil")
|
||||
return
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
testCase2 := test{
|
||||
name: "Custom audience",
|
||||
extractor: NewClaimsExtractor(WithAudience("https://login/")),
|
||||
check: func(t *testing.T, c test) {
|
||||
if c.extractor.authAudience != "https://login/" {
|
||||
t.Errorf("audience expected %s, got %s", "https://login/", c.extractor.authAudience)
|
||||
return
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
testCase3 := test{
|
||||
name: "Custom user id claim",
|
||||
extractor: NewClaimsExtractor(WithUserIDClaim("customUserId")),
|
||||
check: func(t *testing.T, c test) {
|
||||
if c.extractor.userIDClaim != "customUserId" {
|
||||
t.Errorf("user id claim expected %s, got %s", "customUserId", c.extractor.userIDClaim)
|
||||
return
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
testCase4 := test{
|
||||
name: "Custom extractor from request context",
|
||||
extractor: NewClaimsExtractor(
|
||||
WithFromRequestContext(func(r *http.Request) AuthorizationClaims {
|
||||
return AuthorizationClaims{
|
||||
UserId: "testCustomRequest",
|
||||
}
|
||||
})),
|
||||
check: func(t *testing.T, c test) {
|
||||
claims := c.extractor.FromRequestContext(&http.Request{})
|
||||
if claims.UserId != "testCustomRequest" {
|
||||
t.Errorf("user id claim expected %s, got %s", "testCustomRequest", claims.UserId)
|
||||
return
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
for _, testCase := range []test{testCase1, testCase2, testCase3, testCase4} {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
testCase.check(t, testCase)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user