mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-23 02:36:42 +00:00
[management] refactor auth (#3296)
This commit is contained in:
@@ -7,30 +7,24 @@ import (
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/http/middleware/bypass"
|
||||
"github.com/netbirdio/netbird/management/server/http/util"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
)
|
||||
|
||||
// GetUser function defines a function to fetch user from Account by jwtclaims.AuthorizationClaims
|
||||
type GetUser func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*types.User, error)
|
||||
type GetUser func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error)
|
||||
|
||||
// AccessControl middleware to restrict to make POST/PUT/DELETE requests by admin only
|
||||
type AccessControl struct {
|
||||
claimsExtract jwtclaims.ClaimsExtractor
|
||||
getUser GetUser
|
||||
getUser GetUser
|
||||
}
|
||||
|
||||
// NewAccessControl instance constructor
|
||||
func NewAccessControl(audience, userIDClaim string, getUser GetUser) *AccessControl {
|
||||
func NewAccessControl(getUser GetUser) *AccessControl {
|
||||
return &AccessControl{
|
||||
claimsExtract: *jwtclaims.NewClaimsExtractor(
|
||||
jwtclaims.WithAudience(audience),
|
||||
jwtclaims.WithUserIDClaim(userIDClaim),
|
||||
),
|
||||
getUser: getUser,
|
||||
}
|
||||
}
|
||||
@@ -45,12 +39,16 @@ func (a *AccessControl) Handler(h http.Handler) http.Handler {
|
||||
return
|
||||
}
|
||||
|
||||
claims := a.claimsExtract.FromRequestContext(r)
|
||||
|
||||
user, err := a.getUser(r.Context(), claims)
|
||||
userAuth, err := nbcontext.GetUserAuthFromRequest(r)
|
||||
if err != nil {
|
||||
log.WithContext(r.Context()).Errorf("failed to get user from claims: %s", err)
|
||||
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "invalid JWT"), w)
|
||||
log.WithContext(r.Context()).Errorf("failed to get user auth from request: %s", err)
|
||||
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "invalid user auth"), w)
|
||||
}
|
||||
|
||||
user, err := a.getUser(r.Context(), userAuth)
|
||||
if err != nil {
|
||||
log.WithContext(r.Context()).Errorf("failed to get user: %s", err)
|
||||
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "invalid user auth"), w)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -8,67 +8,41 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nbContext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/auth"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/http/middleware/bypass"
|
||||
"github.com/netbirdio/netbird/management/server/http/util"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
// GetAccountInfoFromPATFunc function
|
||||
type GetAccountInfoFromPATFunc func(ctx context.Context, token string) (user *types.User, pat *types.PersonalAccessToken, domain string, category string, err error)
|
||||
|
||||
// ValidateAndParseTokenFunc function
|
||||
type ValidateAndParseTokenFunc func(ctx context.Context, token string) (*jwt.Token, error)
|
||||
|
||||
// MarkPATUsedFunc function
|
||||
type MarkPATUsedFunc func(ctx context.Context, token string) error
|
||||
|
||||
// CheckUserAccessByJWTGroupsFunc function
|
||||
type CheckUserAccessByJWTGroupsFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) error
|
||||
type EnsureAccountFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error)
|
||||
type SyncUserJWTGroupsFunc func(ctx context.Context, userAuth nbcontext.UserAuth) error
|
||||
|
||||
// AuthMiddleware middleware to verify personal access tokens (PAT) and JWT tokens
|
||||
type AuthMiddleware struct {
|
||||
getAccountInfoFromPAT GetAccountInfoFromPATFunc
|
||||
validateAndParseToken ValidateAndParseTokenFunc
|
||||
markPATUsed MarkPATUsedFunc
|
||||
checkUserAccessByJWTGroups CheckUserAccessByJWTGroupsFunc
|
||||
claimsExtractor *jwtclaims.ClaimsExtractor
|
||||
audience string
|
||||
userIDClaim string
|
||||
authManager auth.Manager
|
||||
ensureAccount EnsureAccountFunc
|
||||
syncUserJWTGroups SyncUserJWTGroupsFunc
|
||||
}
|
||||
|
||||
const (
|
||||
userProperty = "user"
|
||||
)
|
||||
|
||||
// NewAuthMiddleware instance constructor
|
||||
func NewAuthMiddleware(getAccountInfoFromPAT GetAccountInfoFromPATFunc, validateAndParseToken ValidateAndParseTokenFunc,
|
||||
markPATUsed MarkPATUsedFunc, checkUserAccessByJWTGroups CheckUserAccessByJWTGroupsFunc, claimsExtractor *jwtclaims.ClaimsExtractor,
|
||||
audience string, userIdClaim string) *AuthMiddleware {
|
||||
if userIdClaim == "" {
|
||||
userIdClaim = jwtclaims.UserIDClaim
|
||||
}
|
||||
|
||||
func NewAuthMiddleware(
|
||||
authManager auth.Manager,
|
||||
ensureAccount EnsureAccountFunc,
|
||||
syncUserJWTGroups SyncUserJWTGroupsFunc,
|
||||
) *AuthMiddleware {
|
||||
return &AuthMiddleware{
|
||||
getAccountInfoFromPAT: getAccountInfoFromPAT,
|
||||
validateAndParseToken: validateAndParseToken,
|
||||
markPATUsed: markPATUsed,
|
||||
checkUserAccessByJWTGroups: checkUserAccessByJWTGroups,
|
||||
claimsExtractor: claimsExtractor,
|
||||
audience: audience,
|
||||
userIDClaim: userIdClaim,
|
||||
authManager: authManager,
|
||||
ensureAccount: ensureAccount,
|
||||
syncUserJWTGroups: syncUserJWTGroups,
|
||||
}
|
||||
}
|
||||
|
||||
// Handler method of the middleware which authenticates a user either by JWT claims or by PAT
|
||||
func (m *AuthMiddleware) Handler(h http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
if bypass.ShouldBypass(r.URL.Path, h, w, r) {
|
||||
return
|
||||
}
|
||||
@@ -84,108 +58,111 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler {
|
||||
|
||||
switch authType {
|
||||
case "bearer":
|
||||
err := m.checkJWTFromRequest(w, r, auth)
|
||||
request, err := m.checkJWTFromRequest(r, auth)
|
||||
if err != nil {
|
||||
log.WithContext(r.Context()).Errorf("Error when validating JWT claims: %s", err.Error())
|
||||
log.WithContext(r.Context()).Errorf("Error when validating JWT: %s", err.Error())
|
||||
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "token invalid"), w)
|
||||
return
|
||||
}
|
||||
|
||||
h.ServeHTTP(w, request)
|
||||
case "token":
|
||||
err := m.checkPATFromRequest(w, r, auth)
|
||||
request, err := m.checkPATFromRequest(r, auth)
|
||||
if err != nil {
|
||||
log.WithContext(r.Context()).Debugf("Error when validating PAT claims: %s", err.Error())
|
||||
log.WithContext(r.Context()).Debugf("Error when validating PAT: %s", err.Error())
|
||||
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "token invalid"), w)
|
||||
return
|
||||
}
|
||||
h.ServeHTTP(w, request)
|
||||
default:
|
||||
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "no valid authentication provided"), w)
|
||||
return
|
||||
}
|
||||
claims := m.claimsExtractor.FromRequestContext(r)
|
||||
//nolint
|
||||
ctx := context.WithValue(r.Context(), nbContext.UserIDKey, claims.UserId)
|
||||
//nolint
|
||||
ctx = context.WithValue(ctx, nbContext.AccountIDKey, claims.AccountId)
|
||||
h.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
|
||||
// CheckJWTFromRequest checks if the JWT is valid
|
||||
func (m *AuthMiddleware) checkJWTFromRequest(w http.ResponseWriter, r *http.Request, auth []string) error {
|
||||
func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, auth []string) (*http.Request, error) {
|
||||
token, err := getTokenFromJWTRequest(auth)
|
||||
|
||||
// If an error occurs, call the error handler and return an error
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error extracting token: %w", err)
|
||||
return r, fmt.Errorf("error extracting token: %w", err)
|
||||
}
|
||||
|
||||
validatedToken, err := m.validateAndParseToken(r.Context(), token)
|
||||
ctx := r.Context()
|
||||
|
||||
userAuth, validatedToken, err := m.authManager.ValidateAndParseToken(ctx, token)
|
||||
if err != nil {
|
||||
return err
|
||||
return r, err
|
||||
}
|
||||
|
||||
if validatedToken == nil {
|
||||
return nil
|
||||
if impersonate, ok := r.URL.Query()["account"]; ok && len(impersonate) == 1 {
|
||||
userAuth.AccountId = impersonate[0]
|
||||
userAuth.IsChild = ok
|
||||
}
|
||||
|
||||
if err := m.verifyUserAccess(r.Context(), validatedToken); err != nil {
|
||||
return err
|
||||
// we need to call this method because if user is new, we will automatically add it to existing or create a new account
|
||||
accountId, _, err := m.ensureAccount(ctx, userAuth)
|
||||
if err != nil {
|
||||
return r, err
|
||||
}
|
||||
|
||||
// If we get here, everything worked and we can set the
|
||||
// user property in context.
|
||||
newRequest := r.WithContext(context.WithValue(r.Context(), userProperty, validatedToken)) //nolint
|
||||
// Update the current request with the new context information.
|
||||
*r = *newRequest
|
||||
return nil
|
||||
}
|
||||
if userAuth.AccountId != accountId {
|
||||
log.WithContext(ctx).Debugf("Auth middleware sets accountId from ensure, before %s, now %s", userAuth.AccountId, accountId)
|
||||
userAuth.AccountId = accountId
|
||||
}
|
||||
|
||||
// verifyUserAccess checks if a user, based on a validated JWT token,
|
||||
// is allowed access, particularly in cases where the admin enabled JWT
|
||||
// group propagation and designated certain groups with access permissions.
|
||||
func (m *AuthMiddleware) verifyUserAccess(ctx context.Context, validatedToken *jwt.Token) error {
|
||||
authClaims := m.claimsExtractor.FromToken(validatedToken)
|
||||
return m.checkUserAccessByJWTGroups(ctx, authClaims)
|
||||
userAuth, err = m.authManager.EnsureUserAccessByJWTGroups(ctx, userAuth, validatedToken)
|
||||
if err != nil {
|
||||
return r, err
|
||||
}
|
||||
|
||||
err = m.syncUserJWTGroups(ctx, userAuth)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("HTTP server failed to sync user JWT groups: %s", err)
|
||||
}
|
||||
|
||||
return nbcontext.SetUserAuthInRequest(r, userAuth), nil
|
||||
}
|
||||
|
||||
// CheckPATFromRequest checks if the PAT is valid
|
||||
func (m *AuthMiddleware) checkPATFromRequest(w http.ResponseWriter, r *http.Request, auth []string) error {
|
||||
func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, auth []string) (*http.Request, error) {
|
||||
token, err := getTokenFromPATRequest(auth)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error extracting token: %w", err)
|
||||
return r, fmt.Errorf("error extracting token: %w", err)
|
||||
}
|
||||
|
||||
user, pat, accDomain, accCategory, err := m.getAccountInfoFromPAT(r.Context(), token)
|
||||
ctx := r.Context()
|
||||
user, pat, accDomain, accCategory, err := m.authManager.GetPATInfo(ctx, token)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid Token: %w", err)
|
||||
return r, fmt.Errorf("invalid Token: %w", err)
|
||||
}
|
||||
if time.Now().After(pat.GetExpirationDate()) {
|
||||
return fmt.Errorf("token expired")
|
||||
return r, fmt.Errorf("token expired")
|
||||
}
|
||||
|
||||
err = m.markPATUsed(r.Context(), pat.ID)
|
||||
err = m.authManager.MarkPATUsed(ctx, pat.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
return r, err
|
||||
}
|
||||
|
||||
claimMaps := jwt.MapClaims{}
|
||||
claimMaps[m.userIDClaim] = user.Id
|
||||
claimMaps[m.audience+jwtclaims.AccountIDSuffix] = user.AccountID
|
||||
claimMaps[m.audience+jwtclaims.DomainIDSuffix] = accDomain
|
||||
claimMaps[m.audience+jwtclaims.DomainCategorySuffix] = accCategory
|
||||
claimMaps[jwtclaims.IsToken] = true
|
||||
jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, claimMaps)
|
||||
newRequest := r.WithContext(context.WithValue(r.Context(), jwtclaims.TokenUserProperty, jwtToken)) //nolint
|
||||
// Update the current request with the new context information.
|
||||
*r = *newRequest
|
||||
return nil
|
||||
userAuth := nbcontext.UserAuth{
|
||||
UserId: user.Id,
|
||||
AccountId: user.AccountID,
|
||||
Domain: accDomain,
|
||||
DomainCategory: accCategory,
|
||||
IsPAT: true,
|
||||
}
|
||||
|
||||
return nbcontext.SetUserAuthInRequest(r, userAuth), nil
|
||||
}
|
||||
|
||||
// getTokenFromJWTRequest is a "TokenExtractor" that takes auth header parts and extracts
|
||||
// the JWT token from the Authorization header.
|
||||
func getTokenFromJWTRequest(authHeaderParts []string) (string, error) {
|
||||
if len(authHeaderParts) != 2 || strings.ToLower(authHeaderParts[0]) != "bearer" {
|
||||
return "", errors.New("Authorization header format must be Bearer {token}")
|
||||
return "", errors.New("authorization header format must be Bearer {token}")
|
||||
}
|
||||
|
||||
return authHeaderParts[1], nil
|
||||
@@ -195,7 +172,7 @@ func getTokenFromJWTRequest(authHeaderParts []string) (string, error) {
|
||||
// the PAT token from the Authorization header.
|
||||
func getTokenFromPATRequest(authHeaderParts []string) (string, error) {
|
||||
if len(authHeaderParts) != 2 || strings.ToLower(authHeaderParts[0]) != "token" {
|
||||
return "", errors.New("Authorization header format must be Token {token}")
|
||||
return "", errors.New("authorization header format must be Token {token}")
|
||||
}
|
||||
|
||||
return authHeaderParts[1], nil
|
||||
|
||||
@@ -9,10 +9,14 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/auth"
|
||||
nbjwt "github.com/netbirdio/netbird/management/server/auth/jwt"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/util"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/http/middleware/bypass"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
@@ -58,17 +62,23 @@ func mockGetAccountInfoFromPAT(_ context.Context, token string) (user *types.Use
|
||||
return nil, nil, "", "", fmt.Errorf("PAT invalid")
|
||||
}
|
||||
|
||||
func mockValidateAndParseToken(_ context.Context, token string) (*jwt.Token, error) {
|
||||
func mockValidateAndParseToken(_ context.Context, token string) (nbcontext.UserAuth, *jwt.Token, error) {
|
||||
if token == JWT {
|
||||
return &jwt.Token{
|
||||
Claims: jwt.MapClaims{
|
||||
userIDClaim: userID,
|
||||
audience + jwtclaims.AccountIDSuffix: accountID,
|
||||
return nbcontext.UserAuth{
|
||||
UserId: userID,
|
||||
AccountId: accountID,
|
||||
Domain: testAccount.Domain,
|
||||
DomainCategory: testAccount.DomainCategory,
|
||||
},
|
||||
Valid: true,
|
||||
}, nil
|
||||
&jwt.Token{
|
||||
Claims: jwt.MapClaims{
|
||||
userIDClaim: userID,
|
||||
audience + nbjwt.AccountIDSuffix: accountID,
|
||||
},
|
||||
Valid: true,
|
||||
}, nil
|
||||
}
|
||||
return nil, fmt.Errorf("JWT invalid")
|
||||
return nbcontext.UserAuth{}, nil, fmt.Errorf("JWT invalid")
|
||||
}
|
||||
|
||||
func mockMarkPATUsed(_ context.Context, token string) error {
|
||||
@@ -78,16 +88,20 @@ func mockMarkPATUsed(_ context.Context, token string) error {
|
||||
return fmt.Errorf("Should never get reached")
|
||||
}
|
||||
|
||||
func mockCheckUserAccessByJWTGroups(_ context.Context, claims jwtclaims.AuthorizationClaims) error {
|
||||
if testAccount.Id != claims.AccountId {
|
||||
return fmt.Errorf("account with id %s does not exist", claims.AccountId)
|
||||
func mockEnsureUserAccessByJWTGroups(_ context.Context, userAuth nbcontext.UserAuth, token *jwt.Token) (nbcontext.UserAuth, error) {
|
||||
if userAuth.IsChild || userAuth.IsPAT {
|
||||
return userAuth, nil
|
||||
}
|
||||
|
||||
if _, ok := testAccount.Users[claims.UserId]; !ok {
|
||||
return fmt.Errorf("user with id %s does not exist", claims.UserId)
|
||||
if testAccount.Id != userAuth.AccountId {
|
||||
return userAuth, fmt.Errorf("account with id %s does not exist", userAuth.AccountId)
|
||||
}
|
||||
|
||||
return nil
|
||||
if _, ok := testAccount.Users[userAuth.UserId]; !ok {
|
||||
return userAuth, fmt.Errorf("user with id %s does not exist", userAuth.UserId)
|
||||
}
|
||||
|
||||
return userAuth, nil
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_Handler(t *testing.T) {
|
||||
@@ -158,22 +172,24 @@ func TestAuthMiddleware_Handler(t *testing.T) {
|
||||
}
|
||||
|
||||
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// do nothing
|
||||
|
||||
})
|
||||
|
||||
claimsExtractor := jwtclaims.NewClaimsExtractor(
|
||||
jwtclaims.WithAudience(audience),
|
||||
jwtclaims.WithUserIDClaim(userIDClaim),
|
||||
)
|
||||
mockAuth := &auth.MockManager{
|
||||
ValidateAndParseTokenFunc: mockValidateAndParseToken,
|
||||
EnsureUserAccessByJWTGroupsFunc: mockEnsureUserAccessByJWTGroups,
|
||||
MarkPATUsedFunc: mockMarkPATUsed,
|
||||
GetPATInfoFunc: mockGetAccountInfoFromPAT,
|
||||
}
|
||||
|
||||
authMiddleware := NewAuthMiddleware(
|
||||
mockGetAccountInfoFromPAT,
|
||||
mockValidateAndParseToken,
|
||||
mockMarkPATUsed,
|
||||
mockCheckUserAccessByJWTGroups,
|
||||
claimsExtractor,
|
||||
audience,
|
||||
userIDClaim,
|
||||
mockAuth,
|
||||
func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) {
|
||||
return userAuth.AccountId, userAuth.UserId, nil
|
||||
},
|
||||
func(ctx context.Context, userAuth nbcontext.UserAuth) error {
|
||||
return nil
|
||||
},
|
||||
)
|
||||
|
||||
handlerToTest := authMiddleware.Handler(nextHandler)
|
||||
@@ -195,9 +211,115 @@ func TestAuthMiddleware_Handler(t *testing.T) {
|
||||
|
||||
result := rec.Result()
|
||||
defer result.Body.Close()
|
||||
|
||||
if result.StatusCode != tc.expectedStatusCode {
|
||||
t.Errorf("expected status code %d, got %d", tc.expectedStatusCode, result.StatusCode)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_Handler_Child(t *testing.T) {
|
||||
tt := []struct {
|
||||
name string
|
||||
path string
|
||||
authHeader string
|
||||
expectedUserAuth *nbcontext.UserAuth // nil expects 401 response status
|
||||
}{
|
||||
{
|
||||
name: "Valid PAT Token",
|
||||
path: "/test",
|
||||
authHeader: "Token " + PAT,
|
||||
expectedUserAuth: &nbcontext.UserAuth{
|
||||
AccountId: accountID,
|
||||
UserId: userID,
|
||||
Domain: testAccount.Domain,
|
||||
DomainCategory: testAccount.DomainCategory,
|
||||
IsPAT: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Valid PAT Token ignores child",
|
||||
path: "/test?account=xyz",
|
||||
authHeader: "Token " + PAT,
|
||||
expectedUserAuth: &nbcontext.UserAuth{
|
||||
AccountId: accountID,
|
||||
UserId: userID,
|
||||
Domain: testAccount.Domain,
|
||||
DomainCategory: testAccount.DomainCategory,
|
||||
IsPAT: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Valid JWT Token",
|
||||
path: "/test",
|
||||
authHeader: "Bearer " + JWT,
|
||||
expectedUserAuth: &nbcontext.UserAuth{
|
||||
AccountId: accountID,
|
||||
UserId: userID,
|
||||
Domain: testAccount.Domain,
|
||||
DomainCategory: testAccount.DomainCategory,
|
||||
},
|
||||
},
|
||||
|
||||
{
|
||||
name: "Valid JWT Token with child",
|
||||
path: "/test?account=xyz",
|
||||
authHeader: "Bearer " + JWT,
|
||||
expectedUserAuth: &nbcontext.UserAuth{
|
||||
AccountId: "xyz",
|
||||
UserId: userID,
|
||||
Domain: testAccount.Domain,
|
||||
DomainCategory: testAccount.DomainCategory,
|
||||
IsChild: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
mockAuth := &auth.MockManager{
|
||||
ValidateAndParseTokenFunc: mockValidateAndParseToken,
|
||||
EnsureUserAccessByJWTGroupsFunc: mockEnsureUserAccessByJWTGroups,
|
||||
MarkPATUsedFunc: mockMarkPATUsed,
|
||||
GetPATInfoFunc: mockGetAccountInfoFromPAT,
|
||||
}
|
||||
|
||||
authMiddleware := NewAuthMiddleware(
|
||||
mockAuth,
|
||||
func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) {
|
||||
return userAuth.AccountId, userAuth.UserId, nil
|
||||
},
|
||||
func(ctx context.Context, userAuth nbcontext.UserAuth) error {
|
||||
return nil
|
||||
},
|
||||
)
|
||||
|
||||
for _, tc := range tt {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
handlerToTest := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromRequest(r)
|
||||
if tc.expectedUserAuth != nil {
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, *tc.expectedUserAuth, userAuth)
|
||||
} else {
|
||||
assert.Error(t, err)
|
||||
assert.Empty(t, userAuth)
|
||||
}
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest("GET", "http://testing"+tc.path, nil)
|
||||
req.Header.Set("Authorization", tc.authHeader)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
handlerToTest.ServeHTTP(rec, req)
|
||||
|
||||
result := rec.Result()
|
||||
defer result.Body.Close()
|
||||
|
||||
if tc.expectedUserAuth != nil {
|
||||
assert.Equal(t, 200, result.StatusCode)
|
||||
} else {
|
||||
assert.Equal(t, 401, result.StatusCode)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user