mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-17 15:56:39 +00:00
add network migration
This commit is contained in:
@@ -1688,7 +1688,7 @@ func (am *DefaultAccountManager) GetAccountSettings(ctx context.Context, account
|
||||
func newAccountWithId(ctx context.Context, accountID, userID, domain string, disableDefaultPolicy bool) *types.Account {
|
||||
log.WithContext(ctx).Debugf("creating new account")
|
||||
|
||||
network := types.NewNetwork()
|
||||
network := types.NewNetwork(accountID)
|
||||
peers := make(map[string]*nbpeer.Peer)
|
||||
users := make(map[string]*types.User)
|
||||
routes := make(map[route.ID]*route.Route)
|
||||
@@ -1792,7 +1792,7 @@ func (am *DefaultAccountManager) GetOrCreateAccountByPrivateDomain(ctx context.C
|
||||
continue
|
||||
}
|
||||
|
||||
network := types.NewNetwork()
|
||||
network := types.NewNetwork(accountId)
|
||||
peers := make(map[string]*nbpeer.Peer)
|
||||
users := make(map[string]*types.User)
|
||||
routes := make(map[route.ID]*route.Route)
|
||||
|
||||
@@ -69,7 +69,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
|
||||
handler := initAccountsTestData(t, &types.Account{
|
||||
Id: accountID,
|
||||
Domain: "hotmail.com",
|
||||
Network: types.NewNetwork(),
|
||||
Network: types.NewNetwork(accountID),
|
||||
Users: map[string]*types.User{
|
||||
adminUser.Id: adminUser,
|
||||
},
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"reflect"
|
||||
"strings"
|
||||
"unicode/utf8"
|
||||
|
||||
@@ -17,6 +18,16 @@ import (
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type LegacyAccountNetwork struct {
|
||||
AccountID string `gorm:"column:id"`
|
||||
Identifier string `gorm:"column:network_identifier"`
|
||||
Net net.IPNet `gorm:"column:network_net;serializer:json"`
|
||||
Dns string `gorm:"column:network_dns"`
|
||||
// Serial is an ID that increments by 1 when any change to the network happened (e.g. new peer has been added).
|
||||
// Used to synchronize state to the client apps.
|
||||
Serial uint64 `gorm:"column:network_serial"`
|
||||
}
|
||||
|
||||
func GetColumnName(db *gorm.DB, column string) string {
|
||||
if db.Name() == "mysql" {
|
||||
return fmt.Sprintf("`%s`", column)
|
||||
@@ -39,6 +50,11 @@ func MigrateFieldFromGobToJSON[T any, S any](ctx context.Context, db *gorm.DB, f
|
||||
return nil
|
||||
}
|
||||
|
||||
if !db.Migrator().HasColumn(&model, oldColumnName) {
|
||||
log.WithContext(ctx).Debugf("Column for %T does not exist, no migration needed", oldColumnName)
|
||||
return nil
|
||||
}
|
||||
|
||||
stmt := &gorm.Statement{DB: db}
|
||||
err := stmt.Parse(model)
|
||||
if err != nil {
|
||||
@@ -471,3 +487,90 @@ func MigrateJsonToTable[T any](ctx context.Context, db *gorm.DB, columnName stri
|
||||
log.WithContext(ctx).Infof("Migration of JSON field %s from table %s into seperte table completed", columnName, tableName)
|
||||
return nil
|
||||
}
|
||||
|
||||
func MigrateEmbeddedToTable[T any, S any, U any](ctx context.Context, db *gorm.DB, pkey string, mapperFunc func(obj S) *U) error {
|
||||
var model T
|
||||
var u U
|
||||
|
||||
log.WithContext(ctx).Debugf("Migrating embedded fields from %T to separate table", model)
|
||||
|
||||
if !db.Migrator().HasTable(&model) {
|
||||
log.WithContext(ctx).Debugf("table for %T does not exist, no migration needed", model)
|
||||
return nil
|
||||
}
|
||||
if db.Migrator().HasTable(&u) {
|
||||
log.WithContext(ctx).Debugf("table for %T already exists, no migration needed", u)
|
||||
return nil
|
||||
}
|
||||
|
||||
stmt := &gorm.Statement{DB: db}
|
||||
err := stmt.Parse(&model)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse model: %w", err)
|
||||
}
|
||||
tableName := stmt.Schema.Table
|
||||
|
||||
if err := db.Transaction(func(tx *gorm.DB) error {
|
||||
var legacyRows []S
|
||||
if err := tx.Table(tableName).Find(&legacyRows).Error; err != nil {
|
||||
log.WithContext(ctx).Errorf("Failed to read legacy accounts: %v", err)
|
||||
return fmt.Errorf("failed to read legacy accounts: %w", err)
|
||||
}
|
||||
|
||||
for _, row := range legacyRows {
|
||||
if err := tx.Create(
|
||||
mapperFunc(row),
|
||||
).Error; err != nil {
|
||||
return fmt.Errorf("failed to insert id %v: %w", row, err)
|
||||
}
|
||||
}
|
||||
|
||||
cols, err := getColumnNamesFromStruct(new(S))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to extract column names: %w", err)
|
||||
}
|
||||
|
||||
for _, col := range cols {
|
||||
if col == pkey {
|
||||
continue
|
||||
}
|
||||
if err := tx.Migrator().DropColumn(&model, col); err != nil {
|
||||
return fmt.Errorf("failed to drop column %s: %w", col, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Infof("Migration of embedded fields %T from table %s into seperte table completed", new(S), tableName)
|
||||
return nil
|
||||
}
|
||||
|
||||
func getColumnNamesFromStruct[T any](model T) ([]string, error) {
|
||||
val := reflect.TypeOf(model)
|
||||
if val.Kind() == reflect.Ptr {
|
||||
val = val.Elem()
|
||||
}
|
||||
|
||||
var cols []string
|
||||
for i := 0; i < val.NumField(); i++ {
|
||||
field := val.Field(i)
|
||||
if field.Name == "ID" {
|
||||
continue // skip primary key
|
||||
}
|
||||
tag := field.Tag.Get("gorm")
|
||||
if tag == "" {
|
||||
continue
|
||||
}
|
||||
// Look for gorm:"column:..."
|
||||
for _, part := range strings.Split(tag, ";") {
|
||||
if strings.HasPrefix(part, "column:") {
|
||||
cols = append(cols, strings.TrimPrefix(part, "column:"))
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
return cols, nil
|
||||
}
|
||||
|
||||
@@ -588,12 +588,12 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
||||
newPeer.DNSLabel = freeLabel
|
||||
newPeer.IP = freeIP
|
||||
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer func() {
|
||||
if unlock != nil {
|
||||
unlock()
|
||||
}
|
||||
}()
|
||||
// unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
// defer func() {
|
||||
// if unlock != nil {
|
||||
// unlock()
|
||||
// }
|
||||
// }()
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
err = transaction.AddPeerToAccount(ctx, store.LockingStrengthUpdate, newPeer)
|
||||
@@ -646,14 +646,14 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
||||
return nil
|
||||
})
|
||||
if err == nil {
|
||||
unlock()
|
||||
unlock = nil
|
||||
// unlock()
|
||||
// unlock = nil
|
||||
break
|
||||
}
|
||||
|
||||
if isUniqueConstraintError(err) {
|
||||
unlock()
|
||||
unlock = nil
|
||||
// unlock()
|
||||
// unlock = nil
|
||||
log.WithContext(ctx).Debugf("Failed to add peer in attempt %d, retrying: %v", attempt, err)
|
||||
continue
|
||||
}
|
||||
@@ -1255,17 +1255,19 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) BufferUpdateAccountPeers(ctx context.Context, accountID string) {
|
||||
mu, _ := am.accountUpdateLocks.LoadOrStore(accountID, &sync.Mutex{})
|
||||
lock := mu.(*sync.Mutex)
|
||||
|
||||
if !lock.TryLock() {
|
||||
return
|
||||
}
|
||||
|
||||
go func() {
|
||||
time.Sleep(time.Duration(am.updateAccountPeersBufferInterval.Load()))
|
||||
lock.Unlock()
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
mu, _ := am.accountUpdateLocks.LoadOrStore(accountID, &sync.Mutex{})
|
||||
lock := mu.(*sync.Mutex)
|
||||
|
||||
if !lock.TryLock() {
|
||||
return
|
||||
}
|
||||
|
||||
go func() {
|
||||
time.Sleep(time.Duration(am.updateAccountPeersBufferInterval.Load()))
|
||||
lock.Unlock()
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
}()
|
||||
}()
|
||||
}
|
||||
|
||||
|
||||
@@ -2145,6 +2145,7 @@ func Test_IsUniqueConstraintError(t *testing.T) {
|
||||
func Test_AddPeer(t *testing.T) {
|
||||
t.Setenv("NETBIRD_STORE_ENGINE", string(types.PostgresStoreEngine))
|
||||
t.Setenv("NB_GET_ACCOUNT_BUFFER_INTERVAL", "300ms")
|
||||
t.Setenv("NB_PEER_UPDATE_BUFFER_INTERVAL", "300ms")
|
||||
manager, err := createManager(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
||||
@@ -96,7 +96,7 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, met
|
||||
return nil, fmt.Errorf("migratePreAuto: %w", err)
|
||||
}
|
||||
err = db.AutoMigrate(
|
||||
&types.SetupKey{}, &nbpeer.Peer{}, &types.User{}, &types.PersonalAccessToken{}, &types.Group{}, &types.GroupPeer{},
|
||||
&types.SetupKey{}, &nbpeer.Peer{}, &types.User{}, &types.PersonalAccessToken{}, &types.Group{}, &types.GroupPeer{}, &types.Network{},
|
||||
&types.Account{}, &types.Policy{}, &types.PolicyRule{}, &route.Route{}, &nbdns.NameServerGroup{},
|
||||
&installation{}, &types.ExtraSettings{}, &posture.Checks{}, &nbpeer.NetworkAddress{},
|
||||
&networkTypes.Network{}, &routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{},
|
||||
@@ -1024,14 +1024,14 @@ func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingSt
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
var accountNetwork types.AccountNetwork
|
||||
if err := tx.Model(&types.Account{}).Where(idQueryCondition, accountID).First(&accountNetwork).Error; err != nil {
|
||||
accountNetwork := types.Network{}
|
||||
if err := tx.Where(accountIDCondition, accountID).First(&accountNetwork).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, status.NewAccountNotFoundError(accountID)
|
||||
}
|
||||
return nil, status.Errorf(status.Internal, "issue getting network from store: %s", err)
|
||||
}
|
||||
return accountNetwork.Network, nil
|
||||
return &accountNetwork, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error) {
|
||||
@@ -1632,7 +1632,7 @@ func (s *SqlStore) DeletePeer(ctx context.Context, lockStrength LockingStrength,
|
||||
|
||||
func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, lockStrength LockingStrength, accountId string) error {
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
Model(&types.Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1"))
|
||||
Model(&types.Network{}).Where(accountIDCondition, accountId).Update("serial", gorm.Expr("serial + 1"))
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to increment network serial count in store: %v", result.Error)
|
||||
return status.Errorf(status.Internal, "failed to increment network serial count in store")
|
||||
|
||||
@@ -2059,7 +2059,7 @@ func TestSqlStore_DeleteNameServerGroup(t *testing.T) {
|
||||
func newAccountWithId(ctx context.Context, accountID, userID, domain string) *types.Account {
|
||||
log.WithContext(ctx).Debugf("creating new account")
|
||||
|
||||
network := types.NewNetwork()
|
||||
network := types.NewNetwork(accountID)
|
||||
peers := make(map[string]*nbpeer.Peer)
|
||||
users := make(map[string]*types.User)
|
||||
routes := make(map[nbroute.ID]*nbroute.Route)
|
||||
|
||||
@@ -361,6 +361,17 @@ func getMigrationsPostAuto(ctx context.Context) []migrationFunc {
|
||||
}
|
||||
})
|
||||
},
|
||||
func(db *gorm.DB) error {
|
||||
return migration.MigrateEmbeddedToTable[types.Account, migration.LegacyAccountNetwork, types.Network](ctx, db, "id", func(obj migration.LegacyAccountNetwork) *types.Network {
|
||||
return &types.Network{
|
||||
AccountID: obj.AccountID,
|
||||
Identifier: obj.Identifier,
|
||||
Net: obj.Net,
|
||||
Serial: obj.Serial,
|
||||
Dns: obj.Dns,
|
||||
}
|
||||
})
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -67,7 +67,7 @@ type Account struct {
|
||||
IsDomainPrimaryAccount bool
|
||||
SetupKeys map[string]*SetupKey `gorm:"-"`
|
||||
SetupKeysG []SetupKey `json:"-" gorm:"foreignKey:AccountID;references:id"`
|
||||
Network *Network `gorm:"embedded;embeddedPrefix:network_"`
|
||||
Network *Network `json:"-" gorm:"foreignKey:AccountID;references:id"`
|
||||
Peers map[string]*nbpeer.Peer `gorm:"-"`
|
||||
PeersG []nbpeer.Peer `json:"-" gorm:"foreignKey:AccountID;references:id"`
|
||||
Users map[string]*User `gorm:"-"`
|
||||
|
||||
@@ -107,7 +107,8 @@ func ipToBytes(ip net.IP) []byte {
|
||||
}
|
||||
|
||||
type Network struct {
|
||||
Identifier string `json:"id"`
|
||||
AccountID string `gorm:"primaryKey"`
|
||||
Identifier string `gorm:"index"`
|
||||
Net net.IPNet `gorm:"serializer:json"`
|
||||
Dns string
|
||||
// Serial is an ID that increments by 1 when any change to the network happened (e.g. new peer has been added).
|
||||
@@ -117,9 +118,13 @@ type Network struct {
|
||||
Mu sync.Mutex `json:"-" gorm:"-"`
|
||||
}
|
||||
|
||||
func (*Network) TableName() string {
|
||||
return "account_networks"
|
||||
}
|
||||
|
||||
// NewNetwork creates a new Network initializing it with a Serial=0
|
||||
// It takes a random /16 subnet from 100.64.0.0/10 (64 different subnets)
|
||||
func NewNetwork() *Network {
|
||||
func NewNetwork(accountID string) *Network {
|
||||
|
||||
n := iplib.NewNet4(net.ParseIP("100.64.0.0"), NetSize)
|
||||
sub, _ := n.Subnet(SubnetSize)
|
||||
@@ -129,6 +134,7 @@ func NewNetwork() *Network {
|
||||
intn := r.Intn(len(sub))
|
||||
|
||||
return &Network{
|
||||
AccountID: accountID,
|
||||
Identifier: xid.New().String(),
|
||||
Net: sub[intn].IPNet,
|
||||
Dns: "",
|
||||
@@ -151,6 +157,7 @@ func (n *Network) CurrentSerial() uint64 {
|
||||
|
||||
func (n *Network) Copy() *Network {
|
||||
return &Network{
|
||||
AccountID: n.AccountID,
|
||||
Identifier: n.Identifier,
|
||||
Net: n.Net,
|
||||
Dns: n.Dns,
|
||||
|
||||
@@ -8,7 +8,7 @@ import (
|
||||
)
|
||||
|
||||
func TestNewNetwork(t *testing.T) {
|
||||
network := NewNetwork()
|
||||
network := NewNetwork("accountID")
|
||||
|
||||
// generated net should be a subnet of a larger 100.64.0.0/10 net
|
||||
ipNet := net.IPNet{IP: net.ParseIP("100.64.0.0"), Mask: net.IPMask{255, 192, 0, 0}}
|
||||
|
||||
Reference in New Issue
Block a user