mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-16 15:26:40 +00:00
fix: capture account/service/user IDs in access logs for auth requests
- Add accountID and serviceID to auth middleware DomainConfig - Set account/service IDs in CapturedData when domain is matched - Update AddDomain to accept accountID and serviceID parameters - Skip access logging for internal proxy assets (/__netbird__/*) - Return validationResult struct from validateSessionToken to preserve user ID even when access is denied - Capture user ID and auth method in access logs for denied requests
This commit is contained in:
@@ -3,15 +3,23 @@ package accesslog
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/rs/xid"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/internal/proxy"
|
||||
"github.com/netbirdio/netbird/proxy/web"
|
||||
)
|
||||
|
||||
func (l *Logger) Middleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Skip logging for internal proxy assets (CSS, JS, etc.)
|
||||
if strings.HasPrefix(r.URL.Path, web.PathPrefix+"/") {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
// Generate request ID early so it can be used by error pages and log correlation.
|
||||
requestID := xid.New().String()
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/auth"
|
||||
"github.com/netbirdio/netbird/proxy/internal/proxy"
|
||||
"github.com/netbirdio/netbird/proxy/internal/types"
|
||||
"github.com/netbirdio/netbird/proxy/web"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
@@ -44,6 +45,14 @@ type DomainConfig struct {
|
||||
Schemes []Scheme
|
||||
SessionPublicKey ed25519.PublicKey
|
||||
SessionExpiration time.Duration
|
||||
AccountID string
|
||||
ServiceID string
|
||||
}
|
||||
|
||||
type validationResult struct {
|
||||
UserID string
|
||||
Valid bool
|
||||
DeniedReason string
|
||||
}
|
||||
|
||||
type Middleware struct {
|
||||
@@ -94,6 +103,12 @@ func (mw *Middleware) Protect(next http.Handler) http.Handler {
|
||||
return
|
||||
}
|
||||
|
||||
// Set account and service IDs in captured data for access logging.
|
||||
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
|
||||
cd.SetAccountId(types.AccountID(config.AccountID))
|
||||
cd.SetServiceId(config.ServiceID)
|
||||
}
|
||||
|
||||
// Check for error from OAuth callback (e.g., access denied)
|
||||
if errCode := r.URL.Query().Get("error"); errCode != "" {
|
||||
var requestID string
|
||||
@@ -126,7 +141,7 @@ func (mw *Middleware) Protect(next http.Handler) http.Handler {
|
||||
for _, scheme := range config.Schemes {
|
||||
token, promptData := scheme.Authenticate(r)
|
||||
if token != "" {
|
||||
userID, err := mw.validateSessionToken(r.Context(), host, token, config.SessionPublicKey, scheme.Type())
|
||||
result, err := mw.validateSessionToken(r.Context(), host, token, config.SessionPublicKey, scheme.Type())
|
||||
if err != nil {
|
||||
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
|
||||
cd.SetOrigin(proxy.OriginAuth)
|
||||
@@ -134,10 +149,12 @@ func (mw *Middleware) Protect(next http.Handler) http.Handler {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if userID == "" {
|
||||
if !result.Valid {
|
||||
var requestID string
|
||||
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
|
||||
cd.SetOrigin(proxy.OriginAuth)
|
||||
cd.SetUserID(result.UserID)
|
||||
cd.SetAuthMethod(scheme.Type().String())
|
||||
requestID = cd.GetRequestID()
|
||||
}
|
||||
web.ServeAccessDeniedPage(w, r, http.StatusForbidden, "Access Denied", "You are not authorized to access this service", requestID)
|
||||
@@ -161,6 +178,8 @@ func (mw *Middleware) Protect(next http.Handler) http.Handler {
|
||||
// The browser will follow with a GET carrying the new session cookie.
|
||||
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
|
||||
cd.SetOrigin(proxy.OriginAuth)
|
||||
cd.SetUserID(result.UserID)
|
||||
cd.SetAuthMethod(scheme.Type().String())
|
||||
}
|
||||
redirectURL := stripSessionTokenParam(r.URL)
|
||||
http.Redirect(w, r, redirectURL, http.StatusSeeOther)
|
||||
@@ -181,11 +200,14 @@ func (mw *Middleware) Protect(next http.Handler) http.Handler {
|
||||
// session JWTs. Returns an error if the key is missing or invalid.
|
||||
// Callers must not serve the domain if this returns an error, to avoid
|
||||
// exposing an unauthenticated service.
|
||||
func (mw *Middleware) AddDomain(domain string, schemes []Scheme, publicKeyB64 string, expiration time.Duration) error {
|
||||
func (mw *Middleware) AddDomain(domain string, schemes []Scheme, publicKeyB64 string, expiration time.Duration, accountID, serviceID string) error {
|
||||
if len(schemes) == 0 {
|
||||
mw.domainsMux.Lock()
|
||||
defer mw.domainsMux.Unlock()
|
||||
mw.domains[domain] = DomainConfig{}
|
||||
mw.domains[domain] = DomainConfig{
|
||||
AccountID: accountID,
|
||||
ServiceID: serviceID,
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -203,6 +225,8 @@ func (mw *Middleware) AddDomain(domain string, schemes []Scheme, publicKeyB64 st
|
||||
Schemes: schemes,
|
||||
SessionPublicKey: pubKeyBytes,
|
||||
SessionExpiration: expiration,
|
||||
AccountID: accountID,
|
||||
ServiceID: serviceID,
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -216,8 +240,8 @@ func (mw *Middleware) RemoveDomain(domain string) {
|
||||
// validateSessionToken validates a session token, optionally checking group access via gRPC.
|
||||
// For OIDC tokens with a configured validator, it calls ValidateSession to check group access.
|
||||
// For other auth methods (PIN, password), it validates the JWT locally.
|
||||
// Returns the user ID if valid, empty string if access denied, or error for invalid tokens.
|
||||
func (mw *Middleware) validateSessionToken(ctx context.Context, host, token string, publicKey ed25519.PublicKey, method auth.Method) (string, error) {
|
||||
// Returns a validationResult with user ID and validity status, or error for invalid tokens.
|
||||
func (mw *Middleware) validateSessionToken(ctx context.Context, host, token string, publicKey ed25519.PublicKey, method auth.Method) (*validationResult, error) {
|
||||
// For OIDC with a session validator, call the gRPC service to check group access
|
||||
if method == auth.MethodOIDC && mw.sessionValidator != nil {
|
||||
resp, err := mw.sessionValidator.ValidateSession(ctx, &proto.ValidateSessionRequest{
|
||||
@@ -226,7 +250,7 @@ func (mw *Middleware) validateSessionToken(ctx context.Context, host, token stri
|
||||
})
|
||||
if err != nil {
|
||||
mw.logger.WithError(err).Error("ValidateSession gRPC call failed")
|
||||
return "", fmt.Errorf("session validation failed")
|
||||
return nil, fmt.Errorf("session validation failed")
|
||||
}
|
||||
if !resp.Valid {
|
||||
mw.logger.WithFields(log.Fields{
|
||||
@@ -234,17 +258,21 @@ func (mw *Middleware) validateSessionToken(ctx context.Context, host, token stri
|
||||
"denied_reason": resp.DeniedReason,
|
||||
"user_id": resp.UserId,
|
||||
}).Debug("Session validation denied")
|
||||
return "", nil
|
||||
return &validationResult{
|
||||
UserID: resp.UserId,
|
||||
Valid: false,
|
||||
DeniedReason: resp.DeniedReason,
|
||||
}, nil
|
||||
}
|
||||
return resp.UserId, nil
|
||||
return &validationResult{UserID: resp.UserId, Valid: true}, nil
|
||||
}
|
||||
|
||||
// For non-OIDC methods or when no validator is configured, validate JWT locally
|
||||
userID, _, err := auth.ValidateSessionJWT(token, host, publicKey)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return nil, err
|
||||
}
|
||||
return userID, nil
|
||||
return &validationResult{UserID: userID, Valid: true}, nil
|
||||
}
|
||||
|
||||
// stripSessionTokenParam returns the request URI with the session_token query
|
||||
|
||||
@@ -56,7 +56,7 @@ func TestAddDomain_ValidKey(t *testing.T) {
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
err := mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour)
|
||||
err := mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "")
|
||||
require.NoError(t, err)
|
||||
|
||||
mw.domainsMux.RLock()
|
||||
@@ -73,7 +73,7 @@ func TestAddDomain_EmptyKey(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil)
|
||||
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
err := mw.AddDomain("example.com", []Scheme{scheme}, "", time.Hour)
|
||||
err := mw.AddDomain("example.com", []Scheme{scheme}, "", time.Hour, "", "")
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "invalid session public key size")
|
||||
|
||||
@@ -87,7 +87,7 @@ func TestAddDomain_InvalidBase64(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil)
|
||||
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
err := mw.AddDomain("example.com", []Scheme{scheme}, "not-valid-base64!!!", time.Hour)
|
||||
err := mw.AddDomain("example.com", []Scheme{scheme}, "not-valid-base64!!!", time.Hour, "", "")
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "decode session public key")
|
||||
|
||||
@@ -102,7 +102,7 @@ func TestAddDomain_WrongKeySize(t *testing.T) {
|
||||
|
||||
shortKey := base64.StdEncoding.EncodeToString([]byte("tooshort"))
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
err := mw.AddDomain("example.com", []Scheme{scheme}, shortKey, time.Hour)
|
||||
err := mw.AddDomain("example.com", []Scheme{scheme}, shortKey, time.Hour, "", "")
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "invalid session public key size")
|
||||
|
||||
@@ -115,7 +115,7 @@ func TestAddDomain_WrongKeySize(t *testing.T) {
|
||||
func TestAddDomain_NoSchemes_NoKeyRequired(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil)
|
||||
|
||||
err := mw.AddDomain("example.com", nil, "", time.Hour)
|
||||
err := mw.AddDomain("example.com", nil, "", time.Hour, "", "")
|
||||
require.NoError(t, err, "domains with no auth schemes should not require a key")
|
||||
|
||||
mw.domainsMux.RLock()
|
||||
@@ -131,8 +131,8 @@ func TestAddDomain_OverwritesPreviousConfig(t *testing.T) {
|
||||
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp1.PublicKey, time.Hour))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp2.PublicKey, 2*time.Hour))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp1.PublicKey, time.Hour, "", ""))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp2.PublicKey, 2*time.Hour, "", ""))
|
||||
|
||||
mw.domainsMux.RLock()
|
||||
config := mw.domains["example.com"]
|
||||
@@ -148,7 +148,7 @@ func TestRemoveDomain(t *testing.T) {
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", ""))
|
||||
|
||||
mw.RemoveDomain("example.com")
|
||||
|
||||
@@ -172,7 +172,7 @@ func TestProtect_UnknownDomainPassesThrough(t *testing.T) {
|
||||
|
||||
func TestProtect_DomainWithNoSchemesPassesThrough(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil)
|
||||
require.NoError(t, mw.AddDomain("example.com", nil, "", time.Hour))
|
||||
require.NoError(t, mw.AddDomain("example.com", nil, "", time.Hour, "", ""))
|
||||
|
||||
handler := mw.Protect(newPassthroughHandler())
|
||||
|
||||
@@ -189,7 +189,7 @@ func TestProtect_UnauthenticatedRequestIsBlocked(t *testing.T) {
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", ""))
|
||||
|
||||
var backendCalled bool
|
||||
backend := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
@@ -210,7 +210,7 @@ func TestProtect_HostWithPortIsMatched(t *testing.T) {
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", ""))
|
||||
|
||||
var backendCalled bool
|
||||
backend := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
@@ -231,7 +231,7 @@ func TestProtect_ValidSessionCookiePassesThrough(t *testing.T) {
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", ""))
|
||||
|
||||
token, err := sessionkey.SignToken(kp.PrivateKey, "test-user", "example.com", auth.MethodPIN, time.Hour)
|
||||
require.NoError(t, err)
|
||||
@@ -261,7 +261,7 @@ func TestProtect_ExpiredSessionCookieIsRejected(t *testing.T) {
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", ""))
|
||||
|
||||
// Sign a token that expired 1 second ago.
|
||||
token, err := sessionkey.SignToken(kp.PrivateKey, "test-user", "example.com", auth.MethodPIN, -time.Second)
|
||||
@@ -287,7 +287,7 @@ func TestProtect_WrongDomainCookieIsRejected(t *testing.T) {
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", ""))
|
||||
|
||||
// Token signed for a different domain audience.
|
||||
token, err := sessionkey.SignToken(kp.PrivateKey, "test-user", "other.com", auth.MethodPIN, time.Hour)
|
||||
@@ -314,7 +314,7 @@ func TestProtect_WrongKeyCookieIsRejected(t *testing.T) {
|
||||
kp2 := generateTestKeyPair(t)
|
||||
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp1.PublicKey, time.Hour))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp1.PublicKey, time.Hour, "", ""))
|
||||
|
||||
// Token signed with a different private key.
|
||||
token, err := sessionkey.SignToken(kp2.PrivateKey, "test-user", "example.com", auth.MethodPIN, time.Hour)
|
||||
@@ -351,7 +351,7 @@ func TestProtect_SchemeAuthRedirectsWithCookie(t *testing.T) {
|
||||
return "", "pin"
|
||||
},
|
||||
}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", ""))
|
||||
|
||||
var backendCalled bool
|
||||
backend := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
@@ -395,7 +395,7 @@ func TestProtect_FailedAuthDoesNotSetCookie(t *testing.T) {
|
||||
return "", "pin"
|
||||
},
|
||||
}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", ""))
|
||||
|
||||
handler := mw.Protect(newPassthroughHandler())
|
||||
|
||||
@@ -431,7 +431,7 @@ func TestProtect_MultipleSchemes(t *testing.T) {
|
||||
return "", "password"
|
||||
},
|
||||
}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{pinScheme, passwordScheme}, kp.PublicKey, time.Hour))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{pinScheme, passwordScheme}, kp.PublicKey, time.Hour, "", ""))
|
||||
|
||||
var backendCalled bool
|
||||
backend := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
@@ -461,7 +461,7 @@ func TestProtect_InvalidTokenFromSchemeReturns400(t *testing.T) {
|
||||
return "invalid-jwt-token", ""
|
||||
},
|
||||
}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", ""))
|
||||
|
||||
handler := mw.Protect(newPassthroughHandler())
|
||||
|
||||
@@ -485,7 +485,7 @@ func TestAddDomain_RandomBytes32NotEd25519(t *testing.T) {
|
||||
key := base64.StdEncoding.EncodeToString(randomBytes)
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
|
||||
err = mw.AddDomain("example.com", []Scheme{scheme}, key, time.Hour)
|
||||
err = mw.AddDomain("example.com", []Scheme{scheme}, key, time.Hour, "", "")
|
||||
require.NoError(t, err, "any 32-byte key should be accepted at registration time")
|
||||
}
|
||||
|
||||
@@ -494,10 +494,10 @@ func TestAddDomain_InvalidKeyDoesNotCorruptExistingConfig(t *testing.T) {
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", ""))
|
||||
|
||||
// Attempt to overwrite with an invalid key.
|
||||
err := mw.AddDomain("example.com", []Scheme{scheme}, "bad", time.Hour)
|
||||
err := mw.AddDomain("example.com", []Scheme{scheme}, "bad", time.Hour, "", "")
|
||||
require.Error(t, err)
|
||||
|
||||
// The original valid config should still be intact.
|
||||
|
||||
@@ -551,7 +551,7 @@ func (s *Server) updateMapping(ctx context.Context, mapping *proto.ProxyMapping)
|
||||
}
|
||||
|
||||
maxSessionAge := time.Duration(mapping.GetAuth().GetMaxSessionAgeSeconds()) * time.Second
|
||||
if err := s.auth.AddDomain(mapping.GetDomain(), schemes, mapping.GetAuth().GetSessionKey(), maxSessionAge); err != nil {
|
||||
if err := s.auth.AddDomain(mapping.GetDomain(), schemes, mapping.GetAuth().GetSessionKey(), maxSessionAge, mapping.GetAccountId(), mapping.GetId()); err != nil {
|
||||
s.Logger.WithField("domain", mapping.GetDomain()).WithError(err).Error("Auth setup failed, refusing to serve domain without authentication")
|
||||
return
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user