From 7d08a609e66abca2be22d539239202fedfab0ad9 Mon Sep 17 00:00:00 2001 From: mlsmaycon Date: Tue, 10 Feb 2026 20:55:07 +0100 Subject: [PATCH] fix: capture account/service/user IDs in access logs for auth requests - Add accountID and serviceID to auth middleware DomainConfig - Set account/service IDs in CapturedData when domain is matched - Update AddDomain to accept accountID and serviceID parameters - Skip access logging for internal proxy assets (/__netbird__/*) - Return validationResult struct from validateSessionToken to preserve user ID even when access is denied - Capture user ID and auth method in access logs for denied requests --- proxy/internal/accesslog/middleware.go | 8 +++++ proxy/internal/auth/middleware.go | 50 ++++++++++++++++++++------ proxy/internal/auth/middleware_test.go | 44 +++++++++++------------ proxy/server.go | 2 +- 4 files changed, 70 insertions(+), 34 deletions(-) diff --git a/proxy/internal/accesslog/middleware.go b/proxy/internal/accesslog/middleware.go index c48e853bb..ca7556bfd 100644 --- a/proxy/internal/accesslog/middleware.go +++ b/proxy/internal/accesslog/middleware.go @@ -3,15 +3,23 @@ package accesslog import ( "net" "net/http" + "strings" "time" "github.com/rs/xid" "github.com/netbirdio/netbird/proxy/internal/proxy" + "github.com/netbirdio/netbird/proxy/web" ) func (l *Logger) Middleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Skip logging for internal proxy assets (CSS, JS, etc.) + if strings.HasPrefix(r.URL.Path, web.PathPrefix+"/") { + next.ServeHTTP(w, r) + return + } + // Generate request ID early so it can be used by error pages and log correlation. requestID := xid.New().String() diff --git a/proxy/internal/auth/middleware.go b/proxy/internal/auth/middleware.go index 9621a58c7..e4f937019 100644 --- a/proxy/internal/auth/middleware.go +++ b/proxy/internal/auth/middleware.go @@ -16,6 +16,7 @@ import ( "github.com/netbirdio/netbird/proxy/auth" "github.com/netbirdio/netbird/proxy/internal/proxy" + "github.com/netbirdio/netbird/proxy/internal/types" "github.com/netbirdio/netbird/proxy/web" "github.com/netbirdio/netbird/shared/management/proto" ) @@ -44,6 +45,14 @@ type DomainConfig struct { Schemes []Scheme SessionPublicKey ed25519.PublicKey SessionExpiration time.Duration + AccountID string + ServiceID string +} + +type validationResult struct { + UserID string + Valid bool + DeniedReason string } type Middleware struct { @@ -94,6 +103,12 @@ func (mw *Middleware) Protect(next http.Handler) http.Handler { return } + // Set account and service IDs in captured data for access logging. + if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil { + cd.SetAccountId(types.AccountID(config.AccountID)) + cd.SetServiceId(config.ServiceID) + } + // Check for error from OAuth callback (e.g., access denied) if errCode := r.URL.Query().Get("error"); errCode != "" { var requestID string @@ -126,7 +141,7 @@ func (mw *Middleware) Protect(next http.Handler) http.Handler { for _, scheme := range config.Schemes { token, promptData := scheme.Authenticate(r) if token != "" { - userID, err := mw.validateSessionToken(r.Context(), host, token, config.SessionPublicKey, scheme.Type()) + result, err := mw.validateSessionToken(r.Context(), host, token, config.SessionPublicKey, scheme.Type()) if err != nil { if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil { cd.SetOrigin(proxy.OriginAuth) @@ -134,10 +149,12 @@ func (mw *Middleware) Protect(next http.Handler) http.Handler { http.Error(w, err.Error(), http.StatusBadRequest) return } - if userID == "" { + if !result.Valid { var requestID string if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil { cd.SetOrigin(proxy.OriginAuth) + cd.SetUserID(result.UserID) + cd.SetAuthMethod(scheme.Type().String()) requestID = cd.GetRequestID() } web.ServeAccessDeniedPage(w, r, http.StatusForbidden, "Access Denied", "You are not authorized to access this service", requestID) @@ -161,6 +178,8 @@ func (mw *Middleware) Protect(next http.Handler) http.Handler { // The browser will follow with a GET carrying the new session cookie. if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil { cd.SetOrigin(proxy.OriginAuth) + cd.SetUserID(result.UserID) + cd.SetAuthMethod(scheme.Type().String()) } redirectURL := stripSessionTokenParam(r.URL) http.Redirect(w, r, redirectURL, http.StatusSeeOther) @@ -181,11 +200,14 @@ func (mw *Middleware) Protect(next http.Handler) http.Handler { // session JWTs. Returns an error if the key is missing or invalid. // Callers must not serve the domain if this returns an error, to avoid // exposing an unauthenticated service. -func (mw *Middleware) AddDomain(domain string, schemes []Scheme, publicKeyB64 string, expiration time.Duration) error { +func (mw *Middleware) AddDomain(domain string, schemes []Scheme, publicKeyB64 string, expiration time.Duration, accountID, serviceID string) error { if len(schemes) == 0 { mw.domainsMux.Lock() defer mw.domainsMux.Unlock() - mw.domains[domain] = DomainConfig{} + mw.domains[domain] = DomainConfig{ + AccountID: accountID, + ServiceID: serviceID, + } return nil } @@ -203,6 +225,8 @@ func (mw *Middleware) AddDomain(domain string, schemes []Scheme, publicKeyB64 st Schemes: schemes, SessionPublicKey: pubKeyBytes, SessionExpiration: expiration, + AccountID: accountID, + ServiceID: serviceID, } return nil } @@ -216,8 +240,8 @@ func (mw *Middleware) RemoveDomain(domain string) { // validateSessionToken validates a session token, optionally checking group access via gRPC. // For OIDC tokens with a configured validator, it calls ValidateSession to check group access. // For other auth methods (PIN, password), it validates the JWT locally. -// Returns the user ID if valid, empty string if access denied, or error for invalid tokens. -func (mw *Middleware) validateSessionToken(ctx context.Context, host, token string, publicKey ed25519.PublicKey, method auth.Method) (string, error) { +// Returns a validationResult with user ID and validity status, or error for invalid tokens. +func (mw *Middleware) validateSessionToken(ctx context.Context, host, token string, publicKey ed25519.PublicKey, method auth.Method) (*validationResult, error) { // For OIDC with a session validator, call the gRPC service to check group access if method == auth.MethodOIDC && mw.sessionValidator != nil { resp, err := mw.sessionValidator.ValidateSession(ctx, &proto.ValidateSessionRequest{ @@ -226,7 +250,7 @@ func (mw *Middleware) validateSessionToken(ctx context.Context, host, token stri }) if err != nil { mw.logger.WithError(err).Error("ValidateSession gRPC call failed") - return "", fmt.Errorf("session validation failed") + return nil, fmt.Errorf("session validation failed") } if !resp.Valid { mw.logger.WithFields(log.Fields{ @@ -234,17 +258,21 @@ func (mw *Middleware) validateSessionToken(ctx context.Context, host, token stri "denied_reason": resp.DeniedReason, "user_id": resp.UserId, }).Debug("Session validation denied") - return "", nil + return &validationResult{ + UserID: resp.UserId, + Valid: false, + DeniedReason: resp.DeniedReason, + }, nil } - return resp.UserId, nil + return &validationResult{UserID: resp.UserId, Valid: true}, nil } // For non-OIDC methods or when no validator is configured, validate JWT locally userID, _, err := auth.ValidateSessionJWT(token, host, publicKey) if err != nil { - return "", err + return nil, err } - return userID, nil + return &validationResult{UserID: userID, Valid: true}, nil } // stripSessionTokenParam returns the request URI with the session_token query diff --git a/proxy/internal/auth/middleware_test.go b/proxy/internal/auth/middleware_test.go index eac4749d5..dd6529164 100644 --- a/proxy/internal/auth/middleware_test.go +++ b/proxy/internal/auth/middleware_test.go @@ -56,7 +56,7 @@ func TestAddDomain_ValidKey(t *testing.T) { kp := generateTestKeyPair(t) scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"} - err := mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour) + err := mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "") require.NoError(t, err) mw.domainsMux.RLock() @@ -73,7 +73,7 @@ func TestAddDomain_EmptyKey(t *testing.T) { mw := NewMiddleware(log.StandardLogger(), nil) scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"} - err := mw.AddDomain("example.com", []Scheme{scheme}, "", time.Hour) + err := mw.AddDomain("example.com", []Scheme{scheme}, "", time.Hour, "", "") require.Error(t, err) assert.Contains(t, err.Error(), "invalid session public key size") @@ -87,7 +87,7 @@ func TestAddDomain_InvalidBase64(t *testing.T) { mw := NewMiddleware(log.StandardLogger(), nil) scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"} - err := mw.AddDomain("example.com", []Scheme{scheme}, "not-valid-base64!!!", time.Hour) + err := mw.AddDomain("example.com", []Scheme{scheme}, "not-valid-base64!!!", time.Hour, "", "") require.Error(t, err) assert.Contains(t, err.Error(), "decode session public key") @@ -102,7 +102,7 @@ func TestAddDomain_WrongKeySize(t *testing.T) { shortKey := base64.StdEncoding.EncodeToString([]byte("tooshort")) scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"} - err := mw.AddDomain("example.com", []Scheme{scheme}, shortKey, time.Hour) + err := mw.AddDomain("example.com", []Scheme{scheme}, shortKey, time.Hour, "", "") require.Error(t, err) assert.Contains(t, err.Error(), "invalid session public key size") @@ -115,7 +115,7 @@ func TestAddDomain_WrongKeySize(t *testing.T) { func TestAddDomain_NoSchemes_NoKeyRequired(t *testing.T) { mw := NewMiddleware(log.StandardLogger(), nil) - err := mw.AddDomain("example.com", nil, "", time.Hour) + err := mw.AddDomain("example.com", nil, "", time.Hour, "", "") require.NoError(t, err, "domains with no auth schemes should not require a key") mw.domainsMux.RLock() @@ -131,8 +131,8 @@ func TestAddDomain_OverwritesPreviousConfig(t *testing.T) { scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"} - require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp1.PublicKey, time.Hour)) - require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp2.PublicKey, 2*time.Hour)) + require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp1.PublicKey, time.Hour, "", "")) + require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp2.PublicKey, 2*time.Hour, "", "")) mw.domainsMux.RLock() config := mw.domains["example.com"] @@ -148,7 +148,7 @@ func TestRemoveDomain(t *testing.T) { kp := generateTestKeyPair(t) scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"} - require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour)) + require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "")) mw.RemoveDomain("example.com") @@ -172,7 +172,7 @@ func TestProtect_UnknownDomainPassesThrough(t *testing.T) { func TestProtect_DomainWithNoSchemesPassesThrough(t *testing.T) { mw := NewMiddleware(log.StandardLogger(), nil) - require.NoError(t, mw.AddDomain("example.com", nil, "", time.Hour)) + require.NoError(t, mw.AddDomain("example.com", nil, "", time.Hour, "", "")) handler := mw.Protect(newPassthroughHandler()) @@ -189,7 +189,7 @@ func TestProtect_UnauthenticatedRequestIsBlocked(t *testing.T) { kp := generateTestKeyPair(t) scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"} - require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour)) + require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "")) var backendCalled bool backend := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { @@ -210,7 +210,7 @@ func TestProtect_HostWithPortIsMatched(t *testing.T) { kp := generateTestKeyPair(t) scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"} - require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour)) + require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "")) var backendCalled bool backend := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { @@ -231,7 +231,7 @@ func TestProtect_ValidSessionCookiePassesThrough(t *testing.T) { kp := generateTestKeyPair(t) scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"} - require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour)) + require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "")) token, err := sessionkey.SignToken(kp.PrivateKey, "test-user", "example.com", auth.MethodPIN, time.Hour) require.NoError(t, err) @@ -261,7 +261,7 @@ func TestProtect_ExpiredSessionCookieIsRejected(t *testing.T) { kp := generateTestKeyPair(t) scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"} - require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour)) + require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "")) // Sign a token that expired 1 second ago. token, err := sessionkey.SignToken(kp.PrivateKey, "test-user", "example.com", auth.MethodPIN, -time.Second) @@ -287,7 +287,7 @@ func TestProtect_WrongDomainCookieIsRejected(t *testing.T) { kp := generateTestKeyPair(t) scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"} - require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour)) + require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "")) // Token signed for a different domain audience. token, err := sessionkey.SignToken(kp.PrivateKey, "test-user", "other.com", auth.MethodPIN, time.Hour) @@ -314,7 +314,7 @@ func TestProtect_WrongKeyCookieIsRejected(t *testing.T) { kp2 := generateTestKeyPair(t) scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"} - require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp1.PublicKey, time.Hour)) + require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp1.PublicKey, time.Hour, "", "")) // Token signed with a different private key. token, err := sessionkey.SignToken(kp2.PrivateKey, "test-user", "example.com", auth.MethodPIN, time.Hour) @@ -351,7 +351,7 @@ func TestProtect_SchemeAuthRedirectsWithCookie(t *testing.T) { return "", "pin" }, } - require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour)) + require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "")) var backendCalled bool backend := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { @@ -395,7 +395,7 @@ func TestProtect_FailedAuthDoesNotSetCookie(t *testing.T) { return "", "pin" }, } - require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour)) + require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "")) handler := mw.Protect(newPassthroughHandler()) @@ -431,7 +431,7 @@ func TestProtect_MultipleSchemes(t *testing.T) { return "", "password" }, } - require.NoError(t, mw.AddDomain("example.com", []Scheme{pinScheme, passwordScheme}, kp.PublicKey, time.Hour)) + require.NoError(t, mw.AddDomain("example.com", []Scheme{pinScheme, passwordScheme}, kp.PublicKey, time.Hour, "", "")) var backendCalled bool backend := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { @@ -461,7 +461,7 @@ func TestProtect_InvalidTokenFromSchemeReturns400(t *testing.T) { return "invalid-jwt-token", "" }, } - require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour)) + require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "")) handler := mw.Protect(newPassthroughHandler()) @@ -485,7 +485,7 @@ func TestAddDomain_RandomBytes32NotEd25519(t *testing.T) { key := base64.StdEncoding.EncodeToString(randomBytes) scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"} - err = mw.AddDomain("example.com", []Scheme{scheme}, key, time.Hour) + err = mw.AddDomain("example.com", []Scheme{scheme}, key, time.Hour, "", "") require.NoError(t, err, "any 32-byte key should be accepted at registration time") } @@ -494,10 +494,10 @@ func TestAddDomain_InvalidKeyDoesNotCorruptExistingConfig(t *testing.T) { kp := generateTestKeyPair(t) scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"} - require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour)) + require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "")) // Attempt to overwrite with an invalid key. - err := mw.AddDomain("example.com", []Scheme{scheme}, "bad", time.Hour) + err := mw.AddDomain("example.com", []Scheme{scheme}, "bad", time.Hour, "", "") require.Error(t, err) // The original valid config should still be intact. diff --git a/proxy/server.go b/proxy/server.go index b5ec50906..4bbb8ac79 100644 --- a/proxy/server.go +++ b/proxy/server.go @@ -551,7 +551,7 @@ func (s *Server) updateMapping(ctx context.Context, mapping *proto.ProxyMapping) } maxSessionAge := time.Duration(mapping.GetAuth().GetMaxSessionAgeSeconds()) * time.Second - if err := s.auth.AddDomain(mapping.GetDomain(), schemes, mapping.GetAuth().GetSessionKey(), maxSessionAge); err != nil { + if err := s.auth.AddDomain(mapping.GetDomain(), schemes, mapping.GetAuth().GetSessionKey(), maxSessionAge, mapping.GetAccountId(), mapping.GetId()); err != nil { s.Logger.WithField("domain", mapping.GetDomain()).WithError(err).Error("Auth setup failed, refusing to serve domain without authentication") return }