diff --git a/management/server/affected_groups.go b/management/server/affected_groups.go new file mode 100644 index 000000000..05e40db0e --- /dev/null +++ b/management/server/affected_groups.go @@ -0,0 +1,220 @@ +package server + +import ( + "context" + + log "github.com/sirupsen/logrus" + + nbdns "github.com/netbirdio/netbird/dns" + routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/route" +) + +// collectGroupChangeAffectedGroups walks policies, routes, nameservers, DNS settings, +// and network routers to collect all group IDs and direct peer IDs affected by the changed groups. +func collectGroupChangeAffectedGroups(ctx context.Context, transaction store.Store, accountID string, changedGroupIDs []string) (allGroupIDs []string, directPeerIDs []string) { + if len(changedGroupIDs) == 0 { + return nil, nil + } + + changedSet := make(map[string]struct{}, len(changedGroupIDs)) + for _, id := range changedGroupIDs { + changedSet[id] = struct{}{} + } + + log.WithContext(ctx).Tracef("collecting affected groups for changed groups %v", changedGroupIDs) + + groupSet := make(map[string]struct{}) + peerSet := make(map[string]struct{}) + + collectPolicyAffectedGroups(ctx, transaction, accountID, changedSet, groupSet, peerSet) + collectRouteAffectedGroups(ctx, transaction, accountID, changedSet, groupSet, peerSet) + collectNameServerAffectedGroups(ctx, transaction, accountID, changedSet, groupSet) + collectDNSSettingsAffectedGroups(ctx, transaction, accountID, changedSet, groupSet) + collectNetworkRouterAffectedGroups(ctx, transaction, accountID, changedSet, groupSet, peerSet) + + allGroupIDs = make([]string, 0, len(groupSet)) + for gID := range groupSet { + allGroupIDs = append(allGroupIDs, gID) + } + + directPeerIDs = make([]string, 0, len(peerSet)) + for pID := range peerSet { + directPeerIDs = append(directPeerIDs, pID) + } + + log.WithContext(ctx).Tracef("affected groups resolution: changed=%v -> affectedGroups=%v, directPeers=%v", changedGroupIDs, allGroupIDs, directPeerIDs) + + return allGroupIDs, directPeerIDs +} + +func collectPolicyAffectedGroups(ctx context.Context, transaction store.Store, accountID string, changedSet, groupSet, peerSet map[string]struct{}) { + policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID) + if err != nil { + log.WithContext(ctx).Errorf("failed to get policies for group change resolution: %v", err) + return + } + + for _, policy := range policies { + if !policyReferencesGroups(policy, changedSet) { + continue + } + ruleGroups := policy.RuleGroups() + log.WithContext(ctx).Tracef("policy %s (%s) references changed groups, adding rule groups %v", policy.ID, policy.Name, ruleGroups) + for _, gID := range ruleGroups { + groupSet[gID] = struct{}{} + } + collectPolicyDirectPeers(ctx, policy, peerSet) + } +} + +func collectPolicyDirectPeers(ctx context.Context, policy *types.Policy, peerSet map[string]struct{}) { + for _, rule := range policy.Rules { + if rule.SourceResource.Type == types.ResourceTypePeer && rule.SourceResource.ID != "" { + log.WithContext(ctx).Tracef("policy %s rule %s has direct source peer %s", policy.ID, rule.ID, rule.SourceResource.ID) + peerSet[rule.SourceResource.ID] = struct{}{} + } + if rule.DestinationResource.Type == types.ResourceTypePeer && rule.DestinationResource.ID != "" { + log.WithContext(ctx).Tracef("policy %s rule %s has direct destination peer %s", policy.ID, rule.ID, rule.DestinationResource.ID) + peerSet[rule.DestinationResource.ID] = struct{}{} + } + } +} + +func collectRouteAffectedGroups(ctx context.Context, transaction store.Store, accountID string, changedSet, groupSet, peerSet map[string]struct{}) { + routes, err := transaction.GetAccountRoutes(ctx, store.LockingStrengthNone, accountID) + if err != nil { + log.WithContext(ctx).Errorf("failed to get routes for group change resolution: %v", err) + return + } + + for _, r := range routes { + if !routeReferencesGroups(r, changedSet) { + continue + } + log.WithContext(ctx).Tracef("route %s (%s) references changed groups", r.ID, r.Description) + for _, gID := range r.Groups { + groupSet[gID] = struct{}{} + } + for _, gID := range r.PeerGroups { + groupSet[gID] = struct{}{} + } + for _, gID := range r.AccessControlGroups { + groupSet[gID] = struct{}{} + } + if r.Peer != "" { + log.WithContext(ctx).Tracef("route %s has direct peer %s", r.ID, r.Peer) + peerSet[r.Peer] = struct{}{} + } + } +} + +func collectNameServerAffectedGroups(ctx context.Context, transaction store.Store, accountID string, changedSet, groupSet map[string]struct{}) { + nsGroups, err := transaction.GetAccountNameServerGroups(ctx, store.LockingStrengthNone, accountID) + if err != nil { + log.WithContext(ctx).Errorf("failed to get nameserver groups for group change resolution: %v", err) + return + } + + for _, ns := range nsGroups { + if !nsReferencesGroups(ns, changedSet) { + continue + } + for _, g := range ns.Groups { + groupSet[g] = struct{}{} + } + } +} + +func nsReferencesGroups(ns *nbdns.NameServerGroup, changedSet map[string]struct{}) bool { + for _, gID := range ns.Groups { + if _, ok := changedSet[gID]; ok { + log.Tracef("nameserver group %s (%s) references changed group %s", ns.ID, ns.Name, gID) + return true + } + } + return false +} + +func collectDNSSettingsAffectedGroups(ctx context.Context, transaction store.Store, accountID string, changedSet, groupSet map[string]struct{}) { + dnsSettings, err := transaction.GetAccountDNSSettings(ctx, store.LockingStrengthNone, accountID) + if err != nil { + log.WithContext(ctx).Errorf("failed to get DNS settings for group change resolution: %v", err) + return + } + + for _, gID := range dnsSettings.DisabledManagementGroups { + if _, ok := changedSet[gID]; ok { + log.WithContext(ctx).Tracef("DNS disabled management group %s matches changed group", gID) + groupSet[gID] = struct{}{} + } + } +} + +func collectNetworkRouterAffectedGroups(ctx context.Context, transaction store.Store, accountID string, changedSet, groupSet, peerSet map[string]struct{}) { + routers, err := transaction.GetNetworkRoutersByAccountID(ctx, store.LockingStrengthNone, accountID) + if err != nil { + log.WithContext(ctx).Errorf("failed to get network routers for group change resolution: %v", err) + return + } + + for _, router := range routers { + if !routerReferencesGroups(router, changedSet) { + continue + } + log.WithContext(ctx).Tracef("network router %s references changed groups", router.ID) + for _, gID := range router.PeerGroups { + groupSet[gID] = struct{}{} + } + if router.Peer != "" { + log.WithContext(ctx).Tracef("network router %s has direct peer %s", router.ID, router.Peer) + peerSet[router.Peer] = struct{}{} + } + } +} + +func policyReferencesGroups(policy *types.Policy, groupSet map[string]struct{}) bool { + for _, rule := range policy.Rules { + for _, gID := range rule.Sources { + if _, ok := groupSet[gID]; ok { + return true + } + } + for _, gID := range rule.Destinations { + if _, ok := groupSet[gID]; ok { + return true + } + } + } + return false +} + +func routeReferencesGroups(r *route.Route, groupSet map[string]struct{}) bool { + for _, gID := range r.Groups { + if _, ok := groupSet[gID]; ok { + return true + } + } + for _, gID := range r.PeerGroups { + if _, ok := groupSet[gID]; ok { + return true + } + } + for _, gID := range r.AccessControlGroups { + if _, ok := groupSet[gID]; ok { + return true + } + } + return false +} + +func routerReferencesGroups(router *routerTypes.NetworkRouter, groupSet map[string]struct{}) bool { + for _, gID := range router.PeerGroups { + if _, ok := groupSet[gID]; ok { + return true + } + } + return false +} diff --git a/management/server/group_linkage.go b/management/server/group_linkage.go new file mode 100644 index 000000000..a69e7d7a7 --- /dev/null +++ b/management/server/group_linkage.go @@ -0,0 +1,226 @@ +package server + +import ( + "context" + "slices" + + log "github.com/sirupsen/logrus" + + nbdns "github.com/netbirdio/netbird/dns" + routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/route" + "github.com/netbirdio/netbird/shared/management/status" +) + +func validateDeleteGroup(ctx context.Context, transaction store.Store, group *types.Group, userID string, flowGroups []string) error { + // disable a deleting integration group if the initiator is not an admin service user + if group.Issued == types.GroupIssuedIntegration { + executingUser, err := transaction.GetUserByUserID(ctx, store.LockingStrengthNone, userID) + if err != nil { + return status.Errorf(status.Internal, "failed to get user") + } + if executingUser.Role != types.UserRoleAdmin || !executingUser.IsServiceUser { + return status.Errorf(status.PermissionDenied, "only service users with admin power can delete integration group") + } + } + + if group.IsGroupAll() { + return status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed") + } + + if len(group.Resources) > 0 { + return &GroupLinkError{"network resource", group.Resources[0].ID} + } + + if slices.Contains(flowGroups, group.ID) { + return &GroupLinkError{"settings", "traffic event logging"} + } + + if isLinked, linkedRoute := isGroupLinkedToRoute(ctx, transaction, group.AccountID, group.ID); isLinked { + return &GroupLinkError{"route", string(linkedRoute.NetID)} + } + + if isLinked, linkedDns := isGroupLinkedToDns(ctx, transaction, group.AccountID, group.ID); isLinked { + return &GroupLinkError{"name server groups", linkedDns.Name} + } + + if isLinked, linkedPolicy := isGroupLinkedToPolicy(ctx, transaction, group.AccountID, group.ID); isLinked { + return &GroupLinkError{"policy", linkedPolicy.Name} + } + + if isLinked, linkedSetupKey := isGroupLinkedToSetupKey(ctx, transaction, group.AccountID, group.ID); isLinked { + return &GroupLinkError{"setup key", linkedSetupKey.Name} + } + + if isLinked, linkedUser := isGroupLinkedToUser(ctx, transaction, group.AccountID, group.ID); isLinked { + return &GroupLinkError{"user", linkedUser.Id} + } + + if isLinked, linkedRouter := isGroupLinkedToNetworkRouter(ctx, transaction, group.AccountID, group.ID); isLinked { + return &GroupLinkError{"network router", linkedRouter.ID} + } + + return checkGroupLinkedToSettings(ctx, transaction, group) +} + +// checkGroupLinkedToSettings verifies if a group is linked to any settings in the account. +func checkGroupLinkedToSettings(ctx context.Context, transaction store.Store, group *types.Group) error { + dnsSettings, err := transaction.GetAccountDNSSettings(ctx, store.LockingStrengthNone, group.AccountID) + if err != nil { + return status.Errorf(status.Internal, "failed to get DNS settings") + } + + if slices.Contains(dnsSettings.DisabledManagementGroups, group.ID) { + return &GroupLinkError{"disabled DNS management groups", group.Name} + } + + settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthNone, group.AccountID) + if err != nil { + return status.Errorf(status.Internal, "failed to get account settings") + } + + if settings.Extra != nil && slices.Contains(settings.Extra.IntegratedValidatorGroups, group.ID) { + return &GroupLinkError{"integrated validator", group.Name} + } + + return nil +} + +// isGroupLinkedToRoute checks if a group is linked to any route in the account. +func isGroupLinkedToRoute(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *route.Route) { + routes, err := transaction.GetAccountRoutes(ctx, store.LockingStrengthNone, accountID) + if err != nil { + log.WithContext(ctx).Errorf("error retrieving routes while checking group linkage: %v", err) + return false, nil + } + + for _, r := range routes { + isLinked := slices.Contains(r.Groups, groupID) || + slices.Contains(r.PeerGroups, groupID) || + slices.Contains(r.AccessControlGroups, groupID) + if isLinked { + return true, r + } + } + + return false, nil +} + +// isGroupLinkedToPolicy checks if a group is linked to any policy in the account. +func isGroupLinkedToPolicy(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *types.Policy) { + policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID) + if err != nil { + log.WithContext(ctx).Errorf("error retrieving policies while checking group linkage: %v", err) + return false, nil + } + + for _, policy := range policies { + for _, rule := range policy.Rules { + if slices.Contains(rule.Sources, groupID) || slices.Contains(rule.Destinations, groupID) { + return true, policy + } + } + } + return false, nil +} + +// isGroupLinkedToDns checks if a group is linked to any nameserver group in the account. +func isGroupLinkedToDns(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *nbdns.NameServerGroup) { + nameServerGroups, err := transaction.GetAccountNameServerGroups(ctx, store.LockingStrengthNone, accountID) + if err != nil { + log.WithContext(ctx).Errorf("error retrieving name server groups while checking group linkage: %v", err) + return false, nil + } + + for _, dns := range nameServerGroups { + for _, g := range dns.Groups { + if g == groupID { + return true, dns + } + } + } + + return false, nil +} + +// isGroupLinkedToSetupKey checks if a group is linked to any setup key in the account. +func isGroupLinkedToSetupKey(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *types.SetupKey) { + setupKeys, err := transaction.GetAccountSetupKeys(ctx, store.LockingStrengthNone, accountID) + if err != nil { + log.WithContext(ctx).Errorf("error retrieving setup keys while checking group linkage: %v", err) + return false, nil + } + + for _, setupKey := range setupKeys { + if slices.Contains(setupKey.AutoGroups, groupID) { + return true, setupKey + } + } + return false, nil +} + +// isGroupLinkedToUser checks if a group is linked to any user in the account. +func isGroupLinkedToUser(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *types.User) { + users, err := transaction.GetAccountUsers(ctx, store.LockingStrengthNone, accountID) + if err != nil { + log.WithContext(ctx).Errorf("error retrieving users while checking group linkage: %v", err) + return false, nil + } + + for _, user := range users { + if slices.Contains(user.AutoGroups, groupID) { + return true, user + } + } + return false, nil +} + +// isGroupLinkedToNetworkRouter checks if a group is linked to any network router in the account. +func isGroupLinkedToNetworkRouter(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *routerTypes.NetworkRouter) { + routers, err := transaction.GetNetworkRoutersByAccountID(ctx, store.LockingStrengthNone, accountID) + if err != nil { + log.WithContext(ctx).Errorf("error retrieving network routers while checking group linkage: %v", err) + return false, nil + } + + for _, router := range routers { + if slices.Contains(router.PeerGroups, groupID) { + return true, router + } + } + return false, nil +} + +// areGroupChangesAffectPeers checks if any changes to the specified groups will affect peers. +func areGroupChangesAffectPeers(ctx context.Context, transaction store.Store, accountID string, groupIDs []string) (bool, error) { + if len(groupIDs) == 0 { + return false, nil + } + + dnsSettings, err := transaction.GetAccountDNSSettings(ctx, store.LockingStrengthNone, accountID) + if err != nil { + return false, err + } + + for _, groupID := range groupIDs { + if slices.Contains(dnsSettings.DisabledManagementGroups, groupID) { + return true, nil + } + if linked, _ := isGroupLinkedToDns(ctx, transaction, accountID, groupID); linked { + return true, nil + } + if linked, _ := isGroupLinkedToPolicy(ctx, transaction, accountID, groupID); linked { + return true, nil + } + if linked, _ := isGroupLinkedToRoute(ctx, transaction, accountID, groupID); linked { + return true, nil + } + if linked, _ := isGroupLinkedToNetworkRouter(ctx, transaction, accountID, groupID); linked { + return true, nil + } + } + + return false, nil +}