diff --git a/.github/workflows/golangci-lint.yml b/.github/workflows/golangci-lint.yml index ca075d30f..bdd508e9b 100644 --- a/.github/workflows/golangci-lint.yml +++ b/.github/workflows/golangci-lint.yml @@ -19,7 +19,7 @@ jobs: - name: codespell uses: codespell-project/actions-codespell@v2 with: - ignore_words_list: erro,clienta,hastable,iif,groupd,testin + ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe skip: go.mod,go.sum only_warn: 1 golangci: diff --git a/management/server/group.go b/management/server/group.go index 8f8196e3b..69140bc00 100644 --- a/management/server/group.go +++ b/management/server/group.go @@ -10,13 +10,13 @@ import ( log "github.com/sirupsen/logrus" nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/management/server/activity" + routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" + "github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/util" "github.com/netbirdio/netbird/route" - - "github.com/netbirdio/netbird/management/server/activity" - "github.com/netbirdio/netbird/management/server/status" ) type GroupLinkError struct { @@ -498,6 +498,10 @@ func validateDeleteGroup(ctx context.Context, transaction store.Store, group *ty return &GroupLinkError{"user", linkedUser.Id} } + if isLinked, linkedRouter := isGroupLinkedToNetworkRouter(ctx, transaction, group.AccountID, group.ID); isLinked { + return &GroupLinkError{"network router", linkedRouter.ID} + } + return checkGroupLinkedToSettings(ctx, transaction, group) } @@ -613,6 +617,22 @@ func isGroupLinkedToUser(ctx context.Context, transaction store.Store, accountID return false, nil } +// isGroupLinkedToNetworkRouter checks if a group is linked to any network router in the account. +func isGroupLinkedToNetworkRouter(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *routerTypes.NetworkRouter) { + routers, err := transaction.GetNetworkRoutersByAccountID(ctx, store.LockingStrengthShare, accountID) + if err != nil { + log.WithContext(ctx).Errorf("error retrieving network routers while checking group linkage: %v", err) + return false, nil + } + + for _, router := range routers { + if slices.Contains(router.PeerGroups, groupID) { + return true, router + } + } + return false, nil +} + // areGroupChangesAffectPeers checks if any changes to the specified groups will affect peers. func areGroupChangesAffectPeers(ctx context.Context, transaction store.Store, accountID string, groupIDs []string) (bool, error) { if len(groupIDs) == 0 { @@ -637,6 +657,9 @@ func areGroupChangesAffectPeers(ctx context.Context, transaction store.Store, ac if linked, _ := isGroupLinkedToRoute(ctx, transaction, accountID, groupID); linked { return true, nil } + if linked, _ := isGroupLinkedToNetworkRouter(ctx, transaction, accountID, groupID); linked { + return true, nil + } } return false, nil diff --git a/management/server/group_test.go b/management/server/group_test.go index b21b5e834..8cdef1dd8 100644 --- a/management/server/group_test.go +++ b/management/server/group_test.go @@ -11,9 +11,19 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/netbirdio/management-integrations/integrations" nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/management/server/groups" + "github.com/netbirdio/netbird/management/server/networks" + "github.com/netbirdio/netbird/management/server/networks/resources" + "github.com/netbirdio/netbird/management/server/networks/routers" + routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" + networkTypes "github.com/netbirdio/netbird/management/server/networks/types" + "github.com/netbirdio/netbird/management/server/permissions" + "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/management/server/users" "github.com/netbirdio/netbird/route" ) @@ -414,6 +424,11 @@ func TestGroupAccountPeersUpdate(t *testing.T) { Name: "GroupD", Peers: []string{}, }, + { + ID: "groupE", + Name: "GroupE", + Peers: []string{peer2.ID}, + }, }) assert.NoError(t, err) @@ -673,4 +688,54 @@ func TestGroupAccountPeersUpdate(t *testing.T) { t.Error("timeout waiting for peerShouldReceiveUpdate") } }) + + // Saving a group linked to network router should update account peers and send peer update + t.Run("saving group linked to network router", func(t *testing.T) { + userManager := users.NewManager(manager.Store) + extraSettingsManager := integrations.NewManager(nil) + settingsManager := settings.NewManager(manager.Store, userManager, extraSettingsManager) + permissionsManager := permissions.NewManager(userManager, settingsManager) + groupsManager := groups.NewManager(manager.Store, permissionsManager, manager) + resourcesManager := resources.NewManager(manager.Store, permissionsManager, groupsManager, manager) + routersManager := routers.NewManager(manager.Store, permissionsManager, manager) + networksManager := networks.NewManager(manager.Store, permissionsManager, resourcesManager, routersManager, manager) + + network, err := networksManager.CreateNetwork(context.Background(), userID, &networkTypes.Network{ + ID: "network_test", + AccountID: account.Id, + Name: "network_test", + Description: "", + }) + require.NoError(t, err) + + _, err = routersManager.CreateRouter(context.Background(), userID, &routerTypes.NetworkRouter{ + ID: "router_test", + NetworkID: network.ID, + AccountID: account.Id, + PeerGroups: []string{"groupE"}, + Masquerade: true, + Metric: 9999, + Enabled: true, + }) + require.NoError(t, err) + + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg) + close(done) + }() + + err = manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{ + ID: "groupE", + Name: "GroupE", + Peers: []string{peer2.ID, peer3.ID}, + }) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldReceiveUpdate") + } + }) }