mirror of
https://github.com/netbirdio/netbird.git
synced 2026-05-31 13:09:55 +00:00
init int inds migration
This commit is contained in:
@@ -53,6 +53,9 @@ type NameServerGroup struct {
|
||||
ID string `gorm:"primaryKey"`
|
||||
// AccountID is a reference to Account that this object belongs
|
||||
AccountID string `gorm:"index"`
|
||||
// AccountSeqID is a per-account monotonically increasing identifier used as the
|
||||
// compact wire id when sending NetworkMap components to capable peers.
|
||||
AccountSeqID uint32 `json:"-" gorm:"index:idx_nameserver_groups_account_seq_id;not null;default:0"`
|
||||
// Name group name
|
||||
Name string
|
||||
// Description group description
|
||||
|
||||
@@ -1621,6 +1621,14 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, g := range newGroupsToCreate {
|
||||
seq, err := transaction.AllocateAccountSeqID(ctx, userAuth.AccountId, types.AccountSeqEntityGroup)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error allocating group seq id: %w", err)
|
||||
}
|
||||
g.AccountSeqID = seq
|
||||
}
|
||||
|
||||
if err = transaction.CreateGroups(ctx, userAuth.AccountId, newGroupsToCreate); err != nil {
|
||||
return fmt.Errorf("error saving groups: %w", err)
|
||||
}
|
||||
|
||||
@@ -3036,6 +3036,16 @@ func TestAccount_SetJWTGroups(t *testing.T) {
|
||||
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user2")
|
||||
assert.NoError(t, err, "unable to get user")
|
||||
assert.Len(t, user.AutoGroups, 1, "new group should be added")
|
||||
|
||||
var newJWTGroup *types.Group
|
||||
for _, g := range groups {
|
||||
if g.Name == "group3" {
|
||||
newJWTGroup = g
|
||||
break
|
||||
}
|
||||
}
|
||||
require.NotNil(t, newJWTGroup, "JIT-created JWT group not found")
|
||||
assert.NotZero(t, newJWTGroup.AccountSeqID, "JIT-created JWT group must have a non-zero AccountSeqID")
|
||||
})
|
||||
|
||||
t.Run("remove all JWT groups when list is empty", func(t *testing.T) {
|
||||
|
||||
@@ -96,6 +96,12 @@ func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, use
|
||||
return err
|
||||
}
|
||||
|
||||
seq, err := transaction.AllocateAccountSeqID(ctx, accountID, types.AccountSeqEntityGroup)
|
||||
if err != nil {
|
||||
return status.Errorf(status.Internal, "failed to allocate group seq id: %v", err)
|
||||
}
|
||||
newGroup.AccountSeqID = seq
|
||||
|
||||
if err := transaction.CreateGroup(ctx, newGroup); err != nil {
|
||||
return status.Errorf(status.Internal, "failed to create group: %v", err)
|
||||
}
|
||||
@@ -170,6 +176,8 @@ func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, use
|
||||
return err
|
||||
}
|
||||
|
||||
newGroup.AccountSeqID = oldGroup.AccountSeqID
|
||||
|
||||
if err = transaction.UpdateGroup(ctx, newGroup); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -221,6 +229,12 @@ func (am *DefaultAccountManager) CreateGroups(ctx context.Context, accountID, us
|
||||
|
||||
newGroup.AccountID = accountID
|
||||
|
||||
seq, err := transaction.AllocateAccountSeqID(ctx, accountID, types.AccountSeqEntityGroup)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
newGroup.AccountSeqID = seq
|
||||
|
||||
if err = transaction.CreateGroup(ctx, newGroup); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -320,6 +334,12 @@ func (am *DefaultAccountManager) updateSingleGroup(ctx context.Context, accountI
|
||||
|
||||
newGroup.AccountID = accountID
|
||||
|
||||
oldGroup, err := transaction.GetGroupByID(ctx, store.LockingStrengthNone, accountID, newGroup.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
newGroup.AccountSeqID = oldGroup.AccountSeqID
|
||||
|
||||
if err := transaction.UpdateGroup(ctx, newGroup); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
156
management/server/migration/account_seq.go
Normal file
156
management/server/migration/account_seq.go
Normal file
@@ -0,0 +1,156 @@
|
||||
package migration
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
// BackfillAccountSeqIDs assigns a deterministic per-account sequential id to all
|
||||
// rows of `model` whose account_seq_id is zero, then seeds account_seq_counters
|
||||
// with the next free id per account. Idempotent: safe to re-run; both steps
|
||||
// no-op once everything is consistent.
|
||||
//
|
||||
// Implemented as two table-wide SQL statements with window functions, one
|
||||
// transaction. Backfilling 246k rows across 154k accounts on Postgres takes
|
||||
// well under a second instead of the per-account-loop ~2 minutes.
|
||||
//
|
||||
// orderColumn is the column to use when assigning the deterministic ordering
|
||||
// (typically the primary-key string id).
|
||||
func BackfillAccountSeqIDs[T any](
|
||||
ctx context.Context,
|
||||
db *gorm.DB,
|
||||
entity types.AccountSeqEntity,
|
||||
orderColumn string,
|
||||
) error {
|
||||
var model T
|
||||
if !db.Migrator().HasTable(&model) {
|
||||
log.WithContext(ctx).Debugf("backfill seq id: table for %T missing, skip", model)
|
||||
return nil
|
||||
}
|
||||
|
||||
stmt := &gorm.Statement{DB: db}
|
||||
if err := stmt.Parse(&model); err != nil {
|
||||
return fmt.Errorf("parse model: %w", err)
|
||||
}
|
||||
table := quoteIdent(db, stmt.Schema.Table)
|
||||
orderCol := quoteIdent(db, orderColumn)
|
||||
|
||||
return db.Transaction(func(tx *gorm.DB) error {
|
||||
var pending int64
|
||||
if err := tx.Raw(
|
||||
fmt.Sprintf("SELECT count(*) FROM %s WHERE account_seq_id IS NULL OR account_seq_id = 0", table),
|
||||
).Scan(&pending).Error; err != nil {
|
||||
return fmt.Errorf("count pending on %s: %w", table, err)
|
||||
}
|
||||
|
||||
if pending > 0 {
|
||||
log.WithContext(ctx).Infof("backfill seq id: %s — %d rows pending", table, pending)
|
||||
if err := backfillRankSQL(tx, table, orderCol); err != nil {
|
||||
return fmt.Errorf("rank %s: %w", table, err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := seedCountersSQL(tx, table, entity); err != nil {
|
||||
return fmt.Errorf("seed counters for %s: %w", entity, err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func quoteIdent(db *gorm.DB, name string) string {
|
||||
switch db.Dialector.Name() {
|
||||
case "mysql":
|
||||
return "`" + name + "`"
|
||||
case "postgres":
|
||||
return `"` + name + `"`
|
||||
default:
|
||||
return name
|
||||
}
|
||||
}
|
||||
|
||||
func backfillRankSQL(db *gorm.DB, table, orderCol string) error {
|
||||
dialect := db.Dialector.Name()
|
||||
var sql string
|
||||
switch dialect {
|
||||
case "postgres", "sqlite":
|
||||
sql = fmt.Sprintf(`
|
||||
WITH max_seq AS (
|
||||
SELECT account_id, COALESCE(MAX(account_seq_id), 0) AS max_seq
|
||||
FROM %s
|
||||
GROUP BY account_id
|
||||
),
|
||||
ranked AS (
|
||||
SELECT p.id,
|
||||
m.max_seq + ROW_NUMBER() OVER (PARTITION BY p.account_id ORDER BY p.%s) AS new_seq
|
||||
FROM %s p
|
||||
JOIN max_seq m ON p.account_id = m.account_id
|
||||
WHERE p.account_seq_id IS NULL OR p.account_seq_id = 0
|
||||
)
|
||||
UPDATE %s SET account_seq_id = ranked.new_seq
|
||||
FROM ranked
|
||||
WHERE %s.id = ranked.id
|
||||
`, table, orderCol, table, table, table)
|
||||
case "mysql":
|
||||
sql = fmt.Sprintf(`
|
||||
UPDATE %s p
|
||||
JOIN (
|
||||
SELECT account_id, COALESCE(MAX(account_seq_id), 0) AS max_seq
|
||||
FROM %s
|
||||
GROUP BY account_id
|
||||
) m ON p.account_id = m.account_id
|
||||
JOIN (
|
||||
SELECT id, ROW_NUMBER() OVER (PARTITION BY account_id ORDER BY %s) AS rn
|
||||
FROM %s
|
||||
WHERE account_seq_id IS NULL OR account_seq_id = 0
|
||||
) r ON p.id = r.id
|
||||
SET p.account_seq_id = m.max_seq + r.rn
|
||||
`, table, table, orderCol, table)
|
||||
default:
|
||||
return fmt.Errorf("unsupported dialect: %s", dialect)
|
||||
}
|
||||
return db.Exec(sql).Error
|
||||
}
|
||||
|
||||
func seedCountersSQL(db *gorm.DB, table string, entity types.AccountSeqEntity) error {
|
||||
dialect := db.Dialector.Name()
|
||||
var sql string
|
||||
switch dialect {
|
||||
case "postgres":
|
||||
sql = fmt.Sprintf(`
|
||||
INSERT INTO account_seq_counters (account_id, entity, next_id)
|
||||
SELECT account_id, ?, MAX(account_seq_id) + 1
|
||||
FROM %s
|
||||
WHERE account_seq_id IS NOT NULL AND account_seq_id > 0
|
||||
GROUP BY account_id
|
||||
ON CONFLICT (account_id, entity) DO UPDATE
|
||||
SET next_id = GREATEST(account_seq_counters.next_id, EXCLUDED.next_id)
|
||||
`, table)
|
||||
case "sqlite":
|
||||
sql = fmt.Sprintf(`
|
||||
INSERT INTO account_seq_counters (account_id, entity, next_id)
|
||||
SELECT account_id, ?, MAX(account_seq_id) + 1
|
||||
FROM %s
|
||||
WHERE account_seq_id IS NOT NULL AND account_seq_id > 0
|
||||
GROUP BY account_id
|
||||
ON CONFLICT (account_id, entity) DO UPDATE
|
||||
SET next_id = max(account_seq_counters.next_id, excluded.next_id)
|
||||
`, table)
|
||||
case "mysql":
|
||||
sql = fmt.Sprintf(`
|
||||
INSERT INTO account_seq_counters (account_id, entity, next_id)
|
||||
SELECT account_id, ?, MAX(account_seq_id) + 1
|
||||
FROM %s
|
||||
WHERE account_seq_id IS NOT NULL AND account_seq_id > 0
|
||||
GROUP BY account_id
|
||||
ON DUPLICATE KEY UPDATE next_id = GREATEST(next_id, VALUES(next_id))
|
||||
`, table)
|
||||
default:
|
||||
return fmt.Errorf("unsupported dialect: %s", dialect)
|
||||
}
|
||||
return db.Exec(sql, string(entity)).Error
|
||||
}
|
||||
@@ -69,6 +69,12 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco
|
||||
return err
|
||||
}
|
||||
|
||||
seq, err := transaction.AllocateAccountSeqID(ctx, accountID, types.AccountSeqEntityNameserverGroup)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
newNSGroup.AccountSeqID = seq
|
||||
|
||||
if err = transaction.SaveNameServerGroup(ctx, newNSGroup); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -120,6 +126,8 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun
|
||||
return err
|
||||
}
|
||||
|
||||
nsGroupToSave.AccountSeqID = oldNSGroup.AccountSeqID
|
||||
|
||||
if err = transaction.SaveNameServerGroup(ctx, nsGroupToSave); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -125,6 +125,12 @@ func (m *managerImpl) CreateResource(ctx context.Context, userID string, resourc
|
||||
return fmt.Errorf("failed to get network: %w", err)
|
||||
}
|
||||
|
||||
seq, err := transaction.AllocateAccountSeqID(ctx, resource.AccountID, nbtypes.AccountSeqEntityNetworkResource)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to allocate network resource seq id: %w", err)
|
||||
}
|
||||
resource.AccountSeqID = seq
|
||||
|
||||
err = transaction.SaveNetworkResource(ctx, resource)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to save network resource: %w", err)
|
||||
@@ -231,6 +237,7 @@ func (m *managerImpl) UpdateResource(ctx context.Context, userID string, resourc
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get network resource: %w", err)
|
||||
}
|
||||
resource.AccountSeqID = oldResource.AccountSeqID
|
||||
|
||||
err = transaction.SaveNetworkResource(ctx, resource)
|
||||
if err != nil {
|
||||
|
||||
@@ -32,6 +32,9 @@ type NetworkResource struct {
|
||||
ID string `gorm:"primaryKey"`
|
||||
NetworkID string `gorm:"index"`
|
||||
AccountID string `gorm:"index"`
|
||||
// AccountSeqID is a per-account monotonically increasing identifier used as the
|
||||
// compact wire id when sending NetworkMap components to capable peers.
|
||||
AccountSeqID uint32 `json:"-" gorm:"index:idx_network_resources_account_seq_id;not null;default:0"`
|
||||
Name string
|
||||
Description string
|
||||
Type NetworkResourceType
|
||||
|
||||
@@ -102,6 +102,12 @@ func (m *managerImpl) CreateRouter(ctx context.Context, userID string, router *t
|
||||
|
||||
router.ID = xid.New().String()
|
||||
|
||||
seq, err := transaction.AllocateAccountSeqID(ctx, router.AccountID, serverTypes.AccountSeqEntityNetworkRouter)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to allocate network router seq id: %w", err)
|
||||
}
|
||||
router.AccountSeqID = seq
|
||||
|
||||
err = transaction.SaveNetworkRouter(ctx, router)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create network router: %w", err)
|
||||
@@ -166,6 +172,12 @@ func (m *managerImpl) UpdateRouter(ctx context.Context, userID string, router *t
|
||||
return status.NewRouterNotPartOfNetworkError(router.ID, router.NetworkID)
|
||||
}
|
||||
|
||||
oldRouter, err := transaction.GetNetworkRouterByID(ctx, store.LockingStrengthNone, router.AccountID, router.ID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get existing network router: %w", err)
|
||||
}
|
||||
router.AccountSeqID = oldRouter.AccountSeqID
|
||||
|
||||
err = transaction.SaveNetworkRouter(ctx, router)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update network router: %w", err)
|
||||
|
||||
@@ -195,6 +195,7 @@ func Test_UpdateRouterSuccessfully(t *testing.T) {
|
||||
if err != nil {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
router.ID = "testRouterId"
|
||||
|
||||
s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir())
|
||||
if err != nil {
|
||||
|
||||
@@ -13,6 +13,9 @@ type NetworkRouter struct {
|
||||
ID string `gorm:"primaryKey"`
|
||||
NetworkID string `gorm:"index"`
|
||||
AccountID string `gorm:"index"`
|
||||
// AccountSeqID is a per-account monotonically increasing identifier used as the
|
||||
// compact wire id when sending NetworkMap components to capable peers.
|
||||
AccountSeqID uint32 `json:"-" gorm:"index:idx_network_routers_account_seq_id;not null;default:0"`
|
||||
Peer string
|
||||
PeerGroups []string `gorm:"serializer:json"`
|
||||
Masquerade bool
|
||||
|
||||
@@ -69,6 +69,8 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
|
||||
return err
|
||||
}
|
||||
|
||||
policy.AccountSeqID = existingPolicy.AccountSeqID
|
||||
|
||||
if err = transaction.SavePolicy(ctx, policy); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -78,6 +80,12 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
|
||||
return err
|
||||
}
|
||||
|
||||
seq, err := transaction.AllocateAccountSeqID(ctx, accountID, types.AccountSeqEntityPolicy)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
policy.AccountSeqID = seq
|
||||
|
||||
if err = transaction.CreatePolicy(ctx, policy); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -178,6 +178,12 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
|
||||
return err
|
||||
}
|
||||
|
||||
seq, err := transaction.AllocateAccountSeqID(ctx, accountID, types.AccountSeqEntityRoute)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
newRoute.AccountSeqID = seq
|
||||
|
||||
if err = transaction.SaveRoute(ctx, newRoute); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -231,6 +237,7 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
|
||||
return err
|
||||
}
|
||||
routeToSave.AccountID = accountID
|
||||
routeToSave.AccountSeqID = oldRoute.AccountSeqID
|
||||
|
||||
if err = transaction.SaveRoute(ctx, routeToSave); err != nil {
|
||||
return err
|
||||
|
||||
465
management/server/store/account_seq_test.go
Normal file
465
management/server/store/account_seq_test.go
Normal file
@@ -0,0 +1,465 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
var errRollback = errors.New("intentional rollback")
|
||||
|
||||
func TestAllocateAccountSeqID_SequentialPerAccount(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
const accA = "acc-a"
|
||||
const accB = "acc-b"
|
||||
|
||||
require.NoError(t, store.ExecuteInTransaction(ctx, func(tx Store) error {
|
||||
got, err := tx.AllocateAccountSeqID(ctx, accA, types.AccountSeqEntityPolicy)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, uint32(1), got)
|
||||
|
||||
got, err = tx.AllocateAccountSeqID(ctx, accA, types.AccountSeqEntityPolicy)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, uint32(2), got)
|
||||
|
||||
got, err = tx.AllocateAccountSeqID(ctx, accB, types.AccountSeqEntityPolicy)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, uint32(1), got, "different account starts from 1")
|
||||
|
||||
got, err = tx.AllocateAccountSeqID(ctx, accA, types.AccountSeqEntityGroup)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, uint32(1), got, "different entity starts from 1")
|
||||
|
||||
return nil
|
||||
}))
|
||||
|
||||
require.NoError(t, store.ExecuteInTransaction(ctx, func(tx Store) error {
|
||||
got, err := tx.AllocateAccountSeqID(ctx, accA, types.AccountSeqEntityPolicy)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, uint32(3), got, "counter persists across transactions")
|
||||
return nil
|
||||
}))
|
||||
}
|
||||
|
||||
func TestPolicyBackfill_AssignsSeqIDsToExistingPolicies(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
const accountID = "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
|
||||
policies, err := store.GetAccountPolicies(ctx, LockingStrengthNone, accountID)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, policies, "test fixture must have policies")
|
||||
|
||||
seen := make(map[uint32]bool)
|
||||
for _, p := range policies {
|
||||
require.NotZero(t, p.AccountSeqID, "policy %s must have a non-zero AccountSeqID after migration", p.ID)
|
||||
require.False(t, seen[p.AccountSeqID], "duplicate AccountSeqID %d in account %s", p.AccountSeqID, accountID)
|
||||
seen[p.AccountSeqID] = true
|
||||
}
|
||||
}
|
||||
|
||||
func TestPolicyUpdate_PreservesSeqID(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
const accountID = "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
const policyID = "cs1tnh0hhcjnqoiuebf0"
|
||||
|
||||
original, err := store.GetPolicyByID(ctx, LockingStrengthNone, accountID, policyID)
|
||||
require.NoError(t, err)
|
||||
originalSeq := original.AccountSeqID
|
||||
require.NotZero(t, originalSeq, "fixture must have non-zero AccountSeqID after backfill")
|
||||
|
||||
updated := &types.Policy{
|
||||
ID: policyID,
|
||||
AccountID: accountID,
|
||||
Name: "renamed",
|
||||
Enabled: false,
|
||||
Rules: original.Rules,
|
||||
}
|
||||
require.Zero(t, updated.AccountSeqID, "incoming struct should have zero AccountSeqID like an HTTP handler would")
|
||||
|
||||
require.NoError(t, store.SavePolicy(ctx, updated))
|
||||
|
||||
got, err := store.GetPolicyByID(ctx, LockingStrengthNone, accountID, policyID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, originalSeq, got.AccountSeqID, "AccountSeqID must not be reset by update path")
|
||||
require.Equal(t, "renamed", got.Name)
|
||||
}
|
||||
|
||||
func TestGroupUpdate_PreservesSeqID(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
const accountID = "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
|
||||
groups, err := store.GetAccountGroups(ctx, LockingStrengthNone, accountID)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, groups)
|
||||
|
||||
original := groups[0]
|
||||
originalSeq := original.AccountSeqID
|
||||
require.NotZero(t, originalSeq)
|
||||
|
||||
updated := &types.Group{
|
||||
ID: original.ID,
|
||||
AccountID: accountID,
|
||||
Name: "renamed",
|
||||
Issued: original.Issued,
|
||||
}
|
||||
require.Zero(t, updated.AccountSeqID)
|
||||
|
||||
require.NoError(t, store.UpdateGroup(ctx, updated))
|
||||
|
||||
got, err := store.GetGroupByID(ctx, LockingStrengthNone, accountID, original.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, originalSeq, got.AccountSeqID, "AccountSeqID must not be reset by UpdateGroup")
|
||||
require.Equal(t, "renamed", got.Name)
|
||||
}
|
||||
|
||||
func TestSaveAccount_AllocatesSeqIDsForDefaultGroupAndPolicy(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
const accountID = "save-account-seqid-test"
|
||||
|
||||
account := &types.Account{
|
||||
Id: accountID,
|
||||
CreatedBy: "user1",
|
||||
Domain: "example.test",
|
||||
DNSSettings: types.DNSSettings{},
|
||||
Settings: &types.Settings{},
|
||||
Network: &types.Network{
|
||||
Identifier: "net-test",
|
||||
},
|
||||
Users: map[string]*types.User{
|
||||
"user1": {Id: "user1", AccountID: accountID, Role: types.UserRoleOwner},
|
||||
},
|
||||
}
|
||||
require.NoError(t, account.AddAllGroup(false), "AddAllGroup should populate default Group + Policy")
|
||||
require.Len(t, account.Groups, 1, "default 'All' group must be present")
|
||||
require.Len(t, account.Policies, 1, "default policy must be present")
|
||||
|
||||
for _, g := range account.Groups {
|
||||
require.Zero(t, g.AccountSeqID, "default group must start with seq=0")
|
||||
}
|
||||
require.Zero(t, account.Policies[0].AccountSeqID, "default policy must start with seq=0")
|
||||
|
||||
require.NoError(t, store.SaveAccount(ctx, account))
|
||||
|
||||
groups, err := store.GetAccountGroups(ctx, LockingStrengthNone, accountID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, groups, 1)
|
||||
require.NotZerof(t, groups[0].AccountSeqID, "default group must have seq>0 after SaveAccount")
|
||||
|
||||
policies, err := store.GetAccountPolicies(ctx, LockingStrengthNone, accountID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, policies, 1)
|
||||
require.NotZerof(t, policies[0].AccountSeqID, "default policy must have seq>0 after SaveAccount")
|
||||
|
||||
require.ErrorIs(t, store.ExecuteInTransaction(ctx, func(tx Store) error {
|
||||
next, err := tx.AllocateAccountSeqID(ctx, accountID, types.AccountSeqEntityGroup)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, groups[0].AccountSeqID+1, next, "next group seq must be max+1")
|
||||
|
||||
next, err = tx.AllocateAccountSeqID(ctx, accountID, types.AccountSeqEntityPolicy)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, policies[0].AccountSeqID+1, next, "next policy seq must be max+1")
|
||||
return errRollback
|
||||
}), errRollback)
|
||||
}
|
||||
|
||||
func TestSaveAccount_PreservesExistingSeqIDs(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
const accountID = "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
|
||||
account, err := store.GetAccount(ctx, accountID)
|
||||
require.NoError(t, err)
|
||||
|
||||
groupSeqs := make(map[string]uint32)
|
||||
policySeqs := make(map[string]uint32)
|
||||
routeSeqs := make(map[route.ID]uint32)
|
||||
nsgSeqs := make(map[string]uint32)
|
||||
resourceSeqs := make(map[string]uint32)
|
||||
routerSeqs := make(map[string]uint32)
|
||||
|
||||
for _, g := range account.Groups {
|
||||
require.NotZero(t, g.AccountSeqID, "fixture group must have seq>0 after backfill")
|
||||
groupSeqs[g.ID] = g.AccountSeqID
|
||||
}
|
||||
for _, p := range account.Policies {
|
||||
require.NotZero(t, p.AccountSeqID, "fixture policy must have seq>0")
|
||||
policySeqs[p.ID] = p.AccountSeqID
|
||||
}
|
||||
for _, r := range account.Routes {
|
||||
require.NotZero(t, r.AccountSeqID, "fixture route must have seq>0")
|
||||
routeSeqs[r.ID] = r.AccountSeqID
|
||||
}
|
||||
for _, n := range account.NameServerGroups {
|
||||
require.NotZero(t, n.AccountSeqID, "fixture name_server_group must have seq>0")
|
||||
nsgSeqs[n.ID] = n.AccountSeqID
|
||||
}
|
||||
for _, nr := range account.NetworkResources {
|
||||
require.NotZero(t, nr.AccountSeqID, "fixture network_resource must have seq>0")
|
||||
resourceSeqs[nr.ID] = nr.AccountSeqID
|
||||
}
|
||||
for _, nr := range account.NetworkRouters {
|
||||
require.NotZero(t, nr.AccountSeqID, "fixture network_router must have seq>0")
|
||||
routerSeqs[nr.ID] = nr.AccountSeqID
|
||||
}
|
||||
|
||||
require.NoError(t, store.SaveAccount(ctx, account))
|
||||
|
||||
after, err := store.GetAccount(ctx, accountID)
|
||||
require.NoError(t, err)
|
||||
for _, g := range after.Groups {
|
||||
require.Equal(t, groupSeqs[g.ID], g.AccountSeqID, "group %s seq must be preserved on re-save", g.ID)
|
||||
}
|
||||
for _, p := range after.Policies {
|
||||
require.Equal(t, policySeqs[p.ID], p.AccountSeqID, "policy %s seq must be preserved", p.ID)
|
||||
}
|
||||
for _, r := range after.Routes {
|
||||
require.Equal(t, routeSeqs[r.ID], r.AccountSeqID, "route %s seq must be preserved (slice-of-value addressability)", r.ID)
|
||||
}
|
||||
for _, n := range after.NameServerGroups {
|
||||
require.Equal(t, nsgSeqs[n.ID], n.AccountSeqID, "name_server_group %s seq must be preserved (slice-of-value addressability)", n.ID)
|
||||
}
|
||||
for _, nr := range after.NetworkResources {
|
||||
require.Equal(t, resourceSeqs[nr.ID], nr.AccountSeqID, "network_resource %s seq must be preserved", nr.ID)
|
||||
}
|
||||
for _, nr := range after.NetworkRouters {
|
||||
require.Equal(t, routerSeqs[nr.ID], nr.AccountSeqID, "network_router %s seq must be preserved", nr.ID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaveAccount_AllocatesSeqIDsForAllEntityTypes(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
const accountID = "save-account-all-entities"
|
||||
|
||||
addr, err := netip.ParseAddr("8.8.8.8")
|
||||
require.NoError(t, err)
|
||||
|
||||
account := &types.Account{
|
||||
Id: accountID,
|
||||
CreatedBy: "user1",
|
||||
Domain: "example.test",
|
||||
Settings: &types.Settings{},
|
||||
Network: &types.Network{Identifier: "net-test"},
|
||||
Users: map[string]*types.User{
|
||||
"user1": {Id: "user1", AccountID: accountID, Role: types.UserRoleOwner},
|
||||
},
|
||||
Groups: map[string]*types.Group{
|
||||
"g1": {ID: "g1", AccountID: accountID, Name: "g1", Issued: types.GroupIssuedAPI},
|
||||
},
|
||||
Policies: []*types.Policy{
|
||||
{ID: "p1", AccountID: accountID, Name: "p1", Enabled: true,
|
||||
Rules: []*types.PolicyRule{{ID: "r1", PolicyID: "p1", Enabled: true}}},
|
||||
},
|
||||
Routes: map[route.ID]*route.Route{
|
||||
"rt1": {ID: "rt1", AccountID: accountID, NetID: "net1", Peer: "peer1"},
|
||||
},
|
||||
NameServerGroups: map[string]*nbdns.NameServerGroup{
|
||||
"nsg1": {ID: "nsg1", AccountID: accountID, Name: "nsg1", Enabled: true,
|
||||
NameServers: []nbdns.NameServer{{IP: addr, NSType: nbdns.UDPNameServerType, Port: 53}}},
|
||||
},
|
||||
NetworkResources: []*resourceTypes.NetworkResource{
|
||||
{ID: "nr1", AccountID: accountID, NetworkID: "net1", Name: "res1", Enabled: true},
|
||||
},
|
||||
NetworkRouters: []*routerTypes.NetworkRouter{
|
||||
{ID: "nrt1", AccountID: accountID, NetworkID: "net1", Peer: "peer1", Enabled: true},
|
||||
},
|
||||
}
|
||||
|
||||
require.NoError(t, store.SaveAccount(ctx, account))
|
||||
|
||||
after, err := store.GetAccount(ctx, accountID)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Len(t, after.Groups, 1)
|
||||
require.Len(t, after.Policies, 1)
|
||||
require.Len(t, after.Routes, 1)
|
||||
require.Len(t, after.NameServerGroups, 1)
|
||||
require.Len(t, after.NetworkResources, 1)
|
||||
require.Len(t, after.NetworkRouters, 1)
|
||||
|
||||
for _, g := range after.Groups {
|
||||
require.NotZero(t, g.AccountSeqID, "group seq must be allocated")
|
||||
}
|
||||
for _, p := range after.Policies {
|
||||
require.NotZero(t, p.AccountSeqID, "policy seq must be allocated")
|
||||
}
|
||||
for _, r := range after.Routes {
|
||||
require.NotZero(t, r.AccountSeqID, "route seq must be allocated (slice-of-value addressability)")
|
||||
}
|
||||
for _, n := range after.NameServerGroups {
|
||||
require.NotZero(t, n.AccountSeqID, "name_server_group seq must be allocated (slice-of-value addressability)")
|
||||
}
|
||||
for _, nr := range after.NetworkResources {
|
||||
require.NotZero(t, nr.AccountSeqID, "network_resource seq must be allocated")
|
||||
}
|
||||
for _, nr := range after.NetworkRouters {
|
||||
require.NotZero(t, nr.AccountSeqID, "network_router seq must be allocated")
|
||||
}
|
||||
|
||||
require.NoError(t, store.SaveAccount(ctx, after))
|
||||
final, err := store.GetAccount(ctx, accountID)
|
||||
require.NoError(t, err)
|
||||
for _, r := range final.Routes {
|
||||
require.Equal(t, after.Routes[r.ID].AccountSeqID, r.AccountSeqID, "route seq preserved on re-save")
|
||||
}
|
||||
for _, n := range final.NameServerGroups {
|
||||
require.Equal(t, after.NameServerGroups[n.ID].AccountSeqID, n.AccountSeqID, "name_server_group seq preserved on re-save")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAllocateAccountSeqID_ConcurrentSameAccountEntity(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
const accountID = "concurrent-test"
|
||||
const entity = types.AccountSeqEntityPolicy
|
||||
const goroutines = 32
|
||||
|
||||
type result struct {
|
||||
seq uint32
|
||||
err error
|
||||
}
|
||||
results := make(chan result, goroutines)
|
||||
start := make(chan struct{})
|
||||
|
||||
for i := 0; i < goroutines; i++ {
|
||||
go func() {
|
||||
<-start
|
||||
var allocated uint32
|
||||
err := store.ExecuteInTransaction(ctx, func(tx Store) error {
|
||||
seq, err := tx.AllocateAccountSeqID(ctx, accountID, entity)
|
||||
allocated = seq
|
||||
return err
|
||||
})
|
||||
results <- result{seq: allocated, err: err}
|
||||
}()
|
||||
}
|
||||
close(start)
|
||||
|
||||
seen := make(map[uint32]int, goroutines)
|
||||
for i := 0; i < goroutines; i++ {
|
||||
r := <-results
|
||||
require.NoError(t, r.err, "concurrent allocate must not fail")
|
||||
require.NotZero(t, r.seq, "allocated seq must be non-zero")
|
||||
seen[r.seq]++
|
||||
}
|
||||
|
||||
require.Lenf(t, seen, goroutines, "every concurrent allocation must yield a unique id; got duplicates in %v", seen)
|
||||
for i := uint32(1); i <= goroutines; i++ {
|
||||
require.Equalf(t, 1, seen[i], "id %d must appear exactly once across concurrent allocations", i)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStoreCreateGroups_AllocatedSeqIDIsNotClobbered(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
const accountID = "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
|
||||
groups := []*types.Group{
|
||||
{ID: "seq-test-g1", AccountID: accountID, Name: "g1", Issued: "jwt", AccountSeqID: 7777},
|
||||
{ID: "seq-test-g2", AccountID: accountID, Name: "g2", Issued: "jwt", AccountSeqID: 7778},
|
||||
}
|
||||
require.NoError(t, store.CreateGroups(ctx, accountID, groups))
|
||||
|
||||
for _, want := range groups {
|
||||
got, err := store.GetGroupByID(ctx, LockingStrengthNone, accountID, want.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, want.AccountSeqID, got.AccountSeqID, "seq id from caller must be persisted on insert")
|
||||
}
|
||||
|
||||
groups[0].Name = "g1-renamed"
|
||||
groups[0].AccountSeqID = 0
|
||||
require.NoError(t, store.CreateGroups(ctx, accountID, groups[:1]))
|
||||
|
||||
got, err := store.GetGroupByID(ctx, LockingStrengthNone, accountID, "seq-test-g1")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "g1-renamed", got.Name, "upsert path still updates other columns")
|
||||
require.Equal(t, uint32(7777), got.AccountSeqID, "upsert path must NOT overwrite account_seq_id")
|
||||
}
|
||||
|
||||
func TestPolicyCreate_AllocatesSeqID(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
const accountID = "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
|
||||
existing, err := store.GetAccountPolicies(ctx, LockingStrengthNone, accountID)
|
||||
require.NoError(t, err)
|
||||
maxSeq := uint32(0)
|
||||
for _, p := range existing {
|
||||
if p.AccountSeqID > maxSeq {
|
||||
maxSeq = p.AccountSeqID
|
||||
}
|
||||
}
|
||||
|
||||
require.NoError(t, store.ExecuteInTransaction(ctx, func(tx Store) error {
|
||||
seq, err := tx.AllocateAccountSeqID(ctx, accountID, types.AccountSeqEntityPolicy)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
require.Equal(t, maxSeq+1, seq, "next id should be max+1 after backfill")
|
||||
|
||||
newPolicy := &types.Policy{
|
||||
ID: "bench-new-policy",
|
||||
AccountID: accountID,
|
||||
AccountSeqID: seq,
|
||||
Enabled: true,
|
||||
Rules: []*types.PolicyRule{{
|
||||
ID: "bench-new-policy-rule",
|
||||
PolicyID: "bench-new-policy",
|
||||
Enabled: true,
|
||||
Action: types.PolicyTrafficActionAccept,
|
||||
Sources: []string{"groupA"},
|
||||
Destinations: []string{"groupC"},
|
||||
Bidirectional: true,
|
||||
}},
|
||||
}
|
||||
return tx.CreatePolicy(ctx, newPolicy)
|
||||
}))
|
||||
|
||||
created, err := store.GetPolicyByID(ctx, LockingStrengthNone, accountID, "bench-new-policy")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, maxSeq+1, created.AccountSeqID)
|
||||
}
|
||||
@@ -137,6 +137,7 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, met
|
||||
&networkTypes.Network{}, &routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{}, &types.AccountOnboarding{},
|
||||
&types.Job{}, &zones.Zone{}, &records.Record{}, &types.UserInviteRecord{}, &rpservice.Service{}, &rpservice.Target{}, &domain.Domain{},
|
||||
&accesslogs.AccessLogEntry{}, &proxy.Proxy{},
|
||||
&types.AccountSeqCounter{},
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("auto migratePreAuto: %w", err)
|
||||
@@ -307,6 +308,10 @@ func (s *SqlStore) SaveAccount(ctx context.Context, account *types.Account) erro
|
||||
return result.Error
|
||||
}
|
||||
|
||||
if err := s.assignAccountSeqIDs(ctx, tx, account); err != nil {
|
||||
return fmt.Errorf("assign seq ids: %w", err)
|
||||
}
|
||||
|
||||
result = tx.
|
||||
Session(&gorm.Session{FullSaveAssociations: true}).
|
||||
Clauses(clause.OnConflict{UpdateAll: true}).
|
||||
@@ -658,6 +663,22 @@ func (s *SqlStore) SaveUser(ctx context.Context, user *types.User) error {
|
||||
}
|
||||
|
||||
// CreateGroups creates the given list of groups to the database.
|
||||
// groupUpsertColumns is the explicit allowlist of columns that get updated when
|
||||
// CreateGroups / UpdateGroups hit a PK conflict. account_seq_id is intentionally
|
||||
// omitted so a caller passing an entity with the zero value (e.g. an HTTP
|
||||
// handler-built struct) cannot reset the persisted seq id during an upsert.
|
||||
// Keep this in sync with the Group schema in management/server/types/group.go.
|
||||
func groupUpsertColumns() clause.Set {
|
||||
return clause.AssignmentColumns([]string{
|
||||
"account_id",
|
||||
"name",
|
||||
"issued",
|
||||
"integration_ref_id",
|
||||
"integration_ref_integration_type",
|
||||
"resources",
|
||||
})
|
||||
}
|
||||
|
||||
func (s *SqlStore) CreateGroups(ctx context.Context, accountID string, groups []*types.Group) error {
|
||||
if len(groups) == 0 {
|
||||
return nil
|
||||
@@ -667,8 +688,9 @@ func (s *SqlStore) CreateGroups(ctx context.Context, accountID string, groups []
|
||||
result := tx.
|
||||
Clauses(
|
||||
clause.OnConflict{
|
||||
Columns: []clause.Column{{Name: "id"}},
|
||||
Where: clause.Where{Exprs: []clause.Expression{clause.Eq{Column: "groups.account_id", Value: accountID}}},
|
||||
UpdateAll: true,
|
||||
DoUpdates: groupUpsertColumns(),
|
||||
},
|
||||
).
|
||||
Omit(clause.Associations).
|
||||
@@ -692,8 +714,9 @@ func (s *SqlStore) UpdateGroups(ctx context.Context, accountID string, groups []
|
||||
result := tx.
|
||||
Clauses(
|
||||
clause.OnConflict{
|
||||
Columns: []clause.Column{{Name: "id"}},
|
||||
Where: clause.Where{Exprs: []clause.Expression{clause.Eq{Column: "groups.account_id", Value: accountID}}},
|
||||
UpdateAll: true,
|
||||
DoUpdates: groupUpsertColumns(),
|
||||
},
|
||||
).
|
||||
Omit(clause.Associations).
|
||||
@@ -1995,7 +2018,7 @@ func (s *SqlStore) getUsers(ctx context.Context, accountID string) ([]types.User
|
||||
}
|
||||
|
||||
func (s *SqlStore) getGroups(ctx context.Context, accountID string) ([]*types.Group, error) {
|
||||
const query = `SELECT id, account_id, name, issued, resources, integration_ref_id, integration_ref_integration_type FROM groups WHERE account_id = $1`
|
||||
const query = `SELECT id, account_id, account_seq_id, name, issued, resources, integration_ref_id, integration_ref_integration_type FROM groups WHERE account_id = $1`
|
||||
rows, err := s.pool.Query(ctx, query, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -2005,7 +2028,7 @@ func (s *SqlStore) getGroups(ctx context.Context, accountID string) ([]*types.Gr
|
||||
var resources []byte
|
||||
var refID sql.NullInt64
|
||||
var refType sql.NullString
|
||||
err := row.Scan(&g.ID, &g.AccountID, &g.Name, &g.Issued, &resources, &refID, &refType)
|
||||
err := row.Scan(&g.ID, &g.AccountID, &g.AccountSeqID, &g.Name, &g.Issued, &resources, &refID, &refType)
|
||||
if err == nil {
|
||||
if refID.Valid {
|
||||
g.IntegrationReference.ID = int(refID.Int64)
|
||||
@@ -2030,7 +2053,7 @@ func (s *SqlStore) getGroups(ctx context.Context, accountID string) ([]*types.Gr
|
||||
}
|
||||
|
||||
func (s *SqlStore) getPolicies(ctx context.Context, accountID string) ([]*types.Policy, error) {
|
||||
const query = `SELECT id, account_id, name, description, enabled, source_posture_checks FROM policies WHERE account_id = $1`
|
||||
const query = `SELECT id, account_id, account_seq_id, name, description, enabled, source_posture_checks FROM policies WHERE account_id = $1`
|
||||
rows, err := s.pool.Query(ctx, query, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -2039,7 +2062,7 @@ func (s *SqlStore) getPolicies(ctx context.Context, accountID string) ([]*types.
|
||||
var p types.Policy
|
||||
var checks []byte
|
||||
var enabled sql.NullBool
|
||||
err := row.Scan(&p.ID, &p.AccountID, &p.Name, &p.Description, &enabled, &checks)
|
||||
err := row.Scan(&p.ID, &p.AccountID, &p.AccountSeqID, &p.Name, &p.Description, &enabled, &checks)
|
||||
if err == nil {
|
||||
if enabled.Valid {
|
||||
p.Enabled = enabled.Bool
|
||||
@@ -2057,7 +2080,7 @@ func (s *SqlStore) getPolicies(ctx context.Context, accountID string) ([]*types.
|
||||
}
|
||||
|
||||
func (s *SqlStore) getRoutes(ctx context.Context, accountID string) ([]route.Route, error) {
|
||||
const query = `SELECT id, account_id, network, domains, keep_route, net_id, description, peer, peer_groups, network_type, masquerade, metric, enabled, groups, access_control_groups, skip_auto_apply FROM routes WHERE account_id = $1`
|
||||
const query = `SELECT id, account_id, account_seq_id, network, domains, keep_route, net_id, description, peer, peer_groups, network_type, masquerade, metric, enabled, groups, access_control_groups, skip_auto_apply FROM routes WHERE account_id = $1`
|
||||
rows, err := s.pool.Query(ctx, query, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -2067,7 +2090,7 @@ func (s *SqlStore) getRoutes(ctx context.Context, accountID string) ([]route.Rou
|
||||
var network, domains, peerGroups, groups, accessGroups []byte
|
||||
var keepRoute, masquerade, enabled, skipAutoApply sql.NullBool
|
||||
var metric sql.NullInt64
|
||||
err := row.Scan(&r.ID, &r.AccountID, &network, &domains, &keepRoute, &r.NetID, &r.Description, &r.Peer, &peerGroups, &r.NetworkType, &masquerade, &metric, &enabled, &groups, &accessGroups, &skipAutoApply)
|
||||
err := row.Scan(&r.ID, &r.AccountID, &r.AccountSeqID, &network, &domains, &keepRoute, &r.NetID, &r.Description, &r.Peer, &peerGroups, &r.NetworkType, &masquerade, &metric, &enabled, &groups, &accessGroups, &skipAutoApply)
|
||||
if err == nil {
|
||||
if keepRoute.Valid {
|
||||
r.KeepRoute = keepRoute.Bool
|
||||
@@ -2109,7 +2132,7 @@ func (s *SqlStore) getRoutes(ctx context.Context, accountID string) ([]route.Rou
|
||||
}
|
||||
|
||||
func (s *SqlStore) getNameServerGroups(ctx context.Context, accountID string) ([]nbdns.NameServerGroup, error) {
|
||||
const query = `SELECT id, account_id, name, description, name_servers, groups, "primary", domains, enabled, search_domains_enabled FROM name_server_groups WHERE account_id = $1`
|
||||
const query = `SELECT id, account_id, account_seq_id, name, description, name_servers, groups, "primary", domains, enabled, search_domains_enabled FROM name_server_groups WHERE account_id = $1`
|
||||
rows, err := s.pool.Query(ctx, query, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -2118,7 +2141,7 @@ func (s *SqlStore) getNameServerGroups(ctx context.Context, accountID string) ([
|
||||
var n nbdns.NameServerGroup
|
||||
var ns, groups, domains []byte
|
||||
var primary, enabled, searchDomainsEnabled sql.NullBool
|
||||
err := row.Scan(&n.ID, &n.AccountID, &n.Name, &n.Description, &ns, &groups, &primary, &domains, &enabled, &searchDomainsEnabled)
|
||||
err := row.Scan(&n.ID, &n.AccountID, &n.AccountSeqID, &n.Name, &n.Description, &ns, &groups, &primary, &domains, &enabled, &searchDomainsEnabled)
|
||||
if err == nil {
|
||||
if primary.Valid {
|
||||
n.Primary = primary.Bool
|
||||
@@ -2345,7 +2368,7 @@ func (s *SqlStore) getNetworks(ctx context.Context, accountID string) ([]*networ
|
||||
}
|
||||
|
||||
func (s *SqlStore) getNetworkRouters(ctx context.Context, accountID string) ([]*routerTypes.NetworkRouter, error) {
|
||||
const query = `SELECT id, network_id, account_id, peer, peer_groups, masquerade, metric, enabled FROM network_routers WHERE account_id = $1`
|
||||
const query = `SELECT id, network_id, account_id, account_seq_id, peer, peer_groups, masquerade, metric, enabled FROM network_routers WHERE account_id = $1`
|
||||
rows, err := s.pool.Query(ctx, query, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -2355,7 +2378,7 @@ func (s *SqlStore) getNetworkRouters(ctx context.Context, accountID string) ([]*
|
||||
var peerGroups []byte
|
||||
var masquerade, enabled sql.NullBool
|
||||
var metric sql.NullInt64
|
||||
err := row.Scan(&r.ID, &r.NetworkID, &r.AccountID, &r.Peer, &peerGroups, &masquerade, &metric, &enabled)
|
||||
err := row.Scan(&r.ID, &r.NetworkID, &r.AccountID, &r.AccountSeqID, &r.Peer, &peerGroups, &masquerade, &metric, &enabled)
|
||||
if err == nil {
|
||||
if masquerade.Valid {
|
||||
r.Masquerade = masquerade.Bool
|
||||
@@ -2383,7 +2406,7 @@ func (s *SqlStore) getNetworkRouters(ctx context.Context, accountID string) ([]*
|
||||
}
|
||||
|
||||
func (s *SqlStore) getNetworkResources(ctx context.Context, accountID string) ([]*resourceTypes.NetworkResource, error) {
|
||||
const query = `SELECT id, network_id, account_id, name, description, type, domain, prefix, enabled FROM network_resources WHERE account_id = $1`
|
||||
const query = `SELECT id, network_id, account_id, account_seq_id, name, description, type, domain, prefix, enabled FROM network_resources WHERE account_id = $1`
|
||||
rows, err := s.pool.Query(ctx, query, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -2392,7 +2415,7 @@ func (s *SqlStore) getNetworkResources(ctx context.Context, accountID string) ([
|
||||
var r resourceTypes.NetworkResource
|
||||
var prefix []byte
|
||||
var enabled sql.NullBool
|
||||
err := row.Scan(&r.ID, &r.NetworkID, &r.AccountID, &r.Name, &r.Description, &r.Type, &r.Domain, &prefix, &enabled)
|
||||
err := row.Scan(&r.ID, &r.NetworkID, &r.AccountID, &r.AccountSeqID, &r.Name, &r.Description, &r.Type, &r.Domain, &prefix, &enabled)
|
||||
if err == nil {
|
||||
if enabled.Valid {
|
||||
r.Enabled = enabled.Bool
|
||||
@@ -3565,6 +3588,145 @@ func (s *SqlStore) withTx(tx *gorm.DB) Store {
|
||||
}
|
||||
}
|
||||
|
||||
// AllocateAccountSeqID returns the next per-account integer id for the given
|
||||
// component kind. Must be called inside ExecuteInTransaction so the increment
|
||||
// is serialized with the component insert.
|
||||
func (s *SqlStore) AllocateAccountSeqID(ctx context.Context, accountID string, entity types.AccountSeqEntity) (uint32, error) {
|
||||
return allocateAccountSeqID(ctx, s.db, s.storeEngine, accountID, entity)
|
||||
}
|
||||
|
||||
func allocateAccountSeqID(_ context.Context, db *gorm.DB, engine types.Engine, accountID string, entity types.AccountSeqEntity) (uint32, error) {
|
||||
switch engine {
|
||||
case types.PostgresStoreEngine, types.SqliteStoreEngine:
|
||||
return allocateAccountSeqIDReturning(db, accountID, entity)
|
||||
case types.MysqlStoreEngine:
|
||||
return allocateAccountSeqIDMysql(db, accountID, entity)
|
||||
default:
|
||||
return 0, fmt.Errorf("unsupported store engine for account_seq allocator: %v", engine)
|
||||
}
|
||||
}
|
||||
|
||||
// allocateAccountSeqIDReturning runs a single atomic INSERT ... ON CONFLICT
|
||||
// DO UPDATE ... RETURNING that gives us the allocated id without a separate
|
||||
// SELECT FOR UPDATE. Two concurrent allocations for the same (account, entity)
|
||||
// produce two distinct ids: one wins the INSERT, the other wins the UPDATE
|
||||
// branch and returns next_id+1.
|
||||
func allocateAccountSeqIDReturning(db *gorm.DB, accountID string, entity types.AccountSeqEntity) (uint32, error) {
|
||||
const sqlStr = `
|
||||
INSERT INTO account_seq_counters (account_id, entity, next_id)
|
||||
VALUES (?, ?, 2)
|
||||
ON CONFLICT (account_id, entity) DO UPDATE
|
||||
SET next_id = account_seq_counters.next_id + 1
|
||||
RETURNING (next_id - 1)
|
||||
`
|
||||
var allocated uint32
|
||||
if err := db.Raw(sqlStr, accountID, string(entity)).Scan(&allocated).Error; err != nil {
|
||||
return 0, fmt.Errorf("upsert account seq counter: %w", err)
|
||||
}
|
||||
if allocated == 0 {
|
||||
return 0, fmt.Errorf("upsert account seq counter returned 0")
|
||||
}
|
||||
return allocated, nil
|
||||
}
|
||||
|
||||
// allocateAccountSeqIDMysql is the MySQL equivalent of allocateAccountSeqIDReturning.
|
||||
// MySQL has no RETURNING on ON DUPLICATE KEY UPDATE, so we use the LAST_INSERT_ID
|
||||
// trick: passing an expression to LAST_INSERT_ID(expr) both sets the session value
|
||||
// and returns it from the INSERT. The INSERT's value uses LAST_INSERT_ID(2) so the
|
||||
// no-conflict path also surfaces the new next_id, keeping the read-back uniform.
|
||||
// LAST_INSERT_ID is per-connection; GORM transactions pin a single connection,
|
||||
// so the follow-up SELECT sees the same value.
|
||||
func allocateAccountSeqIDMysql(db *gorm.DB, accountID string, entity types.AccountSeqEntity) (uint32, error) {
|
||||
const upsertSQL = `
|
||||
INSERT INTO account_seq_counters (account_id, entity, next_id)
|
||||
VALUES (?, ?, LAST_INSERT_ID(2))
|
||||
ON DUPLICATE KEY UPDATE next_id = LAST_INSERT_ID(next_id + 1)
|
||||
`
|
||||
if err := db.Exec(upsertSQL, accountID, string(entity)).Error; err != nil {
|
||||
return 0, fmt.Errorf("upsert account seq counter: %w", err)
|
||||
}
|
||||
var newNext uint64
|
||||
if err := db.Raw("SELECT LAST_INSERT_ID()").Scan(&newNext).Error; err != nil {
|
||||
return 0, fmt.Errorf("get last insert id: %w", err)
|
||||
}
|
||||
if newNext == 0 {
|
||||
return 0, fmt.Errorf("LAST_INSERT_ID returned 0; account_seq_counters misconfigured")
|
||||
}
|
||||
return uint32(newNext - 1), nil
|
||||
}
|
||||
|
||||
// assignAccountSeqIDs allocates a per-account integer id for any component on
|
||||
// the in-memory account whose AccountSeqID is zero. Called from SaveAccount so
|
||||
// the canonical "save the whole account" path produces the same persisted seq
|
||||
// ids that the manager-level Create paths produce. Update flows that go
|
||||
// through SaveAccount preserve existing non-zero values.
|
||||
func (s *SqlStore) assignAccountSeqIDs(ctx context.Context, tx *gorm.DB, account *types.Account) error {
|
||||
for i := range account.GroupsG {
|
||||
g := account.GroupsG[i]
|
||||
if g == nil || g.AccountSeqID != 0 {
|
||||
continue
|
||||
}
|
||||
seq, err := allocateAccountSeqID(ctx, tx, s.storeEngine, account.Id, types.AccountSeqEntityGroup)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
g.AccountSeqID = seq
|
||||
}
|
||||
for _, p := range account.Policies {
|
||||
if p == nil || p.AccountSeqID != 0 {
|
||||
continue
|
||||
}
|
||||
seq, err := allocateAccountSeqID(ctx, tx, s.storeEngine, account.Id, types.AccountSeqEntityPolicy)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
p.AccountSeqID = seq
|
||||
}
|
||||
for i := range account.RoutesG {
|
||||
r := &account.RoutesG[i]
|
||||
if r.AccountSeqID != 0 {
|
||||
continue
|
||||
}
|
||||
seq, err := allocateAccountSeqID(ctx, tx, s.storeEngine, account.Id, types.AccountSeqEntityRoute)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
r.AccountSeqID = seq
|
||||
}
|
||||
for i := range account.NameServerGroupsG {
|
||||
ng := &account.NameServerGroupsG[i]
|
||||
if ng.AccountSeqID != 0 {
|
||||
continue
|
||||
}
|
||||
seq, err := allocateAccountSeqID(ctx, tx, s.storeEngine, account.Id, types.AccountSeqEntityNameserverGroup)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ng.AccountSeqID = seq
|
||||
}
|
||||
for _, nr := range account.NetworkResources {
|
||||
if nr == nil || nr.AccountSeqID != 0 {
|
||||
continue
|
||||
}
|
||||
seq, err := allocateAccountSeqID(ctx, tx, s.storeEngine, account.Id, types.AccountSeqEntityNetworkResource)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
nr.AccountSeqID = seq
|
||||
}
|
||||
for _, nr := range account.NetworkRouters {
|
||||
if nr == nil || nr.AccountSeqID != 0 {
|
||||
continue
|
||||
}
|
||||
seq, err := allocateAccountSeqID(ctx, tx, s.storeEngine, account.Id, types.AccountSeqEntityNetworkRouter)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
nr.AccountSeqID = seq
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// transaction wraps a GORM transaction with MySQL-specific FK checks handling
|
||||
// Use this instead of db.Transaction() directly to avoid deadlocks on MySQL/Aurora
|
||||
func (s *SqlStore) transaction(fn func(*gorm.DB) error) error {
|
||||
@@ -3754,7 +3916,7 @@ func (s *SqlStore) UpdateGroup(ctx context.Context, group *types.Group) error {
|
||||
return status.Errorf(status.InvalidArgument, "group is nil")
|
||||
}
|
||||
|
||||
if err := s.db.Omit(clause.Associations).Save(group).Error; err != nil {
|
||||
if err := s.db.Omit(clause.Associations, "account_seq_id").Save(group).Error; err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to save group to store: %v", err)
|
||||
return status.Errorf(status.Internal, "failed to save group to store")
|
||||
}
|
||||
@@ -3842,7 +4004,7 @@ func (s *SqlStore) CreatePolicy(ctx context.Context, policy *types.Policy) error
|
||||
|
||||
// SavePolicy saves a policy to the database.
|
||||
func (s *SqlStore) SavePolicy(ctx context.Context, policy *types.Policy) error {
|
||||
result := s.db.Session(&gorm.Session{FullSaveAssociations: true}).Save(policy)
|
||||
result := s.db.Session(&gorm.Session{FullSaveAssociations: true}).Omit("account_seq_id").Save(policy)
|
||||
if err := result.Error; err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to save policy to the store: %s", err)
|
||||
return status.Errorf(status.Internal, "failed to save policy to store")
|
||||
|
||||
@@ -220,6 +220,11 @@ type Store interface {
|
||||
GetStoreEngine() types.Engine
|
||||
ExecuteInTransaction(ctx context.Context, f func(store Store) error) error
|
||||
|
||||
// AllocateAccountSeqID returns the next per-account integer id for the given
|
||||
// component kind. Must run inside a transaction so the increment is serialized
|
||||
// with the component insert.
|
||||
AllocateAccountSeqID(ctx context.Context, accountID string, entity types.AccountSeqEntity) (uint32, error)
|
||||
|
||||
GetAccountNetworks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*networkTypes.Network, error)
|
||||
GetNetworkByID(ctx context.Context, lockStrength LockingStrength, accountID, networkID string) (*networkTypes.Network, error)
|
||||
SaveNetwork(ctx context.Context, network *networkTypes.Network) error
|
||||
@@ -522,6 +527,24 @@ func getMigrationsPostAuto(ctx context.Context) []migrationFunc {
|
||||
func(db *gorm.DB) error {
|
||||
return migration.DropIndex[proxy.Proxy](ctx, db, "idx_proxy_account_id_unique")
|
||||
},
|
||||
func(db *gorm.DB) error {
|
||||
return migration.BackfillAccountSeqIDs[types.Policy](ctx, db, types.AccountSeqEntityPolicy, "id")
|
||||
},
|
||||
func(db *gorm.DB) error {
|
||||
return migration.BackfillAccountSeqIDs[types.Group](ctx, db, types.AccountSeqEntityGroup, "id")
|
||||
},
|
||||
func(db *gorm.DB) error {
|
||||
return migration.BackfillAccountSeqIDs[route.Route](ctx, db, types.AccountSeqEntityRoute, "id")
|
||||
},
|
||||
func(db *gorm.DB) error {
|
||||
return migration.BackfillAccountSeqIDs[resourceTypes.NetworkResource](ctx, db, types.AccountSeqEntityNetworkResource, "id")
|
||||
},
|
||||
func(db *gorm.DB) error {
|
||||
return migration.BackfillAccountSeqIDs[routerTypes.NetworkRouter](ctx, db, types.AccountSeqEntityNetworkRouter, "id")
|
||||
},
|
||||
func(db *gorm.DB) error {
|
||||
return migration.BackfillAccountSeqIDs[dns.NameServerGroup](ctx, db, types.AccountSeqEntityNameserverGroup, "id")
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -746,6 +746,21 @@ func (mr *MockStoreMockRecorder) EphemeralServiceExists(ctx, lockStrength, accou
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "EphemeralServiceExists", reflect.TypeOf((*MockStore)(nil).EphemeralServiceExists), ctx, lockStrength, accountID, peerID, domain)
|
||||
}
|
||||
|
||||
// AllocateAccountSeqID mocks base method.
|
||||
func (m *MockStore) AllocateAccountSeqID(ctx context.Context, accountID string, entity types2.AccountSeqEntity) (uint32, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "AllocateAccountSeqID", ctx, accountID, entity)
|
||||
ret0, _ := ret[0].(uint32)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// AllocateAccountSeqID indicates an expected call of AllocateAccountSeqID.
|
||||
func (mr *MockStoreMockRecorder) AllocateAccountSeqID(ctx, accountID, entity interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AllocateAccountSeqID", reflect.TypeOf((*MockStore)(nil).AllocateAccountSeqID), ctx, accountID, entity)
|
||||
}
|
||||
|
||||
// ExecuteInTransaction mocks base method.
|
||||
func (m *MockStore) ExecuteInTransaction(ctx context.Context, f func(Store) error) error {
|
||||
m.ctrl.T.Helper()
|
||||
|
||||
27
management/server/types/account_seq_counter.go
Normal file
27
management/server/types/account_seq_counter.go
Normal file
@@ -0,0 +1,27 @@
|
||||
package types
|
||||
|
||||
// AccountSeqEntity identifies the kind of component that uses a per-account sequence.
|
||||
type AccountSeqEntity string
|
||||
|
||||
const (
|
||||
AccountSeqEntityPolicy AccountSeqEntity = "policy"
|
||||
AccountSeqEntityGroup AccountSeqEntity = "group"
|
||||
AccountSeqEntityRoute AccountSeqEntity = "route"
|
||||
AccountSeqEntityNetworkResource AccountSeqEntity = "network_resource"
|
||||
AccountSeqEntityNetworkRouter AccountSeqEntity = "network_router"
|
||||
AccountSeqEntityNameserverGroup AccountSeqEntity = "nameserver_group"
|
||||
)
|
||||
|
||||
// AccountSeqCounter tracks the next per-account integer id for a given component
|
||||
// kind. Reads/writes go through the store inside the same transaction as the
|
||||
// component insert so two concurrent inserts cannot collide on the same id.
|
||||
type AccountSeqCounter struct {
|
||||
AccountID string `gorm:"primaryKey;size:255"`
|
||||
Entity string `gorm:"primaryKey;size:32"`
|
||||
NextID uint32 `gorm:"not null;default:1"`
|
||||
}
|
||||
|
||||
// TableName overrides the GORM-derived table name.
|
||||
func (AccountSeqCounter) TableName() string {
|
||||
return "account_seq_counters"
|
||||
}
|
||||
@@ -19,6 +19,10 @@ type Group struct {
|
||||
// AccountID is a reference to Account that this object belongs
|
||||
AccountID string `json:"-" gorm:"index"`
|
||||
|
||||
// AccountSeqID is a per-account monotonically increasing identifier used as the
|
||||
// compact wire id when sending NetworkMap components to capable peers.
|
||||
AccountSeqID uint32 `json:"-" gorm:"index:idx_groups_account_seq_id;not null;default:0"`
|
||||
|
||||
// Name visible in the UI
|
||||
Name string
|
||||
|
||||
|
||||
@@ -59,6 +59,10 @@ type Policy struct {
|
||||
// AccountID is a reference to Account that this object belongs
|
||||
AccountID string `json:"-" gorm:"index"`
|
||||
|
||||
// AccountSeqID is a per-account monotonically increasing identifier used as the
|
||||
// compact wire id when sending NetworkMap components to capable peers.
|
||||
AccountSeqID uint32 `json:"-" gorm:"index:idx_policies_account_seq_id;not null;default:0"`
|
||||
|
||||
// Name of the Policy
|
||||
Name string
|
||||
|
||||
|
||||
@@ -95,6 +95,9 @@ type Route struct {
|
||||
ID ID `gorm:"primaryKey"`
|
||||
// AccountID is a reference to Account that this object belongs
|
||||
AccountID string `gorm:"index"`
|
||||
// AccountSeqID is a per-account monotonically increasing identifier used as the
|
||||
// compact wire id when sending NetworkMap components to capable peers.
|
||||
AccountSeqID uint32 `json:"-" gorm:"index:idx_routes_account_seq_id;not null;default:0"`
|
||||
// Network and Domains are mutually exclusive
|
||||
Network netip.Prefix `gorm:"serializer:json"`
|
||||
Domains domain.List `gorm:"serializer:json"`
|
||||
|
||||
Reference in New Issue
Block a user