diff --git a/client/internal/dns/upstream.go b/client/internal/dns/upstream.go index 0fbd32771..375f6df1c 100644 --- a/client/internal/dns/upstream.go +++ b/client/internal/dns/upstream.go @@ -351,9 +351,13 @@ func (u *upstreamResolverBase) waitUntilResponse() { return fmt.Errorf("upstream check call error") } - err := backoff.Retry(operation, exponentialBackOff) + err := backoff.Retry(operation, backoff.WithContext(exponentialBackOff, u.ctx)) if err != nil { - log.Warn(err) + if errors.Is(err, context.Canceled) { + log.Debugf("upstream retry loop exited for upstreams %s", u.upstreamServersString()) + } else { + log.Warnf("upstream retry loop exited for upstreams %s: %v", u.upstreamServersString(), err) + } return } diff --git a/combined/cmd/config.go b/combined/cmd/config.go index 04155f72e..d0ffa4ba4 100644 --- a/combined/cmd/config.go +++ b/combined/cmd/config.go @@ -70,6 +70,7 @@ type ServerConfig struct { DisableGeoliteUpdate bool `yaml:"disableGeoliteUpdate"` Auth AuthConfig `yaml:"auth"` Store StoreConfig `yaml:"store"` + ActivityStore StoreConfig `yaml:"activityStore"` ReverseProxy ReverseProxyConfig `yaml:"reverseProxy"` } diff --git a/combined/cmd/root.go b/combined/cmd/root.go index b8ea7064c..00edcb5d4 100644 --- a/combined/cmd/root.go +++ b/combined/cmd/root.go @@ -141,6 +141,17 @@ func initializeConfig() error { } } + if engine := config.Server.ActivityStore.Engine; engine != "" { + engineLower := strings.ToLower(engine) + if engineLower == "postgres" && config.Server.ActivityStore.DSN == "" { + return fmt.Errorf("activityStore.dsn is required when activityStore.engine is postgres") + } + os.Setenv("NB_ACTIVITY_EVENT_STORE_ENGINE", engineLower) + if dsn := config.Server.ActivityStore.DSN; dsn != "" { + os.Setenv("NB_ACTIVITY_EVENT_POSTGRES_DSN", dsn) + } + } + log.Infof("Starting combined NetBird server") logConfig(config) logEnvVars() @@ -668,8 +679,11 @@ func logEnvVars() { if strings.HasPrefix(env, "NB_") { key, _, _ := strings.Cut(env, "=") value := os.Getenv(key) - if strings.Contains(strings.ToLower(key), "secret") || strings.Contains(strings.ToLower(key), "key") || strings.Contains(strings.ToLower(key), "password") { + keyLower := strings.ToLower(key) + if strings.Contains(keyLower, "secret") || strings.Contains(keyLower, "key") || strings.Contains(keyLower, "password") { value = maskSecret(value) + } else if strings.Contains(keyLower, "dsn") { + value = maskDSNPassword(value) } log.Infof(" %s=%s", key, value) found = true diff --git a/combined/config.yaml.example b/combined/config.yaml.example index b3b38c5a9..ad033396d 100644 --- a/combined/config.yaml.example +++ b/combined/config.yaml.example @@ -104,6 +104,11 @@ server: dsn: "" # Connection string for postgres or mysql encryptionKey: "" + # Activity events store configuration (optional, defaults to sqlite in dataDir) + # activityStore: + # engine: "sqlite" # sqlite or postgres + # dsn: "" # Connection string for postgres + # Reverse proxy settings (optional) # reverseProxy: # trustedHTTPProxies: [] diff --git a/management/internals/controllers/network_map/controller/controller.go b/management/internals/controllers/network_map/controller/controller.go index b2b65f47a..121c55ac5 100644 --- a/management/internals/controllers/network_map/controller/controller.go +++ b/management/internals/controllers/network_map/controller/controller.go @@ -63,6 +63,8 @@ type Controller struct { expNewNetworkMap bool expNewNetworkMapAIDs map[string]struct{} + + compactedNetworkMap bool } type bufferUpdate struct { @@ -85,6 +87,12 @@ func NewController(ctx context.Context, store store.Store, metrics telemetry.App newNetworkMapBuilder = false } + compactedNetworkMap, err := strconv.ParseBool(os.Getenv(types.EnvNewNetworkMapCompacted)) + if err != nil { + log.WithContext(ctx).Warnf("failed to parse %s, using default value false: %v", types.EnvNewNetworkMapCompacted, err) + compactedNetworkMap = false + } + ids := strings.Split(os.Getenv(network_map.EnvNewNetworkMapAccounts), ",") expIDs := make(map[string]struct{}, len(ids)) for _, id := range ids { @@ -108,6 +116,8 @@ func NewController(ctx context.Context, store store.Store, metrics telemetry.App holder: types.NewHolder(), expNewNetworkMap: newNetworkMapBuilder, expNewNetworkMapAIDs: expIDs, + + compactedNetworkMap: compactedNetworkMap, } } @@ -230,9 +240,12 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin var remotePeerNetworkMap *types.NetworkMap - if c.experimentalNetworkMap(accountID) { + switch { + case c.experimentalNetworkMap(accountID): remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, p.AccountID, p.ID, approvedPeersMap, peersCustomZone, accountZones, c.accountManagerMetrics) - } else { + case c.compactedNetworkMap: + remotePeerNetworkMap = account.GetPeerNetworkMapFromComponents(ctx, p.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs) + default: remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, p.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs) } @@ -355,9 +368,12 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe var remotePeerNetworkMap *types.NetworkMap - if c.experimentalNetworkMap(accountId) { + switch { + case c.experimentalNetworkMap(accountId): remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, peersCustomZone, accountZones, c.accountManagerMetrics) - } else { + case c.compactedNetworkMap: + remotePeerNetworkMap = account.GetPeerNetworkMapFromComponents(ctx, peerId, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs) + default: remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, peerId, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs) } @@ -479,7 +495,12 @@ func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresAppr } else { resourcePolicies := account.GetResourcePoliciesMap() routers := account.GetResourceRoutersMap() - networkMap = account.GetPeerNetworkMap(ctx, peer.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, account.GetActiveGroupUsers()) + groupIDToUserIDs := account.GetActiveGroupUsers() + if c.compactedNetworkMap { + networkMap = account.GetPeerNetworkMapFromComponents(ctx, peer.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs) + } else { + networkMap = account.GetPeerNetworkMap(ctx, peer.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs) + } } proxyNetworkMap, ok := proxyNetworkMaps[peer.ID] @@ -854,7 +875,12 @@ func (c *Controller) GetNetworkMap(ctx context.Context, peerID string) (*types.N account.InjectProxyPolicies(ctx) resourcePolicies := account.GetResourcePoliciesMap() routers := account.GetResourceRoutersMap() - networkMap = account.GetPeerNetworkMap(ctx, peer.ID, peersCustomZone, accountZones, validatedPeers, resourcePolicies, routers, nil, account.GetActiveGroupUsers()) + groupIDToUserIDs := account.GetActiveGroupUsers() + if c.compactedNetworkMap { + networkMap = account.GetPeerNetworkMapFromComponents(ctx, peer.ID, peersCustomZone, accountZones, validatedPeers, resourcePolicies, routers, nil, groupIDToUserIDs) + } else { + networkMap = account.GetPeerNetworkMap(ctx, peer.ID, peersCustomZone, accountZones, validatedPeers, resourcePolicies, routers, nil, groupIDToUserIDs) + } } proxyNetworkMap, ok := proxyNetworkMaps[peer.ID] diff --git a/management/internals/modules/reverseproxy/service/interface.go b/management/internals/modules/reverseproxy/service/interface.go index d5219511b..a20f3049b 100644 --- a/management/internals/modules/reverseproxy/service/interface.go +++ b/management/internals/modules/reverseproxy/service/interface.go @@ -14,6 +14,7 @@ type Manager interface { CreateService(ctx context.Context, accountID, userID string, service *Service) (*Service, error) UpdateService(ctx context.Context, accountID, userID string, service *Service) (*Service, error) DeleteService(ctx context.Context, accountID, userID, serviceID string) error + DeleteAllServices(ctx context.Context, accountID, userID string) error SetCertificateIssuedAt(ctx context.Context, accountID, serviceID string) error SetStatus(ctx context.Context, accountID, serviceID string, status Status) error ReloadAllServicesForAccount(ctx context.Context, accountID string) error diff --git a/management/internals/modules/reverseproxy/service/interface_mock.go b/management/internals/modules/reverseproxy/service/interface_mock.go index 673e11fa9..e8aa9eb34 100644 --- a/management/internals/modules/reverseproxy/service/interface_mock.go +++ b/management/internals/modules/reverseproxy/service/interface_mock.go @@ -50,6 +50,20 @@ func (mr *MockManagerMockRecorder) CreateService(ctx, accountID, userID, service return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateService", reflect.TypeOf((*MockManager)(nil).CreateService), ctx, accountID, userID, service) } +// DeleteAllServices mocks base method. +func (m *MockManager) DeleteAllServices(ctx context.Context, accountID, userID string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteAllServices", ctx, accountID, userID) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeleteAllServices indicates an expected call of DeleteAllServices. +func (mr *MockManagerMockRecorder) DeleteAllServices(ctx, accountID, userID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAllServices", reflect.TypeOf((*MockManager)(nil).DeleteAllServices), ctx, accountID, userID) +} + // DeleteService mocks base method. func (m *MockManager) DeleteService(ctx context.Context, accountID, userID, serviceID string) error { m.ctrl.T.Helper() diff --git a/management/internals/modules/reverseproxy/service/manager/manager.go b/management/internals/modules/reverseproxy/service/manager/manager.go index 9927f4244..ea1f5af91 100644 --- a/management/internals/modules/reverseproxy/service/manager/manager.go +++ b/management/internals/modules/reverseproxy/service/manager/manager.go @@ -15,6 +15,7 @@ import ( "github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/shared/management/proto" "github.com/netbirdio/netbird/shared/management/status" ) @@ -344,6 +345,22 @@ func (m *Manager) sendServiceUpdateNotifications(ctx context.Context, accountID } } +func (m *managerImpl) sendServiceUpdate(service *reverseproxy.Service, operation reverseproxy.Operation, cluster, oldService string) { + oidcCfg := m.proxyGRPCServer.GetOIDCValidationConfig() + mapping := service.ToProtoMapping(operation, oldService, oidcCfg) + m.sendMappingsToCluster([]*proto.ProxyMapping{mapping}, cluster) +} + +func (m *managerImpl) sendMappingsToCluster(mappings []*proto.ProxyMapping, cluster string) { + if len(mappings) == 0 { + return + } + update := &proto.GetMappingUpdateResponse{ + Mapping: mappings, + } + m.proxyGRPCServer.SendServiceUpdateToCluster(update, cluster) +} + // validateTargetReferences checks that all target IDs reference existing peers or resources in the account. func validateTargetReferences(ctx context.Context, transaction store.Store, accountID string, targets []*service.Target) error { for _, target := range targets { diff --git a/management/internals/shared/grpc/proxy.go b/management/internals/shared/grpc/proxy.go index bdaeda866..6f75a4d39 100644 --- a/management/internals/shared/grpc/proxy.go +++ b/management/internals/shared/grpc/proxy.go @@ -105,7 +105,7 @@ type proxyConnection struct { proxyID string address string stream proto.ProxyService_GetMappingUpdateServer - sendChan chan *proto.ProxyMapping + sendChan chan *proto.GetMappingUpdateResponse ctx context.Context cancel context.CancelFunc } @@ -114,7 +114,6 @@ type proxyConnection struct { func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeTokenStore, oidcConfig ProxyOIDCConfig, peersManager peers.Manager, usersManager users.Manager, proxyMgr proxy.Manager) *ProxyServiceServer { ctx, cancel := context.WithCancel(context.Background()) s := &ProxyServiceServer{ - updatesChan: make(chan *proto.ProxyMapping, 100), accessLogManager: accessLogMgr, oidcConfig: oidcConfig, tokenStore: tokenStore, @@ -203,7 +202,7 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest proxyID: proxyID, address: proxyAddress, stream: stream, - sendChan: make(chan *proto.ProxyMapping, 100), + sendChan: make(chan *proto.GetMappingUpdateResponse, 100), ctx: connCtx, cancel: cancel, } @@ -349,7 +348,7 @@ func (s *ProxyServiceServer) sender(conn *proxyConnection, errChan chan<- error) for { select { case msg := <-conn.sendChan: - if err := conn.stream.Send(&proto.GetMappingUpdateResponse{Mapping: []*proto.ProxyMapping{msg}}); err != nil { + if err := conn.stream.Send(msg); err != nil { errChan <- err return } @@ -400,7 +399,7 @@ func (s *ProxyServiceServer) SendAccessLog(ctx context.Context, req *proto.SendA // Management should call this when services are created/updated/removed. // For create/update operations a unique one-time auth token is generated per // proxy so that every replica can independently authenticate with management. -func (s *ProxyServiceServer) SendServiceUpdate(update *proto.ProxyMapping) { +func (s *ProxyServiceServer) SendServiceUpdate(update *proto.GetMappingUpdateResponse) { log.Debugf("Broadcasting service update to all connected proxy servers") s.connectedProxies.Range(func(key, value interface{}) bool { conn := value.(*proxyConnection) @@ -410,7 +409,7 @@ func (s *ProxyServiceServer) SendServiceUpdate(update *proto.ProxyMapping) { } select { case conn.sendChan <- msg: - log.Debugf("Sent service update with id %s to proxy server %s", update.Id, conn.proxyID) + log.Debugf("Sent service update to proxy server %s", conn.proxyID) default: log.Warnf("Failed to send service update to proxy server %s (channel full)", conn.proxyID) } @@ -494,23 +493,31 @@ func (s *ProxyServiceServer) SendServiceUpdateToCluster(ctx context.Context, upd } // perProxyMessage returns a copy of update with a fresh one-time token for -// create/update operations. For delete operations the original message is -// returned unchanged because proxies do not need to authenticate for removal. +// create/update operations. For delete operations the original mapping is +// used unchanged because proxies do not need to authenticate for removal. // Returns nil if token generation fails (the proxy should be skipped). -func (s *ProxyServiceServer) perProxyMessage(update *proto.ProxyMapping, proxyID string) *proto.ProxyMapping { - if update.Type == proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED || update.AccountId == "" { - return update +func (s *ProxyServiceServer) perProxyMessage(update *proto.GetMappingUpdateResponse, proxyID string) *proto.GetMappingUpdateResponse { + resp := make([]*proto.ProxyMapping, 0, len(update.Mapping)) + for _, mapping := range update.Mapping { + if mapping.Type == proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED { + resp = append(resp, mapping) + continue + } + + token, err := s.tokenStore.GenerateToken(mapping.AccountId, mapping.Id, 5*time.Minute) + if err != nil { + log.Warnf("Failed to generate token for proxy %s: %v", proxyID, err) + return nil + } + + msg := shallowCloneMapping(mapping) + msg.AuthToken = token + resp = append(resp, msg) } - token, err := s.tokenStore.GenerateToken(update.AccountId, update.Id, 5*time.Minute) - if err != nil { - log.Warnf("Failed to generate token for proxy %s: %v", proxyID, err) - return nil + return &proto.GetMappingUpdateResponse{ + Mapping: resp, } - - msg := shallowCloneMapping(update) - msg.AuthToken = token - return msg } // shallowCloneMapping creates a shallow copy of a ProxyMapping, reusing the diff --git a/management/internals/shared/grpc/proxy_group_access_test.go b/management/internals/shared/grpc/proxy_group_access_test.go index aaf994bd5..c368ba659 100644 --- a/management/internals/shared/grpc/proxy_group_access_test.go +++ b/management/internals/shared/grpc/proxy_group_access_test.go @@ -17,6 +17,10 @@ type mockReverseProxyManager struct { err error } +func (m *mockReverseProxyManager) DeleteAllServices(ctx context.Context, accountID, userID string) error { + return nil +} + func (m *mockReverseProxyManager) GetAccountServices(ctx context.Context, accountID string) ([]*reverseproxy.Service, error) { if m.err != nil { return nil, m.err diff --git a/management/internals/shared/grpc/proxy_test.go b/management/internals/shared/grpc/proxy_test.go index f8863d78f..841686f30 100644 --- a/management/internals/shared/grpc/proxy_test.go +++ b/management/internals/shared/grpc/proxy_test.go @@ -17,8 +17,8 @@ import ( // registerFakeProxy adds a fake proxy connection to the server's internal maps // and returns the channel where messages will be received. -func registerFakeProxy(s *ProxyServiceServer, proxyID, clusterAddr string) chan *proto.ProxyMapping { - ch := make(chan *proto.ProxyMapping, 10) +func registerFakeProxy(s *ProxyServiceServer, proxyID, clusterAddr string) chan *proto.GetMappingUpdateResponse { + ch := make(chan *proto.GetMappingUpdateResponse, 10) conn := &proxyConnection{ proxyID: proxyID, address: clusterAddr, @@ -32,7 +32,7 @@ func registerFakeProxy(s *ProxyServiceServer, proxyID, clusterAddr string) chan return ch } -func drainChannel(ch chan *proto.ProxyMapping) *proto.ProxyMapping { +func drainChannel(ch chan *proto.GetMappingUpdateResponse) *proto.GetMappingUpdateResponse { select { case msg := <-ch: return msg @@ -46,20 +46,19 @@ func TestSendServiceUpdateToCluster_UniqueTokensPerProxy(t *testing.T) { require.NoError(t, err) s := &ProxyServiceServer{ - tokenStore: tokenStore, - updatesChan: make(chan *proto.ProxyMapping, 100), + tokenStore: tokenStore, } const cluster = "proxy.example.com" const numProxies = 3 - channels := make([]chan *proto.ProxyMapping, numProxies) + channels := make([]chan *proto.GetMappingUpdateResponse, numProxies) for i := range numProxies { id := "proxy-" + string(rune('a'+i)) channels[i] = registerFakeProxy(s, id, cluster) } - update := &proto.ProxyMapping{ + mapping := &proto.ProxyMapping{ Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED, Id: "service-1", AccountId: "account-1", @@ -73,10 +72,12 @@ func TestSendServiceUpdateToCluster_UniqueTokensPerProxy(t *testing.T) { tokens := make([]string, numProxies) for i, ch := range channels { - msg := drainChannel(ch) - require.NotNil(t, msg, "proxy %d should receive a message", i) - assert.Equal(t, update.Domain, msg.Domain) - assert.Equal(t, update.Id, msg.Id) + resp := drainChannel(ch) + require.NotNil(t, resp, "proxy %d should receive a message", i) + require.Len(t, resp.Mapping, 1, "proxy %d should receive exactly one mapping", i) + msg := resp.Mapping[0] + assert.Equal(t, mapping.Domain, msg.Domain) + assert.Equal(t, mapping.Id, msg.Id) assert.NotEmpty(t, msg.AuthToken, "proxy %d should have a non-empty token", i) tokens[i] = msg.AuthToken } @@ -101,15 +102,14 @@ func TestSendServiceUpdateToCluster_DeleteNoToken(t *testing.T) { require.NoError(t, err) s := &ProxyServiceServer{ - tokenStore: tokenStore, - updatesChan: make(chan *proto.ProxyMapping, 100), + tokenStore: tokenStore, } const cluster = "proxy.example.com" ch1 := registerFakeProxy(s, "proxy-a", cluster) ch2 := registerFakeProxy(s, "proxy-b", cluster) - update := &proto.ProxyMapping{ + mapping := &proto.ProxyMapping{ Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED, Id: "service-1", AccountId: "account-1", @@ -118,10 +118,12 @@ func TestSendServiceUpdateToCluster_DeleteNoToken(t *testing.T) { s.SendServiceUpdateToCluster(context.Background(), update, cluster) - msg1 := drainChannel(ch1) - msg2 := drainChannel(ch2) - require.NotNil(t, msg1) - require.NotNil(t, msg2) + resp1 := drainChannel(ch1) + resp2 := drainChannel(ch2) + require.NotNil(t, resp1) + require.NotNil(t, resp2) + require.Len(t, resp1.Mapping, 1) + require.Len(t, resp2.Mapping, 1) // Delete operations should not generate tokens assert.Empty(t, msg1.AuthToken) @@ -133,27 +135,35 @@ func TestSendServiceUpdate_UniqueTokensPerProxy(t *testing.T) { require.NoError(t, err) s := &ProxyServiceServer{ - tokenStore: tokenStore, - updatesChan: make(chan *proto.ProxyMapping, 100), + tokenStore: tokenStore, } // Register proxies in different clusters (SendServiceUpdate broadcasts to all) ch1 := registerFakeProxy(s, "proxy-a", "cluster-a") ch2 := registerFakeProxy(s, "proxy-b", "cluster-b") - update := &proto.ProxyMapping{ + mapping := &proto.ProxyMapping{ Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED, Id: "service-1", AccountId: "account-1", Domain: "test.example.com", } + update := &proto.GetMappingUpdateResponse{ + Mapping: []*proto.ProxyMapping{mapping}, + } + s.SendServiceUpdate(update) - msg1 := drainChannel(ch1) - msg2 := drainChannel(ch2) - require.NotNil(t, msg1) - require.NotNil(t, msg2) + resp1 := drainChannel(ch1) + resp2 := drainChannel(ch2) + require.NotNil(t, resp1) + require.NotNil(t, resp2) + require.Len(t, resp1.Mapping, 1) + require.Len(t, resp2.Mapping, 1) + + msg1 := resp1.Mapping[0] + msg2 := resp2.Mapping[0] assert.NotEmpty(t, msg1.AuthToken) assert.NotEmpty(t, msg2.AuthToken) diff --git a/management/server/account.go b/management/server/account.go index baf091f7c..42aa5d144 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -714,6 +714,11 @@ func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, u return status.Errorf(status.Internal, "failed to build user infos for account %s: %v", accountID, err) } + err = am.reverseProxyManager.DeleteAllServices(ctx, accountID, userID) + if err != nil { + return status.Errorf(status.Internal, "failed to delete service %s: %v", accountID, err) + } + for _, otherUser := range account.Users { if otherUser.Id == userID { continue diff --git a/management/server/account_test.go b/management/server/account_test.go index 287448845..cf096c519 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -31,6 +31,7 @@ import ( reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" "github.com/netbirdio/netbird/management/internals/modules/zones" "github.com/netbirdio/netbird/management/internals/server/config" + nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" nbAccount "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/cache" @@ -3122,7 +3123,8 @@ func createManager(t testing.TB) (*DefaultAccountManager, *update_channel.PeersU return nil, nil, err } - manager.SetServiceManager(reverseproxymanager.NewManager(store, manager, permissionsManager, nil, nil)) + proxyGrpcServer := nbgrpc.NewProxyServiceServer(nil, nil, nbgrpc.ProxyOIDCConfig{}, peersManager, nil) + manager.SetServiceManager(reverseproxymanager.NewManager(store, manager, permissionsManager, proxyGrpcServer, nil)) return manager, updateManager, nil } diff --git a/management/server/http/handlers/proxy/auth_callback_integration_test.go b/management/server/http/handlers/proxy/auth_callback_integration_test.go index 2a154f184..2fdd056d9 100644 --- a/management/server/http/handlers/proxy/auth_callback_integration_test.go +++ b/management/server/http/handlers/proxy/auth_callback_integration_test.go @@ -358,6 +358,10 @@ type testServiceManager struct { store store.Store } +func (m *testServiceManager) DeleteAllServices(ctx context.Context, accountID, userID string) error { + return nil +} + func (m *testServiceManager) GetAllServices(_ context.Context, _, _ string) ([]*reverseproxy.Service, error) { return nil, nil } diff --git a/management/server/metrics/selfhosted.go b/management/server/metrics/selfhosted.go index f7a344fcd..f7d07f3a0 100644 --- a/management/server/metrics/selfhosted.go +++ b/management/server/metrics/selfhosted.go @@ -210,6 +210,7 @@ func (w *Worker) generateProperties(ctx context.Context) properties { rosenpassEnabled int localUsers int idpUsers int + embeddedIdpTypes map[string]int ) start := time.Now() metricsProperties := make(properties) @@ -218,6 +219,7 @@ func (w *Worker) generateProperties(ctx context.Context) properties { rulesProtocol = make(map[string]int) rulesDirection = make(map[string]int) activeUsersLastDay = make(map[string]struct{}) + embeddedIdpTypes = make(map[string]int) uptime = time.Since(w.startupTime).Seconds() connections := w.connManager.GetAllConnectedPeers() version = nbversion.NetbirdVersion() @@ -277,6 +279,8 @@ func (w *Worker) generateProperties(ctx context.Context) properties { localUsers++ } else { idpUsers++ + idpType := extractIdpType(idpID) + embeddedIdpTypes[idpType]++ } } } @@ -369,6 +373,11 @@ func (w *Worker) generateProperties(ctx context.Context) properties { metricsProperties["rosenpass_enabled"] = rosenpassEnabled metricsProperties["local_users_count"] = localUsers metricsProperties["idp_users_count"] = idpUsers + metricsProperties["embedded_idp_count"] = len(embeddedIdpTypes) + + for idpType, count := range embeddedIdpTypes { + metricsProperties["embedded_idp_users_"+idpType] = count + } for protocol, count := range rulesProtocol { metricsProperties["rules_protocol_"+protocol] = count @@ -456,6 +465,17 @@ func createPostRequest(ctx context.Context, endpoint string, payloadStr string) return req, cancel, nil } +// extractIdpType extracts the IdP type from a Dex connector ID. +// Connector IDs are formatted as "-" (e.g., "okta-abc123", "zitadel-xyz"). +// Returns the type prefix, or "oidc" if no known prefix is found. +func extractIdpType(connectorID string) string { + idx := strings.LastIndex(connectorID, "-") + if idx <= 0 { + return "oidc" + } + return strings.ToLower(connectorID[:idx]) +} + func getMinMaxVersion(inputList []string) (string, string) { versions := make([]*version.Version, 0) diff --git a/management/server/metrics/selfhosted_test.go b/management/server/metrics/selfhosted_test.go index d0ab45cd7..504d228f7 100644 --- a/management/server/metrics/selfhosted_test.go +++ b/management/server/metrics/selfhosted_test.go @@ -27,7 +27,7 @@ func (mockDatasource) GetAllConnectedPeers() map[string]struct{} { // GetAllAccounts returns a list of *server.Account for use in tests with predefined information func (mockDatasource) GetAllAccounts(_ context.Context) []*types.Account { localUserID := dex.EncodeDexUserID("10", "local") - idpUserID := dex.EncodeDexUserID("20", "zitadel") + idpUserID := dex.EncodeDexUserID("20", "zitadel-d5uv82dra0haedlf6kv0") return []*types.Account{ { Id: "1", @@ -341,4 +341,37 @@ func TestGenerateProperties(t *testing.T) { if properties["idp_users_count"] != 1 { t.Errorf("expected 1 idp_users_count, got %d", properties["idp_users_count"]) } + if properties["embedded_idp_users_zitadel"] != 1 { + t.Errorf("expected 1 embedded_idp_users_zitadel, got %v", properties["embedded_idp_users_zitadel"]) + } + if properties["embedded_idp_count"] != 1 { + t.Errorf("expected 1 embedded_idp_count, got %v", properties["embedded_idp_count"]) + } +} + +func TestExtractIdpType(t *testing.T) { + tests := []struct { + connectorID string + expected string + }{ + {"okta-abc123def", "okta"}, + {"zitadel-d5uv82dra0haedlf6kv0", "zitadel"}, + {"entra-xyz789", "entra"}, + {"google-abc123", "google"}, + {"pocketid-abc123", "pocketid"}, + {"microsoft-abc123", "microsoft"}, + {"authentik-abc123", "authentik"}, + {"keycloak-d5uv82dra0haedlf6kv0", "keycloak"}, + {"local", "oidc"}, + {"", "oidc"}, + } + + for _, tt := range tests { + t.Run(tt.connectorID, func(t *testing.T) { + result := extractIdpType(tt.connectorID) + if result != tt.expected { + t.Errorf("extractIdpType(%q) = %q, want %q", tt.connectorID, result, tt.expected) + } + }) + } } diff --git a/management/server/store/store_mock.go b/management/server/store/store_mock.go index 42bc94921..a382907a7 100644 --- a/management/server/store/store_mock.go +++ b/management/server/store/store_mock.go @@ -1124,6 +1124,21 @@ func (mr *MockStoreMockRecorder) GetAccountServices(ctx, lockStrength, accountID return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAccountServices", reflect.TypeOf((*MockStore)(nil).GetAccountServices), ctx, lockStrength, accountID) } +// GetServicesByAccountID mocks base method. +func (m *MockStore) GetServicesByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*reverseproxy.Service, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetServicesByAccountID", ctx, lockStrength, accountID) + ret0, _ := ret[0].([]*reverseproxy.Service) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetServicesByAccountID indicates an expected call of GetServicesByAccountID. +func (mr *MockStoreMockRecorder) GetServicesByAccountID(ctx, lockStrength, accountID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServicesByAccountID", reflect.TypeOf((*MockStore)(nil).GetServicesByAccountID), ctx, lockStrength, accountID) +} + // GetAccountSettings mocks base method. func (m *MockStore) GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*types2.Settings, error) { m.ctrl.T.Helper() diff --git a/management/server/types/account_components.go b/management/server/types/account_components.go new file mode 100644 index 000000000..1eb25cecc --- /dev/null +++ b/management/server/types/account_components.go @@ -0,0 +1,576 @@ +package types + +import ( + "context" + "slices" + "time" + + log "github.com/sirupsen/logrus" + + nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/management/internals/modules/zones" + resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" + routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/telemetry" + "github.com/netbirdio/netbird/route" +) + +func (a *Account) GetPeerNetworkMapFromComponents( + ctx context.Context, + peerID string, + peersCustomZone nbdns.CustomZone, + accountZones []*zones.Zone, + validatedPeersMap map[string]struct{}, + resourcePolicies map[string][]*Policy, + routers map[string]map[string]*routerTypes.NetworkRouter, + metrics *telemetry.AccountManagerMetrics, + groupIDToUserIDs map[string][]string, +) *NetworkMap { + start := time.Now() + + components := a.GetPeerNetworkMapComponents( + ctx, + peerID, + peersCustomZone, + accountZones, + validatedPeersMap, + resourcePolicies, + routers, + groupIDToUserIDs, + ) + + if components == nil { + return &NetworkMap{Network: a.Network.Copy()} + } + + nm := CalculateNetworkMapFromComponents(ctx, components) + + if metrics != nil { + objectCount := int64(len(nm.Peers) + len(nm.OfflinePeers) + len(nm.Routes) + len(nm.FirewallRules) + len(nm.RoutesFirewallRules)) + metrics.CountNetworkMapObjects(objectCount) + metrics.CountGetPeerNetworkMapDuration(time.Since(start)) + + if objectCount > 5000 { + log.WithContext(ctx).Tracef("account: %s has a total resource count of %d objects from components, "+ + "peers: %d, offline peers: %d, routes: %d, firewall rules: %d, route firewall rules: %d", + a.Id, objectCount, len(nm.Peers), len(nm.OfflinePeers), len(nm.Routes), len(nm.FirewallRules), len(nm.RoutesFirewallRules)) + } + } + + return nm +} + +func (a *Account) GetPeerNetworkMapComponents( + ctx context.Context, + peerID string, + peersCustomZone nbdns.CustomZone, + accountZones []*zones.Zone, + validatedPeersMap map[string]struct{}, + resourcePolicies map[string][]*Policy, + routers map[string]map[string]*routerTypes.NetworkRouter, + groupIDToUserIDs map[string][]string, +) *NetworkMapComponents { + + peer := a.Peers[peerID] + if peer == nil { + return nil + } + + if _, ok := validatedPeersMap[peerID]; !ok { + return nil + } + + components := &NetworkMapComponents{ + PeerID: peerID, + Network: a.Network.Copy(), + NameServerGroups: make([]*nbdns.NameServerGroup, 0), + CustomZoneDomain: peersCustomZone.Domain, + ResourcePoliciesMap: make(map[string][]*Policy), + RoutersMap: make(map[string]map[string]*routerTypes.NetworkRouter), + NetworkResources: make([]*resourceTypes.NetworkResource, 0), + PostureFailedPeers: make(map[string]map[string]struct{}, len(a.PostureChecks)), + RouterPeers: make(map[string]*nbpeer.Peer), + } + + components.AccountSettings = &AccountSettingsInfo{ + PeerLoginExpirationEnabled: a.Settings.PeerLoginExpirationEnabled, + PeerLoginExpiration: a.Settings.PeerLoginExpiration, + PeerInactivityExpirationEnabled: a.Settings.PeerInactivityExpirationEnabled, + PeerInactivityExpiration: a.Settings.PeerInactivityExpiration, + } + + components.DNSSettings = &a.DNSSettings + + relevantPeers, relevantGroups, relevantPolicies, relevantRoutes, sshReqs := a.getPeersGroupsPoliciesRoutes(ctx, peerID, peer.SSHEnabled, validatedPeersMap, &components.PostureFailedPeers) + + if len(sshReqs.neededGroupIDs) > 0 { + components.GroupIDToUserIDs = filterGroupIDToUserIDs(groupIDToUserIDs, sshReqs.neededGroupIDs) + } + if sshReqs.needAllowedUserIDs { + components.AllowedUserIDs = a.getAllowedUserIDs() + } + + components.Peers = relevantPeers + components.Groups = relevantGroups + components.Policies = relevantPolicies + components.Routes = relevantRoutes + components.AllDNSRecords = filterDNSRecordsByPeers(peersCustomZone.Records, relevantPeers) + + peerGroups := a.GetPeerGroups(peerID) + components.AccountZones = filterPeerAppliedZones(ctx, accountZones, peerGroups) + + for _, nsGroup := range a.NameServerGroups { + if nsGroup.Enabled { + for _, gID := range nsGroup.Groups { + if _, found := relevantGroups[gID]; found { + components.NameServerGroups = append(components.NameServerGroups, nsGroup) + break + } + } + } + } + + for _, resource := range a.NetworkResources { + if !resource.Enabled { + continue + } + + policies, exists := resourcePolicies[resource.ID] + if !exists { + continue + } + + addSourcePeers := false + + networkRoutingPeers, routerExists := routers[resource.NetworkID] + if routerExists { + if _, ok := networkRoutingPeers[peerID]; ok { + addSourcePeers = true + } + } + + for _, policy := range policies { + if addSourcePeers { + var peers []string + if policy.Rules[0].SourceResource.Type == ResourceTypePeer && policy.Rules[0].SourceResource.ID != "" { + peers = []string{policy.Rules[0].SourceResource.ID} + } else { + peers = a.getUniquePeerIDsFromGroupsIDs(ctx, policy.SourceGroups()) + } + for _, pID := range a.getPostureValidPeersSaveFailed(peers, policy.SourcePostureChecks, validatedPeersMap, &components.PostureFailedPeers) { + if _, exists := components.Peers[pID]; !exists { + components.Peers[pID] = a.GetPeer(pID) + } + } + } else { + peerInSources := false + if policy.Rules[0].SourceResource.Type == ResourceTypePeer && policy.Rules[0].SourceResource.ID != "" { + peerInSources = policy.Rules[0].SourceResource.ID == peerID + } else { + for _, groupID := range policy.SourceGroups() { + if group := a.GetGroup(groupID); group != nil && slices.Contains(group.Peers, peerID) { + peerInSources = true + break + } + } + } + if !peerInSources { + continue + } + isValid, pname := a.validatePostureChecksOnPeerGetFailed(ctx, policy.SourcePostureChecks, peerID) + if !isValid && len(pname) > 0 { + if _, ok := components.PostureFailedPeers[pname]; !ok { + components.PostureFailedPeers[pname] = make(map[string]struct{}) + } + components.PostureFailedPeers[pname][peer.ID] = struct{}{} + continue + } + addSourcePeers = true + } + + for _, rule := range policy.Rules { + for _, srcGroupID := range rule.Sources { + if g := a.Groups[srcGroupID]; g != nil { + if _, exists := components.Groups[srcGroupID]; !exists { + components.Groups[srcGroupID] = g + } + } + } + for _, dstGroupID := range rule.Destinations { + if g := a.Groups[dstGroupID]; g != nil { + if _, exists := components.Groups[dstGroupID]; !exists { + components.Groups[dstGroupID] = g + } + } + } + } + components.ResourcePoliciesMap[resource.ID] = policies + } + + components.RoutersMap[resource.NetworkID] = networkRoutingPeers + for peerIDKey := range networkRoutingPeers { + if p := a.Peers[peerIDKey]; p != nil { + if _, exists := components.RouterPeers[peerIDKey]; !exists { + components.RouterPeers[peerIDKey] = p + } + if _, exists := components.Peers[peerIDKey]; !exists { + if _, validated := validatedPeersMap[peerIDKey]; validated { + components.Peers[peerIDKey] = p + } + } + } + } + + if addSourcePeers { + components.NetworkResources = append(components.NetworkResources, resource) + } + } + + filterGroupPeers(&components.Groups, components.Peers) + filterPostureFailedPeers(&components.PostureFailedPeers, components.Policies, components.ResourcePoliciesMap, components.Peers) + + return components +} + +type sshRequirements struct { + neededGroupIDs map[string]struct{} + needAllowedUserIDs bool +} + +func (a *Account) getPeersGroupsPoliciesRoutes( + ctx context.Context, + peerID string, + peerSSHEnabled bool, + validatedPeersMap map[string]struct{}, + postureFailedPeers *map[string]map[string]struct{}, +) (map[string]*nbpeer.Peer, map[string]*Group, []*Policy, []*route.Route, sshRequirements) { + relevantPeerIDs := make(map[string]*nbpeer.Peer, len(a.Peers)/4) + relevantGroupIDs := make(map[string]*Group, len(a.Groups)/4) + relevantPolicies := make([]*Policy, 0, len(a.Policies)) + relevantRoutes := make([]*route.Route, 0, len(a.Routes)) + sshReqs := sshRequirements{neededGroupIDs: make(map[string]struct{})} + + relevantPeerIDs[peerID] = a.GetPeer(peerID) + + for groupID, group := range a.Groups { + if slices.Contains(group.Peers, peerID) { + relevantGroupIDs[groupID] = a.GetGroup(groupID) + } + } + + routeAccessControlGroups := make(map[string]struct{}) + for _, r := range a.Routes { + for _, groupID := range r.Groups { + relevantGroupIDs[groupID] = a.GetGroup(groupID) + } + for _, groupID := range r.PeerGroups { + relevantGroupIDs[groupID] = a.GetGroup(groupID) + } + if r.Enabled { + for _, groupID := range r.AccessControlGroups { + relevantGroupIDs[groupID] = a.GetGroup(groupID) + routeAccessControlGroups[groupID] = struct{}{} + } + } + relevantRoutes = append(relevantRoutes, r) + } + + for _, policy := range a.Policies { + if !policy.Enabled { + continue + } + + policyRelevant := false + for _, rule := range policy.Rules { + if !rule.Enabled { + continue + } + + if len(routeAccessControlGroups) > 0 { + for _, destGroupID := range rule.Destinations { + if _, needed := routeAccessControlGroups[destGroupID]; needed { + policyRelevant = true + for _, srcGroupID := range rule.Sources { + relevantGroupIDs[srcGroupID] = a.GetGroup(srcGroupID) + } + for _, dstGroupID := range rule.Destinations { + relevantGroupIDs[dstGroupID] = a.GetGroup(dstGroupID) + } + break + } + } + } + + var sourcePeers, destinationPeers []string + var peerInSources, peerInDestinations bool + + if rule.SourceResource.Type == ResourceTypePeer && rule.SourceResource.ID != "" { + sourcePeers = []string{rule.SourceResource.ID} + if rule.SourceResource.ID == peerID { + peerInSources = true + } + } else { + sourcePeers, peerInSources = a.getPeersFromGroups(ctx, rule.Sources, peerID, policy.SourcePostureChecks, validatedPeersMap, postureFailedPeers) + } + + if rule.DestinationResource.Type == ResourceTypePeer && rule.DestinationResource.ID != "" { + destinationPeers = []string{rule.DestinationResource.ID} + if rule.DestinationResource.ID == peerID { + peerInDestinations = true + } + } else { + destinationPeers, peerInDestinations = a.getPeersFromGroups(ctx, rule.Destinations, peerID, nil, validatedPeersMap, postureFailedPeers) + } + + if peerInSources { + policyRelevant = true + for _, pid := range destinationPeers { + relevantPeerIDs[pid] = a.GetPeer(pid) + } + for _, dstGroupID := range rule.Destinations { + relevantGroupIDs[dstGroupID] = a.GetGroup(dstGroupID) + } + } + + if peerInDestinations { + policyRelevant = true + for _, pid := range sourcePeers { + relevantPeerIDs[pid] = a.GetPeer(pid) + } + for _, srcGroupID := range rule.Sources { + relevantGroupIDs[srcGroupID] = a.GetGroup(srcGroupID) + } + + if rule.Protocol == PolicyRuleProtocolNetbirdSSH { + switch { + case len(rule.AuthorizedGroups) > 0: + for groupID := range rule.AuthorizedGroups { + sshReqs.neededGroupIDs[groupID] = struct{}{} + } + case rule.AuthorizedUser != "": + default: + sshReqs.needAllowedUserIDs = true + } + } else if policyRuleImpliesLegacySSH(rule) && peerSSHEnabled { + sshReqs.needAllowedUserIDs = true + } + } + } + if policyRelevant { + relevantPolicies = append(relevantPolicies, policy) + } + } + + return relevantPeerIDs, relevantGroupIDs, relevantPolicies, relevantRoutes, sshReqs +} + +func (a *Account) getPeersFromGroups(ctx context.Context, groups []string, peerID string, sourcePostureChecksIDs []string, + validatedPeersMap map[string]struct{}, postureFailedPeers *map[string]map[string]struct{}) ([]string, bool) { + peerInGroups := false + filteredPeerIDs := make([]string, 0, len(a.Peers)) + seenPeerIds := make(map[string]struct{}, len(groups)) + + for _, gid := range groups { + group := a.GetGroup(gid) + if group == nil { + continue + } + + if group.IsGroupAll() || len(groups) == 1 { + filteredPeerIDs = filteredPeerIDs[:0] + peerInGroups = false + for _, pid := range group.Peers { + peer, ok := a.Peers[pid] + if !ok || peer == nil { + continue + } + + if _, ok := validatedPeersMap[peer.ID]; !ok { + continue + } + + isValid, pname := a.validatePostureChecksOnPeerGetFailed(ctx, sourcePostureChecksIDs, peer.ID) + if !isValid && len(pname) > 0 { + if _, ok := (*postureFailedPeers)[pname]; !ok { + (*postureFailedPeers)[pname] = make(map[string]struct{}) + } + (*postureFailedPeers)[pname][peer.ID] = struct{}{} + continue + } + + if peer.ID == peerID { + peerInGroups = true + continue + } + + filteredPeerIDs = append(filteredPeerIDs, peer.ID) + } + return filteredPeerIDs, peerInGroups + } + + for _, pid := range group.Peers { + if _, seen := seenPeerIds[pid]; seen { + continue + } + seenPeerIds[pid] = struct{}{} + peer, ok := a.Peers[pid] + if !ok || peer == nil { + continue + } + + if _, ok := validatedPeersMap[peer.ID]; !ok { + continue + } + + isValid, pname := a.validatePostureChecksOnPeerGetFailed(ctx, sourcePostureChecksIDs, peer.ID) + if !isValid && len(pname) > 0 { + if _, ok := (*postureFailedPeers)[pname]; !ok { + (*postureFailedPeers)[pname] = make(map[string]struct{}) + } + (*postureFailedPeers)[pname][peer.ID] = struct{}{} + continue + } + + if peer.ID == peerID { + peerInGroups = true + continue + } + + filteredPeerIDs = append(filteredPeerIDs, peer.ID) + } + } + + return filteredPeerIDs, peerInGroups +} + +func (a *Account) validatePostureChecksOnPeerGetFailed(ctx context.Context, sourcePostureChecksID []string, peerID string) (bool, string) { + peer, ok := a.Peers[peerID] + if !ok || peer == nil { + return false, "" + } + + for _, postureChecksID := range sourcePostureChecksID { + postureChecks := a.GetPostureChecks(postureChecksID) + if postureChecks == nil { + continue + } + + for _, check := range postureChecks.GetChecks() { + isValid, _ := check.Check(ctx, *peer) + if !isValid { + return false, postureChecksID + } + } + } + return true, "" +} + +func (a *Account) getPostureValidPeersSaveFailed(inputPeers []string, postureChecksIDs []string, validatedPeersMap map[string]struct{}, postureFailedPeers *map[string]map[string]struct{}) []string { + var dest []string + for _, peerID := range inputPeers { + if _, validated := validatedPeersMap[peerID]; !validated { + continue + } + valid, pname := a.validatePostureChecksOnPeerGetFailed(context.Background(), postureChecksIDs, peerID) + if valid { + dest = append(dest, peerID) + continue + } + if _, ok := (*postureFailedPeers)[pname]; !ok { + (*postureFailedPeers)[pname] = make(map[string]struct{}) + } + (*postureFailedPeers)[pname][peerID] = struct{}{} + } + return dest +} + +func filterGroupPeers(groups *map[string]*Group, peers map[string]*nbpeer.Peer) { + for groupID, groupInfo := range *groups { + filteredPeers := make([]string, 0, len(groupInfo.Peers)) + for _, pid := range groupInfo.Peers { + if _, exists := peers[pid]; exists { + filteredPeers = append(filteredPeers, pid) + } + } + + if len(filteredPeers) == 0 { + delete(*groups, groupID) + } else if len(filteredPeers) != len(groupInfo.Peers) { + ng := groupInfo.Copy() + ng.Peers = filteredPeers + (*groups)[groupID] = ng + } + } +} + +func filterPostureFailedPeers(postureFailedPeers *map[string]map[string]struct{}, policies []*Policy, resourcePoliciesMap map[string][]*Policy, peers map[string]*nbpeer.Peer) { + if len(*postureFailedPeers) == 0 { + return + } + + referencedPostureChecks := make(map[string]struct{}) + for _, policy := range policies { + for _, checkID := range policy.SourcePostureChecks { + referencedPostureChecks[checkID] = struct{}{} + } + } + for _, resPolicies := range resourcePoliciesMap { + for _, policy := range resPolicies { + for _, checkID := range policy.SourcePostureChecks { + referencedPostureChecks[checkID] = struct{}{} + } + } + } + + for checkID, failedPeers := range *postureFailedPeers { + if _, referenced := referencedPostureChecks[checkID]; !referenced { + delete(*postureFailedPeers, checkID) + continue + } + for peerID := range failedPeers { + if _, exists := peers[peerID]; !exists { + delete(failedPeers, peerID) + } + } + if len(failedPeers) == 0 { + delete(*postureFailedPeers, checkID) + } + } +} + +func filterDNSRecordsByPeers(records []nbdns.SimpleRecord, peers map[string]*nbpeer.Peer) []nbdns.SimpleRecord { + if len(records) == 0 || len(peers) == 0 { + return nil + } + + peerIPs := make(map[string]struct{}, len(peers)) + for _, peer := range peers { + if peer != nil { + peerIPs[peer.IP.String()] = struct{}{} + } + } + + filteredRecords := make([]nbdns.SimpleRecord, 0, len(records)) + for _, record := range records { + if _, exists := peerIPs[record.RData]; exists { + filteredRecords = append(filteredRecords, record) + } + } + + return filteredRecords +} + +func filterGroupIDToUserIDs(fullMap map[string][]string, neededGroupIDs map[string]struct{}) map[string][]string { + if len(neededGroupIDs) == 0 { + return nil + } + + filtered := make(map[string][]string, len(neededGroupIDs)) + for groupID := range neededGroupIDs { + if users, ok := fullMap[groupID]; ok { + filtered[groupID] = users + } + } + return filtered +} diff --git a/management/server/types/networkmap_comparison_test.go b/management/server/types/networkmap_comparison_test.go new file mode 100644 index 000000000..c5844cca0 --- /dev/null +++ b/management/server/types/networkmap_comparison_test.go @@ -0,0 +1,592 @@ +package types + +import ( + "context" + "encoding/json" + "fmt" + "net" + "net/netip" + "os" + "path/filepath" + "sort" + "testing" + "time" + + "github.com/stretchr/testify/require" + + nbdns "github.com/netbirdio/netbird/dns" + resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" + routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" + networkTypes "github.com/netbirdio/netbird/management/server/networks/types" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/posture" + "github.com/netbirdio/netbird/route" +) + +func TestNetworkMapComponents_CompareWithLegacy(t *testing.T) { + account := createTestAccount() + ctx := context.Background() + + peerID := testingPeerID + validatedPeersMap := make(map[string]struct{}) + for i := range numPeers { + pid := fmt.Sprintf("peer-%d", i) + if pid == offlinePeerID { + continue + } + validatedPeersMap[pid] = struct{}{} + } + + peersCustomZone := nbdns.CustomZone{} + resourcePolicies := account.GetResourcePoliciesMap() + routers := account.GetResourceRoutersMap() + groupIDToUserIDs := account.GetActiveGroupUsers() + + legacyNetworkMap := account.GetPeerNetworkMap( + ctx, + peerID, + peersCustomZone, + nil, + validatedPeersMap, + resourcePolicies, + routers, + nil, + groupIDToUserIDs, + ) + + components := account.GetPeerNetworkMapComponents( + ctx, + peerID, + peersCustomZone, + nil, + validatedPeersMap, + resourcePolicies, + routers, + groupIDToUserIDs, + ) + + if components == nil { + t.Fatal("GetPeerNetworkMapComponents returned nil") + } + + newNetworkMap := CalculateNetworkMapFromComponents(ctx, components) + + if newNetworkMap == nil { + t.Fatal("CalculateNetworkMapFromComponents returned nil") + } + + compareNetworkMaps(t, legacyNetworkMap, newNetworkMap) +} + +func TestNetworkMapComponents_GoldenFileComparison(t *testing.T) { + account := createTestAccount() + ctx := context.Background() + + peerID := testingPeerID + validatedPeersMap := make(map[string]struct{}) + for i := range numPeers { + pid := fmt.Sprintf("peer-%d", i) + if pid == offlinePeerID { + continue + } + validatedPeersMap[pid] = struct{}{} + } + + peersCustomZone := nbdns.CustomZone{} + resourcePolicies := account.GetResourcePoliciesMap() + routers := account.GetResourceRoutersMap() + groupIDToUserIDs := account.GetActiveGroupUsers() + + legacyNetworkMap := account.GetPeerNetworkMap( + ctx, + peerID, + peersCustomZone, + nil, + validatedPeersMap, + resourcePolicies, + routers, + nil, + groupIDToUserIDs, + ) + + components := account.GetPeerNetworkMapComponents( + ctx, + peerID, + peersCustomZone, + nil, + validatedPeersMap, + resourcePolicies, + routers, + groupIDToUserIDs, + ) + + require.NotNil(t, components, "GetPeerNetworkMapComponents returned nil") + + newNetworkMap := CalculateNetworkMapFromComponents(ctx, components) + require.NotNil(t, newNetworkMap, "CalculateNetworkMapFromComponents returned nil") + + normalizeAndSortNetworkMap(legacyNetworkMap) + normalizeAndSortNetworkMap(newNetworkMap) + + componentsJSON, err := json.MarshalIndent(components, "", " ") + require.NoError(t, err, "error marshaling components to JSON") + + legacyJSON, err := json.MarshalIndent(legacyNetworkMap, "", " ") + require.NoError(t, err, "error marshaling legacy network map to JSON") + + newJSON, err := json.MarshalIndent(newNetworkMap, "", " ") + require.NoError(t, err, "error marshaling new network map to JSON") + + goldenDir := filepath.Join("testdata", "comparison") + err = os.MkdirAll(goldenDir, 0755) + require.NoError(t, err) + + legacyGoldenPath := filepath.Join(goldenDir, "legacy_networkmap.json") + err = os.WriteFile(legacyGoldenPath, legacyJSON, 0644) + require.NoError(t, err, "error writing legacy golden file") + + newGoldenPath := filepath.Join(goldenDir, "components_networkmap.json") + err = os.WriteFile(newGoldenPath, newJSON, 0644) + require.NoError(t, err, "error writing components golden file") + + componentsPath := filepath.Join(goldenDir, "components.json") + err = os.WriteFile(componentsPath, componentsJSON, 0644) + require.NoError(t, err, "error writing components golden file") + + require.JSONEq(t, string(legacyJSON), string(newJSON), + "NetworkMaps from legacy and components approaches do not match.\n"+ + "Legacy JSON saved to: %s\n"+ + "Components JSON saved to: %s", + legacyGoldenPath, newGoldenPath) + + t.Logf("✅ NetworkMaps are identical") + t.Logf(" Legacy NetworkMap: %s", legacyGoldenPath) + t.Logf(" Components NetworkMap: %s", newGoldenPath) +} + +func normalizeAndSortNetworkMap(nm *NetworkMap) { + if nm == nil { + return + } + + sort.Slice(nm.Peers, func(i, j int) bool { + return nm.Peers[i].ID < nm.Peers[j].ID + }) + + sort.Slice(nm.OfflinePeers, func(i, j int) bool { + return nm.OfflinePeers[i].ID < nm.OfflinePeers[j].ID + }) + + sort.Slice(nm.Routes, func(i, j int) bool { + return string(nm.Routes[i].ID) < string(nm.Routes[j].ID) + }) + + sort.Slice(nm.FirewallRules, func(i, j int) bool { + if nm.FirewallRules[i].PeerIP != nm.FirewallRules[j].PeerIP { + return nm.FirewallRules[i].PeerIP < nm.FirewallRules[j].PeerIP + } + if nm.FirewallRules[i].Direction != nm.FirewallRules[j].Direction { + return nm.FirewallRules[i].Direction < nm.FirewallRules[j].Direction + } + if nm.FirewallRules[i].Protocol != nm.FirewallRules[j].Protocol { + return nm.FirewallRules[i].Protocol < nm.FirewallRules[j].Protocol + } + if nm.FirewallRules[i].Port != nm.FirewallRules[j].Port { + return nm.FirewallRules[i].Port < nm.FirewallRules[j].Port + } + return nm.FirewallRules[i].PolicyID < nm.FirewallRules[j].PolicyID + }) + + for i := range nm.RoutesFirewallRules { + sort.Strings(nm.RoutesFirewallRules[i].SourceRanges) + } + + sort.Slice(nm.RoutesFirewallRules, func(i, j int) bool { + if nm.RoutesFirewallRules[i].Destination != nm.RoutesFirewallRules[j].Destination { + return nm.RoutesFirewallRules[i].Destination < nm.RoutesFirewallRules[j].Destination + } + + minLen := len(nm.RoutesFirewallRules[i].SourceRanges) + if len(nm.RoutesFirewallRules[j].SourceRanges) < minLen { + minLen = len(nm.RoutesFirewallRules[j].SourceRanges) + } + for k := 0; k < minLen; k++ { + if nm.RoutesFirewallRules[i].SourceRanges[k] != nm.RoutesFirewallRules[j].SourceRanges[k] { + return nm.RoutesFirewallRules[i].SourceRanges[k] < nm.RoutesFirewallRules[j].SourceRanges[k] + } + } + if len(nm.RoutesFirewallRules[i].SourceRanges) != len(nm.RoutesFirewallRules[j].SourceRanges) { + return len(nm.RoutesFirewallRules[i].SourceRanges) < len(nm.RoutesFirewallRules[j].SourceRanges) + } + + if string(nm.RoutesFirewallRules[i].RouteID) != string(nm.RoutesFirewallRules[j].RouteID) { + return string(nm.RoutesFirewallRules[i].RouteID) < string(nm.RoutesFirewallRules[j].RouteID) + } + + if nm.RoutesFirewallRules[i].PolicyID != nm.RoutesFirewallRules[j].PolicyID { + return nm.RoutesFirewallRules[i].PolicyID < nm.RoutesFirewallRules[j].PolicyID + } + + if nm.RoutesFirewallRules[i].Port != nm.RoutesFirewallRules[j].Port { + return nm.RoutesFirewallRules[i].Port < nm.RoutesFirewallRules[j].Port + } + + return nm.RoutesFirewallRules[i].Protocol < nm.RoutesFirewallRules[j].Protocol + }) + + if nm.DNSConfig.CustomZones != nil { + for i := range nm.DNSConfig.CustomZones { + sort.Slice(nm.DNSConfig.CustomZones[i].Records, func(a, b int) bool { + return nm.DNSConfig.CustomZones[i].Records[a].Name < nm.DNSConfig.CustomZones[i].Records[b].Name + }) + } + } + + if len(nm.DNSConfig.NameServerGroups) != 0 { + sort.Slice(nm.DNSConfig.NameServerGroups, func(a, b int) bool { + return nm.DNSConfig.NameServerGroups[a].Name < nm.DNSConfig.NameServerGroups[b].Name + }) + } +} + +func compareNetworkMaps(t *testing.T, legacy, current *NetworkMap) { + t.Helper() + + if legacy.Network.Serial != current.Network.Serial { + t.Errorf("Network Serial mismatch: legacy=%d, current=%d", legacy.Network.Serial, current.Network.Serial) + } + + if len(legacy.Peers) != len(current.Peers) { + t.Errorf("Peers count mismatch: legacy=%d, current=%d", len(legacy.Peers), len(current.Peers)) + } + + legacyPeerIDs := make(map[string]bool) + for _, p := range legacy.Peers { + legacyPeerIDs[p.ID] = true + } + + for _, p := range current.Peers { + if !legacyPeerIDs[p.ID] { + t.Errorf("Current NetworkMap contains peer %s not in legacy", p.ID) + } + } + + if len(legacy.OfflinePeers) != len(current.OfflinePeers) { + t.Errorf("OfflinePeers count mismatch: legacy=%d, current=%d", len(legacy.OfflinePeers), len(current.OfflinePeers)) + } + + if len(legacy.FirewallRules) != len(current.FirewallRules) { + t.Logf("FirewallRules count mismatch: legacy=%d, current=%d", len(legacy.FirewallRules), len(current.FirewallRules)) + } + + if len(legacy.Routes) != len(current.Routes) { + t.Logf("Routes count mismatch: legacy=%d, current=%d", len(legacy.Routes), len(current.Routes)) + } + + if len(legacy.RoutesFirewallRules) != len(current.RoutesFirewallRules) { + t.Logf("RoutesFirewallRules count mismatch: legacy=%d, current=%d", len(legacy.RoutesFirewallRules), len(current.RoutesFirewallRules)) + } + + if legacy.DNSConfig.ServiceEnable != current.DNSConfig.ServiceEnable { + t.Errorf("DNSConfig.ServiceEnable mismatch: legacy=%v, current=%v", legacy.DNSConfig.ServiceEnable, current.DNSConfig.ServiceEnable) + } +} + +const ( + numPeers = 100 + devGroupID = "group-dev" + opsGroupID = "group-ops" + allGroupID = "group-all" + routeID = route.ID("route-main") + routeHA1ID = route.ID("route-ha-1") + routeHA2ID = route.ID("route-ha-2") + policyIDDevOps = "policy-dev-ops" + policyIDAll = "policy-all" + policyIDPosture = "policy-posture" + policyIDDrop = "policy-drop" + postureCheckID = "posture-check-ver" + networkResourceID = "res-database" + networkID = "net-database" + networkRouterID = "router-database" + nameserverGroupID = "ns-group-main" + testingPeerID = "peer-60" + expiredPeerID = "peer-98" + offlinePeerID = "peer-99" + routingPeerID = "peer-95" + testAccountID = "account-comparison-test" +) + +func createTestAccount() *Account { + peers := make(map[string]*nbpeer.Peer) + devGroupPeers, opsGroupPeers, allGroupPeers := []string{}, []string{}, []string{} + + for i := range numPeers { + peerID := fmt.Sprintf("peer-%d", i) + ip := net.IP{100, 64, 0, byte(i + 1)} + wtVersion := "0.25.0" + if i%2 == 0 { + wtVersion = "0.40.0" + } + + p := &nbpeer.Peer{ + ID: peerID, IP: ip, Key: fmt.Sprintf("key-%s", peerID), DNSLabel: fmt.Sprintf("peer%d", i+1), + Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()}, + UserID: "user-admin", Meta: nbpeer.PeerSystemMeta{WtVersion: wtVersion, GoOS: "linux"}, + } + + if peerID == expiredPeerID { + p.LoginExpirationEnabled = true + pastTimestamp := time.Now().Add(-2 * time.Hour) + p.LastLogin = &pastTimestamp + } + + peers[peerID] = p + allGroupPeers = append(allGroupPeers, peerID) + if i < numPeers/2 { + devGroupPeers = append(devGroupPeers, peerID) + } else { + opsGroupPeers = append(opsGroupPeers, peerID) + } + } + + groups := map[string]*Group{ + allGroupID: {ID: allGroupID, Name: "All", Peers: allGroupPeers}, + devGroupID: {ID: devGroupID, Name: "Developers", Peers: devGroupPeers}, + opsGroupID: {ID: opsGroupID, Name: "Operations", Peers: opsGroupPeers}, + } + + policies := []*Policy{ + { + ID: policyIDAll, Name: "Default-Allow", Enabled: true, + Rules: []*PolicyRule{{ + ID: policyIDAll, Name: "Allow All", Enabled: true, Action: PolicyTrafficActionAccept, + Protocol: PolicyRuleProtocolALL, Bidirectional: true, + Sources: []string{allGroupID}, Destinations: []string{allGroupID}, + }}, + }, + { + ID: policyIDDevOps, Name: "Dev to Ops Web Access", Enabled: true, + Rules: []*PolicyRule{{ + ID: policyIDDevOps, Name: "Dev -> Ops (HTTP Range)", Enabled: true, Action: PolicyTrafficActionAccept, + Protocol: PolicyRuleProtocolTCP, Bidirectional: false, + PortRanges: []RulePortRange{{Start: 8080, End: 8090}}, + Sources: []string{devGroupID}, Destinations: []string{opsGroupID}, + }}, + }, + { + ID: policyIDDrop, Name: "Drop DB traffic", Enabled: true, + Rules: []*PolicyRule{{ + ID: policyIDDrop, Name: "Drop DB", Enabled: true, Action: PolicyTrafficActionDrop, + Protocol: PolicyRuleProtocolTCP, Ports: []string{"5432"}, Bidirectional: true, + Sources: []string{devGroupID}, Destinations: []string{opsGroupID}, + }}, + }, + { + ID: policyIDPosture, Name: "Posture Check for DB Resource", Enabled: true, + SourcePostureChecks: []string{postureCheckID}, + Rules: []*PolicyRule{{ + ID: policyIDPosture, Name: "Allow DB Access", Enabled: true, Action: PolicyTrafficActionAccept, + Protocol: PolicyRuleProtocolALL, Bidirectional: true, + Sources: []string{opsGroupID}, DestinationResource: Resource{ID: networkResourceID}, + }}, + }, + } + + routes := map[route.ID]*route.Route{ + routeID: { + ID: routeID, Network: netip.MustParsePrefix("192.168.10.0/24"), + Peer: peers["peer-75"].Key, + PeerID: "peer-75", + Description: "Route to internal resource", Enabled: true, + PeerGroups: []string{devGroupID, opsGroupID}, + Groups: []string{devGroupID, opsGroupID}, + AccessControlGroups: []string{devGroupID}, + }, + routeHA1ID: { + ID: routeHA1ID, Network: netip.MustParsePrefix("10.10.0.0/16"), + Peer: peers["peer-80"].Key, + PeerID: "peer-80", + Description: "HA Route 1", Enabled: true, Metric: 1000, + PeerGroups: []string{allGroupID}, + Groups: []string{allGroupID}, + AccessControlGroups: []string{allGroupID}, + }, + routeHA2ID: { + ID: routeHA2ID, Network: netip.MustParsePrefix("10.10.0.0/16"), + Peer: peers["peer-90"].Key, + PeerID: "peer-90", + Description: "HA Route 2", Enabled: true, Metric: 900, + PeerGroups: []string{devGroupID, opsGroupID}, + Groups: []string{devGroupID, opsGroupID}, + AccessControlGroups: []string{allGroupID}, + }, + } + + account := &Account{ + Id: testAccountID, Peers: peers, Groups: groups, Policies: policies, Routes: routes, + Network: &Network{ + Identifier: "net-comparison-test", Net: net.IPNet{IP: net.IP{100, 64, 0, 0}, Mask: net.CIDRMask(16, 32)}, Serial: 1, + }, + DNSSettings: DNSSettings{DisabledManagementGroups: []string{opsGroupID}}, + NameServerGroups: map[string]*nbdns.NameServerGroup{ + nameserverGroupID: { + ID: nameserverGroupID, Name: "Main NS", Enabled: true, Groups: []string{devGroupID}, + NameServers: []nbdns.NameServer{{IP: netip.MustParseAddr("8.8.8.8"), NSType: nbdns.UDPNameServerType, Port: 53}}, + }, + }, + PostureChecks: []*posture.Checks{ + {ID: postureCheckID, Name: "Check version", Checks: posture.ChecksDefinition{ + NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.26.0"}, + }}, + }, + NetworkResources: []*resourceTypes.NetworkResource{ + {ID: networkResourceID, NetworkID: networkID, AccountID: testAccountID, Enabled: true, Address: "db.netbird.cloud"}, + }, + Networks: []*networkTypes.Network{{ID: networkID, Name: "DB Network", AccountID: testAccountID}}, + NetworkRouters: []*routerTypes.NetworkRouter{ + {ID: networkRouterID, NetworkID: networkID, Peer: routingPeerID, Enabled: true, AccountID: testAccountID}, + }, + Settings: &Settings{PeerLoginExpirationEnabled: true, PeerLoginExpiration: 1 * time.Hour}, + } + + for _, p := range account.Policies { + p.AccountID = account.Id + } + for _, r := range account.Routes { + r.AccountID = account.Id + } + + return account +} + +func BenchmarkLegacyNetworkMap(b *testing.B) { + account := createTestAccount() + ctx := context.Background() + peerID := testingPeerID + validatedPeersMap := make(map[string]struct{}) + for i := range numPeers { + pid := fmt.Sprintf("peer-%d", i) + if pid != offlinePeerID { + validatedPeersMap[pid] = struct{}{} + } + } + + peersCustomZone := nbdns.CustomZone{} + resourcePolicies := account.GetResourcePoliciesMap() + routers := account.GetResourceRoutersMap() + groupIDToUserIDs := account.GetActiveGroupUsers() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = account.GetPeerNetworkMap( + ctx, + peerID, + peersCustomZone, + nil, + validatedPeersMap, + resourcePolicies, + routers, + nil, + groupIDToUserIDs, + ) + } +} + +func BenchmarkComponentsNetworkMap(b *testing.B) { + account := createTestAccount() + ctx := context.Background() + peerID := testingPeerID + validatedPeersMap := make(map[string]struct{}) + for i := range numPeers { + pid := fmt.Sprintf("peer-%d", i) + if pid != offlinePeerID { + validatedPeersMap[pid] = struct{}{} + } + } + + peersCustomZone := nbdns.CustomZone{} + resourcePolicies := account.GetResourcePoliciesMap() + routers := account.GetResourceRoutersMap() + groupIDToUserIDs := account.GetActiveGroupUsers() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + components := account.GetPeerNetworkMapComponents( + ctx, + peerID, + peersCustomZone, + nil, + validatedPeersMap, + resourcePolicies, + routers, + groupIDToUserIDs, + ) + _ = CalculateNetworkMapFromComponents(ctx, components) + } +} + +func BenchmarkComponentsCreation(b *testing.B) { + account := createTestAccount() + ctx := context.Background() + peerID := testingPeerID + validatedPeersMap := make(map[string]struct{}) + for i := range numPeers { + pid := fmt.Sprintf("peer-%d", i) + if pid != offlinePeerID { + validatedPeersMap[pid] = struct{}{} + } + } + + peersCustomZone := nbdns.CustomZone{} + resourcePolicies := account.GetResourcePoliciesMap() + routers := account.GetResourceRoutersMap() + groupIDToUserIDs := account.GetActiveGroupUsers() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = account.GetPeerNetworkMapComponents( + ctx, + peerID, + peersCustomZone, + nil, + validatedPeersMap, + resourcePolicies, + routers, + groupIDToUserIDs, + ) + } +} + +func BenchmarkCalculationFromComponents(b *testing.B) { + account := createTestAccount() + ctx := context.Background() + peerID := testingPeerID + validatedPeersMap := make(map[string]struct{}) + for i := range numPeers { + pid := fmt.Sprintf("peer-%d", i) + if pid != offlinePeerID { + validatedPeersMap[pid] = struct{}{} + } + } + + peersCustomZone := nbdns.CustomZone{} + resourcePolicies := account.GetResourcePoliciesMap() + routers := account.GetResourceRoutersMap() + groupIDToUserIDs := account.GetActiveGroupUsers() + + components := account.GetPeerNetworkMapComponents( + ctx, + peerID, + peersCustomZone, + nil, + validatedPeersMap, + resourcePolicies, + routers, + groupIDToUserIDs, + ) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = CalculateNetworkMapFromComponents(ctx, components) + } +} diff --git a/management/server/types/networkmap_components.go b/management/server/types/networkmap_components.go new file mode 100644 index 000000000..ab6b006e6 --- /dev/null +++ b/management/server/types/networkmap_components.go @@ -0,0 +1,938 @@ +package types + +import ( + "context" + "maps" + "net" + "net/netip" + "slices" + "strconv" + "strings" + "time" + + "github.com/netbirdio/netbird/client/ssh/auth" + nbdns "github.com/netbirdio/netbird/dns" + resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" + routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/route" + "github.com/netbirdio/netbird/shared/management/domain" +) + +const EnvNewNetworkMapCompacted = "NB_NETWORK_MAP_COMPACTED" + +type NetworkMapComponents struct { + PeerID string + + Network *Network + AccountSettings *AccountSettingsInfo + DNSSettings *DNSSettings + CustomZoneDomain string + + Peers map[string]*nbpeer.Peer + Groups map[string]*Group + Policies []*Policy + Routes []*route.Route + NameServerGroups []*nbdns.NameServerGroup + AllDNSRecords []nbdns.SimpleRecord + AccountZones []nbdns.CustomZone + ResourcePoliciesMap map[string][]*Policy + RoutersMap map[string]map[string]*routerTypes.NetworkRouter + NetworkResources []*resourceTypes.NetworkResource + + GroupIDToUserIDs map[string][]string + AllowedUserIDs map[string]struct{} + PostureFailedPeers map[string]map[string]struct{} + + RouterPeers map[string]*nbpeer.Peer +} + +type AccountSettingsInfo struct { + PeerLoginExpirationEnabled bool + PeerLoginExpiration time.Duration + PeerInactivityExpirationEnabled bool + PeerInactivityExpiration time.Duration +} + +func (c *NetworkMapComponents) GetPeerInfo(peerID string) *nbpeer.Peer { + return c.Peers[peerID] +} + +func (c *NetworkMapComponents) GetRouterPeerInfo(peerID string) *nbpeer.Peer { + return c.RouterPeers[peerID] +} + +func (c *NetworkMapComponents) GetGroupInfo(groupID string) *Group { + return c.Groups[groupID] +} + +func (c *NetworkMapComponents) IsPeerInGroup(peerID, groupID string) bool { + group := c.GetGroupInfo(groupID) + if group == nil { + return false + } + + return slices.Contains(group.Peers, peerID) +} + +func (c *NetworkMapComponents) GetPeerGroups(peerID string) map[string]struct{} { + groups := make(map[string]struct{}) + for groupID, group := range c.Groups { + if slices.Contains(group.Peers, peerID) { + groups[groupID] = struct{}{} + } + } + return groups +} + +func (c *NetworkMapComponents) ValidatePostureChecksOnPeer(peerID string, postureCheckIDs []string) bool { + _, exists := c.Peers[peerID] + if !exists { + return false + } + if len(postureCheckIDs) == 0 { + return true + } + for _, checkID := range postureCheckIDs { + if failedPeers, exists := c.PostureFailedPeers[checkID]; exists { + if _, failed := failedPeers[peerID]; failed { + return false + } + } + } + return true +} + +func CalculateNetworkMapFromComponents(ctx context.Context, components *NetworkMapComponents) *NetworkMap { + return components.Calculate(ctx) +} + +func (c *NetworkMapComponents) Calculate(ctx context.Context) *NetworkMap { + targetPeerID := c.PeerID + + peerGroups := c.GetPeerGroups(targetPeerID) + + aclPeers, firewallRules, authorizedUsers, sshEnabled := c.getPeerConnectionResources(targetPeerID) + + peersToConnect, expiredPeers := c.filterPeersByLoginExpiration(aclPeers) + + routesUpdate := c.getRoutesToSync(targetPeerID, peersToConnect, peerGroups) + routesFirewallRules := c.getPeerRoutesFirewallRules(ctx, targetPeerID) + + isRouter, networkResourcesRoutes, sourcePeers := c.getNetworkResourcesRoutesToSync(targetPeerID) + var networkResourcesFirewallRules []*RouteFirewallRule + if isRouter { + networkResourcesFirewallRules = c.getPeerNetworkResourceFirewallRules(ctx, targetPeerID, networkResourcesRoutes) + } + + peersToConnectIncludingRouters := c.addNetworksRoutingPeers( + networkResourcesRoutes, + targetPeerID, + peersToConnect, + expiredPeers, + isRouter, + sourcePeers, + ) + + dnsManagementStatus := c.getPeerDNSManagementStatus(targetPeerID) + dnsUpdate := nbdns.Config{ + ServiceEnable: dnsManagementStatus, + } + + if dnsManagementStatus { + var customZones []nbdns.CustomZone + + if c.CustomZoneDomain != "" && len(c.AllDNSRecords) > 0 { + customZones = append(customZones, nbdns.CustomZone{ + Domain: c.CustomZoneDomain, + Records: c.AllDNSRecords, + }) + } + + customZones = append(customZones, c.AccountZones...) + + dnsUpdate.CustomZones = customZones + dnsUpdate.NameServerGroups = c.getPeerNSGroups(targetPeerID) + } + + return &NetworkMap{ + Peers: peersToConnectIncludingRouters, + Network: c.Network.Copy(), + Routes: append(networkResourcesRoutes, routesUpdate...), + DNSConfig: dnsUpdate, + OfflinePeers: expiredPeers, + FirewallRules: firewallRules, + RoutesFirewallRules: append(networkResourcesFirewallRules, routesFirewallRules...), + AuthorizedUsers: authorizedUsers, + EnableSSH: sshEnabled, + } +} + +func (c *NetworkMapComponents) getPeerConnectionResources(targetPeerID string) ([]*nbpeer.Peer, []*FirewallRule, map[string]map[string]struct{}, bool) { + targetPeer := c.GetPeerInfo(targetPeerID) + if targetPeer == nil { + return nil, nil, nil, false + } + + generateResources, getAccumulatedResources := c.connResourcesGenerator(targetPeer) + authorizedUsers := make(map[string]map[string]struct{}) + sshEnabled := false + + for _, policy := range c.Policies { + if !policy.Enabled { + continue + } + + for _, rule := range policy.Rules { + if !rule.Enabled { + continue + } + + var sourcePeers, destinationPeers []*nbpeer.Peer + var peerInSources, peerInDestinations bool + + if rule.SourceResource.Type == ResourceTypePeer && rule.SourceResource.ID != "" { + sourcePeers, peerInSources = c.getPeerFromResource(rule.SourceResource, targetPeerID) + } else { + sourcePeers, peerInSources = c.getAllPeersFromGroups(rule.Sources, targetPeerID, policy.SourcePostureChecks) + } + + if rule.DestinationResource.Type == ResourceTypePeer && rule.DestinationResource.ID != "" { + destinationPeers, peerInDestinations = c.getPeerFromResource(rule.DestinationResource, targetPeerID) + } else { + destinationPeers, peerInDestinations = c.getAllPeersFromGroups(rule.Destinations, targetPeerID, nil) + } + + if rule.Bidirectional { + if peerInSources { + generateResources(rule, destinationPeers, FirewallRuleDirectionIN) + } + if peerInDestinations { + generateResources(rule, sourcePeers, FirewallRuleDirectionOUT) + } + } + + if peerInSources { + generateResources(rule, destinationPeers, FirewallRuleDirectionOUT) + } + + if peerInDestinations { + generateResources(rule, sourcePeers, FirewallRuleDirectionIN) + } + + if peerInDestinations && rule.Protocol == PolicyRuleProtocolNetbirdSSH { + sshEnabled = true + switch { + case len(rule.AuthorizedGroups) > 0: + for groupID, localUsers := range rule.AuthorizedGroups { + userIDs, ok := c.GroupIDToUserIDs[groupID] + if !ok { + continue + } + + if len(localUsers) == 0 { + localUsers = []string{auth.Wildcard} + } + + for _, localUser := range localUsers { + if authorizedUsers[localUser] == nil { + authorizedUsers[localUser] = make(map[string]struct{}) + } + for _, userID := range userIDs { + authorizedUsers[localUser][userID] = struct{}{} + } + } + } + case rule.AuthorizedUser != "": + if authorizedUsers[auth.Wildcard] == nil { + authorizedUsers[auth.Wildcard] = make(map[string]struct{}) + } + authorizedUsers[auth.Wildcard][rule.AuthorizedUser] = struct{}{} + default: + authorizedUsers[auth.Wildcard] = c.getAllowedUserIDs() + } + } else if peerInDestinations && policyRuleImpliesLegacySSH(rule) && targetPeer.SSHEnabled { + sshEnabled = true + authorizedUsers[auth.Wildcard] = c.getAllowedUserIDs() + } + } + } + + peers, fwRules := getAccumulatedResources() + return peers, fwRules, authorizedUsers, sshEnabled +} + +func (c *NetworkMapComponents) getAllowedUserIDs() map[string]struct{} { + if c.AllowedUserIDs != nil { + result := make(map[string]struct{}, len(c.AllowedUserIDs)) + maps.Copy(result, c.AllowedUserIDs) + return result + } + return make(map[string]struct{}) +} + +func (c *NetworkMapComponents) connResourcesGenerator(targetPeer *nbpeer.Peer) (func(*PolicyRule, []*nbpeer.Peer, int), func() ([]*nbpeer.Peer, []*FirewallRule)) { + rulesExists := make(map[string]struct{}) + peersExists := make(map[string]struct{}) + rules := make([]*FirewallRule, 0) + peers := make([]*nbpeer.Peer, 0) + + return func(rule *PolicyRule, groupPeers []*nbpeer.Peer, direction int) { + for _, peer := range groupPeers { + if peer == nil { + continue + } + + if _, ok := peersExists[peer.ID]; !ok { + peers = append(peers, peer) + peersExists[peer.ID] = struct{}{} + } + + protocol := rule.Protocol + if protocol == PolicyRuleProtocolNetbirdSSH { + protocol = PolicyRuleProtocolTCP + } + + fr := FirewallRule{ + PolicyID: rule.ID, + PeerIP: net.IP(peer.IP).String(), + Direction: direction, + Action: string(rule.Action), + Protocol: string(protocol), + } + + ruleID := rule.ID + fr.PeerIP + strconv.Itoa(direction) + + fr.Protocol + fr.Action + strings.Join(rule.Ports, ",") + if _, ok := rulesExists[ruleID]; ok { + continue + } + rulesExists[ruleID] = struct{}{} + + if len(rule.Ports) == 0 && len(rule.PortRanges) == 0 { + rules = append(rules, &fr) + continue + } + + rules = append(rules, expandPortsAndRanges(fr, &PolicyRule{ + ID: rule.ID, + Ports: rule.Ports, + PortRanges: rule.PortRanges, + Protocol: rule.Protocol, + Action: rule.Action, + }, targetPeer)...) + } + }, func() ([]*nbpeer.Peer, []*FirewallRule) { + return peers, rules + } +} + +func (c *NetworkMapComponents) getAllPeersFromGroups(groups []string, peerID string, sourcePostureChecksIDs []string) ([]*nbpeer.Peer, bool) { + peerInGroups := false + uniquePeerIDs := c.getUniquePeerIDsFromGroupsIDs(groups) + filteredPeers := make([]*nbpeer.Peer, 0, len(uniquePeerIDs)) + + for _, p := range uniquePeerIDs { + peerInfo := c.GetPeerInfo(p) + if peerInfo == nil { + continue + } + + if _, ok := c.Peers[p]; !ok { + continue + } + + if !c.ValidatePostureChecksOnPeer(p, sourcePostureChecksIDs) { + continue + } + + if p == peerID { + peerInGroups = true + continue + } + + filteredPeers = append(filteredPeers, peerInfo) + } + + return filteredPeers, peerInGroups +} + +func (c *NetworkMapComponents) getUniquePeerIDsFromGroupsIDs(groups []string) []string { + peerIDs := make(map[string]struct{}, len(groups)) + for _, groupID := range groups { + group := c.GetGroupInfo(groupID) + if group == nil { + continue + } + + if group.IsGroupAll() || len(groups) == 1 { + return group.Peers + } + + for _, peerID := range group.Peers { + peerIDs[peerID] = struct{}{} + } + } + + ids := make([]string, 0, len(peerIDs)) + for peerID := range peerIDs { + ids = append(ids, peerID) + } + + return ids +} + +func (c *NetworkMapComponents) getPeerFromResource(resource Resource, peerID string) ([]*nbpeer.Peer, bool) { + if resource.ID == peerID { + return []*nbpeer.Peer{}, true + } + + peerInfo := c.GetPeerInfo(resource.ID) + if peerInfo == nil { + return []*nbpeer.Peer{}, false + } + + return []*nbpeer.Peer{peerInfo}, false +} + +func (c *NetworkMapComponents) filterPeersByLoginExpiration(aclPeers []*nbpeer.Peer) ([]*nbpeer.Peer, []*nbpeer.Peer) { + var peersToConnect []*nbpeer.Peer + var expiredPeers []*nbpeer.Peer + + for _, p := range aclPeers { + expired, _ := p.LoginExpired(c.AccountSettings.PeerLoginExpiration) + if c.AccountSettings.PeerLoginExpirationEnabled && expired { + expiredPeers = append(expiredPeers, p) + continue + } + peersToConnect = append(peersToConnect, p) + } + + return peersToConnect, expiredPeers +} + +func (c *NetworkMapComponents) getPeerDNSManagementStatus(peerID string) bool { + peerGroups := c.GetPeerGroups(peerID) + enabled := true + for _, groupID := range c.DNSSettings.DisabledManagementGroups { + if _, found := peerGroups[groupID]; found { + enabled = false + break + } + } + return enabled +} + +func (c *NetworkMapComponents) getPeerNSGroups(peerID string) []*nbdns.NameServerGroup { + groupList := c.GetPeerGroups(peerID) + + var peerNSGroups []*nbdns.NameServerGroup + + for _, nsGroup := range c.NameServerGroups { + if !nsGroup.Enabled { + continue + } + for _, gID := range nsGroup.Groups { + _, found := groupList[gID] + if found { + targetPeerInfo := c.GetPeerInfo(peerID) + if targetPeerInfo != nil && !c.peerIsNameserver(targetPeerInfo, nsGroup) { + peerNSGroups = append(peerNSGroups, nsGroup.Copy()) + break + } + } + } + } + + return peerNSGroups +} + +func (c *NetworkMapComponents) peerIsNameserver(peerInfo *nbpeer.Peer, nsGroup *nbdns.NameServerGroup) bool { + for _, ns := range nsGroup.NameServers { + if peerInfo.IP.String() == ns.IP.String() { + return true + } + } + return false +} + +func (c *NetworkMapComponents) getRoutesToSync(peerID string, aclPeers []*nbpeer.Peer, peerGroups LookupMap) []*route.Route { + routes, peerDisabledRoutes := c.getRoutingPeerRoutes(peerID) + peerRoutesMembership := make(LookupMap) + for _, r := range append(routes, peerDisabledRoutes...) { + peerRoutesMembership[string(r.GetHAUniqueID())] = struct{}{} + } + + for _, peer := range aclPeers { + activeRoutes, _ := c.getRoutingPeerRoutes(peer.ID) + groupFilteredRoutes := c.filterRoutesByGroups(activeRoutes, peerGroups) + filteredRoutes := c.filterRoutesFromPeersOfSameHAGroup(groupFilteredRoutes, peerRoutesMembership) + routes = append(routes, filteredRoutes...) + } + + return routes +} + +func (c *NetworkMapComponents) getRoutingPeerRoutes(peerID string) (enabledRoutes []*route.Route, disabledRoutes []*route.Route) { + peerInfo := c.GetPeerInfo(peerID) + if peerInfo == nil { + peerInfo = c.GetRouterPeerInfo(peerID) + } + if peerInfo == nil { + return enabledRoutes, disabledRoutes + } + + seenRoute := make(map[route.ID]struct{}) + + takeRoute := func(r *route.Route) { + if _, ok := seenRoute[r.ID]; ok { + return + } + seenRoute[r.ID] = struct{}{} + + routeObj := c.copyRoute(r) + routeObj.Peer = peerInfo.Key + + if r.Enabled { + enabledRoutes = append(enabledRoutes, routeObj) + return + } + disabledRoutes = append(disabledRoutes, routeObj) + } + + for _, r := range c.Routes { + for _, groupID := range r.PeerGroups { + group := c.GetGroupInfo(groupID) + if group == nil { + continue + } + for _, id := range group.Peers { + if id != peerID { + continue + } + + newPeerRoute := c.copyRoute(r) + newPeerRoute.Peer = id + newPeerRoute.PeerGroups = nil + newPeerRoute.ID = route.ID(string(r.ID) + ":" + id) + takeRoute(newPeerRoute) + break + } + } + if r.Peer == peerID { + takeRoute(c.copyRoute(r)) + } + } + + return enabledRoutes, disabledRoutes +} + +func (c *NetworkMapComponents) copyRoute(r *route.Route) *route.Route { + var groups, accessControlGroups, peerGroups []string + var domains domain.List + + if r.Groups != nil { + groups = append([]string{}, r.Groups...) + } + if r.AccessControlGroups != nil { + accessControlGroups = append([]string{}, r.AccessControlGroups...) + } + if r.PeerGroups != nil { + peerGroups = append([]string{}, r.PeerGroups...) + } + if r.Domains != nil { + domains = append(domain.List{}, r.Domains...) + } + + return &route.Route{ + ID: r.ID, + AccountID: r.AccountID, + Network: r.Network, + NetworkType: r.NetworkType, + Description: r.Description, + Peer: r.Peer, + PeerID: r.PeerID, + Metric: r.Metric, + Masquerade: r.Masquerade, + NetID: r.NetID, + Enabled: r.Enabled, + Groups: groups, + AccessControlGroups: accessControlGroups, + PeerGroups: peerGroups, + Domains: domains, + KeepRoute: r.KeepRoute, + SkipAutoApply: r.SkipAutoApply, + } +} + +func (c *NetworkMapComponents) filterRoutesByGroups(routes []*route.Route, groupListMap LookupMap) []*route.Route { + var filteredRoutes []*route.Route + for _, r := range routes { + for _, groupID := range r.Groups { + _, found := groupListMap[groupID] + if found { + filteredRoutes = append(filteredRoutes, r) + break + } + } + } + return filteredRoutes +} + +func (c *NetworkMapComponents) filterRoutesFromPeersOfSameHAGroup(routes []*route.Route, peerMemberships LookupMap) []*route.Route { + var filteredRoutes []*route.Route + for _, r := range routes { + _, found := peerMemberships[string(r.GetHAUniqueID())] + if !found { + filteredRoutes = append(filteredRoutes, r) + } + } + return filteredRoutes +} + +func (c *NetworkMapComponents) getPeerRoutesFirewallRules(ctx context.Context, peerID string) []*RouteFirewallRule { + routesFirewallRules := make([]*RouteFirewallRule, 0) + + enabledRoutes, _ := c.getRoutingPeerRoutes(peerID) + for _, r := range enabledRoutes { + if len(r.AccessControlGroups) == 0 { + defaultPermit := c.getDefaultPermit(r) + routesFirewallRules = append(routesFirewallRules, defaultPermit...) + continue + } + + distributionPeers := c.getDistributionGroupsPeers(r) + + for _, accessGroup := range r.AccessControlGroups { + policies := c.getAllRoutePoliciesFromGroups([]string{accessGroup}) + rules := c.getRouteFirewallRules(ctx, peerID, policies, r, distributionPeers) + routesFirewallRules = append(routesFirewallRules, rules...) + } + } + + return routesFirewallRules +} + +func (c *NetworkMapComponents) getDefaultPermit(r *route.Route) []*RouteFirewallRule { + var rules []*RouteFirewallRule + + sources := []string{"0.0.0.0/0"} + if r.Network.Addr().Is6() { + sources = []string{"::/0"} + } + + rule := RouteFirewallRule{ + SourceRanges: sources, + Action: string(PolicyTrafficActionAccept), + Destination: r.Network.String(), + Protocol: string(PolicyRuleProtocolALL), + Domains: r.Domains, + IsDynamic: r.IsDynamic(), + RouteID: r.ID, + } + + rules = append(rules, &rule) + + if r.IsDynamic() { + ruleV6 := rule + ruleV6.SourceRanges = []string{"::/0"} + rules = append(rules, &ruleV6) + } + + return rules +} + +func (c *NetworkMapComponents) getDistributionGroupsPeers(r *route.Route) map[string]struct{} { + distPeers := make(map[string]struct{}) + for _, id := range r.Groups { + group := c.GetGroupInfo(id) + if group == nil { + continue + } + + for _, pID := range group.Peers { + distPeers[pID] = struct{}{} + } + } + return distPeers +} + +func (c *NetworkMapComponents) getAllRoutePoliciesFromGroups(accessControlGroups []string) []*Policy { + routePolicies := make([]*Policy, 0) + for _, groupID := range accessControlGroups { + for _, policy := range c.Policies { + for _, rule := range policy.Rules { + if slices.Contains(rule.Destinations, groupID) { + routePolicies = append(routePolicies, policy) + } + } + } + } + + return routePolicies +} + +func (c *NetworkMapComponents) getRouteFirewallRules(ctx context.Context, peerID string, policies []*Policy, route *route.Route, distributionPeers map[string]struct{}) []*RouteFirewallRule { + var fwRules []*RouteFirewallRule + for _, policy := range policies { + if !policy.Enabled { + continue + } + + for _, rule := range policy.Rules { + if !rule.Enabled { + continue + } + + rulePeers := c.getRulePeers(rule, policy.SourcePostureChecks, peerID, distributionPeers) + rules := generateRouteFirewallRules(ctx, route, rule, rulePeers, FirewallRuleDirectionIN) + fwRules = append(fwRules, rules...) + } + } + return fwRules +} + +func (c *NetworkMapComponents) getRulePeers(rule *PolicyRule, postureChecks []string, peerID string, distributionPeers map[string]struct{}) []*nbpeer.Peer { + distPeersWithPolicy := make(map[string]struct{}) + for _, id := range rule.Sources { + group := c.GetGroupInfo(id) + if group == nil { + continue + } + + for _, pID := range group.Peers { + if pID == peerID { + continue + } + _, distPeer := distributionPeers[pID] + _, valid := c.Peers[pID] + if distPeer && valid && c.ValidatePostureChecksOnPeer(pID, postureChecks) { + distPeersWithPolicy[pID] = struct{}{} + } + } + } + if rule.SourceResource.Type == ResourceTypePeer && rule.SourceResource.ID != "" { + _, distPeer := distributionPeers[rule.SourceResource.ID] + _, valid := c.Peers[rule.SourceResource.ID] + if distPeer && valid && c.ValidatePostureChecksOnPeer(rule.SourceResource.ID, postureChecks) { + distPeersWithPolicy[rule.SourceResource.ID] = struct{}{} + } + } + + distributionGroupPeers := make([]*nbpeer.Peer, 0, len(distPeersWithPolicy)) + for pID := range distPeersWithPolicy { + peerInfo := c.GetPeerInfo(pID) + if peerInfo == nil { + continue + } + distributionGroupPeers = append(distributionGroupPeers, peerInfo) + } + return distributionGroupPeers +} + +func (c *NetworkMapComponents) getNetworkResourcesRoutesToSync(peerID string) (bool, []*route.Route, map[string]struct{}) { + var isRoutingPeer bool + var routes []*route.Route + allSourcePeers := make(map[string]struct{}) + + for _, resource := range c.NetworkResources { + if !resource.Enabled { + continue + } + + var addSourcePeers bool + + networkRoutingPeers, exists := c.RoutersMap[resource.NetworkID] + if exists { + if router, ok := networkRoutingPeers[peerID]; ok { + isRoutingPeer, addSourcePeers = true, true + routes = append(routes, c.getNetworkResourcesRoutes(resource, peerID, router)...) + } + } + + addedResourceRoute := false + for _, policy := range c.ResourcePoliciesMap[resource.ID] { + var peers []string + if policy.Rules[0].SourceResource.Type == ResourceTypePeer && policy.Rules[0].SourceResource.ID != "" { + peers = []string{policy.Rules[0].SourceResource.ID} + } else { + peers = c.getUniquePeerIDsFromGroupsIDs(policy.SourceGroups()) + } + if addSourcePeers { + for _, pID := range c.getPostureValidPeers(peers, policy.SourcePostureChecks) { + allSourcePeers[pID] = struct{}{} + } + } else if slices.Contains(peers, peerID) && c.ValidatePostureChecksOnPeer(peerID, policy.SourcePostureChecks) { + for peerId, router := range networkRoutingPeers { + routes = append(routes, c.getNetworkResourcesRoutes(resource, peerId, router)...) + } + addedResourceRoute = true + } + if addedResourceRoute { + break + } + } + } + + return isRoutingPeer, routes, allSourcePeers +} + +func (c *NetworkMapComponents) getNetworkResourcesRoutes(resource *resourceTypes.NetworkResource, peerID string, router *routerTypes.NetworkRouter) []*route.Route { + resourceAppliedPolicies := c.ResourcePoliciesMap[resource.ID] + + var routes []*route.Route + if len(resourceAppliedPolicies) > 0 { + peerInfo := c.GetPeerInfo(peerID) + if peerInfo != nil { + routes = append(routes, c.networkResourceToRoute(resource, peerInfo, router)) + } + } + + return routes +} + +func (c *NetworkMapComponents) networkResourceToRoute(resource *resourceTypes.NetworkResource, peer *nbpeer.Peer, router *routerTypes.NetworkRouter) *route.Route { + r := &route.Route{ + ID: route.ID(resource.ID + ":" + peer.ID), + AccountID: resource.AccountID, + Peer: peer.Key, + PeerID: peer.ID, + Metric: router.Metric, + Masquerade: router.Masquerade, + Enabled: resource.Enabled, + KeepRoute: true, + NetID: route.NetID(resource.Name), + Description: resource.Description, + } + + if resource.Type == resourceTypes.Host || resource.Type == resourceTypes.Subnet { + r.Network = resource.Prefix + + r.NetworkType = route.IPv4Network + if resource.Prefix.Addr().Is6() { + r.NetworkType = route.IPv6Network + } + } + + if resource.Type == resourceTypes.Domain { + domainList, err := domain.FromStringList([]string{resource.Domain}) + if err == nil { + r.Domains = domainList + r.NetworkType = route.DomainNetwork + r.Network = netip.PrefixFrom(netip.AddrFrom4([4]byte{192, 0, 2, 0}), 32) + } + } + + return r +} + +func (c *NetworkMapComponents) getPostureValidPeers(inputPeers []string, postureChecksIDs []string) []string { + var dest []string + for _, peerID := range inputPeers { + if c.ValidatePostureChecksOnPeer(peerID, postureChecksIDs) { + dest = append(dest, peerID) + } + } + return dest +} + +func (c *NetworkMapComponents) getPeerNetworkResourceFirewallRules(ctx context.Context, peerID string, routes []*route.Route) []*RouteFirewallRule { + routesFirewallRules := make([]*RouteFirewallRule, 0) + + peerInfo := c.GetPeerInfo(peerID) + if peerInfo == nil { + return routesFirewallRules + } + + for _, r := range routes { + if r.Peer != peerInfo.Key { + continue + } + + resourceID := string(r.GetResourceID()) + resourcePolicies := c.ResourcePoliciesMap[resourceID] + distributionPeers := c.getPoliciesSourcePeers(resourcePolicies) + + rules := c.getRouteFirewallRules(ctx, peerID, resourcePolicies, r, distributionPeers) + for _, rule := range rules { + if len(rule.SourceRanges) > 0 { + routesFirewallRules = append(routesFirewallRules, rule) + } + } + } + + return routesFirewallRules +} + +func (c *NetworkMapComponents) getPoliciesSourcePeers(policies []*Policy) map[string]struct{} { + sourcePeers := make(map[string]struct{}) + + for _, policy := range policies { + for _, rule := range policy.Rules { + for _, sourceGroup := range rule.Sources { + group := c.GetGroupInfo(sourceGroup) + if group == nil { + continue + } + + for _, peer := range group.Peers { + sourcePeers[peer] = struct{}{} + } + } + + if rule.SourceResource.Type == ResourceTypePeer && rule.SourceResource.ID != "" { + sourcePeers[rule.SourceResource.ID] = struct{}{} + } + } + } + + return sourcePeers +} + +func (c *NetworkMapComponents) addNetworksRoutingPeers( + networkResourcesRoutes []*route.Route, + peerID string, + peersToConnect []*nbpeer.Peer, + expiredPeers []*nbpeer.Peer, + isRouter bool, + sourcePeers map[string]struct{}, +) []*nbpeer.Peer { + + networkRoutesPeers := make(map[string]struct{}, len(networkResourcesRoutes)) + for _, r := range networkResourcesRoutes { + networkRoutesPeers[r.PeerID] = struct{}{} + } + + delete(sourcePeers, peerID) + delete(networkRoutesPeers, peerID) + + for _, existingPeer := range peersToConnect { + delete(sourcePeers, existingPeer.ID) + delete(networkRoutesPeers, existingPeer.ID) + } + for _, expPeer := range expiredPeers { + delete(sourcePeers, expPeer.ID) + delete(networkRoutesPeers, expPeer.ID) + } + + missingPeers := make(map[string]struct{}, len(sourcePeers)+len(networkRoutesPeers)) + if isRouter { + for p := range sourcePeers { + missingPeers[p] = struct{}{} + } + } + for p := range networkRoutesPeers { + missingPeers[p] = struct{}{} + } + + for p := range missingPeers { + peerInfo := c.GetPeerInfo(p) + if peerInfo == nil { + peerInfo = c.GetRouterPeerInfo(p) + } + if peerInfo != nil { + peersToConnect = append(peersToConnect, peerInfo) + } + } + + return peersToConnect +} diff --git a/management/server/types/networkmap_components_compact.go b/management/server/types/networkmap_components_compact.go new file mode 100644 index 000000000..b60f8bdb1 --- /dev/null +++ b/management/server/types/networkmap_components_compact.go @@ -0,0 +1,230 @@ +package types + +import ( + nbdns "github.com/netbirdio/netbird/dns" + resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" + routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/route" +) + +type GroupCompact struct { + Name string + PeerIndexes []int +} + +type NetworkMapComponentsCompact struct { + PeerID string + + Network *Network + AccountSettings *AccountSettingsInfo + DNSSettings *DNSSettings + CustomZoneDomain string + + AllPeers []*nbpeer.Peer + PeerIndexes []int + RouterPeerIndexes []int + + Groups map[string]*GroupCompact + AllPolicies []*Policy + PolicyIndexes []int + ResourcePoliciesMap map[string][]int + Routes []*route.Route + NameServerGroups []*nbdns.NameServerGroup + AllDNSRecords []nbdns.SimpleRecord + AccountZones []nbdns.CustomZone + + RoutersMap map[string]map[string]*routerTypes.NetworkRouter + NetworkResources []*resourceTypes.NetworkResource + + GroupIDToUserIDs map[string][]string + AllowedUserIDs map[string]struct{} + PostureFailedPeers map[string]map[string]struct{} +} + +func (c *NetworkMapComponents) ToCompact() *NetworkMapComponentsCompact { + peerToIndex := make(map[string]int) + var allPeers []*nbpeer.Peer + + for id, peer := range c.Peers { + if _, exists := peerToIndex[id]; !exists { + peerToIndex[id] = len(allPeers) + allPeers = append(allPeers, peer) + } + } + + for id, peer := range c.RouterPeers { + if _, exists := peerToIndex[id]; !exists { + peerToIndex[id] = len(allPeers) + allPeers = append(allPeers, peer) + } + } + + peerIndexes := make([]int, 0, len(c.Peers)) + for id := range c.Peers { + peerIndexes = append(peerIndexes, peerToIndex[id]) + } + + routerPeerIndexes := make([]int, 0, len(c.RouterPeers)) + for id := range c.RouterPeers { + routerPeerIndexes = append(routerPeerIndexes, peerToIndex[id]) + } + + groups := make(map[string]*GroupCompact, len(c.Groups)) + for id, group := range c.Groups { + peerIdxs := make([]int, 0, len(group.Peers)) + for _, peerID := range group.Peers { + if idx, ok := peerToIndex[peerID]; ok { + peerIdxs = append(peerIdxs, idx) + } + } + groups[id] = &GroupCompact{ + Name: group.Name, + PeerIndexes: peerIdxs, + } + } + + policyToIndex := make(map[*Policy]int) + var allPolicies []*Policy + + for _, policy := range c.Policies { + if _, exists := policyToIndex[policy]; !exists { + policyToIndex[policy] = len(allPolicies) + allPolicies = append(allPolicies, policy) + } + } + + for _, policies := range c.ResourcePoliciesMap { + for _, policy := range policies { + if _, exists := policyToIndex[policy]; !exists { + policyToIndex[policy] = len(allPolicies) + allPolicies = append(allPolicies, policy) + } + } + } + + policyIndexes := make([]int, len(c.Policies)) + for i, policy := range c.Policies { + policyIndexes[i] = policyToIndex[policy] + } + + var resourcePoliciesMap map[string][]int + if len(c.ResourcePoliciesMap) > 0 { + resourcePoliciesMap = make(map[string][]int, len(c.ResourcePoliciesMap)) + for resID, policies := range c.ResourcePoliciesMap { + indexes := make([]int, len(policies)) + for i, policy := range policies { + indexes[i] = policyToIndex[policy] + } + resourcePoliciesMap[resID] = indexes + } + } + + return &NetworkMapComponentsCompact{ + PeerID: c.PeerID, + Network: c.Network, + AccountSettings: c.AccountSettings, + DNSSettings: c.DNSSettings, + CustomZoneDomain: c.CustomZoneDomain, + + AllPeers: allPeers, + PeerIndexes: peerIndexes, + RouterPeerIndexes: routerPeerIndexes, + + Groups: groups, + AllPolicies: allPolicies, + PolicyIndexes: policyIndexes, + ResourcePoliciesMap: resourcePoliciesMap, + Routes: c.Routes, + NameServerGroups: c.NameServerGroups, + AllDNSRecords: c.AllDNSRecords, + AccountZones: c.AccountZones, + + RoutersMap: c.RoutersMap, + NetworkResources: c.NetworkResources, + + GroupIDToUserIDs: c.GroupIDToUserIDs, + AllowedUserIDs: c.AllowedUserIDs, + PostureFailedPeers: c.PostureFailedPeers, + } +} + +func (c *NetworkMapComponentsCompact) ToFull() *NetworkMapComponents { + peers := make(map[string]*nbpeer.Peer, len(c.PeerIndexes)) + for _, idx := range c.PeerIndexes { + if idx >= 0 && idx < len(c.AllPeers) { + peer := c.AllPeers[idx] + peers[peer.ID] = peer + } + } + + routerPeers := make(map[string]*nbpeer.Peer, len(c.RouterPeerIndexes)) + for _, idx := range c.RouterPeerIndexes { + if idx >= 0 && idx < len(c.AllPeers) { + peer := c.AllPeers[idx] + routerPeers[peer.ID] = peer + } + } + + groups := make(map[string]*Group, len(c.Groups)) + for id, gc := range c.Groups { + peerIDs := make([]string, 0, len(gc.PeerIndexes)) + for _, idx := range gc.PeerIndexes { + if idx >= 0 && idx < len(c.AllPeers) { + peerIDs = append(peerIDs, c.AllPeers[idx].ID) + } + } + groups[id] = &Group{ + ID: id, + Name: gc.Name, + Peers: peerIDs, + } + } + + policies := make([]*Policy, len(c.PolicyIndexes)) + for i, idx := range c.PolicyIndexes { + if idx >= 0 && idx < len(c.AllPolicies) { + policies[i] = c.AllPolicies[idx] + } + } + + var resourcePoliciesMap map[string][]*Policy + if len(c.ResourcePoliciesMap) > 0 { + resourcePoliciesMap = make(map[string][]*Policy, len(c.ResourcePoliciesMap)) + for resID, indexes := range c.ResourcePoliciesMap { + pols := make([]*Policy, 0, len(indexes)) + for _, idx := range indexes { + if idx >= 0 && idx < len(c.AllPolicies) { + pols = append(pols, c.AllPolicies[idx]) + } + } + resourcePoliciesMap[resID] = pols + } + } + + return &NetworkMapComponents{ + PeerID: c.PeerID, + Network: c.Network, + AccountSettings: c.AccountSettings, + DNSSettings: c.DNSSettings, + CustomZoneDomain: c.CustomZoneDomain, + + Peers: peers, + RouterPeers: routerPeers, + + Groups: groups, + Policies: policies, + Routes: c.Routes, + NameServerGroups: c.NameServerGroups, + AllDNSRecords: c.AllDNSRecords, + AccountZones: c.AccountZones, + + ResourcePoliciesMap: resourcePoliciesMap, + RoutersMap: c.RoutersMap, + NetworkResources: c.NetworkResources, + + GroupIDToUserIDs: c.GroupIDToUserIDs, + AllowedUserIDs: c.AllowedUserIDs, + PostureFailedPeers: c.PostureFailedPeers, + } +} diff --git a/proxy/cmd/proxy/cmd/root.go b/proxy/cmd/proxy/cmd/root.go index 121621109..c594f9800 100644 --- a/proxy/cmd/proxy/cmd/root.go +++ b/proxy/cmd/proxy/cmd/root.go @@ -53,6 +53,7 @@ var ( certLockMethod string wgPort int proxyProtocol bool + preSharedKey string ) var rootCmd = &cobra.Command{ @@ -84,6 +85,7 @@ func init() { rootCmd.Flags().StringVar(&certLockMethod, "cert-lock-method", envStringOrDefault("NB_PROXY_CERT_LOCK_METHOD", "auto"), "Certificate lock method for cross-replica coordination: auto, flock, or k8s-lease") rootCmd.Flags().IntVar(&wgPort, "wg-port", envIntOrDefault("NB_PROXY_WG_PORT", 0), "WireGuard listen port (0 = random). Fixed port only works with single-account deployments") rootCmd.Flags().BoolVar(&proxyProtocol, "proxy-protocol", envBoolOrDefault("NB_PROXY_PROXY_PROTOCOL", false), "Enable PROXY protocol on TCP listeners to preserve client IPs behind L4 proxies") + rootCmd.Flags().StringVar(&preSharedKey, "preshared-key", envStringOrDefault("NB_PROXY_PRESHARED_KEY", ""), "Define a pre-shared key for the tunnel between proxy and peers") } // Execute runs the root command. @@ -156,6 +158,7 @@ func runServer(cmd *cobra.Command, args []string) error { CertLockMethod: nbacme.CertLockMethod(certLockMethod), WireguardPort: wgPort, ProxyProtocol: proxyProtocol, + PreSharedKey: preSharedKey, } ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGTERM, syscall.SIGINT) diff --git a/proxy/internal/roundtrip/netbird.go b/proxy/internal/roundtrip/netbird.go index d7fd2746f..481b42d2b 100644 --- a/proxy/internal/roundtrip/netbird.go +++ b/proxy/internal/roundtrip/netbird.go @@ -86,6 +86,13 @@ func (e *clientEntry) acquireInflight(backend backendKey) (release func(), ok bo } } +// ClientConfig holds configuration for the embedded NetBird client. +type ClientConfig struct { + MgmtAddr string + WGPort int + PreSharedKey string +} + type statusNotifier interface { NotifyStatus(ctx context.Context, accountID, serviceID, domain string, connected bool) error } @@ -98,10 +105,9 @@ type managementClient interface { // backed by underlying NetBird connections. // Clients are keyed by AccountID, allowing multiple domains to share the same connection. type NetBird struct { - mgmtAddr string proxyID string proxyAddr string - wgPort int + clientCfg ClientConfig logger *log.Logger mgmtClient managementClient transportCfg transportConfig @@ -229,11 +235,12 @@ func (n *NetBird) createClientEntry(ctx context.Context, accountID types.Account // The peer has already been created via CreateProxyPeer RPC with the public key. client, err := embed.New(embed.Options{ DeviceName: deviceNamePrefix + n.proxyID, - ManagementURL: n.mgmtAddr, + ManagementURL: n.clientCfg.MgmtAddr, PrivateKey: privateKey.String(), LogLevel: log.WarnLevel.String(), BlockInbound: true, - WireguardPort: &n.wgPort, + WireguardPort: &n.clientCfg.WGPort, + PreSharedKey: n.clientCfg.PreSharedKey, }) if err != nil { return nil, fmt.Errorf("create netbird client: %w", err) @@ -536,18 +543,17 @@ func (n *NetBird) ListClientsForStartup() map[types.AccountID]*embed.Client { return result } -// NewNetBird creates a new NetBird transport. Set wgPort to 0 for a random +// NewNetBird creates a new NetBird transport. Set clientCfg.WGPort to 0 for a random // OS-assigned port. A fixed port only works with single-account deployments; // multiple accounts will fail to bind the same port. -func NewNetBird(mgmtAddr, proxyID, proxyAddr string, wgPort int, logger *log.Logger, notifier statusNotifier, mgmtClient managementClient) *NetBird { +func NewNetBird(proxyID, proxyAddr string, clientCfg ClientConfig, logger *log.Logger, notifier statusNotifier, mgmtClient managementClient) *NetBird { if logger == nil { logger = log.StandardLogger() } return &NetBird{ - mgmtAddr: mgmtAddr, proxyID: proxyID, proxyAddr: proxyAddr, - wgPort: wgPort, + clientCfg: clientCfg, logger: logger, clients: make(map[types.AccountID]*clientEntry), statusNotifier: notifier, diff --git a/proxy/internal/roundtrip/netbird_test.go b/proxy/internal/roundtrip/netbird_test.go index 3e76af9da..0a742c2fa 100644 --- a/proxy/internal/roundtrip/netbird_test.go +++ b/proxy/internal/roundtrip/netbird_test.go @@ -49,7 +49,11 @@ func (m *mockStatusNotifier) calls() []statusCall { // mockNetBird creates a NetBird instance for testing without actually connecting. // It uses an invalid management URL to prevent real connections. func mockNetBird() *NetBird { - return NewNetBird("http://invalid.test:9999", "test-proxy", "invalid.test", 0, nil, nil, &mockMgmtClient{}) + return NewNetBird("test-proxy", "invalid.test", ClientConfig{ + MgmtAddr: "http://invalid.test:9999", + WGPort: 0, + PreSharedKey: "", + }, nil, nil, &mockMgmtClient{}) } func TestNetBird_AddPeer_CreatesClientForNewAccount(t *testing.T) { @@ -282,7 +286,11 @@ func TestNetBird_RoundTrip_RequiresExistingClient(t *testing.T) { func TestNetBird_AddPeer_ExistingStartedClient_NotifiesStatus(t *testing.T) { notifier := &mockStatusNotifier{} - nb := NewNetBird("http://invalid.test:9999", "test-proxy", "invalid.test", 0, nil, notifier, &mockMgmtClient{}) + nb := NewNetBird("test-proxy", "invalid.test", ClientConfig{ + MgmtAddr: "http://invalid.test:9999", + WGPort: 0, + PreSharedKey: "", + }, nil, notifier, &mockMgmtClient{}) accountID := types.AccountID("account-1") // Add first domain — creates a new client entry. @@ -308,7 +316,11 @@ func TestNetBird_AddPeer_ExistingStartedClient_NotifiesStatus(t *testing.T) { func TestNetBird_RemovePeer_NotifiesDisconnection(t *testing.T) { notifier := &mockStatusNotifier{} - nb := NewNetBird("http://invalid.test:9999", "test-proxy", "invalid.test", 0, nil, notifier, &mockMgmtClient{}) + nb := NewNetBird("test-proxy", "invalid.test", ClientConfig{ + MgmtAddr: "http://invalid.test:9999", + WGPort: 0, + PreSharedKey: "", + }, nil, notifier, &mockMgmtClient{}) accountID := types.AccountID("account-1") err := nb.AddPeer(context.Background(), accountID, domain.Domain("domain1.test"), "key-1", "svc-1") diff --git a/proxy/management_integration_test.go b/proxy/management_integration_test.go index 849a41ace..68f0bfdc8 100644 --- a/proxy/management_integration_test.go +++ b/proxy/management_integration_test.go @@ -192,6 +192,10 @@ type storeBackedServiceManager struct { tokenStore *nbgrpc.OneTimeTokenStore } +func (m *storeBackedServiceManager) DeleteAllServices(ctx context.Context, accountID, userID string) error { + return nil +} + func (m *storeBackedServiceManager) GetAllServices(ctx context.Context, accountID, userID string) ([]*reverseproxy.Service, error) { return m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID) } diff --git a/proxy/server.go b/proxy/server.go index 60811e53b..48a876899 100644 --- a/proxy/server.go +++ b/proxy/server.go @@ -114,6 +114,8 @@ type Server struct { // When enabled, the real client IP is extracted from the PROXY header // sent by upstream L4 proxies that support PROXY protocol. ProxyProtocol bool + // PreSharedKey used for tunnel between proxy and peers (set globally not per account) + PreSharedKey string } // NotifyStatus sends a status update to management about tunnel connectivity @@ -163,7 +165,11 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) { // Initialize the netbird client, this is required to build peer connections // to proxy over. - s.netbird = roundtrip.NewNetBird(s.ManagementAddress, s.ID, s.ProxyURL, s.WireguardPort, s.Logger, s, s.mgmtClient) + s.netbird = roundtrip.NewNetBird(s.ID, s.ProxyURL, roundtrip.ClientConfig{ + MgmtAddr: s.ManagementAddress, + WGPort: s.WireguardPort, + PreSharedKey: s.PreSharedKey, + }, s.Logger, s, s.mgmtClient) tlsConfig, err := s.configureTLS(ctx) if err != nil {