[management] fix offline statuses for public proxy clusters (#6133)

This commit is contained in:
Vlad
2026-05-14 13:27:50 +02:00
committed by GitHub
parent ab2a8794e7
commit 77b479286e
5 changed files with 177 additions and 45 deletions

View File

@@ -304,10 +304,27 @@ func (m Manager) getClusterAllowList(ctx context.Context, accountID string) ([]s
if err != nil {
return nil, fmt.Errorf("get BYOP cluster addresses: %w", err)
}
if len(byopAddresses) > 0 {
return byopAddresses, nil
publicAddresses, err := m.proxyManager.GetActiveClusterAddresses(ctx)
if err != nil {
return nil, fmt.Errorf("get public cluster addresses: %w", err)
}
return m.proxyManager.GetActiveClusterAddresses(ctx)
seen := make(map[string]struct{}, len(byopAddresses)+len(publicAddresses))
merged := make([]string, 0, len(byopAddresses)+len(publicAddresses))
for _, addr := range byopAddresses {
if _, ok := seen[addr]; ok {
continue
}
seen[addr] = struct{}{}
merged = append(merged, addr)
}
for _, addr := range publicAddresses {
if _, ok := seen[addr]; ok {
continue
}
seen[addr] = struct{}{}
merged = append(merged, addr)
}
return merged, nil
}
func extractClusterFromCustomDomains(serviceDomain string, customDomains []*domain.Domain) (string, bool) {

View File

@@ -40,22 +40,37 @@ func (m *mockProxyManager) ClusterSupportsCrowdSec(_ context.Context, _ string)
return nil
}
func TestGetClusterAllowList_BYOPProxy(t *testing.T) {
func TestGetClusterAllowList_BYOPMergedWithPublic(t *testing.T) {
pm := &mockProxyManager{
getActiveClusterAddressesForAccountFunc: func(_ context.Context, accID string) ([]string, error) {
assert.Equal(t, "acc-123", accID)
return []string{"byop.example.com"}, nil
},
getActiveClusterAddressesFunc: func(_ context.Context) ([]string, error) {
t.Fatal("should not call GetActiveClusterAddresses when BYOP addresses exist")
return nil, nil
return []string{"eu.proxy.netbird.io"}, nil
},
}
mgr := Manager{proxyManager: pm}
result, err := mgr.getClusterAllowList(context.Background(), "acc-123")
require.NoError(t, err)
assert.Equal(t, []string{"byop.example.com"}, result)
assert.Equal(t, []string{"byop.example.com", "eu.proxy.netbird.io"}, result)
}
func TestGetClusterAllowList_DeduplicatesBYOPAndPublic(t *testing.T) {
pm := &mockProxyManager{
getActiveClusterAddressesForAccountFunc: func(_ context.Context, _ string) ([]string, error) {
return []string{"shared.example.com", "byop.example.com"}, nil
},
getActiveClusterAddressesFunc: func(_ context.Context) ([]string, error) {
return []string{"shared.example.com", "eu.proxy.netbird.io"}, nil
},
}
mgr := Manager{proxyManager: pm}
result, err := mgr.getClusterAllowList(context.Background(), "acc-123")
require.NoError(t, err)
assert.Equal(t, []string{"shared.example.com", "byop.example.com", "eu.proxy.netbird.io"}, result)
}
func TestGetClusterAllowList_NoBYOP_FallbackToShared(t *testing.T) {
@@ -79,10 +94,6 @@ func TestGetClusterAllowList_BYOPError_ReturnsError(t *testing.T) {
getActiveClusterAddressesForAccountFunc: func(_ context.Context, _ string) ([]string, error) {
return nil, errors.New("db error")
},
getActiveClusterAddressesFunc: func(_ context.Context) ([]string, error) {
t.Fatal("should not call GetActiveClusterAddresses when BYOP lookup fails")
return nil, nil
},
}
mgr := Manager{proxyManager: pm}
@@ -92,6 +103,23 @@ func TestGetClusterAllowList_BYOPError_ReturnsError(t *testing.T) {
assert.Contains(t, err.Error(), "BYOP cluster addresses")
}
func TestGetClusterAllowList_PublicError_ReturnsError(t *testing.T) {
pm := &mockProxyManager{
getActiveClusterAddressesForAccountFunc: func(_ context.Context, _ string) ([]string, error) {
return []string{"byop.example.com"}, nil
},
getActiveClusterAddressesFunc: func(_ context.Context) ([]string, error) {
return nil, errors.New("db error")
},
}
mgr := Manager{proxyManager: pm}
result, err := mgr.getClusterAllowList(context.Background(), "acc-123")
require.Error(t, err)
assert.Nil(t, result)
assert.Contains(t, err.Error(), "public cluster addresses")
}
func TestGetClusterAllowList_BYOPEmptySlice_FallbackToShared(t *testing.T) {
pm := &mockProxyManager{
getActiveClusterAddressesForAccountFunc: func(_ context.Context, _ string) ([]string, error) {
@@ -108,3 +136,19 @@ func TestGetClusterAllowList_BYOPEmptySlice_FallbackToShared(t *testing.T) {
assert.Equal(t, []string{"eu.proxy.netbird.io"}, result)
}
func TestGetClusterAllowList_PublicEmpty_BYOPOnly(t *testing.T) {
pm := &mockProxyManager{
getActiveClusterAddressesForAccountFunc: func(_ context.Context, _ string) ([]string, error) {
return []string{"byop.example.com"}, nil
},
getActiveClusterAddressesFunc: func(_ context.Context) ([]string, error) {
return nil, nil
},
}
mgr := Manager{proxyManager: pm}
result, err := mgr.getClusterAllowList(context.Background(), "acc-123")
require.NoError(t, err)
assert.Equal(t, []string{"byop.example.com"}, result)
}

View File

@@ -306,6 +306,10 @@ func (m *Manager) validateSubdomainRequirement(ctx context.Context, domain, clus
func (m *Manager) persistNewService(ctx context.Context, accountID string, svc *service.Service) error {
customPorts := m.clusterCustomPorts(ctx, svc)
if err := validateTargetReferences(ctx, m.store, accountID, svc.Targets); err != nil {
return err
}
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if svc.Domain != "" {
if err := m.checkDomainAvailable(ctx, transaction, svc.Domain, ""); err != nil {
@@ -321,10 +325,6 @@ func (m *Manager) persistNewService(ctx context.Context, accountID string, svc *
return err
}
if err := validateTargetReferences(ctx, transaction, accountID, svc.Targets); err != nil {
return err
}
if err := transaction.CreateService(ctx, svc); err != nil {
return fmt.Errorf("create service: %w", err)
}
@@ -435,6 +435,10 @@ func (m *Manager) assignPort(ctx context.Context, tx store.Store, cluster string
func (m *Manager) persistNewEphemeralService(ctx context.Context, accountID, peerID string, svc *service.Service) error {
customPorts := m.clusterCustomPorts(ctx, svc)
if err := validateTargetReferences(ctx, m.store, accountID, svc.Targets); err != nil {
return err
}
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if err := m.validateEphemeralPreconditions(ctx, transaction, accountID, peerID, svc); err != nil {
return err
@@ -448,10 +452,6 @@ func (m *Manager) persistNewEphemeralService(ctx context.Context, accountID, pee
return err
}
if err := validateTargetReferences(ctx, transaction, accountID, svc.Targets); err != nil {
return err
}
if err := transaction.CreateService(ctx, svc); err != nil {
return fmt.Errorf("create service: %w", err)
}
@@ -552,10 +552,22 @@ func (m *Manager) persistServiceUpdate(ctx context.Context, accountID string, se
svcForCaps.ProxyCluster = effectiveCluster
customPorts := m.clusterCustomPorts(ctx, &svcForCaps)
if err := validateTargetReferences(ctx, m.store, accountID, service.Targets); err != nil {
return nil, err
}
// Validate subdomain requirement *before* the transaction: the underlying
// capability lookup talks to the main DB pool, and SQLite's single-connection
// pool would self-deadlock if this ran while the tx already held the only
// connection.
if err := m.validateSubdomainRequirement(ctx, service.Domain, effectiveCluster); err != nil {
return nil, err
}
var updateInfo serviceUpdateInfo
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
return m.executeServiceUpdate(ctx, transaction, accountID, service, &updateInfo, customPorts)
return m.executeServiceUpdate(ctx, transaction, accountID, service, &updateInfo, customPorts, effectiveCluster)
})
return &updateInfo, err
@@ -585,7 +597,7 @@ func (m *Manager) resolveEffectiveCluster(ctx context.Context, accountID string,
return existing.ProxyCluster, nil
}
func (m *Manager) executeServiceUpdate(ctx context.Context, transaction store.Store, accountID string, service *service.Service, updateInfo *serviceUpdateInfo, customPorts *bool) error {
func (m *Manager) executeServiceUpdate(ctx context.Context, transaction store.Store, accountID string, service *service.Service, updateInfo *serviceUpdateInfo, customPorts *bool, effectiveCluster string) error {
existingService, err := transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, service.ID)
if err != nil {
return err
@@ -603,17 +615,13 @@ func (m *Manager) executeServiceUpdate(ctx context.Context, transaction store.St
updateInfo.domainChanged = existingService.Domain != service.Domain
if updateInfo.domainChanged {
if err := m.handleDomainChange(ctx, transaction, accountID, service); err != nil {
if err := m.handleDomainChange(ctx, transaction, service, effectiveCluster); err != nil {
return err
}
} else {
service.ProxyCluster = existingService.ProxyCluster
}
if err := m.validateSubdomainRequirement(ctx, service.Domain, service.ProxyCluster); err != nil {
return err
}
m.preserveExistingAuthSecrets(service, existingService)
if err := validateHeaderAuthValues(service.Auth.HeaderAuths); err != nil {
return err
@@ -628,9 +636,6 @@ func (m *Manager) executeServiceUpdate(ctx context.Context, transaction store.St
if err := m.checkPortConflict(ctx, transaction, service); err != nil {
return err
}
if err := validateTargetReferences(ctx, transaction, accountID, service.Targets); err != nil {
return err
}
if err := transaction.UpdateService(ctx, service); err != nil {
return fmt.Errorf("update service: %w", err)
}
@@ -638,20 +643,18 @@ func (m *Manager) executeServiceUpdate(ctx context.Context, transaction store.St
return nil
}
func (m *Manager) handleDomainChange(ctx context.Context, transaction store.Store, accountID string, svc *service.Service) error {
// handleDomainChange validates the new domain is free inside the transaction
// and applies the pre-resolved cluster (computed outside the tx by
// resolveEffectiveCluster). It must NOT call clusterDeriver here: that talks
// to the main DB pool and would self-deadlock under SQLite (max_open_conns=1)
// because the transaction already holds the only connection.
func (m *Manager) handleDomainChange(ctx context.Context, transaction store.Store, svc *service.Service, effectiveCluster string) error {
if err := m.checkDomainAvailable(ctx, transaction, svc.Domain, svc.ID); err != nil {
return err
}
if m.clusterDeriver != nil {
newCluster, err := m.clusterDeriver.DeriveClusterFromDomain(ctx, accountID, svc.Domain)
if err != nil {
log.WithError(err).Warnf("could not derive cluster from domain %s", svc.Domain)
} else {
svc.ProxyCluster = newCluster
}
if effectiveCluster != "" {
svc.ProxyCluster = effectiveCluster
}
return nil
}

View File

@@ -8,6 +8,7 @@ import (
"fmt"
"net"
"net/netip"
"net/url"
"os"
"path/filepath"
"runtime"
@@ -2794,12 +2795,27 @@ func NewSqliteStore(ctx context.Context, dataDir string, metrics telemetry.AppMe
connStr = filepath.Join(dataDir, filePath)
}
// Append query parameters: user-provided take precedence, otherwise default to cache=shared on non-Windows
if hasQuery {
connStr += "?" + query
} else if runtime.GOOS != "windows" {
// Compose query parameters. User-provided ?_busy_timeout (or its mattn alias
// ?_timeout) overrides our default; otherwise inject 30s so SQLite waits at
// most that long on a lock instead of blocking the only Go-side connection.
// mattn/go-sqlite3 applies PRAGMA from the DSN on every fresh connection, so
// the value survives ConnMaxIdleTime/ConnMaxLifetime recycling. cache=shared
// stays the default on non-Windows for the same reason as before.
parsed, _ := url.ParseQuery(query)
var defaults []string
if parsed.Get("_busy_timeout") == "" && parsed.Get("_timeout") == "" {
defaults = append(defaults, "_busy_timeout=30000")
}
if !hasQuery && runtime.GOOS != "windows" {
// To avoid `The process cannot access the file because it is being used by another process` on Windows
connStr += "?cache=shared"
defaults = append(defaults, "cache=shared")
}
parts := defaults
if hasQuery {
parts = append(parts, query)
}
if len(parts) > 0 {
connStr += "?" + strings.Join(parts, "&")
}
db, err := gorm.Open(sqlite.Open(connStr), getGormConfig())
@@ -3402,7 +3418,7 @@ func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, accountId string)
}
func (s *SqlStore) ExecuteInTransaction(ctx context.Context, operation func(store Store) error) error {
timeoutCtx, cancel := context.WithTimeout(context.Background(), s.transactionTimeout)
timeoutCtx, cancel := context.WithTimeout(ctx, s.transactionTimeout)
defer cancel()
startTime := time.Now()

View File

@@ -4592,3 +4592,55 @@ func TestSqlStore_DeleteZoneDNSRecords(t *testing.T) {
require.NoError(t, err)
assert.Equal(t, 0, len(remainingRecords))
}
// TestNewSqliteStore_BusyTimeoutApplied opens a fresh SQLite store and verifies
// that the _busy_timeout DSN parameter took effect at the driver level. Without
// this, lock contention on the single SQLite connection waits indefinitely on
// the Go side and can be hidden behind the 5-minute transactionTimeout.
func TestNewSqliteStore_BusyTimeoutApplied(t *testing.T) {
dir := t.TempDir()
store, err := NewSqliteStore(context.Background(), dir, nil, true)
require.NoError(t, err)
t.Cleanup(func() {
_ = store.Close(context.Background())
})
sqlDB, err := store.db.DB()
require.NoError(t, err)
row := sqlDB.QueryRow("PRAGMA busy_timeout")
var busyTimeout int
require.NoError(t, row.Scan(&busyTimeout))
assert.Equal(t, 30000, busyTimeout, "SQLite busy_timeout must be set via DSN so it survives connection recycling")
}
// TestNewSqliteStore_BusyTimeoutRespectsUserOverride confirms that an operator
// passing _busy_timeout or its mattn alias _timeout via NB_STORE_ENGINE_SQLITE_FILE
// wins over our 30s default. This guards the DSN merge logic in NewSqliteStore.
func TestNewSqliteStore_BusyTimeoutRespectsUserOverride(t *testing.T) {
cases := []struct {
name string
envFile string
expected int
}{
{name: "explicit _busy_timeout wins", envFile: "store.db?_busy_timeout=5000", expected: 5000},
{name: "alias _timeout wins", envFile: "store.db?_timeout=7000", expected: 7000},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
t.Setenv("NB_STORE_ENGINE_SQLITE_FILE", tc.envFile)
dir := t.TempDir()
store, err := NewSqliteStore(context.Background(), dir, nil, true)
require.NoError(t, err)
t.Cleanup(func() {
_ = store.Close(context.Background())
})
sqlDB, err := store.db.DB()
require.NoError(t, err)
row := sqlDB.QueryRow("PRAGMA busy_timeout")
var busyTimeout int
require.NoError(t, row.Scan(&busyTimeout))
assert.Equal(t, tc.expected, busyTimeout)
})
}
}