diff --git a/management/server/account.go b/management/server/account.go index fdaa6ddef..6d5705cce 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -1688,7 +1688,7 @@ func (am *DefaultAccountManager) GetAccountSettings(ctx context.Context, account func newAccountWithId(ctx context.Context, accountID, userID, domain string, disableDefaultPolicy bool) *types.Account { log.WithContext(ctx).Debugf("creating new account") - network := types.NewNetwork() + network := types.NewNetwork(accountID) peers := make(map[string]*nbpeer.Peer) users := make(map[string]*types.User) routes := make(map[route.ID]*route.Route) @@ -1792,7 +1792,7 @@ func (am *DefaultAccountManager) GetOrCreateAccountByPrivateDomain(ctx context.C continue } - network := types.NewNetwork() + network := types.NewNetwork(accountId) peers := make(map[string]*nbpeer.Peer) users := make(map[string]*types.User) routes := make(map[route.ID]*route.Route) diff --git a/management/server/http/handlers/accounts/accounts_handler_test.go b/management/server/http/handlers/accounts/accounts_handler_test.go index a18798743..9f64758c8 100644 --- a/management/server/http/handlers/accounts/accounts_handler_test.go +++ b/management/server/http/handlers/accounts/accounts_handler_test.go @@ -69,7 +69,7 @@ func TestAccounts_AccountsHandler(t *testing.T) { handler := initAccountsTestData(t, &types.Account{ Id: accountID, Domain: "hotmail.com", - Network: types.NewNetwork(), + Network: types.NewNetwork(accountID), Users: map[string]*types.User{ adminUser.Id: adminUser, }, diff --git a/management/server/migration/migration.go b/management/server/migration/migration.go index cd81dbc7d..458cb7d60 100644 --- a/management/server/migration/migration.go +++ b/management/server/migration/migration.go @@ -10,6 +10,7 @@ import ( "errors" "fmt" "net" + "reflect" "strings" "unicode/utf8" @@ -17,6 +18,16 @@ import ( "gorm.io/gorm" ) +type LegacyAccountNetwork struct { + AccountID string `gorm:"column:id"` + Identifier string `gorm:"column:network_identifier"` + Net net.IPNet `gorm:"column:network_net;serializer:json"` + Dns string `gorm:"column:network_dns"` + // Serial is an ID that increments by 1 when any change to the network happened (e.g. new peer has been added). + // Used to synchronize state to the client apps. + Serial uint64 `gorm:"column:network_serial"` +} + func GetColumnName(db *gorm.DB, column string) string { if db.Name() == "mysql" { return fmt.Sprintf("`%s`", column) @@ -39,6 +50,11 @@ func MigrateFieldFromGobToJSON[T any, S any](ctx context.Context, db *gorm.DB, f return nil } + if !db.Migrator().HasColumn(&model, oldColumnName) { + log.WithContext(ctx).Debugf("Column for %T does not exist, no migration needed", oldColumnName) + return nil + } + stmt := &gorm.Statement{DB: db} err := stmt.Parse(model) if err != nil { @@ -471,3 +487,90 @@ func MigrateJsonToTable[T any](ctx context.Context, db *gorm.DB, columnName stri log.WithContext(ctx).Infof("Migration of JSON field %s from table %s into seperte table completed", columnName, tableName) return nil } + +func MigrateEmbeddedToTable[T any, S any, U any](ctx context.Context, db *gorm.DB, pkey string, mapperFunc func(obj S) *U) error { + var model T + var u U + + log.WithContext(ctx).Debugf("Migrating embedded fields from %T to separate table", model) + + if !db.Migrator().HasTable(&model) { + log.WithContext(ctx).Debugf("table for %T does not exist, no migration needed", model) + return nil + } + if db.Migrator().HasTable(&u) { + log.WithContext(ctx).Debugf("table for %T already exists, no migration needed", u) + return nil + } + + stmt := &gorm.Statement{DB: db} + err := stmt.Parse(&model) + if err != nil { + return fmt.Errorf("parse model: %w", err) + } + tableName := stmt.Schema.Table + + if err := db.Transaction(func(tx *gorm.DB) error { + var legacyRows []S + if err := tx.Table(tableName).Find(&legacyRows).Error; err != nil { + log.WithContext(ctx).Errorf("Failed to read legacy accounts: %v", err) + return fmt.Errorf("failed to read legacy accounts: %w", err) + } + + for _, row := range legacyRows { + if err := tx.Create( + mapperFunc(row), + ).Error; err != nil { + return fmt.Errorf("failed to insert id %v: %w", row, err) + } + } + + cols, err := getColumnNamesFromStruct(new(S)) + if err != nil { + return fmt.Errorf("failed to extract column names: %w", err) + } + + for _, col := range cols { + if col == pkey { + continue + } + if err := tx.Migrator().DropColumn(&model, col); err != nil { + return fmt.Errorf("failed to drop column %s: %w", col, err) + } + } + + return nil + }); err != nil { + return err + } + + log.WithContext(ctx).Infof("Migration of embedded fields %T from table %s into seperte table completed", new(S), tableName) + return nil +} + +func getColumnNamesFromStruct[T any](model T) ([]string, error) { + val := reflect.TypeOf(model) + if val.Kind() == reflect.Ptr { + val = val.Elem() + } + + var cols []string + for i := 0; i < val.NumField(); i++ { + field := val.Field(i) + if field.Name == "ID" { + continue // skip primary key + } + tag := field.Tag.Get("gorm") + if tag == "" { + continue + } + // Look for gorm:"column:..." + for _, part := range strings.Split(tag, ";") { + if strings.HasPrefix(part, "column:") { + cols = append(cols, strings.TrimPrefix(part, "column:")) + break + } + } + } + return cols, nil +} diff --git a/management/server/peer.go b/management/server/peer.go index 77d4b024c..2dcaa8009 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -588,12 +588,12 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s newPeer.DNSLabel = freeLabel newPeer.IP = freeIP - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer func() { - if unlock != nil { - unlock() - } - }() + // unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + // defer func() { + // if unlock != nil { + // unlock() + // } + // }() err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { err = transaction.AddPeerToAccount(ctx, store.LockingStrengthUpdate, newPeer) @@ -646,14 +646,14 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s return nil }) if err == nil { - unlock() - unlock = nil + // unlock() + // unlock = nil break } if isUniqueConstraintError(err) { - unlock() - unlock = nil + // unlock() + // unlock = nil log.WithContext(ctx).Debugf("Failed to add peer in attempt %d, retrying: %v", attempt, err) continue } @@ -1255,17 +1255,19 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account } func (am *DefaultAccountManager) BufferUpdateAccountPeers(ctx context.Context, accountID string) { - mu, _ := am.accountUpdateLocks.LoadOrStore(accountID, &sync.Mutex{}) - lock := mu.(*sync.Mutex) - - if !lock.TryLock() { - return - } - go func() { - time.Sleep(time.Duration(am.updateAccountPeersBufferInterval.Load())) - lock.Unlock() - am.UpdateAccountPeers(ctx, accountID) + mu, _ := am.accountUpdateLocks.LoadOrStore(accountID, &sync.Mutex{}) + lock := mu.(*sync.Mutex) + + if !lock.TryLock() { + return + } + + go func() { + time.Sleep(time.Duration(am.updateAccountPeersBufferInterval.Load())) + lock.Unlock() + am.UpdateAccountPeers(ctx, accountID) + }() }() } diff --git a/management/server/peer_test.go b/management/server/peer_test.go index 67cc6abeb..9a3f36547 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -2145,6 +2145,7 @@ func Test_IsUniqueConstraintError(t *testing.T) { func Test_AddPeer(t *testing.T) { t.Setenv("NETBIRD_STORE_ENGINE", string(types.PostgresStoreEngine)) t.Setenv("NB_GET_ACCOUNT_BUFFER_INTERVAL", "300ms") + t.Setenv("NB_PEER_UPDATE_BUFFER_INTERVAL", "300ms") manager, err := createManager(t) if err != nil { t.Fatal(err) diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index 2b1073c33..3946706fa 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -96,7 +96,7 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, met return nil, fmt.Errorf("migratePreAuto: %w", err) } err = db.AutoMigrate( - &types.SetupKey{}, &nbpeer.Peer{}, &types.User{}, &types.PersonalAccessToken{}, &types.Group{}, &types.GroupPeer{}, + &types.SetupKey{}, &nbpeer.Peer{}, &types.User{}, &types.PersonalAccessToken{}, &types.Group{}, &types.GroupPeer{}, &types.Network{}, &types.Account{}, &types.Policy{}, &types.PolicyRule{}, &route.Route{}, &nbdns.NameServerGroup{}, &installation{}, &types.ExtraSettings{}, &posture.Checks{}, &nbpeer.NetworkAddress{}, &networkTypes.Network{}, &routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{}, @@ -1024,14 +1024,14 @@ func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingSt tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) } - var accountNetwork types.AccountNetwork - if err := tx.Model(&types.Account{}).Where(idQueryCondition, accountID).First(&accountNetwork).Error; err != nil { + accountNetwork := types.Network{} + if err := tx.Where(accountIDCondition, accountID).First(&accountNetwork).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, status.NewAccountNotFoundError(accountID) } return nil, status.Errorf(status.Internal, "issue getting network from store: %s", err) } - return accountNetwork.Network, nil + return &accountNetwork, nil } func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error) { @@ -1632,7 +1632,7 @@ func (s *SqlStore) DeletePeer(ctx context.Context, lockStrength LockingStrength, func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, lockStrength LockingStrength, accountId string) error { result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). - Model(&types.Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1")) + Model(&types.Network{}).Where(accountIDCondition, accountId).Update("serial", gorm.Expr("serial + 1")) if result.Error != nil { log.WithContext(ctx).Errorf("failed to increment network serial count in store: %v", result.Error) return status.Errorf(status.Internal, "failed to increment network serial count in store") diff --git a/management/server/store/sql_store_test.go b/management/server/store/sql_store_test.go index 9b5101c79..0d3b30f66 100644 --- a/management/server/store/sql_store_test.go +++ b/management/server/store/sql_store_test.go @@ -2059,7 +2059,7 @@ func TestSqlStore_DeleteNameServerGroup(t *testing.T) { func newAccountWithId(ctx context.Context, accountID, userID, domain string) *types.Account { log.WithContext(ctx).Debugf("creating new account") - network := types.NewNetwork() + network := types.NewNetwork(accountID) peers := make(map[string]*nbpeer.Peer) users := make(map[string]*types.User) routes := make(map[nbroute.ID]*nbroute.Route) diff --git a/management/server/store/store.go b/management/server/store/store.go index fad12ce77..b80fe987c 100644 --- a/management/server/store/store.go +++ b/management/server/store/store.go @@ -361,6 +361,17 @@ func getMigrationsPostAuto(ctx context.Context) []migrationFunc { } }) }, + func(db *gorm.DB) error { + return migration.MigrateEmbeddedToTable[types.Account, migration.LegacyAccountNetwork, types.Network](ctx, db, "id", func(obj migration.LegacyAccountNetwork) *types.Network { + return &types.Network{ + AccountID: obj.AccountID, + Identifier: obj.Identifier, + Net: obj.Net, + Serial: obj.Serial, + Dns: obj.Dns, + } + }) + }, } } diff --git a/management/server/types/account.go b/management/server/types/account.go index 5a62ee4c6..13b9b5107 100644 --- a/management/server/types/account.go +++ b/management/server/types/account.go @@ -67,7 +67,7 @@ type Account struct { IsDomainPrimaryAccount bool SetupKeys map[string]*SetupKey `gorm:"-"` SetupKeysG []SetupKey `json:"-" gorm:"foreignKey:AccountID;references:id"` - Network *Network `gorm:"embedded;embeddedPrefix:network_"` + Network *Network `json:"-" gorm:"foreignKey:AccountID;references:id"` Peers map[string]*nbpeer.Peer `gorm:"-"` PeersG []nbpeer.Peer `json:"-" gorm:"foreignKey:AccountID;references:id"` Users map[string]*User `gorm:"-"` diff --git a/management/server/types/network.go b/management/server/types/network.go index eb8415264..317b1a8cc 100644 --- a/management/server/types/network.go +++ b/management/server/types/network.go @@ -107,7 +107,8 @@ func ipToBytes(ip net.IP) []byte { } type Network struct { - Identifier string `json:"id"` + AccountID string `gorm:"primaryKey"` + Identifier string `gorm:"index"` Net net.IPNet `gorm:"serializer:json"` Dns string // Serial is an ID that increments by 1 when any change to the network happened (e.g. new peer has been added). @@ -117,9 +118,13 @@ type Network struct { Mu sync.Mutex `json:"-" gorm:"-"` } +func (*Network) TableName() string { + return "account_networks" +} + // NewNetwork creates a new Network initializing it with a Serial=0 // It takes a random /16 subnet from 100.64.0.0/10 (64 different subnets) -func NewNetwork() *Network { +func NewNetwork(accountID string) *Network { n := iplib.NewNet4(net.ParseIP("100.64.0.0"), NetSize) sub, _ := n.Subnet(SubnetSize) @@ -129,6 +134,7 @@ func NewNetwork() *Network { intn := r.Intn(len(sub)) return &Network{ + AccountID: accountID, Identifier: xid.New().String(), Net: sub[intn].IPNet, Dns: "", @@ -151,6 +157,7 @@ func (n *Network) CurrentSerial() uint64 { func (n *Network) Copy() *Network { return &Network{ + AccountID: n.AccountID, Identifier: n.Identifier, Net: n.Net, Dns: n.Dns, diff --git a/management/server/types/network_test.go b/management/server/types/network_test.go index d0b0894d4..ac6429cf5 100644 --- a/management/server/types/network_test.go +++ b/management/server/types/network_test.go @@ -8,7 +8,7 @@ import ( ) func TestNewNetwork(t *testing.T) { - network := NewNetwork() + network := NewNetwork("accountID") // generated net should be a subnet of a larger 100.64.0.0/10 net ipNet := net.IPNet{IP: net.ParseIP("100.64.0.0"), Mask: net.IPMask{255, 192, 0, 0}}