mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-21 17:56:39 +00:00
[proxy] feature: bring your own proxy
This commit is contained in:
@@ -4464,6 +4464,47 @@ func (s *SqlStore) RevokeProxyAccessToken(ctx context.Context, tokenID string) e
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetProxyAccessTokensByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.ProxyAccessToken, error) {
|
||||
tx := s.db.WithContext(ctx)
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
var tokens []*types.ProxyAccessToken
|
||||
result := tx.Where("account_id = ?", accountID).Find(&tokens)
|
||||
if result.Error != nil {
|
||||
return nil, status.Errorf(status.Internal, "get proxy access tokens by account: %v", result.Error)
|
||||
}
|
||||
|
||||
return tokens, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) IsProxyAccessTokenValid(ctx context.Context, tokenID string) (bool, error) {
|
||||
token, err := s.GetProxyAccessTokenByID(ctx, LockingStrengthNone, tokenID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return token.IsValid(), nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetProxyAccessTokenByID(ctx context.Context, lockStrength LockingStrength, tokenID string) (*types.ProxyAccessToken, error) {
|
||||
tx := s.db.WithContext(ctx)
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
var token types.ProxyAccessToken
|
||||
result := tx.Take(&token, idQueryCondition, tokenID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "proxy access token not found")
|
||||
}
|
||||
return nil, status.Errorf(status.Internal, "get proxy access token by ID: %v", result.Error)
|
||||
}
|
||||
|
||||
return &token, nil
|
||||
}
|
||||
|
||||
// MarkProxyAccessTokenUsed updates the last used timestamp for a proxy access token.
|
||||
func (s *SqlStore) MarkProxyAccessTokenUsed(ctx context.Context, tokenID string) error {
|
||||
result := s.db.WithContext(ctx).Model(&types.ProxyAccessToken{}).
|
||||
@@ -5370,7 +5411,25 @@ func (s *SqlStore) SaveProxy(ctx context.Context, p *proxy.Proxy) error {
|
||||
result := s.db.WithContext(ctx).Save(p)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to save proxy: %v", result.Error)
|
||||
return status.Errorf(status.Internal, "failed to save proxy")
|
||||
return status.Errorf(status.Internal, "failed to save proxy: %v", result.Error)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DisconnectProxy updates only the status, disconnected_at, and last_seen fields
|
||||
func (s *SqlStore) DisconnectProxy(ctx context.Context, proxyID string) error {
|
||||
now := time.Now()
|
||||
result := s.db.WithContext(ctx).
|
||||
Model(&proxy.Proxy{}).
|
||||
Where("id = ?", proxyID).
|
||||
Updates(map[string]interface{}{
|
||||
"status": proxy.StatusDisconnected,
|
||||
"disconnected_at": now,
|
||||
"last_seen": now,
|
||||
})
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to disconnect proxy: %v", result.Error)
|
||||
return status.Errorf(status.Internal, "failed to disconnect proxy")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -5379,7 +5438,7 @@ func (s *SqlStore) SaveProxy(ctx context.Context, p *proxy.Proxy) error {
|
||||
func (s *SqlStore) UpdateProxyHeartbeat(ctx context.Context, proxyID string) error {
|
||||
result := s.db.WithContext(ctx).
|
||||
Model(&proxy.Proxy{}).
|
||||
Where("id = ? AND status = ?", proxyID, "connected").
|
||||
Where("id = ? AND status = ?", proxyID, proxy.StatusConnected).
|
||||
Update("last_seen", time.Now())
|
||||
|
||||
if result.Error != nil {
|
||||
@@ -5395,7 +5454,7 @@ func (s *SqlStore) GetActiveProxyClusterAddresses(ctx context.Context) ([]string
|
||||
|
||||
result := s.db.WithContext(ctx).
|
||||
Model(&proxy.Proxy{}).
|
||||
Where("status = ? AND last_seen > ?", "connected", time.Now().Add(-2*time.Minute)).
|
||||
Where("status = ? AND last_seen > ?", proxy.StatusConnected, time.Now().Add(-2*time.Minute)).
|
||||
Distinct("cluster_address").
|
||||
Pluck("cluster_address", &addresses)
|
||||
|
||||
@@ -5407,6 +5466,66 @@ func (s *SqlStore) GetActiveProxyClusterAddresses(ctx context.Context) ([]string
|
||||
return addresses, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetActiveProxyClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error) {
|
||||
var addresses []string
|
||||
|
||||
result := s.db.WithContext(ctx).
|
||||
Model(&proxy.Proxy{}).
|
||||
Where("account_id = ? AND status = ? AND last_seen > ?", accountID, proxy.StatusConnected, time.Now().Add(-2*time.Minute)).
|
||||
Distinct("cluster_address").
|
||||
Pluck("cluster_address", &addresses)
|
||||
|
||||
if result.Error != nil {
|
||||
return nil, status.Errorf(status.Internal, "failed to get active proxy cluster addresses for account")
|
||||
}
|
||||
|
||||
return addresses, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetProxyByAccountID(ctx context.Context, accountID string) (*proxy.Proxy, error) {
|
||||
var p proxy.Proxy
|
||||
result := s.db.WithContext(ctx).Where("account_id = ?", accountID).Take(&p)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "proxy not found for account")
|
||||
}
|
||||
return nil, status.Errorf(status.Internal, "get proxy by account ID: %v", result.Error)
|
||||
}
|
||||
return &p, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) CountProxiesByAccountID(ctx context.Context, accountID string) (int64, error) {
|
||||
var count int64
|
||||
result := s.db.WithContext(ctx).Model(&proxy.Proxy{}).Where("account_id = ?", accountID).Count(&count)
|
||||
if result.Error != nil {
|
||||
return 0, status.Errorf(status.Internal, "count proxies by account ID: %v", result.Error)
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) IsClusterAddressConflicting(ctx context.Context, clusterAddress, accountID string) (bool, error) {
|
||||
var count int64
|
||||
result := s.db.WithContext(ctx).
|
||||
Model(&proxy.Proxy{}).
|
||||
Where("cluster_address = ? AND (account_id IS NULL OR account_id != ?)", clusterAddress, accountID).
|
||||
Count(&count)
|
||||
if result.Error != nil {
|
||||
return false, status.Errorf(status.Internal, "check cluster address conflict: %v", result.Error)
|
||||
}
|
||||
return count > 0, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) DeleteProxy(ctx context.Context, proxyID string) error {
|
||||
result := s.db.WithContext(ctx).Where(idQueryCondition, proxyID).Delete(&proxy.Proxy{})
|
||||
if result.Error != nil {
|
||||
return status.Errorf(status.Internal, "delete proxy: %v", result.Error)
|
||||
}
|
||||
if result.RowsAffected == 0 {
|
||||
return status.Errorf(status.NotFound, "proxy not found")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// CleanupStaleProxies deletes proxies that haven't sent heartbeat in the specified duration
|
||||
func (s *SqlStore) CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error {
|
||||
cutoffTime := time.Now().Add(-inactivityDuration)
|
||||
|
||||
Reference in New Issue
Block a user