add network migration

This commit is contained in:
Pascal Fischer
2025-07-03 11:20:16 +02:00
parent e23282b92c
commit d06831dd2f
11 changed files with 157 additions and 33 deletions

View File

@@ -1688,7 +1688,7 @@ func (am *DefaultAccountManager) GetAccountSettings(ctx context.Context, account
func newAccountWithId(ctx context.Context, accountID, userID, domain string, disableDefaultPolicy bool) *types.Account {
log.WithContext(ctx).Debugf("creating new account")
network := types.NewNetwork()
network := types.NewNetwork(accountID)
peers := make(map[string]*nbpeer.Peer)
users := make(map[string]*types.User)
routes := make(map[route.ID]*route.Route)
@@ -1792,7 +1792,7 @@ func (am *DefaultAccountManager) GetOrCreateAccountByPrivateDomain(ctx context.C
continue
}
network := types.NewNetwork()
network := types.NewNetwork(accountId)
peers := make(map[string]*nbpeer.Peer)
users := make(map[string]*types.User)
routes := make(map[route.ID]*route.Route)

View File

@@ -69,7 +69,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
handler := initAccountsTestData(t, &types.Account{
Id: accountID,
Domain: "hotmail.com",
Network: types.NewNetwork(),
Network: types.NewNetwork(accountID),
Users: map[string]*types.User{
adminUser.Id: adminUser,
},

View File

@@ -10,6 +10,7 @@ import (
"errors"
"fmt"
"net"
"reflect"
"strings"
"unicode/utf8"
@@ -17,6 +18,16 @@ import (
"gorm.io/gorm"
)
type LegacyAccountNetwork struct {
AccountID string `gorm:"column:id"`
Identifier string `gorm:"column:network_identifier"`
Net net.IPNet `gorm:"column:network_net;serializer:json"`
Dns string `gorm:"column:network_dns"`
// Serial is an ID that increments by 1 when any change to the network happened (e.g. new peer has been added).
// Used to synchronize state to the client apps.
Serial uint64 `gorm:"column:network_serial"`
}
func GetColumnName(db *gorm.DB, column string) string {
if db.Name() == "mysql" {
return fmt.Sprintf("`%s`", column)
@@ -39,6 +50,11 @@ func MigrateFieldFromGobToJSON[T any, S any](ctx context.Context, db *gorm.DB, f
return nil
}
if !db.Migrator().HasColumn(&model, oldColumnName) {
log.WithContext(ctx).Debugf("Column for %T does not exist, no migration needed", oldColumnName)
return nil
}
stmt := &gorm.Statement{DB: db}
err := stmt.Parse(model)
if err != nil {
@@ -471,3 +487,90 @@ func MigrateJsonToTable[T any](ctx context.Context, db *gorm.DB, columnName stri
log.WithContext(ctx).Infof("Migration of JSON field %s from table %s into seperte table completed", columnName, tableName)
return nil
}
func MigrateEmbeddedToTable[T any, S any, U any](ctx context.Context, db *gorm.DB, pkey string, mapperFunc func(obj S) *U) error {
var model T
var u U
log.WithContext(ctx).Debugf("Migrating embedded fields from %T to separate table", model)
if !db.Migrator().HasTable(&model) {
log.WithContext(ctx).Debugf("table for %T does not exist, no migration needed", model)
return nil
}
if db.Migrator().HasTable(&u) {
log.WithContext(ctx).Debugf("table for %T already exists, no migration needed", u)
return nil
}
stmt := &gorm.Statement{DB: db}
err := stmt.Parse(&model)
if err != nil {
return fmt.Errorf("parse model: %w", err)
}
tableName := stmt.Schema.Table
if err := db.Transaction(func(tx *gorm.DB) error {
var legacyRows []S
if err := tx.Table(tableName).Find(&legacyRows).Error; err != nil {
log.WithContext(ctx).Errorf("Failed to read legacy accounts: %v", err)
return fmt.Errorf("failed to read legacy accounts: %w", err)
}
for _, row := range legacyRows {
if err := tx.Create(
mapperFunc(row),
).Error; err != nil {
return fmt.Errorf("failed to insert id %v: %w", row, err)
}
}
cols, err := getColumnNamesFromStruct(new(S))
if err != nil {
return fmt.Errorf("failed to extract column names: %w", err)
}
for _, col := range cols {
if col == pkey {
continue
}
if err := tx.Migrator().DropColumn(&model, col); err != nil {
return fmt.Errorf("failed to drop column %s: %w", col, err)
}
}
return nil
}); err != nil {
return err
}
log.WithContext(ctx).Infof("Migration of embedded fields %T from table %s into seperte table completed", new(S), tableName)
return nil
}
func getColumnNamesFromStruct[T any](model T) ([]string, error) {
val := reflect.TypeOf(model)
if val.Kind() == reflect.Ptr {
val = val.Elem()
}
var cols []string
for i := 0; i < val.NumField(); i++ {
field := val.Field(i)
if field.Name == "ID" {
continue // skip primary key
}
tag := field.Tag.Get("gorm")
if tag == "" {
continue
}
// Look for gorm:"column:..."
for _, part := range strings.Split(tag, ";") {
if strings.HasPrefix(part, "column:") {
cols = append(cols, strings.TrimPrefix(part, "column:"))
break
}
}
}
return cols, nil
}

View File

@@ -588,12 +588,12 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
newPeer.DNSLabel = freeLabel
newPeer.IP = freeIP
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer func() {
if unlock != nil {
unlock()
}
}()
// unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
// defer func() {
// if unlock != nil {
// unlock()
// }
// }()
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
err = transaction.AddPeerToAccount(ctx, store.LockingStrengthUpdate, newPeer)
@@ -646,14 +646,14 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
return nil
})
if err == nil {
unlock()
unlock = nil
// unlock()
// unlock = nil
break
}
if isUniqueConstraintError(err) {
unlock()
unlock = nil
// unlock()
// unlock = nil
log.WithContext(ctx).Debugf("Failed to add peer in attempt %d, retrying: %v", attempt, err)
continue
}
@@ -1255,17 +1255,19 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account
}
func (am *DefaultAccountManager) BufferUpdateAccountPeers(ctx context.Context, accountID string) {
mu, _ := am.accountUpdateLocks.LoadOrStore(accountID, &sync.Mutex{})
lock := mu.(*sync.Mutex)
if !lock.TryLock() {
return
}
go func() {
time.Sleep(time.Duration(am.updateAccountPeersBufferInterval.Load()))
lock.Unlock()
am.UpdateAccountPeers(ctx, accountID)
mu, _ := am.accountUpdateLocks.LoadOrStore(accountID, &sync.Mutex{})
lock := mu.(*sync.Mutex)
if !lock.TryLock() {
return
}
go func() {
time.Sleep(time.Duration(am.updateAccountPeersBufferInterval.Load()))
lock.Unlock()
am.UpdateAccountPeers(ctx, accountID)
}()
}()
}

View File

@@ -2145,6 +2145,7 @@ func Test_IsUniqueConstraintError(t *testing.T) {
func Test_AddPeer(t *testing.T) {
t.Setenv("NETBIRD_STORE_ENGINE", string(types.PostgresStoreEngine))
t.Setenv("NB_GET_ACCOUNT_BUFFER_INTERVAL", "300ms")
t.Setenv("NB_PEER_UPDATE_BUFFER_INTERVAL", "300ms")
manager, err := createManager(t)
if err != nil {
t.Fatal(err)

View File

@@ -96,7 +96,7 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, met
return nil, fmt.Errorf("migratePreAuto: %w", err)
}
err = db.AutoMigrate(
&types.SetupKey{}, &nbpeer.Peer{}, &types.User{}, &types.PersonalAccessToken{}, &types.Group{}, &types.GroupPeer{},
&types.SetupKey{}, &nbpeer.Peer{}, &types.User{}, &types.PersonalAccessToken{}, &types.Group{}, &types.GroupPeer{}, &types.Network{},
&types.Account{}, &types.Policy{}, &types.PolicyRule{}, &route.Route{}, &nbdns.NameServerGroup{},
&installation{}, &types.ExtraSettings{}, &posture.Checks{}, &nbpeer.NetworkAddress{},
&networkTypes.Network{}, &routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{},
@@ -1024,14 +1024,14 @@ func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingSt
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var accountNetwork types.AccountNetwork
if err := tx.Model(&types.Account{}).Where(idQueryCondition, accountID).First(&accountNetwork).Error; err != nil {
accountNetwork := types.Network{}
if err := tx.Where(accountIDCondition, accountID).First(&accountNetwork).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, status.NewAccountNotFoundError(accountID)
}
return nil, status.Errorf(status.Internal, "issue getting network from store: %s", err)
}
return accountNetwork.Network, nil
return &accountNetwork, nil
}
func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error) {
@@ -1632,7 +1632,7 @@ func (s *SqlStore) DeletePeer(ctx context.Context, lockStrength LockingStrength,
func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, lockStrength LockingStrength, accountId string) error {
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
Model(&types.Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1"))
Model(&types.Network{}).Where(accountIDCondition, accountId).Update("serial", gorm.Expr("serial + 1"))
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to increment network serial count in store: %v", result.Error)
return status.Errorf(status.Internal, "failed to increment network serial count in store")

View File

@@ -2059,7 +2059,7 @@ func TestSqlStore_DeleteNameServerGroup(t *testing.T) {
func newAccountWithId(ctx context.Context, accountID, userID, domain string) *types.Account {
log.WithContext(ctx).Debugf("creating new account")
network := types.NewNetwork()
network := types.NewNetwork(accountID)
peers := make(map[string]*nbpeer.Peer)
users := make(map[string]*types.User)
routes := make(map[nbroute.ID]*nbroute.Route)

View File

@@ -361,6 +361,17 @@ func getMigrationsPostAuto(ctx context.Context) []migrationFunc {
}
})
},
func(db *gorm.DB) error {
return migration.MigrateEmbeddedToTable[types.Account, migration.LegacyAccountNetwork, types.Network](ctx, db, "id", func(obj migration.LegacyAccountNetwork) *types.Network {
return &types.Network{
AccountID: obj.AccountID,
Identifier: obj.Identifier,
Net: obj.Net,
Serial: obj.Serial,
Dns: obj.Dns,
}
})
},
}
}

View File

@@ -67,7 +67,7 @@ type Account struct {
IsDomainPrimaryAccount bool
SetupKeys map[string]*SetupKey `gorm:"-"`
SetupKeysG []SetupKey `json:"-" gorm:"foreignKey:AccountID;references:id"`
Network *Network `gorm:"embedded;embeddedPrefix:network_"`
Network *Network `json:"-" gorm:"foreignKey:AccountID;references:id"`
Peers map[string]*nbpeer.Peer `gorm:"-"`
PeersG []nbpeer.Peer `json:"-" gorm:"foreignKey:AccountID;references:id"`
Users map[string]*User `gorm:"-"`

View File

@@ -107,7 +107,8 @@ func ipToBytes(ip net.IP) []byte {
}
type Network struct {
Identifier string `json:"id"`
AccountID string `gorm:"primaryKey"`
Identifier string `gorm:"index"`
Net net.IPNet `gorm:"serializer:json"`
Dns string
// Serial is an ID that increments by 1 when any change to the network happened (e.g. new peer has been added).
@@ -117,9 +118,13 @@ type Network struct {
Mu sync.Mutex `json:"-" gorm:"-"`
}
func (*Network) TableName() string {
return "account_networks"
}
// NewNetwork creates a new Network initializing it with a Serial=0
// It takes a random /16 subnet from 100.64.0.0/10 (64 different subnets)
func NewNetwork() *Network {
func NewNetwork(accountID string) *Network {
n := iplib.NewNet4(net.ParseIP("100.64.0.0"), NetSize)
sub, _ := n.Subnet(SubnetSize)
@@ -129,6 +134,7 @@ func NewNetwork() *Network {
intn := r.Intn(len(sub))
return &Network{
AccountID: accountID,
Identifier: xid.New().String(),
Net: sub[intn].IPNet,
Dns: "",
@@ -151,6 +157,7 @@ func (n *Network) CurrentSerial() uint64 {
func (n *Network) Copy() *Network {
return &Network{
AccountID: n.AccountID,
Identifier: n.Identifier,
Net: n.Net,
Dns: n.Dns,

View File

@@ -8,7 +8,7 @@ import (
)
func TestNewNetwork(t *testing.T) {
network := NewNetwork()
network := NewNetwork("accountID")
// generated net should be a subnet of a larger 100.64.0.0/10 net
ipNet := net.IPNet{IP: net.ParseIP("100.64.0.0"), Mask: net.IPMask{255, 192, 0, 0}}