[proxy] feature: bring your own proxy

This commit is contained in:
crn4
2026-03-17 13:17:50 +01:00
parent 5585adce18
commit 26ba03f08e
32 changed files with 2697 additions and 94 deletions

View File

@@ -4464,6 +4464,47 @@ func (s *SqlStore) RevokeProxyAccessToken(ctx context.Context, tokenID string) e
return nil
}
func (s *SqlStore) GetProxyAccessTokensByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.ProxyAccessToken, error) {
tx := s.db.WithContext(ctx)
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var tokens []*types.ProxyAccessToken
result := tx.Where("account_id = ?", accountID).Find(&tokens)
if result.Error != nil {
return nil, status.Errorf(status.Internal, "get proxy access tokens by account: %v", result.Error)
}
return tokens, nil
}
func (s *SqlStore) IsProxyAccessTokenValid(ctx context.Context, tokenID string) (bool, error) {
token, err := s.GetProxyAccessTokenByID(ctx, LockingStrengthNone, tokenID)
if err != nil {
return false, err
}
return token.IsValid(), nil
}
func (s *SqlStore) GetProxyAccessTokenByID(ctx context.Context, lockStrength LockingStrength, tokenID string) (*types.ProxyAccessToken, error) {
tx := s.db.WithContext(ctx)
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var token types.ProxyAccessToken
result := tx.Take(&token, idQueryCondition, tokenID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "proxy access token not found")
}
return nil, status.Errorf(status.Internal, "get proxy access token by ID: %v", result.Error)
}
return &token, nil
}
// MarkProxyAccessTokenUsed updates the last used timestamp for a proxy access token.
func (s *SqlStore) MarkProxyAccessTokenUsed(ctx context.Context, tokenID string) error {
result := s.db.WithContext(ctx).Model(&types.ProxyAccessToken{}).
@@ -5370,7 +5411,25 @@ func (s *SqlStore) SaveProxy(ctx context.Context, p *proxy.Proxy) error {
result := s.db.WithContext(ctx).Save(p)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to save proxy: %v", result.Error)
return status.Errorf(status.Internal, "failed to save proxy")
return status.Errorf(status.Internal, "failed to save proxy: %v", result.Error)
}
return nil
}
// DisconnectProxy updates only the status, disconnected_at, and last_seen fields
func (s *SqlStore) DisconnectProxy(ctx context.Context, proxyID string) error {
now := time.Now()
result := s.db.WithContext(ctx).
Model(&proxy.Proxy{}).
Where("id = ?", proxyID).
Updates(map[string]interface{}{
"status": proxy.StatusDisconnected,
"disconnected_at": now,
"last_seen": now,
})
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to disconnect proxy: %v", result.Error)
return status.Errorf(status.Internal, "failed to disconnect proxy")
}
return nil
}
@@ -5379,7 +5438,7 @@ func (s *SqlStore) SaveProxy(ctx context.Context, p *proxy.Proxy) error {
func (s *SqlStore) UpdateProxyHeartbeat(ctx context.Context, proxyID string) error {
result := s.db.WithContext(ctx).
Model(&proxy.Proxy{}).
Where("id = ? AND status = ?", proxyID, "connected").
Where("id = ? AND status = ?", proxyID, proxy.StatusConnected).
Update("last_seen", time.Now())
if result.Error != nil {
@@ -5395,7 +5454,7 @@ func (s *SqlStore) GetActiveProxyClusterAddresses(ctx context.Context) ([]string
result := s.db.WithContext(ctx).
Model(&proxy.Proxy{}).
Where("status = ? AND last_seen > ?", "connected", time.Now().Add(-2*time.Minute)).
Where("status = ? AND last_seen > ?", proxy.StatusConnected, time.Now().Add(-2*time.Minute)).
Distinct("cluster_address").
Pluck("cluster_address", &addresses)
@@ -5407,6 +5466,66 @@ func (s *SqlStore) GetActiveProxyClusterAddresses(ctx context.Context) ([]string
return addresses, nil
}
func (s *SqlStore) GetActiveProxyClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error) {
var addresses []string
result := s.db.WithContext(ctx).
Model(&proxy.Proxy{}).
Where("account_id = ? AND status = ? AND last_seen > ?", accountID, proxy.StatusConnected, time.Now().Add(-2*time.Minute)).
Distinct("cluster_address").
Pluck("cluster_address", &addresses)
if result.Error != nil {
return nil, status.Errorf(status.Internal, "failed to get active proxy cluster addresses for account")
}
return addresses, nil
}
func (s *SqlStore) GetProxyByAccountID(ctx context.Context, accountID string) (*proxy.Proxy, error) {
var p proxy.Proxy
result := s.db.WithContext(ctx).Where("account_id = ?", accountID).Take(&p)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "proxy not found for account")
}
return nil, status.Errorf(status.Internal, "get proxy by account ID: %v", result.Error)
}
return &p, nil
}
func (s *SqlStore) CountProxiesByAccountID(ctx context.Context, accountID string) (int64, error) {
var count int64
result := s.db.WithContext(ctx).Model(&proxy.Proxy{}).Where("account_id = ?", accountID).Count(&count)
if result.Error != nil {
return 0, status.Errorf(status.Internal, "count proxies by account ID: %v", result.Error)
}
return count, nil
}
func (s *SqlStore) IsClusterAddressConflicting(ctx context.Context, clusterAddress, accountID string) (bool, error) {
var count int64
result := s.db.WithContext(ctx).
Model(&proxy.Proxy{}).
Where("cluster_address = ? AND (account_id IS NULL OR account_id != ?)", clusterAddress, accountID).
Count(&count)
if result.Error != nil {
return false, status.Errorf(status.Internal, "check cluster address conflict: %v", result.Error)
}
return count > 0, nil
}
func (s *SqlStore) DeleteProxy(ctx context.Context, proxyID string) error {
result := s.db.WithContext(ctx).Where(idQueryCondition, proxyID).Delete(&proxy.Proxy{})
if result.Error != nil {
return status.Errorf(status.Internal, "delete proxy: %v", result.Error)
}
if result.RowsAffected == 0 {
return status.Errorf(status.NotFound, "proxy not found")
}
return nil
}
// CleanupStaleProxies deletes proxies that haven't sent heartbeat in the specified duration
func (s *SqlStore) CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error {
cutoffTime := time.Now().Add(-inactivityDuration)