Compare commits

...

4 Commits

Author SHA1 Message Date
bcmmbaga
6119ce791d Refactor 2026-04-21 20:51:42 +03:00
bcmmbaga
60b375455b Refactor 2026-04-21 20:48:29 +03:00
bcmmbaga
76d631d506 propagate the ctx changes to upstream 2026-04-21 20:41:27 +03:00
bcmmbaga
2c1103883c log transcation caller 2026-04-21 16:59:52 +03:00
3 changed files with 28 additions and 33 deletions

View File

@@ -12,6 +12,7 @@ import (
"go.opentelemetry.io/otel/metric" "go.opentelemetry.io/otel/metric"
"github.com/netbirdio/management-integrations/integrations" "github.com/netbirdio/management-integrations/integrations"
serverauth "github.com/netbirdio/netbird/management/server/auth" serverauth "github.com/netbirdio/netbird/management/server/auth"
nbcontext "github.com/netbirdio/netbird/management/server/context" nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/http/middleware/bypass" "github.com/netbirdio/netbird/management/server/http/middleware/bypass"
@@ -87,17 +88,14 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler {
switch authType { switch authType {
case "bearer": case "bearer":
request, err := m.checkJWTFromRequest(r, authHeader) if err := m.checkJWTFromRequest(r, authHeader); err != nil {
if err != nil {
log.WithContext(r.Context()).Errorf("Error when validating JWT: %s", err.Error()) log.WithContext(r.Context()).Errorf("Error when validating JWT: %s", err.Error())
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "token invalid"), w) util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "token invalid"), w)
return return
} }
h.ServeHTTP(w, r)
h.ServeHTTP(w, request)
case "token": case "token":
request, err := m.checkPATFromRequest(r, authHeader) if err := m.checkPATFromRequest(r, authHeader); err != nil {
if err != nil {
log.WithContext(r.Context()).Debugf("Error when validating PAT: %s", err.Error()) log.WithContext(r.Context()).Debugf("Error when validating PAT: %s", err.Error())
// Check if it's a status error, otherwise default to Unauthorized // Check if it's a status error, otherwise default to Unauthorized
if _, ok := status.FromError(err); !ok { if _, ok := status.FromError(err); !ok {
@@ -106,7 +104,7 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
} }
h.ServeHTTP(w, request) h.ServeHTTP(w, r)
default: default:
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "no valid authentication provided"), w) util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "no valid authentication provided"), w)
return return
@@ -115,19 +113,19 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler {
} }
// CheckJWTFromRequest checks if the JWT is valid // CheckJWTFromRequest checks if the JWT is valid
func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, authHeaderParts []string) (*http.Request, error) { func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, authHeaderParts []string) error {
token, err := getTokenFromJWTRequest(authHeaderParts) token, err := getTokenFromJWTRequest(authHeaderParts)
// If an error occurs, call the error handler and return an error // If an error occurs, call the error handler and return an error
if err != nil { if err != nil {
return r, fmt.Errorf("error extracting token: %w", err) return fmt.Errorf("error extracting token: %w", err)
} }
ctx := r.Context() ctx := r.Context()
userAuth, validatedToken, err := m.authManager.ValidateAndParseToken(ctx, token) userAuth, validatedToken, err := m.authManager.ValidateAndParseToken(ctx, token)
if err != nil { if err != nil {
return r, err return err
} }
if impersonate, ok := r.URL.Query()["account"]; ok && len(impersonate) == 1 { if impersonate, ok := r.URL.Query()["account"]; ok && len(impersonate) == 1 {
@@ -143,7 +141,7 @@ func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, authHeaderParts []
// we need to call this method because if user is new, we will automatically add it to existing or create a new account // we need to call this method because if user is new, we will automatically add it to existing or create a new account
accountId, _, err := m.ensureAccount(ctx, userAuth) accountId, _, err := m.ensureAccount(ctx, userAuth)
if err != nil { if err != nil {
return r, err return err
} }
if userAuth.AccountId != accountId { if userAuth.AccountId != accountId {
@@ -153,7 +151,7 @@ func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, authHeaderParts []
userAuth, err = m.authManager.EnsureUserAccessByJWTGroups(ctx, userAuth, validatedToken) userAuth, err = m.authManager.EnsureUserAccessByJWTGroups(ctx, userAuth, validatedToken)
if err != nil { if err != nil {
return r, err return err
} }
err = m.syncUserJWTGroups(ctx, userAuth) err = m.syncUserJWTGroups(ctx, userAuth)
@@ -164,17 +162,19 @@ func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, authHeaderParts []
_, err = m.getUserFromUserAuth(ctx, userAuth) _, err = m.getUserFromUserAuth(ctx, userAuth)
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("HTTP server failed to update user from user auth: %s", err) log.WithContext(ctx).Errorf("HTTP server failed to update user from user auth: %s", err)
return r, err return err
} }
return nbcontext.SetUserAuthInRequest(r, userAuth), nil // propagates ctx change to upstream middleware
*r = *nbcontext.SetUserAuthInRequest(r, userAuth)
return nil
} }
// CheckPATFromRequest checks if the PAT is valid // CheckPATFromRequest checks if the PAT is valid
func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts []string) (*http.Request, error) { func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts []string) error {
token, err := getTokenFromPATRequest(authHeaderParts) token, err := getTokenFromPATRequest(authHeaderParts)
if err != nil { if err != nil {
return r, fmt.Errorf("error extracting token: %w", err) return fmt.Errorf("error extracting token: %w", err)
} }
if m.patUsageTracker != nil { if m.patUsageTracker != nil {
@@ -183,22 +183,22 @@ func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts []
if m.rateLimiter != nil && !isTerraformRequest(r) { if m.rateLimiter != nil && !isTerraformRequest(r) {
if !m.rateLimiter.Allow(token) { if !m.rateLimiter.Allow(token) {
return r, status.Errorf(status.TooManyRequests, "too many requests") return status.Errorf(status.TooManyRequests, "too many requests")
} }
} }
ctx := r.Context() ctx := r.Context()
user, pat, accDomain, accCategory, err := m.authManager.GetPATInfo(ctx, token) user, pat, accDomain, accCategory, err := m.authManager.GetPATInfo(ctx, token)
if err != nil { if err != nil {
return r, fmt.Errorf("invalid Token: %w", err) return fmt.Errorf("invalid Token: %w", err)
} }
if time.Now().After(pat.GetExpirationDate()) { if time.Now().After(pat.GetExpirationDate()) {
return r, fmt.Errorf("token expired") return fmt.Errorf("token expired")
} }
err = m.authManager.MarkPATUsed(ctx, pat.ID) err = m.authManager.MarkPATUsed(ctx, pat.ID)
if err != nil { if err != nil {
return r, err return err
} }
userAuth := auth.UserAuth{ userAuth := auth.UserAuth{
@@ -216,7 +216,9 @@ func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts []
} }
} }
return nbcontext.SetUserAuthInRequest(r, userAuth), nil // propagates ctx change to upstream middleware
*r = *nbcontext.SetUserAuthInRequest(r, userAuth)
return nil
} }
func isTerraformRequest(r *http.Request) bool { func isTerraformRequest(r *http.Request) bool {

View File

@@ -45,6 +45,7 @@ import (
"github.com/netbirdio/netbird/management/server/util" "github.com/netbirdio/netbird/management/server/util"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/status" "github.com/netbirdio/netbird/shared/management/status"
nbutil "github.com/netbirdio/netbird/util"
"github.com/netbirdio/netbird/util/crypt" "github.com/netbirdio/netbird/util/crypt"
) )

View File

@@ -193,20 +193,12 @@ func (m *HTTPMiddleware) Handler(h http.Handler) http.Handler {
} }
}) })
h.ServeHTTP(w, r.WithContext(ctx)) // Hold on to req so auth's in-place ctx update is visible after ServeHTTP.
req := r.WithContext(ctx)
h.ServeHTTP(w, req)
close(handlerDone) close(handlerDone)
userAuth, err := nbContext.GetUserAuthFromContext(r.Context()) ctx = req.Context()
if err == nil {
if userAuth.AccountId != "" {
//nolint
ctx = context.WithValue(ctx, nbContext.AccountIDKey, userAuth.AccountId)
}
if userAuth.UserId != "" {
//nolint
ctx = context.WithValue(ctx, nbContext.UserIDKey, userAuth.UserId)
}
}
if w.Status() > 399 { if w.Status() > 399 {
log.WithContext(ctx).Errorf("HTTP response %v: %v %v status %v", reqID, r.Method, r.URL, w.Status()) log.WithContext(ctx).Errorf("HTTP response %v: %v %v status %v", reqID, r.Method, r.URL, w.Status())