Reconcile IPv6 addresses on group membership changes (#5837)

This commit is contained in:
Viktor Liu
2026-04-10 09:14:42 +08:00
committed by GitHub
parent 6e05a2ebe9
commit a1cb952764
5 changed files with 259 additions and 53 deletions

View File

@@ -348,7 +348,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
} }
if oldSettings.GroupsPropagationEnabled != newSettings.GroupsPropagationEnabled && newSettings.GroupsPropagationEnabled { if oldSettings.GroupsPropagationEnabled != newSettings.GroupsPropagationEnabled && newSettings.GroupsPropagationEnabled {
groupsUpdated, groupChangesAffectPeers, err = propagateUserGroupMemberships(ctx, transaction, accountID) groupsUpdated, groupChangesAffectPeers, err = am.propagateUserGroupMemberships(ctx, transaction, accountID)
if err != nil { if err != nil {
return err return err
} }
@@ -1599,6 +1599,11 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth
} }
} }
allGroupChanges := slices.Concat(addNewGroups, removeOldGroups)
if err = am.reconcileIPv6ForGroupChanges(ctx, transaction, userAuth.AccountId, allGroupChanges); err != nil {
return fmt.Errorf("reconcile IPv6 for group changes: %w", err)
}
if err = transaction.IncrementNetworkSerial(ctx, userAuth.AccountId); err != nil { if err = transaction.IncrementNetworkSerial(ctx, userAuth.AccountId); err != nil {
return fmt.Errorf("error incrementing network serial: %w", err) return fmt.Errorf("error incrementing network serial: %w", err)
} }
@@ -2160,7 +2165,7 @@ func (am *DefaultAccountManager) UpdateToPrimaryAccount(ctx context.Context, acc
// propagateUserGroupMemberships propagates all account users' group memberships to their peers. // propagateUserGroupMemberships propagates all account users' group memberships to their peers.
// Returns true if any groups were modified, true if those updates affect peers and an error. // Returns true if any groups were modified, true if those updates affect peers and an error.
func propagateUserGroupMemberships(ctx context.Context, transaction store.Store, accountID string) (groupsUpdated bool, peersAffected bool, err error) { func (am *DefaultAccountManager) propagateUserGroupMemberships(ctx context.Context, transaction store.Store, accountID string) (groupsUpdated bool, peersAffected bool, err error) {
users, err := transaction.GetAccountUsers(ctx, store.LockingStrengthNone, accountID) users, err := transaction.GetAccountUsers(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return false, false, err return false, false, err
@@ -2182,29 +2187,13 @@ func propagateUserGroupMemberships(ctx context.Context, transaction store.Store,
} }
} }
updatedGroups := []string{} updatedGroups, err := propagateAutoGroupsForUsers(ctx, transaction, accountID, users, accountGroupPeers)
for _, user := range users { if err != nil {
userPeers, err := transaction.GetUserPeers(ctx, store.LockingStrengthNone, accountID, user.Id) return false, false, err
if err != nil { }
return false, false, err
}
for _, peer := range userPeers { if err = am.reconcileIPv6ForGroupChanges(ctx, transaction, accountID, updatedGroups); err != nil {
for _, groupID := range user.AutoGroups { return false, false, fmt.Errorf("reconcile IPv6 for group changes: %w", err)
if _, exists := accountGroupPeers[groupID]; !exists {
// we do not wanna create the groups here
log.WithContext(ctx).Warnf("group %s does not exist for user group propagation", groupID)
continue
}
if _, exists := accountGroupPeers[groupID][peer.ID]; exists {
continue
}
if err := transaction.AddPeerToGroup(ctx, accountID, peer.ID, groupID); err != nil {
return false, false, fmt.Errorf("error adding peer %s to group %s: %w", peer.ID, groupID, err)
}
updatedGroups = append(updatedGroups, groupID)
}
}
} }
peersAffected, err = areGroupChangesAffectPeers(ctx, transaction, accountID, updatedGroups) peersAffected, err = areGroupChangesAffectPeers(ctx, transaction, accountID, updatedGroups)
@@ -2215,6 +2204,35 @@ func propagateUserGroupMemberships(ctx context.Context, transaction store.Store,
return len(updatedGroups) > 0, peersAffected, nil return len(updatedGroups) > 0, peersAffected, nil
} }
// propagateAutoGroupsForUsers adds each user's peers to their AutoGroups where not already present.
// Returns the list of group IDs that were modified.
func propagateAutoGroupsForUsers(ctx context.Context, transaction store.Store, accountID string, users []*types.User, accountGroupPeers map[string]map[string]struct{}) ([]string, error) {
var updatedGroups []string
for _, user := range users {
userPeers, err := transaction.GetUserPeers(ctx, store.LockingStrengthNone, accountID, user.Id)
if err != nil {
return nil, err
}
for _, peer := range userPeers {
for _, groupID := range user.AutoGroups {
if _, exists := accountGroupPeers[groupID]; !exists {
log.WithContext(ctx).Warnf("group %s does not exist for user group propagation", groupID)
continue
}
if _, exists := accountGroupPeers[groupID][peer.ID]; exists {
continue
}
if err := transaction.AddPeerToGroup(ctx, accountID, peer.ID, groupID); err != nil {
return nil, fmt.Errorf("error adding peer %s to group %s: %w", peer.ID, groupID, err)
}
updatedGroups = append(updatedGroups, groupID)
}
}
}
return updatedGroups, nil
}
// reallocateAccountPeerIPs re-allocates all peer IPs when the network range changes // reallocateAccountPeerIPs re-allocates all peer IPs when the network range changes
func (am *DefaultAccountManager) reallocateAccountPeerIPs(ctx context.Context, transaction store.Store, accountID string, newNetworkRange netip.Prefix) error { func (am *DefaultAccountManager) reallocateAccountPeerIPs(ctx context.Context, transaction store.Store, accountID string, newNetworkRange netip.Prefix) error {
if !newNetworkRange.IsValid() { if !newNetworkRange.IsValid() {
@@ -2315,6 +2333,40 @@ func (am *DefaultAccountManager) updatePeerIPv6Addresses(ctx context.Context, tr
return nil return nil
} }
// reconcileIPv6ForGroupChanges checks whether the given group IDs overlap with
// the account's IPv6EnabledGroups. If they do, it runs a full IPv6 address
// reconciliation so that peers gaining or losing membership in an IPv6-enabled
// group get their addresses assigned or removed.
func (am *DefaultAccountManager) reconcileIPv6ForGroupChanges(ctx context.Context, transaction store.Store, accountID string, groupIDs []string) error {
settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return fmt.Errorf("get account settings: %w", err)
}
if len(settings.IPv6EnabledGroups) == 0 {
return nil
}
enabledSet := make(map[string]struct{}, len(settings.IPv6EnabledGroups))
for _, gid := range settings.IPv6EnabledGroups {
enabledSet[gid] = struct{}{}
}
affected := false
for _, gid := range groupIDs {
if _, ok := enabledSet[gid]; ok {
affected = true
break
}
}
if !affected {
return nil
}
return am.updatePeerIPv6Addresses(ctx, transaction, accountID, settings)
}
func (am *DefaultAccountManager) ensureIPv6Subnet(ctx context.Context, transaction store.Store, accountID string, settings *types.Settings, network *types.Network) error { func (am *DefaultAccountManager) ensureIPv6Subnet(ctx context.Context, transaction store.Store, accountID string, settings *types.Settings, network *types.Network) error {
if settings.NetworkRangeV6.IsValid() { if settings.NetworkRangeV6.IsValid() {
network.NetV6 = net.IPNet{ network.NetV6 = net.IPNet{

View File

@@ -3665,7 +3665,7 @@ func TestPropagateUserGroupMemberships(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
t.Run("should skip propagation when the user has no groups", func(t *testing.T) { t.Run("should skip propagation when the user has no groups", func(t *testing.T) {
groupsUpdated, groupChangesAffectPeers, err := propagateUserGroupMemberships(ctx, manager.Store, account.Id) groupsUpdated, groupChangesAffectPeers, err := manager.propagateUserGroupMemberships(ctx, manager.Store, account.Id)
require.NoError(t, err) require.NoError(t, err)
assert.False(t, groupsUpdated) assert.False(t, groupsUpdated)
assert.False(t, groupChangesAffectPeers) assert.False(t, groupChangesAffectPeers)
@@ -3681,7 +3681,7 @@ func TestPropagateUserGroupMemberships(t *testing.T) {
user.AutoGroups = append(user.AutoGroups, group1.ID) user.AutoGroups = append(user.AutoGroups, group1.ID)
require.NoError(t, manager.Store.SaveUser(ctx, user)) require.NoError(t, manager.Store.SaveUser(ctx, user))
groupsUpdated, groupChangesAffectPeers, err := propagateUserGroupMemberships(ctx, manager.Store, account.Id) groupsUpdated, groupChangesAffectPeers, err := manager.propagateUserGroupMemberships(ctx, manager.Store, account.Id)
require.NoError(t, err) require.NoError(t, err)
assert.True(t, groupsUpdated) assert.True(t, groupsUpdated)
assert.False(t, groupChangesAffectPeers) assert.False(t, groupChangesAffectPeers)
@@ -3719,7 +3719,7 @@ func TestPropagateUserGroupMemberships(t *testing.T) {
}, true) }, true)
require.NoError(t, err) require.NoError(t, err)
groupsUpdated, groupChangesAffectPeers, err := propagateUserGroupMemberships(ctx, manager.Store, account.Id) groupsUpdated, groupChangesAffectPeers, err := manager.propagateUserGroupMemberships(ctx, manager.Store, account.Id)
require.NoError(t, err) require.NoError(t, err)
assert.True(t, groupsUpdated) assert.True(t, groupsUpdated)
assert.True(t, groupChangesAffectPeers) assert.True(t, groupChangesAffectPeers)
@@ -3734,7 +3734,7 @@ func TestPropagateUserGroupMemberships(t *testing.T) {
}) })
t.Run("should not update membership or account peers when no changes", func(t *testing.T) { t.Run("should not update membership or account peers when no changes", func(t *testing.T) {
groupsUpdated, groupChangesAffectPeers, err := propagateUserGroupMemberships(ctx, manager.Store, account.Id) groupsUpdated, groupChangesAffectPeers, err := manager.propagateUserGroupMemberships(ctx, manager.Store, account.Id)
require.NoError(t, err) require.NoError(t, err)
assert.False(t, groupsUpdated) assert.False(t, groupsUpdated)
assert.False(t, groupChangesAffectPeers) assert.False(t, groupChangesAffectPeers)
@@ -3747,7 +3747,7 @@ func TestPropagateUserGroupMemberships(t *testing.T) {
user.AutoGroups = []string{"group1"} user.AutoGroups = []string{"group1"}
require.NoError(t, manager.Store.SaveUser(ctx, user)) require.NoError(t, manager.Store.SaveUser(ctx, user))
groupsUpdated, groupChangesAffectPeers, err := propagateUserGroupMemberships(ctx, manager.Store, account.Id) groupsUpdated, groupChangesAffectPeers, err := manager.propagateUserGroupMemberships(ctx, manager.Store, account.Id)
require.NoError(t, err) require.NoError(t, err)
assert.False(t, groupsUpdated) assert.False(t, groupsUpdated)
assert.False(t, groupChangesAffectPeers) assert.False(t, groupChangesAffectPeers)

View File

@@ -174,6 +174,10 @@ func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, use
return err return err
} }
if err = am.reconcileIPv6ForGroupChanges(ctx, transaction, accountID, []string{newGroup.ID}); err != nil {
return err
}
return transaction.IncrementNetworkSerial(ctx, accountID) return transaction.IncrementNetworkSerial(ctx, accountID)
}) })
if err != nil { if err != nil {
@@ -278,37 +282,17 @@ func (am *DefaultAccountManager) UpdateGroups(ctx context.Context, accountID, us
var globalErr error var globalErr error
groupIDs := make([]string, 0, len(groups)) groupIDs := make([]string, 0, len(groups))
for _, newGroup := range groups { for _, newGroup := range groups {
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { events, err := am.updateSingleGroup(ctx, accountID, userID, newGroup)
if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil {
return err
}
newGroup.AccountID = accountID
if err = transaction.UpdateGroup(ctx, newGroup); err != nil {
return err
}
err = transaction.IncrementNetworkSerial(ctx, accountID)
if err != nil {
return err
}
events := am.prepareGroupEvents(ctx, transaction, accountID, userID, newGroup)
eventsToStore = append(eventsToStore, events...)
groupIDs = append(groupIDs, newGroup.ID)
return nil
})
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("failed to update group %s: %v", newGroup.ID, err) log.WithContext(ctx).Errorf("failed to update group %s: %v", newGroup.ID, err)
if len(groups) == 1 { if len(groups) == 1 {
return err return err
} }
globalErr = errors.Join(globalErr, err) globalErr = errors.Join(globalErr, err)
// continue updating other groups continue
} }
eventsToStore = append(eventsToStore, events...)
groupIDs = append(groupIDs, newGroup.ID)
} }
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, am.Store, accountID, groupIDs) updateAccountPeers, err = areGroupChangesAffectPeers(ctx, am.Store, accountID, groupIDs)
@@ -327,6 +311,33 @@ func (am *DefaultAccountManager) UpdateGroups(ctx context.Context, accountID, us
return globalErr return globalErr
} }
func (am *DefaultAccountManager) updateSingleGroup(ctx context.Context, accountID, userID string, newGroup *types.Group) ([]func(), error) {
var events []func()
err := am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if err := validateNewGroup(ctx, transaction, accountID, newGroup); err != nil {
return err
}
newGroup.AccountID = accountID
if err := transaction.UpdateGroup(ctx, newGroup); err != nil {
return err
}
if err := am.reconcileIPv6ForGroupChanges(ctx, transaction, accountID, []string{newGroup.ID}); err != nil {
return err
}
if err := transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
return err
}
events = am.prepareGroupEvents(ctx, transaction, accountID, userID, newGroup)
return nil
})
return events, err
}
// prepareGroupEvents prepares a list of event functions to be stored. // prepareGroupEvents prepares a list of event functions to be stored.
func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, transaction store.Store, accountID, userID string, newGroup *types.Group) []func() { func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, transaction store.Store, accountID, userID string, newGroup *types.Group) []func() {
var eventsToStore []func() var eventsToStore []func()
@@ -458,6 +469,10 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us
return err return err
} }
if err = am.reconcileIPv6ForGroupChanges(ctx, transaction, accountID, groupIDsToDelete); err != nil {
return err
}
return transaction.IncrementNetworkSerial(ctx, accountID) return transaction.IncrementNetworkSerial(ctx, accountID)
}) })
if err != nil { if err != nil {
@@ -486,6 +501,10 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr
return err return err
} }
if err = am.reconcileIPv6ForGroupChanges(ctx, transaction, accountID, []string{groupID}); err != nil {
return err
}
return transaction.IncrementNetworkSerial(ctx, accountID) return transaction.IncrementNetworkSerial(ctx, accountID)
}) })
if err != nil { if err != nil {
@@ -552,6 +571,10 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID,
return err return err
} }
if err = am.reconcileIPv6ForGroupChanges(ctx, transaction, accountID, []string{groupID}); err != nil {
return err
}
return transaction.IncrementNetworkSerial(ctx, accountID) return transaction.IncrementNetworkSerial(ctx, accountID)
}) })
if err != nil { if err != nil {

View File

@@ -0,0 +1,125 @@
package server
import (
"context"
"math/rand"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
)
// TestGroupIPv6Assignment verifies that peers gain or lose IPv6 addresses
// when they are added to or removed from an IPv6-enabled group.
func TestGroupIPv6Assignment(t *testing.T) {
am, _, err := createManager(t)
require.NoError(t, err)
ctx := context.Background()
userID := groupAdminUserID
account, err := createAccount(am, "ipv6-grp-test", userID, "ipv6test.example.com")
require.NoError(t, err)
// Allocate IPv6 subnet for the account
account.Network.NetV6 = types.AllocateIPv6Subnet(rand.New(rand.NewSource(time.Now().UnixNano())))
require.NoError(t, am.Store.SaveAccount(ctx, account))
// Create setup key
setupKey, err := am.CreateSetupKey(ctx, account.Id, "ipv6-key", types.SetupKeyReusable, time.Hour, nil, 999, userID, false, false)
require.NoError(t, err)
// Create an IPv6-enabled group
ipv6GroupID := "ipv6-enabled-grp"
err = am.CreateGroup(ctx, account.Id, userID, &types.Group{
ID: ipv6GroupID,
Name: "IPv6 Enabled",
Issued: types.GroupIssuedAPI,
Peers: []string{},
})
require.NoError(t, err)
// Enable IPv6 on that group
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, account.Id)
require.NoError(t, err)
settings.IPv6EnabledGroups = []string{ipv6GroupID}
require.NoError(t, am.Store.SaveAccountSettings(ctx, account.Id, settings))
// Register a peer (will be in "All" group, not the IPv6 group)
key, err := wgtypes.GeneratePrivateKey()
require.NoError(t, err)
peer, _, _, err := am.AddPeer(ctx, "", setupKey.Key, "", &nbpeer.Peer{
Key: key.PublicKey().String(),
Meta: nbpeer.PeerSystemMeta{Hostname: "ipv6-test-host"},
}, false)
require.NoError(t, err)
assert.False(t, peer.IPv6.IsValid(), "peer should not have IPv6 before joining an IPv6-enabled group")
t.Run("GroupAddPeer assigns IPv6", func(t *testing.T) {
err := am.GroupAddPeer(ctx, account.Id, ipv6GroupID, peer.ID)
require.NoError(t, err)
p, err := am.Store.GetPeerByID(ctx, store.LockingStrengthNone, account.Id, peer.ID)
require.NoError(t, err)
assert.True(t, p.IPv6.IsValid(), "peer should have an IPv6 address after joining the group")
})
t.Run("GroupDeletePeer clears IPv6", func(t *testing.T) {
err := am.GroupDeletePeer(ctx, account.Id, ipv6GroupID, peer.ID)
require.NoError(t, err)
p, err := am.Store.GetPeerByID(ctx, store.LockingStrengthNone, account.Id, peer.ID)
require.NoError(t, err)
assert.False(t, p.IPv6.IsValid(), "peer should not have IPv6 after removal from the group")
})
t.Run("UpdateGroup with peer addition assigns IPv6", func(t *testing.T) {
grp, err := am.Store.GetGroupByID(ctx, store.LockingStrengthNone, account.Id, ipv6GroupID)
require.NoError(t, err)
grp.Peers = append(grp.Peers, peer.ID)
err = am.UpdateGroup(ctx, account.Id, userID, grp)
require.NoError(t, err)
p, err := am.Store.GetPeerByID(ctx, store.LockingStrengthNone, account.Id, peer.ID)
require.NoError(t, err)
assert.True(t, p.IPv6.IsValid(), "peer should have IPv6 after UpdateGroup adds it")
})
t.Run("UpdateGroup with peer removal clears IPv6", func(t *testing.T) {
grp, err := am.Store.GetGroupByID(ctx, store.LockingStrengthNone, account.Id, ipv6GroupID)
require.NoError(t, err)
grp.Peers = []string{}
err = am.UpdateGroup(ctx, account.Id, userID, grp)
require.NoError(t, err)
p, err := am.Store.GetPeerByID(ctx, store.LockingStrengthNone, account.Id, peer.ID)
require.NoError(t, err)
assert.False(t, p.IPv6.IsValid(), "peer should lose IPv6 after UpdateGroup removes it")
})
t.Run("non-IPv6 group changes do not affect IPv6", func(t *testing.T) {
err := am.CreateGroup(ctx, account.Id, userID, &types.Group{
ID: "regular-grp",
Name: "Regular Group",
Issued: types.GroupIssuedAPI,
Peers: []string{},
})
require.NoError(t, err)
err = am.GroupAddPeer(ctx, account.Id, "regular-grp", peer.ID)
require.NoError(t, err)
p, err := am.Store.GetPeerByID(ctx, store.LockingStrengthNone, account.Id, peer.ID)
require.NoError(t, err)
assert.False(t, p.IPv6.IsValid(), "peer should not get IPv6 from a non-IPv6 group")
})
}

View File

@@ -4,6 +4,7 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"slices"
"strings" "strings"
"time" "time"
"unicode" "unicode"
@@ -824,6 +825,11 @@ func (am *DefaultAccountManager) processUserUpdate(ctx context.Context, transact
} }
} }
} }
allGroupChanges := slices.Concat(removedGroups, addedGroups)
if err := am.reconcileIPv6ForGroupChanges(ctx, transaction, accountID, allGroupChanges); err != nil {
return false, nil, nil, nil, fmt.Errorf("reconcile IPv6 for group changes: %w", err)
}
} }
updateAccountPeers := len(userPeers) > 0 updateAccountPeers := len(userPeers) > 0