[management] refactor auth (#3296)

This commit is contained in:
Pedro Maia Costa
2025-02-20 20:24:40 +00:00
committed by GitHub
parent d7d5b1b1d6
commit 77e40f41f2
64 changed files with 2085 additions and 1937 deletions

View File

@@ -7,30 +7,24 @@ import (
log "github.com/sirupsen/logrus"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/http/middleware/bypass"
"github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/jwtclaims"
)
// GetUser function defines a function to fetch user from Account by jwtclaims.AuthorizationClaims
type GetUser func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*types.User, error)
type GetUser func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error)
// AccessControl middleware to restrict to make POST/PUT/DELETE requests by admin only
type AccessControl struct {
claimsExtract jwtclaims.ClaimsExtractor
getUser GetUser
getUser GetUser
}
// NewAccessControl instance constructor
func NewAccessControl(audience, userIDClaim string, getUser GetUser) *AccessControl {
func NewAccessControl(getUser GetUser) *AccessControl {
return &AccessControl{
claimsExtract: *jwtclaims.NewClaimsExtractor(
jwtclaims.WithAudience(audience),
jwtclaims.WithUserIDClaim(userIDClaim),
),
getUser: getUser,
}
}
@@ -45,12 +39,16 @@ func (a *AccessControl) Handler(h http.Handler) http.Handler {
return
}
claims := a.claimsExtract.FromRequestContext(r)
user, err := a.getUser(r.Context(), claims)
userAuth, err := nbcontext.GetUserAuthFromRequest(r)
if err != nil {
log.WithContext(r.Context()).Errorf("failed to get user from claims: %s", err)
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "invalid JWT"), w)
log.WithContext(r.Context()).Errorf("failed to get user auth from request: %s", err)
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "invalid user auth"), w)
}
user, err := a.getUser(r.Context(), userAuth)
if err != nil {
log.WithContext(r.Context()).Errorf("failed to get user: %s", err)
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "invalid user auth"), w)
return
}

View File

@@ -8,67 +8,41 @@ import (
"strings"
"time"
"github.com/golang-jwt/jwt"
log "github.com/sirupsen/logrus"
nbContext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/auth"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/http/middleware/bypass"
"github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/types"
)
// GetAccountInfoFromPATFunc function
type GetAccountInfoFromPATFunc func(ctx context.Context, token string) (user *types.User, pat *types.PersonalAccessToken, domain string, category string, err error)
// ValidateAndParseTokenFunc function
type ValidateAndParseTokenFunc func(ctx context.Context, token string) (*jwt.Token, error)
// MarkPATUsedFunc function
type MarkPATUsedFunc func(ctx context.Context, token string) error
// CheckUserAccessByJWTGroupsFunc function
type CheckUserAccessByJWTGroupsFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) error
type EnsureAccountFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error)
type SyncUserJWTGroupsFunc func(ctx context.Context, userAuth nbcontext.UserAuth) error
// AuthMiddleware middleware to verify personal access tokens (PAT) and JWT tokens
type AuthMiddleware struct {
getAccountInfoFromPAT GetAccountInfoFromPATFunc
validateAndParseToken ValidateAndParseTokenFunc
markPATUsed MarkPATUsedFunc
checkUserAccessByJWTGroups CheckUserAccessByJWTGroupsFunc
claimsExtractor *jwtclaims.ClaimsExtractor
audience string
userIDClaim string
authManager auth.Manager
ensureAccount EnsureAccountFunc
syncUserJWTGroups SyncUserJWTGroupsFunc
}
const (
userProperty = "user"
)
// NewAuthMiddleware instance constructor
func NewAuthMiddleware(getAccountInfoFromPAT GetAccountInfoFromPATFunc, validateAndParseToken ValidateAndParseTokenFunc,
markPATUsed MarkPATUsedFunc, checkUserAccessByJWTGroups CheckUserAccessByJWTGroupsFunc, claimsExtractor *jwtclaims.ClaimsExtractor,
audience string, userIdClaim string) *AuthMiddleware {
if userIdClaim == "" {
userIdClaim = jwtclaims.UserIDClaim
}
func NewAuthMiddleware(
authManager auth.Manager,
ensureAccount EnsureAccountFunc,
syncUserJWTGroups SyncUserJWTGroupsFunc,
) *AuthMiddleware {
return &AuthMiddleware{
getAccountInfoFromPAT: getAccountInfoFromPAT,
validateAndParseToken: validateAndParseToken,
markPATUsed: markPATUsed,
checkUserAccessByJWTGroups: checkUserAccessByJWTGroups,
claimsExtractor: claimsExtractor,
audience: audience,
userIDClaim: userIdClaim,
authManager: authManager,
ensureAccount: ensureAccount,
syncUserJWTGroups: syncUserJWTGroups,
}
}
// Handler method of the middleware which authenticates a user either by JWT claims or by PAT
func (m *AuthMiddleware) Handler(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if bypass.ShouldBypass(r.URL.Path, h, w, r) {
return
}
@@ -84,108 +58,111 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler {
switch authType {
case "bearer":
err := m.checkJWTFromRequest(w, r, auth)
request, err := m.checkJWTFromRequest(r, auth)
if err != nil {
log.WithContext(r.Context()).Errorf("Error when validating JWT claims: %s", err.Error())
log.WithContext(r.Context()).Errorf("Error when validating JWT: %s", err.Error())
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "token invalid"), w)
return
}
h.ServeHTTP(w, request)
case "token":
err := m.checkPATFromRequest(w, r, auth)
request, err := m.checkPATFromRequest(r, auth)
if err != nil {
log.WithContext(r.Context()).Debugf("Error when validating PAT claims: %s", err.Error())
log.WithContext(r.Context()).Debugf("Error when validating PAT: %s", err.Error())
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "token invalid"), w)
return
}
h.ServeHTTP(w, request)
default:
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "no valid authentication provided"), w)
return
}
claims := m.claimsExtractor.FromRequestContext(r)
//nolint
ctx := context.WithValue(r.Context(), nbContext.UserIDKey, claims.UserId)
//nolint
ctx = context.WithValue(ctx, nbContext.AccountIDKey, claims.AccountId)
h.ServeHTTP(w, r.WithContext(ctx))
})
}
// CheckJWTFromRequest checks if the JWT is valid
func (m *AuthMiddleware) checkJWTFromRequest(w http.ResponseWriter, r *http.Request, auth []string) error {
func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, auth []string) (*http.Request, error) {
token, err := getTokenFromJWTRequest(auth)
// If an error occurs, call the error handler and return an error
if err != nil {
return fmt.Errorf("Error extracting token: %w", err)
return r, fmt.Errorf("error extracting token: %w", err)
}
validatedToken, err := m.validateAndParseToken(r.Context(), token)
ctx := r.Context()
userAuth, validatedToken, err := m.authManager.ValidateAndParseToken(ctx, token)
if err != nil {
return err
return r, err
}
if validatedToken == nil {
return nil
if impersonate, ok := r.URL.Query()["account"]; ok && len(impersonate) == 1 {
userAuth.AccountId = impersonate[0]
userAuth.IsChild = ok
}
if err := m.verifyUserAccess(r.Context(), validatedToken); err != nil {
return err
// we need to call this method because if user is new, we will automatically add it to existing or create a new account
accountId, _, err := m.ensureAccount(ctx, userAuth)
if err != nil {
return r, err
}
// If we get here, everything worked and we can set the
// user property in context.
newRequest := r.WithContext(context.WithValue(r.Context(), userProperty, validatedToken)) //nolint
// Update the current request with the new context information.
*r = *newRequest
return nil
}
if userAuth.AccountId != accountId {
log.WithContext(ctx).Debugf("Auth middleware sets accountId from ensure, before %s, now %s", userAuth.AccountId, accountId)
userAuth.AccountId = accountId
}
// verifyUserAccess checks if a user, based on a validated JWT token,
// is allowed access, particularly in cases where the admin enabled JWT
// group propagation and designated certain groups with access permissions.
func (m *AuthMiddleware) verifyUserAccess(ctx context.Context, validatedToken *jwt.Token) error {
authClaims := m.claimsExtractor.FromToken(validatedToken)
return m.checkUserAccessByJWTGroups(ctx, authClaims)
userAuth, err = m.authManager.EnsureUserAccessByJWTGroups(ctx, userAuth, validatedToken)
if err != nil {
return r, err
}
err = m.syncUserJWTGroups(ctx, userAuth)
if err != nil {
log.WithContext(ctx).Errorf("HTTP server failed to sync user JWT groups: %s", err)
}
return nbcontext.SetUserAuthInRequest(r, userAuth), nil
}
// CheckPATFromRequest checks if the PAT is valid
func (m *AuthMiddleware) checkPATFromRequest(w http.ResponseWriter, r *http.Request, auth []string) error {
func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, auth []string) (*http.Request, error) {
token, err := getTokenFromPATRequest(auth)
if err != nil {
return fmt.Errorf("error extracting token: %w", err)
return r, fmt.Errorf("error extracting token: %w", err)
}
user, pat, accDomain, accCategory, err := m.getAccountInfoFromPAT(r.Context(), token)
ctx := r.Context()
user, pat, accDomain, accCategory, err := m.authManager.GetPATInfo(ctx, token)
if err != nil {
return fmt.Errorf("invalid Token: %w", err)
return r, fmt.Errorf("invalid Token: %w", err)
}
if time.Now().After(pat.GetExpirationDate()) {
return fmt.Errorf("token expired")
return r, fmt.Errorf("token expired")
}
err = m.markPATUsed(r.Context(), pat.ID)
err = m.authManager.MarkPATUsed(ctx, pat.ID)
if err != nil {
return err
return r, err
}
claimMaps := jwt.MapClaims{}
claimMaps[m.userIDClaim] = user.Id
claimMaps[m.audience+jwtclaims.AccountIDSuffix] = user.AccountID
claimMaps[m.audience+jwtclaims.DomainIDSuffix] = accDomain
claimMaps[m.audience+jwtclaims.DomainCategorySuffix] = accCategory
claimMaps[jwtclaims.IsToken] = true
jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, claimMaps)
newRequest := r.WithContext(context.WithValue(r.Context(), jwtclaims.TokenUserProperty, jwtToken)) //nolint
// Update the current request with the new context information.
*r = *newRequest
return nil
userAuth := nbcontext.UserAuth{
UserId: user.Id,
AccountId: user.AccountID,
Domain: accDomain,
DomainCategory: accCategory,
IsPAT: true,
}
return nbcontext.SetUserAuthInRequest(r, userAuth), nil
}
// getTokenFromJWTRequest is a "TokenExtractor" that takes auth header parts and extracts
// the JWT token from the Authorization header.
func getTokenFromJWTRequest(authHeaderParts []string) (string, error) {
if len(authHeaderParts) != 2 || strings.ToLower(authHeaderParts[0]) != "bearer" {
return "", errors.New("Authorization header format must be Bearer {token}")
return "", errors.New("authorization header format must be Bearer {token}")
}
return authHeaderParts[1], nil
@@ -195,7 +172,7 @@ func getTokenFromJWTRequest(authHeaderParts []string) (string, error) {
// the PAT token from the Authorization header.
func getTokenFromPATRequest(authHeaderParts []string) (string, error) {
if len(authHeaderParts) != 2 || strings.ToLower(authHeaderParts[0]) != "token" {
return "", errors.New("Authorization header format must be Token {token}")
return "", errors.New("authorization header format must be Token {token}")
}
return authHeaderParts[1], nil

View File

@@ -9,10 +9,14 @@ import (
"time"
"github.com/golang-jwt/jwt"
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/server/auth"
nbjwt "github.com/netbirdio/netbird/management/server/auth/jwt"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/util"
"github.com/netbirdio/netbird/management/server/http/middleware/bypass"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/types"
)
@@ -58,17 +62,23 @@ func mockGetAccountInfoFromPAT(_ context.Context, token string) (user *types.Use
return nil, nil, "", "", fmt.Errorf("PAT invalid")
}
func mockValidateAndParseToken(_ context.Context, token string) (*jwt.Token, error) {
func mockValidateAndParseToken(_ context.Context, token string) (nbcontext.UserAuth, *jwt.Token, error) {
if token == JWT {
return &jwt.Token{
Claims: jwt.MapClaims{
userIDClaim: userID,
audience + jwtclaims.AccountIDSuffix: accountID,
return nbcontext.UserAuth{
UserId: userID,
AccountId: accountID,
Domain: testAccount.Domain,
DomainCategory: testAccount.DomainCategory,
},
Valid: true,
}, nil
&jwt.Token{
Claims: jwt.MapClaims{
userIDClaim: userID,
audience + nbjwt.AccountIDSuffix: accountID,
},
Valid: true,
}, nil
}
return nil, fmt.Errorf("JWT invalid")
return nbcontext.UserAuth{}, nil, fmt.Errorf("JWT invalid")
}
func mockMarkPATUsed(_ context.Context, token string) error {
@@ -78,16 +88,20 @@ func mockMarkPATUsed(_ context.Context, token string) error {
return fmt.Errorf("Should never get reached")
}
func mockCheckUserAccessByJWTGroups(_ context.Context, claims jwtclaims.AuthorizationClaims) error {
if testAccount.Id != claims.AccountId {
return fmt.Errorf("account with id %s does not exist", claims.AccountId)
func mockEnsureUserAccessByJWTGroups(_ context.Context, userAuth nbcontext.UserAuth, token *jwt.Token) (nbcontext.UserAuth, error) {
if userAuth.IsChild || userAuth.IsPAT {
return userAuth, nil
}
if _, ok := testAccount.Users[claims.UserId]; !ok {
return fmt.Errorf("user with id %s does not exist", claims.UserId)
if testAccount.Id != userAuth.AccountId {
return userAuth, fmt.Errorf("account with id %s does not exist", userAuth.AccountId)
}
return nil
if _, ok := testAccount.Users[userAuth.UserId]; !ok {
return userAuth, fmt.Errorf("user with id %s does not exist", userAuth.UserId)
}
return userAuth, nil
}
func TestAuthMiddleware_Handler(t *testing.T) {
@@ -158,22 +172,24 @@ func TestAuthMiddleware_Handler(t *testing.T) {
}
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// do nothing
})
claimsExtractor := jwtclaims.NewClaimsExtractor(
jwtclaims.WithAudience(audience),
jwtclaims.WithUserIDClaim(userIDClaim),
)
mockAuth := &auth.MockManager{
ValidateAndParseTokenFunc: mockValidateAndParseToken,
EnsureUserAccessByJWTGroupsFunc: mockEnsureUserAccessByJWTGroups,
MarkPATUsedFunc: mockMarkPATUsed,
GetPATInfoFunc: mockGetAccountInfoFromPAT,
}
authMiddleware := NewAuthMiddleware(
mockGetAccountInfoFromPAT,
mockValidateAndParseToken,
mockMarkPATUsed,
mockCheckUserAccessByJWTGroups,
claimsExtractor,
audience,
userIDClaim,
mockAuth,
func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) {
return userAuth.AccountId, userAuth.UserId, nil
},
func(ctx context.Context, userAuth nbcontext.UserAuth) error {
return nil
},
)
handlerToTest := authMiddleware.Handler(nextHandler)
@@ -195,9 +211,115 @@ func TestAuthMiddleware_Handler(t *testing.T) {
result := rec.Result()
defer result.Body.Close()
if result.StatusCode != tc.expectedStatusCode {
t.Errorf("expected status code %d, got %d", tc.expectedStatusCode, result.StatusCode)
}
})
}
}
func TestAuthMiddleware_Handler_Child(t *testing.T) {
tt := []struct {
name string
path string
authHeader string
expectedUserAuth *nbcontext.UserAuth // nil expects 401 response status
}{
{
name: "Valid PAT Token",
path: "/test",
authHeader: "Token " + PAT,
expectedUserAuth: &nbcontext.UserAuth{
AccountId: accountID,
UserId: userID,
Domain: testAccount.Domain,
DomainCategory: testAccount.DomainCategory,
IsPAT: true,
},
},
{
name: "Valid PAT Token ignores child",
path: "/test?account=xyz",
authHeader: "Token " + PAT,
expectedUserAuth: &nbcontext.UserAuth{
AccountId: accountID,
UserId: userID,
Domain: testAccount.Domain,
DomainCategory: testAccount.DomainCategory,
IsPAT: true,
},
},
{
name: "Valid JWT Token",
path: "/test",
authHeader: "Bearer " + JWT,
expectedUserAuth: &nbcontext.UserAuth{
AccountId: accountID,
UserId: userID,
Domain: testAccount.Domain,
DomainCategory: testAccount.DomainCategory,
},
},
{
name: "Valid JWT Token with child",
path: "/test?account=xyz",
authHeader: "Bearer " + JWT,
expectedUserAuth: &nbcontext.UserAuth{
AccountId: "xyz",
UserId: userID,
Domain: testAccount.Domain,
DomainCategory: testAccount.DomainCategory,
IsChild: true,
},
},
}
mockAuth := &auth.MockManager{
ValidateAndParseTokenFunc: mockValidateAndParseToken,
EnsureUserAccessByJWTGroupsFunc: mockEnsureUserAccessByJWTGroups,
MarkPATUsedFunc: mockMarkPATUsed,
GetPATInfoFunc: mockGetAccountInfoFromPAT,
}
authMiddleware := NewAuthMiddleware(
mockAuth,
func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) {
return userAuth.AccountId, userAuth.UserId, nil
},
func(ctx context.Context, userAuth nbcontext.UserAuth) error {
return nil
},
)
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
handlerToTest := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromRequest(r)
if tc.expectedUserAuth != nil {
assert.NoError(t, err)
assert.Equal(t, *tc.expectedUserAuth, userAuth)
} else {
assert.Error(t, err)
assert.Empty(t, userAuth)
}
}))
req := httptest.NewRequest("GET", "http://testing"+tc.path, nil)
req.Header.Set("Authorization", tc.authHeader)
rec := httptest.NewRecorder()
handlerToTest.ServeHTTP(rec, req)
result := rec.Result()
defer result.Body.Close()
if tc.expectedUserAuth != nil {
assert.Equal(t, 200, result.StatusCode)
} else {
assert.Equal(t, 401, result.StatusCode)
}
})
}
}