diff --git a/dns/nameserver.go b/dns/nameserver.go index 81c616c50..c05c09715 100644 --- a/dns/nameserver.go +++ b/dns/nameserver.go @@ -53,6 +53,9 @@ type NameServerGroup struct { ID string `gorm:"primaryKey"` // AccountID is a reference to Account that this object belongs AccountID string `gorm:"index"` + // AccountSeqID is a per-account monotonically increasing identifier used as the + // compact wire id when sending NetworkMap components to capable peers. + AccountSeqID uint32 `json:"-" gorm:"index:idx_nameserver_groups_account_seq_id;not null;default:0"` // Name group name Name string // Description group description diff --git a/management/server/account.go b/management/server/account.go index 8e4e595f0..bfb1bad37 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -1621,6 +1621,14 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth return nil } + for _, g := range newGroupsToCreate { + seq, err := transaction.AllocateAccountSeqID(ctx, userAuth.AccountId, types.AccountSeqEntityGroup) + if err != nil { + return fmt.Errorf("error allocating group seq id: %w", err) + } + g.AccountSeqID = seq + } + if err = transaction.CreateGroups(ctx, userAuth.AccountId, newGroupsToCreate); err != nil { return fmt.Errorf("error saving groups: %w", err) } diff --git a/management/server/account_test.go b/management/server/account_test.go index ba621030c..31da3f533 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -3036,6 +3036,16 @@ func TestAccount_SetJWTGroups(t *testing.T) { user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user2") assert.NoError(t, err, "unable to get user") assert.Len(t, user.AutoGroups, 1, "new group should be added") + + var newJWTGroup *types.Group + for _, g := range groups { + if g.Name == "group3" { + newJWTGroup = g + break + } + } + require.NotNil(t, newJWTGroup, "JIT-created JWT group not found") + assert.NotZero(t, newJWTGroup.AccountSeqID, "JIT-created JWT group must have a non-zero AccountSeqID") }) t.Run("remove all JWT groups when list is empty", func(t *testing.T) { diff --git a/management/server/group.go b/management/server/group.go index 870a441ac..dd1068ee1 100644 --- a/management/server/group.go +++ b/management/server/group.go @@ -96,6 +96,12 @@ func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, use return err } + seq, err := transaction.AllocateAccountSeqID(ctx, accountID, types.AccountSeqEntityGroup) + if err != nil { + return status.Errorf(status.Internal, "failed to allocate group seq id: %v", err) + } + newGroup.AccountSeqID = seq + if err := transaction.CreateGroup(ctx, newGroup); err != nil { return status.Errorf(status.Internal, "failed to create group: %v", err) } @@ -170,6 +176,8 @@ func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, use return err } + newGroup.AccountSeqID = oldGroup.AccountSeqID + if err = transaction.UpdateGroup(ctx, newGroup); err != nil { return err } @@ -221,6 +229,12 @@ func (am *DefaultAccountManager) CreateGroups(ctx context.Context, accountID, us newGroup.AccountID = accountID + seq, err := transaction.AllocateAccountSeqID(ctx, accountID, types.AccountSeqEntityGroup) + if err != nil { + return err + } + newGroup.AccountSeqID = seq + if err = transaction.CreateGroup(ctx, newGroup); err != nil { return err } @@ -320,6 +334,12 @@ func (am *DefaultAccountManager) updateSingleGroup(ctx context.Context, accountI newGroup.AccountID = accountID + oldGroup, err := transaction.GetGroupByID(ctx, store.LockingStrengthNone, accountID, newGroup.ID) + if err != nil { + return err + } + newGroup.AccountSeqID = oldGroup.AccountSeqID + if err := transaction.UpdateGroup(ctx, newGroup); err != nil { return err } diff --git a/management/server/migration/account_seq.go b/management/server/migration/account_seq.go new file mode 100644 index 000000000..ce8095d25 --- /dev/null +++ b/management/server/migration/account_seq.go @@ -0,0 +1,156 @@ +package migration + +import ( + "context" + "fmt" + + log "github.com/sirupsen/logrus" + "gorm.io/gorm" + + "github.com/netbirdio/netbird/management/server/types" +) + +// BackfillAccountSeqIDs assigns a deterministic per-account sequential id to all +// rows of `model` whose account_seq_id is zero, then seeds account_seq_counters +// with the next free id per account. Idempotent: safe to re-run; both steps +// no-op once everything is consistent. +// +// Implemented as two table-wide SQL statements with window functions, one +// transaction. Backfilling 246k rows across 154k accounts on Postgres takes +// well under a second instead of the per-account-loop ~2 minutes. +// +// orderColumn is the column to use when assigning the deterministic ordering +// (typically the primary-key string id). +func BackfillAccountSeqIDs[T any]( + ctx context.Context, + db *gorm.DB, + entity types.AccountSeqEntity, + orderColumn string, +) error { + var model T + if !db.Migrator().HasTable(&model) { + log.WithContext(ctx).Debugf("backfill seq id: table for %T missing, skip", model) + return nil + } + + stmt := &gorm.Statement{DB: db} + if err := stmt.Parse(&model); err != nil { + return fmt.Errorf("parse model: %w", err) + } + table := quoteIdent(db, stmt.Schema.Table) + orderCol := quoteIdent(db, orderColumn) + + return db.Transaction(func(tx *gorm.DB) error { + var pending int64 + if err := tx.Raw( + fmt.Sprintf("SELECT count(*) FROM %s WHERE account_seq_id IS NULL OR account_seq_id = 0", table), + ).Scan(&pending).Error; err != nil { + return fmt.Errorf("count pending on %s: %w", table, err) + } + + if pending > 0 { + log.WithContext(ctx).Infof("backfill seq id: %s — %d rows pending", table, pending) + if err := backfillRankSQL(tx, table, orderCol); err != nil { + return fmt.Errorf("rank %s: %w", table, err) + } + } + + if err := seedCountersSQL(tx, table, entity); err != nil { + return fmt.Errorf("seed counters for %s: %w", entity, err) + } + return nil + }) +} + +func quoteIdent(db *gorm.DB, name string) string { + switch db.Dialector.Name() { + case "mysql": + return "`" + name + "`" + case "postgres": + return `"` + name + `"` + default: + return name + } +} + +func backfillRankSQL(db *gorm.DB, table, orderCol string) error { + dialect := db.Dialector.Name() + var sql string + switch dialect { + case "postgres", "sqlite": + sql = fmt.Sprintf(` +WITH max_seq AS ( + SELECT account_id, COALESCE(MAX(account_seq_id), 0) AS max_seq + FROM %s + GROUP BY account_id +), +ranked AS ( + SELECT p.id, + m.max_seq + ROW_NUMBER() OVER (PARTITION BY p.account_id ORDER BY p.%s) AS new_seq + FROM %s p + JOIN max_seq m ON p.account_id = m.account_id + WHERE p.account_seq_id IS NULL OR p.account_seq_id = 0 +) +UPDATE %s SET account_seq_id = ranked.new_seq +FROM ranked +WHERE %s.id = ranked.id +`, table, orderCol, table, table, table) + case "mysql": + sql = fmt.Sprintf(` +UPDATE %s p +JOIN ( + SELECT account_id, COALESCE(MAX(account_seq_id), 0) AS max_seq + FROM %s + GROUP BY account_id +) m ON p.account_id = m.account_id +JOIN ( + SELECT id, ROW_NUMBER() OVER (PARTITION BY account_id ORDER BY %s) AS rn + FROM %s + WHERE account_seq_id IS NULL OR account_seq_id = 0 +) r ON p.id = r.id +SET p.account_seq_id = m.max_seq + r.rn +`, table, table, orderCol, table) + default: + return fmt.Errorf("unsupported dialect: %s", dialect) + } + return db.Exec(sql).Error +} + +func seedCountersSQL(db *gorm.DB, table string, entity types.AccountSeqEntity) error { + dialect := db.Dialector.Name() + var sql string + switch dialect { + case "postgres": + sql = fmt.Sprintf(` +INSERT INTO account_seq_counters (account_id, entity, next_id) +SELECT account_id, ?, MAX(account_seq_id) + 1 +FROM %s +WHERE account_seq_id IS NOT NULL AND account_seq_id > 0 +GROUP BY account_id +ON CONFLICT (account_id, entity) DO UPDATE + SET next_id = GREATEST(account_seq_counters.next_id, EXCLUDED.next_id) +`, table) + case "sqlite": + sql = fmt.Sprintf(` +INSERT INTO account_seq_counters (account_id, entity, next_id) +SELECT account_id, ?, MAX(account_seq_id) + 1 +FROM %s +WHERE account_seq_id IS NOT NULL AND account_seq_id > 0 +GROUP BY account_id +ON CONFLICT (account_id, entity) DO UPDATE + SET next_id = max(account_seq_counters.next_id, excluded.next_id) +`, table) + case "mysql": + sql = fmt.Sprintf(` +INSERT INTO account_seq_counters (account_id, entity, next_id) +SELECT account_id, ?, MAX(account_seq_id) + 1 +FROM %s +WHERE account_seq_id IS NOT NULL AND account_seq_id > 0 +GROUP BY account_id +ON DUPLICATE KEY UPDATE next_id = GREATEST(next_id, VALUES(next_id)) +`, table) + default: + return fmt.Errorf("unsupported dialect: %s", dialect) + } + return db.Exec(sql, string(entity)).Error +} diff --git a/management/server/nameserver.go b/management/server/nameserver.go index 5859bfb0d..4277d0996 100644 --- a/management/server/nameserver.go +++ b/management/server/nameserver.go @@ -69,6 +69,12 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco return err } + seq, err := transaction.AllocateAccountSeqID(ctx, accountID, types.AccountSeqEntityNameserverGroup) + if err != nil { + return err + } + newNSGroup.AccountSeqID = seq + if err = transaction.SaveNameServerGroup(ctx, newNSGroup); err != nil { return err } @@ -120,6 +126,8 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun return err } + nsGroupToSave.AccountSeqID = oldNSGroup.AccountSeqID + if err = transaction.SaveNameServerGroup(ctx, nsGroupToSave); err != nil { return err } diff --git a/management/server/networks/resources/manager.go b/management/server/networks/resources/manager.go index 5a0e26533..fb098a569 100644 --- a/management/server/networks/resources/manager.go +++ b/management/server/networks/resources/manager.go @@ -125,6 +125,12 @@ func (m *managerImpl) CreateResource(ctx context.Context, userID string, resourc return fmt.Errorf("failed to get network: %w", err) } + seq, err := transaction.AllocateAccountSeqID(ctx, resource.AccountID, nbtypes.AccountSeqEntityNetworkResource) + if err != nil { + return fmt.Errorf("failed to allocate network resource seq id: %w", err) + } + resource.AccountSeqID = seq + err = transaction.SaveNetworkResource(ctx, resource) if err != nil { return fmt.Errorf("failed to save network resource: %w", err) @@ -231,6 +237,7 @@ func (m *managerImpl) UpdateResource(ctx context.Context, userID string, resourc if err != nil { return fmt.Errorf("failed to get network resource: %w", err) } + resource.AccountSeqID = oldResource.AccountSeqID err = transaction.SaveNetworkResource(ctx, resource) if err != nil { diff --git a/management/server/networks/resources/types/resource.go b/management/server/networks/resources/types/resource.go index 1fa908393..1ced3ab91 100644 --- a/management/server/networks/resources/types/resource.go +++ b/management/server/networks/resources/types/resource.go @@ -32,6 +32,9 @@ type NetworkResource struct { ID string `gorm:"primaryKey"` NetworkID string `gorm:"index"` AccountID string `gorm:"index"` + // AccountSeqID is a per-account monotonically increasing identifier used as the + // compact wire id when sending NetworkMap components to capable peers. + AccountSeqID uint32 `json:"-" gorm:"index:idx_network_resources_account_seq_id;not null;default:0"` Name string Description string Type NetworkResourceType diff --git a/management/server/networks/routers/manager.go b/management/server/networks/routers/manager.go index c7c3f2ff4..3a985a5b0 100644 --- a/management/server/networks/routers/manager.go +++ b/management/server/networks/routers/manager.go @@ -102,6 +102,12 @@ func (m *managerImpl) CreateRouter(ctx context.Context, userID string, router *t router.ID = xid.New().String() + seq, err := transaction.AllocateAccountSeqID(ctx, router.AccountID, serverTypes.AccountSeqEntityNetworkRouter) + if err != nil { + return fmt.Errorf("failed to allocate network router seq id: %w", err) + } + router.AccountSeqID = seq + err = transaction.SaveNetworkRouter(ctx, router) if err != nil { return fmt.Errorf("failed to create network router: %w", err) @@ -166,6 +172,12 @@ func (m *managerImpl) UpdateRouter(ctx context.Context, userID string, router *t return status.NewRouterNotPartOfNetworkError(router.ID, router.NetworkID) } + oldRouter, err := transaction.GetNetworkRouterByID(ctx, store.LockingStrengthNone, router.AccountID, router.ID) + if err != nil { + return fmt.Errorf("failed to get existing network router: %w", err) + } + router.AccountSeqID = oldRouter.AccountSeqID + err = transaction.SaveNetworkRouter(ctx, router) if err != nil { return fmt.Errorf("failed to update network router: %w", err) diff --git a/management/server/networks/routers/manager_test.go b/management/server/networks/routers/manager_test.go index 6be90baa7..e89fc323c 100644 --- a/management/server/networks/routers/manager_test.go +++ b/management/server/networks/routers/manager_test.go @@ -195,6 +195,7 @@ func Test_UpdateRouterSuccessfully(t *testing.T) { if err != nil { require.NoError(t, err) } + router.ID = "testRouterId" s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) if err != nil { diff --git a/management/server/networks/routers/types/router.go b/management/server/networks/routers/types/router.go index 1293a9934..7325599b2 100644 --- a/management/server/networks/routers/types/router.go +++ b/management/server/networks/routers/types/router.go @@ -13,6 +13,9 @@ type NetworkRouter struct { ID string `gorm:"primaryKey"` NetworkID string `gorm:"index"` AccountID string `gorm:"index"` + // AccountSeqID is a per-account monotonically increasing identifier used as the + // compact wire id when sending NetworkMap components to capable peers. + AccountSeqID uint32 `json:"-" gorm:"index:idx_network_routers_account_seq_id;not null;default:0"` Peer string PeerGroups []string `gorm:"serializer:json"` Masquerade bool diff --git a/management/server/policy.go b/management/server/policy.go index 40f3908e3..e33a33a43 100644 --- a/management/server/policy.go +++ b/management/server/policy.go @@ -69,6 +69,8 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user return err } + policy.AccountSeqID = existingPolicy.AccountSeqID + if err = transaction.SavePolicy(ctx, policy); err != nil { return err } @@ -78,6 +80,12 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user return err } + seq, err := transaction.AllocateAccountSeqID(ctx, accountID, types.AccountSeqEntityPolicy) + if err != nil { + return err + } + policy.AccountSeqID = seq + if err = transaction.CreatePolicy(ctx, policy); err != nil { return err } diff --git a/management/server/route.go b/management/server/route.go index a9561faf0..4d87b68de 100644 --- a/management/server/route.go +++ b/management/server/route.go @@ -178,6 +178,12 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri return err } + seq, err := transaction.AllocateAccountSeqID(ctx, accountID, types.AccountSeqEntityRoute) + if err != nil { + return err + } + newRoute.AccountSeqID = seq + if err = transaction.SaveRoute(ctx, newRoute); err != nil { return err } @@ -231,6 +237,7 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI return err } routeToSave.AccountID = accountID + routeToSave.AccountSeqID = oldRoute.AccountSeqID if err = transaction.SaveRoute(ctx, routeToSave); err != nil { return err diff --git a/management/server/store/account_seq_test.go b/management/server/store/account_seq_test.go new file mode 100644 index 000000000..4a8134559 --- /dev/null +++ b/management/server/store/account_seq_test.go @@ -0,0 +1,465 @@ +package store + +import ( + "context" + "errors" + "net/netip" + "testing" + + "github.com/stretchr/testify/require" + + nbdns "github.com/netbirdio/netbird/dns" + resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" + routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" + "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/route" +) + +var errRollback = errors.New("intentional rollback") + +func TestAllocateAccountSeqID_SequentialPerAccount(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + ctx := context.Background() + const accA = "acc-a" + const accB = "acc-b" + + require.NoError(t, store.ExecuteInTransaction(ctx, func(tx Store) error { + got, err := tx.AllocateAccountSeqID(ctx, accA, types.AccountSeqEntityPolicy) + require.NoError(t, err) + require.Equal(t, uint32(1), got) + + got, err = tx.AllocateAccountSeqID(ctx, accA, types.AccountSeqEntityPolicy) + require.NoError(t, err) + require.Equal(t, uint32(2), got) + + got, err = tx.AllocateAccountSeqID(ctx, accB, types.AccountSeqEntityPolicy) + require.NoError(t, err) + require.Equal(t, uint32(1), got, "different account starts from 1") + + got, err = tx.AllocateAccountSeqID(ctx, accA, types.AccountSeqEntityGroup) + require.NoError(t, err) + require.Equal(t, uint32(1), got, "different entity starts from 1") + + return nil + })) + + require.NoError(t, store.ExecuteInTransaction(ctx, func(tx Store) error { + got, err := tx.AllocateAccountSeqID(ctx, accA, types.AccountSeqEntityPolicy) + require.NoError(t, err) + require.Equal(t, uint32(3), got, "counter persists across transactions") + return nil + })) +} + +func TestPolicyBackfill_AssignsSeqIDsToExistingPolicies(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + ctx := context.Background() + const accountID = "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + policies, err := store.GetAccountPolicies(ctx, LockingStrengthNone, accountID) + require.NoError(t, err) + require.NotEmpty(t, policies, "test fixture must have policies") + + seen := make(map[uint32]bool) + for _, p := range policies { + require.NotZero(t, p.AccountSeqID, "policy %s must have a non-zero AccountSeqID after migration", p.ID) + require.False(t, seen[p.AccountSeqID], "duplicate AccountSeqID %d in account %s", p.AccountSeqID, accountID) + seen[p.AccountSeqID] = true + } +} + +func TestPolicyUpdate_PreservesSeqID(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + ctx := context.Background() + const accountID = "bf1c8084-ba50-4ce7-9439-34653001fc3b" + const policyID = "cs1tnh0hhcjnqoiuebf0" + + original, err := store.GetPolicyByID(ctx, LockingStrengthNone, accountID, policyID) + require.NoError(t, err) + originalSeq := original.AccountSeqID + require.NotZero(t, originalSeq, "fixture must have non-zero AccountSeqID after backfill") + + updated := &types.Policy{ + ID: policyID, + AccountID: accountID, + Name: "renamed", + Enabled: false, + Rules: original.Rules, + } + require.Zero(t, updated.AccountSeqID, "incoming struct should have zero AccountSeqID like an HTTP handler would") + + require.NoError(t, store.SavePolicy(ctx, updated)) + + got, err := store.GetPolicyByID(ctx, LockingStrengthNone, accountID, policyID) + require.NoError(t, err) + require.Equal(t, originalSeq, got.AccountSeqID, "AccountSeqID must not be reset by update path") + require.Equal(t, "renamed", got.Name) +} + +func TestGroupUpdate_PreservesSeqID(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + ctx := context.Background() + const accountID = "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + groups, err := store.GetAccountGroups(ctx, LockingStrengthNone, accountID) + require.NoError(t, err) + require.NotEmpty(t, groups) + + original := groups[0] + originalSeq := original.AccountSeqID + require.NotZero(t, originalSeq) + + updated := &types.Group{ + ID: original.ID, + AccountID: accountID, + Name: "renamed", + Issued: original.Issued, + } + require.Zero(t, updated.AccountSeqID) + + require.NoError(t, store.UpdateGroup(ctx, updated)) + + got, err := store.GetGroupByID(ctx, LockingStrengthNone, accountID, original.ID) + require.NoError(t, err) + require.Equal(t, originalSeq, got.AccountSeqID, "AccountSeqID must not be reset by UpdateGroup") + require.Equal(t, "renamed", got.Name) +} + +func TestSaveAccount_AllocatesSeqIDsForDefaultGroupAndPolicy(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + ctx := context.Background() + const accountID = "save-account-seqid-test" + + account := &types.Account{ + Id: accountID, + CreatedBy: "user1", + Domain: "example.test", + DNSSettings: types.DNSSettings{}, + Settings: &types.Settings{}, + Network: &types.Network{ + Identifier: "net-test", + }, + Users: map[string]*types.User{ + "user1": {Id: "user1", AccountID: accountID, Role: types.UserRoleOwner}, + }, + } + require.NoError(t, account.AddAllGroup(false), "AddAllGroup should populate default Group + Policy") + require.Len(t, account.Groups, 1, "default 'All' group must be present") + require.Len(t, account.Policies, 1, "default policy must be present") + + for _, g := range account.Groups { + require.Zero(t, g.AccountSeqID, "default group must start with seq=0") + } + require.Zero(t, account.Policies[0].AccountSeqID, "default policy must start with seq=0") + + require.NoError(t, store.SaveAccount(ctx, account)) + + groups, err := store.GetAccountGroups(ctx, LockingStrengthNone, accountID) + require.NoError(t, err) + require.Len(t, groups, 1) + require.NotZerof(t, groups[0].AccountSeqID, "default group must have seq>0 after SaveAccount") + + policies, err := store.GetAccountPolicies(ctx, LockingStrengthNone, accountID) + require.NoError(t, err) + require.Len(t, policies, 1) + require.NotZerof(t, policies[0].AccountSeqID, "default policy must have seq>0 after SaveAccount") + + require.ErrorIs(t, store.ExecuteInTransaction(ctx, func(tx Store) error { + next, err := tx.AllocateAccountSeqID(ctx, accountID, types.AccountSeqEntityGroup) + require.NoError(t, err) + require.Equal(t, groups[0].AccountSeqID+1, next, "next group seq must be max+1") + + next, err = tx.AllocateAccountSeqID(ctx, accountID, types.AccountSeqEntityPolicy) + require.NoError(t, err) + require.Equal(t, policies[0].AccountSeqID+1, next, "next policy seq must be max+1") + return errRollback + }), errRollback) +} + +func TestSaveAccount_PreservesExistingSeqIDs(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + ctx := context.Background() + const accountID = "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + account, err := store.GetAccount(ctx, accountID) + require.NoError(t, err) + + groupSeqs := make(map[string]uint32) + policySeqs := make(map[string]uint32) + routeSeqs := make(map[route.ID]uint32) + nsgSeqs := make(map[string]uint32) + resourceSeqs := make(map[string]uint32) + routerSeqs := make(map[string]uint32) + + for _, g := range account.Groups { + require.NotZero(t, g.AccountSeqID, "fixture group must have seq>0 after backfill") + groupSeqs[g.ID] = g.AccountSeqID + } + for _, p := range account.Policies { + require.NotZero(t, p.AccountSeqID, "fixture policy must have seq>0") + policySeqs[p.ID] = p.AccountSeqID + } + for _, r := range account.Routes { + require.NotZero(t, r.AccountSeqID, "fixture route must have seq>0") + routeSeqs[r.ID] = r.AccountSeqID + } + for _, n := range account.NameServerGroups { + require.NotZero(t, n.AccountSeqID, "fixture name_server_group must have seq>0") + nsgSeqs[n.ID] = n.AccountSeqID + } + for _, nr := range account.NetworkResources { + require.NotZero(t, nr.AccountSeqID, "fixture network_resource must have seq>0") + resourceSeqs[nr.ID] = nr.AccountSeqID + } + for _, nr := range account.NetworkRouters { + require.NotZero(t, nr.AccountSeqID, "fixture network_router must have seq>0") + routerSeqs[nr.ID] = nr.AccountSeqID + } + + require.NoError(t, store.SaveAccount(ctx, account)) + + after, err := store.GetAccount(ctx, accountID) + require.NoError(t, err) + for _, g := range after.Groups { + require.Equal(t, groupSeqs[g.ID], g.AccountSeqID, "group %s seq must be preserved on re-save", g.ID) + } + for _, p := range after.Policies { + require.Equal(t, policySeqs[p.ID], p.AccountSeqID, "policy %s seq must be preserved", p.ID) + } + for _, r := range after.Routes { + require.Equal(t, routeSeqs[r.ID], r.AccountSeqID, "route %s seq must be preserved (slice-of-value addressability)", r.ID) + } + for _, n := range after.NameServerGroups { + require.Equal(t, nsgSeqs[n.ID], n.AccountSeqID, "name_server_group %s seq must be preserved (slice-of-value addressability)", n.ID) + } + for _, nr := range after.NetworkResources { + require.Equal(t, resourceSeqs[nr.ID], nr.AccountSeqID, "network_resource %s seq must be preserved", nr.ID) + } + for _, nr := range after.NetworkRouters { + require.Equal(t, routerSeqs[nr.ID], nr.AccountSeqID, "network_router %s seq must be preserved", nr.ID) + } +} + +func TestSaveAccount_AllocatesSeqIDsForAllEntityTypes(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + ctx := context.Background() + const accountID = "save-account-all-entities" + + addr, err := netip.ParseAddr("8.8.8.8") + require.NoError(t, err) + + account := &types.Account{ + Id: accountID, + CreatedBy: "user1", + Domain: "example.test", + Settings: &types.Settings{}, + Network: &types.Network{Identifier: "net-test"}, + Users: map[string]*types.User{ + "user1": {Id: "user1", AccountID: accountID, Role: types.UserRoleOwner}, + }, + Groups: map[string]*types.Group{ + "g1": {ID: "g1", AccountID: accountID, Name: "g1", Issued: types.GroupIssuedAPI}, + }, + Policies: []*types.Policy{ + {ID: "p1", AccountID: accountID, Name: "p1", Enabled: true, + Rules: []*types.PolicyRule{{ID: "r1", PolicyID: "p1", Enabled: true}}}, + }, + Routes: map[route.ID]*route.Route{ + "rt1": {ID: "rt1", AccountID: accountID, NetID: "net1", Peer: "peer1"}, + }, + NameServerGroups: map[string]*nbdns.NameServerGroup{ + "nsg1": {ID: "nsg1", AccountID: accountID, Name: "nsg1", Enabled: true, + NameServers: []nbdns.NameServer{{IP: addr, NSType: nbdns.UDPNameServerType, Port: 53}}}, + }, + NetworkResources: []*resourceTypes.NetworkResource{ + {ID: "nr1", AccountID: accountID, NetworkID: "net1", Name: "res1", Enabled: true}, + }, + NetworkRouters: []*routerTypes.NetworkRouter{ + {ID: "nrt1", AccountID: accountID, NetworkID: "net1", Peer: "peer1", Enabled: true}, + }, + } + + require.NoError(t, store.SaveAccount(ctx, account)) + + after, err := store.GetAccount(ctx, accountID) + require.NoError(t, err) + + require.Len(t, after.Groups, 1) + require.Len(t, after.Policies, 1) + require.Len(t, after.Routes, 1) + require.Len(t, after.NameServerGroups, 1) + require.Len(t, after.NetworkResources, 1) + require.Len(t, after.NetworkRouters, 1) + + for _, g := range after.Groups { + require.NotZero(t, g.AccountSeqID, "group seq must be allocated") + } + for _, p := range after.Policies { + require.NotZero(t, p.AccountSeqID, "policy seq must be allocated") + } + for _, r := range after.Routes { + require.NotZero(t, r.AccountSeqID, "route seq must be allocated (slice-of-value addressability)") + } + for _, n := range after.NameServerGroups { + require.NotZero(t, n.AccountSeqID, "name_server_group seq must be allocated (slice-of-value addressability)") + } + for _, nr := range after.NetworkResources { + require.NotZero(t, nr.AccountSeqID, "network_resource seq must be allocated") + } + for _, nr := range after.NetworkRouters { + require.NotZero(t, nr.AccountSeqID, "network_router seq must be allocated") + } + + require.NoError(t, store.SaveAccount(ctx, after)) + final, err := store.GetAccount(ctx, accountID) + require.NoError(t, err) + for _, r := range final.Routes { + require.Equal(t, after.Routes[r.ID].AccountSeqID, r.AccountSeqID, "route seq preserved on re-save") + } + for _, n := range final.NameServerGroups { + require.Equal(t, after.NameServerGroups[n.ID].AccountSeqID, n.AccountSeqID, "name_server_group seq preserved on re-save") + } +} + +func TestAllocateAccountSeqID_ConcurrentSameAccountEntity(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + ctx := context.Background() + const accountID = "concurrent-test" + const entity = types.AccountSeqEntityPolicy + const goroutines = 32 + + type result struct { + seq uint32 + err error + } + results := make(chan result, goroutines) + start := make(chan struct{}) + + for i := 0; i < goroutines; i++ { + go func() { + <-start + var allocated uint32 + err := store.ExecuteInTransaction(ctx, func(tx Store) error { + seq, err := tx.AllocateAccountSeqID(ctx, accountID, entity) + allocated = seq + return err + }) + results <- result{seq: allocated, err: err} + }() + } + close(start) + + seen := make(map[uint32]int, goroutines) + for i := 0; i < goroutines; i++ { + r := <-results + require.NoError(t, r.err, "concurrent allocate must not fail") + require.NotZero(t, r.seq, "allocated seq must be non-zero") + seen[r.seq]++ + } + + require.Lenf(t, seen, goroutines, "every concurrent allocation must yield a unique id; got duplicates in %v", seen) + for i := uint32(1); i <= goroutines; i++ { + require.Equalf(t, 1, seen[i], "id %d must appear exactly once across concurrent allocations", i) + } +} + +func TestStoreCreateGroups_AllocatedSeqIDIsNotClobbered(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + ctx := context.Background() + const accountID = "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + groups := []*types.Group{ + {ID: "seq-test-g1", AccountID: accountID, Name: "g1", Issued: "jwt", AccountSeqID: 7777}, + {ID: "seq-test-g2", AccountID: accountID, Name: "g2", Issued: "jwt", AccountSeqID: 7778}, + } + require.NoError(t, store.CreateGroups(ctx, accountID, groups)) + + for _, want := range groups { + got, err := store.GetGroupByID(ctx, LockingStrengthNone, accountID, want.ID) + require.NoError(t, err) + require.Equal(t, want.AccountSeqID, got.AccountSeqID, "seq id from caller must be persisted on insert") + } + + groups[0].Name = "g1-renamed" + groups[0].AccountSeqID = 0 + require.NoError(t, store.CreateGroups(ctx, accountID, groups[:1])) + + got, err := store.GetGroupByID(ctx, LockingStrengthNone, accountID, "seq-test-g1") + require.NoError(t, err) + require.Equal(t, "g1-renamed", got.Name, "upsert path still updates other columns") + require.Equal(t, uint32(7777), got.AccountSeqID, "upsert path must NOT overwrite account_seq_id") +} + +func TestPolicyCreate_AllocatesSeqID(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + ctx := context.Background() + const accountID = "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + existing, err := store.GetAccountPolicies(ctx, LockingStrengthNone, accountID) + require.NoError(t, err) + maxSeq := uint32(0) + for _, p := range existing { + if p.AccountSeqID > maxSeq { + maxSeq = p.AccountSeqID + } + } + + require.NoError(t, store.ExecuteInTransaction(ctx, func(tx Store) error { + seq, err := tx.AllocateAccountSeqID(ctx, accountID, types.AccountSeqEntityPolicy) + if err != nil { + return err + } + require.Equal(t, maxSeq+1, seq, "next id should be max+1 after backfill") + + newPolicy := &types.Policy{ + ID: "bench-new-policy", + AccountID: accountID, + AccountSeqID: seq, + Enabled: true, + Rules: []*types.PolicyRule{{ + ID: "bench-new-policy-rule", + PolicyID: "bench-new-policy", + Enabled: true, + Action: types.PolicyTrafficActionAccept, + Sources: []string{"groupA"}, + Destinations: []string{"groupC"}, + Bidirectional: true, + }}, + } + return tx.CreatePolicy(ctx, newPolicy) + })) + + created, err := store.GetPolicyByID(ctx, LockingStrengthNone, accountID, "bench-new-policy") + require.NoError(t, err) + require.Equal(t, maxSeq+1, created.AccountSeqID) +} diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index 8cf37de56..be0e7f216 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -137,6 +137,7 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, met &networkTypes.Network{}, &routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{}, &types.AccountOnboarding{}, &types.Job{}, &zones.Zone{}, &records.Record{}, &types.UserInviteRecord{}, &rpservice.Service{}, &rpservice.Target{}, &domain.Domain{}, &accesslogs.AccessLogEntry{}, &proxy.Proxy{}, + &types.AccountSeqCounter{}, ) if err != nil { return nil, fmt.Errorf("auto migratePreAuto: %w", err) @@ -307,6 +308,10 @@ func (s *SqlStore) SaveAccount(ctx context.Context, account *types.Account) erro return result.Error } + if err := s.assignAccountSeqIDs(ctx, tx, account); err != nil { + return fmt.Errorf("assign seq ids: %w", err) + } + result = tx. Session(&gorm.Session{FullSaveAssociations: true}). Clauses(clause.OnConflict{UpdateAll: true}). @@ -658,6 +663,22 @@ func (s *SqlStore) SaveUser(ctx context.Context, user *types.User) error { } // CreateGroups creates the given list of groups to the database. +// groupUpsertColumns is the explicit allowlist of columns that get updated when +// CreateGroups / UpdateGroups hit a PK conflict. account_seq_id is intentionally +// omitted so a caller passing an entity with the zero value (e.g. an HTTP +// handler-built struct) cannot reset the persisted seq id during an upsert. +// Keep this in sync with the Group schema in management/server/types/group.go. +func groupUpsertColumns() clause.Set { + return clause.AssignmentColumns([]string{ + "account_id", + "name", + "issued", + "integration_ref_id", + "integration_ref_integration_type", + "resources", + }) +} + func (s *SqlStore) CreateGroups(ctx context.Context, accountID string, groups []*types.Group) error { if len(groups) == 0 { return nil @@ -667,8 +688,9 @@ func (s *SqlStore) CreateGroups(ctx context.Context, accountID string, groups [] result := tx. Clauses( clause.OnConflict{ + Columns: []clause.Column{{Name: "id"}}, Where: clause.Where{Exprs: []clause.Expression{clause.Eq{Column: "groups.account_id", Value: accountID}}}, - UpdateAll: true, + DoUpdates: groupUpsertColumns(), }, ). Omit(clause.Associations). @@ -692,8 +714,9 @@ func (s *SqlStore) UpdateGroups(ctx context.Context, accountID string, groups [] result := tx. Clauses( clause.OnConflict{ + Columns: []clause.Column{{Name: "id"}}, Where: clause.Where{Exprs: []clause.Expression{clause.Eq{Column: "groups.account_id", Value: accountID}}}, - UpdateAll: true, + DoUpdates: groupUpsertColumns(), }, ). Omit(clause.Associations). @@ -1995,7 +2018,7 @@ func (s *SqlStore) getUsers(ctx context.Context, accountID string) ([]types.User } func (s *SqlStore) getGroups(ctx context.Context, accountID string) ([]*types.Group, error) { - const query = `SELECT id, account_id, name, issued, resources, integration_ref_id, integration_ref_integration_type FROM groups WHERE account_id = $1` + const query = `SELECT id, account_id, account_seq_id, name, issued, resources, integration_ref_id, integration_ref_integration_type FROM groups WHERE account_id = $1` rows, err := s.pool.Query(ctx, query, accountID) if err != nil { return nil, err @@ -2005,7 +2028,7 @@ func (s *SqlStore) getGroups(ctx context.Context, accountID string) ([]*types.Gr var resources []byte var refID sql.NullInt64 var refType sql.NullString - err := row.Scan(&g.ID, &g.AccountID, &g.Name, &g.Issued, &resources, &refID, &refType) + err := row.Scan(&g.ID, &g.AccountID, &g.AccountSeqID, &g.Name, &g.Issued, &resources, &refID, &refType) if err == nil { if refID.Valid { g.IntegrationReference.ID = int(refID.Int64) @@ -2030,7 +2053,7 @@ func (s *SqlStore) getGroups(ctx context.Context, accountID string) ([]*types.Gr } func (s *SqlStore) getPolicies(ctx context.Context, accountID string) ([]*types.Policy, error) { - const query = `SELECT id, account_id, name, description, enabled, source_posture_checks FROM policies WHERE account_id = $1` + const query = `SELECT id, account_id, account_seq_id, name, description, enabled, source_posture_checks FROM policies WHERE account_id = $1` rows, err := s.pool.Query(ctx, query, accountID) if err != nil { return nil, err @@ -2039,7 +2062,7 @@ func (s *SqlStore) getPolicies(ctx context.Context, accountID string) ([]*types. var p types.Policy var checks []byte var enabled sql.NullBool - err := row.Scan(&p.ID, &p.AccountID, &p.Name, &p.Description, &enabled, &checks) + err := row.Scan(&p.ID, &p.AccountID, &p.AccountSeqID, &p.Name, &p.Description, &enabled, &checks) if err == nil { if enabled.Valid { p.Enabled = enabled.Bool @@ -2057,7 +2080,7 @@ func (s *SqlStore) getPolicies(ctx context.Context, accountID string) ([]*types. } func (s *SqlStore) getRoutes(ctx context.Context, accountID string) ([]route.Route, error) { - const query = `SELECT id, account_id, network, domains, keep_route, net_id, description, peer, peer_groups, network_type, masquerade, metric, enabled, groups, access_control_groups, skip_auto_apply FROM routes WHERE account_id = $1` + const query = `SELECT id, account_id, account_seq_id, network, domains, keep_route, net_id, description, peer, peer_groups, network_type, masquerade, metric, enabled, groups, access_control_groups, skip_auto_apply FROM routes WHERE account_id = $1` rows, err := s.pool.Query(ctx, query, accountID) if err != nil { return nil, err @@ -2067,7 +2090,7 @@ func (s *SqlStore) getRoutes(ctx context.Context, accountID string) ([]route.Rou var network, domains, peerGroups, groups, accessGroups []byte var keepRoute, masquerade, enabled, skipAutoApply sql.NullBool var metric sql.NullInt64 - err := row.Scan(&r.ID, &r.AccountID, &network, &domains, &keepRoute, &r.NetID, &r.Description, &r.Peer, &peerGroups, &r.NetworkType, &masquerade, &metric, &enabled, &groups, &accessGroups, &skipAutoApply) + err := row.Scan(&r.ID, &r.AccountID, &r.AccountSeqID, &network, &domains, &keepRoute, &r.NetID, &r.Description, &r.Peer, &peerGroups, &r.NetworkType, &masquerade, &metric, &enabled, &groups, &accessGroups, &skipAutoApply) if err == nil { if keepRoute.Valid { r.KeepRoute = keepRoute.Bool @@ -2109,7 +2132,7 @@ func (s *SqlStore) getRoutes(ctx context.Context, accountID string) ([]route.Rou } func (s *SqlStore) getNameServerGroups(ctx context.Context, accountID string) ([]nbdns.NameServerGroup, error) { - const query = `SELECT id, account_id, name, description, name_servers, groups, "primary", domains, enabled, search_domains_enabled FROM name_server_groups WHERE account_id = $1` + const query = `SELECT id, account_id, account_seq_id, name, description, name_servers, groups, "primary", domains, enabled, search_domains_enabled FROM name_server_groups WHERE account_id = $1` rows, err := s.pool.Query(ctx, query, accountID) if err != nil { return nil, err @@ -2118,7 +2141,7 @@ func (s *SqlStore) getNameServerGroups(ctx context.Context, accountID string) ([ var n nbdns.NameServerGroup var ns, groups, domains []byte var primary, enabled, searchDomainsEnabled sql.NullBool - err := row.Scan(&n.ID, &n.AccountID, &n.Name, &n.Description, &ns, &groups, &primary, &domains, &enabled, &searchDomainsEnabled) + err := row.Scan(&n.ID, &n.AccountID, &n.AccountSeqID, &n.Name, &n.Description, &ns, &groups, &primary, &domains, &enabled, &searchDomainsEnabled) if err == nil { if primary.Valid { n.Primary = primary.Bool @@ -2345,7 +2368,7 @@ func (s *SqlStore) getNetworks(ctx context.Context, accountID string) ([]*networ } func (s *SqlStore) getNetworkRouters(ctx context.Context, accountID string) ([]*routerTypes.NetworkRouter, error) { - const query = `SELECT id, network_id, account_id, peer, peer_groups, masquerade, metric, enabled FROM network_routers WHERE account_id = $1` + const query = `SELECT id, network_id, account_id, account_seq_id, peer, peer_groups, masquerade, metric, enabled FROM network_routers WHERE account_id = $1` rows, err := s.pool.Query(ctx, query, accountID) if err != nil { return nil, err @@ -2355,7 +2378,7 @@ func (s *SqlStore) getNetworkRouters(ctx context.Context, accountID string) ([]* var peerGroups []byte var masquerade, enabled sql.NullBool var metric sql.NullInt64 - err := row.Scan(&r.ID, &r.NetworkID, &r.AccountID, &r.Peer, &peerGroups, &masquerade, &metric, &enabled) + err := row.Scan(&r.ID, &r.NetworkID, &r.AccountID, &r.AccountSeqID, &r.Peer, &peerGroups, &masquerade, &metric, &enabled) if err == nil { if masquerade.Valid { r.Masquerade = masquerade.Bool @@ -2383,7 +2406,7 @@ func (s *SqlStore) getNetworkRouters(ctx context.Context, accountID string) ([]* } func (s *SqlStore) getNetworkResources(ctx context.Context, accountID string) ([]*resourceTypes.NetworkResource, error) { - const query = `SELECT id, network_id, account_id, name, description, type, domain, prefix, enabled FROM network_resources WHERE account_id = $1` + const query = `SELECT id, network_id, account_id, account_seq_id, name, description, type, domain, prefix, enabled FROM network_resources WHERE account_id = $1` rows, err := s.pool.Query(ctx, query, accountID) if err != nil { return nil, err @@ -2392,7 +2415,7 @@ func (s *SqlStore) getNetworkResources(ctx context.Context, accountID string) ([ var r resourceTypes.NetworkResource var prefix []byte var enabled sql.NullBool - err := row.Scan(&r.ID, &r.NetworkID, &r.AccountID, &r.Name, &r.Description, &r.Type, &r.Domain, &prefix, &enabled) + err := row.Scan(&r.ID, &r.NetworkID, &r.AccountID, &r.AccountSeqID, &r.Name, &r.Description, &r.Type, &r.Domain, &prefix, &enabled) if err == nil { if enabled.Valid { r.Enabled = enabled.Bool @@ -3565,6 +3588,145 @@ func (s *SqlStore) withTx(tx *gorm.DB) Store { } } +// AllocateAccountSeqID returns the next per-account integer id for the given +// component kind. Must be called inside ExecuteInTransaction so the increment +// is serialized with the component insert. +func (s *SqlStore) AllocateAccountSeqID(ctx context.Context, accountID string, entity types.AccountSeqEntity) (uint32, error) { + return allocateAccountSeqID(ctx, s.db, s.storeEngine, accountID, entity) +} + +func allocateAccountSeqID(_ context.Context, db *gorm.DB, engine types.Engine, accountID string, entity types.AccountSeqEntity) (uint32, error) { + switch engine { + case types.PostgresStoreEngine, types.SqliteStoreEngine: + return allocateAccountSeqIDReturning(db, accountID, entity) + case types.MysqlStoreEngine: + return allocateAccountSeqIDMysql(db, accountID, entity) + default: + return 0, fmt.Errorf("unsupported store engine for account_seq allocator: %v", engine) + } +} + +// allocateAccountSeqIDReturning runs a single atomic INSERT ... ON CONFLICT +// DO UPDATE ... RETURNING that gives us the allocated id without a separate +// SELECT FOR UPDATE. Two concurrent allocations for the same (account, entity) +// produce two distinct ids: one wins the INSERT, the other wins the UPDATE +// branch and returns next_id+1. +func allocateAccountSeqIDReturning(db *gorm.DB, accountID string, entity types.AccountSeqEntity) (uint32, error) { + const sqlStr = ` + INSERT INTO account_seq_counters (account_id, entity, next_id) + VALUES (?, ?, 2) + ON CONFLICT (account_id, entity) DO UPDATE + SET next_id = account_seq_counters.next_id + 1 + RETURNING (next_id - 1) + ` + var allocated uint32 + if err := db.Raw(sqlStr, accountID, string(entity)).Scan(&allocated).Error; err != nil { + return 0, fmt.Errorf("upsert account seq counter: %w", err) + } + if allocated == 0 { + return 0, fmt.Errorf("upsert account seq counter returned 0") + } + return allocated, nil +} + +// allocateAccountSeqIDMysql is the MySQL equivalent of allocateAccountSeqIDReturning. +// MySQL has no RETURNING on ON DUPLICATE KEY UPDATE, so we use the LAST_INSERT_ID +// trick: passing an expression to LAST_INSERT_ID(expr) both sets the session value +// and returns it from the INSERT. The INSERT's value uses LAST_INSERT_ID(2) so the +// no-conflict path also surfaces the new next_id, keeping the read-back uniform. +// LAST_INSERT_ID is per-connection; GORM transactions pin a single connection, +// so the follow-up SELECT sees the same value. +func allocateAccountSeqIDMysql(db *gorm.DB, accountID string, entity types.AccountSeqEntity) (uint32, error) { + const upsertSQL = ` + INSERT INTO account_seq_counters (account_id, entity, next_id) + VALUES (?, ?, LAST_INSERT_ID(2)) + ON DUPLICATE KEY UPDATE next_id = LAST_INSERT_ID(next_id + 1) + ` + if err := db.Exec(upsertSQL, accountID, string(entity)).Error; err != nil { + return 0, fmt.Errorf("upsert account seq counter: %w", err) + } + var newNext uint64 + if err := db.Raw("SELECT LAST_INSERT_ID()").Scan(&newNext).Error; err != nil { + return 0, fmt.Errorf("get last insert id: %w", err) + } + if newNext == 0 { + return 0, fmt.Errorf("LAST_INSERT_ID returned 0; account_seq_counters misconfigured") + } + return uint32(newNext - 1), nil +} + +// assignAccountSeqIDs allocates a per-account integer id for any component on +// the in-memory account whose AccountSeqID is zero. Called from SaveAccount so +// the canonical "save the whole account" path produces the same persisted seq +// ids that the manager-level Create paths produce. Update flows that go +// through SaveAccount preserve existing non-zero values. +func (s *SqlStore) assignAccountSeqIDs(ctx context.Context, tx *gorm.DB, account *types.Account) error { + for i := range account.GroupsG { + g := account.GroupsG[i] + if g == nil || g.AccountSeqID != 0 { + continue + } + seq, err := allocateAccountSeqID(ctx, tx, s.storeEngine, account.Id, types.AccountSeqEntityGroup) + if err != nil { + return err + } + g.AccountSeqID = seq + } + for _, p := range account.Policies { + if p == nil || p.AccountSeqID != 0 { + continue + } + seq, err := allocateAccountSeqID(ctx, tx, s.storeEngine, account.Id, types.AccountSeqEntityPolicy) + if err != nil { + return err + } + p.AccountSeqID = seq + } + for i := range account.RoutesG { + r := &account.RoutesG[i] + if r.AccountSeqID != 0 { + continue + } + seq, err := allocateAccountSeqID(ctx, tx, s.storeEngine, account.Id, types.AccountSeqEntityRoute) + if err != nil { + return err + } + r.AccountSeqID = seq + } + for i := range account.NameServerGroupsG { + ng := &account.NameServerGroupsG[i] + if ng.AccountSeqID != 0 { + continue + } + seq, err := allocateAccountSeqID(ctx, tx, s.storeEngine, account.Id, types.AccountSeqEntityNameserverGroup) + if err != nil { + return err + } + ng.AccountSeqID = seq + } + for _, nr := range account.NetworkResources { + if nr == nil || nr.AccountSeqID != 0 { + continue + } + seq, err := allocateAccountSeqID(ctx, tx, s.storeEngine, account.Id, types.AccountSeqEntityNetworkResource) + if err != nil { + return err + } + nr.AccountSeqID = seq + } + for _, nr := range account.NetworkRouters { + if nr == nil || nr.AccountSeqID != 0 { + continue + } + seq, err := allocateAccountSeqID(ctx, tx, s.storeEngine, account.Id, types.AccountSeqEntityNetworkRouter) + if err != nil { + return err + } + nr.AccountSeqID = seq + } + return nil +} + // transaction wraps a GORM transaction with MySQL-specific FK checks handling // Use this instead of db.Transaction() directly to avoid deadlocks on MySQL/Aurora func (s *SqlStore) transaction(fn func(*gorm.DB) error) error { @@ -3754,7 +3916,7 @@ func (s *SqlStore) UpdateGroup(ctx context.Context, group *types.Group) error { return status.Errorf(status.InvalidArgument, "group is nil") } - if err := s.db.Omit(clause.Associations).Save(group).Error; err != nil { + if err := s.db.Omit(clause.Associations, "account_seq_id").Save(group).Error; err != nil { log.WithContext(ctx).Errorf("failed to save group to store: %v", err) return status.Errorf(status.Internal, "failed to save group to store") } @@ -3842,7 +4004,7 @@ func (s *SqlStore) CreatePolicy(ctx context.Context, policy *types.Policy) error // SavePolicy saves a policy to the database. func (s *SqlStore) SavePolicy(ctx context.Context, policy *types.Policy) error { - result := s.db.Session(&gorm.Session{FullSaveAssociations: true}).Save(policy) + result := s.db.Session(&gorm.Session{FullSaveAssociations: true}).Omit("account_seq_id").Save(policy) if err := result.Error; err != nil { log.WithContext(ctx).Errorf("failed to save policy to the store: %s", err) return status.Errorf(status.Internal, "failed to save policy to store") diff --git a/management/server/store/store.go b/management/server/store/store.go index a723c1fc3..28aa2e264 100644 --- a/management/server/store/store.go +++ b/management/server/store/store.go @@ -220,6 +220,11 @@ type Store interface { GetStoreEngine() types.Engine ExecuteInTransaction(ctx context.Context, f func(store Store) error) error + // AllocateAccountSeqID returns the next per-account integer id for the given + // component kind. Must run inside a transaction so the increment is serialized + // with the component insert. + AllocateAccountSeqID(ctx context.Context, accountID string, entity types.AccountSeqEntity) (uint32, error) + GetAccountNetworks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*networkTypes.Network, error) GetNetworkByID(ctx context.Context, lockStrength LockingStrength, accountID, networkID string) (*networkTypes.Network, error) SaveNetwork(ctx context.Context, network *networkTypes.Network) error @@ -522,6 +527,24 @@ func getMigrationsPostAuto(ctx context.Context) []migrationFunc { func(db *gorm.DB) error { return migration.DropIndex[proxy.Proxy](ctx, db, "idx_proxy_account_id_unique") }, + func(db *gorm.DB) error { + return migration.BackfillAccountSeqIDs[types.Policy](ctx, db, types.AccountSeqEntityPolicy, "id") + }, + func(db *gorm.DB) error { + return migration.BackfillAccountSeqIDs[types.Group](ctx, db, types.AccountSeqEntityGroup, "id") + }, + func(db *gorm.DB) error { + return migration.BackfillAccountSeqIDs[route.Route](ctx, db, types.AccountSeqEntityRoute, "id") + }, + func(db *gorm.DB) error { + return migration.BackfillAccountSeqIDs[resourceTypes.NetworkResource](ctx, db, types.AccountSeqEntityNetworkResource, "id") + }, + func(db *gorm.DB) error { + return migration.BackfillAccountSeqIDs[routerTypes.NetworkRouter](ctx, db, types.AccountSeqEntityNetworkRouter, "id") + }, + func(db *gorm.DB) error { + return migration.BackfillAccountSeqIDs[dns.NameServerGroup](ctx, db, types.AccountSeqEntityNameserverGroup, "id") + }, } } diff --git a/management/server/store/store_mock.go b/management/server/store/store_mock.go index d51629606..25c743865 100644 --- a/management/server/store/store_mock.go +++ b/management/server/store/store_mock.go @@ -746,6 +746,21 @@ func (mr *MockStoreMockRecorder) EphemeralServiceExists(ctx, lockStrength, accou return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "EphemeralServiceExists", reflect.TypeOf((*MockStore)(nil).EphemeralServiceExists), ctx, lockStrength, accountID, peerID, domain) } +// AllocateAccountSeqID mocks base method. +func (m *MockStore) AllocateAccountSeqID(ctx context.Context, accountID string, entity types2.AccountSeqEntity) (uint32, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AllocateAccountSeqID", ctx, accountID, entity) + ret0, _ := ret[0].(uint32) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// AllocateAccountSeqID indicates an expected call of AllocateAccountSeqID. +func (mr *MockStoreMockRecorder) AllocateAccountSeqID(ctx, accountID, entity interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AllocateAccountSeqID", reflect.TypeOf((*MockStore)(nil).AllocateAccountSeqID), ctx, accountID, entity) +} + // ExecuteInTransaction mocks base method. func (m *MockStore) ExecuteInTransaction(ctx context.Context, f func(Store) error) error { m.ctrl.T.Helper() diff --git a/management/server/types/account_seq_counter.go b/management/server/types/account_seq_counter.go new file mode 100644 index 000000000..e21f73317 --- /dev/null +++ b/management/server/types/account_seq_counter.go @@ -0,0 +1,27 @@ +package types + +// AccountSeqEntity identifies the kind of component that uses a per-account sequence. +type AccountSeqEntity string + +const ( + AccountSeqEntityPolicy AccountSeqEntity = "policy" + AccountSeqEntityGroup AccountSeqEntity = "group" + AccountSeqEntityRoute AccountSeqEntity = "route" + AccountSeqEntityNetworkResource AccountSeqEntity = "network_resource" + AccountSeqEntityNetworkRouter AccountSeqEntity = "network_router" + AccountSeqEntityNameserverGroup AccountSeqEntity = "nameserver_group" +) + +// AccountSeqCounter tracks the next per-account integer id for a given component +// kind. Reads/writes go through the store inside the same transaction as the +// component insert so two concurrent inserts cannot collide on the same id. +type AccountSeqCounter struct { + AccountID string `gorm:"primaryKey;size:255"` + Entity string `gorm:"primaryKey;size:32"` + NextID uint32 `gorm:"not null;default:1"` +} + +// TableName overrides the GORM-derived table name. +func (AccountSeqCounter) TableName() string { + return "account_seq_counters" +} diff --git a/management/server/types/group.go b/management/server/types/group.go index b4f50080a..f47bca60c 100644 --- a/management/server/types/group.go +++ b/management/server/types/group.go @@ -19,6 +19,10 @@ type Group struct { // AccountID is a reference to Account that this object belongs AccountID string `json:"-" gorm:"index"` + // AccountSeqID is a per-account monotonically increasing identifier used as the + // compact wire id when sending NetworkMap components to capable peers. + AccountSeqID uint32 `json:"-" gorm:"index:idx_groups_account_seq_id;not null;default:0"` + // Name visible in the UI Name string diff --git a/management/server/types/policy.go b/management/server/types/policy.go index d410aec8d..e3f94e178 100644 --- a/management/server/types/policy.go +++ b/management/server/types/policy.go @@ -59,6 +59,10 @@ type Policy struct { // AccountID is a reference to Account that this object belongs AccountID string `json:"-" gorm:"index"` + // AccountSeqID is a per-account monotonically increasing identifier used as the + // compact wire id when sending NetworkMap components to capable peers. + AccountSeqID uint32 `json:"-" gorm:"index:idx_policies_account_seq_id;not null;default:0"` + // Name of the Policy Name string diff --git a/route/route.go b/route/route.go index 97b9721f6..4a8c342b2 100644 --- a/route/route.go +++ b/route/route.go @@ -95,6 +95,9 @@ type Route struct { ID ID `gorm:"primaryKey"` // AccountID is a reference to Account that this object belongs AccountID string `gorm:"index"` + // AccountSeqID is a per-account monotonically increasing identifier used as the + // compact wire id when sending NetworkMap components to capable peers. + AccountSeqID uint32 `json:"-" gorm:"index:idx_routes_account_seq_id;not null;default:0"` // Network and Domains are mutually exclusive Network netip.Prefix `gorm:"serializer:json"` Domains domain.List `gorm:"serializer:json"`