From a1cb95276426dfd7c87fc96d48b70ed2b8b7c536 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Fri, 10 Apr 2026 09:14:42 +0800 Subject: [PATCH] Reconcile IPv6 addresses on group membership changes (#5837) --- management/server/account.go | 100 ++++++++++++++++----- management/server/account_test.go | 10 +-- management/server/group.go | 71 ++++++++++----- management/server/group_ipv6_test.go | 125 +++++++++++++++++++++++++++ management/server/user.go | 6 ++ 5 files changed, 259 insertions(+), 53 deletions(-) create mode 100644 management/server/group_ipv6_test.go diff --git a/management/server/account.go b/management/server/account.go index 74cc93ca4..ee4760483 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -348,7 +348,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco } if oldSettings.GroupsPropagationEnabled != newSettings.GroupsPropagationEnabled && newSettings.GroupsPropagationEnabled { - groupsUpdated, groupChangesAffectPeers, err = propagateUserGroupMemberships(ctx, transaction, accountID) + groupsUpdated, groupChangesAffectPeers, err = am.propagateUserGroupMemberships(ctx, transaction, accountID) if err != nil { return err } @@ -1599,6 +1599,11 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth } } + allGroupChanges := slices.Concat(addNewGroups, removeOldGroups) + if err = am.reconcileIPv6ForGroupChanges(ctx, transaction, userAuth.AccountId, allGroupChanges); err != nil { + return fmt.Errorf("reconcile IPv6 for group changes: %w", err) + } + if err = transaction.IncrementNetworkSerial(ctx, userAuth.AccountId); err != nil { return fmt.Errorf("error incrementing network serial: %w", err) } @@ -2160,7 +2165,7 @@ func (am *DefaultAccountManager) UpdateToPrimaryAccount(ctx context.Context, acc // propagateUserGroupMemberships propagates all account users' group memberships to their peers. // Returns true if any groups were modified, true if those updates affect peers and an error. -func propagateUserGroupMemberships(ctx context.Context, transaction store.Store, accountID string) (groupsUpdated bool, peersAffected bool, err error) { +func (am *DefaultAccountManager) propagateUserGroupMemberships(ctx context.Context, transaction store.Store, accountID string) (groupsUpdated bool, peersAffected bool, err error) { users, err := transaction.GetAccountUsers(ctx, store.LockingStrengthNone, accountID) if err != nil { return false, false, err @@ -2182,29 +2187,13 @@ func propagateUserGroupMemberships(ctx context.Context, transaction store.Store, } } - updatedGroups := []string{} - for _, user := range users { - userPeers, err := transaction.GetUserPeers(ctx, store.LockingStrengthNone, accountID, user.Id) - if err != nil { - return false, false, err - } + updatedGroups, err := propagateAutoGroupsForUsers(ctx, transaction, accountID, users, accountGroupPeers) + if err != nil { + return false, false, err + } - for _, peer := range userPeers { - for _, groupID := range user.AutoGroups { - if _, exists := accountGroupPeers[groupID]; !exists { - // we do not wanna create the groups here - log.WithContext(ctx).Warnf("group %s does not exist for user group propagation", groupID) - continue - } - if _, exists := accountGroupPeers[groupID][peer.ID]; exists { - continue - } - if err := transaction.AddPeerToGroup(ctx, accountID, peer.ID, groupID); err != nil { - return false, false, fmt.Errorf("error adding peer %s to group %s: %w", peer.ID, groupID, err) - } - updatedGroups = append(updatedGroups, groupID) - } - } + if err = am.reconcileIPv6ForGroupChanges(ctx, transaction, accountID, updatedGroups); err != nil { + return false, false, fmt.Errorf("reconcile IPv6 for group changes: %w", err) } peersAffected, err = areGroupChangesAffectPeers(ctx, transaction, accountID, updatedGroups) @@ -2215,6 +2204,35 @@ func propagateUserGroupMemberships(ctx context.Context, transaction store.Store, return len(updatedGroups) > 0, peersAffected, nil } +// propagateAutoGroupsForUsers adds each user's peers to their AutoGroups where not already present. +// Returns the list of group IDs that were modified. +func propagateAutoGroupsForUsers(ctx context.Context, transaction store.Store, accountID string, users []*types.User, accountGroupPeers map[string]map[string]struct{}) ([]string, error) { + var updatedGroups []string + for _, user := range users { + userPeers, err := transaction.GetUserPeers(ctx, store.LockingStrengthNone, accountID, user.Id) + if err != nil { + return nil, err + } + + for _, peer := range userPeers { + for _, groupID := range user.AutoGroups { + if _, exists := accountGroupPeers[groupID]; !exists { + log.WithContext(ctx).Warnf("group %s does not exist for user group propagation", groupID) + continue + } + if _, exists := accountGroupPeers[groupID][peer.ID]; exists { + continue + } + if err := transaction.AddPeerToGroup(ctx, accountID, peer.ID, groupID); err != nil { + return nil, fmt.Errorf("error adding peer %s to group %s: %w", peer.ID, groupID, err) + } + updatedGroups = append(updatedGroups, groupID) + } + } + } + return updatedGroups, nil +} + // reallocateAccountPeerIPs re-allocates all peer IPs when the network range changes func (am *DefaultAccountManager) reallocateAccountPeerIPs(ctx context.Context, transaction store.Store, accountID string, newNetworkRange netip.Prefix) error { if !newNetworkRange.IsValid() { @@ -2315,6 +2333,40 @@ func (am *DefaultAccountManager) updatePeerIPv6Addresses(ctx context.Context, tr return nil } +// reconcileIPv6ForGroupChanges checks whether the given group IDs overlap with +// the account's IPv6EnabledGroups. If they do, it runs a full IPv6 address +// reconciliation so that peers gaining or losing membership in an IPv6-enabled +// group get their addresses assigned or removed. +func (am *DefaultAccountManager) reconcileIPv6ForGroupChanges(ctx context.Context, transaction store.Store, accountID string, groupIDs []string) error { + settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthNone, accountID) + if err != nil { + return fmt.Errorf("get account settings: %w", err) + } + + if len(settings.IPv6EnabledGroups) == 0 { + return nil + } + + enabledSet := make(map[string]struct{}, len(settings.IPv6EnabledGroups)) + for _, gid := range settings.IPv6EnabledGroups { + enabledSet[gid] = struct{}{} + } + + affected := false + for _, gid := range groupIDs { + if _, ok := enabledSet[gid]; ok { + affected = true + break + } + } + + if !affected { + return nil + } + + return am.updatePeerIPv6Addresses(ctx, transaction, accountID, settings) +} + func (am *DefaultAccountManager) ensureIPv6Subnet(ctx context.Context, transaction store.Store, accountID string, settings *types.Settings, network *types.Network) error { if settings.NetworkRangeV6.IsValid() { network.NetV6 = net.IPNet{ diff --git a/management/server/account_test.go b/management/server/account_test.go index 915075adb..bdaa74e76 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -3665,7 +3665,7 @@ func TestPropagateUserGroupMemberships(t *testing.T) { require.NoError(t, err) t.Run("should skip propagation when the user has no groups", func(t *testing.T) { - groupsUpdated, groupChangesAffectPeers, err := propagateUserGroupMemberships(ctx, manager.Store, account.Id) + groupsUpdated, groupChangesAffectPeers, err := manager.propagateUserGroupMemberships(ctx, manager.Store, account.Id) require.NoError(t, err) assert.False(t, groupsUpdated) assert.False(t, groupChangesAffectPeers) @@ -3681,7 +3681,7 @@ func TestPropagateUserGroupMemberships(t *testing.T) { user.AutoGroups = append(user.AutoGroups, group1.ID) require.NoError(t, manager.Store.SaveUser(ctx, user)) - groupsUpdated, groupChangesAffectPeers, err := propagateUserGroupMemberships(ctx, manager.Store, account.Id) + groupsUpdated, groupChangesAffectPeers, err := manager.propagateUserGroupMemberships(ctx, manager.Store, account.Id) require.NoError(t, err) assert.True(t, groupsUpdated) assert.False(t, groupChangesAffectPeers) @@ -3719,7 +3719,7 @@ func TestPropagateUserGroupMemberships(t *testing.T) { }, true) require.NoError(t, err) - groupsUpdated, groupChangesAffectPeers, err := propagateUserGroupMemberships(ctx, manager.Store, account.Id) + groupsUpdated, groupChangesAffectPeers, err := manager.propagateUserGroupMemberships(ctx, manager.Store, account.Id) require.NoError(t, err) assert.True(t, groupsUpdated) assert.True(t, groupChangesAffectPeers) @@ -3734,7 +3734,7 @@ func TestPropagateUserGroupMemberships(t *testing.T) { }) t.Run("should not update membership or account peers when no changes", func(t *testing.T) { - groupsUpdated, groupChangesAffectPeers, err := propagateUserGroupMemberships(ctx, manager.Store, account.Id) + groupsUpdated, groupChangesAffectPeers, err := manager.propagateUserGroupMemberships(ctx, manager.Store, account.Id) require.NoError(t, err) assert.False(t, groupsUpdated) assert.False(t, groupChangesAffectPeers) @@ -3747,7 +3747,7 @@ func TestPropagateUserGroupMemberships(t *testing.T) { user.AutoGroups = []string{"group1"} require.NoError(t, manager.Store.SaveUser(ctx, user)) - groupsUpdated, groupChangesAffectPeers, err := propagateUserGroupMemberships(ctx, manager.Store, account.Id) + groupsUpdated, groupChangesAffectPeers, err := manager.propagateUserGroupMemberships(ctx, manager.Store, account.Id) require.NoError(t, err) assert.False(t, groupsUpdated) assert.False(t, groupChangesAffectPeers) diff --git a/management/server/group.go b/management/server/group.go index 7b5b9b86c..dadf7783b 100644 --- a/management/server/group.go +++ b/management/server/group.go @@ -174,6 +174,10 @@ func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, use return err } + if err = am.reconcileIPv6ForGroupChanges(ctx, transaction, accountID, []string{newGroup.ID}); err != nil { + return err + } + return transaction.IncrementNetworkSerial(ctx, accountID) }) if err != nil { @@ -278,37 +282,17 @@ func (am *DefaultAccountManager) UpdateGroups(ctx context.Context, accountID, us var globalErr error groupIDs := make([]string, 0, len(groups)) for _, newGroup := range groups { - err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { - if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil { - return err - } - - newGroup.AccountID = accountID - - if err = transaction.UpdateGroup(ctx, newGroup); err != nil { - return err - } - - err = transaction.IncrementNetworkSerial(ctx, accountID) - if err != nil { - return err - } - - events := am.prepareGroupEvents(ctx, transaction, accountID, userID, newGroup) - eventsToStore = append(eventsToStore, events...) - - groupIDs = append(groupIDs, newGroup.ID) - - return nil - }) + events, err := am.updateSingleGroup(ctx, accountID, userID, newGroup) if err != nil { log.WithContext(ctx).Errorf("failed to update group %s: %v", newGroup.ID, err) if len(groups) == 1 { return err } globalErr = errors.Join(globalErr, err) - // continue updating other groups + continue } + eventsToStore = append(eventsToStore, events...) + groupIDs = append(groupIDs, newGroup.ID) } updateAccountPeers, err = areGroupChangesAffectPeers(ctx, am.Store, accountID, groupIDs) @@ -327,6 +311,33 @@ func (am *DefaultAccountManager) UpdateGroups(ctx context.Context, accountID, us return globalErr } +func (am *DefaultAccountManager) updateSingleGroup(ctx context.Context, accountID, userID string, newGroup *types.Group) ([]func(), error) { + var events []func() + err := am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + if err := validateNewGroup(ctx, transaction, accountID, newGroup); err != nil { + return err + } + + newGroup.AccountID = accountID + + if err := transaction.UpdateGroup(ctx, newGroup); err != nil { + return err + } + + if err := am.reconcileIPv6ForGroupChanges(ctx, transaction, accountID, []string{newGroup.ID}); err != nil { + return err + } + + if err := transaction.IncrementNetworkSerial(ctx, accountID); err != nil { + return err + } + + events = am.prepareGroupEvents(ctx, transaction, accountID, userID, newGroup) + return nil + }) + return events, err +} + // prepareGroupEvents prepares a list of event functions to be stored. func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, transaction store.Store, accountID, userID string, newGroup *types.Group) []func() { var eventsToStore []func() @@ -458,6 +469,10 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us return err } + if err = am.reconcileIPv6ForGroupChanges(ctx, transaction, accountID, groupIDsToDelete); err != nil { + return err + } + return transaction.IncrementNetworkSerial(ctx, accountID) }) if err != nil { @@ -486,6 +501,10 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr return err } + if err = am.reconcileIPv6ForGroupChanges(ctx, transaction, accountID, []string{groupID}); err != nil { + return err + } + return transaction.IncrementNetworkSerial(ctx, accountID) }) if err != nil { @@ -552,6 +571,10 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID, return err } + if err = am.reconcileIPv6ForGroupChanges(ctx, transaction, accountID, []string{groupID}); err != nil { + return err + } + return transaction.IncrementNetworkSerial(ctx, accountID) }) if err != nil { diff --git a/management/server/group_ipv6_test.go b/management/server/group_ipv6_test.go new file mode 100644 index 000000000..e4603c879 --- /dev/null +++ b/management/server/group_ipv6_test.go @@ -0,0 +1,125 @@ +package server + +import ( + "context" + "math/rand" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" +) + +// TestGroupIPv6Assignment verifies that peers gain or lose IPv6 addresses +// when they are added to or removed from an IPv6-enabled group. +func TestGroupIPv6Assignment(t *testing.T) { + am, _, err := createManager(t) + require.NoError(t, err) + + ctx := context.Background() + userID := groupAdminUserID + + account, err := createAccount(am, "ipv6-grp-test", userID, "ipv6test.example.com") + require.NoError(t, err) + + // Allocate IPv6 subnet for the account + account.Network.NetV6 = types.AllocateIPv6Subnet(rand.New(rand.NewSource(time.Now().UnixNano()))) + require.NoError(t, am.Store.SaveAccount(ctx, account)) + + // Create setup key + setupKey, err := am.CreateSetupKey(ctx, account.Id, "ipv6-key", types.SetupKeyReusable, time.Hour, nil, 999, userID, false, false) + require.NoError(t, err) + + // Create an IPv6-enabled group + ipv6GroupID := "ipv6-enabled-grp" + err = am.CreateGroup(ctx, account.Id, userID, &types.Group{ + ID: ipv6GroupID, + Name: "IPv6 Enabled", + Issued: types.GroupIssuedAPI, + Peers: []string{}, + }) + require.NoError(t, err) + + // Enable IPv6 on that group + settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, account.Id) + require.NoError(t, err) + settings.IPv6EnabledGroups = []string{ipv6GroupID} + require.NoError(t, am.Store.SaveAccountSettings(ctx, account.Id, settings)) + + // Register a peer (will be in "All" group, not the IPv6 group) + key, err := wgtypes.GeneratePrivateKey() + require.NoError(t, err) + + peer, _, _, err := am.AddPeer(ctx, "", setupKey.Key, "", &nbpeer.Peer{ + Key: key.PublicKey().String(), + Meta: nbpeer.PeerSystemMeta{Hostname: "ipv6-test-host"}, + }, false) + require.NoError(t, err) + assert.False(t, peer.IPv6.IsValid(), "peer should not have IPv6 before joining an IPv6-enabled group") + + t.Run("GroupAddPeer assigns IPv6", func(t *testing.T) { + err := am.GroupAddPeer(ctx, account.Id, ipv6GroupID, peer.ID) + require.NoError(t, err) + + p, err := am.Store.GetPeerByID(ctx, store.LockingStrengthNone, account.Id, peer.ID) + require.NoError(t, err) + assert.True(t, p.IPv6.IsValid(), "peer should have an IPv6 address after joining the group") + }) + + t.Run("GroupDeletePeer clears IPv6", func(t *testing.T) { + err := am.GroupDeletePeer(ctx, account.Id, ipv6GroupID, peer.ID) + require.NoError(t, err) + + p, err := am.Store.GetPeerByID(ctx, store.LockingStrengthNone, account.Id, peer.ID) + require.NoError(t, err) + assert.False(t, p.IPv6.IsValid(), "peer should not have IPv6 after removal from the group") + }) + + t.Run("UpdateGroup with peer addition assigns IPv6", func(t *testing.T) { + grp, err := am.Store.GetGroupByID(ctx, store.LockingStrengthNone, account.Id, ipv6GroupID) + require.NoError(t, err) + + grp.Peers = append(grp.Peers, peer.ID) + err = am.UpdateGroup(ctx, account.Id, userID, grp) + require.NoError(t, err) + + p, err := am.Store.GetPeerByID(ctx, store.LockingStrengthNone, account.Id, peer.ID) + require.NoError(t, err) + assert.True(t, p.IPv6.IsValid(), "peer should have IPv6 after UpdateGroup adds it") + }) + + t.Run("UpdateGroup with peer removal clears IPv6", func(t *testing.T) { + grp, err := am.Store.GetGroupByID(ctx, store.LockingStrengthNone, account.Id, ipv6GroupID) + require.NoError(t, err) + + grp.Peers = []string{} + err = am.UpdateGroup(ctx, account.Id, userID, grp) + require.NoError(t, err) + + p, err := am.Store.GetPeerByID(ctx, store.LockingStrengthNone, account.Id, peer.ID) + require.NoError(t, err) + assert.False(t, p.IPv6.IsValid(), "peer should lose IPv6 after UpdateGroup removes it") + }) + + t.Run("non-IPv6 group changes do not affect IPv6", func(t *testing.T) { + err := am.CreateGroup(ctx, account.Id, userID, &types.Group{ + ID: "regular-grp", + Name: "Regular Group", + Issued: types.GroupIssuedAPI, + Peers: []string{}, + }) + require.NoError(t, err) + + err = am.GroupAddPeer(ctx, account.Id, "regular-grp", peer.ID) + require.NoError(t, err) + + p, err := am.Store.GetPeerByID(ctx, store.LockingStrengthNone, account.Id, peer.ID) + require.NoError(t, err) + assert.False(t, p.IPv6.IsValid(), "peer should not get IPv6 from a non-IPv6 group") + }) +} diff --git a/management/server/user.go b/management/server/user.go index c1f984f2f..647d4cb81 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "slices" "strings" "time" "unicode" @@ -824,6 +825,11 @@ func (am *DefaultAccountManager) processUserUpdate(ctx context.Context, transact } } } + + allGroupChanges := slices.Concat(removedGroups, addedGroups) + if err := am.reconcileIPv6ForGroupChanges(ctx, transaction, accountID, allGroupChanges); err != nil { + return false, nil, nil, nil, fmt.Errorf("reconcile IPv6 for group changes: %w", err) + } } updateAccountPeers := len(userPeers) > 0