Add context to throughout the project and update logging (#2209)

propagate context from all the API calls and log request ID, account ID and peer ID

---------

Co-authored-by: Zoltan Papp <zoltan.pmail@gmail.com>
This commit is contained in:
pascal-fischer
2024-07-03 11:33:02 +02:00
committed by GitHub
parent 7cb81f1d70
commit 765aba2c1c
127 changed files with 2936 additions and 2642 deletions

View File

@@ -1,6 +1,7 @@
package middleware
import (
"context"
"net/http"
"regexp"
@@ -15,7 +16,7 @@ import (
)
// GetUser function defines a function to fetch user from Account by jwtclaims.AuthorizationClaims
type GetUser func(claims jwtclaims.AuthorizationClaims) (*server.User, error)
type GetUser func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.User, error)
// AccessControl middleware to restrict to make POST/PUT/DELETE requests by admin only
type AccessControl struct {
@@ -46,15 +47,15 @@ func (a *AccessControl) Handler(h http.Handler) http.Handler {
claims := a.claimsExtract.FromRequestContext(r)
user, err := a.getUser(claims)
user, err := a.getUser(r.Context(), claims)
if err != nil {
log.Errorf("failed to get user from claims: %s", err)
util.WriteError(status.Errorf(status.Unauthorized, "invalid JWT"), w)
log.WithContext(r.Context()).Errorf("failed to get user from claims: %s", err)
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "invalid JWT"), w)
return
}
if user.IsBlocked() {
util.WriteError(status.Errorf(status.PermissionDenied, "the user has no access to the API or is blocked"), w)
util.WriteError(r.Context(), status.Errorf(status.PermissionDenied, "the user has no access to the API or is blocked"), w)
return
}
@@ -63,12 +64,12 @@ func (a *AccessControl) Handler(h http.Handler) http.Handler {
case http.MethodDelete, http.MethodPost, http.MethodPatch, http.MethodPut:
if tokenPathRegexp.MatchString(r.URL.Path) {
log.Debugf("valid Path")
log.WithContext(r.Context()).Debugf("valid Path")
h.ServeHTTP(w, r)
return
}
util.WriteError(status.Errorf(status.PermissionDenied, "only users with admin power can perform this operation"), w)
util.WriteError(r.Context(), status.Errorf(status.PermissionDenied, "only users with admin power can perform this operation"), w)
return
}
}

View File

@@ -12,6 +12,7 @@ import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server"
nbContext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/http/middleware/bypass"
"github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/jwtclaims"
@@ -19,16 +20,16 @@ import (
)
// GetAccountFromPATFunc function
type GetAccountFromPATFunc func(token string) (*server.Account, *server.User, *server.PersonalAccessToken, error)
type GetAccountFromPATFunc func(ctx context.Context, token string) (*server.Account, *server.User, *server.PersonalAccessToken, error)
// ValidateAndParseTokenFunc function
type ValidateAndParseTokenFunc func(token string) (*jwt.Token, error)
type ValidateAndParseTokenFunc func(ctx context.Context, token string) (*jwt.Token, error)
// MarkPATUsedFunc function
type MarkPATUsedFunc func(token string) error
type MarkPATUsedFunc func(ctx context.Context, token string) error
// CheckUserAccessByJWTGroupsFunc function
type CheckUserAccessByJWTGroupsFunc func(claims jwtclaims.AuthorizationClaims) error
type CheckUserAccessByJWTGroupsFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) error
// AuthMiddleware middleware to verify personal access tokens (PAT) and JWT tokens
type AuthMiddleware struct {
@@ -85,23 +86,27 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler {
case "bearer":
err := m.checkJWTFromRequest(w, r, auth)
if err != nil {
log.Errorf("Error when validating JWT claims: %s", err.Error())
util.WriteError(status.Errorf(status.Unauthorized, "token invalid"), w)
log.WithContext(r.Context()).Errorf("Error when validating JWT claims: %s", err.Error())
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "token invalid"), w)
return
}
h.ServeHTTP(w, r)
case "token":
err := m.checkPATFromRequest(w, r, auth)
if err != nil {
log.Debugf("Error when validating PAT claims: %s", err.Error())
util.WriteError(status.Errorf(status.Unauthorized, "token invalid"), w)
log.WithContext(r.Context()).Debugf("Error when validating PAT claims: %s", err.Error())
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "token invalid"), w)
return
}
h.ServeHTTP(w, r)
default:
util.WriteError(status.Errorf(status.Unauthorized, "no valid authentication provided"), w)
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "no valid authentication provided"), w)
return
}
claims := m.claimsExtractor.FromRequestContext(r)
//nolint
ctx := context.WithValue(r.Context(), nbContext.UserIDKey, claims.UserId)
//nolint
ctx = context.WithValue(ctx, nbContext.AccountIDKey, claims.AccountId)
h.ServeHTTP(w, r.WithContext(ctx))
})
}
@@ -114,7 +119,7 @@ func (m *AuthMiddleware) checkJWTFromRequest(w http.ResponseWriter, r *http.Requ
return fmt.Errorf("Error extracting token: %w", err)
}
validatedToken, err := m.validateAndParseToken(token)
validatedToken, err := m.validateAndParseToken(r.Context(), token)
if err != nil {
return err
}
@@ -123,7 +128,7 @@ func (m *AuthMiddleware) checkJWTFromRequest(w http.ResponseWriter, r *http.Requ
return nil
}
if err := m.verifyUserAccess(validatedToken); err != nil {
if err := m.verifyUserAccess(r.Context(), validatedToken); err != nil {
return err
}
@@ -138,9 +143,9 @@ func (m *AuthMiddleware) checkJWTFromRequest(w http.ResponseWriter, r *http.Requ
// verifyUserAccess checks if a user, based on a validated JWT token,
// is allowed access, particularly in cases where the admin enabled JWT
// group propagation and designated certain groups with access permissions.
func (m *AuthMiddleware) verifyUserAccess(validatedToken *jwt.Token) error {
func (m *AuthMiddleware) verifyUserAccess(ctx context.Context, validatedToken *jwt.Token) error {
authClaims := m.claimsExtractor.FromToken(validatedToken)
return m.checkUserAccessByJWTGroups(authClaims)
return m.checkUserAccessByJWTGroups(ctx, authClaims)
}
// CheckPATFromRequest checks if the PAT is valid
@@ -152,7 +157,7 @@ func (m *AuthMiddleware) checkPATFromRequest(w http.ResponseWriter, r *http.Requ
return fmt.Errorf("Error extracting token: %w", err)
}
account, user, pat, err := m.getAccountFromPAT(token)
account, user, pat, err := m.getAccountFromPAT(r.Context(), token)
if err != nil {
return fmt.Errorf("invalid Token: %w", err)
}
@@ -160,7 +165,7 @@ func (m *AuthMiddleware) checkPATFromRequest(w http.ResponseWriter, r *http.Requ
return fmt.Errorf("token expired")
}
err = m.markPATUsed(pat.ID)
err = m.markPATUsed(r.Context(), pat.ID)
if err != nil {
return err
}

View File

@@ -1,6 +1,7 @@
package middleware
import (
"context"
"fmt"
"net/http"
"net/http/httptest"
@@ -15,15 +16,16 @@ import (
)
const (
audience = "audience"
userIDClaim = "userIDClaim"
accountID = "accountID"
domain = "domain"
userID = "userID"
tokenID = "tokenID"
PAT = "nbp_PAT"
JWT = "JWT"
wrongToken = "wrongToken"
audience = "audience"
userIDClaim = "userIDClaim"
accountID = "accountID"
domain = "domain"
domainCategory = "domainCategory"
userID = "userID"
tokenID = "tokenID"
PAT = "nbp_PAT"
JWT = "JWT"
wrongToken = "wrongToken"
)
var testAccount = &server.Account{
@@ -47,14 +49,14 @@ var testAccount = &server.Account{
},
}
func mockGetAccountFromPAT(token string) (*server.Account, *server.User, *server.PersonalAccessToken, error) {
func mockGetAccountFromPAT(_ context.Context, token string) (*server.Account, *server.User, *server.PersonalAccessToken, error) {
if token == PAT {
return testAccount, testAccount.Users[userID], testAccount.Users[userID].PATs[tokenID], nil
}
return nil, nil, nil, fmt.Errorf("PAT invalid")
}
func mockValidateAndParseToken(token string) (*jwt.Token, error) {
func mockValidateAndParseToken(_ context.Context, token string) (*jwt.Token, error) {
if token == JWT {
return &jwt.Token{
Claims: jwt.MapClaims{
@@ -67,14 +69,14 @@ func mockValidateAndParseToken(token string) (*jwt.Token, error) {
return nil, fmt.Errorf("JWT invalid")
}
func mockMarkPATUsed(token string) error {
func mockMarkPATUsed(_ context.Context, token string) error {
if token == tokenID {
return nil
}
return fmt.Errorf("Should never get reached")
}
func mockCheckUserAccessByJWTGroups(claims jwtclaims.AuthorizationClaims) error {
func mockCheckUserAccessByJWTGroups(_ context.Context, claims jwtclaims.AuthorizationClaims) error {
if testAccount.Id != claims.AccountId {
return fmt.Errorf("account with id %s does not exist", claims.AccountId)
}

View File

@@ -56,7 +56,7 @@ func ShouldBypass(requestPath string, h http.Handler, w http.ResponseWriter, r *
for bypassPath := range bypassPaths {
matched, err := path.Match(bypassPath, requestPath)
if err != nil {
log.Errorf("Error matching path %s with %s from %s: %v", bypassPath, requestPath, GetList(), err)
log.WithContext(r.Context()).Errorf("Error matching path %s with %s from %s: %v", bypassPath, requestPath, GetList(), err)
continue
}
if matched {