diff --git a/management/server/types/account_components.go b/management/server/types/account_components.go index a2e5877f1..27ab456e5 100644 --- a/management/server/types/account_components.go +++ b/management/server/types/account_components.go @@ -41,8 +41,6 @@ func (a *Account) GetPeerNetworkMapComponents( ResourcePoliciesMap: make(map[string][]*Policy), RoutersMap: make(map[string]map[string]*routerTypes.NetworkRouter), NetworkResources: make([]*resourceTypes.NetworkResource, 0), - GroupIDToUserIDs: groupIDToUserIDs, - AllowedUserIDs: a.getAllowedUserIDs(), PostureFailedPeers: make(map[string]map[string]struct{}, len(a.Policies)), } @@ -55,7 +53,14 @@ func (a *Account) GetPeerNetworkMapComponents( components.DNSSettings = &a.DNSSettings - relevantPeers, relevantGroups, relevantPolicies, relevantRoutes := a.getPeersGroupsPoliciesRoutes(ctx, peerID, validatedPeersMap, &components.PostureFailedPeers) + relevantPeers, relevantGroups, relevantPolicies, relevantRoutes, sshReqs := a.getPeersGroupsPoliciesRoutes(ctx, peerID, peer.SSHEnabled, validatedPeersMap, &components.PostureFailedPeers) + + if len(sshReqs.neededGroupIDs) > 0 { + components.GroupIDToUserIDs = filterGroupIDToUserIDs(groupIDToUserIDs, sshReqs.neededGroupIDs) + } + if sshReqs.needAllowedUserIDs { + components.AllowedUserIDs = a.getAllowedUserIDs() + } components.Peers = relevantPeers components.Groups = relevantGroups @@ -173,16 +178,23 @@ func (a *Account) GetPeerNetworkMapComponents( return components } +type sshRequirements struct { + neededGroupIDs map[string]struct{} + needAllowedUserIDs bool +} + func (a *Account) getPeersGroupsPoliciesRoutes( ctx context.Context, peerID string, + peerSSHEnabled bool, validatedPeersMap map[string]struct{}, postureFailedPeers *map[string]map[string]struct{}, -) (map[string]*nbpeer.Peer, map[string]*Group, []*Policy, []*route.Route) { +) (map[string]*nbpeer.Peer, map[string]*Group, []*Policy, []*route.Route, sshRequirements) { relevantPeerIDs := make(map[string]*nbpeer.Peer, len(a.Peers)/4) relevantGroupIDs := make(map[string]*Group, len(a.Groups)/4) relevantPolicies := make([]*Policy, 0, len(a.Policies)) relevantRoutes := make([]*route.Route, 0, len(a.Routes)) + sshReqs := sshRequirements{neededGroupIDs: make(map[string]struct{})} relevantPeerIDs[peerID] = a.GetPeer(peerID) @@ -274,6 +286,20 @@ func (a *Account) getPeersGroupsPoliciesRoutes( for _, srcGroupID := range rule.Sources { relevantGroupIDs[srcGroupID] = a.GetGroup(srcGroupID) } + + if rule.Protocol == PolicyRuleProtocolNetbirdSSH { + switch { + case len(rule.AuthorizedGroups) > 0: + for groupID := range rule.AuthorizedGroups { + sshReqs.neededGroupIDs[groupID] = struct{}{} + } + case rule.AuthorizedUser != "": + default: + sshReqs.needAllowedUserIDs = true + } + } else if policyRuleImpliesLegacySSH(rule) && peerSSHEnabled { + sshReqs.needAllowedUserIDs = true + } } } if policyRelevant { @@ -281,7 +307,7 @@ func (a *Account) getPeersGroupsPoliciesRoutes( } } - return relevantPeerIDs, relevantGroupIDs, relevantPolicies, relevantRoutes + return relevantPeerIDs, relevantGroupIDs, relevantPolicies, relevantRoutes, sshReqs } func (a *Account) getPeersFromGroups(ctx context.Context, groups []string, peerID string, sourcePostureChecksIDs []string, @@ -442,3 +468,17 @@ func filterDNSRecordsByPeers(records []nbdns.SimpleRecord, peers map[string]*nbp return filteredRecords } + +func filterGroupIDToUserIDs(fullMap map[string][]string, neededGroupIDs map[string]struct{}) map[string][]string { + if len(neededGroupIDs) == 0 { + return nil + } + + filtered := make(map[string][]string, len(neededGroupIDs)) + for groupID := range neededGroupIDs { + if users, ok := fullMap[groupID]; ok { + filtered[groupID] = users + } + } + return filtered +}