mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-16 07:16:38 +00:00
Reconcile IPv6 addresses on group membership changes (#5837)
This commit is contained in:
@@ -348,7 +348,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
|
||||
}
|
||||
|
||||
if oldSettings.GroupsPropagationEnabled != newSettings.GroupsPropagationEnabled && newSettings.GroupsPropagationEnabled {
|
||||
groupsUpdated, groupChangesAffectPeers, err = propagateUserGroupMemberships(ctx, transaction, accountID)
|
||||
groupsUpdated, groupChangesAffectPeers, err = am.propagateUserGroupMemberships(ctx, transaction, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -1599,6 +1599,11 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth
|
||||
}
|
||||
}
|
||||
|
||||
allGroupChanges := slices.Concat(addNewGroups, removeOldGroups)
|
||||
if err = am.reconcileIPv6ForGroupChanges(ctx, transaction, userAuth.AccountId, allGroupChanges); err != nil {
|
||||
return fmt.Errorf("reconcile IPv6 for group changes: %w", err)
|
||||
}
|
||||
|
||||
if err = transaction.IncrementNetworkSerial(ctx, userAuth.AccountId); err != nil {
|
||||
return fmt.Errorf("error incrementing network serial: %w", err)
|
||||
}
|
||||
@@ -2160,7 +2165,7 @@ func (am *DefaultAccountManager) UpdateToPrimaryAccount(ctx context.Context, acc
|
||||
|
||||
// propagateUserGroupMemberships propagates all account users' group memberships to their peers.
|
||||
// Returns true if any groups were modified, true if those updates affect peers and an error.
|
||||
func propagateUserGroupMemberships(ctx context.Context, transaction store.Store, accountID string) (groupsUpdated bool, peersAffected bool, err error) {
|
||||
func (am *DefaultAccountManager) propagateUserGroupMemberships(ctx context.Context, transaction store.Store, accountID string) (groupsUpdated bool, peersAffected bool, err error) {
|
||||
users, err := transaction.GetAccountUsers(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return false, false, err
|
||||
@@ -2182,29 +2187,13 @@ func propagateUserGroupMemberships(ctx context.Context, transaction store.Store,
|
||||
}
|
||||
}
|
||||
|
||||
updatedGroups := []string{}
|
||||
for _, user := range users {
|
||||
userPeers, err := transaction.GetUserPeers(ctx, store.LockingStrengthNone, accountID, user.Id)
|
||||
if err != nil {
|
||||
return false, false, err
|
||||
}
|
||||
updatedGroups, err := propagateAutoGroupsForUsers(ctx, transaction, accountID, users, accountGroupPeers)
|
||||
if err != nil {
|
||||
return false, false, err
|
||||
}
|
||||
|
||||
for _, peer := range userPeers {
|
||||
for _, groupID := range user.AutoGroups {
|
||||
if _, exists := accountGroupPeers[groupID]; !exists {
|
||||
// we do not wanna create the groups here
|
||||
log.WithContext(ctx).Warnf("group %s does not exist for user group propagation", groupID)
|
||||
continue
|
||||
}
|
||||
if _, exists := accountGroupPeers[groupID][peer.ID]; exists {
|
||||
continue
|
||||
}
|
||||
if err := transaction.AddPeerToGroup(ctx, accountID, peer.ID, groupID); err != nil {
|
||||
return false, false, fmt.Errorf("error adding peer %s to group %s: %w", peer.ID, groupID, err)
|
||||
}
|
||||
updatedGroups = append(updatedGroups, groupID)
|
||||
}
|
||||
}
|
||||
if err = am.reconcileIPv6ForGroupChanges(ctx, transaction, accountID, updatedGroups); err != nil {
|
||||
return false, false, fmt.Errorf("reconcile IPv6 for group changes: %w", err)
|
||||
}
|
||||
|
||||
peersAffected, err = areGroupChangesAffectPeers(ctx, transaction, accountID, updatedGroups)
|
||||
@@ -2215,6 +2204,35 @@ func propagateUserGroupMemberships(ctx context.Context, transaction store.Store,
|
||||
return len(updatedGroups) > 0, peersAffected, nil
|
||||
}
|
||||
|
||||
// propagateAutoGroupsForUsers adds each user's peers to their AutoGroups where not already present.
|
||||
// Returns the list of group IDs that were modified.
|
||||
func propagateAutoGroupsForUsers(ctx context.Context, transaction store.Store, accountID string, users []*types.User, accountGroupPeers map[string]map[string]struct{}) ([]string, error) {
|
||||
var updatedGroups []string
|
||||
for _, user := range users {
|
||||
userPeers, err := transaction.GetUserPeers(ctx, store.LockingStrengthNone, accountID, user.Id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, peer := range userPeers {
|
||||
for _, groupID := range user.AutoGroups {
|
||||
if _, exists := accountGroupPeers[groupID]; !exists {
|
||||
log.WithContext(ctx).Warnf("group %s does not exist for user group propagation", groupID)
|
||||
continue
|
||||
}
|
||||
if _, exists := accountGroupPeers[groupID][peer.ID]; exists {
|
||||
continue
|
||||
}
|
||||
if err := transaction.AddPeerToGroup(ctx, accountID, peer.ID, groupID); err != nil {
|
||||
return nil, fmt.Errorf("error adding peer %s to group %s: %w", peer.ID, groupID, err)
|
||||
}
|
||||
updatedGroups = append(updatedGroups, groupID)
|
||||
}
|
||||
}
|
||||
}
|
||||
return updatedGroups, nil
|
||||
}
|
||||
|
||||
// reallocateAccountPeerIPs re-allocates all peer IPs when the network range changes
|
||||
func (am *DefaultAccountManager) reallocateAccountPeerIPs(ctx context.Context, transaction store.Store, accountID string, newNetworkRange netip.Prefix) error {
|
||||
if !newNetworkRange.IsValid() {
|
||||
@@ -2315,6 +2333,40 @@ func (am *DefaultAccountManager) updatePeerIPv6Addresses(ctx context.Context, tr
|
||||
return nil
|
||||
}
|
||||
|
||||
// reconcileIPv6ForGroupChanges checks whether the given group IDs overlap with
|
||||
// the account's IPv6EnabledGroups. If they do, it runs a full IPv6 address
|
||||
// reconciliation so that peers gaining or losing membership in an IPv6-enabled
|
||||
// group get their addresses assigned or removed.
|
||||
func (am *DefaultAccountManager) reconcileIPv6ForGroupChanges(ctx context.Context, transaction store.Store, accountID string, groupIDs []string) error {
|
||||
settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get account settings: %w", err)
|
||||
}
|
||||
|
||||
if len(settings.IPv6EnabledGroups) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
enabledSet := make(map[string]struct{}, len(settings.IPv6EnabledGroups))
|
||||
for _, gid := range settings.IPv6EnabledGroups {
|
||||
enabledSet[gid] = struct{}{}
|
||||
}
|
||||
|
||||
affected := false
|
||||
for _, gid := range groupIDs {
|
||||
if _, ok := enabledSet[gid]; ok {
|
||||
affected = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !affected {
|
||||
return nil
|
||||
}
|
||||
|
||||
return am.updatePeerIPv6Addresses(ctx, transaction, accountID, settings)
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) ensureIPv6Subnet(ctx context.Context, transaction store.Store, accountID string, settings *types.Settings, network *types.Network) error {
|
||||
if settings.NetworkRangeV6.IsValid() {
|
||||
network.NetV6 = net.IPNet{
|
||||
|
||||
@@ -3665,7 +3665,7 @@ func TestPropagateUserGroupMemberships(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("should skip propagation when the user has no groups", func(t *testing.T) {
|
||||
groupsUpdated, groupChangesAffectPeers, err := propagateUserGroupMemberships(ctx, manager.Store, account.Id)
|
||||
groupsUpdated, groupChangesAffectPeers, err := manager.propagateUserGroupMemberships(ctx, manager.Store, account.Id)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, groupsUpdated)
|
||||
assert.False(t, groupChangesAffectPeers)
|
||||
@@ -3681,7 +3681,7 @@ func TestPropagateUserGroupMemberships(t *testing.T) {
|
||||
user.AutoGroups = append(user.AutoGroups, group1.ID)
|
||||
require.NoError(t, manager.Store.SaveUser(ctx, user))
|
||||
|
||||
groupsUpdated, groupChangesAffectPeers, err := propagateUserGroupMemberships(ctx, manager.Store, account.Id)
|
||||
groupsUpdated, groupChangesAffectPeers, err := manager.propagateUserGroupMemberships(ctx, manager.Store, account.Id)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, groupsUpdated)
|
||||
assert.False(t, groupChangesAffectPeers)
|
||||
@@ -3719,7 +3719,7 @@ func TestPropagateUserGroupMemberships(t *testing.T) {
|
||||
}, true)
|
||||
require.NoError(t, err)
|
||||
|
||||
groupsUpdated, groupChangesAffectPeers, err := propagateUserGroupMemberships(ctx, manager.Store, account.Id)
|
||||
groupsUpdated, groupChangesAffectPeers, err := manager.propagateUserGroupMemberships(ctx, manager.Store, account.Id)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, groupsUpdated)
|
||||
assert.True(t, groupChangesAffectPeers)
|
||||
@@ -3734,7 +3734,7 @@ func TestPropagateUserGroupMemberships(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("should not update membership or account peers when no changes", func(t *testing.T) {
|
||||
groupsUpdated, groupChangesAffectPeers, err := propagateUserGroupMemberships(ctx, manager.Store, account.Id)
|
||||
groupsUpdated, groupChangesAffectPeers, err := manager.propagateUserGroupMemberships(ctx, manager.Store, account.Id)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, groupsUpdated)
|
||||
assert.False(t, groupChangesAffectPeers)
|
||||
@@ -3747,7 +3747,7 @@ func TestPropagateUserGroupMemberships(t *testing.T) {
|
||||
user.AutoGroups = []string{"group1"}
|
||||
require.NoError(t, manager.Store.SaveUser(ctx, user))
|
||||
|
||||
groupsUpdated, groupChangesAffectPeers, err := propagateUserGroupMemberships(ctx, manager.Store, account.Id)
|
||||
groupsUpdated, groupChangesAffectPeers, err := manager.propagateUserGroupMemberships(ctx, manager.Store, account.Id)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, groupsUpdated)
|
||||
assert.False(t, groupChangesAffectPeers)
|
||||
|
||||
@@ -174,6 +174,10 @@ func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, use
|
||||
return err
|
||||
}
|
||||
|
||||
if err = am.reconcileIPv6ForGroupChanges(ctx, transaction, accountID, []string{newGroup.ID}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return transaction.IncrementNetworkSerial(ctx, accountID)
|
||||
})
|
||||
if err != nil {
|
||||
@@ -278,37 +282,17 @@ func (am *DefaultAccountManager) UpdateGroups(ctx context.Context, accountID, us
|
||||
var globalErr error
|
||||
groupIDs := make([]string, 0, len(groups))
|
||||
for _, newGroup := range groups {
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
newGroup.AccountID = accountID
|
||||
|
||||
if err = transaction.UpdateGroup(ctx, newGroup); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = transaction.IncrementNetworkSerial(ctx, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
events := am.prepareGroupEvents(ctx, transaction, accountID, userID, newGroup)
|
||||
eventsToStore = append(eventsToStore, events...)
|
||||
|
||||
groupIDs = append(groupIDs, newGroup.ID)
|
||||
|
||||
return nil
|
||||
})
|
||||
events, err := am.updateSingleGroup(ctx, accountID, userID, newGroup)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to update group %s: %v", newGroup.ID, err)
|
||||
if len(groups) == 1 {
|
||||
return err
|
||||
}
|
||||
globalErr = errors.Join(globalErr, err)
|
||||
// continue updating other groups
|
||||
continue
|
||||
}
|
||||
eventsToStore = append(eventsToStore, events...)
|
||||
groupIDs = append(groupIDs, newGroup.ID)
|
||||
}
|
||||
|
||||
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, am.Store, accountID, groupIDs)
|
||||
@@ -327,6 +311,33 @@ func (am *DefaultAccountManager) UpdateGroups(ctx context.Context, accountID, us
|
||||
return globalErr
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) updateSingleGroup(ctx context.Context, accountID, userID string, newGroup *types.Group) ([]func(), error) {
|
||||
var events []func()
|
||||
err := am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
if err := validateNewGroup(ctx, transaction, accountID, newGroup); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
newGroup.AccountID = accountID
|
||||
|
||||
if err := transaction.UpdateGroup(ctx, newGroup); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := am.reconcileIPv6ForGroupChanges(ctx, transaction, accountID, []string{newGroup.ID}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
events = am.prepareGroupEvents(ctx, transaction, accountID, userID, newGroup)
|
||||
return nil
|
||||
})
|
||||
return events, err
|
||||
}
|
||||
|
||||
// prepareGroupEvents prepares a list of event functions to be stored.
|
||||
func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, transaction store.Store, accountID, userID string, newGroup *types.Group) []func() {
|
||||
var eventsToStore []func()
|
||||
@@ -458,6 +469,10 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us
|
||||
return err
|
||||
}
|
||||
|
||||
if err = am.reconcileIPv6ForGroupChanges(ctx, transaction, accountID, groupIDsToDelete); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return transaction.IncrementNetworkSerial(ctx, accountID)
|
||||
})
|
||||
if err != nil {
|
||||
@@ -486,6 +501,10 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr
|
||||
return err
|
||||
}
|
||||
|
||||
if err = am.reconcileIPv6ForGroupChanges(ctx, transaction, accountID, []string{groupID}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return transaction.IncrementNetworkSerial(ctx, accountID)
|
||||
})
|
||||
if err != nil {
|
||||
@@ -552,6 +571,10 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID,
|
||||
return err
|
||||
}
|
||||
|
||||
if err = am.reconcileIPv6ForGroupChanges(ctx, transaction, accountID, []string{groupID}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return transaction.IncrementNetworkSerial(ctx, accountID)
|
||||
})
|
||||
if err != nil {
|
||||
|
||||
125
management/server/group_ipv6_test.go
Normal file
125
management/server/group_ipv6_test.go
Normal file
@@ -0,0 +1,125 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"math/rand"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
// TestGroupIPv6Assignment verifies that peers gain or lose IPv6 addresses
|
||||
// when they are added to or removed from an IPv6-enabled group.
|
||||
func TestGroupIPv6Assignment(t *testing.T) {
|
||||
am, _, err := createManager(t)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
userID := groupAdminUserID
|
||||
|
||||
account, err := createAccount(am, "ipv6-grp-test", userID, "ipv6test.example.com")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Allocate IPv6 subnet for the account
|
||||
account.Network.NetV6 = types.AllocateIPv6Subnet(rand.New(rand.NewSource(time.Now().UnixNano())))
|
||||
require.NoError(t, am.Store.SaveAccount(ctx, account))
|
||||
|
||||
// Create setup key
|
||||
setupKey, err := am.CreateSetupKey(ctx, account.Id, "ipv6-key", types.SetupKeyReusable, time.Hour, nil, 999, userID, false, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create an IPv6-enabled group
|
||||
ipv6GroupID := "ipv6-enabled-grp"
|
||||
err = am.CreateGroup(ctx, account.Id, userID, &types.Group{
|
||||
ID: ipv6GroupID,
|
||||
Name: "IPv6 Enabled",
|
||||
Issued: types.GroupIssuedAPI,
|
||||
Peers: []string{},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Enable IPv6 on that group
|
||||
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, account.Id)
|
||||
require.NoError(t, err)
|
||||
settings.IPv6EnabledGroups = []string{ipv6GroupID}
|
||||
require.NoError(t, am.Store.SaveAccountSettings(ctx, account.Id, settings))
|
||||
|
||||
// Register a peer (will be in "All" group, not the IPv6 group)
|
||||
key, err := wgtypes.GeneratePrivateKey()
|
||||
require.NoError(t, err)
|
||||
|
||||
peer, _, _, err := am.AddPeer(ctx, "", setupKey.Key, "", &nbpeer.Peer{
|
||||
Key: key.PublicKey().String(),
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: "ipv6-test-host"},
|
||||
}, false)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, peer.IPv6.IsValid(), "peer should not have IPv6 before joining an IPv6-enabled group")
|
||||
|
||||
t.Run("GroupAddPeer assigns IPv6", func(t *testing.T) {
|
||||
err := am.GroupAddPeer(ctx, account.Id, ipv6GroupID, peer.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
p, err := am.Store.GetPeerByID(ctx, store.LockingStrengthNone, account.Id, peer.ID)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, p.IPv6.IsValid(), "peer should have an IPv6 address after joining the group")
|
||||
})
|
||||
|
||||
t.Run("GroupDeletePeer clears IPv6", func(t *testing.T) {
|
||||
err := am.GroupDeletePeer(ctx, account.Id, ipv6GroupID, peer.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
p, err := am.Store.GetPeerByID(ctx, store.LockingStrengthNone, account.Id, peer.ID)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, p.IPv6.IsValid(), "peer should not have IPv6 after removal from the group")
|
||||
})
|
||||
|
||||
t.Run("UpdateGroup with peer addition assigns IPv6", func(t *testing.T) {
|
||||
grp, err := am.Store.GetGroupByID(ctx, store.LockingStrengthNone, account.Id, ipv6GroupID)
|
||||
require.NoError(t, err)
|
||||
|
||||
grp.Peers = append(grp.Peers, peer.ID)
|
||||
err = am.UpdateGroup(ctx, account.Id, userID, grp)
|
||||
require.NoError(t, err)
|
||||
|
||||
p, err := am.Store.GetPeerByID(ctx, store.LockingStrengthNone, account.Id, peer.ID)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, p.IPv6.IsValid(), "peer should have IPv6 after UpdateGroup adds it")
|
||||
})
|
||||
|
||||
t.Run("UpdateGroup with peer removal clears IPv6", func(t *testing.T) {
|
||||
grp, err := am.Store.GetGroupByID(ctx, store.LockingStrengthNone, account.Id, ipv6GroupID)
|
||||
require.NoError(t, err)
|
||||
|
||||
grp.Peers = []string{}
|
||||
err = am.UpdateGroup(ctx, account.Id, userID, grp)
|
||||
require.NoError(t, err)
|
||||
|
||||
p, err := am.Store.GetPeerByID(ctx, store.LockingStrengthNone, account.Id, peer.ID)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, p.IPv6.IsValid(), "peer should lose IPv6 after UpdateGroup removes it")
|
||||
})
|
||||
|
||||
t.Run("non-IPv6 group changes do not affect IPv6", func(t *testing.T) {
|
||||
err := am.CreateGroup(ctx, account.Id, userID, &types.Group{
|
||||
ID: "regular-grp",
|
||||
Name: "Regular Group",
|
||||
Issued: types.GroupIssuedAPI,
|
||||
Peers: []string{},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = am.GroupAddPeer(ctx, account.Id, "regular-grp", peer.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
p, err := am.Store.GetPeerByID(ctx, store.LockingStrengthNone, account.Id, peer.ID)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, p.IPv6.IsValid(), "peer should not get IPv6 from a non-IPv6 group")
|
||||
})
|
||||
}
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode"
|
||||
@@ -824,6 +825,11 @@ func (am *DefaultAccountManager) processUserUpdate(ctx context.Context, transact
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
allGroupChanges := slices.Concat(removedGroups, addedGroups)
|
||||
if err := am.reconcileIPv6ForGroupChanges(ctx, transaction, accountID, allGroupChanges); err != nil {
|
||||
return false, nil, nil, nil, fmt.Errorf("reconcile IPv6 for group changes: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
updateAccountPeers := len(userPeers) > 0
|
||||
|
||||
Reference in New Issue
Block a user