diff --git a/management/server/types/account_components.go b/management/server/types/account_components.go index 4eab61e54..17a23914a 100644 --- a/management/server/types/account_components.go +++ b/management/server/types/account_components.go @@ -17,6 +17,7 @@ func (a *Account) GetPeerNetworkMapComponents( validatedPeersMap map[string]struct{}, resourcePolicies map[string][]*Policy, routers map[string]map[string]*routerTypes.NetworkRouter, + groupIDToUserIDs map[string][]string, ) *NetworkMapComponents { peer := a.Peers[peerID] @@ -42,6 +43,8 @@ func (a *Account) GetPeerNetworkMapComponents( ResourcePoliciesMap: make(map[string][]*Policy), RoutersMap: make(map[string]map[string]*routerTypes.NetworkRouter), NetworkResources: make([]*resourceTypes.NetworkResource, 0), + GroupIDToUserIDs: groupIDToUserIDs, + AllowedUserIDs: a.getAllowedUserIDs(), } components.AccountSettings = &AccountSettingsInfo{ diff --git a/management/server/types/networkmap_comparison_test.go b/management/server/types/networkmap_comparison_test.go index fdfeafe02..46a775dfb 100644 --- a/management/server/types/networkmap_comparison_test.go +++ b/management/server/types/networkmap_comparison_test.go @@ -61,6 +61,7 @@ func TestNetworkMapComponents_CompareWithLegacy(t *testing.T) { validatedPeersMap, resourcePolicies, routers, + groupIDToUserIDs, ) if components == nil { @@ -113,6 +114,7 @@ func TestNetworkMapComponents_GoldenFileComparison(t *testing.T) { validatedPeersMap, resourcePolicies, routers, + groupIDToUserIDs, ) require.NotNil(t, components, "GetPeerNetworkMapComponents returned nil") @@ -477,6 +479,7 @@ func BenchmarkComponentsNetworkMap(b *testing.B) { peersCustomZone := nbdns.CustomZone{} resourcePolicies := account.GetResourcePoliciesMap() routers := account.GetResourceRoutersMap() + groupIDToUserIDs := account.GetActiveGroupUsers() b.ResetTimer() for i := 0; i < b.N; i++ { @@ -487,6 +490,7 @@ func BenchmarkComponentsNetworkMap(b *testing.B) { validatedPeersMap, resourcePolicies, routers, + groupIDToUserIDs, ) _ = CalculateNetworkMapFromComponents(ctx, components) } @@ -507,6 +511,7 @@ func BenchmarkComponentsCreation(b *testing.B) { peersCustomZone := nbdns.CustomZone{} resourcePolicies := account.GetResourcePoliciesMap() routers := account.GetResourceRoutersMap() + groupIDToUserIDs := account.GetActiveGroupUsers() b.ResetTimer() for i := 0; i < b.N; i++ { @@ -517,6 +522,7 @@ func BenchmarkComponentsCreation(b *testing.B) { validatedPeersMap, resourcePolicies, routers, + groupIDToUserIDs, ) } } @@ -536,6 +542,7 @@ func BenchmarkCalculationFromComponents(b *testing.B) { peersCustomZone := nbdns.CustomZone{} resourcePolicies := account.GetResourcePoliciesMap() routers := account.GetResourceRoutersMap() + groupIDToUserIDs := account.GetActiveGroupUsers() components := account.GetPeerNetworkMapComponents( ctx, @@ -544,6 +551,7 @@ func BenchmarkCalculationFromComponents(b *testing.B) { validatedPeersMap, resourcePolicies, routers, + groupIDToUserIDs, ) b.ResetTimer() @@ -576,7 +584,7 @@ func TestGetPeerNetworkMap_ProdAccount_CompareImplementations(t *testing.T) { legacyNetworkMap := testAccount.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil, groupIDToUserIDs) require.NotNil(t, legacyNetworkMap, "GetPeerNetworkMap returned nil") - components := testAccount.GetPeerNetworkMapComponents(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers) + components := testAccount.GetPeerNetworkMapComponents(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, groupIDToUserIDs) require.NotNil(t, components, "GetPeerNetworkMapComponents returned nil") newNetworkMap := CalculateNetworkMapFromComponents(ctx, components) @@ -678,7 +686,7 @@ func BenchmarkGetPeerNetworkMapCompactCached(b *testing.B) { resourcePolicies := account.GetResourcePoliciesMap() routers := account.GetResourceRoutersMap() agUsers := account.GetActiveGroupUsers() - components := account.GetPeerNetworkMapComponents(ctx, testingPeerID, customZone, validatedPeersMap, resourcePolicies, routers) + components := account.GetPeerNetworkMapComponents(ctx, testingPeerID, customZone, validatedPeersMap, resourcePolicies, routers, agUsers) componentsJSON, err := json.Marshal(components) require.NoError(b, err) @@ -729,6 +737,7 @@ func BenchmarkGetPeerNetworkMapCompactCached(b *testing.B) { validatedPeersMap, resourcePolicies, routers, + agUsers, ) _ = CalculateNetworkMapFromComponents(ctx, components) } @@ -744,6 +753,7 @@ func BenchmarkGetPeerNetworkMapCompactCached(b *testing.B) { validatedPeersMap, resourcePolicies, routers, + agUsers, ) } }) diff --git a/management/server/types/networkmap_components.go b/management/server/types/networkmap_components.go index 1bb6c0c9d..3c6ff4df1 100644 --- a/management/server/types/networkmap_components.go +++ b/management/server/types/networkmap_components.go @@ -7,6 +7,7 @@ import ( "strings" "time" + "github.com/netbirdio/netbird/client/ssh/auth" nbdns "github.com/netbirdio/netbird/dns" nbpeer "github.com/netbirdio/netbird/management/server/peer" resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" @@ -33,6 +34,9 @@ type NetworkMapComponents struct { ResourcePoliciesMap map[string][]*Policy RoutersMap map[string]map[string]*routerTypes.NetworkRouter NetworkResources []*resourceTypes.NetworkResource + + GroupIDToUserIDs map[string][]string + AllowedUserIDs map[string]struct{} } type AccountSettingsInfo struct { @@ -103,7 +107,7 @@ func CalculateNetworkMapFromComponents(ctx context.Context, components *NetworkM func (calc *NetworkMapCalculator) Calculate(ctx context.Context) *NetworkMap { targetPeerID := calc.components.PeerID - aclPeers, firewallRules := calc.getPeerConnectionResources(ctx, targetPeerID) + aclPeers, firewallRules, authorizedUsers, sshEnabled := calc.getPeerConnectionResources(ctx, targetPeerID) peersToConnect, expiredPeers := calc.filterPeersByLoginExpiration(aclPeers) @@ -151,16 +155,20 @@ func (calc *NetworkMapCalculator) Calculate(ctx context.Context) *NetworkMap { OfflinePeers: expiredPeers, FirewallRules: firewallRules, RoutesFirewallRules: append(networkResourcesFirewallRules, routesFirewallRules...), + AuthorizedUsers: authorizedUsers, + EnableSSH: sshEnabled, } } -func (calc *NetworkMapCalculator) getPeerConnectionResources(ctx context.Context, targetPeerID string) ([]*nbpeer.Peer, []*FirewallRule) { +func (calc *NetworkMapCalculator) getPeerConnectionResources(ctx context.Context, targetPeerID string) ([]*nbpeer.Peer, []*FirewallRule, map[string]map[string]struct{}, bool) { targetPeer := calc.components.GetPeerInfo(targetPeerID) if targetPeer == nil { - return nil, nil + return nil, nil, nil, false } generateResources, getAccumulatedResources := calc.connResourcesGenerator(ctx, targetPeer) + authorizedUsers := make(map[string]map[string]struct{}) + sshEnabled := false for _, policy := range calc.components.Policies { if !policy.Enabled { @@ -203,10 +211,58 @@ func (calc *NetworkMapCalculator) getPeerConnectionResources(ctx context.Context if peerInDestinations { generateResources(rule, sourcePeers, FirewallRuleDirectionIN) } + + if peerInDestinations && rule.Protocol == PolicyRuleProtocolNetbirdSSH { + sshEnabled = true + switch { + case len(rule.AuthorizedGroups) > 0: + for groupID, localUsers := range rule.AuthorizedGroups { + userIDs, ok := calc.components.GroupIDToUserIDs[groupID] + if !ok { + continue + } + + if len(localUsers) == 0 { + localUsers = []string{auth.Wildcard} + } + + for _, localUser := range localUsers { + if authorizedUsers[localUser] == nil { + authorizedUsers[localUser] = make(map[string]struct{}) + } + for _, userID := range userIDs { + authorizedUsers[localUser][userID] = struct{}{} + } + } + } + case rule.AuthorizedUser != "": + if authorizedUsers[auth.Wildcard] == nil { + authorizedUsers[auth.Wildcard] = make(map[string]struct{}) + } + authorizedUsers[auth.Wildcard][rule.AuthorizedUser] = struct{}{} + default: + authorizedUsers[auth.Wildcard] = calc.getAllowedUserIDs() + } + } else if peerInDestinations && policyRuleImpliesLegacySSH(rule) && targetPeer.SSHEnabled { + sshEnabled = true + authorizedUsers[auth.Wildcard] = calc.getAllowedUserIDs() + } } } - return getAccumulatedResources() + peers, fwRules := getAccumulatedResources() + return peers, fwRules, authorizedUsers, sshEnabled +} + +func (calc *NetworkMapCalculator) getAllowedUserIDs() map[string]struct{} { + if calc.components.AllowedUserIDs != nil { + result := make(map[string]struct{}, len(calc.components.AllowedUserIDs)) + for k, v := range calc.components.AllowedUserIDs { + result[k] = v + } + return result + } + return make(map[string]struct{}) } func (calc *NetworkMapCalculator) connResourcesGenerator(ctx context.Context, targetPeer *nbpeer.Peer) (func(*PolicyRule, []*nbpeer.Peer, int), func() ([]*nbpeer.Peer, []*FirewallRule)) {