diff --git a/combined/cmd/config.go b/combined/cmd/config.go index 28340ee37..9959f7a56 100644 --- a/combined/cmd/config.go +++ b/combined/cmd/config.go @@ -179,9 +179,11 @@ type StoreConfig struct { // ReverseProxyConfig contains reverse proxy settings type ReverseProxyConfig struct { - TrustedHTTPProxies []string `yaml:"trustedHTTPProxies"` - TrustedHTTPProxiesCount uint `yaml:"trustedHTTPProxiesCount"` - TrustedPeers []string `yaml:"trustedPeers"` + TrustedHTTPProxies []string `yaml:"trustedHTTPProxies"` + TrustedHTTPProxiesCount uint `yaml:"trustedHTTPProxiesCount"` + TrustedPeers []string `yaml:"trustedPeers"` + AccessLogRetentionDays int `yaml:"accessLogRetentionDays"` + AccessLogCleanupIntervalHours int `yaml:"accessLogCleanupIntervalHours"` } // DefaultConfig returns a CombinedConfig with default values @@ -645,7 +647,9 @@ func (c *CombinedConfig) ToManagementConfig() (*nbconfig.Config, error) { // Build reverse proxy config reverseProxy := nbconfig.ReverseProxy{ - TrustedHTTPProxiesCount: mgmt.ReverseProxy.TrustedHTTPProxiesCount, + TrustedHTTPProxiesCount: mgmt.ReverseProxy.TrustedHTTPProxiesCount, + AccessLogRetentionDays: mgmt.ReverseProxy.AccessLogRetentionDays, + AccessLogCleanupIntervalHours: mgmt.ReverseProxy.AccessLogCleanupIntervalHours, } for _, p := range mgmt.ReverseProxy.TrustedHTTPProxies { if prefix, err := netip.ParsePrefix(p); err == nil { diff --git a/management/internals/modules/reverseproxy/accesslogs/manager/manager.go b/management/internals/modules/reverseproxy/accesslogs/manager/manager.go index e8d0ce763..59d7704eb 100644 --- a/management/internals/modules/reverseproxy/accesslogs/manager/manager.go +++ b/management/internals/modules/reverseproxy/accesslogs/manager/manager.go @@ -106,13 +106,23 @@ func (m *managerImpl) CleanupOldAccessLogs(ctx context.Context, retentionDays in // StartPeriodicCleanup starts a background goroutine that periodically cleans up old access logs func (m *managerImpl) StartPeriodicCleanup(ctx context.Context, retentionDays, cleanupIntervalHours int) { - if retentionDays <= 0 { - log.WithContext(ctx).Debug("periodic access log cleanup disabled: retention days is 0 or negative") + if retentionDays < 0 { + log.WithContext(ctx).Debug("periodic access log cleanup disabled: retention days is negative") return } + if retentionDays == 0 { + retentionDays = 7 + log.WithContext(ctx).Debugf("no retention days specified for access log cleanup, defaulting to %d days", retentionDays) + } else { + log.WithContext(ctx).Debugf("access log retention period set to %d days", retentionDays) + } + if cleanupIntervalHours <= 0 { cleanupIntervalHours = 24 + log.WithContext(ctx).Debugf("no cleanup interval specified for access log cleanup, defaulting to %d hours", cleanupIntervalHours) + } else { + log.WithContext(ctx).Debugf("access log cleanup interval set to %d hours", cleanupIntervalHours) } cleanupCtx, cancel := context.WithCancel(ctx) diff --git a/management/internals/modules/reverseproxy/accesslogs/manager/manager_test.go b/management/internals/modules/reverseproxy/accesslogs/manager/manager_test.go index 8fadef85f..11bf60829 100644 --- a/management/internals/modules/reverseproxy/accesslogs/manager/manager_test.go +++ b/management/internals/modules/reverseproxy/accesslogs/manager/manager_test.go @@ -121,7 +121,7 @@ func TestCleanupWithExactBoundary(t *testing.T) { } func TestStartPeriodicCleanup(t *testing.T) { - t.Run("periodic cleanup disabled with zero retention", func(t *testing.T) { + t.Run("periodic cleanup disabled with negative retention", func(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() @@ -135,7 +135,7 @@ func TestStartPeriodicCleanup(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - manager.StartPeriodicCleanup(ctx, 0, 1) + manager.StartPeriodicCleanup(ctx, -1, 1) time.Sleep(100 * time.Millisecond) diff --git a/management/internals/server/config/config.go b/management/internals/server/config/config.go index 0ba393263..fb9c842b7 100644 --- a/management/internals/server/config/config.go +++ b/management/internals/server/config/config.go @@ -203,7 +203,7 @@ type ReverseProxy struct { // AccessLogRetentionDays specifies the number of days to retain access logs. // Logs older than this duration will be automatically deleted during cleanup. - // A value of 0 or negative means logs are kept indefinitely (no cleanup). + // A value of 0 will default to 7 days. Negative means logs are kept indefinitely (no cleanup). AccessLogRetentionDays int // AccessLogCleanupIntervalHours specifies how often (in hours) to run the cleanup routine. diff --git a/management/server/migration/migration.go b/management/server/migration/migration.go index 29555ed0c..7a51cc200 100644 --- a/management/server/migration/migration.go +++ b/management/server/migration/migration.go @@ -489,6 +489,102 @@ func MigrateJsonToTable[T any](ctx context.Context, db *gorm.DB, columnName stri return nil } +// hasForeignKey checks whether a foreign key constraint exists on the given table and column. +func hasForeignKey(db *gorm.DB, table, column string) bool { + var count int64 + + switch db.Name() { + case "postgres": + db.Raw(` + SELECT COUNT(*) FROM information_schema.key_column_usage kcu + JOIN information_schema.table_constraints tc + ON tc.constraint_name = kcu.constraint_name + AND tc.table_schema = kcu.table_schema + WHERE tc.constraint_type = 'FOREIGN KEY' + AND kcu.table_name = ? + AND kcu.column_name = ? + `, table, column).Scan(&count) + case "mysql": + db.Raw(` + SELECT COUNT(*) FROM information_schema.key_column_usage + WHERE table_schema = DATABASE() + AND table_name = ? + AND column_name = ? + AND referenced_table_name IS NOT NULL + `, table, column).Scan(&count) + default: // sqlite + type fkInfo struct { + From string + } + var fks []fkInfo + db.Raw(fmt.Sprintf("PRAGMA foreign_key_list(%s)", table)).Scan(&fks) + for _, fk := range fks { + if fk.From == column { + return true + } + } + return false + } + + return count > 0 +} + +// CleanupOrphanedResources deletes rows from the table of model T where the foreign +// key column (fkColumn) references a row in the table of model R that no longer exists. +func CleanupOrphanedResources[T any, R any](ctx context.Context, db *gorm.DB, fkColumn string) error { + var model T + var refModel R + + if !db.Migrator().HasTable(&model) { + log.WithContext(ctx).Debugf("table for %T does not exist, no cleanup needed", model) + return nil + } + + if !db.Migrator().HasTable(&refModel) { + log.WithContext(ctx).Debugf("referenced table for %T does not exist, no cleanup needed", refModel) + return nil + } + + stmtT := &gorm.Statement{DB: db} + if err := stmtT.Parse(&model); err != nil { + return fmt.Errorf("parse model %T: %w", model, err) + } + childTable := stmtT.Schema.Table + + stmtR := &gorm.Statement{DB: db} + if err := stmtR.Parse(&refModel); err != nil { + return fmt.Errorf("parse reference model %T: %w", refModel, err) + } + parentTable := stmtR.Schema.Table + + if !db.Migrator().HasColumn(&model, fkColumn) { + log.WithContext(ctx).Debugf("column %s does not exist in table %s, no cleanup needed", fkColumn, childTable) + return nil + } + + // If a foreign key constraint already exists on the column, the DB itself + // enforces referential integrity and orphaned rows cannot exist. + if hasForeignKey(db, childTable, fkColumn) { + log.WithContext(ctx).Debugf("foreign key constraint for %s already exists on %s, no cleanup needed", fkColumn, childTable) + return nil + } + + result := db.Exec( + fmt.Sprintf( + "DELETE FROM %s WHERE %s NOT IN (SELECT id FROM %s)", + childTable, fkColumn, parentTable, + ), + ) + if result.Error != nil { + return fmt.Errorf("cleanup orphaned rows in %s: %w", childTable, result.Error) + } + + log.WithContext(ctx).Infof("Cleaned up %d orphaned rows from %s where %s had no matching row in %s", + result.RowsAffected, childTable, fkColumn, parentTable) + + return nil +} + func RemoveDuplicatePeerKeys(ctx context.Context, db *gorm.DB) error { if !db.Migrator().HasTable("peers") { log.WithContext(ctx).Debug("peers table does not exist, skipping duplicate key cleanup") diff --git a/management/server/migration/migration_test.go b/management/server/migration/migration_test.go index c1be8a3a3..5e00976c2 100644 --- a/management/server/migration/migration_test.go +++ b/management/server/migration/migration_test.go @@ -441,3 +441,197 @@ func TestRemoveDuplicatePeerKeys_NoTable(t *testing.T) { err := migration.RemoveDuplicatePeerKeys(context.Background(), db) require.NoError(t, err, "Should not fail when table does not exist") } + +type testParent struct { + ID string `gorm:"primaryKey"` +} + +func (testParent) TableName() string { + return "test_parents" +} + +type testChild struct { + ID string `gorm:"primaryKey"` + ParentID string +} + +func (testChild) TableName() string { + return "test_children" +} + +type testChildWithFK struct { + ID string `gorm:"primaryKey"` + ParentID string `gorm:"index"` + Parent *testParent `gorm:"foreignKey:ParentID"` +} + +func (testChildWithFK) TableName() string { + return "test_children" +} + +func setupOrphanTestDB(t *testing.T, models ...any) *gorm.DB { + t.Helper() + db := setupDatabase(t) + for _, m := range models { + _ = db.Migrator().DropTable(m) + } + err := db.AutoMigrate(models...) + require.NoError(t, err, "Failed to auto-migrate tables") + return db +} + +func TestCleanupOrphanedResources_NoChildTable(t *testing.T) { + db := setupDatabase(t) + _ = db.Migrator().DropTable(&testChild{}) + _ = db.Migrator().DropTable(&testParent{}) + + err := migration.CleanupOrphanedResources[testChild, testParent](context.Background(), db, "parent_id") + require.NoError(t, err, "Should not fail when child table does not exist") +} + +func TestCleanupOrphanedResources_NoParentTable(t *testing.T) { + db := setupDatabase(t) + _ = db.Migrator().DropTable(&testParent{}) + _ = db.Migrator().DropTable(&testChild{}) + + err := db.AutoMigrate(&testChild{}) + require.NoError(t, err) + + err = migration.CleanupOrphanedResources[testChild, testParent](context.Background(), db, "parent_id") + require.NoError(t, err, "Should not fail when parent table does not exist") +} + +func TestCleanupOrphanedResources_EmptyTables(t *testing.T) { + db := setupOrphanTestDB(t, &testParent{}, &testChild{}) + + err := migration.CleanupOrphanedResources[testChild, testParent](context.Background(), db, "parent_id") + require.NoError(t, err, "Should not fail on empty tables") + + var count int64 + db.Model(&testChild{}).Count(&count) + assert.Equal(t, int64(0), count) +} + +func TestCleanupOrphanedResources_NoOrphans(t *testing.T) { + db := setupOrphanTestDB(t, &testParent{}, &testChild{}) + + require.NoError(t, db.Create(&testParent{ID: "p1"}).Error) + require.NoError(t, db.Create(&testParent{ID: "p2"}).Error) + require.NoError(t, db.Create(&testChild{ID: "c1", ParentID: "p1"}).Error) + require.NoError(t, db.Create(&testChild{ID: "c2", ParentID: "p2"}).Error) + + err := migration.CleanupOrphanedResources[testChild, testParent](context.Background(), db, "parent_id") + require.NoError(t, err) + + var count int64 + db.Model(&testChild{}).Count(&count) + assert.Equal(t, int64(2), count, "All children should remain when no orphans") +} + +func TestCleanupOrphanedResources_AllOrphans(t *testing.T) { + db := setupOrphanTestDB(t, &testParent{}, &testChild{}) + + require.NoError(t, db.Exec("INSERT INTO test_children (id, parent_id) VALUES (?, ?)", "c1", "gone1").Error) + require.NoError(t, db.Exec("INSERT INTO test_children (id, parent_id) VALUES (?, ?)", "c2", "gone2").Error) + require.NoError(t, db.Exec("INSERT INTO test_children (id, parent_id) VALUES (?, ?)", "c3", "gone3").Error) + + err := migration.CleanupOrphanedResources[testChild, testParent](context.Background(), db, "parent_id") + require.NoError(t, err) + + var count int64 + db.Model(&testChild{}).Count(&count) + assert.Equal(t, int64(0), count, "All orphaned children should be deleted") +} + +func TestCleanupOrphanedResources_MixedValidAndOrphaned(t *testing.T) { + db := setupOrphanTestDB(t, &testParent{}, &testChild{}) + + require.NoError(t, db.Create(&testParent{ID: "p1"}).Error) + require.NoError(t, db.Create(&testParent{ID: "p2"}).Error) + + require.NoError(t, db.Create(&testChild{ID: "c1", ParentID: "p1"}).Error) + require.NoError(t, db.Create(&testChild{ID: "c2", ParentID: "p2"}).Error) + require.NoError(t, db.Create(&testChild{ID: "c3", ParentID: "p1"}).Error) + + require.NoError(t, db.Exec("INSERT INTO test_children (id, parent_id) VALUES (?, ?)", "c4", "gone1").Error) + require.NoError(t, db.Exec("INSERT INTO test_children (id, parent_id) VALUES (?, ?)", "c5", "gone2").Error) + + err := migration.CleanupOrphanedResources[testChild, testParent](context.Background(), db, "parent_id") + require.NoError(t, err) + + var remaining []testChild + require.NoError(t, db.Order("id").Find(&remaining).Error) + + assert.Len(t, remaining, 3, "Only valid children should remain") + assert.Equal(t, "c1", remaining[0].ID) + assert.Equal(t, "c2", remaining[1].ID) + assert.Equal(t, "c3", remaining[2].ID) +} + +func TestCleanupOrphanedResources_Idempotent(t *testing.T) { + db := setupOrphanTestDB(t, &testParent{}, &testChild{}) + + require.NoError(t, db.Create(&testParent{ID: "p1"}).Error) + require.NoError(t, db.Create(&testChild{ID: "c1", ParentID: "p1"}).Error) + require.NoError(t, db.Exec("INSERT INTO test_children (id, parent_id) VALUES (?, ?)", "c2", "gone").Error) + + ctx := context.Background() + + err := migration.CleanupOrphanedResources[testChild, testParent](ctx, db, "parent_id") + require.NoError(t, err) + + var count int64 + db.Model(&testChild{}).Count(&count) + assert.Equal(t, int64(1), count) + + err = migration.CleanupOrphanedResources[testChild, testParent](ctx, db, "parent_id") + require.NoError(t, err) + + db.Model(&testChild{}).Count(&count) + assert.Equal(t, int64(1), count, "Count should remain the same after second run") +} + +func TestCleanupOrphanedResources_SkipsWhenForeignKeyExists(t *testing.T) { + engine := os.Getenv("NETBIRD_STORE_ENGINE") + if engine != "postgres" && engine != "mysql" { + t.Skip("FK constraint early-exit test requires postgres or mysql") + } + + db := setupDatabase(t) + _ = db.Migrator().DropTable(&testChildWithFK{}) + _ = db.Migrator().DropTable(&testParent{}) + + err := db.AutoMigrate(&testParent{}, &testChildWithFK{}) + require.NoError(t, err) + + require.NoError(t, db.Create(&testParent{ID: "p1"}).Error) + require.NoError(t, db.Create(&testParent{ID: "p2"}).Error) + require.NoError(t, db.Create(&testChildWithFK{ID: "c1", ParentID: "p1"}).Error) + require.NoError(t, db.Create(&testChildWithFK{ID: "c2", ParentID: "p2"}).Error) + + switch engine { + case "postgres": + require.NoError(t, db.Exec("ALTER TABLE test_children DROP CONSTRAINT fk_test_children_parent").Error) + require.NoError(t, db.Exec("DELETE FROM test_parents WHERE id = ?", "p2").Error) + require.NoError(t, db.Exec( + "ALTER TABLE test_children ADD CONSTRAINT fk_test_children_parent "+ + "FOREIGN KEY (parent_id) REFERENCES test_parents(id) NOT VALID", + ).Error) + case "mysql": + require.NoError(t, db.Exec("SET FOREIGN_KEY_CHECKS = 0").Error) + require.NoError(t, db.Exec("ALTER TABLE test_children DROP FOREIGN KEY fk_test_children_parent").Error) + require.NoError(t, db.Exec("DELETE FROM test_parents WHERE id = ?", "p2").Error) + require.NoError(t, db.Exec( + "ALTER TABLE test_children ADD CONSTRAINT fk_test_children_parent "+ + "FOREIGN KEY (parent_id) REFERENCES test_parents(id)", + ).Error) + require.NoError(t, db.Exec("SET FOREIGN_KEY_CHECKS = 1").Error) + } + + err = migration.CleanupOrphanedResources[testChildWithFK, testParent](context.Background(), db, "parent_id") + require.NoError(t, err) + + var count int64 + db.Model(&testChildWithFK{}).Count(&count) + assert.Equal(t, int64(2), count, "Both rows should survive — migration must skip when FK constraint exists") +} diff --git a/management/server/store/store.go b/management/server/store/store.go index 9bd45618d..926b00415 100644 --- a/management/server/store/store.go +++ b/management/server/store/store.go @@ -449,6 +449,12 @@ func getMigrationsPreAuto(ctx context.Context) []migrationFunc { func(db *gorm.DB) error { return migration.RemoveDuplicatePeerKeys(ctx, db) }, + func(db *gorm.DB) error { + return migration.CleanupOrphanedResources[rpservice.Service, types.Account](ctx, db, "account_id") + }, + func(db *gorm.DB) error { + return migration.CleanupOrphanedResources[domain.Domain, types.Account](ctx, db, "account_id") + }, } } diff --git a/shared/relay/client/client.go b/shared/relay/client/client.go index ed1b63435..b10b05617 100644 --- a/shared/relay/client/client.go +++ b/shared/relay/client/client.go @@ -333,7 +333,7 @@ func (c *Client) connect(ctx context.Context) (*RelayAddr, error) { dialers := c.getDialers() rd := dialer.NewRaceDial(c.log, dialer.DefaultConnectionTimeout, c.connectionURL, dialers...) - conn, err := rd.Dial() + conn, err := rd.Dial(ctx) if err != nil { return nil, err } diff --git a/shared/relay/client/dialer/race_dialer.go b/shared/relay/client/dialer/race_dialer.go index 0550fc63e..34359d17e 100644 --- a/shared/relay/client/dialer/race_dialer.go +++ b/shared/relay/client/dialer/race_dialer.go @@ -40,10 +40,10 @@ func NewRaceDial(log *log.Entry, connectionTimeout time.Duration, serverURL stri } } -func (r *RaceDial) Dial() (net.Conn, error) { +func (r *RaceDial) Dial(ctx context.Context) (net.Conn, error) { connChan := make(chan dialResult, len(r.dialerFns)) winnerConn := make(chan net.Conn, 1) - abortCtx, abort := context.WithCancel(context.Background()) + abortCtx, abort := context.WithCancel(ctx) defer abort() for _, dfn := range r.dialerFns { diff --git a/shared/relay/client/dialer/race_dialer_test.go b/shared/relay/client/dialer/race_dialer_test.go index d216ec5e7..aa18df578 100644 --- a/shared/relay/client/dialer/race_dialer_test.go +++ b/shared/relay/client/dialer/race_dialer_test.go @@ -78,7 +78,7 @@ func TestRaceDialEmptyDialers(t *testing.T) { serverURL := "test.server.com" rd := NewRaceDial(logger, DefaultConnectionTimeout, serverURL) - conn, err := rd.Dial() + conn, err := rd.Dial(context.Background()) if err == nil { t.Errorf("Expected an error with empty dialers, got nil") } @@ -104,7 +104,7 @@ func TestRaceDialSingleSuccessfulDialer(t *testing.T) { } rd := NewRaceDial(logger, DefaultConnectionTimeout, serverURL, mockDialer) - conn, err := rd.Dial() + conn, err := rd.Dial(context.Background()) if err != nil { t.Errorf("Expected no error, got %v", err) } @@ -137,7 +137,7 @@ func TestRaceDialMultipleDialersWithOneSuccess(t *testing.T) { } rd := NewRaceDial(logger, DefaultConnectionTimeout, serverURL, mockDialer1, mockDialer2) - conn, err := rd.Dial() + conn, err := rd.Dial(context.Background()) if err != nil { t.Errorf("Expected no error, got %v", err) } @@ -160,7 +160,7 @@ func TestRaceDialTimeout(t *testing.T) { } rd := NewRaceDial(logger, 3*time.Second, serverURL, mockDialer) - conn, err := rd.Dial() + conn, err := rd.Dial(context.Background()) if err == nil { t.Errorf("Expected an error, got nil") } @@ -188,7 +188,7 @@ func TestRaceDialAllDialersFail(t *testing.T) { } rd := NewRaceDial(logger, DefaultConnectionTimeout, serverURL, mockDialer1, mockDialer2) - conn, err := rd.Dial() + conn, err := rd.Dial(context.Background()) if err == nil { t.Errorf("Expected an error, got nil") } @@ -230,7 +230,7 @@ func TestRaceDialFirstSuccessfulDialerWins(t *testing.T) { } rd := NewRaceDial(logger, DefaultConnectionTimeout, serverURL, mockDialer1, mockDialer2) - conn, err := rd.Dial() + conn, err := rd.Dial(context.Background()) if err != nil { t.Errorf("Expected no error, got %v", err) }