From b6038e8acda35545ff551847458bd191f2f008c3 Mon Sep 17 00:00:00 2001 From: Vlad <4941176+crn4@users.noreply.github.com> Date: Thu, 23 Apr 2026 15:13:22 +0200 Subject: [PATCH 01/11] [management] refactor: changeable pat rate limiting (#5946) --- management/internals/server/boot.go | 12 +- management/server/http/handler.go | 44 +---- .../server/http/middleware/auth_middleware.go | 13 +- .../http/middleware/auth_middleware_test.go | 22 ++- .../server/http/middleware/rate_limiter.go | 97 ++++++++++ .../http/middleware/rate_limiter_test.go | 171 ++++++++++++++++++ .../testing/testing_tools/channel/channel.go | 4 +- 7 files changed, 304 insertions(+), 59 deletions(-) diff --git a/management/internals/server/boot.go b/management/internals/server/boot.go index 24dfb641b..2b40c0aad 100644 --- a/management/internals/server/boot.go +++ b/management/internals/server/boot.go @@ -30,6 +30,7 @@ import ( nbcache "github.com/netbirdio/netbird/management/server/cache" nbContext "github.com/netbirdio/netbird/management/server/context" nbhttp "github.com/netbirdio/netbird/management/server/http" + "github.com/netbirdio/netbird/management/server/http/middleware" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/telemetry" mgmtProto "github.com/netbirdio/netbird/shared/management/proto" @@ -109,7 +110,7 @@ func (s *BaseServer) EventStore() activity.Store { func (s *BaseServer) APIHandler() http.Handler { return Create(s, func() http.Handler { - httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.IdpManager(), s.ServiceManager(), s.ReverseProxyDomainManager(), s.AccessLogsManager(), s.ReverseProxyGRPCServer(), s.Config.ReverseProxy.TrustedHTTPProxies) + httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.IdpManager(), s.ServiceManager(), s.ReverseProxyDomainManager(), s.AccessLogsManager(), s.ReverseProxyGRPCServer(), s.Config.ReverseProxy.TrustedHTTPProxies, s.RateLimiter()) if err != nil { log.Fatalf("failed to create API handler: %v", err) } @@ -117,6 +118,15 @@ func (s *BaseServer) APIHandler() http.Handler { }) } +func (s *BaseServer) RateLimiter() *middleware.APIRateLimiter { + return Create(s, func() *middleware.APIRateLimiter { + cfg, enabled := middleware.RateLimiterConfigFromEnv() + limiter := middleware.NewAPIRateLimiter(cfg) + limiter.SetEnabled(enabled) + return limiter + }) +} + func (s *BaseServer) GRPCServer() *grpc.Server { return Create(s, func() *grpc.Server { trustedPeers := s.Config.ReverseProxy.TrustedPeers diff --git a/management/server/http/handler.go b/management/server/http/handler.go index ad36b9d46..56b2d8203 100644 --- a/management/server/http/handler.go +++ b/management/server/http/handler.go @@ -5,9 +5,6 @@ import ( "fmt" "net/http" "net/netip" - "os" - "strconv" - "time" "github.com/gorilla/mux" "github.com/rs/cors" @@ -66,14 +63,11 @@ import ( ) const ( - apiPrefix = "/api" - rateLimitingEnabledKey = "NB_API_RATE_LIMITING_ENABLED" - rateLimitingBurstKey = "NB_API_RATE_LIMITING_BURST" - rateLimitingRPMKey = "NB_API_RATE_LIMITING_RPM" + apiPrefix = "/api" ) // NewAPIHandler creates the Management service HTTP API handler registering all the available endpoints. -func NewAPIHandler(ctx context.Context, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, integratedValidator integrated_validator.IntegratedValidator, proxyController port_forwarding.Controller, permissionsManager permissions.Manager, peersManager nbpeers.Manager, settingsManager settings.Manager, zManager zones.Manager, rManager records.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager, serviceManager service.Manager, reverseProxyDomainManager *manager.Manager, reverseProxyAccessLogsManager accesslogs.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, trustedHTTPProxies []netip.Prefix) (http.Handler, error) { +func NewAPIHandler(ctx context.Context, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, integratedValidator integrated_validator.IntegratedValidator, proxyController port_forwarding.Controller, permissionsManager permissions.Manager, peersManager nbpeers.Manager, settingsManager settings.Manager, zManager zones.Manager, rManager records.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager, serviceManager service.Manager, reverseProxyDomainManager *manager.Manager, reverseProxyAccessLogsManager accesslogs.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, trustedHTTPProxies []netip.Prefix, rateLimiter *middleware.APIRateLimiter) (http.Handler, error) { // Register bypass paths for unauthenticated endpoints if err := bypass.AddBypassPath("/api/instance"); err != nil { @@ -94,34 +88,10 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks return nil, fmt.Errorf("failed to add bypass path: %w", err) } - var rateLimitingConfig *middleware.RateLimiterConfig - if os.Getenv(rateLimitingEnabledKey) == "true" { - rpm := 6 - if v := os.Getenv(rateLimitingRPMKey); v != "" { - value, err := strconv.Atoi(v) - if err != nil { - log.Warnf("parsing %s env var: %v, using default %d", rateLimitingRPMKey, err, rpm) - } else { - rpm = value - } - } - - burst := 500 - if v := os.Getenv(rateLimitingBurstKey); v != "" { - value, err := strconv.Atoi(v) - if err != nil { - log.Warnf("parsing %s env var: %v, using default %d", rateLimitingBurstKey, err, burst) - } else { - burst = value - } - } - - rateLimitingConfig = &middleware.RateLimiterConfig{ - RequestsPerMinute: float64(rpm), - Burst: burst, - CleanupInterval: 6 * time.Hour, - LimiterTTL: 24 * time.Hour, - } + if rateLimiter == nil { + log.Warn("NewAPIHandler: nil rate limiter, rate limiting disabled") + rateLimiter = middleware.NewAPIRateLimiter(nil) + rateLimiter.SetEnabled(false) } authMiddleware := middleware.NewAuthMiddleware( @@ -129,7 +99,7 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks accountManager.GetAccountIDFromUserAuth, accountManager.SyncUserJWTGroups, accountManager.GetUserFromUserAuth, - rateLimitingConfig, + rateLimiter, appMetrics.GetMeter(), ) diff --git a/management/server/http/middleware/auth_middleware.go b/management/server/http/middleware/auth_middleware.go index 8106380f2..6d075d9c2 100644 --- a/management/server/http/middleware/auth_middleware.go +++ b/management/server/http/middleware/auth_middleware.go @@ -43,14 +43,9 @@ func NewAuthMiddleware( ensureAccount EnsureAccountFunc, syncUserJWTGroups SyncUserJWTGroupsFunc, getUserFromUserAuth GetUserFromUserAuthFunc, - rateLimiterConfig *RateLimiterConfig, + rateLimiter *APIRateLimiter, meter metric.Meter, ) *AuthMiddleware { - var rateLimiter *APIRateLimiter - if rateLimiterConfig != nil { - rateLimiter = NewAPIRateLimiter(rateLimiterConfig) - } - var patUsageTracker *PATUsageTracker if meter != nil { var err error @@ -181,10 +176,8 @@ func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts [] m.patUsageTracker.IncrementUsage(token) } - if m.rateLimiter != nil && !isTerraformRequest(r) { - if !m.rateLimiter.Allow(token) { - return status.Errorf(status.TooManyRequests, "too many requests") - } + if !isTerraformRequest(r) && !m.rateLimiter.Allow(token) { + return status.Errorf(status.TooManyRequests, "too many requests") } ctx := r.Context() diff --git a/management/server/http/middleware/auth_middleware_test.go b/management/server/http/middleware/auth_middleware_test.go index f397c63a4..8f736fbfd 100644 --- a/management/server/http/middleware/auth_middleware_test.go +++ b/management/server/http/middleware/auth_middleware_test.go @@ -196,6 +196,8 @@ func TestAuthMiddleware_Handler(t *testing.T) { GetPATInfoFunc: mockGetAccountInfoFromPAT, } + disabledLimiter := NewAPIRateLimiter(nil) + disabledLimiter.SetEnabled(false) authMiddleware := NewAuthMiddleware( mockAuth, func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) { @@ -207,7 +209,7 @@ func TestAuthMiddleware_Handler(t *testing.T) { func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) { return &types.User{}, nil }, - nil, + disabledLimiter, nil, ) @@ -266,7 +268,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) { func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) { return &types.User{}, nil }, - rateLimitConfig, + NewAPIRateLimiter(rateLimitConfig), nil, ) @@ -318,7 +320,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) { func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) { return &types.User{}, nil }, - rateLimitConfig, + NewAPIRateLimiter(rateLimitConfig), nil, ) @@ -361,7 +363,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) { func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) { return &types.User{}, nil }, - rateLimitConfig, + NewAPIRateLimiter(rateLimitConfig), nil, ) @@ -405,7 +407,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) { func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) { return &types.User{}, nil }, - rateLimitConfig, + NewAPIRateLimiter(rateLimitConfig), nil, ) @@ -469,7 +471,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) { func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) { return &types.User{}, nil }, - rateLimitConfig, + NewAPIRateLimiter(rateLimitConfig), nil, ) @@ -528,7 +530,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) { func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) { return &types.User{}, nil }, - rateLimitConfig, + NewAPIRateLimiter(rateLimitConfig), nil, ) @@ -583,7 +585,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) { func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) { return &types.User{}, nil }, - rateLimitConfig, + NewAPIRateLimiter(rateLimitConfig), nil, ) @@ -670,6 +672,8 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) { GetPATInfoFunc: mockGetAccountInfoFromPAT, } + disabledLimiter := NewAPIRateLimiter(nil) + disabledLimiter.SetEnabled(false) authMiddleware := NewAuthMiddleware( mockAuth, func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) { @@ -681,7 +685,7 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) { func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) { return &types.User{}, nil }, - nil, + disabledLimiter, nil, ) diff --git a/management/server/http/middleware/rate_limiter.go b/management/server/http/middleware/rate_limiter.go index 936b34319..bfd44afee 100644 --- a/management/server/http/middleware/rate_limiter.go +++ b/management/server/http/middleware/rate_limiter.go @@ -4,14 +4,27 @@ import ( "context" "net" "net/http" + "os" + "strconv" "sync" + "sync/atomic" "time" + log "github.com/sirupsen/logrus" "golang.org/x/time/rate" "github.com/netbirdio/netbird/shared/management/http/util" ) +const ( + RateLimitingEnabledEnv = "NB_API_RATE_LIMITING_ENABLED" + RateLimitingBurstEnv = "NB_API_RATE_LIMITING_BURST" + RateLimitingRPMEnv = "NB_API_RATE_LIMITING_RPM" + + defaultAPIRPM = 6 + defaultAPIBurst = 500 +) + // RateLimiterConfig holds configuration for the API rate limiter type RateLimiterConfig struct { // RequestsPerMinute defines the rate at which tokens are replenished @@ -34,6 +47,43 @@ func DefaultRateLimiterConfig() *RateLimiterConfig { } } +func RateLimiterConfigFromEnv() (cfg *RateLimiterConfig, enabled bool) { + rpm := defaultAPIRPM + if v := os.Getenv(RateLimitingRPMEnv); v != "" { + value, err := strconv.Atoi(v) + if err != nil { + log.Warnf("parsing %s env var: %v, using default %d", RateLimitingRPMEnv, err, rpm) + } else { + rpm = value + } + } + if rpm <= 0 { + log.Warnf("%s=%d is non-positive, using default %d", RateLimitingRPMEnv, rpm, defaultAPIRPM) + rpm = defaultAPIRPM + } + + burst := defaultAPIBurst + if v := os.Getenv(RateLimitingBurstEnv); v != "" { + value, err := strconv.Atoi(v) + if err != nil { + log.Warnf("parsing %s env var: %v, using default %d", RateLimitingBurstEnv, err, burst) + } else { + burst = value + } + } + if burst <= 0 { + log.Warnf("%s=%d is non-positive, using default %d", RateLimitingBurstEnv, burst, defaultAPIBurst) + burst = defaultAPIBurst + } + + return &RateLimiterConfig{ + RequestsPerMinute: float64(rpm), + Burst: burst, + CleanupInterval: 6 * time.Hour, + LimiterTTL: 24 * time.Hour, + }, os.Getenv(RateLimitingEnabledEnv) == "true" +} + // limiterEntry holds a rate limiter and its last access time type limiterEntry struct { limiter *rate.Limiter @@ -46,6 +96,7 @@ type APIRateLimiter struct { limiters map[string]*limiterEntry mu sync.RWMutex stopChan chan struct{} + enabled atomic.Bool } // NewAPIRateLimiter creates a new API rate limiter with the given configuration @@ -59,14 +110,53 @@ func NewAPIRateLimiter(config *RateLimiterConfig) *APIRateLimiter { limiters: make(map[string]*limiterEntry), stopChan: make(chan struct{}), } + rl.enabled.Store(true) go rl.cleanupLoop() return rl } +func (rl *APIRateLimiter) SetEnabled(enabled bool) { + rl.enabled.Store(enabled) +} + +func (rl *APIRateLimiter) Enabled() bool { + return rl.enabled.Load() +} + +func (rl *APIRateLimiter) UpdateConfig(config *RateLimiterConfig) { + if config == nil { + return + } + if config.RequestsPerMinute <= 0 || config.Burst <= 0 { + log.Warnf("UpdateConfig: ignoring invalid rpm=%v burst=%d", config.RequestsPerMinute, config.Burst) + return + } + + newRPS := rate.Limit(config.RequestsPerMinute / 60.0) + newBurst := config.Burst + + rl.mu.Lock() + rl.config.RequestsPerMinute = config.RequestsPerMinute + rl.config.Burst = newBurst + snapshot := make([]*rate.Limiter, 0, len(rl.limiters)) + for _, entry := range rl.limiters { + snapshot = append(snapshot, entry.limiter) + } + rl.mu.Unlock() + + for _, l := range snapshot { + l.SetLimit(newRPS) + l.SetBurst(newBurst) + } +} + // Allow checks if a request for the given key (token) is allowed func (rl *APIRateLimiter) Allow(key string) bool { + if !rl.enabled.Load() { + return true + } limiter := rl.getLimiter(key) return limiter.Allow() } @@ -74,6 +164,9 @@ func (rl *APIRateLimiter) Allow(key string) bool { // Wait blocks until the rate limiter allows another request for the given key // Returns an error if the context is canceled func (rl *APIRateLimiter) Wait(ctx context.Context, key string) error { + if !rl.enabled.Load() { + return nil + } limiter := rl.getLimiter(key) return limiter.Wait(ctx) } @@ -153,6 +246,10 @@ func (rl *APIRateLimiter) Reset(key string) { // Returns 429 Too Many Requests if the rate limit is exceeded. func (rl *APIRateLimiter) Middleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if !rl.enabled.Load() { + next.ServeHTTP(w, r) + return + } clientIP := getClientIP(r) if !rl.Allow(clientIP) { util.WriteErrorResponse("rate limit exceeded, please try again later", http.StatusTooManyRequests, w) diff --git a/management/server/http/middleware/rate_limiter_test.go b/management/server/http/middleware/rate_limiter_test.go index 68f804e57..4b97d1874 100644 --- a/management/server/http/middleware/rate_limiter_test.go +++ b/management/server/http/middleware/rate_limiter_test.go @@ -1,8 +1,10 @@ package middleware import ( + "fmt" "net/http" "net/http/httptest" + "sync" "testing" "time" @@ -156,3 +158,172 @@ func TestAPIRateLimiter_Reset(t *testing.T) { // Should be allowed again assert.True(t, rl.Allow("test-key")) } + +func TestAPIRateLimiter_SetEnabled(t *testing.T) { + rl := NewAPIRateLimiter(&RateLimiterConfig{ + RequestsPerMinute: 60, + Burst: 1, + CleanupInterval: time.Minute, + LimiterTTL: time.Minute, + }) + defer rl.Stop() + + assert.True(t, rl.Allow("key")) + assert.False(t, rl.Allow("key"), "burst exhausted while enabled") + + rl.SetEnabled(false) + assert.False(t, rl.Enabled()) + for i := 0; i < 5; i++ { + assert.True(t, rl.Allow("key"), "disabled limiter must always allow") + } + + rl.SetEnabled(true) + assert.True(t, rl.Enabled()) + assert.False(t, rl.Allow("key"), "re-enabled limiter retains prior bucket state") +} + +func TestAPIRateLimiter_UpdateConfig(t *testing.T) { + rl := NewAPIRateLimiter(&RateLimiterConfig{ + RequestsPerMinute: 60, + Burst: 2, + CleanupInterval: time.Minute, + LimiterTTL: time.Minute, + }) + defer rl.Stop() + + assert.True(t, rl.Allow("k1")) + assert.True(t, rl.Allow("k1")) + assert.False(t, rl.Allow("k1"), "burst=2 exhausted") + + rl.UpdateConfig(&RateLimiterConfig{ + RequestsPerMinute: 60, + Burst: 10, + CleanupInterval: time.Minute, + LimiterTTL: time.Minute, + }) + + // New burst applies to existing keys in place; bucket refills up to new burst over time, + // but importantly newly-added keys use the updated config immediately. + assert.True(t, rl.Allow("k2")) + for i := 0; i < 9; i++ { + assert.True(t, rl.Allow("k2")) + } + assert.False(t, rl.Allow("k2"), "new burst=10 exhausted") +} + +func TestAPIRateLimiter_UpdateConfig_NilIgnored(t *testing.T) { + rl := NewAPIRateLimiter(&RateLimiterConfig{ + RequestsPerMinute: 60, + Burst: 1, + CleanupInterval: time.Minute, + LimiterTTL: time.Minute, + }) + defer rl.Stop() + + rl.UpdateConfig(nil) // must not panic or zero the config + + assert.True(t, rl.Allow("k")) + assert.False(t, rl.Allow("k")) +} + +func TestAPIRateLimiter_UpdateConfig_NonPositiveIgnored(t *testing.T) { + rl := NewAPIRateLimiter(&RateLimiterConfig{ + RequestsPerMinute: 60, + Burst: 1, + CleanupInterval: time.Minute, + LimiterTTL: time.Minute, + }) + defer rl.Stop() + + assert.True(t, rl.Allow("k")) + assert.False(t, rl.Allow("k")) + + rl.UpdateConfig(&RateLimiterConfig{RequestsPerMinute: 0, Burst: 0, CleanupInterval: time.Minute, LimiterTTL: time.Minute}) + rl.UpdateConfig(&RateLimiterConfig{RequestsPerMinute: -1, Burst: 5, CleanupInterval: time.Minute, LimiterTTL: time.Minute}) + rl.UpdateConfig(&RateLimiterConfig{RequestsPerMinute: 60, Burst: -1, CleanupInterval: time.Minute, LimiterTTL: time.Minute}) + + rl.Reset("k") + assert.True(t, rl.Allow("k")) + assert.False(t, rl.Allow("k"), "burst should still be 1 — invalid UpdateConfig calls were ignored") +} + +func TestAPIRateLimiter_ConcurrentAllowAndUpdate(t *testing.T) { + rl := NewAPIRateLimiter(&RateLimiterConfig{ + RequestsPerMinute: 600, + Burst: 10, + CleanupInterval: time.Minute, + LimiterTTL: time.Minute, + }) + defer rl.Stop() + + var wg sync.WaitGroup + stop := make(chan struct{}) + + for i := 0; i < 8; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + key := fmt.Sprintf("k%d", id) + for { + select { + case <-stop: + return + default: + rl.Allow(key) + } + } + }(i) + } + + wg.Add(1) + go func() { + defer wg.Done() + for i := 0; i < 200; i++ { + select { + case <-stop: + return + default: + rl.UpdateConfig(&RateLimiterConfig{ + RequestsPerMinute: float64(30 + (i % 90)), + Burst: 1 + (i % 20), + CleanupInterval: time.Minute, + LimiterTTL: time.Minute, + }) + rl.SetEnabled(i%2 == 0) + } + } + }() + + time.Sleep(100 * time.Millisecond) + close(stop) + wg.Wait() +} + +func TestRateLimiterConfigFromEnv(t *testing.T) { + t.Setenv(RateLimitingEnabledEnv, "true") + t.Setenv(RateLimitingRPMEnv, "42") + t.Setenv(RateLimitingBurstEnv, "7") + + cfg, enabled := RateLimiterConfigFromEnv() + assert.True(t, enabled) + assert.Equal(t, float64(42), cfg.RequestsPerMinute) + assert.Equal(t, 7, cfg.Burst) + + t.Setenv(RateLimitingEnabledEnv, "false") + _, enabled = RateLimiterConfigFromEnv() + assert.False(t, enabled) + + t.Setenv(RateLimitingEnabledEnv, "") + t.Setenv(RateLimitingRPMEnv, "") + t.Setenv(RateLimitingBurstEnv, "") + cfg, enabled = RateLimiterConfigFromEnv() + assert.False(t, enabled) + assert.Equal(t, float64(defaultAPIRPM), cfg.RequestsPerMinute) + assert.Equal(t, defaultAPIBurst, cfg.Burst) + + t.Setenv(RateLimitingRPMEnv, "0") + t.Setenv(RateLimitingBurstEnv, "-5") + cfg, _ = RateLimiterConfigFromEnv() + assert.Equal(t, float64(defaultAPIRPM), cfg.RequestsPerMinute, "non-positive rpm must fall back to default") + assert.Equal(t, defaultAPIBurst, cfg.Burst, "non-positive burst must fall back to default") +} diff --git a/management/server/http/testing/testing_tools/channel/channel.go b/management/server/http/testing/testing_tools/channel/channel.go index 0203d6177..1a8b83c7e 100644 --- a/management/server/http/testing/testing_tools/channel/channel.go +++ b/management/server/http/testing/testing_tools/channel/channel.go @@ -135,7 +135,7 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee customZonesManager := zonesManager.NewManager(store, am, permissionsManager, "") zoneRecordsManager := recordsManager.NewManager(store, am, permissionsManager) - apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil) + apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil, nil) if err != nil { t.Fatalf("Failed to create API handler: %v", err) } @@ -264,7 +264,7 @@ func BuildApiBlackBoxWithDBStateAndPeerChannel(t testing_tools.TB, sqlFile strin customZonesManager := zonesManager.NewManager(store, am, permissionsManager, "") zoneRecordsManager := recordsManager.NewManager(store, am, permissionsManager) - apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil) + apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil, nil) if err != nil { t.Fatalf("Failed to create API handler: %v", err) } From fa0d58d093a883ddd6183c52e1b92e0888b2bafd Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Thu, 23 Apr 2026 16:01:54 +0200 Subject: [PATCH 02/11] [management] exclude peers for expiration job that have already been marked expired (#5970) --- management/server/account_test.go | 23 +++++++++++ management/server/management_proto_test.go | 4 +- management/server/peer.go | 4 ++ management/server/store/sql_store.go | 2 +- management/server/store/sql_store_test.go | 40 ++++++++++++++----- .../testdata/store_with_expired_peers.sql | 1 + 6 files changed, 62 insertions(+), 12 deletions(-) diff --git a/management/server/account_test.go b/management/server/account_test.go index 4453d064e..bcc73d52f 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -2311,6 +2311,29 @@ func TestAccount_GetExpiredPeers(t *testing.T) { } } +func TestGetExpiredPeers_SkipsAlreadyExpired(t *testing.T) { + ctx := context.Background() + + testStore, cleanUp, err := store.NewTestStoreFromSQL(ctx, "testdata/store_with_expired_peers.sql", t.TempDir()) + t.Cleanup(cleanUp) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + // Verify the already-expired peer is excluded at the store level + peers, err := testStore.GetAccountPeersWithExpiration(ctx, store.LockingStrengthNone, accountID) + require.NoError(t, err) + + for _, peer := range peers { + assert.NotEqual(t, "cg05lnblo1hkg2j514p0", peer.ID, "already expired peer should be excluded by the store query") + assert.False(t, peer.Status.LoginExpired, "returned peers should not already be marked as login expired") + } + + // Only the non-expired peer with expiration enabled should be returned + require.Len(t, peers, 1) + assert.Equal(t, "notexpired01", peers[0].ID) +} + func TestAccount_GetInactivePeers(t *testing.T) { type test struct { name string diff --git a/management/server/management_proto_test.go b/management/server/management_proto_test.go index 4e6eb0a33..18d85315d 100644 --- a/management/server/management_proto_test.go +++ b/management/server/management_proto_test.go @@ -267,8 +267,8 @@ func Test_SyncProtocol(t *testing.T) { } // expired peers come separately. - if len(networkMap.GetOfflinePeers()) != 1 { - t.Fatal("expecting SyncResponse to have NetworkMap with 1 offline peer") + if len(networkMap.GetOfflinePeers()) != 2 { + t.Fatal("expecting SyncResponse to have NetworkMap with 2 offline peer") } expiredPeerPubKey := "RlSy2vzoG2HyMBTUImXOiVhCBiiBa5qD5xzMxkiFDW4=" diff --git a/management/server/peer.go b/management/server/peer.go index a02e34e0d..a95ae17a3 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -1405,6 +1405,10 @@ func (am *DefaultAccountManager) getExpiredPeers(ctx context.Context, accountID var peers []*nbpeer.Peer for _, peer := range peersWithExpiry { + if peer.Status.LoginExpired { + continue + } + expired, _ := peer.LoginExpired(settings.PeerLoginExpiration) if expired { peers = append(peers, peer) diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index 8189548b7..0ff57b752 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -3310,7 +3310,7 @@ func (s *SqlStore) GetAccountPeersWithExpiration(ctx context.Context, lockStreng var peers []*nbpeer.Peer result := tx. - Where("login_expiration_enabled = ? AND user_id IS NOT NULL AND user_id != ''", true). + Where("login_expiration_enabled = ? AND peer_status_login_expired != ? AND user_id IS NOT NULL AND user_id != ''", true, true). Find(&peers, accountIDCondition, accountID) if err := result.Error; err != nil { log.WithContext(ctx).Errorf("failed to get peers with expiration from the store: %s", result.Error) diff --git a/management/server/store/sql_store_test.go b/management/server/store/sql_store_test.go index 8ea6c2ae5..5a5616abc 100644 --- a/management/server/store/sql_store_test.go +++ b/management/server/store/sql_store_test.go @@ -2729,7 +2729,7 @@ func TestSqlStore_GetAccountPeers(t *testing.T) { { name: "should retrieve peers for an existing account ID", accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b", - expectedCount: 4, + expectedCount: 5, }, { name: "should return no peers for a non-existing account ID", @@ -2751,7 +2751,7 @@ func TestSqlStore_GetAccountPeers(t *testing.T) { name: "should filter peers by partial name", accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b", nameFilter: "host", - expectedCount: 3, + expectedCount: 4, }, { name: "should filter peers by ip", @@ -2777,14 +2777,16 @@ func TestSqlStore_GetAccountPeersWithExpiration(t *testing.T) { require.NoError(t, err) tests := []struct { - name string - accountID string - expectedCount int + name string + accountID string + expectedCount int + expectedPeerIDs []string }{ { - name: "should retrieve peers with expiration for an existing account ID", - accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b", - expectedCount: 1, + name: "should retrieve only non-expired peers with expiration enabled", + accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b", + expectedCount: 1, + expectedPeerIDs: []string{"notexpired01"}, }, { name: "should return no peers with expiration for a non-existing account ID", @@ -2803,10 +2805,30 @@ func TestSqlStore_GetAccountPeersWithExpiration(t *testing.T) { peers, err := store.GetAccountPeersWithExpiration(context.Background(), LockingStrengthNone, tt.accountID) require.NoError(t, err) require.Len(t, peers, tt.expectedCount) + for i, peer := range peers { + assert.Equal(t, tt.expectedPeerIDs[i], peer.ID) + } }) } } +func TestSqlStore_GetAccountPeersWithExpiration_ExcludesAlreadyExpired(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store_with_expired_peers.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + peers, err := store.GetAccountPeersWithExpiration(context.Background(), LockingStrengthNone, accountID) + require.NoError(t, err) + + // Verify the already-expired peer (cg05lnblo1hkg2j514p0) is not returned + for _, peer := range peers { + assert.NotEqual(t, "cg05lnblo1hkg2j514p0", peer.ID, "already expired peer should not be returned") + assert.False(t, peer.Status.LoginExpired, "returned peers should not have LoginExpired set") + } +} + func TestSqlStore_GetAccountPeersWithInactivity(t *testing.T) { store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store_with_expired_peers.sql", t.TempDir()) t.Cleanup(cleanup) @@ -2887,7 +2909,7 @@ func TestSqlStore_GetUserPeers(t *testing.T) { name: "should retrieve peers for another valid account ID and user ID", accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b", userID: "edafee4e-63fb-11ec-90d6-0242ac120003", - expectedCount: 2, + expectedCount: 3, }, { name: "should return no peers for existing account ID with empty user ID", diff --git a/management/server/testdata/store_with_expired_peers.sql b/management/server/testdata/store_with_expired_peers.sql index dfcaeee6f..189bd1262 100644 --- a/management/server/testdata/store_with_expired_peers.sql +++ b/management/server/testdata/store_with_expired_peers.sql @@ -31,6 +31,7 @@ INSERT INTO peers VALUES('cfvprsrlo1hqoo49ohog','bf1c8084-ba50-4ce7-9439-3465300 INSERT INTO peers VALUES('cg05lnblo1hkg2j514p0','bf1c8084-ba50-4ce7-9439-34653001fc3b','RlSy2vzoG2HyMBTUImXOiVhCBiiBa5qD5xzMxkiFDW4=','','"100.64.39.54"','expiredhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'expiredhost','expiredhost','2023-03-02 09:19:57.276717255+01:00',0,1,0,'edafee4e-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIMbK5ZXJsGOOWoBT4OmkPtgdPZe2Q7bDuS/zjn2CZxhK',0,1,0,'2023-03-02 09:14:21.791679181+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); INSERT INTO peers VALUES('cg3161rlo1hs9cq94gdg','bf1c8084-ba50-4ce7-9439-34653001fc3b','mVABSKj28gv+JRsf7e0NEGKgSOGTfU/nPB2cpuG56HU=','','"100.64.117.96"','testhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'testhost','testhost','2023-03-06 18:21:27.252010027+01:00',0,0,0,'edafee4e-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAINWvvUkFFcrj48CWTkNUb/do/n52i1L5dH4DhGu+4ZuM',0,0,0,'2023-03-07 09:02:47.442857106+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); INSERT INTO peers VALUES('csrnkiq7qv9d8aitqd50','bf1c8084-ba50-4ce7-9439-34653001fc3b','nVABSKj28gv+JRsf7e0NEGKgSOGTfU/nPB2cpuG56HX=','','"100.64.117.97"','testhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'testhost','testhost-1','2023-03-06 18:21:27.252010027+01:00',0,0,0,'f4f6d672-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAINWvvUkFFcrj48CWTkNUb/do/n52i1L5dH4DhGu+4ZuM',0,0,1,'2023-03-07 09:02:47.442857106+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); +INSERT INTO peers VALUES('notexpired01','bf1c8084-ba50-4ce7-9439-34653001fc3b','oVABSKj28gv+JRsf7e0NEGKgSOGTfU/nPB2cpuG56HY=','','"100.64.117.98"','activehost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'activehost','activehost','2023-03-06 18:21:27.252010027+01:00',0,0,0,'edafee4e-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAINWvvUkFFcrj48CWTkNUb/do/n52i1L5dH4DhGu+4ZuM',0,1,0,'2023-03-07 09:02:47.442857106+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); INSERT INTO users VALUES('f4f6d672-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','user',0,0,'','[]',0,NULL,'2024-10-02 17:00:32.528196+02:00','api',0,''); INSERT INTO users VALUES('edafee4e-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','admin',0,0,'','[]',0,NULL,'2024-10-02 17:00:32.528196+02:00','api',0,''); INSERT INTO installations VALUES(1,''); From c07c726ea7a7d6bc9b34c1a5dc138785c7bd1214 Mon Sep 17 00:00:00 2001 From: alsruf36 <33592711+alsruf36@users.noreply.github.com> Date: Fri, 24 Apr 2026 01:20:54 +0900 Subject: [PATCH 03/11] [proxy] Set session cookie path to root (#5915) --- proxy/internal/auth/middleware.go | 1 + proxy/internal/auth/middleware_test.go | 9 +++++++++ 2 files changed, 10 insertions(+) diff --git a/proxy/internal/auth/middleware.go b/proxy/internal/auth/middleware.go index 055e4510f..3b383f8b4 100644 --- a/proxy/internal/auth/middleware.go +++ b/proxy/internal/auth/middleware.go @@ -433,6 +433,7 @@ func setSessionCookie(w http.ResponseWriter, token string, expiration time.Durat http.SetCookie(w, &http.Cookie{ Name: auth.SessionCookieName, Value: token, + Path: "/", HttpOnly: true, Secure: true, SameSite: http.SameSiteLaxMode, diff --git a/proxy/internal/auth/middleware_test.go b/proxy/internal/auth/middleware_test.go index 16d09800c..2c93d7912 100644 --- a/proxy/internal/auth/middleware_test.go +++ b/proxy/internal/auth/middleware_test.go @@ -391,6 +391,15 @@ func TestProtect_SchemeAuthRedirectsWithCookie(t *testing.T) { assert.Equal(t, http.SameSiteLaxMode, sessionCookie.SameSite) } +func TestSetSessionCookieHasRootPath(t *testing.T) { + w := httptest.NewRecorder() + setSessionCookie(w, "test-token", time.Hour) + + cookies := w.Result().Cookies() + require.Len(t, cookies, 1) + assert.Equal(t, "/", cookies[0].Path, "session cookie must be scoped to root so it applies to all paths") +} + func TestProtect_FailedAuthDoesNotSetCookie(t *testing.T) { mw := NewMiddleware(log.StandardLogger(), nil, nil) kp := generateTestKeyPair(t) From f732b01a055d4de1fd736f2e8b42196c31f291a5 Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Thu, 23 Apr 2026 21:19:21 +0200 Subject: [PATCH 04/11] [management] unify peer-update test timeout via constant (#5952) peerShouldReceiveUpdate waited 500ms for the expected update message, and every outer wrapper across the management/server test suite paired it with a 1s goroutine-drain timeout. Both were too tight for slower CI runners (MySQL, FreeBSD, loaded sqlite), producing intermittent "Timed out waiting for update message" failures in tests like TestDNSAccountPeersUpdate, TestPeerAccountPeersUpdate, and TestNameServerAccountPeersUpdate. Introduce peerUpdateTimeout (5s) next to the helper and use it both in the helper and in every outer wrapper so the two timeouts stay in sync. Only runs down on failure; passing tests return as soon as the channel delivers, so there is no slowdown on green runs. --- management/server/account_test.go | 9 ++++++++- management/server/dns_test.go | 6 +++--- management/server/group_test.go | 14 +++++++------- management/server/nameserver_test.go | 4 ++-- management/server/peer_test.go | 16 ++++++++-------- management/server/policy_test.go | 14 +++++++------- management/server/posture_checks_test.go | 10 +++++----- management/server/route_test.go | 12 ++++++------ management/server/user_test.go | 4 ++-- 9 files changed, 48 insertions(+), 41 deletions(-) diff --git a/management/server/account_test.go b/management/server/account_test.go index bcc73d52f..bef791d77 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -3253,6 +3253,13 @@ func setupNetworkMapTest(t *testing.T) (*DefaultAccountManager, *update_channel. return manager, updateManager, account, peer1, peer2, peer3 } +// peerUpdateTimeout bounds how long peerShouldReceiveUpdate and its outer +// wrappers wait for an expected update message. Sized for slow CI runners +// (MySQL, FreeBSD, loaded sqlite) where the channel publish can take +// seconds. Only runs down on failure; passing tests return immediately +// when the channel delivers. +const peerUpdateTimeout = 5 * time.Second + func peerShouldNotReceiveUpdate(t *testing.T, updateMessage <-chan *network_map.UpdateMessage) { t.Helper() select { @@ -3271,7 +3278,7 @@ func peerShouldReceiveUpdate(t *testing.T, updateMessage <-chan *network_map.Upd if msg == nil { t.Errorf("Received nil update message, expected valid message") } - case <-time.After(500 * time.Millisecond): + case <-time.After(peerUpdateTimeout): t.Error("Timed out waiting for update message") } } diff --git a/management/server/dns_test.go b/management/server/dns_test.go index 0e37a3b22..c443223c6 100644 --- a/management/server/dns_test.go +++ b/management/server/dns_test.go @@ -458,7 +458,7 @@ func TestDNSAccountPeersUpdate(t *testing.T) { select { case <-done: - case <-time.After(time.Second): + case <-time.After(peerUpdateTimeout): t.Error("timeout waiting for peerShouldReceiveUpdate") } }) @@ -478,7 +478,7 @@ func TestDNSAccountPeersUpdate(t *testing.T) { select { case <-done: - case <-time.After(time.Second): + case <-time.After(peerUpdateTimeout): t.Error("timeout waiting for peerShouldReceiveUpdate") } }) @@ -518,7 +518,7 @@ func TestDNSAccountPeersUpdate(t *testing.T) { select { case <-done: - case <-time.After(time.Second): + case <-time.After(peerUpdateTimeout): t.Error("timeout waiting for peerShouldReceiveUpdate") } }) diff --git a/management/server/group_test.go b/management/server/group_test.go index fa818e532..5821b90a3 100644 --- a/management/server/group_test.go +++ b/management/server/group_test.go @@ -620,7 +620,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) { select { case <-done: - case <-time.After(time.Second): + case <-time.After(peerUpdateTimeout): t.Error("timeout waiting for peerShouldReceiveUpdate") } }) @@ -638,7 +638,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) { select { case <-done: - case <-time.After(time.Second): + case <-time.After(peerUpdateTimeout): t.Error("timeout waiting for peerShouldReceiveUpdate") } }) @@ -656,7 +656,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) { select { case <-done: - case <-time.After(time.Second): + case <-time.After(peerUpdateTimeout): t.Error("timeout waiting for peerShouldReceiveUpdate") } }) @@ -689,7 +689,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) { select { case <-done: - case <-time.After(time.Second): + case <-time.After(peerUpdateTimeout): t.Error("timeout waiting for peerShouldReceiveUpdate") } }) @@ -730,7 +730,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) { select { case <-done: - case <-time.After(time.Second): + case <-time.After(peerUpdateTimeout): t.Error("timeout waiting for peerShouldReceiveUpdate") } }) @@ -757,7 +757,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) { select { case <-done: - case <-time.After(time.Second): + case <-time.After(peerUpdateTimeout): t.Error("timeout waiting for peerShouldReceiveUpdate") } }) @@ -804,7 +804,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) { select { case <-done: - case <-time.After(time.Second): + case <-time.After(peerUpdateTimeout): t.Error("timeout waiting for peerShouldReceiveUpdate") } }) diff --git a/management/server/nameserver_test.go b/management/server/nameserver_test.go index d10d4464f..b2c8300d6 100644 --- a/management/server/nameserver_test.go +++ b/management/server/nameserver_test.go @@ -1087,7 +1087,7 @@ func TestNameServerAccountPeersUpdate(t *testing.T) { select { case <-done: - case <-time.After(time.Second): + case <-time.After(peerUpdateTimeout): t.Error("timeout waiting for peerShouldReceiveUpdate") } }) @@ -1105,7 +1105,7 @@ func TestNameServerAccountPeersUpdate(t *testing.T) { select { case <-done: - case <-time.After(time.Second): + case <-time.After(peerUpdateTimeout): t.Error("timeout waiting for peerShouldReceiveUpdate") } }) diff --git a/management/server/peer_test.go b/management/server/peer_test.go index 6f8d924fd..050baa595 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -1907,7 +1907,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) { select { case <-done: - case <-time.After(time.Second): + case <-time.After(peerUpdateTimeout): t.Error("timeout waiting for peerShouldReceiveUpdate") } }) @@ -1929,7 +1929,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) { select { case <-done: - case <-time.After(time.Second): + case <-time.After(peerUpdateTimeout): t.Error("timeout waiting for peerShouldReceiveUpdate") } }) @@ -1994,7 +1994,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) { select { case <-done: - case <-time.After(time.Second): + case <-time.After(peerUpdateTimeout): t.Error("timeout waiting for peerShouldReceiveUpdate") } }) @@ -2012,7 +2012,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) { select { case <-done: - case <-time.After(time.Second): + case <-time.After(peerUpdateTimeout): t.Error("timeout waiting for peerShouldReceiveUpdate") } }) @@ -2058,7 +2058,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) { select { case <-done: - case <-time.After(time.Second): + case <-time.After(peerUpdateTimeout): t.Error("timeout waiting for peerShouldReceiveUpdate") } }) @@ -2076,7 +2076,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) { select { case <-done: - case <-time.After(time.Second): + case <-time.After(peerUpdateTimeout): t.Error("timeout waiting for peerShouldReceiveUpdate") } }) @@ -2113,7 +2113,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) { select { case <-done: - case <-time.After(time.Second): + case <-time.After(peerUpdateTimeout): t.Error("timeout waiting for peerShouldReceiveUpdate") } }) @@ -2131,7 +2131,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) { select { case <-done: - case <-time.After(time.Second): + case <-time.After(peerUpdateTimeout): t.Error("timeout waiting for peerShouldReceiveUpdate") } }) diff --git a/management/server/policy_test.go b/management/server/policy_test.go index a3f987732..a553b7d05 100644 --- a/management/server/policy_test.go +++ b/management/server/policy_test.go @@ -1231,7 +1231,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { select { case <-done: - case <-time.After(time.Second): + case <-time.After(peerUpdateTimeout): t.Error("timeout waiting for peerShouldReceiveUpdate") } }) @@ -1263,7 +1263,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { select { case <-done: - case <-time.After(time.Second): + case <-time.After(peerUpdateTimeout): t.Error("timeout waiting for peerShouldReceiveUpdate") } }) @@ -1294,7 +1294,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { select { case <-done: - case <-time.After(time.Second): + case <-time.After(peerUpdateTimeout): t.Error("timeout waiting for peerShouldReceiveUpdate") } }) @@ -1314,7 +1314,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { select { case <-done: - case <-time.After(time.Second): + case <-time.After(peerUpdateTimeout): t.Error("timeout waiting for peerShouldReceiveUpdate") } }) @@ -1355,7 +1355,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { select { case <-done: - case <-time.After(time.Second): + case <-time.After(peerUpdateTimeout): t.Error("timeout waiting for peerShouldReceiveUpdate") } }) @@ -1373,7 +1373,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { select { case <-done: - case <-time.After(time.Second): + case <-time.After(peerUpdateTimeout): t.Error("timeout waiting for peerShouldReceiveUpdate") } @@ -1393,7 +1393,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { select { case <-done: - case <-time.After(time.Second): + case <-time.After(peerUpdateTimeout): t.Error("timeout waiting for peerShouldReceiveUpdate") } }) diff --git a/management/server/posture_checks_test.go b/management/server/posture_checks_test.go index 7f0a48dc7..394f0d896 100644 --- a/management/server/posture_checks_test.go +++ b/management/server/posture_checks_test.go @@ -244,7 +244,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { select { case <-done: - case <-time.After(time.Second): + case <-time.After(peerUpdateTimeout): t.Error("timeout waiting for peerShouldReceiveUpdate") } }) @@ -273,7 +273,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { select { case <-done: - case <-time.After(time.Second): + case <-time.After(peerUpdateTimeout): t.Error("timeout waiting for peerShouldReceiveUpdate") } }) @@ -292,7 +292,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { select { case <-done: - case <-time.After(time.Second): + case <-time.After(peerUpdateTimeout): t.Error("timeout waiting for peerShouldReceiveUpdate") } }) @@ -395,7 +395,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { select { case <-done: - case <-time.After(time.Second): + case <-time.After(peerUpdateTimeout): t.Error("timeout waiting for peerShouldReceiveUpdate") } }) @@ -438,7 +438,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { select { case <-done: - case <-time.After(time.Second): + case <-time.After(peerUpdateTimeout): t.Error("timeout waiting for peerShouldReceiveUpdate") } }) diff --git a/management/server/route_test.go b/management/server/route_test.go index 91b2cf982..91bd8b050 100644 --- a/management/server/route_test.go +++ b/management/server/route_test.go @@ -2070,7 +2070,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) { select { case <-done: - case <-time.After(time.Second): + case <-time.After(peerUpdateTimeout): t.Error("timeout waiting for peerShouldReceiveUpdate") } @@ -2107,7 +2107,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) { select { case <-done: - case <-time.After(time.Second): + case <-time.After(peerUpdateTimeout): t.Error("timeout waiting for peerShouldReceiveUpdate") } }) @@ -2127,7 +2127,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) { select { case <-done: - case <-time.After(time.Second): + case <-time.After(peerUpdateTimeout): t.Error("timeout waiting for peerShouldReceiveUpdate") } }) @@ -2145,7 +2145,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) { select { case <-done: - case <-time.After(time.Second): + case <-time.After(peerUpdateTimeout): t.Error("timeout waiting for peerShouldReceiveUpdate") } }) @@ -2185,7 +2185,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) { select { case <-done: - case <-time.After(time.Second): + case <-time.After(peerUpdateTimeout): t.Error("timeout waiting for peerShouldReceiveUpdate") } }) @@ -2225,7 +2225,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) { select { case <-done: - case <-time.After(time.Second): + case <-time.After(peerUpdateTimeout): t.Error("timeout waiting for peerShouldReceiveUpdate") } }) diff --git a/management/server/user_test.go b/management/server/user_test.go index 8fdfbd633..c77ea53d1 100644 --- a/management/server/user_test.go +++ b/management/server/user_test.go @@ -1586,7 +1586,7 @@ func TestUserAccountPeersUpdate(t *testing.T) { select { case <-done: - case <-time.After(time.Second): + case <-time.After(peerUpdateTimeout): t.Error("timeout waiting for peerShouldReceiveUpdate") } }) @@ -1609,7 +1609,7 @@ func TestUserAccountPeersUpdate(t *testing.T) { select { case <-done: - case <-time.After(time.Second): + case <-time.After(peerUpdateTimeout): t.Error("timeout waiting for peerShouldReceiveUpdate") } }) From d6f08e48408a7a09995076c4bd595832555eb407 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Fri, 24 Apr 2026 13:13:27 +0200 Subject: [PATCH 05/11] [misc] Update sign pipeline version (#5981) --- .github/workflows/release.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 5ada1033d..1d29c8406 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -9,7 +9,7 @@ on: pull_request: env: - SIGN_PIPE_VER: "v0.1.2" + SIGN_PIPE_VER: "v0.1.3" GORELEASER_VER: "v2.14.3" PRODUCT_NAME: "NetBird" COPYRIGHT: "NetBird GmbH" From 34167c8a160d66668eb5592fb2589b13adcd8ee0 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Mon, 27 Apr 2026 10:55:38 +0200 Subject: [PATCH 06/11] [misc] Update release pipeline version (#5995) --- .github/workflows/release.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 1d29c8406..826c05ff3 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -9,7 +9,7 @@ on: pull_request: env: - SIGN_PIPE_VER: "v0.1.3" + SIGN_PIPE_VER: "v0.1.4" GORELEASER_VER: "v2.14.3" PRODUCT_NAME: "NetBird" COPYRIGHT: "NetBird GmbH" From 154b81645a5922d0fa1fbfff6785798a14640e02 Mon Sep 17 00:00:00 2001 From: Vlad <4941176+crn4@users.noreply.github.com> Date: Mon, 27 Apr 2026 16:02:54 +0200 Subject: [PATCH 07/11] [management] removed legacy network map code (#5565) --- .../network_map/controller/controller.go | 269 +- .../controllers/network_map/interface.go | 3 - management/server/account_test.go | 98 +- .../http/handlers/peers/peers_handler.go | 2 +- management/server/peer_test.go | 11 - management/server/route_test.go | 245 -- management/server/store/sql_store.go | 2 - management/server/types/account.go | 403 --- management/server/types/account_test.go | 399 --- management/server/types/holder.go | 47 - management/server/types/networkmap.go | 67 - .../types/networkmap_comparison_test.go | 592 ----- .../server/types/networkmap_components.go | 2 - .../types/networkmap_components_test.go | 787 ++++++ .../server/types/networkmap_golden_test.go | 967 ------- management/server/types/networkmapbuilder.go | 2317 ----------------- 16 files changed, 807 insertions(+), 5404 deletions(-) delete mode 100644 management/server/types/holder.go delete mode 100644 management/server/types/networkmap.go delete mode 100644 management/server/types/networkmap_comparison_test.go create mode 100644 management/server/types/networkmap_components_test.go delete mode 100644 management/server/types/networkmap_golden_test.go delete mode 100644 management/server/types/networkmapbuilder.go diff --git a/management/internals/controllers/network_map/controller/controller.go b/management/internals/controllers/network_map/controller/controller.go index 4b414df6f..4b47ecaa0 100644 --- a/management/internals/controllers/network_map/controller/controller.go +++ b/management/internals/controllers/network_map/controller/controller.go @@ -7,7 +7,6 @@ import ( "os" "slices" "strconv" - "strings" "sync" "sync/atomic" "time" @@ -16,11 +15,9 @@ import ( "golang.org/x/exp/maps" "golang.org/x/mod/semver" - nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/internals/controllers/network_map" "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache" "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral" - "github.com/netbirdio/netbird/management/internals/modules/zones" "github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/internals/shared/grpc" "github.com/netbirdio/netbird/management/server/account" @@ -58,13 +55,6 @@ type Controller struct { proxyController port_forwarding.Controller integratedPeerValidator integrated_validator.IntegratedValidator - - holder *types.Holder - - expNewNetworkMap bool - expNewNetworkMapAIDs map[string]struct{} - - compactedNetworkMap bool } type bufferUpdate struct { @@ -81,29 +71,6 @@ func NewController(ctx context.Context, store store.Store, metrics telemetry.App log.Fatal(fmt.Errorf("error creating metrics: %w", err)) } - newNetworkMapBuilder, err := strconv.ParseBool(os.Getenv(network_map.EnvNewNetworkMapBuilder)) - if err != nil { - log.WithContext(ctx).Warnf("failed to parse %s, using default value false: %v", network_map.EnvNewNetworkMapBuilder, err) - newNetworkMapBuilder = false - } - - compactedNetworkMap := true - compactedEnv := os.Getenv(types.EnvNewNetworkMapCompacted) - parsedCompactedNmap, err := strconv.ParseBool(compactedEnv) - if err != nil && len(compactedEnv) > 0 { - log.WithContext(ctx).Warnf("failed to parse %s, using default value true: %v", types.EnvNewNetworkMapCompacted, err) - } - if err == nil && !parsedCompactedNmap { - log.WithContext(ctx).Info("disabling compacted mode") - compactedNetworkMap = false - } - - ids := strings.Split(os.Getenv(network_map.EnvNewNetworkMapAccounts), ",") - expIDs := make(map[string]struct{}, len(ids)) - for _, id := range ids { - expIDs[id] = struct{}{} - } - return &Controller{ repo: newRepository(store), metrics: nMetrics, @@ -117,12 +84,6 @@ func NewController(ctx context.Context, store store.Store, metrics telemetry.App proxyController: proxyController, EphemeralPeersManager: ephemeralPeersManager, - - holder: types.NewHolder(), - expNewNetworkMap: newNetworkMapBuilder, - expNewNetworkMapAIDs: expIDs, - - compactedNetworkMap: compactedNetworkMap, } } @@ -153,17 +114,9 @@ func (c *Controller) CountStreams() int { func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID string) error { log.WithContext(ctx).Tracef("updating peers for account %s from %s", accountID, util.GetCallerName()) - var ( - account *types.Account - err error - ) - if c.experimentalNetworkMap(accountID) { - account = c.getAccountFromHolderOrInit(ctx, accountID) - } else { - account, err = c.requestBuffer.GetAccountWithBackpressure(ctx, accountID) - if err != nil { - return fmt.Errorf("failed to get account: %v", err) - } + account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID) + if err != nil { + return fmt.Errorf("failed to get account: %v", err) } globalStart := time.Now() @@ -197,10 +150,6 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin routers := account.GetResourceRoutersMap() groupIDToUserIDs := account.GetActiveGroupUsers() - if c.experimentalNetworkMap(accountID) { - c.initNetworkMapBuilderIfNeeded(account, approvedPeersMap) - } - proxyNetworkMaps, err := c.proxyController.GetProxyNetworkMapsAll(ctx, accountID, account.Peers) if err != nil { log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err) @@ -243,16 +192,7 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin c.metrics.CountCalcPostureChecksDuration(time.Since(start)) start = time.Now() - var remotePeerNetworkMap *types.NetworkMap - - switch { - case c.experimentalNetworkMap(accountID): - remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, p.AccountID, p.ID, approvedPeersMap, peersCustomZone, accountZones, c.accountManagerMetrics) - case c.compactedNetworkMap: - remotePeerNetworkMap = account.GetPeerNetworkMapFromComponents(ctx, p.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs) - default: - remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, p.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs) - } + remotePeerNetworkMap := account.GetPeerNetworkMapFromComponents(ctx, p.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs) c.metrics.CountCalcPeerNetworkMapDuration(time.Since(start)) @@ -318,10 +258,6 @@ func (c *Controller) bufferSendUpdateAccountPeers(ctx context.Context, accountID // UpdatePeers updates all peers that belong to an account. // Should be called when changes have to be synced to peers. func (c *Controller) UpdateAccountPeers(ctx context.Context, accountID string) error { - if err := c.RecalculateNetworkMapCache(ctx, accountID); err != nil { - return fmt.Errorf("recalculate network map cache: %v", err) - } - return c.sendUpdateAccountPeers(ctx, accountID) } @@ -371,16 +307,7 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe return err } - var remotePeerNetworkMap *types.NetworkMap - - switch { - case c.experimentalNetworkMap(accountId): - remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, peersCustomZone, accountZones, c.accountManagerMetrics) - case c.compactedNetworkMap: - remotePeerNetworkMap = account.GetPeerNetworkMapFromComponents(ctx, peerId, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs) - default: - remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, peerId, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs) - } + remotePeerNetworkMap := account.GetPeerNetworkMapFromComponents(ctx, peerId, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs) proxyNetworkMap, ok := proxyNetworkMaps[peer.ID] if ok { @@ -451,17 +378,9 @@ func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresAppr return peer, emptyMap, nil, 0, nil } - var ( - account *types.Account - err error - ) - if c.experimentalNetworkMap(accountID) { - account = c.getAccountFromHolderOrInit(ctx, accountID) - } else { - account, err = c.requestBuffer.GetAccountWithBackpressure(ctx, accountID) - if err != nil { - return nil, nil, nil, 0, err - } + account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID) + if err != nil { + return nil, nil, nil, 0, err } account.InjectProxyPolicies(ctx) @@ -493,20 +412,10 @@ func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresAppr return nil, nil, nil, 0, err } - var networkMap *types.NetworkMap - - if c.experimentalNetworkMap(accountID) { - networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, peersCustomZone, accountZones, c.accountManagerMetrics) - } else { - resourcePolicies := account.GetResourcePoliciesMap() - routers := account.GetResourceRoutersMap() - groupIDToUserIDs := account.GetActiveGroupUsers() - if c.compactedNetworkMap { - networkMap = account.GetPeerNetworkMapFromComponents(ctx, peer.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs) - } else { - networkMap = account.GetPeerNetworkMap(ctx, peer.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs) - } - } + resourcePolicies := account.GetResourcePoliciesMap() + routers := account.GetResourceRoutersMap() + groupIDToUserIDs := account.GetActiveGroupUsers() + networkMap := account.GetPeerNetworkMapFromComponents(ctx, peer.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs) proxyNetworkMap, ok := proxyNetworkMaps[peer.ID] if ok { @@ -518,108 +427,6 @@ func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresAppr return peer, networkMap, postureChecks, dnsFwdPort, nil } -func (c *Controller) initNetworkMapBuilderIfNeeded(account *types.Account, validatedPeers map[string]struct{}) { - c.enrichAccountFromHolder(account) - account.InitNetworkMapBuilderIfNeeded(validatedPeers) -} - -func (c *Controller) getPeerNetworkMapExp( - ctx context.Context, - accountId string, - peerId string, - validatedPeers map[string]struct{}, - peersCustomZone nbdns.CustomZone, - accountZones []*zones.Zone, - metrics *telemetry.AccountManagerMetrics, -) *types.NetworkMap { - account := c.getAccountFromHolderOrInit(ctx, accountId) - if account == nil { - log.WithContext(ctx).Warnf("account %s not found in holder when getting peer network map", accountId) - return &types.NetworkMap{ - Network: &types.Network{}, - } - } - - return account.GetPeerNetworkMapExp(ctx, peerId, peersCustomZone, accountZones, validatedPeers, metrics) -} - -func (c *Controller) onPeersAddedUpdNetworkMapCache(account *types.Account, peerIds ...string) { - c.enrichAccountFromHolder(account) - account.OnPeersAddedUpdNetworkMapCache(peerIds...) -} - -func (c *Controller) onPeerDeletedUpdNetworkMapCache(account *types.Account, peerId string) error { - c.enrichAccountFromHolder(account) - return account.OnPeerDeletedUpdNetworkMapCache(peerId) -} - -func (c *Controller) UpdatePeerInNetworkMapCache(accountId string, peer *nbpeer.Peer) { - account := c.getAccountFromHolder(accountId) - if account == nil { - return - } - account.UpdatePeerInNetworkMapCache(peer) -} - -func (c *Controller) recalculateNetworkMapCache(account *types.Account, validatedPeers map[string]struct{}) { - account.RecalculateNetworkMapCache(validatedPeers) - c.updateAccountInHolder(account) -} - -func (c *Controller) RecalculateNetworkMapCache(ctx context.Context, accountId string) error { - if c.experimentalNetworkMap(accountId) { - account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountId) - if err != nil { - return err - } - validatedPeers, err := c.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra) - if err != nil { - log.WithContext(ctx).Errorf("failed to get validate peers: %v", err) - return err - } - c.recalculateNetworkMapCache(account, validatedPeers) - } - return nil -} - -func (c *Controller) experimentalNetworkMap(accountId string) bool { - _, ok := c.expNewNetworkMapAIDs[accountId] - return c.expNewNetworkMap || ok -} - -func (c *Controller) enrichAccountFromHolder(account *types.Account) { - a := c.holder.GetAccount(account.Id) - if a == nil { - c.holder.AddAccount(account) - return - } - account.NetworkMapCache = a.NetworkMapCache - if account.NetworkMapCache == nil { - return - } - c.holder.AddAccount(account) -} - -func (c *Controller) getAccountFromHolder(accountID string) *types.Account { - return c.holder.GetAccount(accountID) -} - -func (c *Controller) getAccountFromHolderOrInit(ctx context.Context, accountID string) *types.Account { - a := c.holder.GetAccount(accountID) - if a != nil { - return a - } - account, err := c.holder.LoadOrStoreFunc(ctx, accountID, c.requestBuffer.GetAccountWithBackpressure) - if err != nil { - return nil - } - return account -} - -func (c *Controller) updateAccountInHolder(account *types.Account) { - c.holder.AddAccount(account) -} - // GetDNSDomain returns the configured dnsDomain func (c *Controller) GetDNSDomain(settings *types.Settings) string { if settings == nil { @@ -756,16 +563,7 @@ func isPeerInPolicySourceGroups(account *types.Account, peerID string, policy *t } func (c *Controller) OnPeersUpdated(ctx context.Context, accountID string, peerIDs []string) error { - peers, err := c.repo.GetPeersByIDs(ctx, accountID, peerIDs) - if err != nil { - return fmt.Errorf("failed to get peers by ids: %w", err) - } - - for _, peer := range peers { - c.UpdatePeerInNetworkMapCache(accountID, peer) - } - - err = c.bufferSendUpdateAccountPeers(ctx, accountID) + err := c.bufferSendUpdateAccountPeers(ctx, accountID) if err != nil { log.WithContext(ctx).Errorf("failed to buffer update account peers for peer update in account %s: %v", accountID, err) } @@ -775,14 +573,6 @@ func (c *Controller) OnPeersUpdated(ctx context.Context, accountID string, peerI func (c *Controller) OnPeersAdded(ctx context.Context, accountID string, peerIDs []string) error { log.WithContext(ctx).Debugf("OnPeersAdded call to add peers: %v", peerIDs) - if c.experimentalNetworkMap(accountID) { - account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID) - if err != nil { - return err - } - log.WithContext(ctx).Debugf("peers are ready to be added to networkmap cache: %v", peerIDs) - c.onPeersAddedUpdNetworkMapCache(account, peerIDs...) - } return c.bufferSendUpdateAccountPeers(ctx, accountID) } @@ -817,19 +607,6 @@ func (c *Controller) OnPeersDeleted(ctx context.Context, accountID string, peerI MessageType: network_map.MessageTypeNetworkMap, }) c.peersUpdateManager.CloseChannel(ctx, peerID) - - if c.experimentalNetworkMap(accountID) { - account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID) - if err != nil { - log.WithContext(ctx).Errorf("failed to get account %s: %v", accountID, err) - continue - } - err = c.onPeerDeletedUpdNetworkMapCache(account, peerID) - if err != nil { - log.WithContext(ctx).Errorf("failed to update network map cache for deleted peer %s in account %s: %v", peerID, accountID, err) - continue - } - } } return c.bufferSendUpdateAccountPeers(ctx, accountID) @@ -872,21 +649,11 @@ func (c *Controller) GetNetworkMap(ctx context.Context, peerID string) (*types.N return nil, err } - var networkMap *types.NetworkMap - - if c.experimentalNetworkMap(peer.AccountID) { - networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peerID, validatedPeers, peersCustomZone, accountZones, nil) - } else { - account.InjectProxyPolicies(ctx) - resourcePolicies := account.GetResourcePoliciesMap() - routers := account.GetResourceRoutersMap() - groupIDToUserIDs := account.GetActiveGroupUsers() - if c.compactedNetworkMap { - networkMap = account.GetPeerNetworkMapFromComponents(ctx, peer.ID, peersCustomZone, accountZones, validatedPeers, resourcePolicies, routers, nil, groupIDToUserIDs) - } else { - networkMap = account.GetPeerNetworkMap(ctx, peer.ID, peersCustomZone, accountZones, validatedPeers, resourcePolicies, routers, nil, groupIDToUserIDs) - } - } + account.InjectProxyPolicies(ctx) + resourcePolicies := account.GetResourcePoliciesMap() + routers := account.GetResourceRoutersMap() + groupIDToUserIDs := account.GetActiveGroupUsers() + networkMap := account.GetPeerNetworkMapFromComponents(ctx, peer.ID, peersCustomZone, accountZones, validatedPeers, resourcePolicies, routers, nil, groupIDToUserIDs) proxyNetworkMap, ok := proxyNetworkMaps[peer.ID] if ok { diff --git a/management/internals/controllers/network_map/interface.go b/management/internals/controllers/network_map/interface.go index 64caac861..cfea2d3de 100644 --- a/management/internals/controllers/network_map/interface.go +++ b/management/internals/controllers/network_map/interface.go @@ -12,9 +12,6 @@ import ( ) const ( - EnvNewNetworkMapBuilder = "NB_EXPERIMENT_NETWORK_MAP" - EnvNewNetworkMapAccounts = "NB_EXPERIMENT_NETWORK_MAP_ACCOUNTS" - DnsForwarderPort = nbdns.ForwarderServerPort OldForwarderPort = nbdns.ForwarderClientPort DnsForwarderPortMinVersion = "v0.59.0" diff --git a/management/server/account_test.go b/management/server/account_test.go index bef791d77..756c42421 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -408,7 +408,7 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) { } customZone := account.GetPeersCustomZone(context.Background(), "netbird.io") - networkMap := account.GetPeerNetworkMap(context.Background(), testCase.peerID, customZone, nil, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil, account.GetActiveGroupUsers()) + networkMap := account.GetPeerNetworkMapFromComponents(context.Background(), testCase.peerID, customZone, nil, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil, account.GetActiveGroupUsers()) assert.Len(t, networkMap.Peers, len(testCase.expectedPeers)) assert.Len(t, networkMap.OfflinePeers, len(testCase.expectedOfflinePeers)) } @@ -1171,11 +1171,6 @@ func TestAccountManager_AddPeerWithUserID(t *testing.T) { assert.Equal(t, peer.IP.String(), fmt.Sprint(ev.Meta["ip"])) } -func TestAccountManager_NetworkUpdates_SaveGroup_Experimental(t *testing.T) { - t.Setenv(network_map.EnvNewNetworkMapBuilder, "true") - testAccountManager_NetworkUpdates_SaveGroup(t) -} - func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) { testAccountManager_NetworkUpdates_SaveGroup(t) } @@ -1231,11 +1226,6 @@ func testAccountManager_NetworkUpdates_SaveGroup(t *testing.T) { wg.Wait() } -func TestAccountManager_NetworkUpdates_DeletePolicy_Experimental(t *testing.T) { - t.Setenv(network_map.EnvNewNetworkMapBuilder, "true") - testAccountManager_NetworkUpdates_DeletePolicy(t) -} - func TestAccountManager_NetworkUpdates_DeletePolicy(t *testing.T) { testAccountManager_NetworkUpdates_DeletePolicy(t) } @@ -1274,11 +1264,6 @@ func testAccountManager_NetworkUpdates_DeletePolicy(t *testing.T) { wg.Wait() } -func TestAccountManager_NetworkUpdates_SavePolicy_Experimental(t *testing.T) { - t.Setenv(network_map.EnvNewNetworkMapBuilder, "true") - testAccountManager_NetworkUpdates_SavePolicy(t) -} - func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) { testAccountManager_NetworkUpdates_SavePolicy(t) } @@ -1332,11 +1317,6 @@ func testAccountManager_NetworkUpdates_SavePolicy(t *testing.T) { wg.Wait() } -func TestAccountManager_NetworkUpdates_DeletePeer_Experimental(t *testing.T) { - t.Setenv(network_map.EnvNewNetworkMapBuilder, "true") - testAccountManager_NetworkUpdates_DeletePeer(t) -} - func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) { testAccountManager_NetworkUpdates_DeletePeer(t) } @@ -1397,11 +1377,6 @@ func testAccountManager_NetworkUpdates_DeletePeer(t *testing.T) { wg.Wait() } -func TestAccountManager_NetworkUpdates_DeleteGroup_Experimental(t *testing.T) { - t.Setenv(network_map.EnvNewNetworkMapBuilder, "true") - testAccountManager_NetworkUpdates_DeleteGroup(t) -} - func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) { testAccountManager_NetworkUpdates_DeleteGroup(t) } @@ -1633,75 +1608,6 @@ func TestFileStore_GetRoutesByPrefix(t *testing.T) { assert.Contains(t, routeIDs, route.ID("route-2")) } -func TestAccount_GetRoutesToSync(t *testing.T) { - _, prefix, err := route.ParseNetwork("192.168.64.0/24") - if err != nil { - t.Fatal(err) - } - _, prefix2, err := route.ParseNetwork("192.168.0.0/24") - if err != nil { - t.Fatal(err) - } - account := &types.Account{ - Peers: map[string]*nbpeer.Peer{ - "peer-1": {Key: "peer-1", Meta: nbpeer.PeerSystemMeta{GoOS: "linux"}}, "peer-2": {Key: "peer-2", Meta: nbpeer.PeerSystemMeta{GoOS: "linux"}}, "peer-3": {Key: "peer-1", Meta: nbpeer.PeerSystemMeta{GoOS: "linux"}}, - }, - Groups: map[string]*types.Group{"group1": {ID: "group1", Peers: []string{"peer-1", "peer-2"}}}, - Routes: map[route.ID]*route.Route{ - "route-1": { - ID: "route-1", - Network: prefix, - NetID: "network-1", - Description: "network-1", - Peer: "peer-1", - NetworkType: 0, - Masquerade: false, - Metric: 999, - Enabled: true, - Groups: []string{"group1"}, - }, - "route-2": { - ID: "route-2", - Network: prefix2, - NetID: "network-2", - Description: "network-2", - Peer: "peer-2", - NetworkType: 0, - Masquerade: false, - Metric: 999, - Enabled: true, - Groups: []string{"group1"}, - }, - "route-3": { - ID: "route-3", - Network: prefix, - NetID: "network-1", - Description: "network-1", - Peer: "peer-2", - NetworkType: 0, - Masquerade: false, - Metric: 999, - Enabled: true, - Groups: []string{"group1"}, - }, - }, - } - - routes := account.GetRoutesToSync(context.Background(), "peer-2", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-3"}}, account.GetPeerGroups("peer-2")) - - assert.Len(t, routes, 2) - routeIDs := make(map[route.ID]struct{}, 2) - for _, r := range routes { - routeIDs[r.ID] = struct{}{} - } - assert.Contains(t, routeIDs, route.ID("route-2")) - assert.Contains(t, routeIDs, route.ID("route-3")) - - emptyRoutes := account.GetRoutesToSync(context.Background(), "peer-3", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-2"}}, account.GetPeerGroups("peer-3")) - - assert.Len(t, emptyRoutes, 0) -} - func TestAccount_Copy(t *testing.T) { account := &types.Account{ Id: "account1", @@ -1824,9 +1730,7 @@ func TestAccount_Copy(t *testing.T) { AccountID: "account1", }, }, - NetworkMapCache: &types.NetworkMapBuilder{}, } - account.InitOnce() err := hasNilField(account) if err != nil { t.Fatal(err) diff --git a/management/server/http/handlers/peers/peers_handler.go b/management/server/http/handlers/peers/peers_handler.go index 6b9a69f04..bf6937a49 100644 --- a/management/server/http/handlers/peers/peers_handler.go +++ b/management/server/http/handlers/peers/peers_handler.go @@ -417,7 +417,7 @@ func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) { dnsDomain := h.networkMapController.GetDNSDomain(account.Settings) - netMap := account.GetPeerNetworkMap(r.Context(), peerID, dns.CustomZone{}, nil, validPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil, account.GetActiveGroupUsers()) + netMap := account.GetPeerNetworkMapFromComponents(r.Context(), peerID, dns.CustomZone{}, nil, validPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil, account.GetActiveGroupUsers()) util.WriteJSONObject(r.Context(), w, toAccessiblePeers(netMap, dnsDomain)) } diff --git a/management/server/peer_test.go b/management/server/peer_test.go index 050baa595..17202597a 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -179,11 +179,6 @@ func TestAccountManager_GetNetworkMap(t *testing.T) { testGetNetworkMapGeneral(t) } -func TestAccountManager_GetNetworkMap_Experimental(t *testing.T) { - t.Setenv(network_map.EnvNewNetworkMapBuilder, "true") - testGetNetworkMapGeneral(t) -} - func testGetNetworkMapGeneral(t *testing.T) { manager, _, err := createManager(t) if err != nil { @@ -1016,11 +1011,6 @@ func BenchmarkUpdateAccountPeers(b *testing.B) { } } -func TestUpdateAccountPeers_Experimental(t *testing.T) { - t.Setenv(network_map.EnvNewNetworkMapBuilder, "true") - testUpdateAccountPeers(t) -} - func TestUpdateAccountPeers(t *testing.T) { testUpdateAccountPeers(t) } @@ -1600,7 +1590,6 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) { } func Test_LoginPeer(t *testing.T) { - t.Setenv(network_map.EnvNewNetworkMapBuilder, "true") if runtime.GOOS == "windows" { t.Skip("The SQLite store is not properly supported by Windows yet") } diff --git a/management/server/route_test.go b/management/server/route_test.go index 91bd8b050..d0caf4b9b 100644 --- a/management/server/route_test.go +++ b/management/server/route_test.go @@ -2,10 +2,8 @@ package server import ( "context" - "fmt" "net" "net/netip" - "sort" "testing" "time" @@ -1840,11 +1838,6 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) { }, } - validatedPeers := make(map[string]struct{}) - for p := range account.Peers { - validatedPeers[p] = struct{}{} - } - t.Run("check applied policies for the route", func(t *testing.T) { route1 := account.Routes["route1"] policies := types.GetAllRoutePoliciesFromGroups(account, route1.AccessControlGroups) @@ -1858,116 +1851,6 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) { policies = types.GetAllRoutePoliciesFromGroups(account, route3.AccessControlGroups) assert.Len(t, policies, 0) }) - - t.Run("check peer routes firewall rules", func(t *testing.T) { - routesFirewallRules := account.GetPeerRoutesFirewallRules(context.Background(), "peerA", validatedPeers) - assert.Len(t, routesFirewallRules, 4) - - expectedRoutesFirewallRules := []*types.RouteFirewallRule{ - { - SourceRanges: []string{ - fmt.Sprintf(types.AllowedIPsFormat, peerCIp), - fmt.Sprintf(types.AllowedIPsFormat, peerHIp), - fmt.Sprintf(types.AllowedIPsFormat, peerBIp), - }, - Action: "accept", - Destination: "192.168.0.0/16", - Protocol: "all", - Port: 80, - RouteID: "route1:peerA", - }, - { - SourceRanges: []string{ - fmt.Sprintf(types.AllowedIPsFormat, peerCIp), - fmt.Sprintf(types.AllowedIPsFormat, peerHIp), - fmt.Sprintf(types.AllowedIPsFormat, peerBIp), - }, - Action: "accept", - Destination: "192.168.0.0/16", - Protocol: "all", - Port: 320, - RouteID: "route1:peerA", - }, - } - additionalFirewallRule := []*types.RouteFirewallRule{ - { - SourceRanges: []string{ - fmt.Sprintf(types.AllowedIPsFormat, peerJIp), - }, - Action: "accept", - Destination: "192.168.10.0/16", - Protocol: "tcp", - Port: 80, - RouteID: "route4:peerA", - }, - { - SourceRanges: []string{ - fmt.Sprintf(types.AllowedIPsFormat, peerKIp), - }, - Action: "accept", - Destination: "192.168.10.0/16", - Protocol: "all", - RouteID: "route4:peerA", - }, - } - - assert.ElementsMatch(t, orderRuleSourceRanges(routesFirewallRules), orderRuleSourceRanges(append(expectedRoutesFirewallRules, additionalFirewallRule...))) - - // peerD is also the routing peer for route1, should contain same routes firewall rules as peerA - routesFirewallRules = account.GetPeerRoutesFirewallRules(context.Background(), "peerD", validatedPeers) - assert.Len(t, routesFirewallRules, 2) - for _, rule := range expectedRoutesFirewallRules { - rule.RouteID = "route1:peerD" - } - assert.ElementsMatch(t, orderRuleSourceRanges(routesFirewallRules), orderRuleSourceRanges(expectedRoutesFirewallRules)) - - // peerE is a single routing peer for route 2 and route 3 - routesFirewallRules = account.GetPeerRoutesFirewallRules(context.Background(), "peerE", validatedPeers) - assert.Len(t, routesFirewallRules, 3) - - expectedRoutesFirewallRules = []*types.RouteFirewallRule{ - { - SourceRanges: []string{"100.65.250.202/32", "100.65.13.186/32"}, - Action: "accept", - Destination: existingNetwork.String(), - Protocol: "tcp", - PortRange: types.RulePortRange{Start: 80, End: 350}, - RouteID: "route2", - }, - { - SourceRanges: []string{"0.0.0.0/0"}, - Action: "accept", - Destination: "192.0.2.0/32", - Protocol: "all", - Domains: domain.List{"example.com"}, - IsDynamic: true, - RouteID: "route3", - }, - { - SourceRanges: []string{"::/0"}, - Action: "accept", - Destination: "192.0.2.0/32", - Protocol: "all", - Domains: domain.List{"example.com"}, - IsDynamic: true, - RouteID: "route3", - }, - } - assert.ElementsMatch(t, orderRuleSourceRanges(routesFirewallRules), orderRuleSourceRanges(expectedRoutesFirewallRules)) - - // peerC is part of route1 distribution groups but should not receive the routes firewall rules - routesFirewallRules = account.GetPeerRoutesFirewallRules(context.Background(), "peerC", validatedPeers) - assert.Len(t, routesFirewallRules, 0) - }) - -} - -// orderList is a helper function to sort a list of strings -func orderRuleSourceRanges(ruleList []*types.RouteFirewallRule) []*types.RouteFirewallRule { - for _, rule := range ruleList { - sort.Strings(rule.SourceRanges) - } - return ruleList } func TestRouteAccountPeersUpdate(t *testing.T) { @@ -2665,11 +2548,6 @@ func TestAccount_GetPeerNetworkResourceFirewallRules(t *testing.T) { }, } - validatedPeers := make(map[string]struct{}) - for p := range account.Peers { - validatedPeers[p] = struct{}{} - } - t.Run("validate applied policies for different network resources", func(t *testing.T) { // Test case: Resource1 is directly applied to the policy (policyResource1) policies := account.GetPoliciesForNetworkResource("resource1") @@ -2693,127 +2571,4 @@ func TestAccount_GetPeerNetworkResourceFirewallRules(t *testing.T) { policies = account.GetPoliciesForNetworkResource("resource6") assert.Len(t, policies, 1, "resource6 should have exactly 1 policy applied via access control groups") }) - - t.Run("validate routing peer firewall rules for network resources", func(t *testing.T) { - resourcePoliciesMap := account.GetResourcePoliciesMap() - resourceRoutersMap := account.GetResourceRoutersMap() - _, routes, sourcePeers := account.GetNetworkResourcesRoutesToSync(context.Background(), "peerA", resourcePoliciesMap, resourceRoutersMap) - firewallRules := account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers["peerA"], validatedPeers, routes, resourcePoliciesMap) - assert.Len(t, firewallRules, 4) - assert.Len(t, sourcePeers, 5) - - expectedFirewallRules := []*types.RouteFirewallRule{ - { - SourceRanges: []string{ - fmt.Sprintf(types.AllowedIPsFormat, peerCIp), - fmt.Sprintf(types.AllowedIPsFormat, peerHIp), - fmt.Sprintf(types.AllowedIPsFormat, peerBIp), - }, - Action: "accept", - Destination: "192.168.0.0/16", - Protocol: "all", - Port: 80, - RouteID: "resource2:peerA", - }, - { - SourceRanges: []string{ - fmt.Sprintf(types.AllowedIPsFormat, peerCIp), - fmt.Sprintf(types.AllowedIPsFormat, peerHIp), - fmt.Sprintf(types.AllowedIPsFormat, peerBIp), - }, - Action: "accept", - Destination: "192.168.0.0/16", - Protocol: "all", - Port: 320, - RouteID: "resource2:peerA", - }, - } - - additionalFirewallRules := []*types.RouteFirewallRule{ - { - SourceRanges: []string{ - fmt.Sprintf(types.AllowedIPsFormat, peerJIp), - }, - Action: "accept", - Destination: "192.0.2.0/32", - Protocol: "tcp", - Port: 80, - Domains: domain.List{"example.com"}, - IsDynamic: true, - RouteID: "resource4:peerA", - }, - { - SourceRanges: []string{ - fmt.Sprintf(types.AllowedIPsFormat, peerKIp), - }, - Action: "accept", - Destination: "192.0.2.0/32", - Protocol: "all", - Domains: domain.List{"example.com"}, - IsDynamic: true, - RouteID: "resource4:peerA", - }, - } - assert.ElementsMatch(t, orderRuleSourceRanges(firewallRules), orderRuleSourceRanges(append(expectedFirewallRules, additionalFirewallRules...))) - - // peerD is also the routing peer for resource2 - _, routes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), "peerD", resourcePoliciesMap, resourceRoutersMap) - firewallRules = account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers["peerD"], validatedPeers, routes, resourcePoliciesMap) - assert.Len(t, firewallRules, 2) - for _, rule := range expectedFirewallRules { - rule.RouteID = "resource2:peerD" - } - assert.ElementsMatch(t, orderRuleSourceRanges(firewallRules), orderRuleSourceRanges(expectedFirewallRules)) - assert.Len(t, sourcePeers, 3) - - // peerE is a single routing peer for resource1 and resource3 - // PeerE should only receive rules for resource1 since resource3 has no applied policy - _, routes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), "peerE", resourcePoliciesMap, resourceRoutersMap) - firewallRules = account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers["peerE"], validatedPeers, routes, resourcePoliciesMap) - assert.Len(t, firewallRules, 1) - assert.Len(t, sourcePeers, 2) - - expectedFirewallRules = []*types.RouteFirewallRule{ - { - SourceRanges: []string{"100.65.250.202/32", "100.65.13.186/32"}, - Action: "accept", - Destination: "10.10.10.0/24", - Protocol: "tcp", - PortRange: types.RulePortRange{Start: 80, End: 350}, - RouteID: "resource1:peerE", - }, - } - assert.ElementsMatch(t, orderRuleSourceRanges(firewallRules), orderRuleSourceRanges(expectedFirewallRules)) - - // peerC is part of distribution groups for resource2 but should not receive the firewall rules - firewallRules = account.GetPeerRoutesFirewallRules(context.Background(), "peerC", validatedPeers) - assert.Len(t, firewallRules, 0) - - // peerL is the single routing peer for resource5 - _, routes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), "peerL", resourcePoliciesMap, resourceRoutersMap) - assert.Len(t, routes, 1) - firewallRules = account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers["peerL"], validatedPeers, routes, resourcePoliciesMap) - assert.Len(t, firewallRules, 1) - assert.Len(t, sourcePeers, 1) - - expectedFirewallRules = []*types.RouteFirewallRule{ - { - SourceRanges: []string{"100.65.29.67/32"}, - Action: "accept", - Destination: "10.12.12.1/32", - Protocol: "tcp", - Port: 8080, - RouteID: "resource5:peerL", - }, - } - assert.ElementsMatch(t, orderRuleSourceRanges(firewallRules), orderRuleSourceRanges(expectedFirewallRules)) - - _, routes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), "peerM", resourcePoliciesMap, resourceRoutersMap) - assert.Len(t, routes, 1) - assert.Len(t, sourcePeers, 0) - - _, routes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), "peerN", resourcePoliciesMap, resourceRoutersMap) - assert.Len(t, routes, 1) - assert.Len(t, sourcePeers, 2) - }) } diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index 0ff57b752..0a716d08d 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -1196,7 +1196,6 @@ func (s *SqlStore) getAccountGorm(ctx context.Context, accountID string) (*types account.NameServerGroups[ns.ID] = &ns } account.NameServerGroupsG = nil - account.InitOnce() return &account, nil } @@ -1635,7 +1634,6 @@ func (s *SqlStore) getAccount(ctx context.Context, accountID string) (*types.Acc if sExtraIntegratedValidatorGroups.Valid { _ = json.Unmarshal([]byte(sExtraIntegratedValidatorGroups.String), &account.Settings.Extra.IntegratedValidatorGroups) } - account.InitOnce() return &account, nil } diff --git a/management/server/types/account.go b/management/server/types/account.go index c448813db..e7c1e2dce 100644 --- a/management/server/types/account.go +++ b/management/server/types/account.go @@ -8,7 +8,6 @@ import ( "slices" "strconv" "strings" - "sync" "time" "github.com/hashicorp/go-multierror" @@ -27,7 +26,6 @@ import ( networkTypes "github.com/netbirdio/netbird/management/server/networks/types" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/posture" - "github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/util" "github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/shared/management/domain" @@ -110,16 +108,9 @@ type Account struct { NetworkResources []*resourceTypes.NetworkResource `gorm:"foreignKey:AccountID;references:id"` Onboarding AccountOnboarding `gorm:"foreignKey:AccountID;references:id;constraint:OnDelete:CASCADE"` - NetworkMapCache *NetworkMapBuilder `gorm:"-"` - nmapInitOnce *sync.Once `gorm:"-"` - ReverseProxyFreeDomainNonce string } -func (a *Account) InitOnce() { - a.nmapInitOnce = &sync.Once{} -} - // this class is used by gorm only type PrimaryAccountInfo struct { IsDomainPrimaryAccount bool @@ -155,108 +146,6 @@ func (o AccountOnboarding) IsEqual(onboarding AccountOnboarding) bool { o.SignupFormPending == onboarding.SignupFormPending } -// GetRoutesToSync returns the enabled routes for the peer ID and the routes -// from the ACL peers that have distribution groups associated with the peer ID. -// Please mind, that the returned route.Route objects will contain Peer.Key instead of Peer.ID. -func (a *Account) GetRoutesToSync(ctx context.Context, peerID string, aclPeers []*nbpeer.Peer, peerGroups LookupMap) []*route.Route { - routes, peerDisabledRoutes := a.getRoutingPeerRoutes(ctx, peerID) - peerRoutesMembership := make(LookupMap) - for _, r := range append(routes, peerDisabledRoutes...) { - peerRoutesMembership[string(r.GetHAUniqueID())] = struct{}{} - } - - for _, peer := range aclPeers { - activeRoutes, _ := a.getRoutingPeerRoutes(ctx, peer.ID) - groupFilteredRoutes := a.filterRoutesByGroups(activeRoutes, peerGroups) - filteredRoutes := a.filterRoutesFromPeersOfSameHAGroup(groupFilteredRoutes, peerRoutesMembership) - routes = append(routes, filteredRoutes...) - } - - return routes -} - -// filterRoutesFromPeersOfSameHAGroup filters and returns a list of routes that don't share the same HA route membership -func (a *Account) filterRoutesFromPeersOfSameHAGroup(routes []*route.Route, peerMemberships LookupMap) []*route.Route { - var filteredRoutes []*route.Route - for _, r := range routes { - _, found := peerMemberships[string(r.GetHAUniqueID())] - if !found { - filteredRoutes = append(filteredRoutes, r) - } - } - return filteredRoutes -} - -// filterRoutesByGroups returns a list with routes that have distribution groups in the group's map -func (a *Account) filterRoutesByGroups(routes []*route.Route, groupListMap LookupMap) []*route.Route { - var filteredRoutes []*route.Route - for _, r := range routes { - for _, groupID := range r.Groups { - _, found := groupListMap[groupID] - if found { - filteredRoutes = append(filteredRoutes, r) - break - } - } - } - return filteredRoutes -} - -// getRoutingPeerRoutes returns the enabled and disabled lists of routes that the given routing peer serves -// Please mind, that the returned route.Route objects will contain Peer.Key instead of Peer.ID. -// If the given is not a routing peer, then the lists are empty. -func (a *Account) getRoutingPeerRoutes(ctx context.Context, peerID string) (enabledRoutes []*route.Route, disabledRoutes []*route.Route) { - - peer := a.GetPeer(peerID) - if peer == nil { - log.WithContext(ctx).Errorf("peer %s that doesn't exist under account %s", peerID, a.Id) - return enabledRoutes, disabledRoutes - } - - seenRoute := make(map[route.ID]struct{}) - - takeRoute := func(r *route.Route, id string) { - if _, ok := seenRoute[r.ID]; ok { - return - } - seenRoute[r.ID] = struct{}{} - - if r.Enabled { - r.Peer = peer.Key - enabledRoutes = append(enabledRoutes, r) - return - } - disabledRoutes = append(disabledRoutes, r) - } - - for _, r := range a.Routes { - for _, groupID := range r.PeerGroups { - group := a.GetGroup(groupID) - if group == nil { - log.WithContext(ctx).Errorf("route %s has peers group %s that doesn't exist under account %s", r.ID, groupID, a.Id) - continue - } - for _, id := range group.Peers { - if id != peerID { - continue - } - - newPeerRoute := r.Copy() - newPeerRoute.Peer = id - newPeerRoute.PeerGroups = nil - newPeerRoute.ID = route.ID(string(r.ID) + ":" + id) // we have to provide unique route id when distribute network map - takeRoute(newPeerRoute, id) - break - } - } - if r.Peer == peerID { - takeRoute(r.Copy(), peerID) - } - } - - return enabledRoutes, disabledRoutes -} - // GetRoutesByPrefixOrDomains return list of routes by account and route prefix func (a *Account) GetRoutesByPrefixOrDomains(prefix netip.Prefix, domains domain.List) []*route.Route { var routes []*route.Route @@ -276,106 +165,6 @@ func (a *Account) GetGroup(groupID string) *Group { return a.Groups[groupID] } -// GetPeerNetworkMap returns the networkmap for the given peer ID. -func (a *Account) GetPeerNetworkMap( - ctx context.Context, - peerID string, - peersCustomZone nbdns.CustomZone, - accountZones []*zones.Zone, - validatedPeersMap map[string]struct{}, - resourcePolicies map[string][]*Policy, - routers map[string]map[string]*routerTypes.NetworkRouter, - metrics *telemetry.AccountManagerMetrics, - groupIDToUserIDs map[string][]string, -) *NetworkMap { - start := time.Now() - peer := a.Peers[peerID] - if peer == nil { - return &NetworkMap{ - Network: a.Network.Copy(), - } - } - - if _, ok := validatedPeersMap[peerID]; !ok { - return &NetworkMap{ - Network: a.Network.Copy(), - } - } - - peerGroups := a.GetPeerGroups(peerID) - - aclPeers, firewallRules, authorizedUsers, enableSSH := a.GetPeerConnectionResources(ctx, peer, validatedPeersMap, groupIDToUserIDs) - // exclude expired peers - var peersToConnect []*nbpeer.Peer - var expiredPeers []*nbpeer.Peer - for _, p := range aclPeers { - expired, _ := p.LoginExpired(a.Settings.PeerLoginExpiration) - if a.Settings.PeerLoginExpirationEnabled && expired { - expiredPeers = append(expiredPeers, p) - continue - } - peersToConnect = append(peersToConnect, p) - } - - routesUpdate := a.GetRoutesToSync(ctx, peerID, peersToConnect, peerGroups) - routesFirewallRules := a.GetPeerRoutesFirewallRules(ctx, peerID, validatedPeersMap) - isRouter, networkResourcesRoutes, sourcePeers := a.GetNetworkResourcesRoutesToSync(ctx, peerID, resourcePolicies, routers) - var networkResourcesFirewallRules []*RouteFirewallRule - if isRouter { - networkResourcesFirewallRules = a.GetPeerNetworkResourceFirewallRules(ctx, peer, validatedPeersMap, networkResourcesRoutes, resourcePolicies) - } - peersToConnectIncludingRouters := a.addNetworksRoutingPeers(networkResourcesRoutes, peer, peersToConnect, expiredPeers, isRouter, sourcePeers) - - dnsManagementStatus := a.getPeerDNSManagementStatus(peerID) - dnsUpdate := nbdns.Config{ - ServiceEnable: dnsManagementStatus, - } - - if dnsManagementStatus { - var zones []nbdns.CustomZone - - if peersCustomZone.Domain != "" { - records := filterZoneRecordsForPeers(peer, peersCustomZone, peersToConnectIncludingRouters, expiredPeers) - zones = append(zones, nbdns.CustomZone{ - Domain: peersCustomZone.Domain, - Records: records, - }) - } - - filteredAccountZones := filterPeerAppliedZones(ctx, accountZones, peerGroups) - zones = append(zones, filteredAccountZones...) - - dnsUpdate.CustomZones = zones - dnsUpdate.NameServerGroups = getPeerNSGroups(a, peerID) - } - - nm := &NetworkMap{ - Peers: peersToConnectIncludingRouters, - Network: a.Network.Copy(), - Routes: slices.Concat(networkResourcesRoutes, routesUpdate), - DNSConfig: dnsUpdate, - OfflinePeers: expiredPeers, - FirewallRules: firewallRules, - RoutesFirewallRules: slices.Concat(networkResourcesFirewallRules, routesFirewallRules), - AuthorizedUsers: authorizedUsers, - EnableSSH: enableSSH, - } - - if metrics != nil { - objectCount := int64(len(peersToConnectIncludingRouters) + len(expiredPeers) + len(routesUpdate) + len(networkResourcesRoutes) + len(firewallRules) + +len(networkResourcesFirewallRules) + len(routesFirewallRules)) - metrics.CountNetworkMapObjects(objectCount) - metrics.CountGetPeerNetworkMapDuration(time.Since(start)) - - if objectCount > 5000 { - log.WithContext(ctx).Tracef("account: %s has a total resource count of %d objects, "+ - "peers to connect: %d, expired peers: %d, routes: %d, firewall rules: %d, network resources routes: %d, network resources firewall rules: %d, routes firewall rules: %d", - a.Id, objectCount, len(peersToConnectIncludingRouters), len(expiredPeers), len(routesUpdate), len(firewallRules), len(networkResourcesRoutes), len(networkResourcesFirewallRules), len(routesFirewallRules)) - } - } - - return nm -} - func (a *Account) addNetworksRoutingPeers( networkResourcesRoutes []*route.Route, peer *nbpeer.Peer, @@ -421,39 +210,6 @@ func (a *Account) addNetworksRoutingPeers( return peersToConnect } -func getPeerNSGroups(account *Account, peerID string) []*nbdns.NameServerGroup { - groupList := account.GetPeerGroups(peerID) - - var peerNSGroups []*nbdns.NameServerGroup - - for _, nsGroup := range account.NameServerGroups { - if !nsGroup.Enabled { - continue - } - for _, gID := range nsGroup.Groups { - _, found := groupList[gID] - if found { - if !peerIsNameserver(account.GetPeer(peerID), nsGroup) { - peerNSGroups = append(peerNSGroups, nsGroup.Copy()) - break - } - } - } - } - - return peerNSGroups -} - -// peerIsNameserver returns true if the peer is a nameserver for a nsGroup -func peerIsNameserver(peer *nbpeer.Peer, nsGroup *nbdns.NameServerGroup) bool { - for _, ns := range nsGroup.NameServers { - if peer.IP.Equal(ns.IP.AsSlice()) { - return true - } - } - return false -} - func AddPeerLabelsToAccount(ctx context.Context, account *Account, peerLabels LookupMap) { for _, peer := range account.Peers { label, err := GetPeerHostLabel(peer.Name, peerLabels) @@ -800,19 +556,6 @@ func (a *Account) GetPeerGroupsList(peerID string) []string { return grps } -func (a *Account) getPeerDNSManagementStatus(peerID string) bool { - peerGroups := a.GetPeerGroups(peerID) - enabled := true - for _, groupID := range a.DNSSettings.DisabledManagementGroups { - _, found := peerGroups[groupID] - if found { - enabled = false - break - } - } - return enabled -} - func (a *Account) GetPeerGroups(peerID string) LookupMap { groupList := make(LookupMap) for groupID, group := range a.Groups { @@ -941,8 +684,6 @@ func (a *Account) Copy() *Account { NetworkResources: networkResources, Services: services, Onboarding: a.Onboarding, - NetworkMapCache: a.NetworkMapCache, - nmapInitOnce: a.nmapInitOnce, Domains: domains, } } @@ -1304,31 +1045,6 @@ func (a *Account) GetPostureChecks(postureChecksID string) *posture.Checks { return nil } -// GetPeerRoutesFirewallRules gets the routes firewall rules associated with a routing peer ID for the account. -func (a *Account) GetPeerRoutesFirewallRules(ctx context.Context, peerID string, validatedPeersMap map[string]struct{}) []*RouteFirewallRule { - routesFirewallRules := make([]*RouteFirewallRule, 0, len(a.Routes)) - - enabledRoutes, _ := a.getRoutingPeerRoutes(ctx, peerID) - for _, route := range enabledRoutes { - // If no access control groups are specified, accept all traffic. - if len(route.AccessControlGroups) == 0 { - defaultPermit := getDefaultPermit(route) - routesFirewallRules = append(routesFirewallRules, defaultPermit...) - continue - } - - distributionPeers := a.getDistributionGroupsPeers(route) - - for _, accessGroup := range route.AccessControlGroups { - policies := GetAllRoutePoliciesFromGroups(a, []string{accessGroup}) - rules := a.getRouteFirewallRules(ctx, peerID, policies, route, validatedPeersMap, distributionPeers) - routesFirewallRules = append(routesFirewallRules, rules...) - } - } - - return routesFirewallRules -} - func (a *Account) getRouteFirewallRules(ctx context.Context, peerID string, policies []*Policy, route *route.Route, validatedPeersMap map[string]struct{}, distributionPeers map[string]struct{}) []*RouteFirewallRule { var fwRules []*RouteFirewallRule for _, policy := range policies { @@ -1387,50 +1103,6 @@ func (a *Account) getRulePeers(rule *PolicyRule, postureChecks []string, peerID return distributionGroupPeers } -func (a *Account) getDistributionGroupsPeers(route *route.Route) map[string]struct{} { - distPeers := make(map[string]struct{}) - for _, id := range route.Groups { - group := a.Groups[id] - if group == nil { - continue - } - - for _, pID := range group.Peers { - distPeers[pID] = struct{}{} - } - } - return distPeers -} - -func getDefaultPermit(route *route.Route) []*RouteFirewallRule { - var rules []*RouteFirewallRule - - sources := []string{"0.0.0.0/0"} - if route.Network.Addr().Is6() { - sources = []string{"::/0"} - } - rule := RouteFirewallRule{ - SourceRanges: sources, - Action: string(PolicyTrafficActionAccept), - Destination: route.Network.String(), - Protocol: string(PolicyRuleProtocolALL), - Domains: route.Domains, - IsDynamic: route.IsDynamic(), - RouteID: route.ID, - } - - rules = append(rules, &rule) - - // dynamic routes always contain an IPv4 placeholder as destination, hence we must add IPv6 rules additionally - if route.IsDynamic() { - ruleV6 := rule - ruleV6.SourceRanges = []string{"::/0"} - rules = append(rules, &ruleV6) - } - - return rules -} - // GetAllRoutePoliciesFromGroups retrieves route policies associated with the specified access control groups // and returns a list of policies that have rules with destinations matching the specified groups. func GetAllRoutePoliciesFromGroups(account *Account, accessControlGroups []string) []*Policy { @@ -1508,65 +1180,6 @@ func (a *Account) GetResourcePoliciesMap() map[string][]*Policy { return resourcePolicies } -// GetNetworkResourcesRoutesToSync returns network routes for syncing with a specific peer and its ACL peers. -func (a *Account) GetNetworkResourcesRoutesToSync(ctx context.Context, peerID string, resourcePolicies map[string][]*Policy, routers map[string]map[string]*routerTypes.NetworkRouter) (bool, []*route.Route, map[string]struct{}) { - var isRoutingPeer bool - var routes []*route.Route - allSourcePeers := make(map[string]struct{}, len(a.Peers)) - - for _, resource := range a.NetworkResources { - if !resource.Enabled { - continue - } - - var addSourcePeers bool - - networkRoutingPeers, exists := routers[resource.NetworkID] - if exists { - if router, ok := networkRoutingPeers[peerID]; ok { - isRoutingPeer, addSourcePeers = true, true - routes = append(routes, a.getNetworkResourcesRoutes(resource, peerID, router, resourcePolicies)...) - } - } - - addedResourceRoute := false - for _, policy := range resourcePolicies[resource.ID] { - var peers []string - if policy.Rules[0].SourceResource.Type == ResourceTypePeer && policy.Rules[0].SourceResource.ID != "" { - peers = []string{policy.Rules[0].SourceResource.ID} - } else { - peers = a.getUniquePeerIDsFromGroupsIDs(ctx, policy.SourceGroups()) - } - if addSourcePeers { - for _, pID := range a.getPostureValidPeers(peers, policy.SourcePostureChecks) { - allSourcePeers[pID] = struct{}{} - } - } else if slices.Contains(peers, peerID) && a.validatePostureChecksOnPeer(ctx, policy.SourcePostureChecks, peerID) { - // add routes for the resource if the peer is in the distribution group - for peerId, router := range networkRoutingPeers { - routes = append(routes, a.getNetworkResourcesRoutes(resource, peerId, router, resourcePolicies)...) - } - addedResourceRoute = true - } - if addedResourceRoute { - break - } - } - } - - return isRoutingPeer, routes, allSourcePeers -} - -func (a *Account) getPostureValidPeers(inputPeers []string, postureChecksIDs []string) []string { - var dest []string - for _, peerID := range inputPeers { - if a.validatePostureChecksOnPeer(context.Background(), postureChecksIDs, peerID) { - dest = append(dest, peerID) - } - } - return dest -} - func (a *Account) getUniquePeerIDsFromGroupsIDs(ctx context.Context, groups []string) []string { peerIDs := make(map[string]struct{}, len(groups)) // we expect at least one peer per group as initial capacity for _, groupID := range groups { @@ -1658,22 +1271,6 @@ func (a *Account) GetPoliciesAppliedInNetwork(networkID string) []string { return result } -// getNetworkResourcesRoutes convert the network resources list to routes list. -func (a *Account) getNetworkResourcesRoutes(resource *resourceTypes.NetworkResource, peerId string, router *routerTypes.NetworkRouter, resourcePolicies map[string][]*Policy) []*route.Route { - resourceAppliedPolicies := resourcePolicies[resource.ID] - - var routes []*route.Route - // distribute the resource routes only if there is policy applied to it - if len(resourceAppliedPolicies) > 0 { - peer := a.GetPeer(peerId) - if peer != nil { - routes = append(routes, resource.ToRoute(peer, router)) - } - } - - return routes -} - func (a *Account) GetResourceRoutersMap() map[string]map[string]*routerTypes.NetworkRouter { routers := make(map[string]map[string]*routerTypes.NetworkRouter) diff --git a/management/server/types/account_test.go b/management/server/types/account_test.go index 00ba29b7f..9b1c9e31d 100644 --- a/management/server/types/account_test.go +++ b/management/server/types/account_test.go @@ -4,8 +4,6 @@ import ( "context" "fmt" "net" - "net/netip" - "slices" "testing" "github.com/miekg/dns" @@ -19,7 +17,6 @@ import ( routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" networkTypes "github.com/netbirdio/netbird/management/server/networks/types" nbpeer "github.com/netbirdio/netbird/management/server/peer" - "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/route" ) @@ -451,402 +448,6 @@ func Test_AddNetworksRoutingPeersHandlesNoMissingPeers(t *testing.T) { require.Len(t, result, 0) } -const ( - accID = "accountID" - network1ID = "network1ID" - group1ID = "group1" - accNetResourcePeer1ID = "peer1" - accNetResourcePeer2ID = "peer2" - accNetResourceRouter1ID = "router1" - accNetResource1ID = "resource1ID" - accNetResourceRestrictPostureCheckID = "restrictPostureCheck" - accNetResourceRelaxedPostureCheckID = "relaxedPostureCheck" - accNetResourceLockedPostureCheckID = "lockedPostureCheck" - accNetResourceLinuxPostureCheckID = "linuxPostureCheck" -) - -var ( - accNetResourcePeer1IP = net.IP{192, 168, 1, 1} - accNetResourcePeer2IP = net.IP{192, 168, 1, 2} - accNetResourceRouter1IP = net.IP{192, 168, 1, 3} - accNetResourceValidPeers = map[string]struct{}{accNetResourcePeer1ID: {}, accNetResourcePeer2ID: {}} -) - -func getBasicAccountsWithResource() *Account { - return &Account{ - Id: accID, - Peers: map[string]*nbpeer.Peer{ - accNetResourcePeer1ID: { - ID: accNetResourcePeer1ID, - AccountID: accID, - Key: "peer1Key", - IP: accNetResourcePeer1IP, - Meta: nbpeer.PeerSystemMeta{ - GoOS: "linux", - WtVersion: "0.35.1", - KernelVersion: "4.4.0", - }, - }, - accNetResourcePeer2ID: { - ID: accNetResourcePeer2ID, - AccountID: accID, - Key: "peer2Key", - IP: accNetResourcePeer2IP, - Meta: nbpeer.PeerSystemMeta{ - GoOS: "windows", - WtVersion: "0.34.1", - KernelVersion: "4.4.0", - }, - }, - accNetResourceRouter1ID: { - ID: accNetResourceRouter1ID, - AccountID: accID, - Key: "router1Key", - IP: accNetResourceRouter1IP, - Meta: nbpeer.PeerSystemMeta{ - GoOS: "linux", - WtVersion: "0.35.1", - KernelVersion: "4.4.0", - }, - }, - }, - Groups: map[string]*Group{ - group1ID: { - ID: group1ID, - Peers: []string{accNetResourcePeer1ID, accNetResourcePeer2ID}, - }, - }, - Networks: []*networkTypes.Network{ - { - ID: network1ID, - AccountID: accID, - Name: "network1", - }, - }, - NetworkRouters: []*routerTypes.NetworkRouter{ - { - ID: accNetResourceRouter1ID, - NetworkID: network1ID, - AccountID: accID, - Peer: accNetResourceRouter1ID, - PeerGroups: []string{}, - Masquerade: false, - Metric: 100, - Enabled: true, - }, - }, - NetworkResources: []*resourceTypes.NetworkResource{ - { - ID: accNetResource1ID, - AccountID: accID, - NetworkID: network1ID, - Address: "10.10.10.0/24", - Prefix: netip.MustParsePrefix("10.10.10.0/24"), - Type: resourceTypes.NetworkResourceType("subnet"), - Enabled: true, - }, - }, - Policies: []*Policy{ - { - ID: "policy1ID", - AccountID: accID, - Enabled: true, - Rules: []*PolicyRule{ - { - ID: "rule1ID", - Enabled: true, - Sources: []string{group1ID}, - DestinationResource: Resource{ - ID: accNetResource1ID, - Type: "Host", - }, - Protocol: PolicyRuleProtocolTCP, - Ports: []string{"80"}, - Action: PolicyTrafficActionAccept, - }, - }, - SourcePostureChecks: nil, - }, - }, - PostureChecks: []*posture.Checks{ - { - ID: accNetResourceRestrictPostureCheckID, - Name: accNetResourceRestrictPostureCheckID, - Checks: posture.ChecksDefinition{ - NBVersionCheck: &posture.NBVersionCheck{ - MinVersion: "0.35.0", - }, - }, - }, - { - ID: accNetResourceRelaxedPostureCheckID, - Name: accNetResourceRelaxedPostureCheckID, - Checks: posture.ChecksDefinition{ - NBVersionCheck: &posture.NBVersionCheck{ - MinVersion: "0.0.1", - }, - }, - }, - { - ID: accNetResourceLockedPostureCheckID, - Name: accNetResourceLockedPostureCheckID, - Checks: posture.ChecksDefinition{ - NBVersionCheck: &posture.NBVersionCheck{ - MinVersion: "7.7.7", - }, - }, - }, - { - ID: accNetResourceLinuxPostureCheckID, - Name: accNetResourceLinuxPostureCheckID, - Checks: posture.ChecksDefinition{ - OSVersionCheck: &posture.OSVersionCheck{ - Linux: &posture.MinKernelVersionCheck{ - MinKernelVersion: "0.0.0"}, - }, - }, - }, - }, - } -} - -func Test_NetworksNetMapGenWithNoPostureChecks(t *testing.T) { - account := getBasicAccountsWithResource() - - // all peers should match the policy - - // validate for peer1 - isRouter, networkResourcesRoutes, sourcePeers := account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap()) - assert.False(t, isRouter, "expected router status") - assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match") - assert.Len(t, sourcePeers, 0, "expected source peers don't match") - - // validate for peer2 - isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer2ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap()) - assert.False(t, isRouter, "expected router status") - assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match") - assert.Len(t, sourcePeers, 0, "expected source peers don't match") - - // validate routes for router1 - isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourceRouter1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap()) - assert.True(t, isRouter, "should be router") - assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match") - assert.Len(t, sourcePeers, 2, "expected source peers don't match") - assert.NotNil(t, sourcePeers[accNetResourcePeer1ID], "expected source peers don't match") - assert.NotNil(t, sourcePeers[accNetResourcePeer2ID], "expected source peers don't match") - - // validate rules for router1 - rules := account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers[accNetResourceRouter1ID], accNetResourceValidPeers, networkResourcesRoutes, account.GetResourcePoliciesMap()) - assert.Len(t, rules, 1, "expected rules count don't match") - assert.Equal(t, uint16(80), rules[0].Port, "should have port 80") - assert.Equal(t, "tcp", rules[0].Protocol, "should have protocol tcp") - if !slices.Contains(rules[0].SourceRanges, accNetResourcePeer1IP.String()+"/32") { - t.Errorf("%s should have source range of peer1 %s", rules[0].SourceRanges, accNetResourcePeer1IP.String()) - } - if !slices.Contains(rules[0].SourceRanges, accNetResourcePeer2IP.String()+"/32") { - t.Errorf("%s should have source range of peer2 %s", rules[0].SourceRanges, accNetResourcePeer2IP.String()) - } -} - -func Test_NetworksNetMapGenWithPostureChecks(t *testing.T) { - account := getBasicAccountsWithResource() - - // should allow peer1 to match the policy - policy := account.Policies[0] - policy.SourcePostureChecks = []string{accNetResourceRestrictPostureCheckID} - - // validate for peer1 - isRouter, networkResourcesRoutes, sourcePeers := account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap()) - assert.False(t, isRouter, "expected router status") - assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match") - assert.Len(t, sourcePeers, 0, "expected source peers don't match") - - // validate for peer2 - isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer2ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap()) - assert.False(t, isRouter, "expected router status") - assert.Len(t, networkResourcesRoutes, 0, "expected network resource route don't match") - assert.Len(t, sourcePeers, 0, "expected source peers don't match") - - // validate routes for router1 - isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourceRouter1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap()) - assert.True(t, isRouter, "should be router") - assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match") - assert.Len(t, sourcePeers, 1, "expected source peers don't match") - assert.NotNil(t, sourcePeers[accNetResourcePeer1ID], "expected source peers don't match") - - // validate rules for router1 - rules := account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers[accNetResourceRouter1ID], accNetResourceValidPeers, networkResourcesRoutes, account.GetResourcePoliciesMap()) - assert.Len(t, rules, 1, "expected rules count don't match") - assert.Equal(t, uint16(80), rules[0].Port, "should have port 80") - assert.Equal(t, "tcp", rules[0].Protocol, "should have protocol tcp") - if !slices.Contains(rules[0].SourceRanges, accNetResourcePeer1IP.String()+"/32") { - t.Errorf("%s should have source range of peer1 %s", rules[0].SourceRanges, accNetResourcePeer1IP.String()) - } - if slices.Contains(rules[0].SourceRanges, accNetResourcePeer2IP.String()+"/32") { - t.Errorf("%s should not have source range of peer2 %s", rules[0].SourceRanges, accNetResourcePeer2IP.String()) - } -} - -func Test_NetworksNetMapGenWithNoMatchedPostureChecks(t *testing.T) { - account := getBasicAccountsWithResource() - - // should not match any peer - policy := account.Policies[0] - policy.SourcePostureChecks = []string{accNetResourceLockedPostureCheckID} - - // validate for peer1 - isRouter, networkResourcesRoutes, sourcePeers := account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap()) - assert.False(t, isRouter, "expected router status") - assert.Len(t, networkResourcesRoutes, 0, "expected network resource route don't match") - assert.Len(t, sourcePeers, 0, "expected source peers don't match") - - // validate for peer2 - isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer2ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap()) - assert.False(t, isRouter, "expected router status") - assert.Len(t, networkResourcesRoutes, 0, "expected network resource route don't match") - assert.Len(t, sourcePeers, 0, "expected source peers don't match") - - // validate routes for router1 - isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourceRouter1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap()) - assert.True(t, isRouter, "should be router") - assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match") - assert.Len(t, sourcePeers, 0, "expected source peers don't match") - - // validate rules for router1 - rules := account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers[accNetResourceRouter1ID], accNetResourceValidPeers, networkResourcesRoutes, account.GetResourcePoliciesMap()) - assert.Len(t, rules, 0, "expected rules count don't match") -} - -func Test_NetworksNetMapGenWithTwoPoliciesAndPostureChecks(t *testing.T) { - account := getBasicAccountsWithResource() - - // should allow peer1 to match the policy - policy := account.Policies[0] - policy.SourcePostureChecks = []string{accNetResourceRestrictPostureCheckID} - - // should allow peer1 and peer2 to match the policy - newPolicy := &Policy{ - ID: "policy2ID", - AccountID: accID, - Enabled: true, - Rules: []*PolicyRule{ - { - ID: "policy2ID", - Enabled: true, - Sources: []string{group1ID}, - DestinationResource: Resource{ - ID: accNetResource1ID, - Type: "Host", - }, - Protocol: PolicyRuleProtocolTCP, - Ports: []string{"22"}, - Action: PolicyTrafficActionAccept, - }, - }, - SourcePostureChecks: []string{accNetResourceRelaxedPostureCheckID}, - } - - account.Policies = append(account.Policies, newPolicy) - - // validate for peer1 - isRouter, networkResourcesRoutes, sourcePeers := account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap()) - assert.False(t, isRouter, "expected router status") - assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match") - assert.Len(t, sourcePeers, 0, "expected source peers don't match") - - // validate for peer2 - isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer2ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap()) - assert.False(t, isRouter, "expected router status") - assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match") - assert.Len(t, sourcePeers, 0, "expected source peers don't match") - - // validate routes for router1 - isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourceRouter1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap()) - assert.True(t, isRouter, "should be router") - assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match") - assert.Len(t, sourcePeers, 2, "expected source peers don't match") - assert.NotNil(t, sourcePeers[accNetResourcePeer1ID], "expected source peers don't match") - assert.NotNil(t, sourcePeers[accNetResourcePeer2ID], "expected source peers don't match") - - // validate rules for router1 - rules := account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers[accNetResourceRouter1ID], accNetResourceValidPeers, networkResourcesRoutes, account.GetResourcePoliciesMap()) - assert.Len(t, rules, 2, "expected rules count don't match") - assert.Equal(t, uint16(80), rules[0].Port, "should have port 80") - assert.Equal(t, "tcp", rules[0].Protocol, "should have protocol tcp") - if !slices.Contains(rules[0].SourceRanges, accNetResourcePeer1IP.String()+"/32") { - t.Errorf("%s should have source range of peer1 %s", rules[0].SourceRanges, accNetResourcePeer1IP.String()) - } - if slices.Contains(rules[0].SourceRanges, accNetResourcePeer2IP.String()+"/32") { - t.Errorf("%s should not have source range of peer2 %s", rules[0].SourceRanges, accNetResourcePeer2IP.String()) - } - - assert.Equal(t, uint16(22), rules[1].Port, "should have port 22") - assert.Equal(t, "tcp", rules[1].Protocol, "should have protocol tcp") - if !slices.Contains(rules[1].SourceRanges, accNetResourcePeer1IP.String()+"/32") { - t.Errorf("%s should have source range of peer1 %s", rules[1].SourceRanges, accNetResourcePeer1IP.String()) - } - if !slices.Contains(rules[1].SourceRanges, accNetResourcePeer2IP.String()+"/32") { - t.Errorf("%s should have source range of peer2 %s", rules[1].SourceRanges, accNetResourcePeer2IP.String()) - } -} - -func Test_NetworksNetMapGenWithTwoPostureChecks(t *testing.T) { - account := getBasicAccountsWithResource() - - // two posture checks should match only the peers that match both checks - policy := account.Policies[0] - policy.SourcePostureChecks = []string{accNetResourceRelaxedPostureCheckID, accNetResourceLinuxPostureCheckID} - - // validate for peer1 - isRouter, networkResourcesRoutes, sourcePeers := account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap()) - assert.False(t, isRouter, "expected router status") - assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match") - assert.Len(t, sourcePeers, 0, "expected source peers don't match") - - // validate for peer2 - isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer2ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap()) - assert.False(t, isRouter, "expected router status") - assert.Len(t, networkResourcesRoutes, 0, "expected network resource route don't match") - assert.Len(t, sourcePeers, 0, "expected source peers don't match") - - // validate routes for router1 - isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourceRouter1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap()) - assert.True(t, isRouter, "should be router") - assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match") - assert.Len(t, sourcePeers, 1, "expected source peers don't match") - assert.NotNil(t, sourcePeers[accNetResourcePeer1ID], "expected source peers don't match") - - // validate rules for router1 - rules := account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers[accNetResourceRouter1ID], accNetResourceValidPeers, networkResourcesRoutes, account.GetResourcePoliciesMap()) - assert.Len(t, rules, 1, "expected rules count don't match") - assert.Equal(t, uint16(80), rules[0].Port, "should have port 80") - assert.Equal(t, "tcp", rules[0].Protocol, "should have protocol tcp") - if !slices.Contains(rules[0].SourceRanges, accNetResourcePeer1IP.String()+"/32") { - t.Errorf("%s should have source range of peer1 %s", rules[0].SourceRanges, accNetResourcePeer1IP.String()) - } - if slices.Contains(rules[0].SourceRanges, accNetResourcePeer2IP.String()+"/32") { - t.Errorf("%s should not have source range of peer2 %s", rules[0].SourceRanges, accNetResourcePeer2IP.String()) - } -} - -func Test_NetworksNetMapGenShouldExcludeOtherRouters(t *testing.T) { - account := getBasicAccountsWithResource() - - account.Peers["router2Id"] = &nbpeer.Peer{Key: "router2Key", ID: "router2Id", AccountID: accID, IP: net.IP{192, 168, 1, 4}} - account.NetworkRouters = append(account.NetworkRouters, &routerTypes.NetworkRouter{ - ID: "router2Id", - NetworkID: network1ID, - AccountID: accID, - Peer: "router2Id", - }) - - // validate routes for router1 - isRouter, networkResourcesRoutes, sourcePeers := account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourceRouter1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap()) - assert.True(t, isRouter, "should be router") - assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match") - assert.Len(t, sourcePeers, 2, "expected source peers don't match") -} - func Test_ExpandPortsAndRanges_SSHRuleExpansion(t *testing.T) { tests := []struct { name string diff --git a/management/server/types/holder.go b/management/server/types/holder.go deleted file mode 100644 index de8ac8110..000000000 --- a/management/server/types/holder.go +++ /dev/null @@ -1,47 +0,0 @@ -package types - -import ( - "context" - "sync" -) - -type Holder struct { - mu sync.RWMutex - accounts map[string]*Account -} - -func NewHolder() *Holder { - return &Holder{ - accounts: make(map[string]*Account), - } -} - -func (h *Holder) GetAccount(id string) *Account { - h.mu.RLock() - defer h.mu.RUnlock() - return h.accounts[id] -} - -func (h *Holder) AddAccount(account *Account) { - h.mu.Lock() - defer h.mu.Unlock() - a := h.accounts[account.Id] - if a != nil && a.Network.CurrentSerial() >= account.Network.CurrentSerial() { - return - } - h.accounts[account.Id] = account -} - -func (h *Holder) LoadOrStoreFunc(ctx context.Context, id string, accGetter func(context.Context, string) (*Account, error)) (*Account, error) { - h.mu.Lock() - defer h.mu.Unlock() - if acc, ok := h.accounts[id]; ok { - return acc, nil - } - account, err := accGetter(ctx, id) - if err != nil { - return nil, err - } - h.accounts[id] = account - return account, nil -} diff --git a/management/server/types/networkmap.go b/management/server/types/networkmap.go deleted file mode 100644 index 68c988a93..000000000 --- a/management/server/types/networkmap.go +++ /dev/null @@ -1,67 +0,0 @@ -package types - -import ( - "context" - - nbdns "github.com/netbirdio/netbird/dns" - "github.com/netbirdio/netbird/management/internals/modules/zones" - nbpeer "github.com/netbirdio/netbird/management/server/peer" - "github.com/netbirdio/netbird/management/server/telemetry" -) - -func (a *Account) initNetworkMapBuilder(validatedPeers map[string]struct{}) { - if a.NetworkMapCache != nil { - return - } - a.nmapInitOnce.Do(func() { - a.NetworkMapCache = NewNetworkMapBuilder(a, validatedPeers) - }) -} - -func (a *Account) InitNetworkMapBuilderIfNeeded(validatedPeers map[string]struct{}) { - a.initNetworkMapBuilder(validatedPeers) -} - -func (a *Account) GetPeerNetworkMapExp( - ctx context.Context, - peerID string, - peersCustomZone nbdns.CustomZone, - accountZones []*zones.Zone, - validatedPeers map[string]struct{}, - metrics *telemetry.AccountManagerMetrics, -) *NetworkMap { - a.initNetworkMapBuilder(validatedPeers) - return a.NetworkMapCache.GetPeerNetworkMap(ctx, peerID, peersCustomZone, accountZones, validatedPeers, metrics) -} - -func (a *Account) OnPeerAddedUpdNetworkMapCache(peerId string) error { - if a.NetworkMapCache == nil { - return nil - } - return a.NetworkMapCache.OnPeerAddedIncremental(a, peerId) -} - -func (a *Account) OnPeersAddedUpdNetworkMapCache(peerIds ...string) { - if a.NetworkMapCache == nil { - return - } - a.NetworkMapCache.EnqueuePeersForIncrementalAdd(a, peerIds...) -} - -func (a *Account) OnPeerDeletedUpdNetworkMapCache(peerId string) error { - if a.NetworkMapCache == nil { - return nil - } - return a.NetworkMapCache.OnPeerDeleted(a, peerId) -} - -func (a *Account) UpdatePeerInNetworkMapCache(peer *nbpeer.Peer) { - if a.NetworkMapCache == nil { - return - } - a.NetworkMapCache.UpdatePeer(peer) -} - -func (a *Account) RecalculateNetworkMapCache(validatedPeers map[string]struct{}) { - a.initNetworkMapBuilder(validatedPeers) -} diff --git a/management/server/types/networkmap_comparison_test.go b/management/server/types/networkmap_comparison_test.go deleted file mode 100644 index c5844cca0..000000000 --- a/management/server/types/networkmap_comparison_test.go +++ /dev/null @@ -1,592 +0,0 @@ -package types - -import ( - "context" - "encoding/json" - "fmt" - "net" - "net/netip" - "os" - "path/filepath" - "sort" - "testing" - "time" - - "github.com/stretchr/testify/require" - - nbdns "github.com/netbirdio/netbird/dns" - resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" - routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" - networkTypes "github.com/netbirdio/netbird/management/server/networks/types" - nbpeer "github.com/netbirdio/netbird/management/server/peer" - "github.com/netbirdio/netbird/management/server/posture" - "github.com/netbirdio/netbird/route" -) - -func TestNetworkMapComponents_CompareWithLegacy(t *testing.T) { - account := createTestAccount() - ctx := context.Background() - - peerID := testingPeerID - validatedPeersMap := make(map[string]struct{}) - for i := range numPeers { - pid := fmt.Sprintf("peer-%d", i) - if pid == offlinePeerID { - continue - } - validatedPeersMap[pid] = struct{}{} - } - - peersCustomZone := nbdns.CustomZone{} - resourcePolicies := account.GetResourcePoliciesMap() - routers := account.GetResourceRoutersMap() - groupIDToUserIDs := account.GetActiveGroupUsers() - - legacyNetworkMap := account.GetPeerNetworkMap( - ctx, - peerID, - peersCustomZone, - nil, - validatedPeersMap, - resourcePolicies, - routers, - nil, - groupIDToUserIDs, - ) - - components := account.GetPeerNetworkMapComponents( - ctx, - peerID, - peersCustomZone, - nil, - validatedPeersMap, - resourcePolicies, - routers, - groupIDToUserIDs, - ) - - if components == nil { - t.Fatal("GetPeerNetworkMapComponents returned nil") - } - - newNetworkMap := CalculateNetworkMapFromComponents(ctx, components) - - if newNetworkMap == nil { - t.Fatal("CalculateNetworkMapFromComponents returned nil") - } - - compareNetworkMaps(t, legacyNetworkMap, newNetworkMap) -} - -func TestNetworkMapComponents_GoldenFileComparison(t *testing.T) { - account := createTestAccount() - ctx := context.Background() - - peerID := testingPeerID - validatedPeersMap := make(map[string]struct{}) - for i := range numPeers { - pid := fmt.Sprintf("peer-%d", i) - if pid == offlinePeerID { - continue - } - validatedPeersMap[pid] = struct{}{} - } - - peersCustomZone := nbdns.CustomZone{} - resourcePolicies := account.GetResourcePoliciesMap() - routers := account.GetResourceRoutersMap() - groupIDToUserIDs := account.GetActiveGroupUsers() - - legacyNetworkMap := account.GetPeerNetworkMap( - ctx, - peerID, - peersCustomZone, - nil, - validatedPeersMap, - resourcePolicies, - routers, - nil, - groupIDToUserIDs, - ) - - components := account.GetPeerNetworkMapComponents( - ctx, - peerID, - peersCustomZone, - nil, - validatedPeersMap, - resourcePolicies, - routers, - groupIDToUserIDs, - ) - - require.NotNil(t, components, "GetPeerNetworkMapComponents returned nil") - - newNetworkMap := CalculateNetworkMapFromComponents(ctx, components) - require.NotNil(t, newNetworkMap, "CalculateNetworkMapFromComponents returned nil") - - normalizeAndSortNetworkMap(legacyNetworkMap) - normalizeAndSortNetworkMap(newNetworkMap) - - componentsJSON, err := json.MarshalIndent(components, "", " ") - require.NoError(t, err, "error marshaling components to JSON") - - legacyJSON, err := json.MarshalIndent(legacyNetworkMap, "", " ") - require.NoError(t, err, "error marshaling legacy network map to JSON") - - newJSON, err := json.MarshalIndent(newNetworkMap, "", " ") - require.NoError(t, err, "error marshaling new network map to JSON") - - goldenDir := filepath.Join("testdata", "comparison") - err = os.MkdirAll(goldenDir, 0755) - require.NoError(t, err) - - legacyGoldenPath := filepath.Join(goldenDir, "legacy_networkmap.json") - err = os.WriteFile(legacyGoldenPath, legacyJSON, 0644) - require.NoError(t, err, "error writing legacy golden file") - - newGoldenPath := filepath.Join(goldenDir, "components_networkmap.json") - err = os.WriteFile(newGoldenPath, newJSON, 0644) - require.NoError(t, err, "error writing components golden file") - - componentsPath := filepath.Join(goldenDir, "components.json") - err = os.WriteFile(componentsPath, componentsJSON, 0644) - require.NoError(t, err, "error writing components golden file") - - require.JSONEq(t, string(legacyJSON), string(newJSON), - "NetworkMaps from legacy and components approaches do not match.\n"+ - "Legacy JSON saved to: %s\n"+ - "Components JSON saved to: %s", - legacyGoldenPath, newGoldenPath) - - t.Logf("✅ NetworkMaps are identical") - t.Logf(" Legacy NetworkMap: %s", legacyGoldenPath) - t.Logf(" Components NetworkMap: %s", newGoldenPath) -} - -func normalizeAndSortNetworkMap(nm *NetworkMap) { - if nm == nil { - return - } - - sort.Slice(nm.Peers, func(i, j int) bool { - return nm.Peers[i].ID < nm.Peers[j].ID - }) - - sort.Slice(nm.OfflinePeers, func(i, j int) bool { - return nm.OfflinePeers[i].ID < nm.OfflinePeers[j].ID - }) - - sort.Slice(nm.Routes, func(i, j int) bool { - return string(nm.Routes[i].ID) < string(nm.Routes[j].ID) - }) - - sort.Slice(nm.FirewallRules, func(i, j int) bool { - if nm.FirewallRules[i].PeerIP != nm.FirewallRules[j].PeerIP { - return nm.FirewallRules[i].PeerIP < nm.FirewallRules[j].PeerIP - } - if nm.FirewallRules[i].Direction != nm.FirewallRules[j].Direction { - return nm.FirewallRules[i].Direction < nm.FirewallRules[j].Direction - } - if nm.FirewallRules[i].Protocol != nm.FirewallRules[j].Protocol { - return nm.FirewallRules[i].Protocol < nm.FirewallRules[j].Protocol - } - if nm.FirewallRules[i].Port != nm.FirewallRules[j].Port { - return nm.FirewallRules[i].Port < nm.FirewallRules[j].Port - } - return nm.FirewallRules[i].PolicyID < nm.FirewallRules[j].PolicyID - }) - - for i := range nm.RoutesFirewallRules { - sort.Strings(nm.RoutesFirewallRules[i].SourceRanges) - } - - sort.Slice(nm.RoutesFirewallRules, func(i, j int) bool { - if nm.RoutesFirewallRules[i].Destination != nm.RoutesFirewallRules[j].Destination { - return nm.RoutesFirewallRules[i].Destination < nm.RoutesFirewallRules[j].Destination - } - - minLen := len(nm.RoutesFirewallRules[i].SourceRanges) - if len(nm.RoutesFirewallRules[j].SourceRanges) < minLen { - minLen = len(nm.RoutesFirewallRules[j].SourceRanges) - } - for k := 0; k < minLen; k++ { - if nm.RoutesFirewallRules[i].SourceRanges[k] != nm.RoutesFirewallRules[j].SourceRanges[k] { - return nm.RoutesFirewallRules[i].SourceRanges[k] < nm.RoutesFirewallRules[j].SourceRanges[k] - } - } - if len(nm.RoutesFirewallRules[i].SourceRanges) != len(nm.RoutesFirewallRules[j].SourceRanges) { - return len(nm.RoutesFirewallRules[i].SourceRanges) < len(nm.RoutesFirewallRules[j].SourceRanges) - } - - if string(nm.RoutesFirewallRules[i].RouteID) != string(nm.RoutesFirewallRules[j].RouteID) { - return string(nm.RoutesFirewallRules[i].RouteID) < string(nm.RoutesFirewallRules[j].RouteID) - } - - if nm.RoutesFirewallRules[i].PolicyID != nm.RoutesFirewallRules[j].PolicyID { - return nm.RoutesFirewallRules[i].PolicyID < nm.RoutesFirewallRules[j].PolicyID - } - - if nm.RoutesFirewallRules[i].Port != nm.RoutesFirewallRules[j].Port { - return nm.RoutesFirewallRules[i].Port < nm.RoutesFirewallRules[j].Port - } - - return nm.RoutesFirewallRules[i].Protocol < nm.RoutesFirewallRules[j].Protocol - }) - - if nm.DNSConfig.CustomZones != nil { - for i := range nm.DNSConfig.CustomZones { - sort.Slice(nm.DNSConfig.CustomZones[i].Records, func(a, b int) bool { - return nm.DNSConfig.CustomZones[i].Records[a].Name < nm.DNSConfig.CustomZones[i].Records[b].Name - }) - } - } - - if len(nm.DNSConfig.NameServerGroups) != 0 { - sort.Slice(nm.DNSConfig.NameServerGroups, func(a, b int) bool { - return nm.DNSConfig.NameServerGroups[a].Name < nm.DNSConfig.NameServerGroups[b].Name - }) - } -} - -func compareNetworkMaps(t *testing.T, legacy, current *NetworkMap) { - t.Helper() - - if legacy.Network.Serial != current.Network.Serial { - t.Errorf("Network Serial mismatch: legacy=%d, current=%d", legacy.Network.Serial, current.Network.Serial) - } - - if len(legacy.Peers) != len(current.Peers) { - t.Errorf("Peers count mismatch: legacy=%d, current=%d", len(legacy.Peers), len(current.Peers)) - } - - legacyPeerIDs := make(map[string]bool) - for _, p := range legacy.Peers { - legacyPeerIDs[p.ID] = true - } - - for _, p := range current.Peers { - if !legacyPeerIDs[p.ID] { - t.Errorf("Current NetworkMap contains peer %s not in legacy", p.ID) - } - } - - if len(legacy.OfflinePeers) != len(current.OfflinePeers) { - t.Errorf("OfflinePeers count mismatch: legacy=%d, current=%d", len(legacy.OfflinePeers), len(current.OfflinePeers)) - } - - if len(legacy.FirewallRules) != len(current.FirewallRules) { - t.Logf("FirewallRules count mismatch: legacy=%d, current=%d", len(legacy.FirewallRules), len(current.FirewallRules)) - } - - if len(legacy.Routes) != len(current.Routes) { - t.Logf("Routes count mismatch: legacy=%d, current=%d", len(legacy.Routes), len(current.Routes)) - } - - if len(legacy.RoutesFirewallRules) != len(current.RoutesFirewallRules) { - t.Logf("RoutesFirewallRules count mismatch: legacy=%d, current=%d", len(legacy.RoutesFirewallRules), len(current.RoutesFirewallRules)) - } - - if legacy.DNSConfig.ServiceEnable != current.DNSConfig.ServiceEnable { - t.Errorf("DNSConfig.ServiceEnable mismatch: legacy=%v, current=%v", legacy.DNSConfig.ServiceEnable, current.DNSConfig.ServiceEnable) - } -} - -const ( - numPeers = 100 - devGroupID = "group-dev" - opsGroupID = "group-ops" - allGroupID = "group-all" - routeID = route.ID("route-main") - routeHA1ID = route.ID("route-ha-1") - routeHA2ID = route.ID("route-ha-2") - policyIDDevOps = "policy-dev-ops" - policyIDAll = "policy-all" - policyIDPosture = "policy-posture" - policyIDDrop = "policy-drop" - postureCheckID = "posture-check-ver" - networkResourceID = "res-database" - networkID = "net-database" - networkRouterID = "router-database" - nameserverGroupID = "ns-group-main" - testingPeerID = "peer-60" - expiredPeerID = "peer-98" - offlinePeerID = "peer-99" - routingPeerID = "peer-95" - testAccountID = "account-comparison-test" -) - -func createTestAccount() *Account { - peers := make(map[string]*nbpeer.Peer) - devGroupPeers, opsGroupPeers, allGroupPeers := []string{}, []string{}, []string{} - - for i := range numPeers { - peerID := fmt.Sprintf("peer-%d", i) - ip := net.IP{100, 64, 0, byte(i + 1)} - wtVersion := "0.25.0" - if i%2 == 0 { - wtVersion = "0.40.0" - } - - p := &nbpeer.Peer{ - ID: peerID, IP: ip, Key: fmt.Sprintf("key-%s", peerID), DNSLabel: fmt.Sprintf("peer%d", i+1), - Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()}, - UserID: "user-admin", Meta: nbpeer.PeerSystemMeta{WtVersion: wtVersion, GoOS: "linux"}, - } - - if peerID == expiredPeerID { - p.LoginExpirationEnabled = true - pastTimestamp := time.Now().Add(-2 * time.Hour) - p.LastLogin = &pastTimestamp - } - - peers[peerID] = p - allGroupPeers = append(allGroupPeers, peerID) - if i < numPeers/2 { - devGroupPeers = append(devGroupPeers, peerID) - } else { - opsGroupPeers = append(opsGroupPeers, peerID) - } - } - - groups := map[string]*Group{ - allGroupID: {ID: allGroupID, Name: "All", Peers: allGroupPeers}, - devGroupID: {ID: devGroupID, Name: "Developers", Peers: devGroupPeers}, - opsGroupID: {ID: opsGroupID, Name: "Operations", Peers: opsGroupPeers}, - } - - policies := []*Policy{ - { - ID: policyIDAll, Name: "Default-Allow", Enabled: true, - Rules: []*PolicyRule{{ - ID: policyIDAll, Name: "Allow All", Enabled: true, Action: PolicyTrafficActionAccept, - Protocol: PolicyRuleProtocolALL, Bidirectional: true, - Sources: []string{allGroupID}, Destinations: []string{allGroupID}, - }}, - }, - { - ID: policyIDDevOps, Name: "Dev to Ops Web Access", Enabled: true, - Rules: []*PolicyRule{{ - ID: policyIDDevOps, Name: "Dev -> Ops (HTTP Range)", Enabled: true, Action: PolicyTrafficActionAccept, - Protocol: PolicyRuleProtocolTCP, Bidirectional: false, - PortRanges: []RulePortRange{{Start: 8080, End: 8090}}, - Sources: []string{devGroupID}, Destinations: []string{opsGroupID}, - }}, - }, - { - ID: policyIDDrop, Name: "Drop DB traffic", Enabled: true, - Rules: []*PolicyRule{{ - ID: policyIDDrop, Name: "Drop DB", Enabled: true, Action: PolicyTrafficActionDrop, - Protocol: PolicyRuleProtocolTCP, Ports: []string{"5432"}, Bidirectional: true, - Sources: []string{devGroupID}, Destinations: []string{opsGroupID}, - }}, - }, - { - ID: policyIDPosture, Name: "Posture Check for DB Resource", Enabled: true, - SourcePostureChecks: []string{postureCheckID}, - Rules: []*PolicyRule{{ - ID: policyIDPosture, Name: "Allow DB Access", Enabled: true, Action: PolicyTrafficActionAccept, - Protocol: PolicyRuleProtocolALL, Bidirectional: true, - Sources: []string{opsGroupID}, DestinationResource: Resource{ID: networkResourceID}, - }}, - }, - } - - routes := map[route.ID]*route.Route{ - routeID: { - ID: routeID, Network: netip.MustParsePrefix("192.168.10.0/24"), - Peer: peers["peer-75"].Key, - PeerID: "peer-75", - Description: "Route to internal resource", Enabled: true, - PeerGroups: []string{devGroupID, opsGroupID}, - Groups: []string{devGroupID, opsGroupID}, - AccessControlGroups: []string{devGroupID}, - }, - routeHA1ID: { - ID: routeHA1ID, Network: netip.MustParsePrefix("10.10.0.0/16"), - Peer: peers["peer-80"].Key, - PeerID: "peer-80", - Description: "HA Route 1", Enabled: true, Metric: 1000, - PeerGroups: []string{allGroupID}, - Groups: []string{allGroupID}, - AccessControlGroups: []string{allGroupID}, - }, - routeHA2ID: { - ID: routeHA2ID, Network: netip.MustParsePrefix("10.10.0.0/16"), - Peer: peers["peer-90"].Key, - PeerID: "peer-90", - Description: "HA Route 2", Enabled: true, Metric: 900, - PeerGroups: []string{devGroupID, opsGroupID}, - Groups: []string{devGroupID, opsGroupID}, - AccessControlGroups: []string{allGroupID}, - }, - } - - account := &Account{ - Id: testAccountID, Peers: peers, Groups: groups, Policies: policies, Routes: routes, - Network: &Network{ - Identifier: "net-comparison-test", Net: net.IPNet{IP: net.IP{100, 64, 0, 0}, Mask: net.CIDRMask(16, 32)}, Serial: 1, - }, - DNSSettings: DNSSettings{DisabledManagementGroups: []string{opsGroupID}}, - NameServerGroups: map[string]*nbdns.NameServerGroup{ - nameserverGroupID: { - ID: nameserverGroupID, Name: "Main NS", Enabled: true, Groups: []string{devGroupID}, - NameServers: []nbdns.NameServer{{IP: netip.MustParseAddr("8.8.8.8"), NSType: nbdns.UDPNameServerType, Port: 53}}, - }, - }, - PostureChecks: []*posture.Checks{ - {ID: postureCheckID, Name: "Check version", Checks: posture.ChecksDefinition{ - NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.26.0"}, - }}, - }, - NetworkResources: []*resourceTypes.NetworkResource{ - {ID: networkResourceID, NetworkID: networkID, AccountID: testAccountID, Enabled: true, Address: "db.netbird.cloud"}, - }, - Networks: []*networkTypes.Network{{ID: networkID, Name: "DB Network", AccountID: testAccountID}}, - NetworkRouters: []*routerTypes.NetworkRouter{ - {ID: networkRouterID, NetworkID: networkID, Peer: routingPeerID, Enabled: true, AccountID: testAccountID}, - }, - Settings: &Settings{PeerLoginExpirationEnabled: true, PeerLoginExpiration: 1 * time.Hour}, - } - - for _, p := range account.Policies { - p.AccountID = account.Id - } - for _, r := range account.Routes { - r.AccountID = account.Id - } - - return account -} - -func BenchmarkLegacyNetworkMap(b *testing.B) { - account := createTestAccount() - ctx := context.Background() - peerID := testingPeerID - validatedPeersMap := make(map[string]struct{}) - for i := range numPeers { - pid := fmt.Sprintf("peer-%d", i) - if pid != offlinePeerID { - validatedPeersMap[pid] = struct{}{} - } - } - - peersCustomZone := nbdns.CustomZone{} - resourcePolicies := account.GetResourcePoliciesMap() - routers := account.GetResourceRoutersMap() - groupIDToUserIDs := account.GetActiveGroupUsers() - - b.ResetTimer() - for i := 0; i < b.N; i++ { - _ = account.GetPeerNetworkMap( - ctx, - peerID, - peersCustomZone, - nil, - validatedPeersMap, - resourcePolicies, - routers, - nil, - groupIDToUserIDs, - ) - } -} - -func BenchmarkComponentsNetworkMap(b *testing.B) { - account := createTestAccount() - ctx := context.Background() - peerID := testingPeerID - validatedPeersMap := make(map[string]struct{}) - for i := range numPeers { - pid := fmt.Sprintf("peer-%d", i) - if pid != offlinePeerID { - validatedPeersMap[pid] = struct{}{} - } - } - - peersCustomZone := nbdns.CustomZone{} - resourcePolicies := account.GetResourcePoliciesMap() - routers := account.GetResourceRoutersMap() - groupIDToUserIDs := account.GetActiveGroupUsers() - - b.ResetTimer() - for i := 0; i < b.N; i++ { - components := account.GetPeerNetworkMapComponents( - ctx, - peerID, - peersCustomZone, - nil, - validatedPeersMap, - resourcePolicies, - routers, - groupIDToUserIDs, - ) - _ = CalculateNetworkMapFromComponents(ctx, components) - } -} - -func BenchmarkComponentsCreation(b *testing.B) { - account := createTestAccount() - ctx := context.Background() - peerID := testingPeerID - validatedPeersMap := make(map[string]struct{}) - for i := range numPeers { - pid := fmt.Sprintf("peer-%d", i) - if pid != offlinePeerID { - validatedPeersMap[pid] = struct{}{} - } - } - - peersCustomZone := nbdns.CustomZone{} - resourcePolicies := account.GetResourcePoliciesMap() - routers := account.GetResourceRoutersMap() - groupIDToUserIDs := account.GetActiveGroupUsers() - - b.ResetTimer() - for i := 0; i < b.N; i++ { - _ = account.GetPeerNetworkMapComponents( - ctx, - peerID, - peersCustomZone, - nil, - validatedPeersMap, - resourcePolicies, - routers, - groupIDToUserIDs, - ) - } -} - -func BenchmarkCalculationFromComponents(b *testing.B) { - account := createTestAccount() - ctx := context.Background() - peerID := testingPeerID - validatedPeersMap := make(map[string]struct{}) - for i := range numPeers { - pid := fmt.Sprintf("peer-%d", i) - if pid != offlinePeerID { - validatedPeersMap[pid] = struct{}{} - } - } - - peersCustomZone := nbdns.CustomZone{} - resourcePolicies := account.GetResourcePoliciesMap() - routers := account.GetResourceRoutersMap() - groupIDToUserIDs := account.GetActiveGroupUsers() - - components := account.GetPeerNetworkMapComponents( - ctx, - peerID, - peersCustomZone, - nil, - validatedPeersMap, - resourcePolicies, - routers, - groupIDToUserIDs, - ) - - b.ResetTimer() - for i := 0; i < b.N; i++ { - _ = CalculateNetworkMapFromComponents(ctx, components) - } -} diff --git a/management/server/types/networkmap_components.go b/management/server/types/networkmap_components.go index 23d84a994..6f84c8d30 100644 --- a/management/server/types/networkmap_components.go +++ b/management/server/types/networkmap_components.go @@ -19,8 +19,6 @@ import ( "github.com/netbirdio/netbird/shared/management/domain" ) -const EnvNewNetworkMapCompacted = "NB_NETWORK_MAP_COMPACTED" - type NetworkMapComponents struct { PeerID string diff --git a/management/server/types/networkmap_components_test.go b/management/server/types/networkmap_components_test.go new file mode 100644 index 000000000..dde639ccb --- /dev/null +++ b/management/server/types/networkmap_components_test.go @@ -0,0 +1,787 @@ +package types_test + +import ( + "context" + "net" + "net/netip" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + nbdns "github.com/netbirdio/netbird/dns" + resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" + routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" + networkTypes "github.com/netbirdio/netbird/management/server/networks/types" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/posture" + "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/route" +) + +func networkMapFromComponents(t *testing.T, account *types.Account, peerID string, validatedPeers map[string]struct{}) *types.NetworkMap { + t.Helper() + return account.GetPeerNetworkMapFromComponents( + context.Background(), + peerID, + account.GetPeersCustomZone(context.Background(), "netbird.io"), + nil, + validatedPeers, + account.GetResourcePoliciesMap(), + account.GetResourceRoutersMap(), + nil, + account.GetActiveGroupUsers(), + ) +} + +func allPeersValidated(account *types.Account, excludePeerIDs ...string) map[string]struct{} { + excludeSet := make(map[string]struct{}, len(excludePeerIDs)) + for _, id := range excludePeerIDs { + excludeSet[id] = struct{}{} + } + validated := make(map[string]struct{}, len(account.Peers)) + for id := range account.Peers { + if _, excluded := excludeSet[id]; !excluded { + validated[id] = struct{}{} + } + } + return validated +} + +func peerIDs(peers []*nbpeer.Peer) []string { + ids := make([]string, len(peers)) + for i, p := range peers { + ids[i] = p.ID + } + return ids +} + +func TestNetworkMapComponents_RegularPeerConnectivity(t *testing.T) { + account := createComponentTestAccount() + validated := allPeersValidated(account) + + nm := networkMapFromComponents(t, account, "peer-src-1", validated) + + assert.NotNil(t, nm) + assert.Contains(t, peerIDs(nm.Peers), "peer-dst-1", "should see peer from destination group via bidirectional policy") + assert.Contains(t, peerIDs(nm.Peers), "peer-router-1", "should see router peer via resource policy") + assert.NotContains(t, peerIDs(nm.Peers), "peer-src-1", "should not see itself") + assert.Empty(t, nm.OfflinePeers, "no expired peers expected") +} + +func TestNetworkMapComponents_IntraGroupConnectivity(t *testing.T) { + account := createComponentTestAccount() + account.Policies = append(account.Policies, &types.Policy{ + ID: "policy-intra-src", Name: "Intra-source connectivity", Enabled: true, AccountID: account.Id, + Rules: []*types.PolicyRule{{ + ID: "rule-intra-src", Name: "src <-> src", Enabled: true, + Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolALL, + Bidirectional: true, + Sources: []string{"group-src"}, Destinations: []string{"group-src"}, + }}, + }) + validated := allPeersValidated(account) + + nm := networkMapFromComponents(t, account, "peer-src-1", validated) + assert.Contains(t, peerIDs(nm.Peers), "peer-src-2", "should see peer from same group with intra-group policy") +} + +func TestNetworkMapComponents_FirewallRules(t *testing.T) { + account := createComponentTestAccount() + validated := allPeersValidated(account) + + nm := networkMapFromComponents(t, account, "peer-src-1", validated) + + require.NotEmpty(t, nm.FirewallRules, "firewall rules should be generated") + + var hasAcceptAll bool + for _, rule := range nm.FirewallRules { + if rule.Protocol == string(types.PolicyRuleProtocolALL) && rule.Action == string(types.PolicyTrafficActionAccept) { + hasAcceptAll = true + } + } + assert.True(t, hasAcceptAll, "should have an accept-all firewall rule from the base policy") +} + +func TestNetworkMapComponents_LoginExpiration(t *testing.T) { + account := createComponentTestAccount() + account.Settings.PeerLoginExpirationEnabled = true + account.Settings.PeerLoginExpiration = 1 * time.Hour + + expiredTime := time.Now().Add(-2 * time.Hour) + account.Peers["peer-dst-1"].LoginExpirationEnabled = true + account.Peers["peer-dst-1"].LastLogin = &expiredTime + + validated := allPeersValidated(account) + nm := networkMapFromComponents(t, account, "peer-src-1", validated) + + assert.Contains(t, peerIDs(nm.OfflinePeers), "peer-dst-1", "expired peer should be in OfflinePeers") + assert.NotContains(t, peerIDs(nm.Peers), "peer-dst-1", "expired peer should NOT be in active Peers") +} + +func TestNetworkMapComponents_InvalidatedPeerExcluded(t *testing.T) { + account := createComponentTestAccount() + validated := allPeersValidated(account, "peer-dst-1") + + nm := networkMapFromComponents(t, account, "peer-src-1", validated) + + assert.NotContains(t, peerIDs(nm.Peers), "peer-dst-1", "non-validated peer should be excluded") + assert.NotContains(t, peerIDs(nm.OfflinePeers), "peer-dst-1", "non-validated peer should not be in offline peers either") +} + +func TestNetworkMapComponents_NonValidatedTargetPeer(t *testing.T) { + account := createComponentTestAccount() + validated := allPeersValidated(account, "peer-src-1") + + nm := networkMapFromComponents(t, account, "peer-src-1", validated) + + assert.Empty(t, nm.Peers, "non-validated target peer should get empty network map") + assert.Empty(t, nm.FirewallRules) +} + +func TestNetworkMapComponents_NetworkResourceRoutes_SourcePeer(t *testing.T) { + account := createComponentTestAccount() + validated := allPeersValidated(account) + + nm := networkMapFromComponents(t, account, "peer-src-1", validated) + + var hasResourceRoute bool + for _, r := range nm.Routes { + if r.Network.String() == "10.200.0.1/32" { + hasResourceRoute = true + break + } + } + assert.True(t, hasResourceRoute, "source peer should receive route to network resource via router") + assert.Contains(t, peerIDs(nm.Peers), "peer-router-1", "source peer should see the routing peer") +} + +func TestNetworkMapComponents_NetworkResourceRoutes_RouterPeer(t *testing.T) { + account := createComponentTestAccount() + validated := allPeersValidated(account) + + nm := networkMapFromComponents(t, account, "peer-router-1", validated) + + var hasResourceRoute bool + for _, r := range nm.Routes { + if r.Network.String() == "10.200.0.1/32" { + hasResourceRoute = true + break + } + } + assert.True(t, hasResourceRoute, "router peer should receive network resource route") + assert.NotEmpty(t, nm.RoutesFirewallRules, "router peer should have route firewall rules for the resource") +} + +func TestNetworkMapComponents_NetworkResourceRoutes_UnrelatedPeer(t *testing.T) { + account := createComponentTestAccount() + validated := allPeersValidated(account) + + nm := networkMapFromComponents(t, account, "peer-dst-1", validated) + + for _, r := range nm.Routes { + assert.NotEqual(t, "10.200.0.1/32", r.Network.String(), "unrelated peer should not receive network resource route") + } +} + +func TestNetworkMapComponents_NetworkResource_WithPostureCheck(t *testing.T) { + account := createComponentTestAccount() + validated := allPeersValidated(account) + + account.PostureChecks = []*posture.Checks{ + {ID: "pc-version", Name: "Version check", Checks: posture.ChecksDefinition{ + NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.30.0"}, + }}, + } + account.Policies = append(account.Policies, &types.Policy{ + ID: "policy-posture-resource", Name: "Posture resource access", Enabled: true, AccountID: account.Id, + SourcePostureChecks: []string{"pc-version"}, + Rules: []*types.PolicyRule{{ + ID: "rule-posture-resource", Name: "Posture -> Resource", Enabled: true, + Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolALL, + Sources: []string{"group-src"}, + DestinationResource: types.Resource{ID: "resource-guarded"}, + }}, + }) + + account.NetworkResources = append(account.NetworkResources, &resourceTypes.NetworkResource{ + ID: "resource-guarded", NetworkID: "net-guarded", AccountID: account.Id, Enabled: true, + Type: resourceTypes.Host, Prefix: netip.MustParsePrefix("10.200.1.1/32"), Address: "10.200.1.1/32", + }) + account.Networks = append(account.Networks, &networkTypes.Network{ + ID: "net-guarded", Name: "Guarded Net", AccountID: account.Id, + }) + account.NetworkRouters = append(account.NetworkRouters, &routerTypes.NetworkRouter{ + ID: "router-guarded", NetworkID: "net-guarded", Peer: "peer-router-1", Enabled: true, AccountID: account.Id, + }) + + t.Run("peer passes posture check", func(t *testing.T) { + account.Peers["peer-src-1"].Meta.WtVersion = "0.35.0" + nm := networkMapFromComponents(t, account, "peer-src-1", validated) + + var hasGuardedRoute bool + for _, r := range nm.Routes { + if r.Network.String() == "10.200.1.1/32" { + hasGuardedRoute = true + } + } + assert.True(t, hasGuardedRoute, "peer passing posture check should get guarded resource route") + }) + + t.Run("peer fails posture check", func(t *testing.T) { + account.Peers["peer-src-1"].Meta.WtVersion = "0.20.0" + nm := networkMapFromComponents(t, account, "peer-src-1", validated) + + for _, r := range nm.Routes { + assert.NotEqual(t, "10.200.1.1/32", r.Network.String(), "peer failing posture check should NOT get guarded resource route") + } + }) +} + +func TestNetworkMapComponents_NetworkResource_MultiplePostureChecks(t *testing.T) { + account := createComponentTestAccount() + validated := allPeersValidated(account) + + account.PostureChecks = []*posture.Checks{ + {ID: "pc-version", Name: "Version", Checks: posture.ChecksDefinition{ + NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.30.0"}, + }}, + {ID: "pc-os", Name: "OS check", Checks: posture.ChecksDefinition{ + OSVersionCheck: &posture.OSVersionCheck{Linux: &posture.MinKernelVersionCheck{MinKernelVersion: "5.0"}}, + }}, + } + + account.Policies = append(account.Policies, &types.Policy{ + ID: "policy-multi-posture", Name: "Multi posture", Enabled: true, AccountID: account.Id, + SourcePostureChecks: []string{"pc-version", "pc-os"}, + Rules: []*types.PolicyRule{{ + ID: "rule-multi-posture", Name: "Multi posture rule", Enabled: true, + Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolALL, + Sources: []string{"group-src"}, + DestinationResource: types.Resource{ID: "resource-strict"}, + }}, + }) + + account.NetworkResources = append(account.NetworkResources, &resourceTypes.NetworkResource{ + ID: "resource-strict", NetworkID: "net-strict", AccountID: account.Id, Enabled: true, + Type: resourceTypes.Host, Prefix: netip.MustParsePrefix("10.200.2.1/32"), Address: "10.200.2.1/32", + }) + account.Networks = append(account.Networks, &networkTypes.Network{ + ID: "net-strict", Name: "Strict Net", AccountID: account.Id, + }) + account.NetworkRouters = append(account.NetworkRouters, &routerTypes.NetworkRouter{ + ID: "router-strict", NetworkID: "net-strict", Peer: "peer-router-1", Enabled: true, AccountID: account.Id, + }) + + t.Run("passes both posture checks", func(t *testing.T) { + account.Peers["peer-src-1"].Meta.WtVersion = "0.35.0" + account.Peers["peer-src-1"].Meta.GoOS = "linux" + account.Peers["peer-src-1"].Meta.KernelVersion = "6.1.0" + nm := networkMapFromComponents(t, account, "peer-src-1", validated) + + var found bool + for _, r := range nm.Routes { + if r.Network.String() == "10.200.2.1/32" { + found = true + } + } + assert.True(t, found, "peer passing both checks should get resource route") + }) + + t.Run("fails version posture check", func(t *testing.T) { + account.Peers["peer-src-1"].Meta.WtVersion = "0.20.0" + account.Peers["peer-src-1"].Meta.KernelVersion = "6.1.0" + nm := networkMapFromComponents(t, account, "peer-src-1", validated) + + for _, r := range nm.Routes { + assert.NotEqual(t, "10.200.2.1/32", r.Network.String(), "peer failing version check should NOT get resource route") + } + }) + + t.Run("fails OS posture check", func(t *testing.T) { + account.Peers["peer-src-1"].Meta.WtVersion = "0.35.0" + account.Peers["peer-src-1"].Meta.KernelVersion = "4.0.0" + nm := networkMapFromComponents(t, account, "peer-src-1", validated) + + for _, r := range nm.Routes { + assert.NotEqual(t, "10.200.2.1/32", r.Network.String(), "peer failing OS check should NOT get resource route") + } + }) +} + +func TestNetworkMapComponents_RouterPeerFirewallRules(t *testing.T) { + account := createComponentTestAccount() + validated := allPeersValidated(account) + + nm := networkMapFromComponents(t, account, "peer-router-1", validated) + + var resourceFWRules []*types.RouteFirewallRule + for _, rule := range nm.RoutesFirewallRules { + if rule.Destination == "10.200.0.1/32" { + resourceFWRules = append(resourceFWRules, rule) + } + } + assert.NotEmpty(t, resourceFWRules, "router should have firewall rules for the network resource") + + var hasSourcePeerIP bool + for _, rule := range resourceFWRules { + for _, sr := range rule.SourceRanges { + if sr == account.Peers["peer-src-1"].IP.String()+"/32" || sr == account.Peers["peer-src-2"].IP.String()+"/32" { + hasSourcePeerIP = true + } + } + } + assert.True(t, hasSourcePeerIP, "resource firewall rules should include source peer IPs") +} + +func TestNetworkMapComponents_DNSManagement(t *testing.T) { + account := createComponentTestAccount() + validated := allPeersValidated(account) + + t.Run("peer in DNS-enabled group", func(t *testing.T) { + nm := networkMapFromComponents(t, account, "peer-src-1", validated) + assert.True(t, nm.DNSConfig.ServiceEnable, "peer in non-disabled group should have DNS enabled") + }) + + t.Run("peer in DNS-disabled group", func(t *testing.T) { + nm := networkMapFromComponents(t, account, "peer-dst-1", validated) + assert.False(t, nm.DNSConfig.ServiceEnable, "peer in DNS-disabled group should have DNS disabled") + }) +} + +func TestNetworkMapComponents_NameServerGroups(t *testing.T) { + account := createComponentTestAccount() + validated := allPeersValidated(account) + + nm := networkMapFromComponents(t, account, "peer-src-1", validated) + assert.True(t, nm.DNSConfig.ServiceEnable) + + var hasNSGroup bool + for _, ns := range nm.DNSConfig.NameServerGroups { + if ns.ID == "ns-main" { + hasNSGroup = true + } + } + assert.True(t, hasNSGroup, "peer in NS group should receive nameserver configuration") +} + +func TestNetworkMapComponents_RoutesWithHADeduplication(t *testing.T) { + account := createComponentTestAccount() + validated := allPeersValidated(account) + + account.Routes["route-ha-1"] = &route.Route{ + ID: "route-ha-1", Network: netip.MustParsePrefix("172.16.0.0/16"), + Peer: account.Peers["peer-dst-1"].Key, PeerID: "peer-dst-1", + Enabled: true, Metric: 100, AccountID: account.Id, + Groups: []string{"group-src", "group-dst"}, PeerGroups: []string{"group-dst"}, + } + account.Routes["route-ha-2"] = &route.Route{ + ID: "route-ha-2", Network: netip.MustParsePrefix("172.16.0.0/16"), + Peer: account.Peers["peer-src-1"].Key, PeerID: "peer-src-1", + Enabled: true, Metric: 200, AccountID: account.Id, + Groups: []string{"group-src", "group-dst"}, PeerGroups: []string{"group-src"}, + } + + nm := networkMapFromComponents(t, account, "peer-src-1", validated) + + haCount := 0 + for _, r := range nm.Routes { + if r.Network.String() == "172.16.0.0/16" { + haCount++ + } + } + assert.Equal(t, 1, haCount, "peer should only receive one route from HA group (not both, since it's a member of one)") +} + +func TestNetworkMapComponents_RoutesFirewallRulesForAccessControl(t *testing.T) { + account := createComponentTestAccount() + validated := allPeersValidated(account) + + account.Routes["route-acl"] = &route.Route{ + ID: "route-acl", Network: netip.MustParsePrefix("192.168.100.0/24"), + Peer: account.Peers["peer-src-1"].Key, PeerID: "peer-src-1", + Enabled: true, Metric: 100, AccountID: account.Id, + Groups: []string{"group-dst"}, + PeerGroups: []string{"group-src"}, + AccessControlGroups: []string{"group-dst"}, + } + + nm := networkMapFromComponents(t, account, "peer-src-1", validated) + + var hasFWRule bool + for _, rule := range nm.RoutesFirewallRules { + if rule.Destination == "192.168.100.0/24" { + hasFWRule = true + } + } + assert.True(t, hasFWRule, "routing peer should have firewall rules for route with access control groups") +} + +func TestNetworkMapComponents_RoutesDefaultPermit(t *testing.T) { + account := createComponentTestAccount() + validated := allPeersValidated(account) + + account.Routes["route-open"] = &route.Route{ + ID: "route-open", Network: netip.MustParsePrefix("10.99.0.0/16"), + Peer: account.Peers["peer-src-1"].Key, PeerID: "peer-src-1", + Enabled: true, Metric: 100, AccountID: account.Id, + Groups: []string{"group-src"}, + PeerGroups: []string{"group-src"}, + } + + nm := networkMapFromComponents(t, account, "peer-src-1", validated) + + var hasFWRule bool + for _, rule := range nm.RoutesFirewallRules { + if rule.Destination == "10.99.0.0/16" { + hasFWRule = true + } + } + assert.True(t, hasFWRule, "route without access control groups should have default permit firewall rules") +} + +func TestNetworkMapComponents_SSHAuthorizedUsers(t *testing.T) { + account := createComponentTestAccount() + validated := allPeersValidated(account) + + account.Peers["peer-dst-1"].SSHEnabled = true + + account.Policies = append(account.Policies, &types.Policy{ + ID: "policy-ssh", Name: "SSH Access", Enabled: true, AccountID: account.Id, + Rules: []*types.PolicyRule{{ + ID: "rule-ssh", Name: "SSH to dst", Enabled: true, + Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolALL, + Bidirectional: true, + Sources: []string{"group-src"}, Destinations: []string{"group-dst"}, + }}, + }) + + nm := networkMapFromComponents(t, account, "peer-dst-1", validated) + assert.True(t, nm.EnableSSH, "SSH-enabled peer with matching policy should have EnableSSH") +} + +func TestNetworkMapComponents_DisabledPolicyIgnored(t *testing.T) { + account := createComponentTestAccount() + validated := allPeersValidated(account) + + for _, p := range account.Policies { + p.Enabled = false + } + + nm := networkMapFromComponents(t, account, "peer-src-1", validated) + assert.Empty(t, nm.Peers, "with all policies disabled, peer should see no other peers") + assert.Empty(t, nm.FirewallRules) +} + +func TestNetworkMapComponents_DisabledRouteIgnored(t *testing.T) { + account := createComponentTestAccount() + validated := allPeersValidated(account) + + for _, r := range account.Routes { + r.Enabled = false + } + for _, r := range account.NetworkResources { + r.Enabled = false + } + + nm := networkMapFromComponents(t, account, "peer-src-1", validated) + assert.Empty(t, nm.Routes, "disabled routes should not appear in network map") +} + +func TestNetworkMapComponents_DisabledNetworkResourceIgnored(t *testing.T) { + account := createComponentTestAccount() + validated := allPeersValidated(account) + + for _, r := range account.NetworkResources { + r.Enabled = false + } + + nm := networkMapFromComponents(t, account, "peer-router-1", validated) + + for _, r := range nm.Routes { + assert.NotEqual(t, "10.200.0.1/32", r.Network.String(), "disabled resource should not generate routes") + } +} + +func TestNetworkMapComponents_BidirectionalPolicy(t *testing.T) { + account := createComponentTestAccount() + validated := allPeersValidated(account) + + nmSrc := networkMapFromComponents(t, account, "peer-src-1", validated) + nmDst := networkMapFromComponents(t, account, "peer-dst-1", validated) + + assert.Contains(t, peerIDs(nmSrc.Peers), "peer-dst-1", "src should see dst via bidirectional policy") + assert.Contains(t, peerIDs(nmDst.Peers), "peer-src-1", "dst should see src via bidirectional policy") +} + +func TestNetworkMapComponents_DropPolicy(t *testing.T) { + account := createComponentTestAccount() + validated := allPeersValidated(account) + + account.Policies = append(account.Policies, &types.Policy{ + ID: "policy-drop", Name: "Drop traffic", Enabled: true, AccountID: account.Id, + Rules: []*types.PolicyRule{{ + ID: "rule-drop", Name: "Drop src->dst", Enabled: true, + Action: types.PolicyTrafficActionDrop, Protocol: types.PolicyRuleProtocolTCP, + Ports: []string{"5432"}, + Sources: []string{"group-src"}, Destinations: []string{"group-dst"}, + }}, + }) + + nm := networkMapFromComponents(t, account, "peer-src-1", validated) + + var hasDropRule bool + for _, rule := range nm.FirewallRules { + if rule.Action == string(types.PolicyTrafficActionDrop) && rule.Port == "5432" { + hasDropRule = true + } + } + assert.True(t, hasDropRule, "drop policy should generate drop firewall rule") +} + +func TestNetworkMapComponents_PortRangePolicy(t *testing.T) { + account := createComponentTestAccount() + validated := allPeersValidated(account) + + account.Peers["peer-src-1"].Meta.WtVersion = "0.50.0" + + account.Policies = append(account.Policies, &types.Policy{ + ID: "policy-range", Name: "Port range", Enabled: true, AccountID: account.Id, + Rules: []*types.PolicyRule{{ + ID: "rule-range", Name: "Range rule", Enabled: true, + Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolTCP, + PortRanges: []types.RulePortRange{{Start: 8080, End: 8090}}, + Sources: []string{"group-src"}, Destinations: []string{"group-dst"}, + }}, + }) + + nm := networkMapFromComponents(t, account, "peer-src-1", validated) + + var hasRangeRule bool + for _, rule := range nm.FirewallRules { + if rule.PortRange.Start == 8080 && rule.PortRange.End == 8090 { + hasRangeRule = true + } + } + assert.True(t, hasRangeRule, "port range policy should generate corresponding firewall rule") +} + +func TestNetworkMapComponents_MultipleNetworkResources(t *testing.T) { + account := createComponentTestAccount() + validated := allPeersValidated(account) + + account.NetworkResources = append(account.NetworkResources, &resourceTypes.NetworkResource{ + ID: "resource-2", NetworkID: "net-1", AccountID: account.Id, Enabled: true, + Type: resourceTypes.Host, Prefix: netip.MustParsePrefix("10.200.0.2/32"), Address: "10.200.0.2/32", + }) + account.Groups["group-res2"] = &types.Group{ID: "group-res2", Name: "Resource 2 Group", Peers: []string{"peer-src-1", "peer-src-2"}, + Resources: []types.Resource{{ID: "resource-2"}}, + } + account.Policies = append(account.Policies, &types.Policy{ + ID: "policy-res2", Name: "Resource 2 Policy", Enabled: true, AccountID: account.Id, + Rules: []*types.PolicyRule{{ + ID: "rule-res2", Name: "Access Resource 2", Enabled: true, + Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolALL, + Sources: []string{"group-src"}, + DestinationResource: types.Resource{ID: "resource-2"}, + }}, + }) + + nm := networkMapFromComponents(t, account, "peer-router-1", validated) + + resourceRouteCount := 0 + for _, r := range nm.Routes { + if r.Network.String() == "10.200.0.1/32" || r.Network.String() == "10.200.0.2/32" { + resourceRouteCount++ + } + } + assert.Equal(t, 2, resourceRouteCount, "router should have routes for both network resources") +} + +func TestNetworkMapComponents_DomainNetworkResource(t *testing.T) { + account := createComponentTestAccount() + validated := allPeersValidated(account) + + account.NetworkResources = append(account.NetworkResources, &resourceTypes.NetworkResource{ + ID: "resource-domain", NetworkID: "net-1", AccountID: account.Id, Enabled: true, + Type: resourceTypes.Domain, Domain: "api.example.com", Address: "api.example.com", + }) + account.Groups["group-res-domain"] = &types.Group{ + ID: "group-res-domain", Name: "Domain Resource Group", + Resources: []types.Resource{{ID: "resource-domain"}}, + } + account.Policies = append(account.Policies, &types.Policy{ + ID: "policy-domain", Name: "Domain resource policy", Enabled: true, AccountID: account.Id, + Rules: []*types.PolicyRule{{ + ID: "rule-domain", Name: "Access domain resource", Enabled: true, + Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolALL, + Sources: []string{"group-src"}, + DestinationResource: types.Resource{ID: "resource-domain"}, + }}, + }) + + nm := networkMapFromComponents(t, account, "peer-src-1", validated) + + var hasDomainRoute bool + for _, r := range nm.Routes { + if r.NetworkType == route.DomainNetwork && len(r.Domains) > 0 && r.Domains[0].SafeString() == "api.example.com" { + hasDomainRoute = true + } + } + assert.True(t, hasDomainRoute, "source peer should receive domain route for domain network resource") +} + +func TestNetworkMapComponents_NetworkEmpty(t *testing.T) { + account := createComponentTestAccount() + validated := allPeersValidated(account) + + nm := networkMapFromComponents(t, account, "nonexistent-peer", validated) + + assert.NotNil(t, nm) + assert.Empty(t, nm.Peers) + assert.Empty(t, nm.FirewallRules) + assert.NotNil(t, nm.Network) +} + +func TestNetworkMapComponents_RouterExcludesOtherNetworkRoutes(t *testing.T) { + account := createComponentTestAccount() + validated := allPeersValidated(account) + + account.NetworkResources = append(account.NetworkResources, &resourceTypes.NetworkResource{ + ID: "resource-other", NetworkID: "net-other", AccountID: account.Id, Enabled: true, + Type: resourceTypes.Host, Prefix: netip.MustParsePrefix("10.200.99.1/32"), Address: "10.200.99.1/32", + }) + account.Networks = append(account.Networks, &networkTypes.Network{ + ID: "net-other", Name: "Other Net", AccountID: account.Id, + }) + account.NetworkRouters = append(account.NetworkRouters, &routerTypes.NetworkRouter{ + ID: "router-other", NetworkID: "net-other", Peer: "peer-dst-1", Enabled: true, AccountID: account.Id, + }) + account.Groups["group-res-other"] = &types.Group{ID: "group-res-other", Name: "Other resource group", + Resources: []types.Resource{{ID: "resource-other"}}, + } + account.Policies = append(account.Policies, &types.Policy{ + ID: "policy-other-resource", Name: "Other resource policy", Enabled: true, AccountID: account.Id, + Rules: []*types.PolicyRule{{ + ID: "rule-other", Name: "Other resource access", Enabled: true, + Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolALL, + Sources: []string{"group-src"}, + DestinationResource: types.Resource{ID: "resource-other"}, + }}, + }) + + nm := networkMapFromComponents(t, account, "peer-router-1", validated) + + for _, r := range nm.Routes { + assert.NotEqual(t, "10.200.99.1/32", r.Network.String(), "router-1 should NOT get routes for other network's resources") + } +} + +func createComponentTestAccount() *types.Account { + peers := map[string]*nbpeer.Peer{ + "peer-src-1": { + ID: "peer-src-1", IP: net.IP{100, 64, 0, 1}, Key: "key-src-1", DNSLabel: "src1", + Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()}, UserID: "user-1", + Meta: nbpeer.PeerSystemMeta{WtVersion: "0.35.0", GoOS: "linux"}, + }, + "peer-src-2": { + ID: "peer-src-2", IP: net.IP{100, 64, 0, 2}, Key: "key-src-2", DNSLabel: "src2", + Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()}, UserID: "user-1", + Meta: nbpeer.PeerSystemMeta{WtVersion: "0.35.0", GoOS: "linux"}, + }, + "peer-dst-1": { + ID: "peer-dst-1", IP: net.IP{100, 64, 0, 3}, Key: "key-dst-1", DNSLabel: "dst1", + Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()}, UserID: "user-2", + Meta: nbpeer.PeerSystemMeta{WtVersion: "0.35.0", GoOS: "linux"}, + }, + "peer-router-1": { + ID: "peer-router-1", IP: net.IP{100, 64, 0, 10}, Key: "key-router-1", DNSLabel: "router1", + Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()}, UserID: "user-1", + Meta: nbpeer.PeerSystemMeta{WtVersion: "0.35.0", GoOS: "linux"}, + }, + } + + groups := map[string]*types.Group{ + "group-src": {ID: "group-src", Name: "Sources", Peers: []string{"peer-src-1", "peer-src-2"}}, + "group-dst": {ID: "group-dst", Name: "Destinations", Peers: []string{"peer-dst-1"}}, + "group-all": {ID: "group-all", Name: "All", Peers: []string{"peer-src-1", "peer-src-2", "peer-dst-1", "peer-router-1"}}, + "group-res": { + ID: "group-res", Name: "Resource Group", + Resources: []types.Resource{{ID: "resource-1"}}, + }, + } + + policies := []*types.Policy{ + { + ID: "policy-base", Name: "Base connectivity", Enabled: true, + Rules: []*types.PolicyRule{{ + ID: "rule-base", Name: "Allow src <-> dst", Enabled: true, + Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolALL, + Bidirectional: true, + Sources: []string{"group-src"}, Destinations: []string{"group-dst"}, + }}, + }, + { + ID: "policy-resource", Name: "Network resource access", Enabled: true, + Rules: []*types.PolicyRule{{ + ID: "rule-resource", Name: "Source -> Resource", Enabled: true, + Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolALL, + Sources: []string{"group-src"}, + DestinationResource: types.Resource{ID: "resource-1"}, + }}, + }, + } + + routes := map[route.ID]*route.Route{ + "route-main": { + ID: "route-main", Network: netip.MustParsePrefix("192.168.10.0/24"), + Peer: peers["peer-dst-1"].Key, PeerID: "peer-dst-1", + Enabled: true, Metric: 100, + Groups: []string{"group-src", "group-dst"}, PeerGroups: []string{"group-dst"}, + }, + } + + users := map[string]*types.User{ + "user-1": {Id: "user-1", Role: types.UserRoleAdmin, IsServiceUser: false, AutoGroups: []string{"group-all"}}, + "user-2": {Id: "user-2", Role: types.UserRoleUser, IsServiceUser: false, AutoGroups: []string{"group-all"}}, + } + + account := &types.Account{ + Id: "account-components-test", Peers: peers, Groups: groups, Policies: policies, Routes: routes, + Users: users, + Network: &types.Network{ + Identifier: "net-test", Net: net.IPNet{IP: net.IP{100, 64, 0, 0}, Mask: net.CIDRMask(16, 32)}, Serial: 1, + }, + DNSSettings: types.DNSSettings{DisabledManagementGroups: []string{"group-dst"}}, + NameServerGroups: map[string]*nbdns.NameServerGroup{ + "ns-main": { + ID: "ns-main", Name: "Main NS", Enabled: true, Groups: []string{"group-src"}, + NameServers: []nbdns.NameServer{{IP: netip.MustParseAddr("8.8.8.8"), NSType: nbdns.UDPNameServerType, Port: 53}}, + }, + }, + PostureChecks: []*posture.Checks{}, + NetworkResources: []*resourceTypes.NetworkResource{ + { + ID: "resource-1", NetworkID: "net-1", AccountID: "account-components-test", Enabled: true, + Type: resourceTypes.Host, Prefix: netip.MustParsePrefix("10.200.0.1/32"), Address: "10.200.0.1/32", + }, + }, + Networks: []*networkTypes.Network{ + {ID: "net-1", Name: "Resource Net", AccountID: "account-components-test"}, + }, + NetworkRouters: []*routerTypes.NetworkRouter{ + {ID: "router-1", NetworkID: "net-1", Peer: "peer-router-1", Enabled: true, AccountID: "account-components-test"}, + }, + Settings: &types.Settings{PeerLoginExpirationEnabled: false, PeerLoginExpiration: 24 * time.Hour}, + } + + for _, p := range account.Policies { + p.AccountID = account.Id + } + for _, r := range account.Routes { + r.AccountID = account.Id + } + + return account +} diff --git a/management/server/types/networkmap_golden_test.go b/management/server/types/networkmap_golden_test.go deleted file mode 100644 index 53261f22d..000000000 --- a/management/server/types/networkmap_golden_test.go +++ /dev/null @@ -1,967 +0,0 @@ -package types_test - -import ( - "context" - "encoding/json" - "fmt" - "net" - "net/netip" - "os" - "path/filepath" - "slices" - "sort" - "testing" - "time" - - "github.com/stretchr/testify/require" - - "github.com/netbirdio/netbird/dns" - "github.com/netbirdio/netbird/management/internals/modules/zones" - resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" - routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" - networkTypes "github.com/netbirdio/netbird/management/server/networks/types" - nbpeer "github.com/netbirdio/netbird/management/server/peer" - "github.com/netbirdio/netbird/management/server/posture" - "github.com/netbirdio/netbird/management/server/types" - "github.com/netbirdio/netbird/route" -) - -const ( - numPeers = 100 - devGroupID = "group-dev" - opsGroupID = "group-ops" - allGroupID = "group-all" - sshUsersGroupID = "group-ssh-users" - routeID = route.ID("route-main") - routeHA1ID = route.ID("route-ha-1") - routeHA2ID = route.ID("route-ha-2") - policyIDDevOps = "policy-dev-ops" - policyIDAll = "policy-all" - policyIDPosture = "policy-posture" - policyIDDrop = "policy-drop" - policyIDSSH = "policy-ssh" - postureCheckID = "posture-check-ver" - networkResourceID = "res-database" - networkID = "net-database" - networkRouterID = "router-database" - nameserverGroupID = "ns-group-main" - testingPeerID = "peer-60" // A peer from the "dev" group, should receive the most detailed map. - expiredPeerID = "peer-98" // This peer will be online but with an expired session. - offlinePeerID = "peer-99" // This peer will be completely offline. - routingPeerID = "peer-95" // This peer is used for routing, it has a route to the network. - testAccountID = "account-golden-test" - userAdminID = "user-admin" - userDevID = "user-dev" - userOpsID = "user-ops" -) - -func TestGetPeerNetworkMap_Golden(t *testing.T) { - account := createTestAccountWithEntities() - - ctx := context.Background() - validatedPeersMap := make(map[string]struct{}) - for i := range numPeers { - peerID := fmt.Sprintf("peer-%d", i) - if peerID == offlinePeerID { - continue - } - validatedPeersMap[peerID] = struct{}{} - } - - resourcePolicies := account.GetResourcePoliciesMap() - routers := account.GetResourceRoutersMap() - - legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, []*zones.Zone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers()) - normalizeAndSortNetworkMap(legacyNetworkMap) - legacyJSON, err := json.MarshalIndent(toNetworkMapJSON(legacyNetworkMap), "", " ") - require.NoError(t, err, "error marshaling legacy network map to JSON") - - builder := types.NewNetworkMapBuilder(account, validatedPeersMap) - newNetworkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil) - normalizeAndSortNetworkMap(newNetworkMap) - newJSON, err := json.MarshalIndent(toNetworkMapJSON(newNetworkMap), "", " ") - require.NoError(t, err, "error marshaling new network map to JSON") - - if string(legacyJSON) != string(newJSON) { - legacyFilePath := filepath.Join("testdata", "networkmap_golden.json") - newFilePath := filepath.Join("testdata", "networkmap_golden_new.json") - - err = os.MkdirAll(filepath.Dir(legacyFilePath), 0755) - require.NoError(t, err) - - err = os.WriteFile(legacyFilePath, legacyJSON, 0644) - require.NoError(t, err) - t.Logf("Saved legacy network map to %s", legacyFilePath) - - err = os.WriteFile(newFilePath, newJSON, 0644) - require.NoError(t, err) - t.Logf("Saved new network map to %s", newFilePath) - - require.JSONEq(t, string(legacyJSON), string(newJSON), "network maps from legacy and new builder do not match") - } -} - -func BenchmarkGetPeerNetworkMap(b *testing.B) { - account := createTestAccountWithEntities() - ctx := context.Background() - validatedPeersMap := make(map[string]struct{}) - var peerIDs []string - for i := range numPeers { - peerID := fmt.Sprintf("peer-%d", i) - validatedPeersMap[peerID] = struct{}{} - peerIDs = append(peerIDs, peerID) - } - - b.ResetTimer() - b.Run("old builder", func(b *testing.B) { - for range b.N { - for _, peerID := range peerIDs { - _ = account.GetPeerNetworkMap(ctx, peerID, dns.CustomZone{}, []*zones.Zone{}, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers()) - } - } - }) - b.ResetTimer() - b.Run("new builder", func(b *testing.B) { - for range b.N { - builder := types.NewNetworkMapBuilder(account, validatedPeersMap) - for _, peerID := range peerIDs { - _ = builder.GetPeerNetworkMap(ctx, peerID, dns.CustomZone{}, nil, validatedPeersMap, nil) - } - } - }) -} - -func TestGetPeerNetworkMap_Golden_WithNewPeer(t *testing.T) { - account := createTestAccountWithEntities() - - ctx := context.Background() - validatedPeersMap := make(map[string]struct{}) - for i := range numPeers { - peerID := fmt.Sprintf("peer-%d", i) - if peerID == offlinePeerID { - continue - } - validatedPeersMap[peerID] = struct{}{} - } - - builder := types.NewNetworkMapBuilder(account, validatedPeersMap) - - newPeerID := "peer-new-101" - newPeerIP := net.IP{100, 64, 1, 1} - newPeer := &nbpeer.Peer{ - ID: newPeerID, - IP: newPeerIP, - Key: fmt.Sprintf("key-%s", newPeerID), - DNSLabel: "peernew101", - Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()}, - UserID: "user-admin", - Meta: nbpeer.PeerSystemMeta{WtVersion: "0.26.0", GoOS: "linux"}, - LastLogin: func() *time.Time { t := time.Now(); return &t }(), - } - - account.Peers[newPeerID] = newPeer - - if devGroup, exists := account.Groups[devGroupID]; exists { - devGroup.Peers = append(devGroup.Peers, newPeerID) - } - - if allGroup, exists := account.Groups[allGroupID]; exists { - allGroup.Peers = append(allGroup.Peers, newPeerID) - } - - validatedPeersMap[newPeerID] = struct{}{} - - if account.Network != nil { - account.Network.Serial++ - } - - resourcePolicies := account.GetResourcePoliciesMap() - routers := account.GetResourceRoutersMap() - - legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, []*zones.Zone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers()) - normalizeAndSortNetworkMap(legacyNetworkMap) - legacyJSON, err := json.MarshalIndent(toNetworkMapJSON(legacyNetworkMap), "", " ") - require.NoError(t, err, "error marshaling legacy network map to JSON") - - err = builder.OnPeerAddedIncremental(account, newPeerID) - require.NoError(t, err, "error adding peer to cache") - - newNetworkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil) - normalizeAndSortNetworkMap(newNetworkMap) - newJSON, err := json.MarshalIndent(toNetworkMapJSON(newNetworkMap), "", " ") - require.NoError(t, err, "error marshaling new network map to JSON") - - if string(legacyJSON) != string(newJSON) { - legacyFilePath := filepath.Join("testdata", "networkmap_golden_with_new_peer.json") - newFilePath := filepath.Join("testdata", "networkmap_golden_new_with_onpeeradded.json") - - err = os.MkdirAll(filepath.Dir(legacyFilePath), 0755) - require.NoError(t, err) - - err = os.WriteFile(legacyFilePath, legacyJSON, 0644) - require.NoError(t, err) - t.Logf("Saved legacy network map to %s", legacyFilePath) - - err = os.WriteFile(newFilePath, newJSON, 0644) - require.NoError(t, err) - t.Logf("Saved new network map to %s", newFilePath) - - require.JSONEq(t, string(legacyJSON), string(newJSON), "network maps with new peer from legacy and new builder do not match") - } -} - -func BenchmarkGetPeerNetworkMap_AfterPeerAdded(b *testing.B) { - account := createTestAccountWithEntities() - ctx := context.Background() - validatedPeersMap := make(map[string]struct{}) - var peerIDs []string - for i := range numPeers { - peerID := fmt.Sprintf("peer-%d", i) - validatedPeersMap[peerID] = struct{}{} - peerIDs = append(peerIDs, peerID) - } - builder := types.NewNetworkMapBuilder(account, validatedPeersMap) - newPeerID := "peer-new-101" - newPeer := &nbpeer.Peer{ - ID: newPeerID, - IP: net.IP{100, 64, 1, 1}, - Key: fmt.Sprintf("key-%s", newPeerID), - DNSLabel: "peernew101", - Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()}, - UserID: "user-admin", - Meta: nbpeer.PeerSystemMeta{WtVersion: "0.26.0", GoOS: "linux"}, - } - - account.Peers[newPeerID] = newPeer - account.Groups[devGroupID].Peers = append(account.Groups[devGroupID].Peers, newPeerID) - account.Groups[allGroupID].Peers = append(account.Groups[allGroupID].Peers, newPeerID) - validatedPeersMap[newPeerID] = struct{}{} - - b.ResetTimer() - b.Run("old builder after add", func(b *testing.B) { - for i := 0; i < b.N; i++ { - for _, testingPeerID := range peerIDs { - _ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, []*zones.Zone{}, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers()) - } - } - }) - - b.ResetTimer() - b.Run("new builder after add", func(b *testing.B) { - for i := 0; i < b.N; i++ { - _ = builder.OnPeerAddedIncremental(account, newPeerID) - for _, testingPeerID := range peerIDs { - _ = builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil) - } - } - }) -} - -func TestGetPeerNetworkMap_Golden_WithNewRoutingPeer(t *testing.T) { - account := createTestAccountWithEntities() - - ctx := context.Background() - validatedPeersMap := make(map[string]struct{}) - for i := range numPeers { - peerID := fmt.Sprintf("peer-%d", i) - if peerID == offlinePeerID { - continue - } - validatedPeersMap[peerID] = struct{}{} - } - - builder := types.NewNetworkMapBuilder(account, validatedPeersMap) - - newRouterID := "peer-new-router-102" - newRouterIP := net.IP{100, 64, 1, 2} - newRouter := &nbpeer.Peer{ - ID: newRouterID, - IP: newRouterIP, - Key: fmt.Sprintf("key-%s", newRouterID), - DNSLabel: "newrouter102", - Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()}, - UserID: "user-admin", - Meta: nbpeer.PeerSystemMeta{WtVersion: "0.26.0", GoOS: "linux"}, - LastLogin: func() *time.Time { t := time.Now(); return &t }(), - } - - account.Peers[newRouterID] = newRouter - - if opsGroup, exists := account.Groups[opsGroupID]; exists { - opsGroup.Peers = append(opsGroup.Peers, newRouterID) - } - - if allGroup, exists := account.Groups[allGroupID]; exists { - allGroup.Peers = append(allGroup.Peers, newRouterID) - } - - newRoute := &route.Route{ - ID: route.ID("route-new-router"), - Network: netip.MustParsePrefix("172.16.0.0/24"), - Peer: newRouter.Key, - PeerID: newRouterID, - Description: "Route from new router", - Enabled: true, - PeerGroups: []string{opsGroupID}, - Groups: []string{devGroupID, opsGroupID}, - AccessControlGroups: []string{devGroupID}, - AccountID: account.Id, - } - account.Routes[newRoute.ID] = newRoute - - validatedPeersMap[newRouterID] = struct{}{} - - if account.Network != nil { - account.Network.Serial++ - } - - resourcePolicies := account.GetResourcePoliciesMap() - routers := account.GetResourceRoutersMap() - - legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, []*zones.Zone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers()) - normalizeAndSortNetworkMap(legacyNetworkMap) - legacyJSON, err := json.MarshalIndent(toNetworkMapJSON(legacyNetworkMap), "", " ") - require.NoError(t, err, "error marshaling legacy network map to JSON") - - err = builder.OnPeerAddedIncremental(account, newRouterID) - require.NoError(t, err, "error adding router to cache") - - newNetworkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil) - normalizeAndSortNetworkMap(newNetworkMap) - newJSON, err := json.MarshalIndent(toNetworkMapJSON(newNetworkMap), "", " ") - require.NoError(t, err, "error marshaling new network map to JSON") - - if string(legacyJSON) != string(newJSON) { - legacyFilePath := filepath.Join("testdata", "networkmap_golden_with_new_router.json") - newFilePath := filepath.Join("testdata", "networkmap_golden_new_with_onpeeradded_router.json") - - err = os.MkdirAll(filepath.Dir(legacyFilePath), 0755) - require.NoError(t, err) - - err = os.WriteFile(legacyFilePath, legacyJSON, 0644) - require.NoError(t, err) - t.Logf("Saved legacy network map to %s", legacyFilePath) - - err = os.WriteFile(newFilePath, newJSON, 0644) - require.NoError(t, err) - t.Logf("Saved new network map to %s", newFilePath) - - require.JSONEq(t, string(legacyJSON), string(newJSON), "network maps with new router from legacy and new builder do not match") - } -} - -func BenchmarkGetPeerNetworkMap_AfterRouterPeerAdded(b *testing.B) { - account := createTestAccountWithEntities() - ctx := context.Background() - validatedPeersMap := make(map[string]struct{}) - var peerIDs []string - for i := range numPeers { - peerID := fmt.Sprintf("peer-%d", i) - validatedPeersMap[peerID] = struct{}{} - peerIDs = append(peerIDs, peerID) - } - builder := types.NewNetworkMapBuilder(account, validatedPeersMap) - newRouterID := "peer-new-router-102" - newRouterIP := net.IP{100, 64, 1, 2} - newRouter := &nbpeer.Peer{ - ID: newRouterID, - IP: newRouterIP, - Key: fmt.Sprintf("key-%s", newRouterID), - DNSLabel: "newrouter102", - Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()}, - UserID: "user-admin", - Meta: nbpeer.PeerSystemMeta{WtVersion: "0.26.0", GoOS: "linux"}, - LastLogin: func() *time.Time { t := time.Now(); return &t }(), - } - - account.Peers[newRouterID] = newRouter - - if opsGroup, exists := account.Groups[opsGroupID]; exists { - opsGroup.Peers = append(opsGroup.Peers, newRouterID) - } - if allGroup, exists := account.Groups[allGroupID]; exists { - allGroup.Peers = append(allGroup.Peers, newRouterID) - } - - newRoute := &route.Route{ - ID: route.ID("route-new-router"), - Network: netip.MustParsePrefix("172.16.0.0/24"), - Peer: newRouter.Key, - PeerID: newRouterID, - Description: "Route from new router", - Enabled: true, - PeerGroups: []string{opsGroupID}, - Groups: []string{devGroupID, opsGroupID}, - AccessControlGroups: []string{devGroupID}, - AccountID: account.Id, - } - account.Routes[newRoute.ID] = newRoute - - validatedPeersMap[newRouterID] = struct{}{} - - b.ResetTimer() - b.Run("old builder after add", func(b *testing.B) { - for i := 0; i < b.N; i++ { - for _, testingPeerID := range peerIDs { - _ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, []*zones.Zone{}, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers()) - } - } - }) - - b.ResetTimer() - b.Run("new builder after add", func(b *testing.B) { - for i := 0; i < b.N; i++ { - _ = builder.OnPeerAddedIncremental(account, newRouterID) - for _, testingPeerID := range peerIDs { - _ = builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil) - } - } - }) -} - -func TestGetPeerNetworkMap_Golden_WithDeletedPeer(t *testing.T) { - account := createTestAccountWithEntities() - - ctx := context.Background() - validatedPeersMap := make(map[string]struct{}) - for i := range numPeers { - peerID := fmt.Sprintf("peer-%d", i) - if peerID == offlinePeerID { - continue - } - validatedPeersMap[peerID] = struct{}{} - } - - builder := types.NewNetworkMapBuilder(account, validatedPeersMap) - - deletedPeerID := "peer-25" - - delete(account.Peers, deletedPeerID) - - if devGroup, exists := account.Groups[devGroupID]; exists { - devGroup.Peers = slices.DeleteFunc(devGroup.Peers, func(id string) bool { - return id == deletedPeerID - }) - } - - if allGroup, exists := account.Groups[allGroupID]; exists { - allGroup.Peers = slices.DeleteFunc(allGroup.Peers, func(id string) bool { - return id == deletedPeerID - }) - } - - delete(validatedPeersMap, deletedPeerID) - - if account.Network != nil { - account.Network.Serial++ - } - - resourcePolicies := account.GetResourcePoliciesMap() - routers := account.GetResourceRoutersMap() - - legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, []*zones.Zone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers()) - normalizeAndSortNetworkMap(legacyNetworkMap) - legacyJSON, err := json.MarshalIndent(toNetworkMapJSON(legacyNetworkMap), "", " ") - require.NoError(t, err, "error marshaling legacy network map to JSON") - - err = builder.OnPeerDeleted(account, deletedPeerID) - require.NoError(t, err, "error deleting peer from cache") - - newNetworkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil) - normalizeAndSortNetworkMap(newNetworkMap) - newJSON, err := json.MarshalIndent(toNetworkMapJSON(newNetworkMap), "", " ") - require.NoError(t, err, "error marshaling new network map to JSON") - - if string(legacyJSON) != string(newJSON) { - legacyFilePath := filepath.Join("testdata", "networkmap_golden_with_deleted_peer.json") - newFilePath := filepath.Join("testdata", "networkmap_golden_new_with_onpeerdeleted.json") - - err = os.MkdirAll(filepath.Dir(legacyFilePath), 0755) - require.NoError(t, err) - - err = os.WriteFile(legacyFilePath, legacyJSON, 0644) - require.NoError(t, err) - t.Logf("Saved legacy network map to %s", legacyFilePath) - - err = os.WriteFile(newFilePath, newJSON, 0644) - require.NoError(t, err) - t.Logf("Saved new network map to %s", newFilePath) - - require.JSONEq(t, string(legacyJSON), string(newJSON), "network maps with deleted peer from legacy and new builder do not match") - } -} - -func TestGetPeerNetworkMap_Golden_WithDeletedRouterPeer(t *testing.T) { - account := createTestAccountWithEntities() - - ctx := context.Background() - validatedPeersMap := make(map[string]struct{}) - for i := range numPeers { - peerID := fmt.Sprintf("peer-%d", i) - if peerID == offlinePeerID { - continue - } - validatedPeersMap[peerID] = struct{}{} - } - - builder := types.NewNetworkMapBuilder(account, validatedPeersMap) - - deletedRouterID := "peer-75" - - var affectedRoute *route.Route - for _, r := range account.Routes { - if r.PeerID == deletedRouterID { - affectedRoute = r - break - } - } - require.NotNil(t, affectedRoute, "Router peer should have a route") - - for _, group := range account.Groups { - group.Peers = slices.DeleteFunc(group.Peers, func(id string) bool { - return id == deletedRouterID - }) - } - - for routeID, r := range account.Routes { - if r.Peer == account.Peers[deletedRouterID].Key || r.PeerID == deletedRouterID { - delete(account.Routes, routeID) - } - } - delete(account.Peers, deletedRouterID) - delete(validatedPeersMap, deletedRouterID) - - if account.Network != nil { - account.Network.Serial++ - } - - resourcePolicies := account.GetResourcePoliciesMap() - routers := account.GetResourceRoutersMap() - - legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, []*zones.Zone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers()) - normalizeAndSortNetworkMap(legacyNetworkMap) - legacyJSON, err := json.MarshalIndent(toNetworkMapJSON(legacyNetworkMap), "", " ") - require.NoError(t, err, "error marshaling legacy network map to JSON") - - err = builder.OnPeerDeleted(account, deletedRouterID) - require.NoError(t, err, "error deleting routing peer from cache") - - newNetworkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil) - normalizeAndSortNetworkMap(newNetworkMap) - newJSON, err := json.MarshalIndent(toNetworkMapJSON(newNetworkMap), "", " ") - require.NoError(t, err, "error marshaling new network map to JSON") - - if string(legacyJSON) != string(newJSON) { - legacyFilePath := filepath.Join("testdata", "networkmap_golden_with_deleted_router_peer.json") - newFilePath := filepath.Join("testdata", "networkmap_golden_new_with_deleted_router.json") - - err = os.MkdirAll(filepath.Dir(legacyFilePath), 0755) - require.NoError(t, err) - - err = os.WriteFile(legacyFilePath, legacyJSON, 0644) - require.NoError(t, err) - t.Logf("Saved legacy network map to %s", legacyFilePath) - - err = os.WriteFile(newFilePath, newJSON, 0644) - require.NoError(t, err) - t.Logf("Saved new network map to %s", newFilePath) - - require.JSONEq(t, string(legacyJSON), string(newJSON), "network maps with deleted router from legacy and new builder do not match") - } -} - -func BenchmarkGetPeerNetworkMap_AfterPeerDeleted(b *testing.B) { - account := createTestAccountWithEntities() - ctx := context.Background() - validatedPeersMap := make(map[string]struct{}) - var peerIDs []string - for i := range numPeers { - peerID := fmt.Sprintf("peer-%d", i) - validatedPeersMap[peerID] = struct{}{} - peerIDs = append(peerIDs, peerID) - } - - deletedPeerID := "peer-25" - - delete(account.Peers, deletedPeerID) - account.Groups[devGroupID].Peers = slices.DeleteFunc(account.Groups[devGroupID].Peers, func(id string) bool { - return id == deletedPeerID - }) - account.Groups[allGroupID].Peers = slices.DeleteFunc(account.Groups[allGroupID].Peers, func(id string) bool { - return id == deletedPeerID - }) - delete(validatedPeersMap, deletedPeerID) - - builder := types.NewNetworkMapBuilder(account, validatedPeersMap) - - b.ResetTimer() - b.Run("old builder after delete", func(b *testing.B) { - for i := 0; i < b.N; i++ { - for _, testingPeerID := range peerIDs { - _ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, []*zones.Zone{}, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers()) - } - } - }) - - b.ResetTimer() - b.Run("new builder after delete", func(b *testing.B) { - for i := 0; i < b.N; i++ { - _ = builder.OnPeerDeleted(account, deletedPeerID) - for _, testingPeerID := range peerIDs { - _ = builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil) - } - } - }) -} - -func normalizeAndSortNetworkMap(networkMap *types.NetworkMap) { - for _, peer := range networkMap.Peers { - if peer.Status != nil { - peer.Status.LastSeen = time.Time{} - } - peer.LastLogin = &time.Time{} - } - for _, peer := range networkMap.OfflinePeers { - if peer.Status != nil { - peer.Status.LastSeen = time.Time{} - } - peer.LastLogin = &time.Time{} - } - - sort.Slice(networkMap.Peers, func(i, j int) bool { return networkMap.Peers[i].ID < networkMap.Peers[j].ID }) - sort.Slice(networkMap.OfflinePeers, func(i, j int) bool { return networkMap.OfflinePeers[i].ID < networkMap.OfflinePeers[j].ID }) - sort.Slice(networkMap.Routes, func(i, j int) bool { return networkMap.Routes[i].ID < networkMap.Routes[j].ID }) - - sort.Slice(networkMap.FirewallRules, func(i, j int) bool { - r1, r2 := networkMap.FirewallRules[i], networkMap.FirewallRules[j] - if r1.PeerIP != r2.PeerIP { - return r1.PeerIP < r2.PeerIP - } - if r1.Protocol != r2.Protocol { - return r1.Protocol < r2.Protocol - } - if r1.Direction != r2.Direction { - return r1.Direction < r2.Direction - } - if r1.Action != r2.Action { - return r1.Action < r2.Action - } - return r1.Port < r2.Port - }) - - sort.Slice(networkMap.RoutesFirewallRules, func(i, j int) bool { - r1, r2 := networkMap.RoutesFirewallRules[i], networkMap.RoutesFirewallRules[j] - if r1.RouteID != r2.RouteID { - return r1.RouteID < r2.RouteID - } - if r1.Action != r2.Action { - return r1.Action < r2.Action - } - if r1.Destination != r2.Destination { - return r1.Destination < r2.Destination - } - if len(r1.SourceRanges) > 0 && len(r2.SourceRanges) > 0 { - if r1.SourceRanges[0] != r2.SourceRanges[0] { - return r1.SourceRanges[0] < r2.SourceRanges[0] - } - } - return r1.Port < r2.Port - }) - - for _, ranges := range networkMap.RoutesFirewallRules { - sort.Slice(ranges.SourceRanges, func(i, j int) bool { - return ranges.SourceRanges[i] < ranges.SourceRanges[j] - }) - } -} - -type networkMapJSON struct { - Peers []*nbpeer.Peer `json:"Peers"` - Network *types.Network `json:"Network"` - Routes []*route.Route `json:"Routes"` - DNSConfig dns.Config `json:"DNSConfig"` - OfflinePeers []*nbpeer.Peer `json:"OfflinePeers"` - FirewallRules []*types.FirewallRule `json:"FirewallRules"` - RoutesFirewallRules []*types.RouteFirewallRule `json:"RoutesFirewallRules"` - ForwardingRules []*types.ForwardingRule `json:"ForwardingRules"` - AuthorizedUsers map[string][]string `json:"AuthorizedUsers,omitempty"` - EnableSSH bool `json:"EnableSSH"` -} - -func toNetworkMapJSON(nm *types.NetworkMap) *networkMapJSON { - result := &networkMapJSON{ - Peers: nm.Peers, - Network: nm.Network, - Routes: nm.Routes, - DNSConfig: nm.DNSConfig, - OfflinePeers: nm.OfflinePeers, - FirewallRules: nm.FirewallRules, - RoutesFirewallRules: nm.RoutesFirewallRules, - ForwardingRules: nm.ForwardingRules, - EnableSSH: nm.EnableSSH, - } - - if len(nm.AuthorizedUsers) > 0 { - result.AuthorizedUsers = make(map[string][]string) - localUsers := make([]string, 0, len(nm.AuthorizedUsers)) - for localUser := range nm.AuthorizedUsers { - localUsers = append(localUsers, localUser) - } - sort.Strings(localUsers) - - for _, localUser := range localUsers { - userIDs := nm.AuthorizedUsers[localUser] - sortedUserIDs := make([]string, 0, len(userIDs)) - for userID := range userIDs { - sortedUserIDs = append(sortedUserIDs, userID) - } - sort.Strings(sortedUserIDs) - result.AuthorizedUsers[localUser] = sortedUserIDs - } - } - - return result -} - -func createTestAccountWithEntities() *types.Account { - peers := make(map[string]*nbpeer.Peer) - devGroupPeers, opsGroupPeers, allGroupPeers := []string{}, []string{}, []string{} - - for i := range numPeers { - peerID := fmt.Sprintf("peer-%d", i) - ip := net.IP{100, 64, 0, byte(i + 1)} - wtVersion := "0.25.0" - if i%2 == 0 { - wtVersion = "0.40.0" - } - - p := &nbpeer.Peer{ - ID: peerID, IP: ip, Key: fmt.Sprintf("key-%s", peerID), DNSLabel: fmt.Sprintf("peer%d", i+1), - Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()}, - UserID: "user-admin", Meta: nbpeer.PeerSystemMeta{WtVersion: wtVersion, GoOS: "linux"}, - } - - if peerID == expiredPeerID { - p.LoginExpirationEnabled = true - pastTimestamp := time.Now().Add(-2 * time.Hour) - p.LastLogin = &pastTimestamp - } - - peers[peerID] = p - allGroupPeers = append(allGroupPeers, peerID) - if i < numPeers/2 { - devGroupPeers = append(devGroupPeers, peerID) - } else { - opsGroupPeers = append(opsGroupPeers, peerID) - } - - } - - groups := map[string]*types.Group{ - allGroupID: {ID: allGroupID, Name: "All", Peers: allGroupPeers}, - devGroupID: {ID: devGroupID, Name: "Developers", Peers: devGroupPeers}, - opsGroupID: {ID: opsGroupID, Name: "Operations", Peers: opsGroupPeers}, - sshUsersGroupID: {ID: sshUsersGroupID, Name: "SSH Users", Peers: []string{}}, - } - - policies := []*types.Policy{ - { - ID: policyIDAll, Name: "Default-Allow", Enabled: true, - Rules: []*types.PolicyRule{{ - ID: policyIDAll, Name: "Allow All", Enabled: true, Action: types.PolicyTrafficActionAccept, - Protocol: types.PolicyRuleProtocolALL, Bidirectional: true, - Sources: []string{allGroupID}, Destinations: []string{allGroupID}, - }}, - }, - { - ID: policyIDDevOps, Name: "Dev to Ops Web Access", Enabled: true, - Rules: []*types.PolicyRule{{ - ID: policyIDDevOps, Name: "Dev -> Ops (HTTP Range)", Enabled: true, Action: types.PolicyTrafficActionAccept, - Protocol: types.PolicyRuleProtocolTCP, Bidirectional: false, - PortRanges: []types.RulePortRange{{Start: 8080, End: 8090}}, - Sources: []string{devGroupID}, Destinations: []string{opsGroupID}, - }}, - }, - { - ID: policyIDDrop, Name: "Drop DB traffic", Enabled: true, - Rules: []*types.PolicyRule{{ - ID: policyIDDrop, Name: "Drop DB", Enabled: true, Action: types.PolicyTrafficActionDrop, - Protocol: types.PolicyRuleProtocolTCP, Ports: []string{"5432"}, Bidirectional: true, - Sources: []string{devGroupID}, Destinations: []string{opsGroupID}, - }}, - }, - { - ID: policyIDPosture, Name: "Posture Check for DB Resource", Enabled: true, - SourcePostureChecks: []string{postureCheckID}, - Rules: []*types.PolicyRule{{ - ID: policyIDPosture, Name: "Allow DB Access", Enabled: true, Action: types.PolicyTrafficActionAccept, - Protocol: types.PolicyRuleProtocolALL, Bidirectional: true, - Sources: []string{opsGroupID}, DestinationResource: types.Resource{ID: networkResourceID}, - }}, - }, - { - ID: policyIDSSH, Name: "SSH Access Policy", Enabled: true, - Rules: []*types.PolicyRule{{ - ID: policyIDSSH, Name: "Allow SSH to Ops", Enabled: true, Action: types.PolicyTrafficActionAccept, - Protocol: types.PolicyRuleProtocolNetbirdSSH, Bidirectional: false, - Sources: []string{devGroupID}, Destinations: []string{opsGroupID}, - AuthorizedGroups: map[string][]string{sshUsersGroupID: {"root", "admin"}}, - }}, - }, - } - - routes := map[route.ID]*route.Route{ - routeID: { - ID: routeID, Network: netip.MustParsePrefix("192.168.10.0/24"), - Peer: peers["peer-75"].Key, - PeerID: "peer-75", - Description: "Route to internal resource", Enabled: true, - PeerGroups: []string{devGroupID, opsGroupID}, - Groups: []string{devGroupID, opsGroupID}, - AccessControlGroups: []string{devGroupID}, - }, - routeHA1ID: { - ID: routeHA1ID, Network: netip.MustParsePrefix("10.10.0.0/16"), - Peer: peers["peer-80"].Key, - PeerID: "peer-80", - Description: "HA Route 1", Enabled: true, Metric: 1000, - PeerGroups: []string{allGroupID}, - Groups: []string{allGroupID}, - AccessControlGroups: []string{allGroupID}, - }, - routeHA2ID: { - ID: routeHA2ID, Network: netip.MustParsePrefix("10.10.0.0/16"), - Peer: peers["peer-90"].Key, - PeerID: "peer-90", - Description: "HA Route 2", Enabled: true, Metric: 900, - PeerGroups: []string{devGroupID, opsGroupID}, - Groups: []string{devGroupID, opsGroupID}, - AccessControlGroups: []string{allGroupID}, - }, - } - - users := map[string]*types.User{ - userAdminID: {Id: userAdminID, Role: types.UserRoleAdmin, IsServiceUser: false, AccountID: testAccountID, AutoGroups: []string{allGroupID}}, - userDevID: {Id: userDevID, Role: types.UserRoleUser, IsServiceUser: false, AccountID: testAccountID, AutoGroups: []string{sshUsersGroupID, devGroupID}}, - userOpsID: {Id: userOpsID, Role: types.UserRoleUser, IsServiceUser: false, AccountID: testAccountID, AutoGroups: []string{sshUsersGroupID, opsGroupID}}, - } - - account := &types.Account{ - Id: testAccountID, Peers: peers, Groups: groups, Policies: policies, Routes: routes, - Users: users, - Network: &types.Network{ - Identifier: "net-golden-test", Net: net.IPNet{IP: net.IP{100, 64, 0, 0}, Mask: net.CIDRMask(16, 32)}, Serial: 1, - }, - DNSSettings: types.DNSSettings{DisabledManagementGroups: []string{opsGroupID}}, - NameServerGroups: map[string]*dns.NameServerGroup{ - nameserverGroupID: { - ID: nameserverGroupID, Name: "Main NS", Enabled: true, Groups: []string{devGroupID}, - NameServers: []dns.NameServer{{IP: netip.MustParseAddr("8.8.8.8"), NSType: dns.UDPNameServerType, Port: 53}}, - }, - }, - PostureChecks: []*posture.Checks{ - {ID: postureCheckID, Name: "Check version", Checks: posture.ChecksDefinition{ - NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.26.0"}, - }}, - }, - NetworkResources: []*resourceTypes.NetworkResource{ - {ID: networkResourceID, NetworkID: networkID, AccountID: testAccountID, Enabled: true, Address: "db.netbird.cloud"}, - }, - Networks: []*networkTypes.Network{{ID: networkID, Name: "DB Network", AccountID: testAccountID}}, - NetworkRouters: []*routerTypes.NetworkRouter{ - {ID: networkRouterID, NetworkID: networkID, Peer: routingPeerID, Enabled: true, AccountID: testAccountID}, - }, - Settings: &types.Settings{PeerLoginExpirationEnabled: true, PeerLoginExpiration: 1 * time.Hour}, - } - - for _, p := range account.Policies { - p.AccountID = account.Id - } - for _, r := range account.Routes { - r.AccountID = account.Id - } - - return account -} - -func TestGetPeerNetworkMap_Golden_New_WithOnPeerAddedRouter_Batched(t *testing.T) { - account := createTestAccountWithEntities() - - ctx := context.Background() - validatedPeersMap := make(map[string]struct{}) - for i := range numPeers { - peerID := fmt.Sprintf("peer-%d", i) - if peerID == offlinePeerID { - continue - } - validatedPeersMap[peerID] = struct{}{} - } - - builder := types.NewNetworkMapBuilder(account, validatedPeersMap) - - newRouterID := "peer-new-router-102" - newRouterIP := net.IP{100, 64, 1, 2} - newRouter := &nbpeer.Peer{ - ID: newRouterID, - IP: newRouterIP, - Key: fmt.Sprintf("key-%s", newRouterID), - DNSLabel: "newrouter102", - Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()}, - UserID: "user-admin", - Meta: nbpeer.PeerSystemMeta{WtVersion: "0.26.0", GoOS: "linux"}, - LastLogin: func() *time.Time { t := time.Now(); return &t }(), - } - - account.Peers[newRouterID] = newRouter - - if opsGroup, exists := account.Groups[opsGroupID]; exists { - opsGroup.Peers = append(opsGroup.Peers, newRouterID) - } - if allGroup, exists := account.Groups[allGroupID]; exists { - allGroup.Peers = append(allGroup.Peers, newRouterID) - } - - newRoute := &route.Route{ - ID: route.ID("route-new-router"), - Network: netip.MustParsePrefix("172.16.0.0/24"), - Peer: newRouter.Key, - PeerID: newRouterID, - Description: "Route from new router", - Enabled: true, - PeerGroups: []string{opsGroupID}, - Groups: []string{devGroupID, opsGroupID}, - AccessControlGroups: []string{devGroupID}, - AccountID: account.Id, - } - account.Routes[newRoute.ID] = newRoute - - validatedPeersMap[newRouterID] = struct{}{} - - if account.Network != nil { - account.Network.Serial++ - } - - builder.EnqueuePeersForIncrementalAdd(account, newRouterID) - - time.Sleep(100 * time.Millisecond) - - networkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil) - - normalizeAndSortNetworkMap(networkMap) - - jsonData, err := json.MarshalIndent(networkMap, "", " ") - require.NoError(t, err, "error marshaling network map to JSON") - - goldenFilePath := filepath.Join("testdata", "networkmap_golden_new_with_onpeeradded_router.json") - - t.Log("Update golden file with OnPeerAdded router...") - err = os.MkdirAll(filepath.Dir(goldenFilePath), 0755) - require.NoError(t, err) - err = os.WriteFile(goldenFilePath, jsonData, 0644) - require.NoError(t, err) - - expectedJSON, err := os.ReadFile(goldenFilePath) - require.NoError(t, err, "error reading golden file") - - require.JSONEq(t, string(expectedJSON), string(jsonData), "network map from NEW builder with OnPeerAdded router does not match golden file") -} diff --git a/management/server/types/networkmapbuilder.go b/management/server/types/networkmapbuilder.go deleted file mode 100644 index 6448b8403..000000000 --- a/management/server/types/networkmapbuilder.go +++ /dev/null @@ -1,2317 +0,0 @@ -package types - -import ( - "context" - "fmt" - "slices" - "strconv" - "strings" - "sync" - "time" - - log "github.com/sirupsen/logrus" - "golang.org/x/exp/maps" - - "github.com/netbirdio/netbird/client/ssh/auth" - nbdns "github.com/netbirdio/netbird/dns" - "github.com/netbirdio/netbird/management/internals/modules/zones" - resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" - routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" - nbpeer "github.com/netbirdio/netbird/management/server/peer" - "github.com/netbirdio/netbird/management/server/telemetry" - "github.com/netbirdio/netbird/route" -) - -const ( - allPeers = "0.0.0.0" - allWildcard = "0.0.0.0/0" - v6AllWildcard = "::/0" - fw = "fw:" - rfw = "route-fw:" - - szAddPeerBatch = 10 - maxPeerAddRetries = 20 -) - -type NetworkMapCache struct { - globalRoutes map[route.ID]*route.Route - globalRules map[string]*FirewallRule //ruleId - globalRouteRules map[string]*RouteFirewallRule //ruleId - globalPeers map[string]*nbpeer.Peer - - groupToPeers map[string][]string - peerToGroups map[string][]string - policyToRules map[string][]*PolicyRule //policyId - groupToPolicies map[string][]*Policy - groupToRoutes map[string][]*route.Route - peerToRoutes map[string][]*route.Route - - peerACLs map[string]*PeerACLView - peerRoutes map[string]*PeerRoutesView - peerDNS map[string]*nbdns.Config - peerSSH map[string]*PeerSSHView - - groupIDToUserIDs map[string][]string - allowedUserIDs map[string]struct{} - - resourceRouters map[string]map[string]*routerTypes.NetworkRouter - resourcePolicies map[string][]*Policy - - globalResources map[string]*resourceTypes.NetworkResource // resourceId - - acgToRoutes map[string]map[route.ID]*RouteOwnerInfo // routeID -> owner info - noACGRoutes map[route.ID]*RouteOwnerInfo - - mu sync.RWMutex -} - -type RouteOwnerInfo struct { - PeerID string - RouteID route.ID -} - -type PeerACLView struct { - ConnectedPeerIDs []string - FirewallRuleIDs []string -} - -type PeerRoutesView struct { - OwnRouteIDs []route.ID - NetworkResourceIDs []route.ID - InheritedRouteIDs []route.ID - RouteFirewallRuleIDs []string -} - -type PeerSSHView struct { - EnableSSH bool - AuthorizedUsers map[string]map[string]struct{} -} - -type NetworkMapBuilder struct { - account *Account - cache *NetworkMapCache - validatedPeers map[string]struct{} - - apb addPeerBatch -} - -type addPeerBatch struct { - mu sync.Mutex - sg *sync.Cond - ids []string - la *Account - retryCount map[string]int -} - -func NewNetworkMapBuilder(account *Account, validatedPeers map[string]struct{}) *NetworkMapBuilder { - builder := &NetworkMapBuilder{ - cache: &NetworkMapCache{ - globalRoutes: make(map[route.ID]*route.Route), - globalRules: make(map[string]*FirewallRule), - globalRouteRules: make(map[string]*RouteFirewallRule), - globalPeers: make(map[string]*nbpeer.Peer), - groupToPeers: make(map[string][]string), - peerToGroups: make(map[string][]string), - policyToRules: make(map[string][]*PolicyRule), - groupToPolicies: make(map[string][]*Policy), - groupToRoutes: make(map[string][]*route.Route), - peerToRoutes: make(map[string][]*route.Route), - peerACLs: make(map[string]*PeerACLView), - peerRoutes: make(map[string]*PeerRoutesView), - peerDNS: make(map[string]*nbdns.Config), - peerSSH: make(map[string]*PeerSSHView), - groupIDToUserIDs: make(map[string][]string), - allowedUserIDs: make(map[string]struct{}), - globalResources: make(map[string]*resourceTypes.NetworkResource), - acgToRoutes: make(map[string]map[route.ID]*RouteOwnerInfo), - noACGRoutes: make(map[route.ID]*RouteOwnerInfo), - }, - validatedPeers: make(map[string]struct{}), - } - builder.apb.sg = sync.NewCond(&builder.apb.mu) - builder.apb.ids = make([]string, 0, szAddPeerBatch) - builder.apb.la = account - builder.apb.retryCount = make(map[string]int) - - maps.Copy(builder.validatedPeers, validatedPeers) - - builder.initialBuild(account) - - go builder.incAddPeerLoop() - return builder -} - -func (b *NetworkMapBuilder) initialBuild(account *Account) { - b.cache.mu.Lock() - defer b.cache.mu.Unlock() - - b.account = account - - start := time.Now() - - b.buildGlobalIndexes(account) - - resourceRouters := account.GetResourceRoutersMap() - resourcePolicies := account.GetResourcePoliciesMap() - b.cache.resourceRouters = resourceRouters - b.cache.resourcePolicies = resourcePolicies - - for peerID := range account.Peers { - b.buildPeerACLView(account, peerID) - b.buildPeerRoutesView(account, peerID) - b.buildPeerDNSView(account, peerID) - } - - log.Debugf("NetworkMapBuilder: Initial build completed in %v for account %s", time.Since(start), account.Id) -} - -func (b *NetworkMapBuilder) buildGlobalIndexes(account *Account) { - clear(b.cache.globalPeers) - clear(b.cache.groupToPeers) - clear(b.cache.peerToGroups) - clear(b.cache.policyToRules) - clear(b.cache.groupToPolicies) - clear(b.cache.globalRoutes) - clear(b.cache.globalRules) - clear(b.cache.globalRouteRules) - clear(b.cache.globalResources) - clear(b.cache.groupToRoutes) - clear(b.cache.peerToRoutes) - clear(b.cache.acgToRoutes) - clear(b.cache.noACGRoutes) - clear(b.cache.groupIDToUserIDs) - clear(b.cache.allowedUserIDs) - clear(b.cache.peerSSH) - - maps.Copy(b.cache.globalPeers, account.Peers) - - b.cache.groupIDToUserIDs = account.GetActiveGroupUsers() - b.cache.allowedUserIDs = b.buildAllowedUserIDs(account) - - for groupID, group := range account.Groups { - peersCopy := make([]string, len(group.Peers)) - copy(peersCopy, group.Peers) - b.cache.groupToPeers[groupID] = peersCopy - - for _, peerID := range group.Peers { - b.cache.peerToGroups[peerID] = append(b.cache.peerToGroups[peerID], groupID) - } - } - - for _, policy := range account.Policies { - if !policy.Enabled { - continue - } - - b.cache.policyToRules[policy.ID] = policy.Rules - - affectedGroups := make(map[string]struct{}) - for _, rule := range policy.Rules { - if !rule.Enabled { - continue - } - - for _, groupID := range rule.Sources { - affectedGroups[groupID] = struct{}{} - } - for _, groupID := range rule.Destinations { - affectedGroups[groupID] = struct{}{} - } - if rule.SourceResource.Type == ResourceTypePeer && rule.SourceResource.ID != "" { - groupId := rule.SourceResource.ID - affectedGroups[groupId] = struct{}{} - b.cache.peerToGroups[rule.SourceResource.ID] = append(b.cache.peerToGroups[rule.SourceResource.ID], groupId) - } - if rule.DestinationResource.Type == ResourceTypePeer && rule.DestinationResource.ID != "" { - groupId := rule.DestinationResource.ID - affectedGroups[groupId] = struct{}{} - b.cache.peerToGroups[rule.DestinationResource.ID] = append(b.cache.peerToGroups[rule.DestinationResource.ID], groupId) - } - } - - for groupID := range affectedGroups { - b.cache.groupToPolicies[groupID] = append(b.cache.groupToPolicies[groupID], policy) - } - } - - for _, resource := range account.NetworkResources { - if !resource.Enabled { - continue - } - b.cache.globalResources[resource.ID] = resource - } - - for _, r := range account.Routes { - if !r.Enabled { - continue - } - for _, groupID := range r.PeerGroups { - b.cache.groupToRoutes[groupID] = append(b.cache.groupToRoutes[groupID], r) - } - if r.Peer != "" { - if peer, ok := b.cache.globalPeers[r.Peer]; ok { - b.cache.peerToRoutes[peer.ID] = append(b.cache.peerToRoutes[peer.ID], r) - } - } - } -} - -func (b *NetworkMapBuilder) buildPeerACLView(account *Account, peerID string) { - peer := account.GetPeer(peerID) - if peer == nil { - return - } - - allPotentialPeers, firewallRules, authorizedUsers, sshEnabled := b.getPeerConnectionResources(account, peer, b.validatedPeers) - - isRouter, networkResourcesRoutes, sourcePeers := b.getNetworkResourcesForPeer(account, peer) - - var emptyExpiredPeers []*nbpeer.Peer - finalAllPeers := b.addNetworksRoutingPeers( - networkResourcesRoutes, - peer, - allPotentialPeers, - emptyExpiredPeers, - isRouter, - sourcePeers, - ) - - view := &PeerACLView{ - ConnectedPeerIDs: make([]string, 0, len(finalAllPeers)), - FirewallRuleIDs: make([]string, 0, len(firewallRules)), - } - - for _, p := range finalAllPeers { - view.ConnectedPeerIDs = append(view.ConnectedPeerIDs, p.ID) - } - - for _, rule := range firewallRules { - ruleID := b.generateFirewallRuleID(rule) - view.FirewallRuleIDs = append(view.FirewallRuleIDs, ruleID) - b.cache.globalRules[ruleID] = rule - } - - b.cache.peerACLs[peerID] = view - b.cache.peerSSH[peerID] = &PeerSSHView{ - EnableSSH: sshEnabled, - AuthorizedUsers: authorizedUsers, - } -} - -func (b *NetworkMapBuilder) getPeerConnectionResources(account *Account, peer *nbpeer.Peer, - validatedPeersMap map[string]struct{}, -) ([]*nbpeer.Peer, []*FirewallRule, map[string]map[string]struct{}, bool) { - peerID := peer.ID - ctx := context.Background() - - peerGroups := b.cache.peerToGroups[peerID] - peerGroupsMap := make(map[string]struct{}, len(peerGroups)) - for _, groupID := range peerGroups { - peerGroupsMap[groupID] = struct{}{} - } - - rulesExists := make(map[string]struct{}) - peersExists := make(map[string]struct{}) - fwRules := make([]*FirewallRule, 0) - peers := make([]*nbpeer.Peer, 0) - - authorizedUsers := make(map[string]map[string]struct{}) - sshEnabled := false - - for _, group := range peerGroups { - policies := b.cache.groupToPolicies[group] - for _, policy := range policies { - if isValid := account.validatePostureChecksOnPeer(ctx, policy.SourcePostureChecks, peerID); !isValid { - continue - } - rules := b.cache.policyToRules[policy.ID] - for _, rule := range rules { - var sourcePeers, destinationPeers []*nbpeer.Peer - var peerInSources, peerInDestinations bool - - if rule.SourceResource.Type == ResourceTypePeer && rule.SourceResource.ID != "" { - peerInSources = rule.SourceResource.ID == peerID - } else { - peerInSources = b.isPeerInGroupscached(rule.Sources, peerGroupsMap) - } - - if rule.DestinationResource.Type == ResourceTypePeer && rule.DestinationResource.ID != "" { - peerInDestinations = rule.DestinationResource.ID == peerID - } else { - peerInDestinations = b.isPeerInGroupscached(rule.Destinations, peerGroupsMap) - } - - if !peerInSources && !peerInDestinations { - continue - } - - if rule.SourceResource.Type == ResourceTypePeer && rule.SourceResource.ID != "" { - peer := account.GetPeer(rule.SourceResource.ID) - if peer != nil { - sourcePeers = []*nbpeer.Peer{peer} - } - } else { - sourcePeers = b.getPeersFromGroupscached(account, rule.Sources, peerID, policy.SourcePostureChecks, validatedPeersMap) - } - - if rule.DestinationResource.Type == ResourceTypePeer && rule.DestinationResource.ID != "" { - peer := account.GetPeer(rule.DestinationResource.ID) - if peer != nil { - destinationPeers = []*nbpeer.Peer{peer} - } - } else { - destinationPeers = b.getPeersFromGroupscached(account, rule.Destinations, peerID, nil, validatedPeersMap) - } - - if rule.Bidirectional { - if peerInSources { - b.generateResourcescached( - rule, destinationPeers, FirewallRuleDirectionIN, - peer, &peers, &fwRules, peersExists, rulesExists, - ) - } - if peerInDestinations { - b.generateResourcescached( - rule, sourcePeers, FirewallRuleDirectionOUT, - peer, &peers, &fwRules, peersExists, rulesExists, - ) - } - } - - if peerInSources { - b.generateResourcescached( - rule, destinationPeers, FirewallRuleDirectionOUT, - peer, &peers, &fwRules, peersExists, rulesExists, - ) - } - - if peerInDestinations { - b.generateResourcescached( - rule, sourcePeers, FirewallRuleDirectionIN, - peer, &peers, &fwRules, peersExists, rulesExists, - ) - - if rule.Protocol == PolicyRuleProtocolNetbirdSSH { - sshEnabled = true - switch { - case len(rule.AuthorizedGroups) > 0: - for groupID, localUsers := range rule.AuthorizedGroups { - userIDs, ok := b.cache.groupIDToUserIDs[groupID] - if !ok { - continue - } - - if len(localUsers) == 0 { - localUsers = []string{auth.Wildcard} - } - - for _, localUser := range localUsers { - if authorizedUsers[localUser] == nil { - authorizedUsers[localUser] = make(map[string]struct{}) - } - for _, userID := range userIDs { - authorizedUsers[localUser][userID] = struct{}{} - } - } - } - case rule.AuthorizedUser != "": - if authorizedUsers[auth.Wildcard] == nil { - authorizedUsers[auth.Wildcard] = make(map[string]struct{}) - } - authorizedUsers[auth.Wildcard][rule.AuthorizedUser] = struct{}{} - default: - authorizedUsers[auth.Wildcard] = maps.Clone(b.cache.allowedUserIDs) - } - } else if policyRuleImpliesLegacySSH(rule) && peer.SSHEnabled { - sshEnabled = true - authorizedUsers[auth.Wildcard] = maps.Clone(b.cache.allowedUserIDs) - } - } - } - } - } - - return peers, fwRules, authorizedUsers, sshEnabled -} - -func (b *NetworkMapBuilder) isPeerInGroupscached(groupIDs []string, peerGroupsMap map[string]struct{}) bool { - for _, groupID := range groupIDs { - if _, exists := peerGroupsMap[groupID]; exists { - return true - } - } - return false -} - -func (b *NetworkMapBuilder) getPeersFromGroupscached(account *Account, groupIDs []string, - excludePeerID string, postureChecksIDs []string, validatedPeersMap map[string]struct{}, -) []*nbpeer.Peer { - ctx := context.Background() - uniquePeers := make(map[string]*nbpeer.Peer) - - for _, groupID := range groupIDs { - peerIDs := b.cache.groupToPeers[groupID] - for _, peerID := range peerIDs { - if peerID == excludePeerID { - continue - } - - if _, ok := validatedPeersMap[peerID]; !ok { - continue - } - - peer := b.cache.globalPeers[peerID] - if peer == nil { - continue - } - - if len(postureChecksIDs) > 0 { - if !account.validatePostureChecksOnPeer(ctx, postureChecksIDs, peerID) { - continue - } - } - - uniquePeers[peerID] = peer - } - } - - result := make([]*nbpeer.Peer, 0, len(uniquePeers)) - for _, peer := range uniquePeers { - result = append(result, peer) - } - - return result -} - -func (b *NetworkMapBuilder) generateResourcescached( - rule *PolicyRule, groupPeers []*nbpeer.Peer, direction int, targetPeer *nbpeer.Peer, - peers *[]*nbpeer.Peer, rules *[]*FirewallRule, peersExists map[string]struct{}, rulesExists map[string]struct{}, -) { - for _, peer := range groupPeers { - if peer == nil { - continue - } - if _, ok := peersExists[peer.ID]; !ok { - *peers = append(*peers, peer) - peersExists[peer.ID] = struct{}{} - } - - fr := FirewallRule{ - PolicyID: rule.ID, - PeerIP: peer.IP.String(), - Direction: direction, - Action: string(rule.Action), - Protocol: firewallRuleProtocol(rule.Protocol), - } - - var s strings.Builder - s.WriteString(rule.ID) - s.WriteString(fr.PeerIP) - s.WriteString(strconv.Itoa(direction)) - s.WriteString(fr.Protocol) - s.WriteString(fr.Action) - s.WriteString(strings.Join(rule.Ports, ",")) - - ruleID := s.String() - - if _, ok := rulesExists[ruleID]; ok { - continue - } - rulesExists[ruleID] = struct{}{} - - if len(rule.Ports) == 0 && len(rule.PortRanges) == 0 { - *rules = append(*rules, &fr) - continue - } - - *rules = append(*rules, expandPortsAndRanges(fr, rule, targetPeer)...) - } -} - -func (b *NetworkMapBuilder) getNetworkResourcesForPeer(account *Account, peer *nbpeer.Peer) (bool, []*route.Route, map[string]struct{}) { - ctx := context.Background() - peerID := peer.ID - - var isRoutingPeer bool - var routes []*route.Route - allSourcePeers := make(map[string]struct{}) - - peerGroups := b.cache.peerToGroups[peerID] - peerGroupsMap := make(map[string]struct{}, len(peerGroups)) - for _, groupID := range peerGroups { - peerGroupsMap[groupID] = struct{}{} - } - - for _, resource := range b.cache.globalResources { - - networkRoutingPeers := b.cache.resourceRouters[resource.NetworkID] - resourcePolicies := b.cache.resourcePolicies[resource.ID] - if len(resourcePolicies) == 0 { - continue - } - - isRouterForThisResource := false - - if networkRoutingPeers != nil { - if router, ok := networkRoutingPeers[peerID]; ok && router.Enabled { - isRoutingPeer = true - isRouterForThisResource = true - if rt := b.createNetworkResourceRoutes(resource, peerID, router, resourcePolicies); rt != nil { - routes = append(routes, rt) - } - } - } - - hasAccessAsClient := false - if !isRouterForThisResource { - for _, policy := range resourcePolicies { - if b.isPeerInGroupscached(policy.SourceGroups(), peerGroupsMap) { - if account.validatePostureChecksOnPeer(ctx, policy.SourcePostureChecks, peerID) { - hasAccessAsClient = true - break - } - } - } - } - - if hasAccessAsClient && networkRoutingPeers != nil { - for routerPeerID, router := range networkRoutingPeers { - if router.Enabled { - if rt := b.createNetworkResourceRoutes(resource, routerPeerID, router, resourcePolicies); rt != nil { - routes = append(routes, rt) - } - } - } - } - - if isRouterForThisResource { - for _, policy := range resourcePolicies { - var peersWithAccess []*nbpeer.Peer - if policy.Rules[0].SourceResource.Type == ResourceTypePeer && policy.Rules[0].SourceResource.ID != "" { - peersWithAccess = []*nbpeer.Peer{peer} - } else { - peersWithAccess = b.getPeersFromGroupscached(account, policy.SourceGroups(), "", policy.SourcePostureChecks, b.validatedPeers) - } - for _, p := range peersWithAccess { - allSourcePeers[p.ID] = struct{}{} - } - } - } - } - - return isRoutingPeer, routes, allSourcePeers -} - -func (b *NetworkMapBuilder) createNetworkResourceRoutes( - resource *resourceTypes.NetworkResource, routerPeerID string, - router *routerTypes.NetworkRouter, resourcePolicies []*Policy, -) *route.Route { - if len(resourcePolicies) > 0 { - peer := b.cache.globalPeers[routerPeerID] - if peer != nil { - return resource.ToRoute(peer, router) - } - } - return nil -} - -func (b *NetworkMapBuilder) addNetworksRoutingPeers( - networkResourcesRoutes []*route.Route, peer *nbpeer.Peer, peersToConnect []*nbpeer.Peer, - expiredPeers []*nbpeer.Peer, isRouter bool, sourcePeers map[string]struct{}, -) []*nbpeer.Peer { - - networkRoutesPeers := make(map[string]struct{}, len(networkResourcesRoutes)) - for _, r := range networkResourcesRoutes { - networkRoutesPeers[r.PeerID] = struct{}{} - } - - delete(sourcePeers, peer.ID) - delete(networkRoutesPeers, peer.ID) - - for _, existingPeer := range peersToConnect { - delete(sourcePeers, existingPeer.ID) - delete(networkRoutesPeers, existingPeer.ID) - } - for _, expPeer := range expiredPeers { - delete(sourcePeers, expPeer.ID) - delete(networkRoutesPeers, expPeer.ID) - } - - missingPeers := make(map[string]struct{}, len(sourcePeers)+len(networkRoutesPeers)) - if isRouter { - for p := range sourcePeers { - missingPeers[p] = struct{}{} - } - } - for p := range networkRoutesPeers { - missingPeers[p] = struct{}{} - } - - for p := range missingPeers { - if missingPeer := b.cache.globalPeers[p]; missingPeer != nil { - peersToConnect = append(peersToConnect, missingPeer) - } - } - - return peersToConnect -} - -func (b *NetworkMapBuilder) buildPeerRoutesView(account *Account, peerID string) { - ctx := context.Background() - peer := account.GetPeer(peerID) - if peer == nil { - return - } - resourcePolicies := b.cache.resourcePolicies - - view := &PeerRoutesView{ - OwnRouteIDs: make([]route.ID, 0), - NetworkResourceIDs: make([]route.ID, 0), - RouteFirewallRuleIDs: make([]string, 0), - } - - enabledRoutes, disabledRoutes := b.getRoutingPeerRoutes(peerID) - for _, rt := range enabledRoutes { - if rt.PeerID != "" && rt.PeerID != peerID { - if b.cache.globalPeers[rt.PeerID] == nil { - continue - } - } - - view.OwnRouteIDs = append(view.OwnRouteIDs, rt.ID) - b.cache.globalRoutes[rt.ID] = rt - } - - aclView := b.cache.peerACLs[peerID] - if aclView != nil { - peerRoutesMembership := make(LookupMap) - for _, r := range append(enabledRoutes, disabledRoutes...) { - peerRoutesMembership[string(r.GetHAUniqueID())] = struct{}{} - } - - peerGroups := b.cache.peerToGroups[peerID] - peerGroupsMap := make(LookupMap) - for _, groupID := range peerGroups { - peerGroupsMap[groupID] = struct{}{} - } - - for _, aclPeerID := range aclView.ConnectedPeerIDs { - if aclPeerID == peerID { - continue - } - activeRoutes, _ := b.getRoutingPeerRoutes(aclPeerID) - groupFilteredRoutes := account.filterRoutesByGroups(activeRoutes, peerGroupsMap) - haFilteredRoutes := account.filterRoutesFromPeersOfSameHAGroup(groupFilteredRoutes, peerRoutesMembership) - - for _, inheritedRoute := range haFilteredRoutes { - view.InheritedRouteIDs = append(view.InheritedRouteIDs, inheritedRoute.ID) - b.cache.globalRoutes[inheritedRoute.ID] = inheritedRoute - } - } - } - - _, networkResourcesRoutes, _ := b.getNetworkResourcesForPeer(account, peer) - - for _, rt := range networkResourcesRoutes { - view.NetworkResourceIDs = append(view.NetworkResourceIDs, rt.ID) - b.cache.globalRoutes[rt.ID] = rt - } - - allRoutes := slices.Concat(enabledRoutes, networkResourcesRoutes) - b.updateACGIndexForPeer(peerID, allRoutes) - - routeFirewallRules := b.getPeerRoutesFirewallRules(account, peerID, b.validatedPeers) - for _, rule := range routeFirewallRules { - ruleID := b.generateRouteFirewallRuleID(rule) - view.RouteFirewallRuleIDs = append(view.RouteFirewallRuleIDs, ruleID) - b.cache.globalRouteRules[ruleID] = rule - } - - if len(networkResourcesRoutes) > 0 { - networkResourceFirewallRules := account.GetPeerNetworkResourceFirewallRules(ctx, peer, b.validatedPeers, networkResourcesRoutes, resourcePolicies) - for _, rule := range networkResourceFirewallRules { - ruleID := b.generateRouteFirewallRuleID(rule) - view.RouteFirewallRuleIDs = append(view.RouteFirewallRuleIDs, ruleID) - b.cache.globalRouteRules[ruleID] = rule - } - } - - b.cache.peerRoutes[peerID] = view -} - -func (b *NetworkMapBuilder) updateACGIndexForPeer(peerID string, routes []*route.Route) { - for acg, routeMap := range b.cache.acgToRoutes { - for routeID, info := range routeMap { - if info.PeerID == peerID { - delete(routeMap, routeID) - } - } - if len(routeMap) == 0 { - delete(b.cache.acgToRoutes, acg) - } - } - - for routeID, info := range b.cache.noACGRoutes { - if info.PeerID == peerID { - delete(b.cache.noACGRoutes, routeID) - } - } - - for _, rt := range routes { - if !rt.Enabled { - continue - } - - if len(rt.AccessControlGroups) == 0 { - b.cache.noACGRoutes[rt.ID] = &RouteOwnerInfo{ - PeerID: peerID, - RouteID: rt.ID, - } - } else { - for _, acg := range rt.AccessControlGroups { - if b.cache.acgToRoutes[acg] == nil { - b.cache.acgToRoutes[acg] = make(map[route.ID]*RouteOwnerInfo) - } - - b.cache.acgToRoutes[acg][rt.ID] = &RouteOwnerInfo{ - PeerID: peerID, - RouteID: rt.ID, - } - } - } - } -} - -func (b *NetworkMapBuilder) getRoutingPeerRoutes(peerID string) (enabledRoutes []*route.Route, disabledRoutes []*route.Route) { - peer := b.cache.globalPeers[peerID] - if peer == nil { - return enabledRoutes, disabledRoutes - } - - seenRoute := make(map[route.ID]struct{}) - - takeRoute := func(r *route.Route, id string) { - if _, ok := seenRoute[r.ID]; ok { - return - } - seenRoute[r.ID] = struct{}{} - - if r.Enabled { - // maybe here is some mess - here we store peer key (see comment below) - r.Peer = peer.Key - enabledRoutes = append(enabledRoutes, r) - return - } - disabledRoutes = append(disabledRoutes, r) - } - - peerGroups := b.cache.peerToGroups[peerID] - for _, groupID := range peerGroups { - groupRoutes := b.cache.groupToRoutes[groupID] - for _, r := range groupRoutes { - newPeerRoute := r.Copy() - // and here we store peer ID - this logic is taken from original account.getRoutingPeerRoutes - newPeerRoute.Peer = peerID - newPeerRoute.PeerGroups = nil - newPeerRoute.ID = route.ID(string(r.ID) + ":" + peerID) - takeRoute(newPeerRoute, peerID) - } - } - for _, r := range b.cache.peerToRoutes[peerID] { - takeRoute(r.Copy(), peerID) - } - return enabledRoutes, disabledRoutes -} - -func (b *NetworkMapBuilder) getPeerRoutesFirewallRules(account *Account, peerID string, validatedPeersMap map[string]struct{}) []*RouteFirewallRule { - routesFirewallRules := make([]*RouteFirewallRule, 0) - - enabledRoutes, _ := b.getRoutingPeerRoutes(peerID) - for _, route := range enabledRoutes { - if len(route.AccessControlGroups) == 0 { - defaultPermit := getDefaultPermit(route) - routesFirewallRules = append(routesFirewallRules, defaultPermit...) - continue - } - - distributionPeers := b.getDistributionGroupsPeers(route) - - for _, accessGroup := range route.AccessControlGroups { - policies := b.getAllRoutePoliciesFromGroups([]string{accessGroup}) - - rules := b.getRouteFirewallRules(peerID, policies, route, validatedPeersMap, distributionPeers, account) - routesFirewallRules = append(routesFirewallRules, rules...) - } - } - - return routesFirewallRules -} - -func (b *NetworkMapBuilder) getDistributionGroupsPeers(route *route.Route) map[string]struct{} { - distPeers := make(map[string]struct{}) - for _, id := range route.Groups { - groupPeers := b.cache.groupToPeers[id] - if groupPeers == nil { - continue - } - - for _, pID := range groupPeers { - distPeers[pID] = struct{}{} - } - } - return distPeers -} - -func (b *NetworkMapBuilder) getAllRoutePoliciesFromGroups(accessControlGroups []string) []*Policy { - routePolicies := make(map[string]*Policy) - - for _, groupID := range accessControlGroups { - candidatePolicies := b.cache.groupToPolicies[groupID] - - for _, policy := range candidatePolicies { - if _, found := routePolicies[policy.ID]; found { - continue - } - policyRules := b.cache.policyToRules[policy.ID] - for _, rule := range policyRules { - if slices.Contains(rule.Destinations, groupID) { - routePolicies[policy.ID] = policy - break - } - } - } - } - - return maps.Values(routePolicies) -} - -func (b *NetworkMapBuilder) getRouteFirewallRules( - peerID string, policies []*Policy, route *route.Route, validatedPeersMap map[string]struct{}, - distributionPeers map[string]struct{}, account *Account, -) []*RouteFirewallRule { - ctx := context.Background() - var fwRules []*RouteFirewallRule - for _, policy := range policies { - if !policy.Enabled { - continue - } - - for _, rule := range policy.Rules { - if !rule.Enabled { - continue - } - - rulePeers := b.getRulePeers(rule, policy.SourcePostureChecks, peerID, distributionPeers, validatedPeersMap, account) - - rules := generateRouteFirewallRules(ctx, route, rule, rulePeers, FirewallRuleDirectionIN) - fwRules = append(fwRules, rules...) - } - } - return fwRules -} - -func (b *NetworkMapBuilder) getRulePeers( - rule *PolicyRule, postureChecks []string, peerID string, distributionPeers map[string]struct{}, - validatedPeersMap map[string]struct{}, account *Account, -) []*nbpeer.Peer { - distPeersWithPolicy := make(map[string]struct{}) - - for _, id := range rule.Sources { - groupPeers := b.cache.groupToPeers[id] - if groupPeers == nil { - continue - } - - for _, pID := range groupPeers { - if pID == peerID { - continue - } - _, distPeer := distributionPeers[pID] - _, valid := validatedPeersMap[pID] - - if distPeer && valid && account.validatePostureChecksOnPeer(context.Background(), postureChecks, pID) { - distPeersWithPolicy[pID] = struct{}{} - } - } - } - - if rule.SourceResource.Type == ResourceTypePeer && rule.SourceResource.ID != "" { - _, distPeer := distributionPeers[rule.SourceResource.ID] - _, valid := validatedPeersMap[rule.SourceResource.ID] - if distPeer && valid && account.validatePostureChecksOnPeer(context.Background(), postureChecks, rule.SourceResource.ID) { - distPeersWithPolicy[rule.SourceResource.ID] = struct{}{} - } - } - - distributionGroupPeers := make([]*nbpeer.Peer, 0, len(distPeersWithPolicy)) - for pID := range distPeersWithPolicy { - peer := b.cache.globalPeers[pID] - if peer == nil { - continue - } - distributionGroupPeers = append(distributionGroupPeers, peer) - } - return distributionGroupPeers -} - -func (b *NetworkMapBuilder) buildPeerDNSView(account *Account, peerID string) { - peerGroups := b.cache.peerToGroups[peerID] - checkGroups := make(map[string]struct{}, len(peerGroups)) - for _, groupID := range peerGroups { - checkGroups[groupID] = struct{}{} - } - - dnsManagementStatus := b.getPeerDNSManagementStatus(account, checkGroups) - dnsConfig := &nbdns.Config{ - ServiceEnable: dnsManagementStatus, - } - - if dnsManagementStatus { - dnsConfig.NameServerGroups = b.getPeerNSGroups(account, peerID, checkGroups) - } - - b.cache.peerDNS[peerID] = dnsConfig -} - -func (b *NetworkMapBuilder) getPeerDNSManagementStatus(account *Account, checkGroups map[string]struct{}) bool { - - enabled := true - for _, groupID := range account.DNSSettings.DisabledManagementGroups { - _, found := checkGroups[groupID] - if found { - enabled = false - break - } - } - return enabled -} - -func (b *NetworkMapBuilder) getPeerNSGroups(account *Account, peerID string, checkGroups map[string]struct{}) []*nbdns.NameServerGroup { - var peerNSGroups []*nbdns.NameServerGroup - - for _, nsGroup := range account.NameServerGroups { - if !nsGroup.Enabled { - continue - } - for _, gID := range nsGroup.Groups { - _, found := checkGroups[gID] - if found { - peer := b.cache.globalPeers[peerID] - if !peerIsNameserver(peer, nsGroup) { - peerNSGroups = append(peerNSGroups, nsGroup.Copy()) - break - } - } - } - } - - return peerNSGroups -} - -func (b *NetworkMapBuilder) buildAllowedUserIDs(account *Account) map[string]struct{} { - users := make(map[string]struct{}) - for _, nbUser := range account.Users { - if !nbUser.IsBlocked() && !nbUser.IsServiceUser { - users[nbUser.Id] = struct{}{} - } - } - return users -} - -func firewallRuleProtocol(protocol PolicyRuleProtocolType) string { - if protocol == PolicyRuleProtocolNetbirdSSH { - return string(PolicyRuleProtocolTCP) - } - return string(protocol) -} - -// lock should be held -func (b *NetworkMapBuilder) updateAccountLocked(account *Account) *Account { - if account.Network.CurrentSerial() > b.account.Network.CurrentSerial() { - b.account = account - } - return b.account -} - -func (b *NetworkMapBuilder) GetPeerNetworkMap( - ctx context.Context, peerID string, peersCustomZone nbdns.CustomZone, accountZones []*zones.Zone, - validatedPeers map[string]struct{}, metrics *telemetry.AccountManagerMetrics, -) *NetworkMap { - start := time.Now() - - b.cache.mu.RLock() - defer b.cache.mu.RUnlock() - - account := b.account - - peer := account.GetPeer(peerID) - if peer == nil { - return &NetworkMap{Network: account.Network.Copy()} - } - - aclView := b.cache.peerACLs[peerID] - routesView := b.cache.peerRoutes[peerID] - dnsConfig := b.cache.peerDNS[peerID] - sshView := b.cache.peerSSH[peerID] - - if aclView == nil || routesView == nil || dnsConfig == nil { - return &NetworkMap{Network: account.Network.Copy()} - } - - nm := b.assembleNetworkMap(ctx, account, peer, aclView, routesView, dnsConfig, sshView, peersCustomZone, accountZones, validatedPeers) - - if metrics != nil { - objectCount := int64(len(nm.Peers) + len(nm.OfflinePeers) + len(nm.Routes) + len(nm.FirewallRules) + len(nm.RoutesFirewallRules)) - metrics.CountNetworkMapObjects(objectCount) - metrics.CountGetPeerNetworkMapDuration(time.Since(start)) - - if objectCount > 5000 { - log.WithContext(ctx).Tracef("account: %s has a total resource count of %d objects from cache", - account.Id, objectCount) - } - } - - return nm -} - -func (b *NetworkMapBuilder) assembleNetworkMap( - ctx context.Context, account *Account, peer *nbpeer.Peer, aclView *PeerACLView, routesView *PeerRoutesView, - dnsConfig *nbdns.Config, sshView *PeerSSHView, peersCustomZone nbdns.CustomZone, accountZones []*zones.Zone, validatedPeers map[string]struct{}, -) *NetworkMap { - - var peersToConnect []*nbpeer.Peer - var expiredPeers []*nbpeer.Peer - - for _, peerID := range aclView.ConnectedPeerIDs { - if _, ok := validatedPeers[peerID]; !ok { - continue - } - - peer := b.cache.globalPeers[peerID] - if peer == nil { - continue - } - - expired, _ := peer.LoginExpired(account.Settings.PeerLoginExpiration) - if account.Settings.PeerLoginExpirationEnabled && expired { - expiredPeers = append(expiredPeers, peer) - } else { - peersToConnect = append(peersToConnect, peer) - } - } - - var routes []*route.Route - allRouteIDs := slices.Concat(routesView.OwnRouteIDs, routesView.NetworkResourceIDs, routesView.InheritedRouteIDs) - - for _, routeID := range allRouteIDs { - if route := b.cache.globalRoutes[routeID]; route != nil { - routes = append(routes, route) - } - } - - var firewallRules []*FirewallRule - for _, ruleID := range aclView.FirewallRuleIDs { - if rule := b.cache.globalRules[ruleID]; rule != nil { - firewallRules = append(firewallRules, rule) - } else { - log.Debugf("NetworkMapBuilder: peer %s assembling network map has no fwrule %s in globalRules", peer.ID, ruleID) - } - } - - var routesFirewallRules []*RouteFirewallRule - for _, ruleID := range routesView.RouteFirewallRuleIDs { - if rule := b.cache.globalRouteRules[ruleID]; rule != nil { - routesFirewallRules = append(routesFirewallRules, rule) - } - } - - finalDNSConfig := *dnsConfig - if finalDNSConfig.ServiceEnable { - var zones []nbdns.CustomZone - - peerGroupsSlice := b.cache.peerToGroups[peer.ID] - peerGroups := make(LookupMap, len(peerGroupsSlice)) - for _, groupID := range peerGroupsSlice { - peerGroups[groupID] = struct{}{} - } - - if peersCustomZone.Domain != "" { - records := filterZoneRecordsForPeers(peer, peersCustomZone, peersToConnect, expiredPeers) - zones = append(zones, nbdns.CustomZone{ - Domain: peersCustomZone.Domain, - Records: records, - }) - } - - filteredAccountZones := filterPeerAppliedZones(ctx, accountZones, peerGroups) - zones = append(zones, filteredAccountZones...) - - finalDNSConfig.CustomZones = zones - } - - nm := &NetworkMap{ - Peers: peersToConnect, - Network: account.Network.Copy(), - Routes: routes, - DNSConfig: finalDNSConfig, - OfflinePeers: expiredPeers, - FirewallRules: firewallRules, - RoutesFirewallRules: routesFirewallRules, - } - - if sshView != nil { - nm.EnableSSH = sshView.EnableSSH - nm.AuthorizedUsers = sshView.AuthorizedUsers - } - - return nm -} - -func (b *NetworkMapBuilder) generateFirewallRuleID(rule *FirewallRule) string { - var s strings.Builder - s.WriteString(fw) - s.WriteString(rule.PolicyID) - s.WriteRune(':') - s.WriteString(rule.PeerIP) - s.WriteRune(':') - s.WriteString(strconv.Itoa(rule.Direction)) - s.WriteRune(':') - s.WriteString(rule.Protocol) - s.WriteRune(':') - s.WriteString(rule.Action) - s.WriteRune(':') - s.WriteString(rule.Port) - s.WriteRune(':') - s.WriteString(strconv.Itoa(int(rule.PortRange.Start))) - s.WriteRune('-') - s.WriteString(strconv.Itoa(int(rule.PortRange.End))) - return s.String() -} - -func (b *NetworkMapBuilder) generateRouteFirewallRuleID(rule *RouteFirewallRule) string { - var s strings.Builder - s.WriteString(rfw) - s.WriteString(string(rule.RouteID)) - s.WriteRune(':') - s.WriteString(rule.Destination) - s.WriteRune(':') - s.WriteString(rule.Action) - s.WriteRune(':') - s.WriteString(strings.Join(rule.SourceRanges, ",")) - s.WriteRune(':') - s.WriteString(rule.Protocol) - s.WriteRune(':') - s.WriteString(strconv.Itoa(int(rule.Port))) - return s.String() -} - -func (b *NetworkMapBuilder) isPeerInGroups(groupIDs []string, peerGroups []string) bool { - for _, groupID := range groupIDs { - if slices.Contains(peerGroups, groupID) { - return true - } - } - return false -} - -func (b *NetworkMapBuilder) isPeerRouter(account *Account, peerID string) bool { - for _, r := range account.Routes { - if !r.Enabled { - continue - } - - if r.PeerID == peerID { - return true - } - - if peer := b.cache.globalPeers[peerID]; peer != nil { - if r.Peer == peer.Key && r.PeerID == "" { - return true - } - } - } - - routers := account.GetResourceRoutersMap() - for _, networkRouters := range routers { - if router, exists := networkRouters[peerID]; exists && router.Enabled { - return true - } - } - - return false -} - -func (b *NetworkMapBuilder) incAddPeerLoop() { - for { - b.apb.mu.Lock() - if len(b.apb.ids) == 0 { - b.apb.sg.Wait() - } - b.addPeersIncrementally() - b.apb.mu.Unlock() - } -} - -// lock on b.apb level should be held -func (b *NetworkMapBuilder) addPeersIncrementally() { - peers := slices.Clone(b.apb.ids) - clear(b.apb.ids) - b.apb.ids = b.apb.ids[:0] - latestAcc := b.apb.la - b.apb.mu.Unlock() - - tt := time.Now() - b.cache.mu.Lock() - defer b.cache.mu.Unlock() - - account := b.updateAccountLocked(latestAcc) - - log.Debugf("NetworkMapBuilder: Starting incremental add of %d peers", len(peers)) - - allUpdates := make(map[string]*PeerUpdateDelta) - - for _, peerID := range peers { - peer := account.GetPeer(peerID) - if peer == nil { - b.apb.mu.Lock() - retries := b.apb.retryCount[peerID] - b.apb.mu.Unlock() - - if retries >= maxPeerAddRetries { - log.Errorf("NetworkMapBuilder: peer %s not found in account %s after %d retries, giving up", peerID, account.Id, retries) - b.apb.mu.Lock() - delete(b.apb.retryCount, peerID) - b.apb.mu.Unlock() - continue - } - - log.Warnf("NetworkMapBuilder: peer %s not found in account %s, retry %d/%d", peerID, account.Id, retries+1, maxPeerAddRetries) - b.apb.mu.Lock() - b.apb.retryCount[peerID] = retries + 1 - b.apb.mu.Unlock() - b.enqueuePeersForIncrementalAdd(latestAcc, peerID) - continue - } - - b.apb.mu.Lock() - delete(b.apb.retryCount, peerID) - b.apb.mu.Unlock() - - b.validatedPeers[peerID] = struct{}{} - b.cache.globalPeers[peerID] = peer - - peerGroups := b.updateIndexesForNewPeer(account, peerID) - b.buildPeerACLView(account, peerID) - b.buildPeerRoutesView(account, peerID) - b.buildPeerDNSView(account, peerID) - - peerDeltas := b.collectDeltasForNewPeer(account, peerID, peerGroups) - for affectedPeerID, delta := range peerDeltas { - if existing, ok := allUpdates[affectedPeerID]; ok { - existing.mergeFrom(delta) - continue - } - allUpdates[affectedPeerID] = delta - } - } - - for affectedPeerID, delta := range allUpdates { - b.applyDeltaToPeer(account, affectedPeerID, delta) - } - - log.Debugf("NetworkMapBuilder: Added %d peers to cache, affected %d peers, took %s", len(peers), len(allUpdates), time.Since(tt)) - - b.apb.mu.Lock() - if len(b.apb.ids) > 0 { - b.apb.sg.Signal() - } -} - -func (b *NetworkMapBuilder) enqueuePeersForIncrementalAdd(acc *Account, peerIDs ...string) { - b.apb.mu.Lock() - b.apb.ids = append(b.apb.ids, peerIDs...) - if b.apb.la != nil && acc.Network.CurrentSerial() > b.apb.la.Network.CurrentSerial() { - b.apb.la = acc - } - b.apb.sg.Signal() - b.apb.mu.Unlock() -} - -func (b *NetworkMapBuilder) EnqueuePeersForIncrementalAdd(acc *Account, peerIDs ...string) { - b.enqueuePeersForIncrementalAdd(acc, peerIDs...) -} - -type ViewDelta struct { - AddedPeerIDs []string - RemovedPeerIDs []string - AddedRuleIDs []string - RemovedRuleIDs []string -} - -func (b *NetworkMapBuilder) OnPeerAddedIncremental(acc *Account, peerID string) error { - tt := time.Now() - peer := acc.GetPeer(peerID) - if peer == nil { - return fmt.Errorf("NetworkMapBuilder: peer %s not found in account", peerID) - } - - b.cache.mu.Lock() - defer b.cache.mu.Unlock() - - account := b.updateAccountLocked(acc) - - log.Debugf("NetworkMapBuilder: Adding peer %s (IP: %s) to cache", peerID, peer.IP.String()) - - b.validatedPeers[peerID] = struct{}{} - - b.cache.globalPeers[peerID] = peer - - peerGroups := b.updateIndexesForNewPeer(account, peerID) - - b.buildPeerACLView(account, peerID) - b.buildPeerRoutesView(account, peerID) - b.buildPeerDNSView(account, peerID) - - log.Debugf("NetworkMapBuilder: Adding peer %s to cache, views took %s", peerID, time.Since(tt)) - - b.incrementalUpdateAffectedPeers(account, peerID, peerGroups) - - log.Debugf("NetworkMapBuilder: Added peer %s to cache, took %s", peerID, time.Since(tt)) - - return nil -} - -func (b *NetworkMapBuilder) updateIndexesForNewPeer(account *Account, peerID string) []string { - peerGroups := make([]string, 0) - - for groupID, group := range account.Groups { - if slices.Contains(group.Peers, peerID) { - if !slices.Contains(b.cache.groupToPeers[groupID], peerID) { - b.cache.groupToPeers[groupID] = append(b.cache.groupToPeers[groupID], peerID) - } - peerGroups = append(peerGroups, groupID) - } - } - - b.cache.peerToGroups[peerID] = peerGroups - - for _, r := range account.Routes { - if !r.Enabled || b.cache.globalRoutes[r.ID] != nil { - continue - } - for _, groupID := range r.PeerGroups { - if !slices.Contains(b.cache.groupToRoutes[groupID], r) { - b.cache.groupToRoutes[groupID] = append(b.cache.groupToRoutes[groupID], r) - } - } - if r.Peer != "" { - if peer, ok := b.cache.globalPeers[r.Peer]; ok { - if !slices.Contains(b.cache.peerToRoutes[peer.ID], r) { - b.cache.peerToRoutes[peer.ID] = append(b.cache.peerToRoutes[peer.ID], r) - } - } - } - b.cache.globalRoutes[r.ID] = r - } - - return peerGroups -} - -func (b *NetworkMapBuilder) incrementalUpdateAffectedPeers(account *Account, newPeerID string, peerGroups []string) { - updates := b.collectDeltasForNewPeer(account, newPeerID, peerGroups) - for affectedPeerID, delta := range updates { - b.applyDeltaToPeer(account, affectedPeerID, delta) - } -} - -func (b *NetworkMapBuilder) collectDeltasForNewPeer(account *Account, newPeerID string, peerGroups []string) map[string]*PeerUpdateDelta { - updates := b.calculateIncrementalUpdates(account, newPeerID, peerGroups) - - if b.isPeerRouter(account, newPeerID) { - affectedByRoutes := b.findPeersAffectedByNewRouter(account, newPeerID, peerGroups) - for affectedPeerID := range affectedByRoutes { - if affectedPeerID == newPeerID { - continue - } - if _, exists := updates[affectedPeerID]; !exists { - updates[affectedPeerID] = &PeerUpdateDelta{ - PeerID: affectedPeerID, - RebuildRoutesView: true, - } - } else { - updates[affectedPeerID].RebuildRoutesView = true - } - } - } - - return updates -} - -func (b *NetworkMapBuilder) findPeersAffectedByNewRouter(account *Account, newRouterID string, routerGroups []string) map[string]struct{} { - affected := make(map[string]struct{}) - enabledRoutes, _ := b.getRoutingPeerRoutes(newRouterID) - - for _, route := range enabledRoutes { - for _, distGroupID := range route.Groups { - if peers := b.cache.groupToPeers[distGroupID]; peers != nil { - for _, peerID := range peers { - if peerID != newRouterID { - affected[peerID] = struct{}{} - } - } - } - } - - for _, peerGroupID := range route.PeerGroups { - if peers := b.cache.groupToPeers[peerGroupID]; peers != nil { - for _, peerID := range peers { - if peerID != newRouterID { - affected[peerID] = struct{}{} - } - } - } - } - } - - for _, route := range account.Routes { - if !route.Enabled { - continue - } - - routerInPeerGroups := false - for _, peerGroupID := range route.PeerGroups { - if slices.Contains(routerGroups, peerGroupID) { - routerInPeerGroups = true - break - } - } - - if routerInPeerGroups { - for _, distGroupID := range route.Groups { - if peers := b.cache.groupToPeers[distGroupID]; peers != nil { - for _, peerID := range peers { - affected[peerID] = struct{}{} - } - } - } - } - } - - return affected -} - -func (b *NetworkMapBuilder) calculateIncrementalUpdates(account *Account, newPeerID string, peerGroups []string) map[string]*PeerUpdateDelta { - updates := make(map[string]*PeerUpdateDelta) - ctx := context.Background() - - groupAllLn := 0 - if allGroup, err := account.GetGroupAll(); err == nil { - groupAllLn = len(allGroup.Peers) - 1 - } - - newPeer := b.cache.globalPeers[newPeerID] - if newPeer == nil { - return updates - } - - for _, policy := range account.Policies { - if !policy.Enabled { - continue - } - - for _, rule := range policy.Rules { - if !rule.Enabled { - continue - } - var peerInSources, peerInDestinations bool - - if rule.SourceResource.Type == ResourceTypePeer && rule.SourceResource.ID == newPeerID { - peerInSources = true - } else { - peerInSources = b.isPeerInGroups(rule.Sources, peerGroups) - } - - if rule.DestinationResource.Type == ResourceTypePeer && rule.DestinationResource.ID == newPeerID { - peerInDestinations = true - } else { - peerInDestinations = b.isPeerInGroups(rule.Destinations, peerGroups) - } - - if peerInSources { - if len(rule.Destinations) > 0 { - b.addUpdateForPeersInGroups(updates, rule.Destinations, newPeerID, rule, FirewallRuleDirectionIN, groupAllLn) - } - if rule.DestinationResource.Type == ResourceTypePeer && rule.DestinationResource.ID != "" { - b.addUpdateForDirectPeerResource(updates, rule.DestinationResource.ID, newPeerID, rule, FirewallRuleDirectionIN) - } - } - - if peerInDestinations { - if len(rule.Sources) > 0 { - b.addUpdateForPeersInGroups(updates, rule.Sources, newPeerID, rule, FirewallRuleDirectionOUT, groupAllLn) - } - if rule.SourceResource.Type == ResourceTypePeer && rule.SourceResource.ID != "" { - b.addUpdateForDirectPeerResource(updates, rule.SourceResource.ID, newPeerID, rule, FirewallRuleDirectionOUT) - } - } - - if rule.Bidirectional { - if peerInSources { - if len(rule.Destinations) > 0 { - b.addUpdateForPeersInGroups(updates, rule.Destinations, newPeerID, rule, FirewallRuleDirectionOUT, groupAllLn) - } - if rule.DestinationResource.Type == ResourceTypePeer && rule.DestinationResource.ID != "" { - b.addUpdateForDirectPeerResource(updates, rule.DestinationResource.ID, newPeerID, rule, FirewallRuleDirectionOUT) - } - } - if peerInDestinations { - if len(rule.Sources) > 0 { - b.addUpdateForPeersInGroups(updates, rule.Sources, newPeerID, rule, FirewallRuleDirectionIN, groupAllLn) - } - if rule.SourceResource.Type == ResourceTypePeer && rule.SourceResource.ID != "" { - b.addUpdateForDirectPeerResource(updates, rule.SourceResource.ID, newPeerID, rule, FirewallRuleDirectionIN) - } - } - } - } - } - - b.calculateRouteFirewallUpdates(newPeerID, newPeer, peerGroups, updates) - - b.calculateNetworkResourceFirewallUpdates(ctx, account, newPeerID, newPeer, peerGroups, updates) - - b.calculateNewRouterNetworkResourceUpdates(ctx, account, newPeerID, updates) - - return updates -} - -func (b *NetworkMapBuilder) calculateNewRouterNetworkResourceUpdates( - ctx context.Context, account *Account, newPeerID string, - updates map[string]*PeerUpdateDelta, -) { - resourceRouters := b.cache.resourceRouters - - for networkID, routers := range resourceRouters { - router, isRouter := routers[newPeerID] - if !isRouter || !router.Enabled { - continue - } - - for _, resource := range b.cache.globalResources { - if resource.NetworkID != networkID { - continue - } - - policies := b.cache.resourcePolicies[resource.ID] - if len(policies) == 0 { - continue - } - - peersWithAccess := make(map[string]struct{}) - - for _, policy := range policies { - if !policy.Enabled { - continue - } - - sourceGroups := policy.SourceGroups() - for _, sourceGroup := range sourceGroups { - groupPeers := b.cache.groupToPeers[sourceGroup] - for _, peerID := range groupPeers { - if peerID == newPeerID { - continue - } - - if account.validatePostureChecksOnPeer(ctx, policy.SourcePostureChecks, peerID) { - peersWithAccess[peerID] = struct{}{} - } - } - } - } - - for peerID := range peersWithAccess { - delta := updates[peerID] - if delta == nil { - delta = &PeerUpdateDelta{ - PeerID: peerID, - } - updates[peerID] = delta - } - - if !slices.Contains(delta.AddConnectedPeers, newPeerID) { - delta.AddConnectedPeers = append(delta.AddConnectedPeers, newPeerID) - } - - delta.RebuildRoutesView = true - } - } - } -} - -func (b *NetworkMapBuilder) calculateRouteFirewallUpdates( - newPeerID string, newPeer *nbpeer.Peer, - peerGroups []string, updates map[string]*PeerUpdateDelta, -) { - processedPeerRoutes := make(map[string]map[route.ID]struct{}) - - for routeID, info := range b.cache.noACGRoutes { - if info.PeerID == newPeerID { - continue - } - - b.addRouteFirewallUpdate(updates, info.PeerID, string(routeID), newPeer.IP.String()) - - if processedPeerRoutes[info.PeerID] == nil { - processedPeerRoutes[info.PeerID] = make(map[route.ID]struct{}) - } - processedPeerRoutes[info.PeerID][routeID] = struct{}{} - } - - for _, acg := range peerGroups { - routeInfos := b.cache.acgToRoutes[acg] - if routeInfos == nil { - continue - } - - for routeID, info := range routeInfos { - if info.PeerID == newPeerID { - continue - } - - if processedRoutes, exists := processedPeerRoutes[info.PeerID]; exists { - if _, processed := processedRoutes[routeID]; processed { - continue - } - } - - b.addRouteFirewallUpdate(updates, info.PeerID, string(routeID), newPeer.IP.String()) - - if processedPeerRoutes[info.PeerID] == nil { - processedPeerRoutes[info.PeerID] = make(map[route.ID]struct{}) - } - processedPeerRoutes[info.PeerID][routeID] = struct{}{} - } - } -} - -func (b *NetworkMapBuilder) addRouteFirewallUpdate( - updates map[string]*PeerUpdateDelta, peerID string, - routeID string, sourceIP string, -) { - delta := updates[peerID] - if delta == nil { - delta = &PeerUpdateDelta{ - PeerID: peerID, - UpdateRouteFirewallRules: make([]*RouteFirewallRuleUpdate, 0), - } - updates[peerID] = delta - } - - for _, existing := range delta.UpdateRouteFirewallRules { - if existing.RuleID == routeID && existing.AddSourceIP == sourceIP { - return - } - } - - delta.UpdateRouteFirewallRules = append(delta.UpdateRouteFirewallRules, &RouteFirewallRuleUpdate{ - RuleID: routeID, - AddSourceIP: sourceIP, - }) -} - -func (b *NetworkMapBuilder) calculateNetworkResourceFirewallUpdates( - ctx context.Context, account *Account, newPeerID string, - newPeer *nbpeer.Peer, peerGroups []string, updates map[string]*PeerUpdateDelta, -) { - for _, resource := range b.cache.globalResources { - resourcePolicies := b.cache.resourcePolicies - resourceRouters := b.cache.resourceRouters - - policies := resourcePolicies[resource.ID] - peerHasAccess := false - - for _, policy := range policies { - if !policy.Enabled { - continue - } - - sourceGroups := policy.SourceGroups() - for _, sourceGroup := range sourceGroups { - if slices.Contains(peerGroups, sourceGroup) { - if account.validatePostureChecksOnPeer(ctx, policy.SourcePostureChecks, newPeerID) { - peerHasAccess = true - break - } - } - } - - if peerHasAccess { - break - } - } - - if !peerHasAccess { - continue - } - - networkRouters := resourceRouters[resource.NetworkID] - for routerPeerID, router := range networkRouters { - if !router.Enabled || routerPeerID == newPeerID { - continue - } - - delta := updates[routerPeerID] - if delta == nil { - delta = &PeerUpdateDelta{ - PeerID: routerPeerID, - } - updates[routerPeerID] = delta - } - - if !slices.Contains(delta.AddConnectedPeers, newPeerID) { - delta.AddConnectedPeers = append(delta.AddConnectedPeers, newPeerID) - } - - delta.RebuildRoutesView = true - } - } -} - -type PeerUpdateDelta struct { - PeerID string - AddConnectedPeers []string - AddFirewallRules []*FirewallRuleDelta - AddRoutes []route.ID - UpdateRouteFirewallRules []*RouteFirewallRuleUpdate - UpdateDNS bool - RebuildRoutesView bool -} - -func (d *PeerUpdateDelta) mergeFrom(other *PeerUpdateDelta) { - for _, peerID := range other.AddConnectedPeers { - if !slices.Contains(d.AddConnectedPeers, peerID) { - d.AddConnectedPeers = append(d.AddConnectedPeers, peerID) - } - } - - existingRuleIDs := make(map[string]struct{}, len(d.AddFirewallRules)) - for _, rule := range d.AddFirewallRules { - existingRuleIDs[rule.RuleID] = struct{}{} - } - for _, rule := range other.AddFirewallRules { - if _, exists := existingRuleIDs[rule.RuleID]; !exists { - d.AddFirewallRules = append(d.AddFirewallRules, rule) - existingRuleIDs[rule.RuleID] = struct{}{} - } - } - - for _, routeID := range other.AddRoutes { - if !slices.Contains(d.AddRoutes, routeID) { - d.AddRoutes = append(d.AddRoutes, routeID) - } - } - - existingRouteUpdates := make(map[string]map[string]struct{}) - for _, update := range d.UpdateRouteFirewallRules { - if existingRouteUpdates[update.RuleID] == nil { - existingRouteUpdates[update.RuleID] = make(map[string]struct{}) - } - existingRouteUpdates[update.RuleID][update.AddSourceIP] = struct{}{} - } - for _, update := range other.UpdateRouteFirewallRules { - if existingRouteUpdates[update.RuleID] == nil { - existingRouteUpdates[update.RuleID] = make(map[string]struct{}) - } - if _, exists := existingRouteUpdates[update.RuleID][update.AddSourceIP]; !exists { - d.UpdateRouteFirewallRules = append(d.UpdateRouteFirewallRules, update) - existingRouteUpdates[update.RuleID][update.AddSourceIP] = struct{}{} - } - } - - if other.UpdateDNS { - d.UpdateDNS = true - } - if other.RebuildRoutesView { - d.RebuildRoutesView = true - } -} - -type FirewallRuleDelta struct { - Rule *FirewallRule - RuleID string - Direction int -} - -type RouteFirewallRuleUpdate struct { - RuleID string - AddSourceIP string -} - -func (b *NetworkMapBuilder) addUpdateForPeersInGroups( - updates map[string]*PeerUpdateDelta, groupIDs []string, newPeerID string, - rule *PolicyRule, direction int, allGroupLn int, -) { - for _, groupID := range groupIDs { - peers := b.cache.groupToPeers[groupID] - cnt := 0 - for _, peerID := range peers { - if peerID == newPeerID { - continue - } - if _, ok := b.validatedPeers[peerID]; !ok { - continue - } - cnt++ - } - all := false - if allGroupLn > 0 && cnt == allGroupLn { - all = true - } - newPeer := b.cache.globalPeers[newPeerID] - fr := &FirewallRule{ - PolicyID: rule.ID, - PeerIP: newPeer.IP.String(), - Direction: direction, - Action: string(rule.Action), - Protocol: firewallRuleProtocol(rule.Protocol), - } - for _, peerID := range peers { - if peerID == newPeerID { - continue - } - if _, ok := b.validatedPeers[peerID]; !ok { - continue - } - targetPeer := b.cache.globalPeers[peerID] - if targetPeer == nil { - continue - } - - peerIPForRule := fr.PeerIP - if all { - peerIPForRule = allPeers - } - - b.addOrUpdateFirewallRuleInDelta(updates, peerID, newPeerID, rule, direction, fr, peerIPForRule, targetPeer) - } - } -} - -func (b *NetworkMapBuilder) addUpdateForDirectPeerResource( - updates map[string]*PeerUpdateDelta, targetPeerID string, newPeerID string, - rule *PolicyRule, direction int, -) { - if targetPeerID == newPeerID { - return - } - - if _, ok := b.validatedPeers[targetPeerID]; !ok { - return - } - - newPeer := b.cache.globalPeers[newPeerID] - if newPeer == nil { - return - } - - targetPeer := b.cache.globalPeers[targetPeerID] - if targetPeer == nil { - return - } - - fr := &FirewallRule{ - PolicyID: rule.ID, - PeerIP: newPeer.IP.String(), - Direction: direction, - Action: string(rule.Action), - Protocol: firewallRuleProtocol(rule.Protocol), - } - - b.addOrUpdateFirewallRuleInDelta(updates, targetPeerID, newPeerID, rule, direction, fr, fr.PeerIP, targetPeer) -} - -func (b *NetworkMapBuilder) addOrUpdateFirewallRuleInDelta( - updates map[string]*PeerUpdateDelta, targetPeerID string, newPeerID string, - rule *PolicyRule, direction int, baseRule *FirewallRule, peerIP string, targetPeer *nbpeer.Peer, -) { - delta := updates[targetPeerID] - if delta == nil { - delta = &PeerUpdateDelta{ - PeerID: targetPeerID, - AddConnectedPeers: []string{newPeerID}, - AddFirewallRules: make([]*FirewallRuleDelta, 0), - } - updates[targetPeerID] = delta - } else if !slices.Contains(delta.AddConnectedPeers, newPeerID) { - delta.AddConnectedPeers = append(delta.AddConnectedPeers, newPeerID) - } - - baseRule.PeerIP = peerIP - - if len(rule.Ports) > 0 || len(rule.PortRanges) > 0 { - expandedRules := expandPortsAndRanges(*baseRule, rule, targetPeer) - for _, expandedRule := range expandedRules { - ruleID := b.generateFirewallRuleID(expandedRule) - delta.AddFirewallRules = append(delta.AddFirewallRules, &FirewallRuleDelta{ - Rule: expandedRule, - RuleID: ruleID, - Direction: direction, - }) - } - } else { - ruleID := b.generateFirewallRuleID(baseRule) - delta.AddFirewallRules = append(delta.AddFirewallRules, &FirewallRuleDelta{ - Rule: baseRule, - RuleID: ruleID, - Direction: direction, - }) - } -} - -func (b *NetworkMapBuilder) applyDeltaToPeer(account *Account, peerID string, delta *PeerUpdateDelta) { - if len(delta.AddConnectedPeers) > 0 || len(delta.AddFirewallRules) > 0 { - if aclView := b.cache.peerACLs[peerID]; aclView != nil { - for _, connectedPeerID := range delta.AddConnectedPeers { - if !slices.Contains(aclView.ConnectedPeerIDs, connectedPeerID) { - aclView.ConnectedPeerIDs = append(aclView.ConnectedPeerIDs, connectedPeerID) - } - } - - for _, ruleDelta := range delta.AddFirewallRules { - b.cache.globalRules[ruleDelta.RuleID] = ruleDelta.Rule - - if !slices.Contains(aclView.FirewallRuleIDs, ruleDelta.RuleID) { - aclView.FirewallRuleIDs = append(aclView.FirewallRuleIDs, ruleDelta.RuleID) - } - } - } - } - - if delta.RebuildRoutesView { - b.buildPeerRoutesView(account, peerID) - } else if len(delta.UpdateRouteFirewallRules) > 0 { - if routesView := b.cache.peerRoutes[peerID]; routesView != nil { - b.updateRouteFirewallRules(routesView, delta.UpdateRouteFirewallRules) - } - } - - if delta.UpdateDNS { - b.buildPeerDNSView(account, peerID) - } -} - -func (b *NetworkMapBuilder) updateRouteFirewallRules(routesView *PeerRoutesView, updates []*RouteFirewallRuleUpdate) { - for _, update := range updates { - for _, ruleID := range routesView.RouteFirewallRuleIDs { - rule := b.cache.globalRouteRules[ruleID] - if rule == nil { - continue - } - - if string(rule.RouteID) == update.RuleID { - if hasWildcard := slices.Contains(rule.SourceRanges, allWildcard) || slices.Contains(rule.SourceRanges, v6AllWildcard); hasWildcard { - break - } - - sourceIP := update.AddSourceIP - - if strings.Contains(sourceIP, ":") { - sourceIP += "/128" // IPv6 - } else { - sourceIP += "/32" // IPv4 - } - - if !slices.Contains(rule.SourceRanges, sourceIP) { - rule.SourceRanges = append(rule.SourceRanges, sourceIP) - } - break - } - } - } -} - -func (b *NetworkMapBuilder) OnPeerDeleted(acc *Account, peerID string) error { - b.cache.mu.Lock() - defer b.cache.mu.Unlock() - - account := b.updateAccountLocked(acc) - - deletedPeer := b.cache.globalPeers[peerID] - if deletedPeer == nil { - return fmt.Errorf("peer %s not found in cache", peerID) - } - - deletedPeerKey := deletedPeer.Key - peerGroups := b.cache.peerToGroups[peerID] - peerIP := deletedPeer.IP.String() - - log.Debugf("NetworkMapBuilder: Deleting peer %s (IP: %s) from cache", peerID, peerIP) - - delete(b.validatedPeers, peerID) - - routesToDelete := []route.ID{} - - for routeID, r := range account.Routes { - if r.Peer != deletedPeerKey && r.PeerID != peerID { - continue - } - if len(r.PeerGroups) == 0 { - routesToDelete = append(routesToDelete, routeID) - continue - } - newPeerAssigned := false - for _, groupID := range r.PeerGroups { - candidatePeerIDs := b.cache.groupToPeers[groupID] - for _, candidatePeerID := range candidatePeerIDs { - if candidatePeerID == peerID { - continue - } - if candidatePeer := b.cache.globalPeers[candidatePeerID]; candidatePeer != nil { - r.Peer = candidatePeer.Key - r.PeerID = candidatePeerID - newPeerAssigned = true - break - } - } - if newPeerAssigned { - break - } - } - - if !newPeerAssigned { - routesToDelete = append(routesToDelete, routeID) - } - } - - for _, routeID := range routesToDelete { - delete(account.Routes, routeID) - } - - delete(b.cache.peerACLs, peerID) - delete(b.cache.peerRoutes, peerID) - delete(b.cache.peerDNS, peerID) - delete(b.cache.peerSSH, peerID) - - delete(b.cache.globalPeers, peerID) - - for acg, routeMap := range b.cache.acgToRoutes { - for routeID, info := range routeMap { - if info.PeerID == peerID { - delete(routeMap, routeID) - } - } - if len(routeMap) == 0 { - delete(b.cache.acgToRoutes, acg) - } - } - - for _, groupID := range peerGroups { - if peers := b.cache.groupToPeers[groupID]; peers != nil { - b.cache.groupToPeers[groupID] = slices.DeleteFunc(peers, func(id string) bool { - return id == peerID - }) - } - } - delete(b.cache.peerToGroups, peerID) - - affectedPeers := make(map[string]struct{}) - - for _, r := range account.Routes { - for _, groupID := range r.Groups { - if peers := b.cache.groupToPeers[groupID]; peers != nil { - for _, p := range peers { - affectedPeers[p] = struct{}{} - } - } - } - - for _, groupID := range r.PeerGroups { - if peers := b.cache.groupToPeers[groupID]; peers != nil { - for _, p := range peers { - affectedPeers[p] = struct{}{} - } - } - } - } - - for affectedPeerID := range affectedPeers { - if affectedPeerID == peerID { - continue - } - b.buildPeerRoutesView(account, affectedPeerID) - } - - peersToRebuildACL := make(map[string]struct{}) - peerDeletionUpdates := b.findPeersAffectedByDeletedPeerACL(peerID, peerIP, peerGroups, peersToRebuildACL) - for affectedPeerID, updates := range peerDeletionUpdates { - b.applyDeletionUpdates(affectedPeerID, updates) - } - - for affectedPeerID := range peersToRebuildACL { - b.buildPeerACLView(account, affectedPeerID) - } - - b.cleanupUnusedRules() - - log.Debugf("NetworkMapBuilder: Deleted peer %s, affected %d other peers", peerID, len(affectedPeers)) - - return nil -} - -func (b *NetworkMapBuilder) findPeersAffectedByDeletedPeerACL( - deletedPeerID string, - peerIP string, - peerGroups []string, - peersToRebuildACL map[string]struct{}, -) map[string]*PeerDeletionUpdate { - - affected := make(map[string]*PeerDeletionUpdate) - - for peerID, aclView := range b.cache.peerACLs { - if peerID == deletedPeerID { - continue - } - - if slices.Contains(aclView.ConnectedPeerIDs, deletedPeerID) { - peersToRebuildACL[peerID] = struct{}{} - if affected[peerID] == nil { - affected[peerID] = &PeerDeletionUpdate{ - RemovePeerID: deletedPeerID, - PeerIP: peerIP, - } - } - } - } - - affectedRouteOwners := make(map[string]struct{}) - - for _, groupID := range peerGroups { - if routeMap, ok := b.cache.acgToRoutes[groupID]; ok { - for _, info := range routeMap { - if info.PeerID != deletedPeerID { - affectedRouteOwners[info.PeerID] = struct{}{} - } - } - } - } - - for _, info := range b.cache.noACGRoutes { - if info.PeerID != deletedPeerID { - affectedRouteOwners[info.PeerID] = struct{}{} - } - } - - for ownerPeerID := range affectedRouteOwners { - if affected[ownerPeerID] == nil { - affected[ownerPeerID] = &PeerDeletionUpdate{ - RemovePeerID: deletedPeerID, - PeerIP: peerIP, - RemoveFromSourceRanges: true, - } - } else { - affected[ownerPeerID].RemoveFromSourceRanges = true - } - } - - return affected -} - -type PeerDeletionUpdate struct { - RemovePeerID string - RemoveFirewallRuleIDs []string - RemoveRouteIDs []route.ID - RemoveFromSourceRanges bool - PeerIP string -} - -func (b *NetworkMapBuilder) applyDeletionUpdates(peerID string, updates *PeerDeletionUpdate) { - if routesView := b.cache.peerRoutes[peerID]; routesView != nil { - if len(updates.RemoveRouteIDs) > 0 { - routesView.NetworkResourceIDs = slices.DeleteFunc(routesView.NetworkResourceIDs, func(routeID route.ID) bool { - return slices.Contains(updates.RemoveRouteIDs, routeID) - }) - } - - if updates.RemoveFromSourceRanges { - b.removeIPFromRouteFirewallRules(routesView, updates.PeerIP) - } - } -} - -func (b *NetworkMapBuilder) removeIPFromRouteFirewallRules(routesView *PeerRoutesView, peerIP string) { - sourceIPv4 := peerIP + "/32" - sourceIPv6 := peerIP + "/128" - - rulesToRemove := []string{} - - for _, ruleID := range routesView.RouteFirewallRuleIDs { - if rule := b.cache.globalRouteRules[ruleID]; rule != nil { - rule.SourceRanges = slices.DeleteFunc(rule.SourceRanges, func(source string) bool { - return source == sourceIPv4 || source == sourceIPv6 || source == peerIP - }) - - if len(rule.SourceRanges) == 0 { - rulesToRemove = append(rulesToRemove, ruleID) - } - } - } - - if len(rulesToRemove) > 0 { - routesView.RouteFirewallRuleIDs = slices.DeleteFunc(routesView.RouteFirewallRuleIDs, func(ruleID string) bool { - return slices.Contains(rulesToRemove, ruleID) - }) - } -} - -func (b *NetworkMapBuilder) cleanupUnusedRules() { - usedFirewallRules := make(map[string]struct{}) - usedRouteRules := make(map[string]struct{}) - usedRoutes := make(map[route.ID]struct{}) - - for _, aclView := range b.cache.peerACLs { - for _, ruleID := range aclView.FirewallRuleIDs { - usedFirewallRules[ruleID] = struct{}{} - } - } - - for _, routesView := range b.cache.peerRoutes { - for _, ruleID := range routesView.RouteFirewallRuleIDs { - usedRouteRules[ruleID] = struct{}{} - } - - for _, routeID := range routesView.OwnRouteIDs { - usedRoutes[routeID] = struct{}{} - } - for _, routeID := range routesView.NetworkResourceIDs { - usedRoutes[routeID] = struct{}{} - } - } - - for ruleID := range b.cache.globalRules { - if _, used := usedFirewallRules[ruleID]; !used { - delete(b.cache.globalRules, ruleID) - } - } - - for ruleID := range b.cache.globalRouteRules { - if _, used := usedRouteRules[ruleID]; !used { - delete(b.cache.globalRouteRules, ruleID) - } - } - - for routeID := range b.cache.globalRoutes { - if _, used := usedRoutes[routeID]; !used { - delete(b.cache.globalRoutes, routeID) - } - } -} - -func (b *NetworkMapBuilder) UpdatePeer(peer *nbpeer.Peer) { - b.cache.mu.Lock() - defer b.cache.mu.Unlock() - peerStored, ok := b.cache.globalPeers[peer.ID] - if !ok { - return - } - *peerStored = *peer -} From f8745723fcb8edd19e8c07bcc12c06fb97226b1e Mon Sep 17 00:00:00 2001 From: Bethuel Mmbaga Date: Tue, 28 Apr 2026 12:42:19 +0300 Subject: [PATCH 08/11] [management] Add Microsoft AD FS support for embedded Dex identity providers (#6008) --- idp/dex/config.go | 4 +++- idp/dex/connector.go | 6 ++++-- management/server/identity_provider.go | 4 +++- management/server/types/identity_provider.go | 5 ++++- shared/management/http/api/openapi.yml | 1 + shared/management/http/api/types.gen.go | 3 +++ 6 files changed, 18 insertions(+), 5 deletions(-) diff --git a/idp/dex/config.go b/idp/dex/config.go index 7f5300f14..e686233ad 100644 --- a/idp/dex/config.go +++ b/idp/dex/config.go @@ -193,7 +193,7 @@ func (c *Connector) ToStorageConnector() (storage.Connector, error) { // are stored with types that Dex can open. func mapConnectorToDex(connType string, config map[string]interface{}) (string, map[string]interface{}) { switch connType { - case "oidc", "zitadel", "entra", "okta", "pocketid", "authentik", "keycloak": + case "oidc", "zitadel", "entra", "okta", "pocketid", "authentik", "keycloak", "adfs": return "oidc", applyOIDCDefaults(connType, config) default: return connType, config @@ -218,6 +218,8 @@ func applyOIDCDefaults(connType string, config map[string]interface{}) map[strin setDefault(augmented, "claimMapping", map[string]string{"email": "preferred_username"}) case "okta", "pocketid": augmented["scopes"] = []string{"openid", "profile", "email", "groups"} + case "adfs": + augmented["scopes"] = []string{"openid", "profile", "email", "allatclaims"} } return augmented diff --git a/idp/dex/connector.go b/idp/dex/connector.go index ba2bb1f00..8aba92999 100644 --- a/idp/dex/connector.go +++ b/idp/dex/connector.go @@ -168,7 +168,7 @@ func (p *Provider) buildStorageConnector(cfg *ConnectorConfig) (storage.Connecto var err error switch cfg.Type { - case "oidc", "zitadel", "entra", "okta", "pocketid", "authentik", "keycloak": + case "oidc", "zitadel", "entra", "okta", "pocketid", "authentik", "keycloak", "adfs": dexType = "oidc" configData, err = buildOIDCConnectorConfig(cfg, redirectURI) case "google": @@ -220,6 +220,8 @@ func buildOIDCConnectorConfig(cfg *ConnectorConfig, redirectURI string) ([]byte, oidcConfig["scopes"] = []string{"openid", "profile", "email", "groups"} case "pocketid": oidcConfig["scopes"] = []string{"openid", "profile", "email", "groups"} + case "adfs": + oidcConfig["scopes"] = []string{"openid", "profile", "email", "allatclaims"} } return encodeConnectorConfig(oidcConfig) } @@ -283,7 +285,7 @@ func inferIdentityProviderType(dexType, connectorID string, _ map[string]interfa // inferOIDCProviderType infers the specific OIDC provider from connector ID func inferOIDCProviderType(connectorID string) string { connectorIDLower := strings.ToLower(connectorID) - for _, provider := range []string{"pocketid", "zitadel", "entra", "okta", "authentik", "keycloak"} { + for _, provider := range []string{"pocketid", "zitadel", "entra", "okta", "authentik", "keycloak", "adfs"} { if strings.Contains(connectorIDLower, provider) { return provider } diff --git a/management/server/identity_provider.go b/management/server/identity_provider.go index 8fd96c238..f965f36b8 100644 --- a/management/server/identity_provider.go +++ b/management/server/identity_provider.go @@ -274,7 +274,7 @@ func identityProviderToConnectorConfig(idpConfig *types.IdentityProvider) *dex.C } // generateIdentityProviderID generates a unique ID for an identity provider. -// For specific provider types (okta, zitadel, entra, google, pocketid, microsoft), +// For specific provider types (okta, zitadel, entra, google, pocketid, microsoft, adfs), // the ID is prefixed with the type name. Generic OIDC providers get no prefix. func generateIdentityProviderID(idpType types.IdentityProviderType) string { id := xid.New().String() @@ -296,6 +296,8 @@ func generateIdentityProviderID(idpType types.IdentityProviderType) string { return "authentik-" + id case types.IdentityProviderTypeKeycloak: return "keycloak-" + id + case types.IdentityProviderTypeADFS: + return "adfs-" + id default: // Generic OIDC - no prefix return id diff --git a/management/server/types/identity_provider.go b/management/server/types/identity_provider.go index c4498e4d4..0c1f9509c 100644 --- a/management/server/types/identity_provider.go +++ b/management/server/types/identity_provider.go @@ -39,6 +39,8 @@ const ( IdentityProviderTypeAuthentik IdentityProviderType = "authentik" // IdentityProviderTypeKeycloak is the Keycloak identity provider IdentityProviderTypeKeycloak IdentityProviderType = "keycloak" + // IdentityProviderTypeADFS is the Microsoft AD FS identity provider + IdentityProviderTypeADFS IdentityProviderType = "adfs" ) // IdentityProvider represents an identity provider configuration @@ -112,7 +114,8 @@ func (t IdentityProviderType) IsValid() bool { switch t { case IdentityProviderTypeOIDC, IdentityProviderTypeZitadel, IdentityProviderTypeEntra, IdentityProviderTypeGoogle, IdentityProviderTypeOkta, IdentityProviderTypePocketID, - IdentityProviderTypeMicrosoft, IdentityProviderTypeAuthentik, IdentityProviderTypeKeycloak: + IdentityProviderTypeMicrosoft, IdentityProviderTypeAuthentik, IdentityProviderTypeKeycloak, + IdentityProviderTypeADFS: return true } return false diff --git a/shared/management/http/api/openapi.yml b/shared/management/http/api/openapi.yml index 0b855db67..b70f89499 100644 --- a/shared/management/http/api/openapi.yml +++ b/shared/management/http/api/openapi.yml @@ -2917,6 +2917,7 @@ components: - okta - pocketid - microsoft + - adfs example: oidc IdentityProvider: type: object diff --git a/shared/management/http/api/types.gen.go b/shared/management/http/api/types.gen.go index 0317b8183..d56cb9b74 100644 --- a/shared/management/http/api/types.gen.go +++ b/shared/management/http/api/types.gen.go @@ -518,6 +518,7 @@ const ( IdentityProviderTypeOkta IdentityProviderType = "okta" IdentityProviderTypePocketid IdentityProviderType = "pocketid" IdentityProviderTypeZitadel IdentityProviderType = "zitadel" + IdentityProviderTypeAdfs IdentityProviderType = "adfs" ) // Valid indicates whether the value is a known member of the IdentityProviderType enum. @@ -537,6 +538,8 @@ func (e IdentityProviderType) Valid() bool { return true case IdentityProviderTypeZitadel: return true + case IdentityProviderTypeAdfs: + return true default: return false } From 6f0eff3ba0696f26032fe4412e35def32e27eac6 Mon Sep 17 00:00:00 2001 From: Bethuel Mmbaga Date: Tue, 28 Apr 2026 14:48:28 +0300 Subject: [PATCH 09/11] [management] Handle single-string JWT group claim from IdPs (#6014) --- shared/auth/jwt/extractor.go | 10 ++++++++-- shared/auth/jwt/extractor_test.go | 9 +++++++++ 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/shared/auth/jwt/extractor.go b/shared/auth/jwt/extractor.go index 5806d1f4d..d113f53d5 100644 --- a/shared/auth/jwt/extractor.go +++ b/shared/auth/jwt/extractor.go @@ -146,7 +146,11 @@ func (c *ClaimsExtractor) ToGroups(token *jwt.Token, claimName string) []string userJWTGroups := make([]string, 0) if claim, ok := claims[claimName]; ok { - if claimGroups, ok := claim.([]interface{}); ok { + switch claimGroups := claim.(type) { + case string: + // Some IdPs emit a single group claim as a string instead of an array. + userJWTGroups = append(userJWTGroups, claimGroups) + case []any: for _, g := range claimGroups { if group, ok := g.(string); ok { userJWTGroups = append(userJWTGroups, group) @@ -154,9 +158,11 @@ func (c *ClaimsExtractor) ToGroups(token *jwt.Token, claimName string) []string log.Debugf("JWT claim %q contains a non-string group (type: %T): %v", claimName, g, g) } } + default: + log.Debugf("JWT claim %q is not a string or string array (type: %T): %v", claimName, claim, claim) } } else { - log.Debugf("JWT claim %q is not a string array", claimName) + log.Debugf("JWT claim %q is missing", claimName) } return userJWTGroups diff --git a/shared/auth/jwt/extractor_test.go b/shared/auth/jwt/extractor_test.go index 45529770d..4f8fe0007 100644 --- a/shared/auth/jwt/extractor_test.go +++ b/shared/auth/jwt/extractor_test.go @@ -249,6 +249,15 @@ func TestClaimsExtractor_ToGroups(t *testing.T) { groupClaimName: "groups", expectedGroups: []string{}, }, + { + name: "extracts single group string from claim", + claims: jwt.MapClaims{ + "sub": "user-123", + "groups": "admin", + }, + groupClaimName: "groups", + expectedGroups: []string{"admin"}, + }, { name: "handles custom claim name", claims: jwt.MapClaims{ From 9c50819f20b466aa0a56a56e4677ce24a7b5222a Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Tue, 28 Apr 2026 15:04:41 +0200 Subject: [PATCH 10/11] Don't mark management disconnected on transient job stream errors (#6005) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The JOB stream is a separate channel from the SYNC stream. Server-side EOF or other transient errors on the JOB stream do not indicate that the management connection is unhealthy — the SYNC stream remains the authoritative state signal. Previously, a JOB stream EOF would call notifyDisconnected and the client would emit OnConnecting to the UI. The backoff retry would reconnect the JOB stream, but handleJobStream never calls notifyConnected on success, so the UI was stuck on "Connecting" until the next SYNC event or health check. Keep notifyDisconnected for codes.PermissionDenied since IsLoginRequired relies on managementError to detect expired auth. --- shared/management/client/grpc.go | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/shared/management/client/grpc.go b/shared/management/client/grpc.go index e9bea7ffb..2a51a777d 100644 --- a/shared/management/client/grpc.go +++ b/shared/management/client/grpc.go @@ -252,21 +252,19 @@ func (c *GrpcClient) handleJobStream( c.notifyDisconnected(err) return backoff.Permanent(err) // unrecoverable error, propagate to the upper layer case codes.Canceled: - log.Debugf("management connection context has been canceled, this usually indicates shutdown") + log.Debugf("job stream context has been canceled, this usually indicates shutdown") return err case codes.Unimplemented: log.Warn("Job feature is not supported by the current management server version. " + "Please update the management service to use this feature.") return nil default: - c.notifyDisconnected(err) - log.Warnf("disconnected from the Management service but will retry silently. Reason: %v", err) + log.Warnf("job stream disconnected, will retry silently. Reason: %v", err) return err } } else { // non-gRPC error - c.notifyDisconnected(err) - log.Warnf("disconnected from the Management service but will retry silently. Reason: %v", err) + log.Warnf("job stream disconnected, will retry silently. Reason: %v", err) return err } } From 8fc4265995e81dd6ff12e438b8b4232c7bd377bc Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Tue, 28 Apr 2026 15:04:48 +0200 Subject: [PATCH 11/11] [relay] evict foreign client cache on disconnect (#6015) * [relay] evict foreign client cache on disconnect When a foreign relay's TCP connection drops, the manager's onServerDisconnected handler only triggered reconnect logic for the home server; the disconnected foreign entry stayed in the relayClients cache. Subsequent OpenConn calls reused the closed client until the 60-second cleanup tick evicted it, breaking peer connectivity through that relay for up to a minute. Evict the foreign entry from the cache on disconnect so the next OpenConn dials a fresh client. Also: - Make the reconnect backoff cap configurable via WithMaxBackoffInterval ManagerOption; the previous hard-coded 60s constant forced TestAutoReconnect to sleep ~61s. Test now polls Ready() and finishes in ~2s. - Add NB_HOME_RELAY_SERVERS env var that overrides the relay URL list received from management, so a peer can be pinned to a specific home relay (used by the netbird-conn-lab Edge 4 reproducer). * [client] treat empty NB_HOME_RELAY_SERVERS as unset Returning (urls=[], ok=true) when the env var contained only separators or whitespace caused callers to wipe the mgmt-provided relay list, leaving the peer with no relays. Treat a parsed-empty result the same as an unset env. --- client/internal/connect.go | 4 +++ client/internal/engine.go | 7 ++++- client/internal/peer/env.go | 28 +++++++++++++++++++- shared/relay/client/guard.go | 30 +++++++++++++--------- shared/relay/client/manager.go | 40 +++++++++++++++++++++++++---- shared/relay/client/manager_test.go | 23 +++++++++++++++-- 6 files changed, 111 insertions(+), 21 deletions(-) diff --git a/client/internal/connect.go b/client/internal/connect.go index ac498f719..72e096a80 100644 --- a/client/internal/connect.go +++ b/client/internal/connect.go @@ -333,6 +333,10 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan c.statusRecorder.MarkSignalConnected() relayURLs, token := parseRelayInfo(loginResp) + if override, ok := peer.OverrideRelayURLs(); ok { + log.Infof("overriding relay URLs from %s: %v", peer.EnvKeyNBHomeRelayServers, override) + relayURLs = override + } peerConfig := loginResp.GetPeerConfig() engineConfig, err := createEngineConfig(myPrivateKey, c.config, peerConfig, logPath) diff --git a/client/internal/engine.go b/client/internal/engine.go index 8d7e02bd5..351e4bfe9 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -944,7 +944,12 @@ func (e *Engine) handleRelayUpdate(update *mgmProto.RelayConfig) error { return fmt.Errorf("update relay token: %w", err) } - e.relayManager.UpdateServerURLs(update.Urls) + urls := update.Urls + if override, ok := peer.OverrideRelayURLs(); ok { + log.Infof("overriding relay URLs from %s: %v", peer.EnvKeyNBHomeRelayServers, override) + urls = override + } + e.relayManager.UpdateServerURLs(urls) // Just in case the agent started with an MGM server where the relay was disabled but was later enabled. // We can ignore all errors because the guard will manage the reconnection retries. diff --git a/client/internal/peer/env.go b/client/internal/peer/env.go index b4ba9ad7b..ed6a3af53 100644 --- a/client/internal/peer/env.go +++ b/client/internal/peer/env.go @@ -7,7 +7,8 @@ import ( ) const ( - EnvKeyNBForceRelay = "NB_FORCE_RELAY" + EnvKeyNBForceRelay = "NB_FORCE_RELAY" + EnvKeyNBHomeRelayServers = "NB_HOME_RELAY_SERVERS" ) func IsForceRelayed() bool { @@ -16,3 +17,28 @@ func IsForceRelayed() bool { } return strings.EqualFold(os.Getenv(EnvKeyNBForceRelay), "true") } + +// OverrideRelayURLs returns the relay server URL list set in +// NB_HOME_RELAY_SERVERS (comma-separated) and a boolean indicating whether +// the override is active. When the env var is unset, the boolean is false +// and the caller should keep the list received from the management server. +// Intended for lab/debug scenarios where a peer must pin to a specific home +// relay regardless of what management offers. +func OverrideRelayURLs() ([]string, bool) { + raw := os.Getenv(EnvKeyNBHomeRelayServers) + if raw == "" { + return nil, false + } + parts := strings.Split(raw, ",") + urls := make([]string, 0, len(parts)) + for _, p := range parts { + p = strings.TrimSpace(p) + if p != "" { + urls = append(urls, p) + } + } + if len(urls) == 0 { + return nil, false + } + return urls, true +} diff --git a/shared/relay/client/guard.go b/shared/relay/client/guard.go index f4d3a8cce..d7892d0ce 100644 --- a/shared/relay/client/guard.go +++ b/shared/relay/client/guard.go @@ -8,10 +8,7 @@ import ( log "github.com/sirupsen/logrus" ) -const ( - // TODO: make it configurable, the manager should validate all configurable parameters - reconnectingTimeout = 60 * time.Second -) +const defaultMaxBackoffInterval = 60 * time.Second // Guard manage the reconnection tries to the Relay server in case of disconnection event. type Guard struct { @@ -19,14 +16,23 @@ type Guard struct { OnNewRelayClient chan *Client OnReconnected chan struct{} serverPicker *ServerPicker + + // maxBackoffInterval caps the exponential backoff between reconnect + // attempts. + maxBackoffInterval time.Duration } -// NewGuard creates a new guard for the relay client. -func NewGuard(sp *ServerPicker) *Guard { +// NewGuard creates a new guard for the relay client. A non-positive +// maxBackoffInterval falls back to defaultMaxBackoffInterval. +func NewGuard(sp *ServerPicker, maxBackoffInterval time.Duration) *Guard { + if maxBackoffInterval <= 0 { + maxBackoffInterval = defaultMaxBackoffInterval + } g := &Guard{ - OnNewRelayClient: make(chan *Client, 1), - OnReconnected: make(chan struct{}, 1), - serverPicker: sp, + OnNewRelayClient: make(chan *Client, 1), + OnReconnected: make(chan struct{}, 1), + serverPicker: sp, + maxBackoffInterval: maxBackoffInterval, } return g } @@ -49,7 +55,7 @@ func (g *Guard) StartReconnectTrys(ctx context.Context, relayClient *Client) { } // start a ticker to pick a new server - ticker := exponentTicker(ctx) + ticker := g.exponentTicker(ctx) defer ticker.Stop() for { @@ -125,11 +131,11 @@ func (g *Guard) notifyReconnected() { } } -func exponentTicker(ctx context.Context) *backoff.Ticker { +func (g *Guard) exponentTicker(ctx context.Context) *backoff.Ticker { bo := backoff.WithContext(&backoff.ExponentialBackOff{ InitialInterval: 2 * time.Second, Multiplier: 2, - MaxInterval: reconnectingTimeout, + MaxInterval: g.maxBackoffInterval, Clock: backoff.SystemClock, }, ctx) diff --git a/shared/relay/client/manager.go b/shared/relay/client/manager.go index 6220e7f6b..37104bfe7 100644 --- a/shared/relay/client/manager.go +++ b/shared/relay/client/manager.go @@ -39,6 +39,15 @@ func NewRelayTrack() *RelayTrack { type OnServerCloseListener func() +// ManagerOption configures a Manager at construction time. +type ManagerOption func(*Manager) + +// WithMaxBackoffInterval caps the exponential backoff between reconnect +// attempts to the home relay. A non-positive value keeps the default. +func WithMaxBackoffInterval(d time.Duration) ManagerOption { + return func(m *Manager) { m.maxBackoffInterval = d } +} + // Manager is a manager for the relay client instances. It establishes one persistent connection to the given relay URL // and automatically reconnect to them in case disconnection. // The manager also manage temporary relay connection. If a client wants to communicate with a client on a @@ -64,12 +73,13 @@ type Manager struct { onReconnectedListenerFn func() listenerLock sync.Mutex - mtu uint16 + mtu uint16 + maxBackoffInterval time.Duration } // NewManager creates a new manager instance. // The serverURL address can be empty. In this case, the manager will not serve. -func NewManager(ctx context.Context, serverURLs []string, peerID string, mtu uint16) *Manager { +func NewManager(ctx context.Context, serverURLs []string, peerID string, mtu uint16, opts ...ManagerOption) *Manager { tokenStore := &relayAuth.TokenStore{} m := &Manager{ @@ -86,8 +96,11 @@ func NewManager(ctx context.Context, serverURLs []string, peerID string, mtu uin relayClients: make(map[string]*RelayTrack), onDisconnectedListeners: make(map[string]*list.List), } + for _, opt := range opts { + opt(m) + } m.serverPicker.ServerURLs.Store(serverURLs) - m.reconnectGuard = NewGuard(m.serverPicker) + m.reconnectGuard = NewGuard(m.serverPicker, m.maxBackoffInterval) return m } @@ -290,19 +303,36 @@ func (m *Manager) onServerConnected() { go m.onReconnectedListenerFn() } -// onServerDisconnected start to reconnection for home server only +// onServerDisconnected handles relay disconnect events. For the home server it +// starts the reconnect guard. For foreign servers it evicts the now-dead client +// from the cache so the next OpenConn builds a fresh one instead of reusing a +// closed client. func (m *Manager) onServerDisconnected(serverAddress string) { m.relayClientMu.Lock() - if serverAddress == m.relayClient.connectionURL { + isHome := m.relayClient != nil && serverAddress == m.relayClient.connectionURL + if isHome { go func(client *Client) { m.reconnectGuard.StartReconnectTrys(m.ctx, client) }(m.relayClient) } m.relayClientMu.Unlock() + if !isHome { + m.evictForeignRelay(serverAddress) + } + m.notifyOnDisconnectListeners(serverAddress) } +func (m *Manager) evictForeignRelay(serverAddress string) { + m.relayClientsMutex.Lock() + defer m.relayClientsMutex.Unlock() + if _, ok := m.relayClients[serverAddress]; ok { + delete(m.relayClients, serverAddress) + log.Debugf("evicted disconnected foreign relay client: %s", serverAddress) + } +} + func (m *Manager) listenGuardEvent(ctx context.Context) { for { select { diff --git a/shared/relay/client/manager_test.go b/shared/relay/client/manager_test.go index fb91f7682..5bbcad886 100644 --- a/shared/relay/client/manager_test.go +++ b/shared/relay/client/manager_test.go @@ -2,6 +2,7 @@ package client import ( "context" + "fmt" "testing" "time" @@ -360,7 +361,8 @@ func TestAutoReconnect(t *testing.T) { t.Fatalf("failed to serve manager: %s", err) } - clientAlice := NewManager(mCtx, toURL(srvCfg), "alice", iface.DefaultMTU) + clientAlice := NewManager(mCtx, toURL(srvCfg), "alice", iface.DefaultMTU, + WithMaxBackoffInterval(2*time.Second)) err = clientAlice.Serve() if err != nil { t.Fatalf("failed to serve manager: %s", err) @@ -384,7 +386,9 @@ func TestAutoReconnect(t *testing.T) { } log.Infof("waiting for reconnection") - time.Sleep(reconnectingTimeout + 1*time.Second) + if err := waitForReady(ctx, clientAlice, 15*time.Second); err != nil { + t.Fatalf("manager did not reconnect: %s", err) + } log.Infof("reopent the connection") _, err = clientAlice.OpenConn(ctx, ra, "bob") @@ -393,6 +397,21 @@ func TestAutoReconnect(t *testing.T) { } } +func waitForReady(ctx context.Context, m *Manager, timeout time.Duration) error { + deadline := time.Now().Add(timeout) + for time.Now().Before(deadline) { + if m.Ready() { + return nil + } + select { + case <-time.After(100 * time.Millisecond): + case <-ctx.Done(): + return ctx.Err() + } + } + return fmt.Errorf("manager not ready within %s", timeout) +} + func TestNotifierDoubleAdd(t *testing.T) { ctx := context.Background()