mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-21 01:36:46 +00:00
Add context to throughout the project and update logging (#2209)
propagate context from all the API calls and log request ID, account ID and peer ID --------- Co-authored-by: Zoltan Papp <zoltan.pmail@gmail.com>
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"regexp"
|
||||
|
||||
@@ -15,7 +16,7 @@ import (
|
||||
)
|
||||
|
||||
// GetUser function defines a function to fetch user from Account by jwtclaims.AuthorizationClaims
|
||||
type GetUser func(claims jwtclaims.AuthorizationClaims) (*server.User, error)
|
||||
type GetUser func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.User, error)
|
||||
|
||||
// AccessControl middleware to restrict to make POST/PUT/DELETE requests by admin only
|
||||
type AccessControl struct {
|
||||
@@ -46,15 +47,15 @@ func (a *AccessControl) Handler(h http.Handler) http.Handler {
|
||||
|
||||
claims := a.claimsExtract.FromRequestContext(r)
|
||||
|
||||
user, err := a.getUser(claims)
|
||||
user, err := a.getUser(r.Context(), claims)
|
||||
if err != nil {
|
||||
log.Errorf("failed to get user from claims: %s", err)
|
||||
util.WriteError(status.Errorf(status.Unauthorized, "invalid JWT"), w)
|
||||
log.WithContext(r.Context()).Errorf("failed to get user from claims: %s", err)
|
||||
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "invalid JWT"), w)
|
||||
return
|
||||
}
|
||||
|
||||
if user.IsBlocked() {
|
||||
util.WriteError(status.Errorf(status.PermissionDenied, "the user has no access to the API or is blocked"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.PermissionDenied, "the user has no access to the API or is blocked"), w)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -63,12 +64,12 @@ func (a *AccessControl) Handler(h http.Handler) http.Handler {
|
||||
case http.MethodDelete, http.MethodPost, http.MethodPatch, http.MethodPut:
|
||||
|
||||
if tokenPathRegexp.MatchString(r.URL.Path) {
|
||||
log.Debugf("valid Path")
|
||||
log.WithContext(r.Context()).Debugf("valid Path")
|
||||
h.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteError(status.Errorf(status.PermissionDenied, "only users with admin power can perform this operation"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.PermissionDenied, "only users with admin power can perform this operation"), w)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
nbContext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/http/middleware/bypass"
|
||||
"github.com/netbirdio/netbird/management/server/http/util"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
@@ -19,16 +20,16 @@ import (
|
||||
)
|
||||
|
||||
// GetAccountFromPATFunc function
|
||||
type GetAccountFromPATFunc func(token string) (*server.Account, *server.User, *server.PersonalAccessToken, error)
|
||||
type GetAccountFromPATFunc func(ctx context.Context, token string) (*server.Account, *server.User, *server.PersonalAccessToken, error)
|
||||
|
||||
// ValidateAndParseTokenFunc function
|
||||
type ValidateAndParseTokenFunc func(token string) (*jwt.Token, error)
|
||||
type ValidateAndParseTokenFunc func(ctx context.Context, token string) (*jwt.Token, error)
|
||||
|
||||
// MarkPATUsedFunc function
|
||||
type MarkPATUsedFunc func(token string) error
|
||||
type MarkPATUsedFunc func(ctx context.Context, token string) error
|
||||
|
||||
// CheckUserAccessByJWTGroupsFunc function
|
||||
type CheckUserAccessByJWTGroupsFunc func(claims jwtclaims.AuthorizationClaims) error
|
||||
type CheckUserAccessByJWTGroupsFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) error
|
||||
|
||||
// AuthMiddleware middleware to verify personal access tokens (PAT) and JWT tokens
|
||||
type AuthMiddleware struct {
|
||||
@@ -85,23 +86,27 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler {
|
||||
case "bearer":
|
||||
err := m.checkJWTFromRequest(w, r, auth)
|
||||
if err != nil {
|
||||
log.Errorf("Error when validating JWT claims: %s", err.Error())
|
||||
util.WriteError(status.Errorf(status.Unauthorized, "token invalid"), w)
|
||||
log.WithContext(r.Context()).Errorf("Error when validating JWT claims: %s", err.Error())
|
||||
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "token invalid"), w)
|
||||
return
|
||||
}
|
||||
h.ServeHTTP(w, r)
|
||||
case "token":
|
||||
err := m.checkPATFromRequest(w, r, auth)
|
||||
if err != nil {
|
||||
log.Debugf("Error when validating PAT claims: %s", err.Error())
|
||||
util.WriteError(status.Errorf(status.Unauthorized, "token invalid"), w)
|
||||
log.WithContext(r.Context()).Debugf("Error when validating PAT claims: %s", err.Error())
|
||||
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "token invalid"), w)
|
||||
return
|
||||
}
|
||||
h.ServeHTTP(w, r)
|
||||
default:
|
||||
util.WriteError(status.Errorf(status.Unauthorized, "no valid authentication provided"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "no valid authentication provided"), w)
|
||||
return
|
||||
}
|
||||
claims := m.claimsExtractor.FromRequestContext(r)
|
||||
//nolint
|
||||
ctx := context.WithValue(r.Context(), nbContext.UserIDKey, claims.UserId)
|
||||
//nolint
|
||||
ctx = context.WithValue(ctx, nbContext.AccountIDKey, claims.AccountId)
|
||||
h.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
|
||||
@@ -114,7 +119,7 @@ func (m *AuthMiddleware) checkJWTFromRequest(w http.ResponseWriter, r *http.Requ
|
||||
return fmt.Errorf("Error extracting token: %w", err)
|
||||
}
|
||||
|
||||
validatedToken, err := m.validateAndParseToken(token)
|
||||
validatedToken, err := m.validateAndParseToken(r.Context(), token)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -123,7 +128,7 @@ func (m *AuthMiddleware) checkJWTFromRequest(w http.ResponseWriter, r *http.Requ
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := m.verifyUserAccess(validatedToken); err != nil {
|
||||
if err := m.verifyUserAccess(r.Context(), validatedToken); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -138,9 +143,9 @@ func (m *AuthMiddleware) checkJWTFromRequest(w http.ResponseWriter, r *http.Requ
|
||||
// verifyUserAccess checks if a user, based on a validated JWT token,
|
||||
// is allowed access, particularly in cases where the admin enabled JWT
|
||||
// group propagation and designated certain groups with access permissions.
|
||||
func (m *AuthMiddleware) verifyUserAccess(validatedToken *jwt.Token) error {
|
||||
func (m *AuthMiddleware) verifyUserAccess(ctx context.Context, validatedToken *jwt.Token) error {
|
||||
authClaims := m.claimsExtractor.FromToken(validatedToken)
|
||||
return m.checkUserAccessByJWTGroups(authClaims)
|
||||
return m.checkUserAccessByJWTGroups(ctx, authClaims)
|
||||
}
|
||||
|
||||
// CheckPATFromRequest checks if the PAT is valid
|
||||
@@ -152,7 +157,7 @@ func (m *AuthMiddleware) checkPATFromRequest(w http.ResponseWriter, r *http.Requ
|
||||
return fmt.Errorf("Error extracting token: %w", err)
|
||||
}
|
||||
|
||||
account, user, pat, err := m.getAccountFromPAT(token)
|
||||
account, user, pat, err := m.getAccountFromPAT(r.Context(), token)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid Token: %w", err)
|
||||
}
|
||||
@@ -160,7 +165,7 @@ func (m *AuthMiddleware) checkPATFromRequest(w http.ResponseWriter, r *http.Requ
|
||||
return fmt.Errorf("token expired")
|
||||
}
|
||||
|
||||
err = m.markPATUsed(pat.ID)
|
||||
err = m.markPATUsed(r.Context(), pat.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
@@ -15,15 +16,16 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
audience = "audience"
|
||||
userIDClaim = "userIDClaim"
|
||||
accountID = "accountID"
|
||||
domain = "domain"
|
||||
userID = "userID"
|
||||
tokenID = "tokenID"
|
||||
PAT = "nbp_PAT"
|
||||
JWT = "JWT"
|
||||
wrongToken = "wrongToken"
|
||||
audience = "audience"
|
||||
userIDClaim = "userIDClaim"
|
||||
accountID = "accountID"
|
||||
domain = "domain"
|
||||
domainCategory = "domainCategory"
|
||||
userID = "userID"
|
||||
tokenID = "tokenID"
|
||||
PAT = "nbp_PAT"
|
||||
JWT = "JWT"
|
||||
wrongToken = "wrongToken"
|
||||
)
|
||||
|
||||
var testAccount = &server.Account{
|
||||
@@ -47,14 +49,14 @@ var testAccount = &server.Account{
|
||||
},
|
||||
}
|
||||
|
||||
func mockGetAccountFromPAT(token string) (*server.Account, *server.User, *server.PersonalAccessToken, error) {
|
||||
func mockGetAccountFromPAT(_ context.Context, token string) (*server.Account, *server.User, *server.PersonalAccessToken, error) {
|
||||
if token == PAT {
|
||||
return testAccount, testAccount.Users[userID], testAccount.Users[userID].PATs[tokenID], nil
|
||||
}
|
||||
return nil, nil, nil, fmt.Errorf("PAT invalid")
|
||||
}
|
||||
|
||||
func mockValidateAndParseToken(token string) (*jwt.Token, error) {
|
||||
func mockValidateAndParseToken(_ context.Context, token string) (*jwt.Token, error) {
|
||||
if token == JWT {
|
||||
return &jwt.Token{
|
||||
Claims: jwt.MapClaims{
|
||||
@@ -67,14 +69,14 @@ func mockValidateAndParseToken(token string) (*jwt.Token, error) {
|
||||
return nil, fmt.Errorf("JWT invalid")
|
||||
}
|
||||
|
||||
func mockMarkPATUsed(token string) error {
|
||||
func mockMarkPATUsed(_ context.Context, token string) error {
|
||||
if token == tokenID {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("Should never get reached")
|
||||
}
|
||||
|
||||
func mockCheckUserAccessByJWTGroups(claims jwtclaims.AuthorizationClaims) error {
|
||||
func mockCheckUserAccessByJWTGroups(_ context.Context, claims jwtclaims.AuthorizationClaims) error {
|
||||
if testAccount.Id != claims.AccountId {
|
||||
return fmt.Errorf("account with id %s does not exist", claims.AccountId)
|
||||
}
|
||||
|
||||
@@ -56,7 +56,7 @@ func ShouldBypass(requestPath string, h http.Handler, w http.ResponseWriter, r *
|
||||
for bypassPath := range bypassPaths {
|
||||
matched, err := path.Match(bypassPath, requestPath)
|
||||
if err != nil {
|
||||
log.Errorf("Error matching path %s with %s from %s: %v", bypassPath, requestPath, GetList(), err)
|
||||
log.WithContext(r.Context()).Errorf("Error matching path %s with %s from %s: %v", bypassPath, requestPath, GetList(), err)
|
||||
continue
|
||||
}
|
||||
if matched {
|
||||
|
||||
Reference in New Issue
Block a user