diff --git a/management/internals/modules/reverseproxy/domain/manager/manager.go b/management/internals/modules/reverseproxy/domain/manager/manager.go index ab899e0bf..2790b5f20 100644 --- a/management/internals/modules/reverseproxy/domain/manager/manager.go +++ b/management/internals/modules/reverseproxy/domain/manager/manager.go @@ -304,10 +304,27 @@ func (m Manager) getClusterAllowList(ctx context.Context, accountID string) ([]s if err != nil { return nil, fmt.Errorf("get BYOP cluster addresses: %w", err) } - if len(byopAddresses) > 0 { - return byopAddresses, nil + publicAddresses, err := m.proxyManager.GetActiveClusterAddresses(ctx) + if err != nil { + return nil, fmt.Errorf("get public cluster addresses: %w", err) } - return m.proxyManager.GetActiveClusterAddresses(ctx) + seen := make(map[string]struct{}, len(byopAddresses)+len(publicAddresses)) + merged := make([]string, 0, len(byopAddresses)+len(publicAddresses)) + for _, addr := range byopAddresses { + if _, ok := seen[addr]; ok { + continue + } + seen[addr] = struct{}{} + merged = append(merged, addr) + } + for _, addr := range publicAddresses { + if _, ok := seen[addr]; ok { + continue + } + seen[addr] = struct{}{} + merged = append(merged, addr) + } + return merged, nil } func extractClusterFromCustomDomains(serviceDomain string, customDomains []*domain.Domain) (string, bool) { diff --git a/management/internals/modules/reverseproxy/domain/manager/manager_test.go b/management/internals/modules/reverseproxy/domain/manager/manager_test.go index fdeb0765f..5e7bbfc36 100644 --- a/management/internals/modules/reverseproxy/domain/manager/manager_test.go +++ b/management/internals/modules/reverseproxy/domain/manager/manager_test.go @@ -40,22 +40,37 @@ func (m *mockProxyManager) ClusterSupportsCrowdSec(_ context.Context, _ string) return nil } -func TestGetClusterAllowList_BYOPProxy(t *testing.T) { +func TestGetClusterAllowList_BYOPMergedWithPublic(t *testing.T) { pm := &mockProxyManager{ getActiveClusterAddressesForAccountFunc: func(_ context.Context, accID string) ([]string, error) { assert.Equal(t, "acc-123", accID) return []string{"byop.example.com"}, nil }, getActiveClusterAddressesFunc: func(_ context.Context) ([]string, error) { - t.Fatal("should not call GetActiveClusterAddresses when BYOP addresses exist") - return nil, nil + return []string{"eu.proxy.netbird.io"}, nil }, } mgr := Manager{proxyManager: pm} result, err := mgr.getClusterAllowList(context.Background(), "acc-123") require.NoError(t, err) - assert.Equal(t, []string{"byop.example.com"}, result) + assert.Equal(t, []string{"byop.example.com", "eu.proxy.netbird.io"}, result) +} + +func TestGetClusterAllowList_DeduplicatesBYOPAndPublic(t *testing.T) { + pm := &mockProxyManager{ + getActiveClusterAddressesForAccountFunc: func(_ context.Context, _ string) ([]string, error) { + return []string{"shared.example.com", "byop.example.com"}, nil + }, + getActiveClusterAddressesFunc: func(_ context.Context) ([]string, error) { + return []string{"shared.example.com", "eu.proxy.netbird.io"}, nil + }, + } + + mgr := Manager{proxyManager: pm} + result, err := mgr.getClusterAllowList(context.Background(), "acc-123") + require.NoError(t, err) + assert.Equal(t, []string{"shared.example.com", "byop.example.com", "eu.proxy.netbird.io"}, result) } func TestGetClusterAllowList_NoBYOP_FallbackToShared(t *testing.T) { @@ -79,10 +94,6 @@ func TestGetClusterAllowList_BYOPError_ReturnsError(t *testing.T) { getActiveClusterAddressesForAccountFunc: func(_ context.Context, _ string) ([]string, error) { return nil, errors.New("db error") }, - getActiveClusterAddressesFunc: func(_ context.Context) ([]string, error) { - t.Fatal("should not call GetActiveClusterAddresses when BYOP lookup fails") - return nil, nil - }, } mgr := Manager{proxyManager: pm} @@ -92,6 +103,23 @@ func TestGetClusterAllowList_BYOPError_ReturnsError(t *testing.T) { assert.Contains(t, err.Error(), "BYOP cluster addresses") } +func TestGetClusterAllowList_PublicError_ReturnsError(t *testing.T) { + pm := &mockProxyManager{ + getActiveClusterAddressesForAccountFunc: func(_ context.Context, _ string) ([]string, error) { + return []string{"byop.example.com"}, nil + }, + getActiveClusterAddressesFunc: func(_ context.Context) ([]string, error) { + return nil, errors.New("db error") + }, + } + + mgr := Manager{proxyManager: pm} + result, err := mgr.getClusterAllowList(context.Background(), "acc-123") + require.Error(t, err) + assert.Nil(t, result) + assert.Contains(t, err.Error(), "public cluster addresses") +} + func TestGetClusterAllowList_BYOPEmptySlice_FallbackToShared(t *testing.T) { pm := &mockProxyManager{ getActiveClusterAddressesForAccountFunc: func(_ context.Context, _ string) ([]string, error) { @@ -108,3 +136,19 @@ func TestGetClusterAllowList_BYOPEmptySlice_FallbackToShared(t *testing.T) { assert.Equal(t, []string{"eu.proxy.netbird.io"}, result) } +func TestGetClusterAllowList_PublicEmpty_BYOPOnly(t *testing.T) { + pm := &mockProxyManager{ + getActiveClusterAddressesForAccountFunc: func(_ context.Context, _ string) ([]string, error) { + return []string{"byop.example.com"}, nil + }, + getActiveClusterAddressesFunc: func(_ context.Context) ([]string, error) { + return nil, nil + }, + } + + mgr := Manager{proxyManager: pm} + result, err := mgr.getClusterAllowList(context.Background(), "acc-123") + require.NoError(t, err) + assert.Equal(t, []string{"byop.example.com"}, result) +} + diff --git a/management/internals/modules/reverseproxy/service/manager/manager.go b/management/internals/modules/reverseproxy/service/manager/manager.go index c866d8f75..4a8598afb 100644 --- a/management/internals/modules/reverseproxy/service/manager/manager.go +++ b/management/internals/modules/reverseproxy/service/manager/manager.go @@ -306,6 +306,10 @@ func (m *Manager) validateSubdomainRequirement(ctx context.Context, domain, clus func (m *Manager) persistNewService(ctx context.Context, accountID string, svc *service.Service) error { customPorts := m.clusterCustomPorts(ctx, svc) + if err := validateTargetReferences(ctx, m.store, accountID, svc.Targets); err != nil { + return err + } + return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { if svc.Domain != "" { if err := m.checkDomainAvailable(ctx, transaction, svc.Domain, ""); err != nil { @@ -321,10 +325,6 @@ func (m *Manager) persistNewService(ctx context.Context, accountID string, svc * return err } - if err := validateTargetReferences(ctx, transaction, accountID, svc.Targets); err != nil { - return err - } - if err := transaction.CreateService(ctx, svc); err != nil { return fmt.Errorf("create service: %w", err) } @@ -435,6 +435,10 @@ func (m *Manager) assignPort(ctx context.Context, tx store.Store, cluster string func (m *Manager) persistNewEphemeralService(ctx context.Context, accountID, peerID string, svc *service.Service) error { customPorts := m.clusterCustomPorts(ctx, svc) + if err := validateTargetReferences(ctx, m.store, accountID, svc.Targets); err != nil { + return err + } + return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { if err := m.validateEphemeralPreconditions(ctx, transaction, accountID, peerID, svc); err != nil { return err @@ -448,10 +452,6 @@ func (m *Manager) persistNewEphemeralService(ctx context.Context, accountID, pee return err } - if err := validateTargetReferences(ctx, transaction, accountID, svc.Targets); err != nil { - return err - } - if err := transaction.CreateService(ctx, svc); err != nil { return fmt.Errorf("create service: %w", err) } @@ -552,10 +552,22 @@ func (m *Manager) persistServiceUpdate(ctx context.Context, accountID string, se svcForCaps.ProxyCluster = effectiveCluster customPorts := m.clusterCustomPorts(ctx, &svcForCaps) + if err := validateTargetReferences(ctx, m.store, accountID, service.Targets); err != nil { + return nil, err + } + + // Validate subdomain requirement *before* the transaction: the underlying + // capability lookup talks to the main DB pool, and SQLite's single-connection + // pool would self-deadlock if this ran while the tx already held the only + // connection. + if err := m.validateSubdomainRequirement(ctx, service.Domain, effectiveCluster); err != nil { + return nil, err + } + var updateInfo serviceUpdateInfo err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { - return m.executeServiceUpdate(ctx, transaction, accountID, service, &updateInfo, customPorts) + return m.executeServiceUpdate(ctx, transaction, accountID, service, &updateInfo, customPorts, effectiveCluster) }) return &updateInfo, err @@ -585,7 +597,7 @@ func (m *Manager) resolveEffectiveCluster(ctx context.Context, accountID string, return existing.ProxyCluster, nil } -func (m *Manager) executeServiceUpdate(ctx context.Context, transaction store.Store, accountID string, service *service.Service, updateInfo *serviceUpdateInfo, customPorts *bool) error { +func (m *Manager) executeServiceUpdate(ctx context.Context, transaction store.Store, accountID string, service *service.Service, updateInfo *serviceUpdateInfo, customPorts *bool, effectiveCluster string) error { existingService, err := transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, service.ID) if err != nil { return err @@ -603,17 +615,13 @@ func (m *Manager) executeServiceUpdate(ctx context.Context, transaction store.St updateInfo.domainChanged = existingService.Domain != service.Domain if updateInfo.domainChanged { - if err := m.handleDomainChange(ctx, transaction, accountID, service); err != nil { + if err := m.handleDomainChange(ctx, transaction, service, effectiveCluster); err != nil { return err } } else { service.ProxyCluster = existingService.ProxyCluster } - if err := m.validateSubdomainRequirement(ctx, service.Domain, service.ProxyCluster); err != nil { - return err - } - m.preserveExistingAuthSecrets(service, existingService) if err := validateHeaderAuthValues(service.Auth.HeaderAuths); err != nil { return err @@ -628,9 +636,6 @@ func (m *Manager) executeServiceUpdate(ctx context.Context, transaction store.St if err := m.checkPortConflict(ctx, transaction, service); err != nil { return err } - if err := validateTargetReferences(ctx, transaction, accountID, service.Targets); err != nil { - return err - } if err := transaction.UpdateService(ctx, service); err != nil { return fmt.Errorf("update service: %w", err) } @@ -638,20 +643,18 @@ func (m *Manager) executeServiceUpdate(ctx context.Context, transaction store.St return nil } -func (m *Manager) handleDomainChange(ctx context.Context, transaction store.Store, accountID string, svc *service.Service) error { +// handleDomainChange validates the new domain is free inside the transaction +// and applies the pre-resolved cluster (computed outside the tx by +// resolveEffectiveCluster). It must NOT call clusterDeriver here: that talks +// to the main DB pool and would self-deadlock under SQLite (max_open_conns=1) +// because the transaction already holds the only connection. +func (m *Manager) handleDomainChange(ctx context.Context, transaction store.Store, svc *service.Service, effectiveCluster string) error { if err := m.checkDomainAvailable(ctx, transaction, svc.Domain, svc.ID); err != nil { return err } - - if m.clusterDeriver != nil { - newCluster, err := m.clusterDeriver.DeriveClusterFromDomain(ctx, accountID, svc.Domain) - if err != nil { - log.WithError(err).Warnf("could not derive cluster from domain %s", svc.Domain) - } else { - svc.ProxyCluster = newCluster - } + if effectiveCluster != "" { + svc.ProxyCluster = effectiveCluster } - return nil } diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index 4c2f0be52..893ee2168 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -8,6 +8,7 @@ import ( "fmt" "net" "net/netip" + "net/url" "os" "path/filepath" "runtime" @@ -2794,12 +2795,27 @@ func NewSqliteStore(ctx context.Context, dataDir string, metrics telemetry.AppMe connStr = filepath.Join(dataDir, filePath) } - // Append query parameters: user-provided take precedence, otherwise default to cache=shared on non-Windows - if hasQuery { - connStr += "?" + query - } else if runtime.GOOS != "windows" { + // Compose query parameters. User-provided ?_busy_timeout (or its mattn alias + // ?_timeout) overrides our default; otherwise inject 30s so SQLite waits at + // most that long on a lock instead of blocking the only Go-side connection. + // mattn/go-sqlite3 applies PRAGMA from the DSN on every fresh connection, so + // the value survives ConnMaxIdleTime/ConnMaxLifetime recycling. cache=shared + // stays the default on non-Windows for the same reason as before. + parsed, _ := url.ParseQuery(query) + var defaults []string + if parsed.Get("_busy_timeout") == "" && parsed.Get("_timeout") == "" { + defaults = append(defaults, "_busy_timeout=30000") + } + if !hasQuery && runtime.GOOS != "windows" { // To avoid `The process cannot access the file because it is being used by another process` on Windows - connStr += "?cache=shared" + defaults = append(defaults, "cache=shared") + } + parts := defaults + if hasQuery { + parts = append(parts, query) + } + if len(parts) > 0 { + connStr += "?" + strings.Join(parts, "&") } db, err := gorm.Open(sqlite.Open(connStr), getGormConfig()) @@ -3402,7 +3418,7 @@ func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, accountId string) } func (s *SqlStore) ExecuteInTransaction(ctx context.Context, operation func(store Store) error) error { - timeoutCtx, cancel := context.WithTimeout(context.Background(), s.transactionTimeout) + timeoutCtx, cancel := context.WithTimeout(ctx, s.transactionTimeout) defer cancel() startTime := time.Now() diff --git a/management/server/store/sql_store_test.go b/management/server/store/sql_store_test.go index 2819265c3..7515add62 100644 --- a/management/server/store/sql_store_test.go +++ b/management/server/store/sql_store_test.go @@ -4592,3 +4592,55 @@ func TestSqlStore_DeleteZoneDNSRecords(t *testing.T) { require.NoError(t, err) assert.Equal(t, 0, len(remainingRecords)) } + +// TestNewSqliteStore_BusyTimeoutApplied opens a fresh SQLite store and verifies +// that the _busy_timeout DSN parameter took effect at the driver level. Without +// this, lock contention on the single SQLite connection waits indefinitely on +// the Go side and can be hidden behind the 5-minute transactionTimeout. +func TestNewSqliteStore_BusyTimeoutApplied(t *testing.T) { + dir := t.TempDir() + store, err := NewSqliteStore(context.Background(), dir, nil, true) + require.NoError(t, err) + t.Cleanup(func() { + _ = store.Close(context.Background()) + }) + + sqlDB, err := store.db.DB() + require.NoError(t, err) + row := sqlDB.QueryRow("PRAGMA busy_timeout") + var busyTimeout int + require.NoError(t, row.Scan(&busyTimeout)) + assert.Equal(t, 30000, busyTimeout, "SQLite busy_timeout must be set via DSN so it survives connection recycling") +} + +// TestNewSqliteStore_BusyTimeoutRespectsUserOverride confirms that an operator +// passing _busy_timeout or its mattn alias _timeout via NB_STORE_ENGINE_SQLITE_FILE +// wins over our 30s default. This guards the DSN merge logic in NewSqliteStore. +func TestNewSqliteStore_BusyTimeoutRespectsUserOverride(t *testing.T) { + cases := []struct { + name string + envFile string + expected int + }{ + {name: "explicit _busy_timeout wins", envFile: "store.db?_busy_timeout=5000", expected: 5000}, + {name: "alias _timeout wins", envFile: "store.db?_timeout=7000", expected: 7000}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + t.Setenv("NB_STORE_ENGINE_SQLITE_FILE", tc.envFile) + dir := t.TempDir() + store, err := NewSqliteStore(context.Background(), dir, nil, true) + require.NoError(t, err) + t.Cleanup(func() { + _ = store.Close(context.Background()) + }) + + sqlDB, err := store.db.DB() + require.NoError(t, err) + row := sqlDB.QueryRow("PRAGMA busy_timeout") + var busyTimeout int + require.NoError(t, row.Scan(&busyTimeout)) + assert.Equal(t, tc.expected, busyTimeout) + }) + } +}