[management] add context cancel monitoring (#5879)

This commit is contained in:
Pascal Fischer
2026-04-14 12:49:18 +02:00
committed by GitHub
parent 7f666b8022
commit c5623307cc
2 changed files with 29 additions and 18 deletions

View File

@@ -1017,10 +1017,10 @@ func (s *SqlStore) GetAccountsCounter(ctx context.Context) (int64, error) {
// GetCustomDomainsCounts returns the total and validated custom domain counts. // GetCustomDomainsCounts returns the total and validated custom domain counts.
func (s *SqlStore) GetCustomDomainsCounts(ctx context.Context) (int64, int64, error) { func (s *SqlStore) GetCustomDomainsCounts(ctx context.Context) (int64, int64, error) {
var total, validated int64 var total, validated int64
if err := s.db.WithContext(ctx).Model(&domain.Domain{}).Count(&total).Error; err != nil { if err := s.db.Model(&domain.Domain{}).Count(&total).Error; err != nil {
return 0, 0, err return 0, 0, err
} }
if err := s.db.WithContext(ctx).Model(&domain.Domain{}).Where("validated = ?", true).Count(&validated).Error; err != nil { if err := s.db.Model(&domain.Domain{}).Where("validated = ?", true).Count(&validated).Error; err != nil {
return 0, 0, err return 0, 0, err
} }
return total, validated, nil return total, validated, nil
@@ -4442,7 +4442,7 @@ func (s *SqlStore) DeletePAT(ctx context.Context, userID, patID string) error {
// GetProxyAccessTokenByHashedToken retrieves a proxy access token by its hashed value. // GetProxyAccessTokenByHashedToken retrieves a proxy access token by its hashed value.
func (s *SqlStore) GetProxyAccessTokenByHashedToken(ctx context.Context, lockStrength LockingStrength, hashedToken types.HashedProxyToken) (*types.ProxyAccessToken, error) { func (s *SqlStore) GetProxyAccessTokenByHashedToken(ctx context.Context, lockStrength LockingStrength, hashedToken types.HashedProxyToken) (*types.ProxyAccessToken, error) {
tx := s.db.WithContext(ctx) tx := s.db
if lockStrength != LockingStrengthNone { if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
} }
@@ -4461,7 +4461,7 @@ func (s *SqlStore) GetProxyAccessTokenByHashedToken(ctx context.Context, lockStr
// GetAllProxyAccessTokens retrieves all proxy access tokens. // GetAllProxyAccessTokens retrieves all proxy access tokens.
func (s *SqlStore) GetAllProxyAccessTokens(ctx context.Context, lockStrength LockingStrength) ([]*types.ProxyAccessToken, error) { func (s *SqlStore) GetAllProxyAccessTokens(ctx context.Context, lockStrength LockingStrength) ([]*types.ProxyAccessToken, error) {
tx := s.db.WithContext(ctx) tx := s.db
if lockStrength != LockingStrengthNone { if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
} }
@@ -4477,7 +4477,7 @@ func (s *SqlStore) GetAllProxyAccessTokens(ctx context.Context, lockStrength Loc
// SaveProxyAccessToken saves a proxy access token to the database. // SaveProxyAccessToken saves a proxy access token to the database.
func (s *SqlStore) SaveProxyAccessToken(ctx context.Context, token *types.ProxyAccessToken) error { func (s *SqlStore) SaveProxyAccessToken(ctx context.Context, token *types.ProxyAccessToken) error {
if result := s.db.WithContext(ctx).Create(token); result.Error != nil { if result := s.db.Create(token); result.Error != nil {
return status.Errorf(status.Internal, "save proxy access token: %v", result.Error) return status.Errorf(status.Internal, "save proxy access token: %v", result.Error)
} }
return nil return nil
@@ -4485,7 +4485,7 @@ func (s *SqlStore) SaveProxyAccessToken(ctx context.Context, token *types.ProxyA
// RevokeProxyAccessToken revokes a proxy access token by its ID. // RevokeProxyAccessToken revokes a proxy access token by its ID.
func (s *SqlStore) RevokeProxyAccessToken(ctx context.Context, tokenID string) error { func (s *SqlStore) RevokeProxyAccessToken(ctx context.Context, tokenID string) error {
result := s.db.WithContext(ctx).Model(&types.ProxyAccessToken{}).Where(idQueryCondition, tokenID).Update("revoked", true) result := s.db.Model(&types.ProxyAccessToken{}).Where(idQueryCondition, tokenID).Update("revoked", true)
if result.Error != nil { if result.Error != nil {
return status.Errorf(status.Internal, "revoke proxy access token: %v", result.Error) return status.Errorf(status.Internal, "revoke proxy access token: %v", result.Error)
} }
@@ -4499,7 +4499,7 @@ func (s *SqlStore) RevokeProxyAccessToken(ctx context.Context, tokenID string) e
// MarkProxyAccessTokenUsed updates the last used timestamp for a proxy access token. // MarkProxyAccessTokenUsed updates the last used timestamp for a proxy access token.
func (s *SqlStore) MarkProxyAccessTokenUsed(ctx context.Context, tokenID string) error { func (s *SqlStore) MarkProxyAccessTokenUsed(ctx context.Context, tokenID string) error {
result := s.db.WithContext(ctx).Model(&types.ProxyAccessToken{}). result := s.db.Model(&types.ProxyAccessToken{}).
Where(idQueryCondition, tokenID). Where(idQueryCondition, tokenID).
Update("last_used", time.Now().UTC()) Update("last_used", time.Now().UTC())
if result.Error != nil { if result.Error != nil {
@@ -5168,7 +5168,7 @@ func (s *SqlStore) EphemeralServiceExists(ctx context.Context, lockStrength Lock
// GetServicesByClusterAndPort returns services matching the given proxy cluster, mode, and listen port. // GetServicesByClusterAndPort returns services matching the given proxy cluster, mode, and listen port.
func (s *SqlStore) GetServicesByClusterAndPort(ctx context.Context, lockStrength LockingStrength, proxyCluster string, mode string, listenPort uint16) ([]*rpservice.Service, error) { func (s *SqlStore) GetServicesByClusterAndPort(ctx context.Context, lockStrength LockingStrength, proxyCluster string, mode string, listenPort uint16) ([]*rpservice.Service, error) {
tx := s.db.WithContext(ctx) tx := s.db
if lockStrength != LockingStrengthNone { if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
} }
@@ -5184,7 +5184,7 @@ func (s *SqlStore) GetServicesByClusterAndPort(ctx context.Context, lockStrength
// GetServicesByCluster returns all services for the given proxy cluster. // GetServicesByCluster returns all services for the given proxy cluster.
func (s *SqlStore) GetServicesByCluster(ctx context.Context, lockStrength LockingStrength, proxyCluster string) ([]*rpservice.Service, error) { func (s *SqlStore) GetServicesByCluster(ctx context.Context, lockStrength LockingStrength, proxyCluster string) ([]*rpservice.Service, error) {
tx := s.db.WithContext(ctx) tx := s.db
if lockStrength != LockingStrengthNone { if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
} }
@@ -5294,7 +5294,7 @@ func (s *SqlStore) GetAccountAccessLogs(ctx context.Context, lockStrength Lockin
var logs []*accesslogs.AccessLogEntry var logs []*accesslogs.AccessLogEntry
var totalCount int64 var totalCount int64
baseQuery := s.db.WithContext(ctx). baseQuery := s.db.
Model(&accesslogs.AccessLogEntry{}). Model(&accesslogs.AccessLogEntry{}).
Where(accountIDCondition, accountID) Where(accountIDCondition, accountID)
@@ -5305,7 +5305,7 @@ func (s *SqlStore) GetAccountAccessLogs(ctx context.Context, lockStrength Lockin
return nil, 0, status.Errorf(status.Internal, "failed to count access logs") return nil, 0, status.Errorf(status.Internal, "failed to count access logs")
} }
query := s.db.WithContext(ctx). query := s.db.
Where(accountIDCondition, accountID) Where(accountIDCondition, accountID)
query = s.applyAccessLogFilters(query, filter) query = s.applyAccessLogFilters(query, filter)
@@ -5342,7 +5342,7 @@ func (s *SqlStore) GetAccountAccessLogs(ctx context.Context, lockStrength Lockin
// DeleteOldAccessLogs deletes all access logs older than the specified time // DeleteOldAccessLogs deletes all access logs older than the specified time
func (s *SqlStore) DeleteOldAccessLogs(ctx context.Context, olderThan time.Time) (int64, error) { func (s *SqlStore) DeleteOldAccessLogs(ctx context.Context, olderThan time.Time) (int64, error) {
result := s.db.WithContext(ctx). result := s.db.
Where("timestamp < ?", olderThan). Where("timestamp < ?", olderThan).
Delete(&accesslogs.AccessLogEntry{}) Delete(&accesslogs.AccessLogEntry{})
@@ -5431,7 +5431,7 @@ func (s *SqlStore) GetServiceTargetByTargetID(ctx context.Context, lockStrength
// SaveProxy saves or updates a proxy in the database // SaveProxy saves or updates a proxy in the database
func (s *SqlStore) SaveProxy(ctx context.Context, p *proxy.Proxy) error { func (s *SqlStore) SaveProxy(ctx context.Context, p *proxy.Proxy) error {
result := s.db.WithContext(ctx).Save(p) result := s.db.Save(p)
if result.Error != nil { if result.Error != nil {
log.WithContext(ctx).Errorf("failed to save proxy: %v", result.Error) log.WithContext(ctx).Errorf("failed to save proxy: %v", result.Error)
return status.Errorf(status.Internal, "failed to save proxy") return status.Errorf(status.Internal, "failed to save proxy")
@@ -5443,7 +5443,7 @@ func (s *SqlStore) SaveProxy(ctx context.Context, p *proxy.Proxy) error {
func (s *SqlStore) UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error { func (s *SqlStore) UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error {
now := time.Now() now := time.Now()
result := s.db.WithContext(ctx). result := s.db.
Model(&proxy.Proxy{}). Model(&proxy.Proxy{}).
Where("id = ? AND status = ?", proxyID, "connected"). Where("id = ? AND status = ?", proxyID, "connected").
Update("last_seen", now) Update("last_seen", now)
@@ -5462,7 +5462,7 @@ func (s *SqlStore) UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAdd
ConnectedAt: &now, ConnectedAt: &now,
Status: "connected", Status: "connected",
} }
if err := s.db.WithContext(ctx).Save(p).Error; err != nil { if err := s.db.Save(p).Error; err != nil {
log.WithContext(ctx).Errorf("failed to create proxy on heartbeat: %v", err) log.WithContext(ctx).Errorf("failed to create proxy on heartbeat: %v", err)
return status.Errorf(status.Internal, "failed to create proxy on heartbeat") return status.Errorf(status.Internal, "failed to create proxy on heartbeat")
} }
@@ -5475,7 +5475,7 @@ func (s *SqlStore) UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAdd
func (s *SqlStore) GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error) { func (s *SqlStore) GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error) {
var addresses []string var addresses []string
result := s.db.WithContext(ctx). result := s.db.
Model(&proxy.Proxy{}). Model(&proxy.Proxy{}).
Where("status = ? AND last_seen > ?", "connected", time.Now().Add(-proxyActiveThreshold)). Where("status = ? AND last_seen > ?", "connected", time.Now().Add(-proxyActiveThreshold)).
Distinct("cluster_address"). Distinct("cluster_address").
@@ -5598,7 +5598,7 @@ func (s *SqlStore) getClusterCapability(ctx context.Context, clusterAddr, column
AnyTrue bool AnyTrue bool
} }
err := s.db.WithContext(ctx). err := s.db.
Model(&proxy.Proxy{}). Model(&proxy.Proxy{}).
Select("COUNT(CASE WHEN "+column+" IS NOT NULL THEN 1 END) > 0 AS has_capability, "+ Select("COUNT(CASE WHEN "+column+" IS NOT NULL THEN 1 END) > 0 AS has_capability, "+
"COALESCE(MAX(CASE WHEN "+column+" = true THEN 1 ELSE 0 END), 0) = 1 AS any_true"). "COALESCE(MAX(CASE WHEN "+column+" = true THEN 1 ELSE 0 END), 0) = 1 AS any_true").
@@ -5622,7 +5622,7 @@ func (s *SqlStore) getClusterCapability(ctx context.Context, clusterAddr, column
func (s *SqlStore) CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error { func (s *SqlStore) CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error {
cutoffTime := time.Now().Add(-inactivityDuration) cutoffTime := time.Now().Add(-inactivityDuration)
result := s.db.WithContext(ctx). result := s.db.
Where("last_seen < ?", cutoffTime). Where("last_seen < ?", cutoffTime).
Delete(&proxy.Proxy{}) Delete(&proxy.Proxy{})

View File

@@ -183,7 +183,18 @@ func (m *HTTPMiddleware) Handler(h http.Handler) http.Handler {
w := WrapResponseWriter(rw) w := WrapResponseWriter(rw)
handlerDone := make(chan struct{})
context.AfterFunc(ctx, func() {
select {
case <-handlerDone:
default:
log.Debugf("HTTP request context canceled mid-flight: %v %v (reqID=%s, after %v, cause: %v)",
r.Method, r.URL.Path, reqID, time.Since(reqStart), context.Cause(ctx))
}
})
h.ServeHTTP(w, r.WithContext(ctx)) h.ServeHTTP(w, r.WithContext(ctx))
close(handlerDone)
userAuth, err := nbContext.GetUserAuthFromContext(r.Context()) userAuth, err := nbContext.GetUserAuthFromContext(r.Context())
if err == nil { if err == nil {