mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-25 19:56:46 +00:00
Merge branch 'main' into proto-ipv6-overlay
This commit is contained in:
@@ -1018,10 +1018,10 @@ func (s *SqlStore) GetAccountsCounter(ctx context.Context) (int64, error) {
|
||||
// GetCustomDomainsCounts returns the total and validated custom domain counts.
|
||||
func (s *SqlStore) GetCustomDomainsCounts(ctx context.Context) (int64, int64, error) {
|
||||
var total, validated int64
|
||||
if err := s.db.WithContext(ctx).Model(&domain.Domain{}).Count(&total).Error; err != nil {
|
||||
if err := s.db.Model(&domain.Domain{}).Count(&total).Error; err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
if err := s.db.WithContext(ctx).Model(&domain.Domain{}).Where("validated = ?", true).Count(&validated).Error; err != nil {
|
||||
if err := s.db.Model(&domain.Domain{}).Where("validated = ?", true).Count(&validated).Error; err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
return total, validated, nil
|
||||
@@ -4457,7 +4457,7 @@ func (s *SqlStore) DeletePAT(ctx context.Context, userID, patID string) error {
|
||||
|
||||
// GetProxyAccessTokenByHashedToken retrieves a proxy access token by its hashed value.
|
||||
func (s *SqlStore) GetProxyAccessTokenByHashedToken(ctx context.Context, lockStrength LockingStrength, hashedToken types.HashedProxyToken) (*types.ProxyAccessToken, error) {
|
||||
tx := s.db.WithContext(ctx)
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
@@ -4476,7 +4476,7 @@ func (s *SqlStore) GetProxyAccessTokenByHashedToken(ctx context.Context, lockStr
|
||||
|
||||
// GetAllProxyAccessTokens retrieves all proxy access tokens.
|
||||
func (s *SqlStore) GetAllProxyAccessTokens(ctx context.Context, lockStrength LockingStrength) ([]*types.ProxyAccessToken, error) {
|
||||
tx := s.db.WithContext(ctx)
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
@@ -4492,7 +4492,7 @@ func (s *SqlStore) GetAllProxyAccessTokens(ctx context.Context, lockStrength Loc
|
||||
|
||||
// SaveProxyAccessToken saves a proxy access token to the database.
|
||||
func (s *SqlStore) SaveProxyAccessToken(ctx context.Context, token *types.ProxyAccessToken) error {
|
||||
if result := s.db.WithContext(ctx).Create(token); result.Error != nil {
|
||||
if result := s.db.Create(token); result.Error != nil {
|
||||
return status.Errorf(status.Internal, "save proxy access token: %v", result.Error)
|
||||
}
|
||||
return nil
|
||||
@@ -4500,7 +4500,7 @@ func (s *SqlStore) SaveProxyAccessToken(ctx context.Context, token *types.ProxyA
|
||||
|
||||
// RevokeProxyAccessToken revokes a proxy access token by its ID.
|
||||
func (s *SqlStore) RevokeProxyAccessToken(ctx context.Context, tokenID string) error {
|
||||
result := s.db.WithContext(ctx).Model(&types.ProxyAccessToken{}).Where(idQueryCondition, tokenID).Update("revoked", true)
|
||||
result := s.db.Model(&types.ProxyAccessToken{}).Where(idQueryCondition, tokenID).Update("revoked", true)
|
||||
if result.Error != nil {
|
||||
return status.Errorf(status.Internal, "revoke proxy access token: %v", result.Error)
|
||||
}
|
||||
@@ -4514,7 +4514,7 @@ func (s *SqlStore) RevokeProxyAccessToken(ctx context.Context, tokenID string) e
|
||||
|
||||
// MarkProxyAccessTokenUsed updates the last used timestamp for a proxy access token.
|
||||
func (s *SqlStore) MarkProxyAccessTokenUsed(ctx context.Context, tokenID string) error {
|
||||
result := s.db.WithContext(ctx).Model(&types.ProxyAccessToken{}).
|
||||
result := s.db.Model(&types.ProxyAccessToken{}).
|
||||
Where(idQueryCondition, tokenID).
|
||||
Update("last_used", time.Now().UTC())
|
||||
if result.Error != nil {
|
||||
@@ -5204,7 +5204,7 @@ func (s *SqlStore) EphemeralServiceExists(ctx context.Context, lockStrength Lock
|
||||
|
||||
// GetServicesByClusterAndPort returns services matching the given proxy cluster, mode, and listen port.
|
||||
func (s *SqlStore) GetServicesByClusterAndPort(ctx context.Context, lockStrength LockingStrength, proxyCluster string, mode string, listenPort uint16) ([]*rpservice.Service, error) {
|
||||
tx := s.db.WithContext(ctx)
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
@@ -5220,7 +5220,7 @@ func (s *SqlStore) GetServicesByClusterAndPort(ctx context.Context, lockStrength
|
||||
|
||||
// GetServicesByCluster returns all services for the given proxy cluster.
|
||||
func (s *SqlStore) GetServicesByCluster(ctx context.Context, lockStrength LockingStrength, proxyCluster string) ([]*rpservice.Service, error) {
|
||||
tx := s.db.WithContext(ctx)
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
@@ -5330,7 +5330,7 @@ func (s *SqlStore) GetAccountAccessLogs(ctx context.Context, lockStrength Lockin
|
||||
var logs []*accesslogs.AccessLogEntry
|
||||
var totalCount int64
|
||||
|
||||
baseQuery := s.db.WithContext(ctx).
|
||||
baseQuery := s.db.
|
||||
Model(&accesslogs.AccessLogEntry{}).
|
||||
Where(accountIDCondition, accountID)
|
||||
|
||||
@@ -5341,7 +5341,7 @@ func (s *SqlStore) GetAccountAccessLogs(ctx context.Context, lockStrength Lockin
|
||||
return nil, 0, status.Errorf(status.Internal, "failed to count access logs")
|
||||
}
|
||||
|
||||
query := s.db.WithContext(ctx).
|
||||
query := s.db.
|
||||
Where(accountIDCondition, accountID)
|
||||
|
||||
query = s.applyAccessLogFilters(query, filter)
|
||||
@@ -5378,7 +5378,7 @@ func (s *SqlStore) GetAccountAccessLogs(ctx context.Context, lockStrength Lockin
|
||||
|
||||
// DeleteOldAccessLogs deletes all access logs older than the specified time
|
||||
func (s *SqlStore) DeleteOldAccessLogs(ctx context.Context, olderThan time.Time) (int64, error) {
|
||||
result := s.db.WithContext(ctx).
|
||||
result := s.db.
|
||||
Where("timestamp < ?", olderThan).
|
||||
Delete(&accesslogs.AccessLogEntry{})
|
||||
|
||||
@@ -5467,7 +5467,7 @@ func (s *SqlStore) GetServiceTargetByTargetID(ctx context.Context, lockStrength
|
||||
|
||||
// SaveProxy saves or updates a proxy in the database
|
||||
func (s *SqlStore) SaveProxy(ctx context.Context, p *proxy.Proxy) error {
|
||||
result := s.db.WithContext(ctx).Save(p)
|
||||
result := s.db.Save(p)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to save proxy: %v", result.Error)
|
||||
return status.Errorf(status.Internal, "failed to save proxy")
|
||||
@@ -5479,7 +5479,7 @@ func (s *SqlStore) SaveProxy(ctx context.Context, p *proxy.Proxy) error {
|
||||
func (s *SqlStore) UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error {
|
||||
now := time.Now()
|
||||
|
||||
result := s.db.WithContext(ctx).
|
||||
result := s.db.
|
||||
Model(&proxy.Proxy{}).
|
||||
Where("id = ? AND status = ?", proxyID, "connected").
|
||||
Update("last_seen", now)
|
||||
@@ -5498,7 +5498,7 @@ func (s *SqlStore) UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAdd
|
||||
ConnectedAt: &now,
|
||||
Status: "connected",
|
||||
}
|
||||
if err := s.db.WithContext(ctx).Save(p).Error; err != nil {
|
||||
if err := s.db.Save(p).Error; err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to create proxy on heartbeat: %v", err)
|
||||
return status.Errorf(status.Internal, "failed to create proxy on heartbeat")
|
||||
}
|
||||
@@ -5511,7 +5511,7 @@ func (s *SqlStore) UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAdd
|
||||
func (s *SqlStore) GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error) {
|
||||
var addresses []string
|
||||
|
||||
result := s.db.WithContext(ctx).
|
||||
result := s.db.
|
||||
Model(&proxy.Proxy{}).
|
||||
Where("status = ? AND last_seen > ?", "connected", time.Now().Add(-proxyActiveThreshold)).
|
||||
Distinct("cluster_address").
|
||||
@@ -5550,6 +5550,7 @@ const proxyActiveThreshold = 2 * time.Minute
|
||||
var validCapabilityColumns = map[string]struct{}{
|
||||
"supports_custom_ports": {},
|
||||
"require_subdomain": {},
|
||||
"supports_crowdsec": {},
|
||||
}
|
||||
|
||||
// GetClusterSupportsCustomPorts returns whether any active proxy in the cluster
|
||||
@@ -5564,6 +5565,59 @@ func (s *SqlStore) GetClusterRequireSubdomain(ctx context.Context, clusterAddr s
|
||||
return s.getClusterCapability(ctx, clusterAddr, "require_subdomain")
|
||||
}
|
||||
|
||||
// GetClusterSupportsCrowdSec returns whether all active proxies in the cluster
|
||||
// have CrowdSec configured. Returns nil when no proxy reported the capability.
|
||||
// Unlike other capabilities that use ANY-true (for rolling upgrades), CrowdSec
|
||||
// requires unanimous support: a single unconfigured proxy would let requests
|
||||
// bypass reputation checks.
|
||||
func (s *SqlStore) GetClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool {
|
||||
return s.getClusterUnanimousCapability(ctx, clusterAddr, "supports_crowdsec")
|
||||
}
|
||||
|
||||
// getClusterUnanimousCapability returns an aggregated boolean capability
|
||||
// requiring all active proxies in the cluster to report true.
|
||||
func (s *SqlStore) getClusterUnanimousCapability(ctx context.Context, clusterAddr, column string) *bool {
|
||||
if _, ok := validCapabilityColumns[column]; !ok {
|
||||
log.WithContext(ctx).Errorf("invalid capability column: %s", column)
|
||||
return nil
|
||||
}
|
||||
|
||||
var result struct {
|
||||
Total int64
|
||||
Reported int64
|
||||
AllTrue bool
|
||||
}
|
||||
|
||||
// All active proxies must have reported the capability (no NULLs) and all
|
||||
// must report true. A single unreported or false proxy means the cluster
|
||||
// does not unanimously support the capability.
|
||||
err := s.db.WithContext(ctx).
|
||||
Model(&proxy.Proxy{}).
|
||||
Select("COUNT(*) AS total, "+
|
||||
"COUNT(CASE WHEN "+column+" IS NOT NULL THEN 1 END) AS reported, "+
|
||||
"COUNT(*) > 0 AND COUNT(*) = COUNT(CASE WHEN "+column+" = true THEN 1 END) AS all_true").
|
||||
Where("cluster_address = ? AND status = ? AND last_seen > ?",
|
||||
clusterAddr, "connected", time.Now().Add(-proxyActiveThreshold)).
|
||||
Scan(&result).Error
|
||||
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("query cluster capability %s for %s: %v", column, clusterAddr, err)
|
||||
return nil
|
||||
}
|
||||
|
||||
if result.Total == 0 || result.Reported == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// If any proxy has not reported (NULL), we can't confirm unanimous support.
|
||||
if result.Reported < result.Total {
|
||||
v := false
|
||||
return &v
|
||||
}
|
||||
|
||||
return &result.AllTrue
|
||||
}
|
||||
|
||||
// getClusterCapability returns an aggregated boolean capability for the given
|
||||
// cluster. It checks active (connected, recently seen) proxies and returns:
|
||||
// - *true if any proxy in the cluster has the capability set to true,
|
||||
@@ -5580,7 +5634,7 @@ func (s *SqlStore) getClusterCapability(ctx context.Context, clusterAddr, column
|
||||
AnyTrue bool
|
||||
}
|
||||
|
||||
err := s.db.WithContext(ctx).
|
||||
err := s.db.
|
||||
Model(&proxy.Proxy{}).
|
||||
Select("COUNT(CASE WHEN "+column+" IS NOT NULL THEN 1 END) > 0 AS has_capability, "+
|
||||
"COALESCE(MAX(CASE WHEN "+column+" = true THEN 1 ELSE 0 END), 0) = 1 AS any_true").
|
||||
@@ -5604,7 +5658,7 @@ func (s *SqlStore) getClusterCapability(ctx context.Context, clusterAddr, column
|
||||
func (s *SqlStore) CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error {
|
||||
cutoffTime := time.Now().Add(-inactivityDuration)
|
||||
|
||||
result := s.db.WithContext(ctx).
|
||||
result := s.db.
|
||||
Where("last_seen < ?", cutoffTime).
|
||||
Delete(&proxy.Proxy{})
|
||||
|
||||
|
||||
Reference in New Issue
Block a user