mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-16 07:16:38 +00:00
638 lines
19 KiB
Go
638 lines
19 KiB
Go
package migration
|
|
|
|
import (
|
|
"context"
|
|
"crypto/sha256"
|
|
"database/sql"
|
|
b64 "encoding/base64"
|
|
"encoding/gob"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"net"
|
|
"strings"
|
|
"unicode/utf8"
|
|
|
|
log "github.com/sirupsen/logrus"
|
|
"gorm.io/gorm"
|
|
"gorm.io/gorm/clause"
|
|
)
|
|
|
|
func GetColumnName(db *gorm.DB, column string) string {
|
|
if db.Name() == "mysql" {
|
|
return fmt.Sprintf("`%s`", column)
|
|
}
|
|
return column
|
|
}
|
|
|
|
// MigrateFieldFromGobToJSON migrates a column from Gob encoding to JSON encoding.
|
|
// T is the type of the model that contains the field to be migrated.
|
|
// S is the type of the field to be migrated.
|
|
func MigrateFieldFromGobToJSON[T any, S any](ctx context.Context, db *gorm.DB, fieldName string) error {
|
|
orgColumnName := fieldName
|
|
oldColumnName := GetColumnName(db, orgColumnName)
|
|
newColumnName := fieldName + "_tmp"
|
|
|
|
var model T
|
|
|
|
if !db.Migrator().HasTable(&model) {
|
|
log.WithContext(ctx).Debugf("Table for %T does not exist, no migration needed", model)
|
|
return nil
|
|
}
|
|
|
|
if !db.Migrator().HasColumn(&model, fieldName) {
|
|
log.WithContext(ctx).Debugf("Table for %T does not have column %s, no migration needed", model, fieldName)
|
|
return nil
|
|
}
|
|
|
|
stmt := &gorm.Statement{DB: db}
|
|
err := stmt.Parse(model)
|
|
if err != nil {
|
|
return fmt.Errorf("parse model: %w", err)
|
|
}
|
|
tableName := stmt.Schema.Table
|
|
|
|
var sqliteItem sql.NullString
|
|
if err := db.Model(model).Select(oldColumnName).First(&sqliteItem).Error; err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
log.WithContext(ctx).Debugf("No records in table %s, no migration needed", tableName)
|
|
return nil
|
|
}
|
|
return fmt.Errorf("fetch first record: %w", err)
|
|
}
|
|
|
|
item := sqliteItem.String
|
|
|
|
var js json.RawMessage
|
|
var syntaxError *json.SyntaxError
|
|
err = json.Unmarshal([]byte(item), &js)
|
|
// if the item is JSON parsable or an empty string it can not be gob encoded
|
|
if err == nil || !errors.As(err, &syntaxError) || item == "" {
|
|
log.WithContext(ctx).Debugf("No migration needed for %s, %s", tableName, fieldName)
|
|
return nil
|
|
}
|
|
|
|
if err := db.Transaction(func(tx *gorm.DB) error {
|
|
if err := tx.Exec(fmt.Sprintf("ALTER TABLE %s ADD COLUMN %s TEXT", tableName, newColumnName)).Error; err != nil {
|
|
return fmt.Errorf("add column %s: %w", newColumnName, err)
|
|
}
|
|
|
|
var rows []map[string]any
|
|
if err := tx.Table(tableName).Select("id", oldColumnName).Find(&rows).Error; err != nil {
|
|
return fmt.Errorf("find rows: %w", err)
|
|
}
|
|
|
|
for _, row := range rows {
|
|
var field S
|
|
|
|
str, ok := row[orgColumnName].(string)
|
|
if !ok {
|
|
return fmt.Errorf("type assertion failed")
|
|
}
|
|
reader := strings.NewReader(str)
|
|
|
|
if err := gob.NewDecoder(reader).Decode(&field); err != nil {
|
|
return fmt.Errorf("gob decode error: %w", err)
|
|
}
|
|
jsonValue, err := json.Marshal(field)
|
|
if err != nil {
|
|
return fmt.Errorf("re-encode to JSON: %w", err)
|
|
}
|
|
|
|
if err := tx.Table(tableName).Where("id = ?", row["id"]).Update(newColumnName, jsonValue).Error; err != nil {
|
|
return fmt.Errorf("update row: %w", err)
|
|
}
|
|
}
|
|
|
|
if err := tx.Exec(fmt.Sprintf("ALTER TABLE %s DROP COLUMN %s", tableName, oldColumnName)).Error; err != nil {
|
|
return fmt.Errorf("drop column %s: %w", oldColumnName, err)
|
|
}
|
|
if err := tx.Exec(fmt.Sprintf("ALTER TABLE %s RENAME COLUMN %s TO %s", tableName, newColumnName, oldColumnName)).Error; err != nil {
|
|
return fmt.Errorf("rename column %s to %s: %w", newColumnName, oldColumnName, err)
|
|
}
|
|
|
|
return nil
|
|
}); err != nil {
|
|
return err
|
|
}
|
|
|
|
log.WithContext(ctx).Infof("Migration of %s.%s from gob to json completed", tableName, fieldName)
|
|
|
|
return nil
|
|
}
|
|
|
|
// MigrateNetIPFieldFromBlobToJSON migrates a Net IP column from Blob encoding to JSON encoding.
|
|
// T is the type of the model that contains the field to be migrated.
|
|
func MigrateNetIPFieldFromBlobToJSON[T any](ctx context.Context, db *gorm.DB, fieldName string, indexName string) error {
|
|
orgColumnName := fieldName
|
|
oldColumnName := GetColumnName(db, orgColumnName)
|
|
newColumnName := fieldName + "_tmp"
|
|
|
|
var model T
|
|
|
|
if !db.Migrator().HasTable(&model) {
|
|
log.Printf("Table for %T does not exist, no migration needed", model)
|
|
return nil
|
|
}
|
|
|
|
stmt := &gorm.Statement{DB: db}
|
|
err := stmt.Parse(&model)
|
|
if err != nil {
|
|
return fmt.Errorf("parse model: %w", err)
|
|
}
|
|
tableName := stmt.Schema.Table
|
|
|
|
var item sql.NullString
|
|
if err := db.Model(&model).Select(oldColumnName).First(&item).Error; err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
log.Printf("No records in table %s, no migration needed", tableName)
|
|
return nil
|
|
}
|
|
return fmt.Errorf("fetch first record: %w", err)
|
|
}
|
|
|
|
if item.Valid {
|
|
var js json.RawMessage
|
|
var syntaxError *json.SyntaxError
|
|
err = json.Unmarshal([]byte(item.String), &js)
|
|
if err == nil || !errors.As(err, &syntaxError) {
|
|
log.WithContext(ctx).Debugf("No migration needed for %s, %s", tableName, fieldName)
|
|
return nil
|
|
}
|
|
}
|
|
|
|
if err := db.Transaction(func(tx *gorm.DB) error {
|
|
if err := tx.Exec(fmt.Sprintf("ALTER TABLE %s ADD COLUMN %s TEXT", tableName, newColumnName)).Error; err != nil {
|
|
return fmt.Errorf("add column %s: %w", newColumnName, err)
|
|
}
|
|
|
|
var rows []map[string]any
|
|
if err := tx.Table(tableName).Select("id", oldColumnName).Find(&rows).Error; err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
log.Printf("No records in table %s, no migration needed", tableName)
|
|
return nil
|
|
}
|
|
return fmt.Errorf("find rows: %w", err)
|
|
}
|
|
|
|
for _, row := range rows {
|
|
var blobValue string
|
|
if columnValue := row[orgColumnName]; columnValue != nil {
|
|
value, ok := columnValue.(string)
|
|
if !ok {
|
|
return fmt.Errorf("type assertion failed")
|
|
}
|
|
blobValue = value
|
|
}
|
|
|
|
columnIpValue := net.IP(blobValue)
|
|
if net.ParseIP(columnIpValue.String()) == nil {
|
|
log.WithContext(ctx).Debugf("failed to parse %s as ip, fallback to ipv6 loopback", oldColumnName)
|
|
columnIpValue = net.IPv6loopback
|
|
}
|
|
|
|
jsonValue, err := json.Marshal(columnIpValue)
|
|
if err != nil {
|
|
return fmt.Errorf("re-encode to JSON: %w", err)
|
|
}
|
|
|
|
if err := tx.Table(tableName).Where("id = ?", row["id"]).Update(newColumnName, jsonValue).Error; err != nil {
|
|
return fmt.Errorf("update row: %w", err)
|
|
}
|
|
}
|
|
|
|
if indexName != "" {
|
|
if err := tx.Migrator().DropIndex(&model, indexName); err != nil {
|
|
return fmt.Errorf("drop index %s: %w", indexName, err)
|
|
}
|
|
}
|
|
|
|
if err := tx.Exec(fmt.Sprintf("ALTER TABLE %s DROP COLUMN %s", tableName, oldColumnName)).Error; err != nil {
|
|
return fmt.Errorf("drop column %s: %w", oldColumnName, err)
|
|
}
|
|
if err := tx.Exec(fmt.Sprintf("ALTER TABLE %s RENAME COLUMN %s TO %s", tableName, newColumnName, oldColumnName)).Error; err != nil {
|
|
return fmt.Errorf("rename column %s to %s: %w", newColumnName, oldColumnName, err)
|
|
}
|
|
return nil
|
|
}); err != nil {
|
|
return err
|
|
}
|
|
|
|
log.Printf("Migration of %s.%s from blob to json completed", tableName, fieldName)
|
|
|
|
return nil
|
|
}
|
|
|
|
func MigrateSetupKeyToHashedSetupKey[T any](ctx context.Context, db *gorm.DB) error {
|
|
orgColumnName := "key"
|
|
oldColumnName := GetColumnName(db, orgColumnName)
|
|
newColumnName := "key_secret"
|
|
|
|
var model T
|
|
|
|
if !db.Migrator().HasTable(&model) {
|
|
log.WithContext(ctx).Debugf("Table for %T does not exist, no migration needed", model)
|
|
return nil
|
|
}
|
|
|
|
stmt := &gorm.Statement{DB: db}
|
|
err := stmt.Parse(&model)
|
|
if err != nil {
|
|
return fmt.Errorf("parse model: %w", err)
|
|
}
|
|
tableName := stmt.Schema.Table
|
|
|
|
if err := db.Transaction(func(tx *gorm.DB) error {
|
|
if !tx.Migrator().HasColumn(&model, newColumnName) {
|
|
log.WithContext(ctx).Infof("Column %s does not exist in table %s, adding it", newColumnName, tableName)
|
|
if err := tx.Migrator().AddColumn(&model, newColumnName); err != nil {
|
|
return fmt.Errorf("add column %s: %w", newColumnName, err)
|
|
}
|
|
}
|
|
|
|
var rows []map[string]any
|
|
if err := tx.Table(tableName).
|
|
Select("id", oldColumnName, newColumnName).
|
|
Where(newColumnName + " IS NULL OR " + newColumnName + " = ''").
|
|
Where("SUBSTR(" + oldColumnName + ", 9, 1) = '-'").
|
|
Find(&rows).Error; err != nil {
|
|
return fmt.Errorf("find rows with empty secret key and matching pattern: %w", err)
|
|
}
|
|
|
|
if len(rows) == 0 {
|
|
log.WithContext(ctx).Infof("No plain setup keys found in table %s, no migration needed", tableName)
|
|
return nil
|
|
}
|
|
|
|
for _, row := range rows {
|
|
|
|
var plainKey string
|
|
if columnValue := row[orgColumnName]; columnValue != nil {
|
|
value, ok := columnValue.(string)
|
|
if !ok {
|
|
return fmt.Errorf("type assertion failed")
|
|
}
|
|
plainKey = value
|
|
}
|
|
|
|
secretKey := hiddenKey(plainKey, 4)
|
|
|
|
hashedKey := sha256.Sum256([]byte(plainKey))
|
|
encodedHashedKey := b64.StdEncoding.EncodeToString(hashedKey[:])
|
|
|
|
if err := tx.Table(tableName).Where("id = ?", row["id"]).Update(newColumnName, secretKey).Error; err != nil {
|
|
return fmt.Errorf("update row with secret key: %w", err)
|
|
}
|
|
|
|
if err := tx.Table(tableName).Where("id = ?", row["id"]).Update(oldColumnName, encodedHashedKey).Error; err != nil {
|
|
return fmt.Errorf("update row with hashed key: %w", err)
|
|
}
|
|
}
|
|
|
|
if err := tx.Exec(fmt.Sprintf("ALTER TABLE %s DROP COLUMN IF EXISTS %s", "peers", "setup_key")).Error; err != nil {
|
|
log.WithContext(ctx).Errorf("Failed to drop column %s: %v", "setup_key", err)
|
|
}
|
|
|
|
return nil
|
|
}); err != nil {
|
|
return err
|
|
}
|
|
|
|
log.Printf("Migration of plain setup key to hashed setup key completed")
|
|
return nil
|
|
}
|
|
|
|
// hiddenKey returns the Key value hidden with "*" and a 5 character prefix.
|
|
// E.g., "831F6*******************************"
|
|
func hiddenKey(key string, length int) string {
|
|
prefix := key[0:5]
|
|
if length > utf8.RuneCountInString(key) {
|
|
length = utf8.RuneCountInString(key) - len(prefix)
|
|
}
|
|
return prefix + strings.Repeat("*", length)
|
|
}
|
|
|
|
func MigrateNewField[T any](ctx context.Context, db *gorm.DB, columnName string, defaultValue any) error {
|
|
var model T
|
|
|
|
if !db.Migrator().HasTable(&model) {
|
|
log.WithContext(ctx).Debugf("Table for %T does not exist, no migration needed", model)
|
|
return nil
|
|
}
|
|
|
|
stmt := &gorm.Statement{DB: db}
|
|
err := stmt.Parse(&model)
|
|
if err != nil {
|
|
return fmt.Errorf("parse model: %w", err)
|
|
}
|
|
tableName := stmt.Schema.Table
|
|
|
|
if err := db.Transaction(func(tx *gorm.DB) error {
|
|
if !tx.Migrator().HasColumn(&model, columnName) {
|
|
log.WithContext(ctx).Infof("Column %s does not exist in table %s, adding it", columnName, tableName)
|
|
if err := tx.Migrator().AddColumn(&model, columnName); err != nil {
|
|
return fmt.Errorf("add column %s: %w", columnName, err)
|
|
}
|
|
}
|
|
|
|
var rows []map[string]any
|
|
if err := tx.Table(tableName).Select("id", columnName).Where(columnName + " IS NULL").Find(&rows).Error; err != nil {
|
|
return fmt.Errorf("failed to find rows with empty %s: %w", columnName, err)
|
|
}
|
|
|
|
if len(rows) == 0 {
|
|
log.WithContext(ctx).Infof("No rows with empty %s found in table %s, no migration needed", columnName, tableName)
|
|
return nil
|
|
}
|
|
|
|
for _, row := range rows {
|
|
if err := tx.Table(tableName).Where("id = ?", row["id"]).Update(columnName, defaultValue).Error; err != nil {
|
|
return fmt.Errorf("failed to update row with id %v: %w", row["id"], err)
|
|
}
|
|
}
|
|
return nil
|
|
}); err != nil {
|
|
return err
|
|
}
|
|
|
|
log.WithContext(ctx).Infof("Migration of empty %s to default value in table %s completed", columnName, tableName)
|
|
return nil
|
|
}
|
|
|
|
func DropIndex[T any](ctx context.Context, db *gorm.DB, indexName string) error {
|
|
var model T
|
|
|
|
if !db.Migrator().HasTable(&model) {
|
|
log.WithContext(ctx).Debugf("table for %T does not exist, no migration needed", model)
|
|
return nil
|
|
}
|
|
|
|
if !db.Migrator().HasIndex(&model, indexName) {
|
|
log.WithContext(ctx).Debugf("index %s does not exist in table %T, no migration needed", indexName, model)
|
|
return nil
|
|
}
|
|
|
|
if err := db.Migrator().DropIndex(&model, indexName); err != nil {
|
|
return fmt.Errorf("failed to drop index %s: %w", indexName, err)
|
|
}
|
|
|
|
log.WithContext(ctx).Infof("dropped index %s from table %T", indexName, model)
|
|
return nil
|
|
}
|
|
|
|
func CreateIndexIfNotExists[T any](ctx context.Context, db *gorm.DB, indexName string, columns ...string) error {
|
|
var model T
|
|
|
|
if !db.Migrator().HasTable(&model) {
|
|
log.WithContext(ctx).Debugf("table for %T does not exist, no migration needed", model)
|
|
return nil
|
|
}
|
|
|
|
stmt := &gorm.Statement{DB: db}
|
|
if err := stmt.Parse(&model); err != nil {
|
|
return fmt.Errorf("failed to parse model schema: %w", err)
|
|
}
|
|
tableName := stmt.Schema.Table
|
|
dialect := db.Name()
|
|
|
|
if db.Migrator().HasIndex(&model, indexName) {
|
|
log.WithContext(ctx).Infof("index %s already exists on table %s", indexName, tableName)
|
|
return nil
|
|
}
|
|
|
|
var columnClause string
|
|
if dialect == "mysql" {
|
|
var withLength []string
|
|
for _, col := range columns {
|
|
quotedCol := fmt.Sprintf("`%s`", col)
|
|
if col == "ip" || col == "dns_label" || col == "key" {
|
|
withLength = append(withLength, fmt.Sprintf("%s(64)", quotedCol))
|
|
} else {
|
|
withLength = append(withLength, quotedCol)
|
|
}
|
|
}
|
|
columnClause = strings.Join(withLength, ", ")
|
|
} else {
|
|
columnClause = strings.Join(columns, ", ")
|
|
}
|
|
|
|
createStmt := fmt.Sprintf("CREATE UNIQUE INDEX %s ON %s (%s)", indexName, tableName, columnClause)
|
|
if dialect == "postgres" || dialect == "sqlite" {
|
|
createStmt = strings.Replace(createStmt, "CREATE UNIQUE INDEX", "CREATE UNIQUE INDEX IF NOT EXISTS", 1)
|
|
}
|
|
|
|
log.WithContext(ctx).Infof("executing index creation: %s", createStmt)
|
|
if err := db.Exec(createStmt).Error; err != nil {
|
|
return fmt.Errorf("failed to create index %s: %w", indexName, err)
|
|
}
|
|
|
|
log.WithContext(ctx).Infof("successfully created index %s on table %s", indexName, tableName)
|
|
return nil
|
|
}
|
|
|
|
func MigrateJsonToTable[T any](ctx context.Context, db *gorm.DB, columnName string, mapperFunc func(accountID string, id string, value string) any) error {
|
|
var model T
|
|
|
|
if !db.Migrator().HasTable(&model) {
|
|
log.WithContext(ctx).Debugf("table for %T does not exist, no migration needed", model)
|
|
return nil
|
|
}
|
|
|
|
stmt := &gorm.Statement{DB: db}
|
|
err := stmt.Parse(&model)
|
|
if err != nil {
|
|
return fmt.Errorf("parse model: %w", err)
|
|
}
|
|
tableName := stmt.Schema.Table
|
|
|
|
if !db.Migrator().HasColumn(&model, columnName) {
|
|
log.WithContext(ctx).Debugf("column %s does not exist in table %s, no migration needed", columnName, tableName)
|
|
return nil
|
|
}
|
|
|
|
if err := db.Transaction(func(tx *gorm.DB) error {
|
|
var rows []map[string]any
|
|
if err := tx.Table(tableName).Select("id", "account_id", columnName).Find(&rows).Error; err != nil {
|
|
return fmt.Errorf("find rows: %w", err)
|
|
}
|
|
|
|
for _, row := range rows {
|
|
jsonValue, ok := row[columnName].(string)
|
|
if !ok || jsonValue == "" {
|
|
continue
|
|
}
|
|
|
|
var data []string
|
|
if err := json.Unmarshal([]byte(jsonValue), &data); err != nil {
|
|
return fmt.Errorf("unmarshal json: %w", err)
|
|
}
|
|
|
|
for _, value := range data {
|
|
if err := tx.Clauses(clause.OnConflict{DoNothing: true}).Create(
|
|
mapperFunc(row["account_id"].(string), row["id"].(string), value),
|
|
).Error; err != nil {
|
|
return fmt.Errorf("failed to insert id %v: %w", row["id"], err)
|
|
}
|
|
}
|
|
}
|
|
|
|
if err := tx.Migrator().DropColumn(&model, columnName); err != nil {
|
|
return fmt.Errorf("drop column %s: %w", columnName, err)
|
|
}
|
|
|
|
return nil
|
|
}); err != nil {
|
|
return err
|
|
}
|
|
|
|
log.WithContext(ctx).Infof("Migration of JSON field %s from table %s into separate table completed", columnName, tableName)
|
|
return nil
|
|
}
|
|
|
|
// hasForeignKey checks whether a foreign key constraint exists on the given table and column.
|
|
func hasForeignKey(db *gorm.DB, table, column string) bool {
|
|
var count int64
|
|
|
|
switch db.Name() {
|
|
case "postgres":
|
|
db.Raw(`
|
|
SELECT COUNT(*) FROM information_schema.key_column_usage kcu
|
|
JOIN information_schema.table_constraints tc
|
|
ON tc.constraint_name = kcu.constraint_name
|
|
AND tc.table_schema = kcu.table_schema
|
|
WHERE tc.constraint_type = 'FOREIGN KEY'
|
|
AND kcu.table_name = ?
|
|
AND kcu.column_name = ?
|
|
`, table, column).Scan(&count)
|
|
case "mysql":
|
|
db.Raw(`
|
|
SELECT COUNT(*) FROM information_schema.key_column_usage
|
|
WHERE table_schema = DATABASE()
|
|
AND table_name = ?
|
|
AND column_name = ?
|
|
AND referenced_table_name IS NOT NULL
|
|
`, table, column).Scan(&count)
|
|
default: // sqlite
|
|
type fkInfo struct {
|
|
From string
|
|
}
|
|
var fks []fkInfo
|
|
db.Raw(fmt.Sprintf("PRAGMA foreign_key_list(%s)", table)).Scan(&fks)
|
|
for _, fk := range fks {
|
|
if fk.From == column {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
return count > 0
|
|
}
|
|
|
|
// CleanupOrphanedResources deletes rows from the table of model T where the foreign
|
|
// key column (fkColumn) references a row in the table of model R that no longer exists.
|
|
func CleanupOrphanedResources[T any, R any](ctx context.Context, db *gorm.DB, fkColumn string) error {
|
|
var model T
|
|
var refModel R
|
|
|
|
if !db.Migrator().HasTable(&model) {
|
|
log.WithContext(ctx).Debugf("table for %T does not exist, no cleanup needed", model)
|
|
return nil
|
|
}
|
|
|
|
if !db.Migrator().HasTable(&refModel) {
|
|
log.WithContext(ctx).Debugf("referenced table for %T does not exist, no cleanup needed", refModel)
|
|
return nil
|
|
}
|
|
|
|
stmtT := &gorm.Statement{DB: db}
|
|
if err := stmtT.Parse(&model); err != nil {
|
|
return fmt.Errorf("parse model %T: %w", model, err)
|
|
}
|
|
childTable := stmtT.Schema.Table
|
|
|
|
stmtR := &gorm.Statement{DB: db}
|
|
if err := stmtR.Parse(&refModel); err != nil {
|
|
return fmt.Errorf("parse reference model %T: %w", refModel, err)
|
|
}
|
|
parentTable := stmtR.Schema.Table
|
|
|
|
if !db.Migrator().HasColumn(&model, fkColumn) {
|
|
log.WithContext(ctx).Debugf("column %s does not exist in table %s, no cleanup needed", fkColumn, childTable)
|
|
return nil
|
|
}
|
|
|
|
// If a foreign key constraint already exists on the column, the DB itself
|
|
// enforces referential integrity and orphaned rows cannot exist.
|
|
if hasForeignKey(db, childTable, fkColumn) {
|
|
log.WithContext(ctx).Debugf("foreign key constraint for %s already exists on %s, no cleanup needed", fkColumn, childTable)
|
|
return nil
|
|
}
|
|
|
|
result := db.Exec(
|
|
fmt.Sprintf(
|
|
"DELETE FROM %s WHERE %s NOT IN (SELECT id FROM %s)",
|
|
childTable, fkColumn, parentTable,
|
|
),
|
|
)
|
|
if result.Error != nil {
|
|
return fmt.Errorf("cleanup orphaned rows in %s: %w", childTable, result.Error)
|
|
}
|
|
|
|
log.WithContext(ctx).Infof("Cleaned up %d orphaned rows from %s where %s had no matching row in %s",
|
|
result.RowsAffected, childTable, fkColumn, parentTable)
|
|
|
|
return nil
|
|
}
|
|
|
|
func RemoveDuplicatePeerKeys(ctx context.Context, db *gorm.DB) error {
|
|
if !db.Migrator().HasTable("peers") {
|
|
log.WithContext(ctx).Debug("peers table does not exist, skipping duplicate key cleanup")
|
|
return nil
|
|
}
|
|
|
|
keyColumn := GetColumnName(db, "key")
|
|
|
|
var duplicates []struct {
|
|
Key string
|
|
Count int64
|
|
}
|
|
|
|
if err := db.Table("peers").
|
|
Select(keyColumn + ", COUNT(*) as count").
|
|
Group(keyColumn).
|
|
Having("COUNT(*) > 1").
|
|
Find(&duplicates).Error; err != nil {
|
|
return fmt.Errorf("find duplicate keys: %w", err)
|
|
}
|
|
|
|
if len(duplicates) == 0 {
|
|
return nil
|
|
}
|
|
|
|
log.WithContext(ctx).Warnf("Found %d duplicate peer keys, cleaning up", len(duplicates))
|
|
|
|
for _, dup := range duplicates {
|
|
var peerIDs []string
|
|
if err := db.Table("peers").
|
|
Select("id").
|
|
Where(keyColumn+" = ?", dup.Key).
|
|
Order("peer_status_last_seen DESC").
|
|
Pluck("id", &peerIDs).Error; err != nil {
|
|
return fmt.Errorf("get peers for key: %w", err)
|
|
}
|
|
|
|
if len(peerIDs) <= 1 {
|
|
continue
|
|
}
|
|
|
|
idsToDelete := peerIDs[1:]
|
|
|
|
if err := db.Table("peers").Where("id IN ?", idsToDelete).Delete(nil).Error; err != nil {
|
|
return fmt.Errorf("delete duplicate peers: %w", err)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|