From ab2a8794e7a41693fff303725c96399ca190e8ff Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Thu, 14 May 2026 12:30:42 +0200 Subject: [PATCH 01/31] [client] Add short flags for status command options (#6137) * [client] Add short flags for status command options * uppercase filters --- client/cmd/status.go | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/client/cmd/status.go b/client/cmd/status.go index dae30e854..103b3044a 100644 --- a/client/cmd/status.go +++ b/client/cmd/status.go @@ -43,16 +43,16 @@ func init() { ipsFilterMap = make(map[string]struct{}) prefixNamesFilterMap = make(map[string]struct{}) statusCmd.PersistentFlags().BoolVarP(&detailFlag, "detail", "d", false, "display detailed status information in human-readable format") - statusCmd.PersistentFlags().BoolVar(&jsonFlag, "json", false, "display detailed status information in json format") - statusCmd.PersistentFlags().BoolVar(&yamlFlag, "yaml", false, "display detailed status information in yaml format") - statusCmd.PersistentFlags().BoolVar(&ipv4Flag, "ipv4", false, "display only NetBird IPv4 of this peer, e.g., --ipv4 will output 100.64.0.33") - statusCmd.PersistentFlags().BoolVar(&ipv6Flag, "ipv6", false, "display only NetBird IPv6 of this peer") + statusCmd.PersistentFlags().BoolVarP(&jsonFlag, "json", "j", false, "display detailed status information in json format") + statusCmd.PersistentFlags().BoolVarP(&yamlFlag, "yaml", "y", false, "display detailed status information in yaml format") + statusCmd.PersistentFlags().BoolVarP(&ipv4Flag, "ipv4", "4", false, "display only NetBird IPv4 of this peer, e.g., --ipv4 will output 100.64.0.33") + statusCmd.PersistentFlags().BoolVarP(&ipv6Flag, "ipv6", "6", false, "display only NetBird IPv6 of this peer") statusCmd.MarkFlagsMutuallyExclusive("detail", "json", "yaml", "ipv4", "ipv6") - statusCmd.PersistentFlags().StringSliceVar(&ipsFilter, "filter-by-ips", []string{}, "filters the detailed output by a list of one or more IPs (v4 or v6), e.g., --filter-by-ips 100.64.0.100,fd00::1") - statusCmd.PersistentFlags().StringSliceVar(&prefixNamesFilter, "filter-by-names", []string{}, "filters the detailed output by a list of one or more peer FQDN or hostnames, e.g., --filter-by-names peer-a,peer-b.netbird.cloud") - statusCmd.PersistentFlags().StringVar(&statusFilter, "filter-by-status", "", "filters the detailed output by connection status(idle|connecting|connected), e.g., --filter-by-status connected") - statusCmd.PersistentFlags().StringVar(&connectionTypeFilter, "filter-by-connection-type", "", "filters the detailed output by connection type (P2P|Relayed), e.g., --filter-by-connection-type P2P") - statusCmd.PersistentFlags().StringVar(&checkFlag, "check", "", "run a health check and exit with code 0 on success, 1 on failure (live|ready|startup)") + statusCmd.PersistentFlags().StringSliceVarP(&ipsFilter, "filter-by-ips", "I", []string{}, "filters the detailed output by a list of one or more IPs (v4 or v6), e.g., --filter-by-ips 100.64.0.100,fd00::1") + statusCmd.PersistentFlags().StringSliceVarP(&prefixNamesFilter, "filter-by-names", "N", []string{}, "filters the detailed output by a list of one or more peer FQDN or hostnames, e.g., --filter-by-names peer-a,peer-b.netbird.cloud") + statusCmd.PersistentFlags().StringVarP(&statusFilter, "filter-by-status", "S", "", "filters the detailed output by connection status(idle|connecting|connected), e.g., --filter-by-status connected") + statusCmd.PersistentFlags().StringVarP(&connectionTypeFilter, "filter-by-connection-type", "T", "", "filters the detailed output by connection type (P2P|Relayed), e.g., --filter-by-connection-type P2P") + statusCmd.PersistentFlags().StringVarP(&checkFlag, "check", "C", "", "run a health check and exit with code 0 on success, 1 on failure (live|ready|startup)") } func statusFunc(cmd *cobra.Command, args []string) error { From 77b479286e399660ef2bdcbe7983363946660574 Mon Sep 17 00:00:00 2001 From: Vlad <4941176+crn4@users.noreply.github.com> Date: Thu, 14 May 2026 13:27:50 +0200 Subject: [PATCH 02/31] [management] fix offline statuses for public proxy clusters (#6133) --- .../reverseproxy/domain/manager/manager.go | 23 ++++++- .../domain/manager/manager_test.go | 60 ++++++++++++++++--- .../reverseproxy/service/manager/manager.go | 59 +++++++++--------- management/server/store/sql_store.go | 28 +++++++-- management/server/store/sql_store_test.go | 52 ++++++++++++++++ 5 files changed, 177 insertions(+), 45 deletions(-) diff --git a/management/internals/modules/reverseproxy/domain/manager/manager.go b/management/internals/modules/reverseproxy/domain/manager/manager.go index ab899e0bf..2790b5f20 100644 --- a/management/internals/modules/reverseproxy/domain/manager/manager.go +++ b/management/internals/modules/reverseproxy/domain/manager/manager.go @@ -304,10 +304,27 @@ func (m Manager) getClusterAllowList(ctx context.Context, accountID string) ([]s if err != nil { return nil, fmt.Errorf("get BYOP cluster addresses: %w", err) } - if len(byopAddresses) > 0 { - return byopAddresses, nil + publicAddresses, err := m.proxyManager.GetActiveClusterAddresses(ctx) + if err != nil { + return nil, fmt.Errorf("get public cluster addresses: %w", err) } - return m.proxyManager.GetActiveClusterAddresses(ctx) + seen := make(map[string]struct{}, len(byopAddresses)+len(publicAddresses)) + merged := make([]string, 0, len(byopAddresses)+len(publicAddresses)) + for _, addr := range byopAddresses { + if _, ok := seen[addr]; ok { + continue + } + seen[addr] = struct{}{} + merged = append(merged, addr) + } + for _, addr := range publicAddresses { + if _, ok := seen[addr]; ok { + continue + } + seen[addr] = struct{}{} + merged = append(merged, addr) + } + return merged, nil } func extractClusterFromCustomDomains(serviceDomain string, customDomains []*domain.Domain) (string, bool) { diff --git a/management/internals/modules/reverseproxy/domain/manager/manager_test.go b/management/internals/modules/reverseproxy/domain/manager/manager_test.go index fdeb0765f..5e7bbfc36 100644 --- a/management/internals/modules/reverseproxy/domain/manager/manager_test.go +++ b/management/internals/modules/reverseproxy/domain/manager/manager_test.go @@ -40,22 +40,37 @@ func (m *mockProxyManager) ClusterSupportsCrowdSec(_ context.Context, _ string) return nil } -func TestGetClusterAllowList_BYOPProxy(t *testing.T) { +func TestGetClusterAllowList_BYOPMergedWithPublic(t *testing.T) { pm := &mockProxyManager{ getActiveClusterAddressesForAccountFunc: func(_ context.Context, accID string) ([]string, error) { assert.Equal(t, "acc-123", accID) return []string{"byop.example.com"}, nil }, getActiveClusterAddressesFunc: func(_ context.Context) ([]string, error) { - t.Fatal("should not call GetActiveClusterAddresses when BYOP addresses exist") - return nil, nil + return []string{"eu.proxy.netbird.io"}, nil }, } mgr := Manager{proxyManager: pm} result, err := mgr.getClusterAllowList(context.Background(), "acc-123") require.NoError(t, err) - assert.Equal(t, []string{"byop.example.com"}, result) + assert.Equal(t, []string{"byop.example.com", "eu.proxy.netbird.io"}, result) +} + +func TestGetClusterAllowList_DeduplicatesBYOPAndPublic(t *testing.T) { + pm := &mockProxyManager{ + getActiveClusterAddressesForAccountFunc: func(_ context.Context, _ string) ([]string, error) { + return []string{"shared.example.com", "byop.example.com"}, nil + }, + getActiveClusterAddressesFunc: func(_ context.Context) ([]string, error) { + return []string{"shared.example.com", "eu.proxy.netbird.io"}, nil + }, + } + + mgr := Manager{proxyManager: pm} + result, err := mgr.getClusterAllowList(context.Background(), "acc-123") + require.NoError(t, err) + assert.Equal(t, []string{"shared.example.com", "byop.example.com", "eu.proxy.netbird.io"}, result) } func TestGetClusterAllowList_NoBYOP_FallbackToShared(t *testing.T) { @@ -79,10 +94,6 @@ func TestGetClusterAllowList_BYOPError_ReturnsError(t *testing.T) { getActiveClusterAddressesForAccountFunc: func(_ context.Context, _ string) ([]string, error) { return nil, errors.New("db error") }, - getActiveClusterAddressesFunc: func(_ context.Context) ([]string, error) { - t.Fatal("should not call GetActiveClusterAddresses when BYOP lookup fails") - return nil, nil - }, } mgr := Manager{proxyManager: pm} @@ -92,6 +103,23 @@ func TestGetClusterAllowList_BYOPError_ReturnsError(t *testing.T) { assert.Contains(t, err.Error(), "BYOP cluster addresses") } +func TestGetClusterAllowList_PublicError_ReturnsError(t *testing.T) { + pm := &mockProxyManager{ + getActiveClusterAddressesForAccountFunc: func(_ context.Context, _ string) ([]string, error) { + return []string{"byop.example.com"}, nil + }, + getActiveClusterAddressesFunc: func(_ context.Context) ([]string, error) { + return nil, errors.New("db error") + }, + } + + mgr := Manager{proxyManager: pm} + result, err := mgr.getClusterAllowList(context.Background(), "acc-123") + require.Error(t, err) + assert.Nil(t, result) + assert.Contains(t, err.Error(), "public cluster addresses") +} + func TestGetClusterAllowList_BYOPEmptySlice_FallbackToShared(t *testing.T) { pm := &mockProxyManager{ getActiveClusterAddressesForAccountFunc: func(_ context.Context, _ string) ([]string, error) { @@ -108,3 +136,19 @@ func TestGetClusterAllowList_BYOPEmptySlice_FallbackToShared(t *testing.T) { assert.Equal(t, []string{"eu.proxy.netbird.io"}, result) } +func TestGetClusterAllowList_PublicEmpty_BYOPOnly(t *testing.T) { + pm := &mockProxyManager{ + getActiveClusterAddressesForAccountFunc: func(_ context.Context, _ string) ([]string, error) { + return []string{"byop.example.com"}, nil + }, + getActiveClusterAddressesFunc: func(_ context.Context) ([]string, error) { + return nil, nil + }, + } + + mgr := Manager{proxyManager: pm} + result, err := mgr.getClusterAllowList(context.Background(), "acc-123") + require.NoError(t, err) + assert.Equal(t, []string{"byop.example.com"}, result) +} + diff --git a/management/internals/modules/reverseproxy/service/manager/manager.go b/management/internals/modules/reverseproxy/service/manager/manager.go index c866d8f75..4a8598afb 100644 --- a/management/internals/modules/reverseproxy/service/manager/manager.go +++ b/management/internals/modules/reverseproxy/service/manager/manager.go @@ -306,6 +306,10 @@ func (m *Manager) validateSubdomainRequirement(ctx context.Context, domain, clus func (m *Manager) persistNewService(ctx context.Context, accountID string, svc *service.Service) error { customPorts := m.clusterCustomPorts(ctx, svc) + if err := validateTargetReferences(ctx, m.store, accountID, svc.Targets); err != nil { + return err + } + return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { if svc.Domain != "" { if err := m.checkDomainAvailable(ctx, transaction, svc.Domain, ""); err != nil { @@ -321,10 +325,6 @@ func (m *Manager) persistNewService(ctx context.Context, accountID string, svc * return err } - if err := validateTargetReferences(ctx, transaction, accountID, svc.Targets); err != nil { - return err - } - if err := transaction.CreateService(ctx, svc); err != nil { return fmt.Errorf("create service: %w", err) } @@ -435,6 +435,10 @@ func (m *Manager) assignPort(ctx context.Context, tx store.Store, cluster string func (m *Manager) persistNewEphemeralService(ctx context.Context, accountID, peerID string, svc *service.Service) error { customPorts := m.clusterCustomPorts(ctx, svc) + if err := validateTargetReferences(ctx, m.store, accountID, svc.Targets); err != nil { + return err + } + return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { if err := m.validateEphemeralPreconditions(ctx, transaction, accountID, peerID, svc); err != nil { return err @@ -448,10 +452,6 @@ func (m *Manager) persistNewEphemeralService(ctx context.Context, accountID, pee return err } - if err := validateTargetReferences(ctx, transaction, accountID, svc.Targets); err != nil { - return err - } - if err := transaction.CreateService(ctx, svc); err != nil { return fmt.Errorf("create service: %w", err) } @@ -552,10 +552,22 @@ func (m *Manager) persistServiceUpdate(ctx context.Context, accountID string, se svcForCaps.ProxyCluster = effectiveCluster customPorts := m.clusterCustomPorts(ctx, &svcForCaps) + if err := validateTargetReferences(ctx, m.store, accountID, service.Targets); err != nil { + return nil, err + } + + // Validate subdomain requirement *before* the transaction: the underlying + // capability lookup talks to the main DB pool, and SQLite's single-connection + // pool would self-deadlock if this ran while the tx already held the only + // connection. + if err := m.validateSubdomainRequirement(ctx, service.Domain, effectiveCluster); err != nil { + return nil, err + } + var updateInfo serviceUpdateInfo err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { - return m.executeServiceUpdate(ctx, transaction, accountID, service, &updateInfo, customPorts) + return m.executeServiceUpdate(ctx, transaction, accountID, service, &updateInfo, customPorts, effectiveCluster) }) return &updateInfo, err @@ -585,7 +597,7 @@ func (m *Manager) resolveEffectiveCluster(ctx context.Context, accountID string, return existing.ProxyCluster, nil } -func (m *Manager) executeServiceUpdate(ctx context.Context, transaction store.Store, accountID string, service *service.Service, updateInfo *serviceUpdateInfo, customPorts *bool) error { +func (m *Manager) executeServiceUpdate(ctx context.Context, transaction store.Store, accountID string, service *service.Service, updateInfo *serviceUpdateInfo, customPorts *bool, effectiveCluster string) error { existingService, err := transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, service.ID) if err != nil { return err @@ -603,17 +615,13 @@ func (m *Manager) executeServiceUpdate(ctx context.Context, transaction store.St updateInfo.domainChanged = existingService.Domain != service.Domain if updateInfo.domainChanged { - if err := m.handleDomainChange(ctx, transaction, accountID, service); err != nil { + if err := m.handleDomainChange(ctx, transaction, service, effectiveCluster); err != nil { return err } } else { service.ProxyCluster = existingService.ProxyCluster } - if err := m.validateSubdomainRequirement(ctx, service.Domain, service.ProxyCluster); err != nil { - return err - } - m.preserveExistingAuthSecrets(service, existingService) if err := validateHeaderAuthValues(service.Auth.HeaderAuths); err != nil { return err @@ -628,9 +636,6 @@ func (m *Manager) executeServiceUpdate(ctx context.Context, transaction store.St if err := m.checkPortConflict(ctx, transaction, service); err != nil { return err } - if err := validateTargetReferences(ctx, transaction, accountID, service.Targets); err != nil { - return err - } if err := transaction.UpdateService(ctx, service); err != nil { return fmt.Errorf("update service: %w", err) } @@ -638,20 +643,18 @@ func (m *Manager) executeServiceUpdate(ctx context.Context, transaction store.St return nil } -func (m *Manager) handleDomainChange(ctx context.Context, transaction store.Store, accountID string, svc *service.Service) error { +// handleDomainChange validates the new domain is free inside the transaction +// and applies the pre-resolved cluster (computed outside the tx by +// resolveEffectiveCluster). It must NOT call clusterDeriver here: that talks +// to the main DB pool and would self-deadlock under SQLite (max_open_conns=1) +// because the transaction already holds the only connection. +func (m *Manager) handleDomainChange(ctx context.Context, transaction store.Store, svc *service.Service, effectiveCluster string) error { if err := m.checkDomainAvailable(ctx, transaction, svc.Domain, svc.ID); err != nil { return err } - - if m.clusterDeriver != nil { - newCluster, err := m.clusterDeriver.DeriveClusterFromDomain(ctx, accountID, svc.Domain) - if err != nil { - log.WithError(err).Warnf("could not derive cluster from domain %s", svc.Domain) - } else { - svc.ProxyCluster = newCluster - } + if effectiveCluster != "" { + svc.ProxyCluster = effectiveCluster } - return nil } diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index 4c2f0be52..893ee2168 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -8,6 +8,7 @@ import ( "fmt" "net" "net/netip" + "net/url" "os" "path/filepath" "runtime" @@ -2794,12 +2795,27 @@ func NewSqliteStore(ctx context.Context, dataDir string, metrics telemetry.AppMe connStr = filepath.Join(dataDir, filePath) } - // Append query parameters: user-provided take precedence, otherwise default to cache=shared on non-Windows - if hasQuery { - connStr += "?" + query - } else if runtime.GOOS != "windows" { + // Compose query parameters. User-provided ?_busy_timeout (or its mattn alias + // ?_timeout) overrides our default; otherwise inject 30s so SQLite waits at + // most that long on a lock instead of blocking the only Go-side connection. + // mattn/go-sqlite3 applies PRAGMA from the DSN on every fresh connection, so + // the value survives ConnMaxIdleTime/ConnMaxLifetime recycling. cache=shared + // stays the default on non-Windows for the same reason as before. + parsed, _ := url.ParseQuery(query) + var defaults []string + if parsed.Get("_busy_timeout") == "" && parsed.Get("_timeout") == "" { + defaults = append(defaults, "_busy_timeout=30000") + } + if !hasQuery && runtime.GOOS != "windows" { // To avoid `The process cannot access the file because it is being used by another process` on Windows - connStr += "?cache=shared" + defaults = append(defaults, "cache=shared") + } + parts := defaults + if hasQuery { + parts = append(parts, query) + } + if len(parts) > 0 { + connStr += "?" + strings.Join(parts, "&") } db, err := gorm.Open(sqlite.Open(connStr), getGormConfig()) @@ -3402,7 +3418,7 @@ func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, accountId string) } func (s *SqlStore) ExecuteInTransaction(ctx context.Context, operation func(store Store) error) error { - timeoutCtx, cancel := context.WithTimeout(context.Background(), s.transactionTimeout) + timeoutCtx, cancel := context.WithTimeout(ctx, s.transactionTimeout) defer cancel() startTime := time.Now() diff --git a/management/server/store/sql_store_test.go b/management/server/store/sql_store_test.go index 2819265c3..7515add62 100644 --- a/management/server/store/sql_store_test.go +++ b/management/server/store/sql_store_test.go @@ -4592,3 +4592,55 @@ func TestSqlStore_DeleteZoneDNSRecords(t *testing.T) { require.NoError(t, err) assert.Equal(t, 0, len(remainingRecords)) } + +// TestNewSqliteStore_BusyTimeoutApplied opens a fresh SQLite store and verifies +// that the _busy_timeout DSN parameter took effect at the driver level. Without +// this, lock contention on the single SQLite connection waits indefinitely on +// the Go side and can be hidden behind the 5-minute transactionTimeout. +func TestNewSqliteStore_BusyTimeoutApplied(t *testing.T) { + dir := t.TempDir() + store, err := NewSqliteStore(context.Background(), dir, nil, true) + require.NoError(t, err) + t.Cleanup(func() { + _ = store.Close(context.Background()) + }) + + sqlDB, err := store.db.DB() + require.NoError(t, err) + row := sqlDB.QueryRow("PRAGMA busy_timeout") + var busyTimeout int + require.NoError(t, row.Scan(&busyTimeout)) + assert.Equal(t, 30000, busyTimeout, "SQLite busy_timeout must be set via DSN so it survives connection recycling") +} + +// TestNewSqliteStore_BusyTimeoutRespectsUserOverride confirms that an operator +// passing _busy_timeout or its mattn alias _timeout via NB_STORE_ENGINE_SQLITE_FILE +// wins over our 30s default. This guards the DSN merge logic in NewSqliteStore. +func TestNewSqliteStore_BusyTimeoutRespectsUserOverride(t *testing.T) { + cases := []struct { + name string + envFile string + expected int + }{ + {name: "explicit _busy_timeout wins", envFile: "store.db?_busy_timeout=5000", expected: 5000}, + {name: "alias _timeout wins", envFile: "store.db?_timeout=7000", expected: 7000}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + t.Setenv("NB_STORE_ENGINE_SQLITE_FILE", tc.envFile) + dir := t.TempDir() + store, err := NewSqliteStore(context.Background(), dir, nil, true) + require.NoError(t, err) + t.Cleanup(func() { + _ = store.Close(context.Background()) + }) + + sqlDB, err := store.db.DB() + require.NoError(t, err) + row := sqlDB.QueryRow("PRAGMA busy_timeout") + var busyTimeout int + require.NoError(t, row.Scan(&busyTimeout)) + assert.Equal(t, tc.expected, busyTimeout) + }) + } +} From ea9fab4396fc5513f7c62e3465dd361dc8bb9e91 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Thu, 14 May 2026 23:05:33 +0900 Subject: [PATCH 03/31] [management] Allocate and preserve IPv6 overlay addresses for embedded proxy peers (#6132) --- management/server/account.go | 12 +++++++++ management/server/peer.go | 17 +++++++----- management/server/types/account.go | 30 ++++++++++----------- management/server/types/group.go | 5 +++- management/server/types/ipv6_groups_test.go | 30 +++++++++++++++++++++ 5 files changed, 70 insertions(+), 24 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index 364c0c37b..77a46a069 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -2487,6 +2487,18 @@ func (am *DefaultAccountManager) buildIPv6AllowedPeers(ctx context.Context, tran allowedPeers[peerID] = struct{}{} } } + + // Embedded proxy peers sit outside regular group membership but must + // participate in any v6-enabled overlay to reach v6-only peers. + peers, err := transaction.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "") + if err != nil { + return nil, fmt.Errorf("get peers: %w", err) + } + for _, p := range peers { + if p.ProxyMeta.Embedded { + allowedPeers[p.ID] = struct{}{} + } + } return allowedPeers, nil } diff --git a/management/server/peer.go b/management/server/peer.go index 8a39fbbb8..c3b130ba2 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -762,16 +762,19 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe newPeer.IP = freeIP if len(settings.IPv6EnabledGroups) > 0 && network.NetV6.IP != nil { - var allGroupID string - if !peer.ProxyMeta.Embedded { - allGroup, err := am.Store.GetGroupByName(ctx, store.LockingStrengthNone, accountID, "All") - if err != nil { - log.WithContext(ctx).Debugf("get All group for IPv6 allocation: %v", err) - } else { + // Embedded proxy peers are not group members but participate in any + // IPv6-enabled overlay so reverse-proxy traffic reaches v6-only peers. + allocate := peer.ProxyMeta.Embedded + if !allocate { + var allGroupID string + if allGroup, err := am.Store.GetGroupByName(ctx, store.LockingStrengthNone, accountID, types.GroupAllName); err == nil { allGroupID = allGroup.ID + } else { + log.WithContext(ctx).Debugf("get All group for IPv6 allocation: %v", err) } + allocate = peerWillHaveIPv6(settings, peerAddConfig.GroupsToAdd, allGroupID) } - if peerWillHaveIPv6(settings, peerAddConfig.GroupsToAdd, allGroupID) { + if allocate { v6Prefix, err := netip.ParsePrefix(network.NetV6.String()) if err != nil { return nil, nil, nil, fmt.Errorf("parse IPv6 prefix: %w", err) diff --git a/management/server/types/account.go b/management/server/types/account.go index 49600163a..870333a60 100644 --- a/management/server/types/account.go +++ b/management/server/types/account.go @@ -598,28 +598,21 @@ func (a *Account) GetPeerGroups(peerID string) LookupMap { return groupList } -// PeerIPv6Allowed reports whether the given peer is in any of the account's IPv6 enabled groups. +// PeerIPv6Allowed reports whether the given peer participates in the IPv6 overlay. // Returns false if IPv6 is disabled or no groups are configured. func (a *Account) PeerIPv6Allowed(peerID string) bool { - if len(a.Settings.IPv6EnabledGroups) == 0 { - return false - } - - for _, groupID := range a.Settings.IPv6EnabledGroups { - group, ok := a.Groups[groupID] - if !ok { - continue - } - if slices.Contains(group.Peers, peerID) { - return true - } - } - return false + _, ok := a.peerIPv6AllowedSet()[peerID] + return ok } -// peerIPv6AllowedSet returns a set of peer IDs that belong to any IPv6-enabled group. +// peerIPv6AllowedSet returns the set of peer IDs that participate in the IPv6 overlay: +// members of any IPv6-enabled group, plus every embedded proxy peer (which sit outside +// regular group membership but must reach v6-enabled peers). func (a *Account) peerIPv6AllowedSet() map[string]struct{} { result := make(map[string]struct{}) + if len(a.Settings.IPv6EnabledGroups) == 0 { + return result + } for _, groupID := range a.Settings.IPv6EnabledGroups { group, ok := a.Groups[groupID] if !ok { @@ -629,6 +622,11 @@ func (a *Account) peerIPv6AllowedSet() map[string]struct{} { result[peerID] = struct{}{} } } + for id, p := range a.Peers { + if p != nil && p.ProxyMeta.Embedded { + result[id] = struct{}{} + } + } return result } diff --git a/management/server/types/group.go b/management/server/types/group.go index 00fdf7a69..b4f50080a 100644 --- a/management/server/types/group.go +++ b/management/server/types/group.go @@ -92,9 +92,12 @@ func (g *Group) HasPeers() bool { return len(g.Peers) > 0 } +// GroupAllName is the reserved name of the default group that contains every peer in an account. +const GroupAllName = "All" + // IsGroupAll checks if the group is a default "All" group. func (g *Group) IsGroupAll() bool { - return g.Name == "All" + return g.Name == GroupAllName } // AddPeer adds peerID to Peers if not present, returning true if added. diff --git a/management/server/types/ipv6_groups_test.go b/management/server/types/ipv6_groups_test.go index 5151e1b1f..766a9c92c 100644 --- a/management/server/types/ipv6_groups_test.go +++ b/management/server/types/ipv6_groups_test.go @@ -232,3 +232,33 @@ func TestIPv6RecalculationOnGroupChange(t *testing.T) { assert.True(t, account.PeerIPv6Allowed("peer3"), "peer3 now in infra") }) } + +func TestPeerIPv6AllowedEmbeddedProxy(t *testing.T) { + account := &Account{ + Peers: map[string]*nbpeer.Peer{ + "peer1": {ID: "peer1"}, + "proxy": {ID: "proxy", ProxyMeta: nbpeer.ProxyMeta{Embedded: true, Cluster: "netbird.test"}}, + }, + Groups: map[string]*Group{ + "group-devs": {ID: "group-devs", Peers: []string{"peer1"}}, + }, + Settings: &Settings{}, + } + + t.Run("embedded proxy allowed when any v6 group exists, without group membership", func(t *testing.T) { + account.Settings.IPv6EnabledGroups = []string{"group-devs"} + assert.True(t, account.PeerIPv6Allowed("proxy"), "embedded proxy participates in v6 overlay") + assert.True(t, account.PeerIPv6Allowed("peer1"), "regular peer in enabled group still allowed") + }) + + t.Run("embedded proxy denied when no v6 group enabled", func(t *testing.T) { + account.Settings.IPv6EnabledGroups = nil + assert.False(t, account.PeerIPv6Allowed("proxy"), "v6 disabled account-wide denies embedded proxies too") + }) + + t.Run("non-embedded peer outside any enabled group is not pulled in", func(t *testing.T) { + account.Settings.IPv6EnabledGroups = []string{"group-devs"} + account.Peers["lonely"] = &nbpeer.Peer{ID: "lonely"} + assert.False(t, account.PeerIPv6Allowed("lonely"), "embedded-proxy bypass must not leak to regular peers") + }) +} From 3f914090cbb345707a88b5edb20d9c1351873b4c Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Thu, 14 May 2026 23:22:53 +0900 Subject: [PATCH 04/31] [client] Bracket IPv6 in embed listeners, expand debug bundle (#6134) --- client/embed/embed.go | 4 +- client/internal/debug/debug.go | 36 +++-- client/internal/debug/debug_linux.go | 195 +++++++++++++++++------- client/internal/debug/debug_nonlinux.go | 5 + 4 files changed, 178 insertions(+), 62 deletions(-) diff --git a/client/embed/embed.go b/client/embed/embed.go index 4b9445b97..8b669e547 100644 --- a/client/embed/embed.go +++ b/client/embed/embed.go @@ -336,7 +336,7 @@ func (c *Client) ListenTCP(address string) (net.Listener, error) { if err != nil { return nil, fmt.Errorf("split host port: %w", err) } - listenAddr := fmt.Sprintf("%s:%s", addr, port) + listenAddr := net.JoinHostPort(addr.String(), port) tcpAddr, err := net.ResolveTCPAddr("tcp", listenAddr) if err != nil { @@ -357,7 +357,7 @@ func (c *Client) ListenUDP(address string) (net.PacketConn, error) { if err != nil { return nil, fmt.Errorf("split host port: %w", err) } - listenAddr := fmt.Sprintf("%s:%s", addr, port) + listenAddr := net.JoinHostPort(addr.String(), port) udpAddr, err := net.ResolveUDPAddr("udp", listenAddr) if err != nil { diff --git a/client/internal/debug/debug.go b/client/internal/debug/debug.go index 9c50f02b3..ebaf71b21 100644 --- a/client/internal/debug/debug.go +++ b/client/internal/debug/debug.go @@ -45,8 +45,11 @@ netbird.out: Most recent, anonymized stdout log file of the NetBird client. routes.txt: Detailed system routing table in tabular format including destination, gateway, interface, metrics, and protocol information, if --system-info flag was provided. interfaces.txt: Anonymized network interface information, if --system-info flag was provided. ip_rules.txt: Detailed IP routing rules in tabular format including priority, source, destination, interfaces, table, and action information (Linux only), if --system-info flag was provided. -iptables.txt: Anonymized iptables rules with packet counters, if --system-info flag was provided. -nftables.txt: Anonymized nftables rules with packet counters, if --system-info flag was provided. +iptables.txt: Anonymized iptables (IPv4) rules with packet counters, if --system-info flag was provided. +ip6tables.txt: Anonymized ip6tables (IPv6) rules with packet counters, if --system-info flag was provided. +ipset.txt: Anonymized ipset list output, if --system-info flag was provided. +nftables.txt: Anonymized nftables rules with packet counters across all families (ip, ip6, inet, etc.), if --system-info flag was provided. +sysctls.txt: Forwarding, reverse-path filter, source-validation, and conntrack accounting sysctl values that the NetBird client may read or modify, if --system-info flag was provided (Linux only). resolv.conf: DNS resolver configuration from /etc/resolv.conf (Unix systems only), if --system-info flag was provided. scutil_dns.txt: DNS configuration from scutil --dns (macOS only), if --system-info flag was provided. resolved_domains.txt: Anonymized resolved domain IP addresses from the status recorder. @@ -165,22 +168,33 @@ The config.txt file contains anonymized configuration information of the NetBird Other non-sensitive configuration options are included without anonymization. Firewall Rules (Linux only) -The bundle includes two separate firewall rule files: +The bundle includes the following firewall-related files: iptables.txt: -- Complete iptables ruleset with packet counters using 'iptables -v -n -L' +- IPv4 iptables ruleset with packet counters using 'iptables-save' and 'iptables -v -n -L' - Includes all tables (filter, nat, mangle, raw, security) - Shows packet and byte counters for each rule - All IP addresses are anonymized - Chain names, table names, and other non-sensitive information remain unchanged +ip6tables.txt: +- IPv6 ip6tables ruleset with packet counters using 'ip6tables-save' and 'ip6tables -v -n -L' +- Same table coverage and anonymization as iptables.txt +- Omitted when ip6tables is not installed or no IPv6 rules are present + +ipset.txt: +- Output of 'ipset list' (family-agnostic) +- IP addresses are anonymized; set names and types remain unchanged + nftables.txt: -- Complete nftables ruleset obtained via 'nft -a list ruleset' +- Complete nftables ruleset across all families (ip, ip6, inet, arp, bridge, netdev) via 'nft -a list ruleset' - Includes rule handle numbers and packet counters -- All tables, chains, and rules are included -- Shows packet and byte counters for each rule -- All IP addresses are anonymized -- Chain names, table names, and other non-sensitive information remain unchanged +- All IP addresses are anonymized; chain/table names remain unchanged + +sysctls.txt: +- Forwarding (IPv4 + IPv6, global and per-interface), reverse-path filter, source-validation, conntrack accounting, and TCP-related sysctls that netbird may read or modify +- Per-interface keys are enumerated from /proc/sys/net/ipv{4,6}/conf +- Interface names anonymized when --anonymize is set IP Rules (Linux only) The ip_rules.txt file contains detailed IP routing rule information: @@ -412,6 +426,10 @@ func (g *BundleGenerator) addSystemInfo() { log.Errorf("failed to add firewall rules to debug bundle: %v", err) } + if err := g.addSysctls(); err != nil { + log.Errorf("failed to add sysctls to debug bundle: %v", err) + } + if err := g.addDNSInfo(); err != nil { log.Errorf("failed to add DNS info to debug bundle: %v", err) } diff --git a/client/internal/debug/debug_linux.go b/client/internal/debug/debug_linux.go index aedf88b79..40d864eda 100644 --- a/client/internal/debug/debug_linux.go +++ b/client/internal/debug/debug_linux.go @@ -124,15 +124,18 @@ func getSystemdLogs(serviceName string) (string, error) { // addFirewallRules collects and adds firewall rules to the archive func (g *BundleGenerator) addFirewallRules() error { log.Info("Collecting firewall rules") - iptablesRules, err := collectIPTablesRules() + g.addIPTablesRulesToBundle("iptables-save", "iptables", "iptables.txt") + g.addIPTablesRulesToBundle("ip6tables-save", "ip6tables", "ip6tables.txt") + + ipsetOutput, err := collectIPSets() if err != nil { - log.Warnf("Failed to collect iptables rules: %v", err) + log.Warnf("Failed to collect ipset information: %v", err) } else { if g.anonymize { - iptablesRules = g.anonymizer.AnonymizeString(iptablesRules) + ipsetOutput = g.anonymizer.AnonymizeString(ipsetOutput) } - if err := g.addFileToZip(strings.NewReader(iptablesRules), "iptables.txt"); err != nil { - log.Warnf("Failed to add iptables rules to bundle: %v", err) + if err := g.addFileToZip(strings.NewReader(ipsetOutput), "ipset.txt"); err != nil { + log.Warnf("Failed to add ipset output to bundle: %v", err) } } @@ -151,44 +154,65 @@ func (g *BundleGenerator) addFirewallRules() error { return nil } -// collectIPTablesRules collects rules using both iptables-save and verbose listing -func collectIPTablesRules() (string, error) { - var builder strings.Builder - - saveOutput, err := collectIPTablesSave() +// addIPTablesRulesToBundle collects iptables/ip6tables rules and writes them to the bundle. +func (g *BundleGenerator) addIPTablesRulesToBundle(saveBin, listBin, filename string) { + rules, err := collectIPTablesRules(saveBin, listBin) if err != nil { - log.Warnf("Failed to collect iptables rules using iptables-save: %v", err) - } else { - builder.WriteString("=== iptables-save output ===\n") + log.Warnf("Failed to collect %s rules: %v", listBin, err) + return + } + if g.anonymize { + rules = g.anonymizer.AnonymizeString(rules) + } + if err := g.addFileToZip(strings.NewReader(rules), filename); err != nil { + log.Warnf("Failed to add %s rules to bundle: %v", listBin, err) + } +} + +// collectIPTablesRules collects rules using both and verbose listing via . +// Returns an error when neither command produced any output (e.g. the binary is missing), +// so the caller can skip writing an empty file. +func collectIPTablesRules(saveBin, listBin string) (string, error) { + var builder strings.Builder + var collected bool + var firstErr error + + saveOutput, err := runCommand(saveBin) + switch { + case err != nil: + firstErr = err + log.Warnf("Failed to collect %s output: %v", saveBin, err) + case strings.TrimSpace(saveOutput) == "": + log.Debugf("%s produced no output, skipping", saveBin) + default: + builder.WriteString(fmt.Sprintf("=== %s output ===\n", saveBin)) builder.WriteString(saveOutput) builder.WriteString("\n") + collected = true } - ipsetOutput, err := collectIPSets() - if err != nil { - log.Warnf("Failed to collect ipset information: %v", err) - } else { - builder.WriteString("=== ipset list output ===\n") - builder.WriteString(ipsetOutput) - builder.WriteString("\n") - } - - builder.WriteString("=== iptables -v -n -L output ===\n") + listHeader := fmt.Sprintf("=== %s -v -n -L output ===\n", listBin) + builder.WriteString(listHeader) tables := []string{"filter", "nat", "mangle", "raw", "security"} - for _, table := range tables { - builder.WriteString(fmt.Sprintf("*%s\n", table)) - - stats, err := getTableStatistics(table) + stats, err := runCommand(listBin, "-v", "-n", "-L", "-t", table) if err != nil { - log.Warnf("Failed to get statistics for table %s: %v", table, err) + if firstErr == nil { + firstErr = err + } + log.Warnf("Failed to get %s statistics for table %s: %v", listBin, table, err) continue } + builder.WriteString(fmt.Sprintf("*%s\n", table)) builder.WriteString(stats) builder.WriteString("\n") + collected = true } + if !collected { + return "", fmt.Errorf("collect %s rules: %w", listBin, firstErr) + } return builder.String(), nil } @@ -214,34 +238,15 @@ func collectIPSets() (string, error) { return ipsets, nil } -// collectIPTablesSave uses iptables-save to get rule definitions -func collectIPTablesSave() (string, error) { - cmd := exec.Command("iptables-save") +// runCommand executes a command and returns its stdout, wrapping stderr in the error on failure. +func runCommand(name string, args ...string) (string, error) { + cmd := exec.Command(name, args...) var stdout, stderr bytes.Buffer cmd.Stdout = &stdout cmd.Stderr = &stderr if err := cmd.Run(); err != nil { - return "", fmt.Errorf("execute iptables-save: %w (stderr: %s)", err, stderr.String()) - } - - rules := stdout.String() - if strings.TrimSpace(rules) == "" { - return "", fmt.Errorf("no iptables rules found") - } - - return rules, nil -} - -// getTableStatistics gets verbose statistics for an entire table using iptables command -func getTableStatistics(table string) (string, error) { - cmd := exec.Command("iptables", "-v", "-n", "-L", "-t", table) - var stdout, stderr bytes.Buffer - cmd.Stdout = &stdout - cmd.Stderr = &stderr - - if err := cmd.Run(); err != nil { - return "", fmt.Errorf("execute iptables -v -n -L: %w (stderr: %s)", err, stderr.String()) + return "", fmt.Errorf("execute %s: %w (stderr: %s)", name, err, stderr.String()) } return stdout.String(), nil @@ -804,3 +809,91 @@ func formatSetKeyType(keyType nftables.SetDatatype) string { return fmt.Sprintf("type-%v", keyType) } } + +// addSysctls collects forwarding and netbird-managed sysctl values and writes them to the bundle. +func (g *BundleGenerator) addSysctls() error { + log.Info("Collecting sysctls") + content := collectSysctls() + if g.anonymize { + content = g.anonymizer.AnonymizeString(content) + } + if err := g.addFileToZip(strings.NewReader(content), "sysctls.txt"); err != nil { + return fmt.Errorf("add sysctls to bundle: %w", err) + } + return nil +} + +// collectSysctls reads every sysctl that the netbird client may modify, plus +// global IPv4/IPv6 forwarding, and returns a formatted dump grouped by topic. +// Per-interface values are enumerated by listing /proc/sys/net/ipv{4,6}/conf. +func collectSysctls() string { + var builder strings.Builder + + writeSysctlGroup(&builder, "forwarding", []string{ + "net.ipv4.ip_forward", + "net.ipv6.conf.all.forwarding", + "net.ipv6.conf.default.forwarding", + }) + writeSysctlGroup(&builder, "ipv4 per-interface forwarding", listInterfaceSysctls("ipv4", "forwarding")) + writeSysctlGroup(&builder, "ipv6 per-interface forwarding", listInterfaceSysctls("ipv6", "forwarding")) + writeSysctlGroup(&builder, "rp_filter", append( + []string{"net.ipv4.conf.all.rp_filter", "net.ipv4.conf.default.rp_filter"}, + listInterfaceSysctls("ipv4", "rp_filter")..., + )) + writeSysctlGroup(&builder, "src_valid_mark", append( + []string{"net.ipv4.conf.all.src_valid_mark", "net.ipv4.conf.default.src_valid_mark"}, + listInterfaceSysctls("ipv4", "src_valid_mark")..., + )) + writeSysctlGroup(&builder, "conntrack", []string{ + "net.netfilter.nf_conntrack_acct", + "net.netfilter.nf_conntrack_tcp_loose", + }) + writeSysctlGroup(&builder, "tcp", []string{ + "net.ipv4.tcp_tw_reuse", + }) + + return builder.String() +} + +func writeSysctlGroup(builder *strings.Builder, title string, keys []string) { + builder.WriteString(fmt.Sprintf("=== %s ===\n", title)) + for _, key := range keys { + value, err := readSysctl(key) + if err != nil { + builder.WriteString(fmt.Sprintf("%s = \n", key, err)) + continue + } + builder.WriteString(fmt.Sprintf("%s = %s\n", key, value)) + } + builder.WriteString("\n") +} + +// listInterfaceSysctls returns net.ipvX.conf.. keys for every +// interface present in /proc/sys/net/ipvX/conf, skipping "all" and "default" +// (callers add those explicitly so they appear first). +func listInterfaceSysctls(family, leaf string) []string { + dir := fmt.Sprintf("/proc/sys/net/%s/conf", family) + entries, err := os.ReadDir(dir) + if err != nil { + return nil + } + var keys []string + for _, e := range entries { + name := e.Name() + if name == "all" || name == "default" { + continue + } + keys = append(keys, fmt.Sprintf("net.%s.conf.%s.%s", family, name, leaf)) + } + sort.Strings(keys) + return keys +} + +func readSysctl(key string) (string, error) { + path := fmt.Sprintf("/proc/sys/%s", strings.ReplaceAll(key, ".", "/")) + value, err := os.ReadFile(path) + if err != nil { + return "", err + } + return strings.TrimSpace(string(value)), nil +} diff --git a/client/internal/debug/debug_nonlinux.go b/client/internal/debug/debug_nonlinux.go index ace53bd94..878fee40f 100644 --- a/client/internal/debug/debug_nonlinux.go +++ b/client/internal/debug/debug_nonlinux.go @@ -17,3 +17,8 @@ func (g *BundleGenerator) addIPRules() error { // IP rules are only supported on Linux return nil } + +func (g *BundleGenerator) addSysctls() error { + // Sysctl collection is only supported on Linux + return nil +} From 07e5450117dd0451aaeefc18729a822115587e69 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Thu, 14 May 2026 23:42:40 +0900 Subject: [PATCH 05/31] [management] Bracket IPv6 reverse-proxy target hosts when building URL Host field (#6141) --- .../modules/reverseproxy/service/service.go | 18 ++++- .../reverseproxy/service/service_test.go | 77 +++++++++++++++++++ 2 files changed, 93 insertions(+), 2 deletions(-) diff --git a/management/internals/modules/reverseproxy/service/service.go b/management/internals/modules/reverseproxy/service/service.go index 769e037bc..166a66a5f 100644 --- a/management/internals/modules/reverseproxy/service/service.go +++ b/management/internals/modules/reverseproxy/service/service.go @@ -381,13 +381,14 @@ func (s *Service) buildPathMappings() []*proto.PathMapping { } // HTTP/HTTPS: build full URL + hostNoBrackets := strings.TrimSuffix(strings.TrimPrefix(target.Host, "["), "]") targetURL := url.URL{ Scheme: target.Protocol, - Host: target.Host, + Host: bracketIPv6Host(hostNoBrackets), Path: "/", } if target.Port > 0 && !isDefaultPort(target.Protocol, target.Port) { - targetURL.Host = net.JoinHostPort(targetURL.Host, strconv.FormatUint(uint64(target.Port), 10)) + targetURL.Host = net.JoinHostPort(hostNoBrackets, strconv.FormatUint(uint64(target.Port), 10)) } path := "/" @@ -405,6 +406,19 @@ func (s *Service) buildPathMappings() []*proto.PathMapping { return pathMappings } +// bracketIPv6Host wraps host in square brackets when it is an IPv6 literal, as +// required for the Host field of net/url.URL (RFC 3986 ยง3.2.2). v4-mapped IPv6 +// addresses are bracketed too since their textual form contains colons. +func bracketIPv6Host(host string) string { + if strings.HasPrefix(host, "[") { + return host + } + if addr, err := netip.ParseAddr(host); err == nil && addr.Is6() { + return "[" + host + "]" + } + return host +} + func operationToProtoType(op Operation) proto.ProxyMappingUpdateType { switch op { case Create: diff --git a/management/internals/modules/reverseproxy/service/service_test.go b/management/internals/modules/reverseproxy/service/service_test.go index ff54cb79f..f1349ff65 100644 --- a/management/internals/modules/reverseproxy/service/service_test.go +++ b/management/internals/modules/reverseproxy/service/service_test.go @@ -351,6 +351,83 @@ func TestToProtoMapping_PortInTargetURL(t *testing.T) { port: 80, wantTarget: "https://10.0.0.1:80/", }, + { + name: "domain host without port is unchanged", + protocol: "http", + host: "example.com", + port: 0, + wantTarget: "http://example.com/", + }, + { + name: "domain host with non-default port is unchanged", + protocol: "http", + host: "example.com", + port: 8080, + wantTarget: "http://example.com:8080/", + }, + { + name: "ipv6 host without port is bracketed", + protocol: "http", + host: "fb00:cafe:1::3", + port: 0, + wantTarget: "http://[fb00:cafe:1::3]/", + }, + { + name: "ipv6 host with default port omits port and brackets host", + protocol: "http", + host: "fb00:cafe:1::3", + port: 80, + wantTarget: "http://[fb00:cafe:1::3]/", + }, + { + name: "ipv6 host with non-default port is bracketed", + protocol: "http", + host: "fb00:cafe:1::3", + port: 8080, + wantTarget: "http://[fb00:cafe:1::3]:8080/", + }, + { + name: "ipv6 loopback without port is bracketed", + protocol: "http", + host: "::1", + port: 0, + wantTarget: "http://[::1]/", + }, + { + name: "ipv6 host with 5-digit port is bracketed", + protocol: "http", + host: "fb00:cafe::1", + port: 18080, + wantTarget: "http://[fb00:cafe::1]:18080/", + }, + { + name: "pre-bracketed ipv6 without port stays single-bracketed", + protocol: "http", + host: "[fb00:cafe::1]", + port: 0, + wantTarget: "http://[fb00:cafe::1]/", + }, + { + name: "pre-bracketed ipv6 with port is not double-bracketed", + protocol: "http", + host: "[fb00:cafe::1]", + port: 8080, + wantTarget: "http://[fb00:cafe::1]:8080/", + }, + { + name: "v4-mapped ipv6 host without port is bracketed", + protocol: "http", + host: "::ffff:10.0.0.1", + port: 0, + wantTarget: "http://[::ffff:10.0.0.1]/", + }, + { + name: "full-form 8-group ipv6 without port is bracketed", + protocol: "http", + host: "fb00:cafe:1:0:0:0:0:3", + port: 0, + wantTarget: "http://[fb00:cafe:1:0:0:0:0:3]/", + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { From 2ccae7ec479c106efb6d7a7edff4bb55affb2aa4 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Fri, 15 May 2026 23:58:47 +0900 Subject: [PATCH 06/31] [client] Mirror v4 exit selection onto v6 pair and honour SkipAutoApply per route (#6150) --- client/internal/routemanager/manager.go | 5 +- .../internal/routeselector/routeselector.go | 84 ++++++----- .../routeselector/routeselector_test.go | 131 ++++++++++++++++++ client/ui/network.go | 10 +- 4 files changed, 197 insertions(+), 33 deletions(-) diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index e5d9363ca..907f1f592 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -704,7 +704,10 @@ func (m *DefaultManager) collectExitNodeInfo(clientRoutes route.HAMap) exitNodeI } func (m *DefaultManager) isExitNodeRoute(routes []*route.Route) bool { - return len(routes) > 0 && routes[0].Network.String() == vars.ExitNodeCIDR + if len(routes) == 0 { + return false + } + return route.IsV4DefaultRoute(routes[0].Network) || route.IsV6DefaultRoute(routes[0].Network) } func (m *DefaultManager) categorizeUserSelection(netID route.NetID, info *exitNodeInfo) { diff --git a/client/internal/routeselector/routeselector.go b/client/internal/routeselector/routeselector.go index 30afc013b..2ddc24bf2 100644 --- a/client/internal/routeselector/routeselector.go +++ b/client/internal/routeselector/routeselector.go @@ -4,6 +4,7 @@ import ( "encoding/json" "fmt" "slices" + "strings" "sync" "github.com/hashicorp/go-multierror" @@ -12,10 +13,6 @@ import ( "github.com/netbirdio/netbird/route" ) -const ( - exitNodeCIDR = "0.0.0.0/0" -) - type RouteSelector struct { mu sync.RWMutex deselectedRoutes map[route.NetID]struct{} @@ -124,13 +121,7 @@ func (rs *RouteSelector) IsSelected(routeID route.NetID) bool { rs.mu.RLock() defer rs.mu.RUnlock() - if rs.deselectAll { - return false - } - - _, deselected := rs.deselectedRoutes[routeID] - isSelected := !deselected - return isSelected + return rs.isSelectedLocked(routeID) } // FilterSelected removes unselected routes from the provided map. @@ -144,23 +135,22 @@ func (rs *RouteSelector) FilterSelected(routes route.HAMap) route.HAMap { filtered := route.HAMap{} for id, rt := range routes { - netID := id.NetID() - _, deselected := rs.deselectedRoutes[netID] - if !deselected { + if !rs.isDeselectedLocked(id.NetID()) { filtered[id] = rt } } return filtered } -// HasUserSelectionForRoute returns true if the user has explicitly selected or deselected this specific route +// HasUserSelectionForRoute returns true if the user has explicitly selected or deselected this route. +// Intended for exit-node code paths: a v6 exit-node pair (e.g. "MyExit-v6") with no explicit state of +// its own inherits its v4 base's state, so legacy persisted selections that predate v6 pairing +// transparently apply to the synthesized v6 entry. func (rs *RouteSelector) HasUserSelectionForRoute(routeID route.NetID) bool { rs.mu.RLock() defer rs.mu.RUnlock() - _, selected := rs.selectedRoutes[routeID] - _, deselected := rs.deselectedRoutes[routeID] - return selected || deselected + return rs.hasUserSelectionForRouteLocked(rs.effectiveNetID(routeID)) } func (rs *RouteSelector) FilterSelectedExitNodes(routes route.HAMap) route.HAMap { @@ -174,7 +164,7 @@ func (rs *RouteSelector) FilterSelectedExitNodes(routes route.HAMap) route.HAMap filtered := make(route.HAMap, len(routes)) for id, rt := range routes { netID := id.NetID() - if rs.isDeselected(netID) { + if rs.isDeselectedLocked(netID) { continue } @@ -189,13 +179,48 @@ func (rs *RouteSelector) FilterSelectedExitNodes(routes route.HAMap) route.HAMap return filtered } -func (rs *RouteSelector) isDeselected(netID route.NetID) bool { +// effectiveNetID returns the v4 base for a "-v6" exit pair entry that has no explicit +// state of its own, so selections made on the v4 entry govern the v6 entry automatically. +// Only call this from exit-node-specific code paths: applying it to a non-exit "-v6" route +// would make it inherit unrelated v4 state. Must be called with rs.mu held. +func (rs *RouteSelector) effectiveNetID(id route.NetID) route.NetID { + name := string(id) + if !strings.HasSuffix(name, route.V6ExitSuffix) { + return id + } + if _, ok := rs.selectedRoutes[id]; ok { + return id + } + if _, ok := rs.deselectedRoutes[id]; ok { + return id + } + return route.NetID(strings.TrimSuffix(name, route.V6ExitSuffix)) +} + +func (rs *RouteSelector) isSelectedLocked(routeID route.NetID) bool { + if rs.deselectAll { + return false + } + _, deselected := rs.deselectedRoutes[routeID] + return !deselected +} + +func (rs *RouteSelector) isDeselectedLocked(netID route.NetID) bool { + if rs.deselectAll { + return true + } _, deselected := rs.deselectedRoutes[netID] - return deselected || rs.deselectAll + return deselected +} + +func (rs *RouteSelector) hasUserSelectionForRouteLocked(routeID route.NetID) bool { + _, selected := rs.selectedRoutes[routeID] + _, deselected := rs.deselectedRoutes[routeID] + return selected || deselected } func isExitNode(rt []*route.Route) bool { - return len(rt) > 0 && rt[0].Network.String() == exitNodeCIDR + return len(rt) > 0 && (route.IsV4DefaultRoute(rt[0].Network) || route.IsV6DefaultRoute(rt[0].Network)) } func (rs *RouteSelector) applyExitNodeFilter( @@ -204,26 +229,23 @@ func (rs *RouteSelector) applyExitNodeFilter( rt []*route.Route, out route.HAMap, ) { - - if rs.hasUserSelections() { - // user made explicit selects/deselects - if rs.IsSelected(netID) { + // Exit-node path: apply the v4/v6 pair mirror so a deselect on the v4 base also + // drops the synthesized v6 entry that lacks its own explicit state. + effective := rs.effectiveNetID(netID) + if rs.hasUserSelectionForRouteLocked(effective) { + if rs.isSelectedLocked(effective) { out[id] = rt } return } - // no explicit selections: only include routes marked !SkipAutoApply (=AutoApply) + // no explicit selection for this route: defer to management's SkipAutoApply flag sel := collectSelected(rt) if len(sel) > 0 { out[id] = sel } } -func (rs *RouteSelector) hasUserSelections() bool { - return len(rs.selectedRoutes) > 0 || len(rs.deselectedRoutes) > 0 -} - func collectSelected(rt []*route.Route) []*route.Route { var sel []*route.Route for _, r := range rt { diff --git a/client/internal/routeselector/routeselector_test.go b/client/internal/routeselector/routeselector_test.go index 5faea2456..3f0d9f120 100644 --- a/client/internal/routeselector/routeselector_test.go +++ b/client/internal/routeselector/routeselector_test.go @@ -330,6 +330,137 @@ func TestRouteSelector_FilterSelectedExitNodes(t *testing.T) { assert.Len(t, filtered, 0) // No routes should be selected } +// TestRouteSelector_V6ExitPairInherits covers the v4/v6 exit-node pair selection +// mirror. The mirror is scoped to exit-node code paths: HasUserSelectionForRoute +// and FilterSelectedExitNodes resolve a "-v6" entry without explicit state to its +// v4 base, so legacy persisted selections that predate v6 pairing transparently +// apply to the synthesized v6 entry. General lookups (IsSelected, FilterSelected) +// stay literal so unrelated routes named "*-v6" don't inherit unrelated state. +func TestRouteSelector_V6ExitPairInherits(t *testing.T) { + all := []route.NetID{"exit1", "exit1-v6", "exit2", "exit2-v6", "corp", "corp-v6"} + + t.Run("HasUserSelectionForRoute mirrors deselected v4 base", func(t *testing.T) { + rs := routeselector.NewRouteSelector() + require.NoError(t, rs.DeselectRoutes([]route.NetID{"exit1"}, all)) + + assert.True(t, rs.HasUserSelectionForRoute("exit1-v6"), "v6 pair sees v4 base's user selection") + + // unrelated v6 with no v4 base touched is unaffected + assert.False(t, rs.HasUserSelectionForRoute("exit2-v6")) + }) + + t.Run("IsSelected stays literal for non-exit lookups", func(t *testing.T) { + rs := routeselector.NewRouteSelector() + require.NoError(t, rs.DeselectRoutes([]route.NetID{"corp"}, all)) + + // A non-exit route literally named "corp-v6" must not inherit "corp"'s state + // via the mirror; the mirror only applies in exit-node code paths. + assert.False(t, rs.IsSelected("corp")) + assert.True(t, rs.IsSelected("corp-v6"), "non-exit *-v6 routes must not inherit unrelated v4 state") + }) + + t.Run("explicit v6 state overrides v4 base in filter", func(t *testing.T) { + rs := routeselector.NewRouteSelector() + require.NoError(t, rs.DeselectRoutes([]route.NetID{"exit1"}, all)) + require.NoError(t, rs.SelectRoutes([]route.NetID{"exit1-v6"}, true, all)) + + v4Route := &route.Route{NetID: "exit1", Network: netip.MustParsePrefix("0.0.0.0/0")} + v6Route := &route.Route{NetID: "exit1-v6", Network: netip.MustParsePrefix("::/0")} + routes := route.HAMap{ + "exit1|0.0.0.0/0": {v4Route}, + "exit1-v6|::/0": {v6Route}, + } + + filtered := rs.FilterSelectedExitNodes(routes) + assert.NotContains(t, filtered, route.HAUniqueID("exit1|0.0.0.0/0")) + assert.Contains(t, filtered, route.HAUniqueID("exit1-v6|::/0"), "explicit v6 select wins over v4 base") + }) + + t.Run("non-v6-suffix routes unaffected", func(t *testing.T) { + rs := routeselector.NewRouteSelector() + require.NoError(t, rs.DeselectRoutes([]route.NetID{"exit1"}, all)) + + // A route literally named "exit1-something" must not pair-resolve. + assert.False(t, rs.HasUserSelectionForRoute("exit1-something")) + }) + + t.Run("filter v6 paired with deselected v4 base", func(t *testing.T) { + rs := routeselector.NewRouteSelector() + require.NoError(t, rs.DeselectRoutes([]route.NetID{"exit1"}, all)) + + v4Route := &route.Route{NetID: "exit1", Network: netip.MustParsePrefix("0.0.0.0/0")} + v6Route := &route.Route{NetID: "exit1-v6", Network: netip.MustParsePrefix("::/0")} + routes := route.HAMap{ + "exit1|0.0.0.0/0": {v4Route}, + "exit1-v6|::/0": {v6Route}, + } + + filtered := rs.FilterSelectedExitNodes(routes) + assert.Empty(t, filtered, "deselecting v4 base must also drop the v6 pair") + }) + + t.Run("non-exit *-v6 routes pass through FilterSelectedExitNodes", func(t *testing.T) { + rs := routeselector.NewRouteSelector() + require.NoError(t, rs.DeselectRoutes([]route.NetID{"corp"}, all)) + + // A non-default-route entry named "corp-v6" is not an exit node and + // must not be skipped because its v4 base "corp" is deselected. + corpV6 := &route.Route{NetID: "corp-v6", Network: netip.MustParsePrefix("10.0.0.0/8")} + routes := route.HAMap{ + "corp-v6|10.0.0.0/8": {corpV6}, + } + + filtered := rs.FilterSelectedExitNodes(routes) + assert.Contains(t, filtered, route.HAUniqueID("corp-v6|10.0.0.0/8"), + "non-exit *-v6 routes must not inherit unrelated v4 state in FilterSelectedExitNodes") + }) +} + +// TestRouteSelector_SkipAutoApplyPerRoute verifies that management's +// SkipAutoApply flag governs each untouched route independently, even when +// the user has explicit selections on other routes. +func TestRouteSelector_SkipAutoApplyPerRoute(t *testing.T) { + autoApplied := &route.Route{ + NetID: "Auto", + Network: netip.MustParsePrefix("0.0.0.0/0"), + SkipAutoApply: false, + } + skipApply := &route.Route{ + NetID: "Skip", + Network: netip.MustParsePrefix("0.0.0.0/0"), + SkipAutoApply: true, + } + routes := route.HAMap{ + "Auto|0.0.0.0/0": {autoApplied}, + "Skip|0.0.0.0/0": {skipApply}, + } + + rs := routeselector.NewRouteSelector() + // User makes an unrelated explicit selection elsewhere. + require.NoError(t, rs.DeselectRoutes([]route.NetID{"Unrelated"}, []route.NetID{"Auto", "Skip", "Unrelated"})) + + filtered := rs.FilterSelectedExitNodes(routes) + assert.Contains(t, filtered, route.HAUniqueID("Auto|0.0.0.0/0"), "AutoApply route should be included") + assert.NotContains(t, filtered, route.HAUniqueID("Skip|0.0.0.0/0"), "SkipAutoApply route should be excluded without explicit user selection") +} + +// TestRouteSelector_V6ExitIsExitNode verifies that ::/0 routes are recognized +// as exit nodes by the selector's filter path. +func TestRouteSelector_V6ExitIsExitNode(t *testing.T) { + v6Exit := &route.Route{ + NetID: "V6Only", + Network: netip.MustParsePrefix("::/0"), + SkipAutoApply: true, + } + routes := route.HAMap{ + "V6Only|::/0": {v6Exit}, + } + + rs := routeselector.NewRouteSelector() + filtered := rs.FilterSelectedExitNodes(routes) + assert.Empty(t, filtered, "::/0 should be treated as an exit node and respect SkipAutoApply") +} + func TestRouteSelector_NewRoutesBehavior(t *testing.T) { initialRoutes := []route.NetID{"route1", "route2", "route3"} newRoutes := []route.NetID{"route1", "route2", "route3", "route4", "route5"} diff --git a/client/ui/network.go b/client/ui/network.go index 1619f78a2..cd5d23558 100644 --- a/client/ui/network.go +++ b/client/ui/network.go @@ -193,7 +193,15 @@ func getOverlappingNetworks(routes []*proto.Network) []*proto.Network { } func isDefaultRoute(routeRange string) bool { - return routeRange == "0.0.0.0/0" || routeRange == "::/0" + // routeRange is the merged display string from the daemon, e.g. "0.0.0.0/0", + // "::/0", or "0.0.0.0/0, ::/0" when a v4 exit node has a paired v6 entry. + for _, part := range strings.Split(routeRange, ",") { + switch strings.TrimSpace(part) { + case "0.0.0.0/0", "::/0": + return true + } + } + return false } func getExitNodeNetworks(routes []*proto.Network) []*proto.Network { From 9ed2e2a5b463077f8abe3e3926695f5dc9411e29 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Sat, 16 May 2026 00:07:38 +0900 Subject: [PATCH 07/31] [client] Drop DNS probes for passive health projection (#5971) --- client/internal/connect.go | 2 - client/internal/dns/host.go | 12 + client/internal/dns/host_android.go | 19 +- client/internal/dns/host_ios.go | 9 + client/internal/dns/host_windows.go | 121 ++- client/internal/dns/hosts_dns_holder.go | 1 + client/internal/dns/local/local.go | 2 - client/internal/dns/mock_server.go | 9 +- client/internal/dns/network_manager_unix.go | 211 ++++- client/internal/dns/server.go | 928 ++++++++++++-------- client/internal/dns/server_android.go | 2 +- client/internal/dns/server_test.go | 698 +++++++++++++-- client/internal/dns/systemd_linux.go | 151 +++- client/internal/dns/upstream.go | 683 +++++++------- client/internal/dns/upstream_android.go | 5 +- client/internal/dns/upstream_general.go | 5 +- client/internal/dns/upstream_ios.go | 17 +- client/internal/dns/upstream_test.go | 227 +++-- client/internal/engine.go | 16 +- client/internal/routemanager/manager.go | 34 + client/internal/routemanager/mock.go | 9 + client/ios/NetBirdSDK/client.go | 6 +- 22 files changed, 2294 insertions(+), 873 deletions(-) diff --git a/client/internal/connect.go b/client/internal/connect.go index 8c0e9b1ba..ea884818f 100644 --- a/client/internal/connect.go +++ b/client/internal/connect.go @@ -116,7 +116,6 @@ func (c *ConnectClient) RunOniOS( fileDescriptor int32, networkChangeListener listener.NetworkChangeListener, dnsManager dns.IosDnsManager, - dnsAddresses []netip.AddrPort, stateFilePath string, ) error { // Set GC percent to 5% to reduce memory usage as iOS only allows 50MB of memory for the extension. @@ -126,7 +125,6 @@ func (c *ConnectClient) RunOniOS( FileDescriptor: fileDescriptor, NetworkChangeListener: networkChangeListener, DnsManager: dnsManager, - HostDNSAddresses: dnsAddresses, StateFilePath: stateFilePath, } return c.run(mobileDependency, nil, "") diff --git a/client/internal/dns/host.go b/client/internal/dns/host.go index f7dc46a6b..48eacef29 100644 --- a/client/internal/dns/host.go +++ b/client/internal/dns/host.go @@ -16,6 +16,10 @@ type hostManager interface { restoreHostDNS() error supportCustomPort() bool string() string + // getOriginalNameservers returns the OS-side resolvers used as PriorityFallback + // upstreams: pre-takeover snapshots on desktop, the OS-pushed list on Android, + // hardcoded Quad9 on iOS, nil for noop / mock. + getOriginalNameservers() []netip.Addr } type SystemDNSSettings struct { @@ -131,3 +135,11 @@ func (n noopHostConfigurator) supportCustomPort() bool { func (n noopHostConfigurator) string() string { return "noop" } + +func (n noopHostConfigurator) getOriginalNameservers() []netip.Addr { + return nil +} + +func (m *mockHostConfigurator) getOriginalNameservers() []netip.Addr { + return nil +} diff --git a/client/internal/dns/host_android.go b/client/internal/dns/host_android.go index dfa3e5712..48b3e0301 100644 --- a/client/internal/dns/host_android.go +++ b/client/internal/dns/host_android.go @@ -1,14 +1,20 @@ package dns import ( + "net/netip" + "github.com/netbirdio/netbird/client/internal/statemanager" ) +// androidHostManager is a noop on the OS side (Android's VPN service handles +// DNS for us) but tracks the OS-reported resolver list pushed via +// OnUpdatedHostDNSServer so it can serve as the fallback nameserver source. type androidHostManager struct { + holder *hostsDNSHolder } -func newHostManager() (*androidHostManager, error) { - return &androidHostManager{}, nil +func newHostManager(holder *hostsDNSHolder) (*androidHostManager, error) { + return &androidHostManager{holder: holder}, nil } func (a androidHostManager) applyDNSConfig(HostDNSConfig, *statemanager.Manager) error { @@ -26,3 +32,12 @@ func (a androidHostManager) supportCustomPort() bool { func (a androidHostManager) string() string { return "none" } + +func (a androidHostManager) getOriginalNameservers() []netip.Addr { + hosts := a.holder.get() + out := make([]netip.Addr, 0, len(hosts)) + for ap := range hosts { + out = append(out, ap.Addr()) + } + return out +} diff --git a/client/internal/dns/host_ios.go b/client/internal/dns/host_ios.go index 1c0ac63e9..860bb8b50 100644 --- a/client/internal/dns/host_ios.go +++ b/client/internal/dns/host_ios.go @@ -3,6 +3,7 @@ package dns import ( "encoding/json" "fmt" + "net/netip" log "github.com/sirupsen/logrus" @@ -20,6 +21,14 @@ func newHostManager(dnsManager IosDnsManager) (*iosHostManager, error) { }, nil } +func (a iosHostManager) getOriginalNameservers() []netip.Addr { + // Quad9 v4+v6: 9.9.9.9, 2620:fe::fe. + return []netip.Addr{ + netip.AddrFrom4([4]byte{9, 9, 9, 9}), + netip.AddrFrom16([16]byte{0x26, 0x20, 0x00, 0xfe, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xfe}), + } +} + func (a iosHostManager) applyDNSConfig(config HostDNSConfig, _ *statemanager.Manager) error { jsonData, err := json.Marshal(config) if err != nil { diff --git a/client/internal/dns/host_windows.go b/client/internal/dns/host_windows.go index 4a8cf8cec..4f6ece532 100644 --- a/client/internal/dns/host_windows.go +++ b/client/internal/dns/host_windows.go @@ -7,6 +7,7 @@ import ( "io" "net/netip" "os/exec" + "slices" "strings" "syscall" "time" @@ -44,9 +45,11 @@ const ( nrptMaxDomainsPerRule = 50 - interfaceConfigPath = `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces` - interfaceConfigNameServerKey = "NameServer" - interfaceConfigSearchListKey = "SearchList" + interfaceConfigPath = `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces` + interfaceConfigPathV6 = `SYSTEM\CurrentControlSet\Services\Tcpip6\Parameters\Interfaces` + interfaceConfigNameServerKey = "NameServer" + interfaceConfigDhcpNameSrvKey = "DhcpNameServer" + interfaceConfigSearchListKey = "SearchList" // Network interface DNS registration settings disableDynamicUpdateKey = "DisableDynamicUpdate" @@ -67,10 +70,11 @@ const ( ) type registryConfigurator struct { - guid string - routingAll bool - gpo bool - nrptEntryCount int + guid string + routingAll bool + gpo bool + nrptEntryCount int + origNameservers []netip.Addr } func newHostManager(wgInterface WGIface) (*registryConfigurator, error) { @@ -94,6 +98,17 @@ func newHostManager(wgInterface WGIface) (*registryConfigurator, error) { gpo: useGPO, } + origNameservers, err := configurator.captureOriginalNameservers() + switch { + case err != nil: + log.Warnf("capture original nameservers from non-WG adapters: %v", err) + case len(origNameservers) == 0: + log.Warnf("no original nameservers captured from non-WG adapters; DNS fallback will be empty") + default: + log.Debugf("captured %d original nameservers from non-WG adapters: %v", len(origNameservers), origNameservers) + } + configurator.origNameservers = origNameservers + if err := configurator.configureInterface(); err != nil { log.Errorf("failed to configure interface settings: %v", err) } @@ -101,6 +116,98 @@ func newHostManager(wgInterface WGIface) (*registryConfigurator, error) { return configurator, nil } +// captureOriginalNameservers reads DNS addresses from every Tcpip(6) interface +// registry key except the WG adapter. v4 and v6 servers live in separate +// hives (Tcpip vs Tcpip6) keyed by the same interface GUID. +func (r *registryConfigurator) captureOriginalNameservers() ([]netip.Addr, error) { + seen := make(map[netip.Addr]struct{}) + var out []netip.Addr + var merr *multierror.Error + for _, root := range []string{interfaceConfigPath, interfaceConfigPathV6} { + addrs, err := r.captureFromTcpipRoot(root) + if err != nil { + merr = multierror.Append(merr, fmt.Errorf("%s: %w", root, err)) + continue + } + for _, addr := range addrs { + if _, dup := seen[addr]; dup { + continue + } + seen[addr] = struct{}{} + out = append(out, addr) + } + } + return out, nberrors.FormatErrorOrNil(merr) +} + +func (r *registryConfigurator) captureFromTcpipRoot(rootPath string) ([]netip.Addr, error) { + root, err := registry.OpenKey(registry.LOCAL_MACHINE, rootPath, registry.READ) + if err != nil { + return nil, fmt.Errorf("open key: %w", err) + } + defer closer(root) + + guids, err := root.ReadSubKeyNames(-1) + if err != nil { + return nil, fmt.Errorf("read subkeys: %w", err) + } + + var out []netip.Addr + for _, guid := range guids { + if strings.EqualFold(guid, r.guid) { + continue + } + out = append(out, readInterfaceNameservers(rootPath, guid)...) + } + return out, nil +} + +func readInterfaceNameservers(rootPath, guid string) []netip.Addr { + keyPath := rootPath + "\\" + guid + k, err := registry.OpenKey(registry.LOCAL_MACHINE, keyPath, registry.QUERY_VALUE) + if err != nil { + return nil + } + defer closer(k) + + // Static NameServer wins over DhcpNameServer for actual resolution. + for _, name := range []string{interfaceConfigNameServerKey, interfaceConfigDhcpNameSrvKey} { + raw, _, err := k.GetStringValue(name) + if err != nil || raw == "" { + continue + } + if out := parseRegistryNameservers(raw); len(out) > 0 { + return out + } + } + return nil +} + +func parseRegistryNameservers(raw string) []netip.Addr { + var out []netip.Addr + for _, field := range strings.FieldsFunc(raw, func(r rune) bool { return r == ',' || r == ' ' || r == '\t' }) { + addr, err := netip.ParseAddr(strings.TrimSpace(field)) + if err != nil { + continue + } + addr = addr.Unmap() + if !addr.IsValid() || addr.IsUnspecified() { + continue + } + // Drop unzoned link-local: not routable without a scope id. If + // the user wrote "fe80::1%eth0" ParseAddr preserves the zone. + if addr.IsLinkLocalUnicast() && addr.Zone() == "" { + continue + } + out = append(out, addr) + } + return out +} + +func (r *registryConfigurator) getOriginalNameservers() []netip.Addr { + return slices.Clone(r.origNameservers) +} + func (r *registryConfigurator) supportCustomPort() bool { return false } diff --git a/client/internal/dns/hosts_dns_holder.go b/client/internal/dns/hosts_dns_holder.go index 980d917a7..9ecc397be 100644 --- a/client/internal/dns/hosts_dns_holder.go +++ b/client/internal/dns/hosts_dns_holder.go @@ -25,6 +25,7 @@ func (h *hostsDNSHolder) set(list []netip.AddrPort) { h.mutex.Unlock() } +//nolint:unused func (h *hostsDNSHolder) get() map[netip.AddrPort]struct{} { h.mutex.RLock() l := h.unprotectedDNSList diff --git a/client/internal/dns/local/local.go b/client/internal/dns/local/local.go index e9d310f00..4a75a76b6 100644 --- a/client/internal/dns/local/local.go +++ b/client/internal/dns/local/local.go @@ -76,8 +76,6 @@ func (d *Resolver) ID() types.HandlerID { return "local-resolver" } -func (d *Resolver) ProbeAvailability(context.Context) {} - // ServeDNS handles a DNS request func (d *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { logger := log.WithFields(log.Fields{ diff --git a/client/internal/dns/mock_server.go b/client/internal/dns/mock_server.go index 548b1f54f..31fedd9e5 100644 --- a/client/internal/dns/mock_server.go +++ b/client/internal/dns/mock_server.go @@ -9,6 +9,7 @@ import ( dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config" nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/shared/management/domain" ) @@ -70,10 +71,6 @@ func (m *MockServer) SearchDomains() []string { return make([]string, 0) } -// ProbeAvailability mocks implementation of ProbeAvailability from the Server interface -func (m *MockServer) ProbeAvailability() { -} - func (m *MockServer) UpdateServerConfig(domains dnsconfig.ServerDomains) error { if m.UpdateServerConfigFunc != nil { return m.UpdateServerConfigFunc(domains) @@ -85,8 +82,8 @@ func (m *MockServer) PopulateManagementDomain(mgmtURL *url.URL) error { return nil } -// SetRouteChecker mock implementation of SetRouteChecker from Server interface -func (m *MockServer) SetRouteChecker(func(netip.Addr) bool) { +// SetRouteSources mock implementation of SetRouteSources from Server interface +func (m *MockServer) SetRouteSources(selected, active func() route.HAMap) { // Mock implementation - no-op } diff --git a/client/internal/dns/network_manager_unix.go b/client/internal/dns/network_manager_unix.go index 66d82dcd7..3932e78b7 100644 --- a/client/internal/dns/network_manager_unix.go +++ b/client/internal/dns/network_manager_unix.go @@ -8,6 +8,7 @@ import ( "errors" "fmt" "net/netip" + "slices" "strings" "time" @@ -32,6 +33,15 @@ const ( networkManagerDbusDeviceGetAppliedConnectionMethod = networkManagerDbusDeviceInterface + ".GetAppliedConnection" networkManagerDbusDeviceReapplyMethod = networkManagerDbusDeviceInterface + ".Reapply" networkManagerDbusDeviceDeleteMethod = networkManagerDbusDeviceInterface + ".Delete" + networkManagerDbusDeviceIp4ConfigProperty = networkManagerDbusDeviceInterface + ".Ip4Config" + networkManagerDbusDeviceIp6ConfigProperty = networkManagerDbusDeviceInterface + ".Ip6Config" + networkManagerDbusDeviceIfaceProperty = networkManagerDbusDeviceInterface + ".Interface" + networkManagerDbusGetDevicesMethod = networkManagerDest + ".GetDevices" + networkManagerDbusIp4ConfigInterface = "org.freedesktop.NetworkManager.IP4Config" + networkManagerDbusIp6ConfigInterface = "org.freedesktop.NetworkManager.IP6Config" + networkManagerDbusIp4ConfigNameserverDataProperty = networkManagerDbusIp4ConfigInterface + ".NameserverData" + networkManagerDbusIp4ConfigNameserversProperty = networkManagerDbusIp4ConfigInterface + ".Nameservers" + networkManagerDbusIp6ConfigNameserversProperty = networkManagerDbusIp6ConfigInterface + ".Nameservers" networkManagerDbusDefaultBehaviorFlag networkManagerConfigBehavior = 0 networkManagerDbusIPv4Key = "ipv4" networkManagerDbusIPv6Key = "ipv6" @@ -51,9 +61,10 @@ var supportedNetworkManagerVersionConstraints = []string{ } type networkManagerDbusConfigurator struct { - dbusLinkObject dbus.ObjectPath - routingAll bool - ifaceName string + dbusLinkObject dbus.ObjectPath + routingAll bool + ifaceName string + origNameservers []netip.Addr } // the types below are based on dbus specification, each field is mapped to a dbus type @@ -92,10 +103,200 @@ func newNetworkManagerDbusConfigurator(wgInterface string) (*networkManagerDbusC log.Debugf("got network manager dbus Link Object: %s from net interface %s", s, wgInterface) - return &networkManagerDbusConfigurator{ + c := &networkManagerDbusConfigurator{ dbusLinkObject: dbus.ObjectPath(s), ifaceName: wgInterface, - }, nil + } + + origNameservers, err := c.captureOriginalNameservers() + switch { + case err != nil: + log.Warnf("capture original nameservers from NetworkManager: %v", err) + case len(origNameservers) == 0: + log.Warnf("no original nameservers captured from non-WG NetworkManager devices; DNS fallback will be empty") + default: + log.Debugf("captured %d original nameservers from non-WG NetworkManager devices: %v", len(origNameservers), origNameservers) + } + c.origNameservers = origNameservers + return c, nil +} + +// captureOriginalNameservers reads DNS servers from every NM device's +// IP4Config / IP6Config except our WG device. +func (n *networkManagerDbusConfigurator) captureOriginalNameservers() ([]netip.Addr, error) { + devices, err := networkManagerListDevices() + if err != nil { + return nil, fmt.Errorf("list devices: %w", err) + } + + seen := make(map[netip.Addr]struct{}) + var out []netip.Addr + for _, dev := range devices { + if dev == n.dbusLinkObject { + continue + } + ifaceName := readNetworkManagerDeviceInterface(dev) + for _, addr := range readNetworkManagerDeviceDNS(dev) { + addr = addr.Unmap() + if !addr.IsValid() || addr.IsUnspecified() { + continue + } + // IP6Config.Nameservers is a byte slice without zone info; + // reattach the device's interface name so a captured fe80::โ€ฆ + // stays routable. + if addr.IsLinkLocalUnicast() && ifaceName != "" { + addr = addr.WithZone(ifaceName) + } + if _, dup := seen[addr]; dup { + continue + } + seen[addr] = struct{}{} + out = append(out, addr) + } + } + return out, nil +} + +func readNetworkManagerDeviceInterface(devicePath dbus.ObjectPath) string { + obj, closeConn, err := getDbusObject(networkManagerDest, devicePath) + if err != nil { + return "" + } + defer closeConn() + v, err := obj.GetProperty(networkManagerDbusDeviceIfaceProperty) + if err != nil { + return "" + } + s, _ := v.Value().(string) + return s +} + +func networkManagerListDevices() ([]dbus.ObjectPath, error) { + obj, closeConn, err := getDbusObject(networkManagerDest, networkManagerDbusObjectNode) + if err != nil { + return nil, fmt.Errorf("dbus NetworkManager: %w", err) + } + defer closeConn() + var devs []dbus.ObjectPath + if err := obj.Call(networkManagerDbusGetDevicesMethod, dbusDefaultFlag).Store(&devs); err != nil { + return nil, err + } + return devs, nil +} + +func readNetworkManagerDeviceDNS(devicePath dbus.ObjectPath) []netip.Addr { + obj, closeConn, err := getDbusObject(networkManagerDest, devicePath) + if err != nil { + return nil + } + defer closeConn() + + var out []netip.Addr + if path := readNetworkManagerConfigPath(obj, networkManagerDbusDeviceIp4ConfigProperty); path != "" { + out = append(out, readIPv4ConfigDNS(path)...) + } + if path := readNetworkManagerConfigPath(obj, networkManagerDbusDeviceIp6ConfigProperty); path != "" { + out = append(out, readIPv6ConfigDNS(path)...) + } + return out +} + +func readNetworkManagerConfigPath(obj dbus.BusObject, property string) dbus.ObjectPath { + v, err := obj.GetProperty(property) + if err != nil { + return "" + } + path, ok := v.Value().(dbus.ObjectPath) + if !ok || path == "/" { + return "" + } + return path +} + +func readIPv4ConfigDNS(path dbus.ObjectPath) []netip.Addr { + obj, closeConn, err := getDbusObject(networkManagerDest, path) + if err != nil { + return nil + } + defer closeConn() + + // NameserverData (NM 1.13+) carries strings; older NMs only expose the + // legacy uint32 Nameservers property. + if out := readIPv4NameserverData(obj); len(out) > 0 { + return out + } + return readIPv4LegacyNameservers(obj) +} + +func readIPv4NameserverData(obj dbus.BusObject) []netip.Addr { + v, err := obj.GetProperty(networkManagerDbusIp4ConfigNameserverDataProperty) + if err != nil { + return nil + } + entries, ok := v.Value().([]map[string]dbus.Variant) + if !ok { + return nil + } + var out []netip.Addr + for _, entry := range entries { + addrVar, ok := entry["address"] + if !ok { + continue + } + s, ok := addrVar.Value().(string) + if !ok { + continue + } + if a, err := netip.ParseAddr(s); err == nil { + out = append(out, a) + } + } + return out +} + +func readIPv4LegacyNameservers(obj dbus.BusObject) []netip.Addr { + v, err := obj.GetProperty(networkManagerDbusIp4ConfigNameserversProperty) + if err != nil { + return nil + } + raw, ok := v.Value().([]uint32) + if !ok { + return nil + } + out := make([]netip.Addr, 0, len(raw)) + for _, n := range raw { + var b [4]byte + binary.LittleEndian.PutUint32(b[:], n) + out = append(out, netip.AddrFrom4(b)) + } + return out +} + +func readIPv6ConfigDNS(path dbus.ObjectPath) []netip.Addr { + obj, closeConn, err := getDbusObject(networkManagerDest, path) + if err != nil { + return nil + } + defer closeConn() + v, err := obj.GetProperty(networkManagerDbusIp6ConfigNameserversProperty) + if err != nil { + return nil + } + raw, ok := v.Value().([][]byte) + if !ok { + return nil + } + out := make([]netip.Addr, 0, len(raw)) + for _, b := range raw { + if a, ok := netip.AddrFromSlice(b); ok { + out = append(out, a) + } + } + return out +} + +func (n *networkManagerDbusConfigurator) getOriginalNameservers() []netip.Addr { + return slices.Clone(n.origNameservers) } func (n *networkManagerDbusConfigurator) supportCustomPort() bool { diff --git a/client/internal/dns/server.go b/client/internal/dns/server.go index 6fe2e21b6..e689f3586 100644 --- a/client/internal/dns/server.go +++ b/client/internal/dns/server.go @@ -6,11 +6,10 @@ import ( "fmt" "net/netip" "net/url" - "os" - "runtime" - "strconv" + "slices" "strings" "sync" + "time" "github.com/miekg/dns" "github.com/mitchellh/hashstructure/v2" @@ -25,11 +24,31 @@ import ( "github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/statemanager" + "github.com/netbirdio/netbird/client/proto" nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/shared/management/domain" ) -const envSkipDNSProbe = "NB_SKIP_DNS_PROBE" +const ( + // healthLookback must exceed the upstream query timeout so one + // query per refresh cycle is enough to keep a group marked healthy. + healthLookback = 60 * time.Second + nsGroupHealthRefreshInterval = 10 * time.Second + // defaultWarningDelayBase is the starting grace window before a + // "Nameserver group unreachable" event fires for a group that's + // never been healthy and only has overlay upstreams with no + // Connected peer. Per-server and overridable; see warningDelayFor. + defaultWarningDelayBase = 30 * time.Second + // warningDelayBonusCap caps the route-count bonus added to the + // base grace window. See warningDelayFor. + warningDelayBonusCap = 30 * time.Second +) + +// errNoUsableNameservers signals that a merged-domain group has no usable +// upstream servers. Callers should skip the group without treating it as a +// build failure. +var errNoUsableNameservers = errors.New("no usable nameservers") // ReadyListener is a notification mechanism what indicate the server is ready to handle host dns address changes type ReadyListener interface { @@ -54,10 +73,9 @@ type Server interface { UpdateDNSServer(serial uint64, update nbdns.Config) error OnUpdatedHostDNSServer(addrs []netip.AddrPort) SearchDomains() []string - ProbeAvailability() UpdateServerConfig(domains dnsconfig.ServerDomains) error PopulateManagementDomain(mgmtURL *url.URL) error - SetRouteChecker(func(netip.Addr) bool) + SetRouteSources(selected, active func() route.HAMap) SetFirewall(Firewall) } @@ -66,12 +84,47 @@ type nsGroupsByDomain struct { groups []*nbdns.NameServerGroup } -// hostManagerWithOriginalNS extends the basic hostManager interface -type hostManagerWithOriginalNS interface { - hostManager - getOriginalNameservers() []netip.Addr +// nsGroupID identifies a nameserver group by the tuple (server list, domain +// list) so config updates produce stable IDs across recomputations. +type nsGroupID string + +// nsHealthSnapshot is the input to projectNSGroupHealth, captured under +// s.mux so projection runs lock-free. +type nsHealthSnapshot struct { + groups []*nbdns.NameServerGroup + merged map[netip.AddrPort]UpstreamHealth + selected route.HAMap + active route.HAMap } +// nsGroupProj holds per-group state for the emission rules. +type nsGroupProj struct { + // unhealthySince is the start of the current Unhealthy streak, + // zero when the group is not currently Unhealthy. + unhealthySince time.Time + // everHealthy is sticky: once the group has been Healthy at least + // once this session, subsequent failures skip warningDelay. + everHealthy bool + // warningActive tracks whether we've already published a warning + // for the current streak, so recovery emits iff a warning did. + warningActive bool +} + +// nsGroupVerdict is the outcome of evaluateNSGroupHealth. +type nsGroupVerdict int + +const ( + // nsVerdictUndecided means no upstream has a fresh observation + // (startup before first query, or records aged past healthLookback). + nsVerdictUndecided nsGroupVerdict = iota + // nsVerdictHealthy means at least one upstream's most-recent + // in-lookback observation is a success. + nsVerdictHealthy + // nsVerdictUnhealthy means at least one upstream has a recent + // failure and none has a fresher success. + nsVerdictUnhealthy +) + // DefaultServer dns server object type DefaultServer struct { ctx context.Context @@ -100,26 +153,46 @@ type DefaultServer struct { permanent bool hostsDNSHolder *hostsDNSHolder + // fallbackHandler is the upstream resolver currently registered at + // PriorityFallback. Tracked so registerFallback can Stop() the previous + // instance instead of leaking its context. + fallbackHandler handlerWithStop + // make sense on mobile only searchDomainNotifier *notifier iosDnsManager IosDnsManager statusRecorder *peer.Status stateManager *statemanager.Manager - routeMatch func(netip.Addr) bool + // selectedRoutes returns admin-enabled client routes. + selectedRoutes func() route.HAMap + // activeRoutes returns the subset whose peer is in StatusConnected. + activeRoutes func() route.HAMap - probeMu sync.Mutex - probeCancel context.CancelFunc - probeWg sync.WaitGroup + nsGroups []*nbdns.NameServerGroup + healthProjectMu sync.Mutex + // nsGroupProj is the per-group state used by the emission rules. + // Accessed only under healthProjectMu. + nsGroupProj map[nsGroupID]*nsGroupProj + // warningDelayBase is the base grace window for health projection. + // Set at construction, mutated only by tests. Read by the + // refresher goroutine so never change it while one is running. + warningDelayBase time.Duration + // healthRefresh is buffered=1; writers coalesce, senders never block. + // See refreshHealth for the lock-order rationale. + healthRefresh chan struct{} } type handlerWithStop interface { dns.Handler Stop() - ProbeAvailability(context.Context) ID() types.HandlerID } +type upstreamHealthReporter interface { + UpstreamHealth() map[netip.AddrPort]UpstreamHealth +} + type handlerWrapper struct { domain string handler handlerWithStop @@ -174,7 +247,6 @@ func NewDefaultServerPermanentUpstream( ds.hostsDNSHolder.set(hostsDnsList) ds.permanent = true - ds.addHostRootZone() ds.currentConfig = dnsConfigToHostDNSConfig(config, ds.service.RuntimeIP(), ds.service.RuntimePort()) ds.searchDomainNotifier = newNotifier(ds.SearchDomains()) ds.searchDomainNotifier.setListener(listener) @@ -182,21 +254,17 @@ func NewDefaultServerPermanentUpstream( return ds } -// NewDefaultServerIos returns a new dns server. It optimized for ios +// NewDefaultServerIos returns a new dns server. It optimized for ios. func NewDefaultServerIos( ctx context.Context, wgInterface WGIface, iosDnsManager IosDnsManager, - hostsDnsList []netip.AddrPort, statusRecorder *peer.Status, disableSys bool, ) *DefaultServer { - log.Debugf("iOS host dns address list is: %v", hostsDnsList) ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder, nil, disableSys) ds.iosDnsManager = iosDnsManager - ds.hostsDNSHolder.set(hostsDnsList) ds.permanent = true - ds.addHostRootZone() return ds } @@ -230,6 +298,8 @@ func newDefaultServer( hostManager: &noopHostConfigurator{}, mgmtCacheResolver: mgmtCacheResolver, currentConfigHash: ^uint64(0), // Initialize to max uint64 to ensure first config is always applied + warningDelayBase: defaultWarningDelayBase, + healthRefresh: make(chan struct{}, 1), } // register with root zone, handler chain takes care of the routing @@ -238,12 +308,26 @@ func newDefaultServer( return defaultServer } -// SetRouteChecker sets the function used by upstream resolvers to determine -// whether an IP is routed through the tunnel. -func (s *DefaultServer) SetRouteChecker(f func(netip.Addr) bool) { +// SetRouteSources wires the route-manager accessors used by health +// projection to classify each upstream for emission timing. +func (s *DefaultServer) SetRouteSources(selected, active func() route.HAMap) { s.mux.Lock() defer s.mux.Unlock() - s.routeMatch = f + s.selectedRoutes = selected + s.activeRoutes = active + + // Permanent / iOS constructors build the root handler before the + // engine wires route sources, so its selectedRoutes callback would + // otherwise remain nil and overlay upstreams would be classified + // as public. Propagate the new accessors to existing handlers. + type routeSettable interface { + setSelectedRoutes(func() route.HAMap) + } + for _, entry := range s.dnsMuxMap { + if h, ok := entry.handler.(routeSettable); ok { + h.setSelectedRoutes(selected) + } + } } // RegisterHandler registers a handler for the given domains with the given priority. @@ -256,7 +340,6 @@ func (s *DefaultServer) RegisterHandler(domains domain.List, handler dns.Handler // TODO: This will take over zones for non-wildcard domains, for which we might not have a handler in the chain for _, domain := range domains { - // convert to zone with simple ref counter s.extraDomains[toZone(domain)]++ } if !s.batchMode { @@ -357,6 +440,8 @@ func (s *DefaultServer) Initialize() (err error) { s.stateManager.RegisterState(&ShutdownState{}) + s.startHealthRefresher() + // Keep using noop host manager if dns off requested or running in netstack mode. // Netstack mode currently doesn't have a way to receive DNS requests. // TODO: Use listener on localhost in netstack mode when running as root. @@ -370,6 +455,13 @@ func (s *DefaultServer) Initialize() (err error) { return fmt.Errorf("initialize: %w", err) } s.hostManager = hostManager + // On mobile-permanent setups the seeded host DNS list is the only + // source until the first network-map arrives; register it now so DNS + // works in that window. Desktop host managers register fallback when + // applyConfiguration runs. + if s.permanent { + s.registerFallback() + } return nil } @@ -394,13 +486,7 @@ func (s *DefaultServer) SetFirewall(fw Firewall) { // Stop stops the server func (s *DefaultServer) Stop() { - s.probeMu.Lock() - if s.probeCancel != nil { - s.probeCancel() - } s.ctxCancel() - s.probeMu.Unlock() - s.probeWg.Wait() s.shutdownWg.Wait() s.mux.Lock() @@ -411,6 +497,13 @@ func (s *DefaultServer) Stop() { } clear(s.extraDomains) + + // Clear health projection state so a subsequent Start doesn't + // inherit sticky flags (notably everHealthy) that would bypass + // the grace window during the next peer handshake. + s.healthProjectMu.Lock() + s.nsGroupProj = nil + s.healthProjectMu.Unlock() } func (s *DefaultServer) disableDNS() (retErr error) { @@ -424,10 +517,9 @@ func (s *DefaultServer) disableDNS() (retErr error) { return nil } - // Deregister original nameservers if they were registered as fallback - if srvs, ok := s.hostManager.(hostManagerWithOriginalNS); ok && len(srvs.getOriginalNameservers()) > 0 { - log.Debugf("deregistering original nameservers as fallback handlers") - s.deregisterHandler([]string{nbdns.RootZone}, PriorityFallback) + if s.fallbackHandler != nil { + log.Debugf("deregistering fallback handlers") + s.clearFallback() } if err := s.hostManager.restoreHostDNS(); err != nil { @@ -441,27 +533,16 @@ func (s *DefaultServer) disableDNS() (retErr error) { return nil } -// OnUpdatedHostDNSServer update the DNS servers addresses for root zones -// It will be applied if the mgm server do not enforce DNS settings for root zone +// OnUpdatedHostDNSServer updates the fallback DNS upstreams. Called by Android +// outside the engine's sync mux when the OS reports a network change, so it +// takes s.mux to serialize against host manager swaps in Initialize/enableDNS. func (s *DefaultServer) OnUpdatedHostDNSServer(hostsDnsList []netip.AddrPort) { s.hostsDNSHolder.set(hostsDnsList) - - // Check if there's any root handler - var hasRootHandler bool - for _, handler := range s.dnsMuxMap { - if handler.domain == nbdns.RootZone { - hasRootHandler = true - break - } - } - - if hasRootHandler { - log.Debugf("on new host DNS config but skip to apply it") - return - } - log.Debugf("update host DNS settings: %+v", hostsDnsList) - s.addHostRootZone() + + s.mux.Lock() + defer s.mux.Unlock() + s.registerFallback() } // UpdateDNSServer processes an update received from the management service @@ -520,69 +601,6 @@ func (s *DefaultServer) SearchDomains() []string { return searchDomains } -// ProbeAvailability tests each upstream group's servers for availability -// and deactivates the group if no server responds. -// If a previous probe is still running, it will be cancelled before starting a new one. -func (s *DefaultServer) ProbeAvailability() { - if val := os.Getenv(envSkipDNSProbe); val != "" { - skipProbe, err := strconv.ParseBool(val) - if err != nil { - log.Warnf("failed to parse %s: %v", envSkipDNSProbe, err) - } - if skipProbe { - log.Infof("skipping DNS probe due to %s", envSkipDNSProbe) - return - } - } - - s.probeMu.Lock() - - // don't start probes on a stopped server - if s.ctx.Err() != nil { - s.probeMu.Unlock() - return - } - - // cancel any running probe - if s.probeCancel != nil { - s.probeCancel() - s.probeCancel = nil - } - - // wait for the previous probe goroutines to finish while holding - // the mutex so no other caller can start a new probe concurrently - s.probeWg.Wait() - - // start a new probe - probeCtx, probeCancel := context.WithCancel(s.ctx) - s.probeCancel = probeCancel - - s.probeWg.Add(1) - defer s.probeWg.Done() - - // Snapshot handlers under s.mux to avoid racing with updateMux/dnsMuxMap writers. - s.mux.Lock() - handlers := make([]handlerWithStop, 0, len(s.dnsMuxMap)) - for _, mux := range s.dnsMuxMap { - handlers = append(handlers, mux.handler) - } - s.mux.Unlock() - - var wg sync.WaitGroup - for _, handler := range handlers { - wg.Add(1) - go func(h handlerWithStop) { - defer wg.Done() - h.ProbeAvailability(probeCtx) - }(handler) - } - - s.probeMu.Unlock() - - wg.Wait() - probeCancel() -} - func (s *DefaultServer) UpdateServerConfig(domains dnsconfig.ServerDomains) error { s.mux.Lock() defer s.mux.Unlock() @@ -746,19 +764,17 @@ func (s *DefaultServer) applyHostConfig() { s.currentConfigHash = hash } - s.registerFallback(config) + s.registerFallback() } // registerFallback registers original nameservers as low-priority fallback handlers. -func (s *DefaultServer) registerFallback(config HostDNSConfig) { - hostMgrWithNS, ok := s.hostManager.(hostManagerWithOriginalNS) - if !ok { - return - } - - originalNameservers := hostMgrWithNS.getOriginalNameservers() +// Replaces and Stop()s the previously-registered fallback handler so its +// context is released rather than leaked until GC. +func (s *DefaultServer) registerFallback() { + originalNameservers := s.hostManager.getOriginalNameservers() if len(originalNameservers) == 0 { - s.deregisterHandler([]string{nbdns.RootZone}, PriorityFallback) + log.Debugf("no fallback upstreams to register; clearing PriorityFallback handler") + s.clearFallback() return } @@ -775,21 +791,28 @@ func (s *DefaultServer) registerFallback(config HostDNSConfig) { log.Errorf("failed to create upstream resolver for original nameservers: %v", err) return } - handler.routeMatch = s.routeMatch + handler.selectedRoutes = s.selectedRoutes + var servers []netip.AddrPort for _, ns := range originalNameservers { - if ns == config.ServerIP { - log.Debugf("skipping original nameserver %s as it is the same as the server IP %s", ns, config.ServerIP) - continue - } - - addrPort := netip.AddrPortFrom(ns, DefaultPort) - handler.upstreamServers = append(handler.upstreamServers, addrPort) + servers = append(servers, netip.AddrPortFrom(ns, DefaultPort)) } - handler.deactivate = func(error) { /* always active */ } - handler.reactivate = func() { /* always active */ } + handler.addRace(servers) + prev := s.fallbackHandler + s.fallbackHandler = handler s.registerHandler([]string{nbdns.RootZone}, handler, PriorityFallback) + if prev != nil { + prev.Stop() + } +} + +func (s *DefaultServer) clearFallback() { + s.deregisterHandler([]string{nbdns.RootZone}, PriorityFallback) + if s.fallbackHandler != nil { + s.fallbackHandler.Stop() + s.fallbackHandler = nil + } } func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone) ([]handlerWrapper, []nbdns.CustomZone, error) { @@ -847,100 +870,99 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam groupedNS := groupNSGroupsByDomain(nameServerGroups) for _, domainGroup := range groupedNS { - basePriority := PriorityUpstream + priority := PriorityUpstream if domainGroup.domain == nbdns.RootZone { - basePriority = PriorityDefault + priority = PriorityDefault } - updates, err := s.createHandlersForDomainGroup(domainGroup, basePriority) + update, err := s.buildMergedDomainHandler(domainGroup, priority) if err != nil { + if errors.Is(err, errNoUsableNameservers) { + log.Errorf("no usable nameservers for domain=%s", domainGroup.domain) + continue + } return nil, err } - muxUpdates = append(muxUpdates, updates...) + muxUpdates = append(muxUpdates, *update) } return muxUpdates, nil } -func (s *DefaultServer) createHandlersForDomainGroup(domainGroup nsGroupsByDomain, basePriority int) ([]handlerWrapper, error) { - var muxUpdates []handlerWrapper +// buildMergedDomainHandler merges every nameserver group that targets the +// same domain into one handler whose inner groups are raced in parallel. +func (s *DefaultServer) buildMergedDomainHandler(domainGroup nsGroupsByDomain, priority int) (*handlerWrapper, error) { + handler, err := newUpstreamResolver( + s.ctx, + s.wgInterface, + s.statusRecorder, + s.hostsDNSHolder, + domain.Domain(domainGroup.domain), + ) + if err != nil { + return nil, fmt.Errorf("create upstream resolver: %v", err) + } + handler.selectedRoutes = s.selectedRoutes - for i, nsGroup := range domainGroup.groups { - // Decrement priority by handler index (0, 1, 2, ...) to avoid conflicts - priority := basePriority - i - - // Check if we're about to overlap with the next priority tier - if s.leaksPriority(domainGroup, basePriority, priority) { - break - } - - log.Debugf("creating handler for domain=%s with priority=%d", domainGroup.domain, priority) - handler, err := newUpstreamResolver( - s.ctx, - s.wgInterface, - s.statusRecorder, - s.hostsDNSHolder, - domainGroup.domain, - ) - if err != nil { - return nil, fmt.Errorf("create upstream resolver: %v", err) - } - handler.routeMatch = s.routeMatch - - for _, ns := range nsGroup.NameServers { - if ns.NSType != nbdns.UDPNameServerType { - log.Warnf("skipping nameserver %s with type %s, this peer supports only %s", - ns.IP.String(), ns.NSType.String(), nbdns.UDPNameServerType.String()) - continue - } - - if ns.IP == s.service.RuntimeIP() { - log.Warnf("skipping nameserver %s as it matches our DNS server IP, preventing potential loop", ns.IP) - continue - } - - handler.upstreamServers = append(handler.upstreamServers, ns.AddrPort()) - } - - if len(handler.upstreamServers) == 0 { - handler.Stop() - log.Errorf("received a nameserver group with an invalid nameserver list") + for _, nsGroup := range domainGroup.groups { + servers := s.filterNameServers(nsGroup.NameServers) + if len(servers) == 0 { + log.Warnf("nameserver group for domain=%s yielded no usable servers, skipping", domainGroup.domain) continue } - - // when upstream fails to resolve domain several times over all it servers - // it will calls this hook to exclude self from the configuration and - // reapply DNS settings, but it not touch the original configuration and serial number - // because it is temporal deactivation until next try - // - // after some period defined by upstream it tries to reactivate self by calling this hook - // everything we need here is just to re-apply current configuration because it already - // contains this upstream settings (temporal deactivation not removed it) - handler.deactivate, handler.reactivate = s.upstreamCallbacks(nsGroup, handler, priority) - - muxUpdates = append(muxUpdates, handlerWrapper{ - domain: domainGroup.domain, - handler: handler, - priority: priority, - }) + handler.addRace(servers) } - return muxUpdates, nil + if len(handler.upstreamServers) == 0 { + handler.Stop() + return nil, errNoUsableNameservers + } + + log.Debugf("creating merged handler for domain=%s with %d group(s) priority=%d", domainGroup.domain, len(handler.upstreamServers), priority) + + return &handlerWrapper{ + domain: domainGroup.domain, + handler: handler, + priority: priority, + }, nil } -func (s *DefaultServer) leaksPriority(domainGroup nsGroupsByDomain, basePriority int, priority int) bool { - if basePriority == PriorityUpstream && priority <= PriorityDefault { - log.Warnf("too many handlers for domain=%s, would overlap with default priority tier (diff=%d). Skipping remaining handlers", - domainGroup.domain, PriorityUpstream-PriorityDefault) - return true - } - if basePriority == PriorityDefault && priority <= PriorityFallback { - log.Warnf("too many handlers for domain=%s, would overlap with fallback priority tier (diff=%d). Skipping remaining handlers", - domainGroup.domain, PriorityDefault-PriorityFallback) - return true +func (s *DefaultServer) filterNameServers(nameServers []nbdns.NameServer) []netip.AddrPort { + var out []netip.AddrPort + for _, ns := range nameServers { + if ns.NSType != nbdns.UDPNameServerType { + log.Warnf("skipping nameserver %s with type %s, this peer supports only %s", + ns.IP.String(), ns.NSType.String(), nbdns.UDPNameServerType.String()) + continue + } + if ns.IP == s.service.RuntimeIP() { + log.Warnf("skipping nameserver %s as it matches our DNS server IP, preventing potential loop", ns.IP) + continue + } + out = append(out, ns.AddrPort()) } + return out +} - return false +// usableNameServers returns the subset of nameServers the handler would +// actually query. Matches filterNameServers without the warning logs, so +// it's safe to call on every health-projection tick. +func (s *DefaultServer) usableNameServers(nameServers []nbdns.NameServer) []netip.AddrPort { + var runtimeIP netip.Addr + if s.service != nil { + runtimeIP = s.service.RuntimeIP() + } + var out []netip.AddrPort + for _, ns := range nameServers { + if ns.NSType != nbdns.UDPNameServerType { + continue + } + if runtimeIP.IsValid() && ns.IP == runtimeIP { + continue + } + out = append(out, ns.AddrPort()) + } + return out } func (s *DefaultServer) updateMux(muxUpdates []handlerWrapper) { @@ -951,175 +973,356 @@ func (s *DefaultServer) updateMux(muxUpdates []handlerWrapper) { } muxUpdateMap := make(registeredHandlerMap) - var containsRootUpdate bool for _, update := range muxUpdates { - if update.domain == nbdns.RootZone { - containsRootUpdate = true - } s.registerHandler([]string{update.domain}, update.handler, update.priority) muxUpdateMap[update.handler.ID()] = update } - // If there's no root update and we had a root handler, restore it - if !containsRootUpdate { - for _, existing := range s.dnsMuxMap { - if existing.domain == nbdns.RootZone { - s.addHostRootZone() - break - } - } - } - s.dnsMuxMap = muxUpdateMap } -// upstreamCallbacks returns two functions, the first one is used to deactivate -// the upstream resolver from the configuration, the second one is used to -// reactivate it. Not allowed to call reactivate before deactivate. -func (s *DefaultServer) upstreamCallbacks( - nsGroup *nbdns.NameServerGroup, - handler dns.Handler, - priority int, -) (deactivate func(error), reactivate func()) { - var removeIndex map[string]int - deactivate = func(err error) { - s.mux.Lock() - defer s.mux.Unlock() - - l := log.WithField("nameservers", nsGroup.NameServers) - l.Info("Temporarily deactivating nameservers group due to timeout") - - removeIndex = make(map[string]int) - for _, domain := range nsGroup.Domains { - removeIndex[domain] = -1 - } - if nsGroup.Primary { - removeIndex[nbdns.RootZone] = -1 - s.currentConfig.RouteAll = false - s.deregisterHandler([]string{nbdns.RootZone}, priority) - } - - for i, item := range s.currentConfig.Domains { - if _, found := removeIndex[item.Domain]; found { - s.currentConfig.Domains[i].Disabled = true - s.deregisterHandler([]string{item.Domain}, priority) - removeIndex[item.Domain] = i - } - } - - // Always apply host config when nameserver goes down, regardless of batch mode - s.applyHostConfig() - - go func() { - if err := s.stateManager.PersistState(s.ctx); err != nil { - l.Errorf("Failed to persist dns state: %v", err) - } - }() - - if runtime.GOOS == "android" && nsGroup.Primary && len(s.hostsDNSHolder.get()) > 0 { - s.addHostRootZone() - } - - s.updateNSState(nsGroup, err, false) - } - - reactivate = func() { - s.mux.Lock() - defer s.mux.Unlock() - - for domain, i := range removeIndex { - if i == -1 || i >= len(s.currentConfig.Domains) || s.currentConfig.Domains[i].Domain != domain { - continue - } - s.currentConfig.Domains[i].Disabled = false - s.registerHandler([]string{domain}, handler, priority) - } - - l := log.WithField("nameservers", nsGroup.NameServers) - l.Debug("reactivate temporary disabled nameserver group") - - if nsGroup.Primary { - s.currentConfig.RouteAll = true - s.registerHandler([]string{nbdns.RootZone}, handler, priority) - } - - // Always apply host config when nameserver reactivates, regardless of batch mode - s.applyHostConfig() - - s.updateNSState(nsGroup, nil, true) - } - return -} - -func (s *DefaultServer) addHostRootZone() { - hostDNSServers := s.hostsDNSHolder.get() - if len(hostDNSServers) == 0 { - log.Debug("no host DNS servers available, skipping root zone handler creation") - return - } - - handler, err := newUpstreamResolver( - s.ctx, - s.wgInterface, - s.statusRecorder, - s.hostsDNSHolder, - nbdns.RootZone, - ) - if err != nil { - log.Errorf("unable to create a new upstream resolver, error: %v", err) - return - } - handler.routeMatch = s.routeMatch - - handler.upstreamServers = maps.Keys(hostDNSServers) - handler.deactivate = func(error) {} - handler.reactivate = func() {} - - s.registerHandler([]string{nbdns.RootZone}, handler, PriorityDefault) -} - +// updateNSGroupStates records the new group set and pokes the refresher. +// Must hold s.mux; projection runs async (see refreshHealth for why). func (s *DefaultServer) updateNSGroupStates(groups []*nbdns.NameServerGroup) { - var states []peer.NSGroupState + s.nsGroups = groups + select { + case s.healthRefresh <- struct{}{}: + default: + } +} - for _, group := range groups { - var servers []netip.AddrPort - for _, ns := range group.NameServers { - servers = append(servers, ns.AddrPort()) +// refreshHealth runs one projection cycle. Must not be called while +// holding s.mux: the route callbacks re-enter routemanager's lock. +func (s *DefaultServer) refreshHealth() { + s.mux.Lock() + groups := s.nsGroups + merged := s.collectUpstreamHealth() + selFn := s.selectedRoutes + actFn := s.activeRoutes + s.mux.Unlock() + + var selected, active route.HAMap + if selFn != nil { + selected = selFn() + } + if actFn != nil { + active = actFn() + } + + s.projectNSGroupHealth(nsHealthSnapshot{ + groups: groups, + merged: merged, + selected: selected, + active: active, + }) +} + +// projectNSGroupHealth applies the emission rules to the snapshot and +// publishes the resulting NSGroupStates. Serialized by healthProjectMu, +// lock-free wrt s.mux. +// +// Rules: +// - Healthy: emit recovery iff warningActive; set everHealthy. +// - Unhealthy: stamp unhealthySince on streak start; emit warning +// iff any of immediate / everHealthy / elapsed >= effective delay. +// - Undecided: no-op. +// +// "Immediate" means the group has at least one upstream that's public +// or overlay+Connected: no peer-startup race to wait out. +func (s *DefaultServer) projectNSGroupHealth(snap nsHealthSnapshot) { + if s.statusRecorder == nil { + return + } + + s.healthProjectMu.Lock() + defer s.healthProjectMu.Unlock() + + if s.nsGroupProj == nil { + s.nsGroupProj = make(map[nsGroupID]*nsGroupProj) + } + + now := time.Now() + delay := s.warningDelay(haMapRouteCount(snap.selected)) + states := make([]peer.NSGroupState, 0, len(snap.groups)) + seen := make(map[nsGroupID]struct{}, len(snap.groups)) + for _, group := range snap.groups { + servers := s.usableNameServers(group.NameServers) + if len(servers) == 0 { + continue + } + verdict, groupErr := evaluateNSGroupHealth(snap.merged, servers, now) + id := generateGroupKey(group) + seen[id] = struct{}{} + + immediate := s.groupHasImmediateUpstream(servers, snap) + + p, known := s.nsGroupProj[id] + if !known { + p = &nsGroupProj{} + s.nsGroupProj[id] = p } - state := peer.NSGroupState{ - ID: generateGroupKey(group), + enabled := true + switch verdict { + case nsVerdictHealthy: + enabled = s.projectHealthy(p, servers) + case nsVerdictUnhealthy: + enabled = s.projectUnhealthy(p, servers, immediate, now, delay) + case nsVerdictUndecided: + // Stay Available until evidence says otherwise, unless a + // warning is already active for this group. Also clear any + // prior Unhealthy streak so a later Unhealthy verdict starts + // a fresh grace window rather than inheriting a stale one. + p.unhealthySince = time.Time{} + enabled = !p.warningActive + groupErr = nil + } + + states = append(states, peer.NSGroupState{ + ID: string(id), Servers: servers, Domains: group.Domains, - // The probe will determine the state, default enabled - Enabled: true, - Error: nil, - } - states = append(states, state) + Enabled: enabled, + Error: groupErr, + }) } - s.statusRecorder.UpdateDNSStates(states) -} - -func (s *DefaultServer) updateNSState(nsGroup *nbdns.NameServerGroup, err error, enabled bool) { - states := s.statusRecorder.GetDNSStates() - id := generateGroupKey(nsGroup) - for i, state := range states { - if state.ID == id { - states[i].Enabled = enabled - states[i].Error = err - break + for id := range s.nsGroupProj { + if _, ok := seen[id]; !ok { + delete(s.nsGroupProj, id) } } s.statusRecorder.UpdateDNSStates(states) } -func generateGroupKey(nsGroup *nbdns.NameServerGroup) string { - var servers []string +// projectHealthy records a healthy tick on p and publishes a recovery +// event iff a warning was active for the current streak. Returns the +// Enabled flag to record in NSGroupState. +func (s *DefaultServer) projectHealthy(p *nsGroupProj, servers []netip.AddrPort) bool { + p.everHealthy = true + p.unhealthySince = time.Time{} + if !p.warningActive { + return true + } + log.Debugf("DNS health: group [%s] recovered, emitting event", joinAddrPorts(servers)) + s.statusRecorder.PublishEvent( + proto.SystemEvent_INFO, + proto.SystemEvent_DNS, + "Nameserver group recovered", + "DNS servers are reachable again.", + map[string]string{"upstreams": joinAddrPorts(servers)}, + ) + p.warningActive = false + return true +} + +// projectUnhealthy records an unhealthy tick on p, publishes the +// warning when the emission rules fire, and returns the Enabled flag +// to record in NSGroupState. +func (s *DefaultServer) projectUnhealthy(p *nsGroupProj, servers []netip.AddrPort, immediate bool, now time.Time, delay time.Duration) bool { + streakStart := p.unhealthySince.IsZero() + if streakStart { + p.unhealthySince = now + } + reason := unhealthyEmitReason(immediate, p.everHealthy, now.Sub(p.unhealthySince), delay) + switch { + case reason != "" && !p.warningActive: + log.Debugf("DNS health: group [%s] unreachable, emitting event (reason=%s)", joinAddrPorts(servers), reason) + s.statusRecorder.PublishEvent( + proto.SystemEvent_WARNING, + proto.SystemEvent_DNS, + "Nameserver group unreachable", + "Unable to reach one or more DNS servers. This might affect your ability to connect to some services.", + map[string]string{"upstreams": joinAddrPorts(servers)}, + ) + p.warningActive = true + case streakStart && reason == "": + // One line per streak, not per tick. + log.Debugf("DNS health: group [%s] unreachable but holding warning for up to %v (overlay-routed, no connected peer)", joinAddrPorts(servers), delay) + } + return false +} + +// warningDelay returns the grace window for the given selected-route +// count. Scales gently: +1s per 100 routes, capped by +// warningDelayBonusCap. Parallel handshakes mean handshake time grows +// much slower than route count, so linear scaling would overcorrect. +// +// TODO: revisit the scaling curve with real-world data โ€” the current +// values are a reasonable starting point, not a measured fit. +func (s *DefaultServer) warningDelay(routeCount int) time.Duration { + bonus := time.Duration(routeCount/100) * time.Second + if bonus > warningDelayBonusCap { + bonus = warningDelayBonusCap + } + return s.warningDelayBase + bonus +} + +// groupHasImmediateUpstream reports whether the group has at least one +// upstream in a classification that bypasses the grace window: public +// (outside the overlay range and not routed), or overlay/routed with a +// Connected peer. +// +// TODO(ipv6): include the v6 overlay prefix once it's plumbed in. +func (s *DefaultServer) groupHasImmediateUpstream(servers []netip.AddrPort, snap nsHealthSnapshot) bool { + var overlayV4 netip.Prefix + if s.wgInterface != nil { + overlayV4 = s.wgInterface.Address().Network + } + for _, srv := range servers { + addr := srv.Addr().Unmap() + overlay := overlayV4.IsValid() && overlayV4.Contains(addr) + selMatched, selDynamic := haMapContains(snap.selected, addr) + // Treat an unknown (dynamic selected route) as possibly routed: + // the upstream might reach through a dynamic route whose Network + // hasn't resolved yet, and classifying as public would bypass + // the startup grace window. + routed := selMatched || selDynamic + if !overlay && !routed { + return true + } + if actMatched, _ := haMapContains(snap.active, addr); actMatched { + return true + } + } + return false +} + +// collectUpstreamHealth merges health snapshots across handlers, keeping +// the most recent success and failure per upstream when an address appears +// in more than one handler. +func (s *DefaultServer) collectUpstreamHealth() map[netip.AddrPort]UpstreamHealth { + merged := make(map[netip.AddrPort]UpstreamHealth) + for _, entry := range s.dnsMuxMap { + reporter, ok := entry.handler.(upstreamHealthReporter) + if !ok { + continue + } + for addr, h := range reporter.UpstreamHealth() { + existing, have := merged[addr] + if !have { + merged[addr] = h + continue + } + if h.LastOk.After(existing.LastOk) { + existing.LastOk = h.LastOk + } + if h.LastFail.After(existing.LastFail) { + existing.LastFail = h.LastFail + existing.LastErr = h.LastErr + } + merged[addr] = existing + } + } + return merged +} + +func (s *DefaultServer) startHealthRefresher() { + s.shutdownWg.Add(1) + go func() { + defer s.shutdownWg.Done() + ticker := time.NewTicker(nsGroupHealthRefreshInterval) + defer ticker.Stop() + for { + select { + case <-s.ctx.Done(): + return + case <-ticker.C: + case <-s.healthRefresh: + } + s.refreshHealth() + } + }() +} + +// evaluateNSGroupHealth decides a group's verdict from query records +// alone. Per upstream, the most-recent-in-lookback observation wins. +// Group is Healthy if any upstream is fresh-working, Unhealthy if any +// is fresh-broken with no fresh-working sibling, Undecided otherwise. +func evaluateNSGroupHealth(merged map[netip.AddrPort]UpstreamHealth, servers []netip.AddrPort, now time.Time) (nsGroupVerdict, error) { + anyWorking := false + anyBroken := false + var mostRecentFail time.Time + var mostRecentErr string + + for _, srv := range servers { + h, ok := merged[srv] + if !ok { + continue + } + switch classifyUpstreamHealth(h, now) { + case upstreamFresh: + anyWorking = true + case upstreamBroken: + anyBroken = true + if h.LastFail.After(mostRecentFail) { + mostRecentFail = h.LastFail + mostRecentErr = h.LastErr + } + } + } + + if anyWorking { + return nsVerdictHealthy, nil + } + if anyBroken { + if mostRecentErr == "" { + return nsVerdictUnhealthy, nil + } + return nsVerdictUnhealthy, errors.New(mostRecentErr) + } + return nsVerdictUndecided, nil +} + +// upstreamClassification is the per-upstream verdict within healthLookback. +type upstreamClassification int + +const ( + upstreamStale upstreamClassification = iota + upstreamFresh + upstreamBroken +) + +// classifyUpstreamHealth compares the last ok and last fail timestamps +// against healthLookback and returns which one (if any) counts. Fresh +// wins when both are in-window and ok is newer; broken otherwise. +func classifyUpstreamHealth(h UpstreamHealth, now time.Time) upstreamClassification { + okRecent := !h.LastOk.IsZero() && now.Sub(h.LastOk) <= healthLookback + failRecent := !h.LastFail.IsZero() && now.Sub(h.LastFail) <= healthLookback + switch { + case okRecent && failRecent: + if h.LastOk.After(h.LastFail) { + return upstreamFresh + } + return upstreamBroken + case okRecent: + return upstreamFresh + case failRecent: + return upstreamBroken + } + return upstreamStale +} + +func joinAddrPorts(servers []netip.AddrPort) string { + parts := make([]string, 0, len(servers)) + for _, s := range servers { + parts = append(parts, s.String()) + } + return strings.Join(parts, ", ") +} + +// generateGroupKey returns a stable identity for an NS group so health +// state (everHealthy / warningActive) survives reorderings in the +// configured nameserver or domain lists. +func generateGroupKey(nsGroup *nbdns.NameServerGroup) nsGroupID { + servers := make([]string, 0, len(nsGroup.NameServers)) for _, ns := range nsGroup.NameServers { servers = append(servers, ns.AddrPort().String()) } - return fmt.Sprintf("%v_%v", servers, nsGroup.Domains) + slices.Sort(servers) + domains := slices.Clone(nsGroup.Domains) + slices.Sort(domains) + return nsGroupID(fmt.Sprintf("%v_%v", servers, domains)) } // groupNSGroupsByDomain groups nameserver groups by their match domains @@ -1161,6 +1364,21 @@ func toZone(d domain.Domain) domain.Domain { ) } +// unhealthyEmitReason returns the tag of the rule that fires the +// warning now, or "" if the group is still inside its grace window. +func unhealthyEmitReason(immediate, everHealthy bool, elapsed, delay time.Duration) string { + switch { + case immediate: + return "immediate" + case everHealthy: + return "ever-healthy" + case elapsed >= delay: + return "grace-elapsed" + default: + return "" + } +} + // PopulateManagementDomain populates the DNS cache with management domain func (s *DefaultServer) PopulateManagementDomain(mgmtURL *url.URL) error { if s.mgmtCacheResolver != nil { diff --git a/client/internal/dns/server_android.go b/client/internal/dns/server_android.go index 7ca12d69d..b2cb26f65 100644 --- a/client/internal/dns/server_android.go +++ b/client/internal/dns/server_android.go @@ -1,5 +1,5 @@ package dns func (s *DefaultServer) initialize() (manager hostManager, err error) { - return newHostManager() + return newHostManager(s.hostsDNSHolder) } diff --git a/client/internal/dns/server_test.go b/client/internal/dns/server_test.go index 1026a29fc..722c2abd7 100644 --- a/client/internal/dns/server_test.go +++ b/client/internal/dns/server_test.go @@ -6,7 +6,7 @@ import ( "net" "net/netip" "os" - "strings" + "runtime" "testing" "time" @@ -15,6 +15,7 @@ import ( log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" "golang.zx2c4.com/wireguard/tun/netstack" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" @@ -31,8 +32,10 @@ import ( "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/stdnet" + "github.com/netbirdio/netbird/client/proto" nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/formatter" + "github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/shared/management/domain" ) @@ -101,16 +104,17 @@ func init() { formatter.SetTextFormatter(log.StandardLogger()) } -func generateDummyHandler(domain string, servers []nbdns.NameServer) *upstreamResolverBase { +func generateDummyHandler(d string, servers []nbdns.NameServer) *upstreamResolverBase { var srvs []netip.AddrPort for _, srv := range servers { srvs = append(srvs, srv.AddrPort()) } - return &upstreamResolverBase{ - domain: domain, - upstreamServers: srvs, - cancel: func() {}, + u := &upstreamResolverBase{ + domain: domain.Domain(d), + cancel: func() {}, } + u.addRace(srvs) + return u } func TestUpdateDNSServer(t *testing.T) { @@ -653,74 +657,8 @@ func TestDNSServerStartStop(t *testing.T) { } } -func TestDNSServerUpstreamDeactivateCallback(t *testing.T) { - hostManager := &mockHostConfigurator{} - server := DefaultServer{ - ctx: context.Background(), - service: NewServiceViaMemory(&mocWGIface{}), - localResolver: local.NewResolver(), - handlerChain: NewHandlerChain(), - hostManager: hostManager, - currentConfig: HostDNSConfig{ - Domains: []DomainConfig{ - {false, "domain0", false}, - {false, "domain1", false}, - {false, "domain2", false}, - }, - }, - statusRecorder: peer.NewRecorder("mgm"), - } - - var domainsUpdate string - hostManager.applyDNSConfigFunc = func(config HostDNSConfig, statemanager *statemanager.Manager) error { - domains := []string{} - for _, item := range config.Domains { - if item.Disabled { - continue - } - domains = append(domains, item.Domain) - } - domainsUpdate = strings.Join(domains, ",") - return nil - } - - deactivate, reactivate := server.upstreamCallbacks(&nbdns.NameServerGroup{ - Domains: []string{"domain1"}, - NameServers: []nbdns.NameServer{ - {IP: netip.MustParseAddr("8.8.0.0"), NSType: nbdns.UDPNameServerType, Port: 53}, - }, - }, nil, 0) - - deactivate(nil) - expected := "domain0,domain2" - domains := []string{} - for _, item := range server.currentConfig.Domains { - if item.Disabled { - continue - } - domains = append(domains, item.Domain) - } - got := strings.Join(domains, ",") - if expected != got { - t.Errorf("expected domains list: %q, got %q", expected, got) - } - - reactivate() - expected = "domain0,domain1,domain2" - domains = []string{} - for _, item := range server.currentConfig.Domains { - if item.Disabled { - continue - } - domains = append(domains, item.Domain) - } - got = strings.Join(domains, ",") - if expected != got { - t.Errorf("expected domains list: %q, got %q", expected, domainsUpdate) - } -} - func TestDNSPermanent_updateHostDNS_emptyUpstream(t *testing.T) { + skipUnlessAndroid(t) wgIFace, err := createWgInterfaceWithBind(t) if err != nil { t.Fatal("failed to initialize wg interface") @@ -748,6 +686,7 @@ func TestDNSPermanent_updateHostDNS_emptyUpstream(t *testing.T) { } func TestDNSPermanent_updateUpstream(t *testing.T) { + skipUnlessAndroid(t) wgIFace, err := createWgInterfaceWithBind(t) if err != nil { t.Fatal("failed to initialize wg interface") @@ -841,6 +780,7 @@ func TestDNSPermanent_updateUpstream(t *testing.T) { } func TestDNSPermanent_matchOnly(t *testing.T) { + skipUnlessAndroid(t) wgIFace, err := createWgInterfaceWithBind(t) if err != nil { t.Fatal("failed to initialize wg interface") @@ -913,6 +853,18 @@ func TestDNSPermanent_matchOnly(t *testing.T) { } } +// skipUnlessAndroid marks tests that exercise the mobile-permanent DNS path, +// which only matches a real production setup on android (NewDefaultServerPermanentUpstream +// + androidHostManager). On non-android the desktop host manager replaces it +// during Initialize and the assertion stops making sense. Skipped here until we +// have an android CI runner. +func skipUnlessAndroid(t *testing.T) { + t.Helper() + if runtime.GOOS != "android" { + t.Skip("requires android runner; mobile-permanent path doesn't match production on this OS") + } +} + func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) { t.Helper() ov := os.Getenv("NB_WG_KERNEL_DISABLED") @@ -1065,7 +1017,6 @@ type mockHandler struct { func (m *mockHandler) ServeDNS(dns.ResponseWriter, *dns.Msg) {} func (m *mockHandler) Stop() {} -func (m *mockHandler) ProbeAvailability(context.Context) {} func (m *mockHandler) ID() types.HandlerID { return types.HandlerID(m.Id) } type mockService struct{} @@ -2085,6 +2036,598 @@ func TestLocalResolverPriorityConstants(t *testing.T) { assert.Equal(t, "local.example.com", localMuxUpdates[0].domain) } +// TestBuildUpstreamHandler_MergesGroupsPerDomain verifies that multiple +// admin-defined nameserver groups targeting the same domain collapse into a +// single handler with each group preserved as a sequential inner list. +func TestBuildUpstreamHandler_MergesGroupsPerDomain(t *testing.T) { + wgInterface := &mocWGIface{} + service := NewServiceViaMemory(wgInterface) + server := &DefaultServer{ + ctx: context.Background(), + wgInterface: wgInterface, + service: service, + localResolver: local.NewResolver(), + handlerChain: NewHandlerChain(), + hostManager: &noopHostConfigurator{}, + dnsMuxMap: make(registeredHandlerMap), + } + + groups := []*nbdns.NameServerGroup{ + { + NameServers: []nbdns.NameServer{ + {IP: netip.MustParseAddr("192.0.2.1"), NSType: nbdns.UDPNameServerType, Port: 53}, + }, + Domains: []string{"example.com"}, + }, + { + NameServers: []nbdns.NameServer{ + {IP: netip.MustParseAddr("192.0.2.2"), NSType: nbdns.UDPNameServerType, Port: 53}, + {IP: netip.MustParseAddr("192.0.2.3"), NSType: nbdns.UDPNameServerType, Port: 53}, + }, + Domains: []string{"example.com"}, + }, + } + + muxUpdates, err := server.buildUpstreamHandlerUpdate(groups) + require.NoError(t, err) + require.Len(t, muxUpdates, 1, "same-domain groups should merge into one handler") + assert.Equal(t, "example.com", muxUpdates[0].domain) + assert.Equal(t, PriorityUpstream, muxUpdates[0].priority) + + handler := muxUpdates[0].handler.(*upstreamResolver) + require.Len(t, handler.upstreamServers, 2, "handler should have two groups") + assert.Equal(t, upstreamRace{netip.MustParseAddrPort("192.0.2.1:53")}, handler.upstreamServers[0]) + assert.Equal(t, upstreamRace{ + netip.MustParseAddrPort("192.0.2.2:53"), + netip.MustParseAddrPort("192.0.2.3:53"), + }, handler.upstreamServers[1]) +} + +// TestEvaluateNSGroupHealth covers the records-only verdict. The gate +// (overlay route selected-but-no-active-peer) is intentionally NOT an +// input to the evaluator anymore: the verdict drives the Enabled flag, +// which must always reflect what we actually observed. Gate-aware event +// suppression is tested separately in the projection test. +// +// Matrix per upstream: {no record, fresh Ok, fresh Fail, stale Fail, +// stale Ok, Ok newer than Fail, Fail newer than Ok}. +// Group verdict: any fresh-working โ†’ Healthy; any fresh-broken with no +// fresh-working โ†’ Unhealthy; otherwise Undecided. +func TestEvaluateNSGroupHealth(t *testing.T) { + now := time.Now() + a := netip.MustParseAddrPort("192.0.2.1:53") + b := netip.MustParseAddrPort("192.0.2.2:53") + + recentOk := UpstreamHealth{LastOk: now.Add(-2 * time.Second)} + recentFail := UpstreamHealth{LastFail: now.Add(-1 * time.Second), LastErr: "timeout"} + staleOk := UpstreamHealth{LastOk: now.Add(-10 * time.Minute)} + staleFail := UpstreamHealth{LastFail: now.Add(-10 * time.Minute), LastErr: "timeout"} + okThenFail := UpstreamHealth{ + LastOk: now.Add(-10 * time.Second), + LastFail: now.Add(-1 * time.Second), + LastErr: "timeout", + } + failThenOk := UpstreamHealth{ + LastOk: now.Add(-1 * time.Second), + LastFail: now.Add(-10 * time.Second), + LastErr: "timeout", + } + + tests := []struct { + name string + health map[netip.AddrPort]UpstreamHealth + servers []netip.AddrPort + wantVerdict nsGroupVerdict + wantErrSubst string + }{ + { + name: "no record, undecided", + servers: []netip.AddrPort{a}, + wantVerdict: nsVerdictUndecided, + }, + { + name: "fresh success, healthy", + health: map[netip.AddrPort]UpstreamHealth{a: recentOk}, + servers: []netip.AddrPort{a}, + wantVerdict: nsVerdictHealthy, + }, + { + name: "fresh failure, unhealthy", + health: map[netip.AddrPort]UpstreamHealth{a: recentFail}, + servers: []netip.AddrPort{a}, + wantVerdict: nsVerdictUnhealthy, + wantErrSubst: "timeout", + }, + { + name: "only stale success, undecided", + health: map[netip.AddrPort]UpstreamHealth{a: staleOk}, + servers: []netip.AddrPort{a}, + wantVerdict: nsVerdictUndecided, + }, + { + name: "only stale failure, undecided", + health: map[netip.AddrPort]UpstreamHealth{a: staleFail}, + servers: []netip.AddrPort{a}, + wantVerdict: nsVerdictUndecided, + }, + { + name: "both fresh, fail newer, unhealthy", + health: map[netip.AddrPort]UpstreamHealth{a: okThenFail}, + servers: []netip.AddrPort{a}, + wantVerdict: nsVerdictUnhealthy, + wantErrSubst: "timeout", + }, + { + name: "both fresh, ok newer, healthy", + health: map[netip.AddrPort]UpstreamHealth{a: failThenOk}, + servers: []netip.AddrPort{a}, + wantVerdict: nsVerdictHealthy, + }, + { + name: "two upstreams, one success wins", + health: map[netip.AddrPort]UpstreamHealth{ + a: recentFail, + b: recentOk, + }, + servers: []netip.AddrPort{a, b}, + wantVerdict: nsVerdictHealthy, + }, + { + name: "two upstreams, one fail one unseen, unhealthy", + health: map[netip.AddrPort]UpstreamHealth{ + a: recentFail, + }, + servers: []netip.AddrPort{a, b}, + wantVerdict: nsVerdictUnhealthy, + wantErrSubst: "timeout", + }, + { + name: "two upstreams, all recent failures, unhealthy", + health: map[netip.AddrPort]UpstreamHealth{ + a: {LastFail: now.Add(-5 * time.Second), LastErr: "timeout"}, + b: {LastFail: now.Add(-1 * time.Second), LastErr: "SERVFAIL"}, + }, + servers: []netip.AddrPort{a, b}, + wantVerdict: nsVerdictUnhealthy, + wantErrSubst: "SERVFAIL", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + verdict, err := evaluateNSGroupHealth(tc.health, tc.servers, now) + assert.Equal(t, tc.wantVerdict, verdict, "verdict mismatch") + if tc.wantErrSubst != "" { + require.Error(t, err) + assert.Contains(t, err.Error(), tc.wantErrSubst) + } else { + assert.NoError(t, err) + } + }) + } +} + +// healthStubHandler is a minimal dnsMuxMap entry that exposes a fixed +// UpstreamHealth snapshot, letting tests drive recomputeNSGroupStates +// without spinning up real handlers. +type healthStubHandler struct { + health map[netip.AddrPort]UpstreamHealth +} + +func (h *healthStubHandler) ServeDNS(dns.ResponseWriter, *dns.Msg) {} +func (h *healthStubHandler) Stop() {} +func (h *healthStubHandler) ID() types.HandlerID { return "health-stub" } +func (h *healthStubHandler) UpstreamHealth() map[netip.AddrPort]UpstreamHealth { + return h.health +} + +// TestProjection_SteadyStateIsSilent guards against duplicate events: +// while a group stays Unhealthy tick after tick, only the first +// Unhealthy transition may emit. Same for staying Healthy. +func TestProjection_SteadyStateIsSilent(t *testing.T) { + fx := newProjTestFixture(t) + + fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"}) + fx.tick() + fx.expectEvent("unreachable", "first fail emits warning") + + fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"}) + fx.tick() + fx.tick() + fx.expectNoEvent("staying unhealthy must not re-emit") + + fx.setHealth(UpstreamHealth{LastOk: time.Now()}) + fx.tick() + fx.expectEvent("recovered", "recovery on transition") + + fx.tick() + fx.tick() + fx.expectNoEvent("staying healthy must not re-emit") +} + +// projTestFixture is the common setup for the projection tests: a +// single-upstream group whose route classification the test can flip by +// assigning to selected/active. Callers drive failures/successes by +// mutating stub.health and calling refreshHealth. +type projTestFixture struct { + t *testing.T + recorder *peer.Status + events <-chan *proto.SystemEvent + server *DefaultServer + stub *healthStubHandler + group *nbdns.NameServerGroup + srv netip.AddrPort + selected route.HAMap + active route.HAMap +} + +func newProjTestFixture(t *testing.T) *projTestFixture { + t.Helper() + recorder := peer.NewRecorder("mgm") + sub := recorder.SubscribeToEvents() + t.Cleanup(func() { recorder.UnsubscribeFromEvents(sub) }) + + srv := netip.MustParseAddrPort("100.64.0.1:53") + fx := &projTestFixture{ + t: t, + recorder: recorder, + events: sub.Events(), + stub: &healthStubHandler{health: map[netip.AddrPort]UpstreamHealth{}}, + srv: srv, + group: &nbdns.NameServerGroup{ + Domains: []string{"example.com"}, + NameServers: []nbdns.NameServer{{IP: srv.Addr(), NSType: nbdns.UDPNameServerType, Port: int(srv.Port())}}, + }, + } + fx.server = &DefaultServer{ + ctx: context.Background(), + wgInterface: &mocWGIface{}, + statusRecorder: recorder, + dnsMuxMap: make(registeredHandlerMap), + selectedRoutes: func() route.HAMap { return fx.selected }, + activeRoutes: func() route.HAMap { return fx.active }, + warningDelayBase: defaultWarningDelayBase, + } + fx.server.dnsMuxMap["example.com"] = handlerWrapper{domain: "example.com", handler: fx.stub, priority: PriorityUpstream} + + fx.server.mux.Lock() + fx.server.updateNSGroupStates([]*nbdns.NameServerGroup{fx.group}) + fx.server.mux.Unlock() + return fx +} + +func (f *projTestFixture) setHealth(h UpstreamHealth) { + f.stub.health = map[netip.AddrPort]UpstreamHealth{f.srv: h} +} + +func (f *projTestFixture) tick() []peer.NSGroupState { + f.server.refreshHealth() + return f.recorder.GetDNSStates() +} + +func (f *projTestFixture) expectNoEvent(why string) { + f.t.Helper() + select { + case evt := <-f.events: + f.t.Fatalf("unexpected event (%s): %+v", why, evt) + case <-time.After(100 * time.Millisecond): + } +} + +func (f *projTestFixture) expectEvent(substr, why string) *proto.SystemEvent { + f.t.Helper() + select { + case evt := <-f.events: + assert.Contains(f.t, evt.Message, substr, why) + return evt + case <-time.After(time.Second): + f.t.Fatalf("expected event (%s) with %q", why, substr) + return nil + } +} + +var overlayNetForTest = netip.MustParsePrefix("100.64.0.0/16") +var overlayMapForTest = route.HAMap{"overlay": {{Network: overlayNetForTest}}} + +// TestProjection_PublicFailEmitsImmediately covers rule 1: an upstream +// that is not inside any selected route (public DNS) fires the warning +// on the first Unhealthy tick, no grace period. +func TestProjection_PublicFailEmitsImmediately(t *testing.T) { + fx := newProjTestFixture(t) + + fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"}) + states := fx.tick() + require.Len(t, states, 1) + assert.False(t, states[0].Enabled) + fx.expectEvent("unreachable", "public DNS failure") +} + +// TestProjection_OverlayConnectedFailEmitsImmediately covers rule 2: +// the upstream is inside a selected route AND the route has a Connected +// peer. Tunnel is up, failure is real, emit immediately. +func TestProjection_OverlayConnectedFailEmitsImmediately(t *testing.T) { + fx := newProjTestFixture(t) + fx.selected = overlayMapForTest + fx.active = overlayMapForTest + + fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"}) + states := fx.tick() + require.Len(t, states, 1) + assert.False(t, states[0].Enabled) + fx.expectEvent("unreachable", "overlay + connected failure") +} + +// TestProjection_OverlayNotConnectedDelaysWarning covers rule 3: the +// upstream is routed but no peer is Connected (Connecting/Idle/missing). +// First tick: Unhealthy display, no warning. After the grace window +// elapses with no recovery, the warning fires. +func TestProjection_OverlayNotConnectedDelaysWarning(t *testing.T) { + grace := 50 * time.Millisecond + fx := newProjTestFixture(t) + fx.server.warningDelayBase = grace + fx.selected = overlayMapForTest + // active stays nil: routed but not connected. + + fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"}) + states := fx.tick() + require.Len(t, states, 1) + assert.False(t, states[0].Enabled, "display must reflect failure even during grace window") + fx.expectNoEvent("first fail tick within grace window") + + time.Sleep(grace + 10*time.Millisecond) + fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"}) + fx.tick() + fx.expectEvent("unreachable", "warning after grace window") +} + +// TestProjection_OverlayAddrNoRouteDelaysWarning covers an upstream +// whose address is inside the WireGuard overlay range but is not +// covered by any selected route (peer-to-peer DNS without an explicit +// route). Until a peer reports Connected for that address, startup +// failures must be held just like the routed case. +func TestProjection_OverlayAddrNoRouteDelaysWarning(t *testing.T) { + recorder := peer.NewRecorder("mgm") + sub := recorder.SubscribeToEvents() + t.Cleanup(func() { recorder.UnsubscribeFromEvents(sub) }) + + overlayPeer := netip.MustParseAddrPort("100.66.100.5:53") + server := &DefaultServer{ + ctx: context.Background(), + wgInterface: &mocWGIface{}, + statusRecorder: recorder, + dnsMuxMap: make(registeredHandlerMap), + selectedRoutes: func() route.HAMap { return nil }, + activeRoutes: func() route.HAMap { return nil }, + warningDelayBase: 50 * time.Millisecond, + } + group := &nbdns.NameServerGroup{ + Domains: []string{"example.com"}, + NameServers: []nbdns.NameServer{{IP: overlayPeer.Addr(), NSType: nbdns.UDPNameServerType, Port: int(overlayPeer.Port())}}, + } + stub := &healthStubHandler{health: map[netip.AddrPort]UpstreamHealth{ + overlayPeer: {LastFail: time.Now(), LastErr: "timeout"}, + }} + server.dnsMuxMap["example.com"] = handlerWrapper{domain: "example.com", handler: stub, priority: PriorityUpstream} + + server.mux.Lock() + server.updateNSGroupStates([]*nbdns.NameServerGroup{group}) + server.mux.Unlock() + server.refreshHealth() + + select { + case evt := <-sub.Events(): + t.Fatalf("unexpected event during grace window: %+v", evt) + case <-time.After(100 * time.Millisecond): + } + + time.Sleep(60 * time.Millisecond) + stub.health = map[netip.AddrPort]UpstreamHealth{overlayPeer: {LastFail: time.Now(), LastErr: "timeout"}} + server.refreshHealth() + + select { + case evt := <-sub.Events(): + assert.Contains(t, evt.Message, "unreachable") + case <-time.After(time.Second): + t.Fatal("expected warning after grace window") + } +} + +// TestProjection_StopClearsHealthState verifies that Stop wipes the +// per-group projection state so a subsequent Start doesn't inherit +// sticky flags (notably everHealthy) that would bypass the grace +// window during the next peer handshake. +func TestProjection_StopClearsHealthState(t *testing.T) { + wgIface := &mocWGIface{} + server := &DefaultServer{ + ctx: context.Background(), + wgInterface: wgIface, + service: NewServiceViaMemory(wgIface), + hostManager: &noopHostConfigurator{}, + extraDomains: map[domain.Domain]int{}, + dnsMuxMap: make(registeredHandlerMap), + statusRecorder: peer.NewRecorder("mgm"), + selectedRoutes: func() route.HAMap { return nil }, + activeRoutes: func() route.HAMap { return nil }, + warningDelayBase: defaultWarningDelayBase, + currentConfigHash: ^uint64(0), + } + server.ctx, server.ctxCancel = context.WithCancel(context.Background()) + + srv := netip.MustParseAddrPort("8.8.8.8:53") + group := &nbdns.NameServerGroup{ + Domains: []string{"example.com"}, + NameServers: []nbdns.NameServer{{IP: srv.Addr(), NSType: nbdns.UDPNameServerType, Port: int(srv.Port())}}, + } + stub := &healthStubHandler{health: map[netip.AddrPort]UpstreamHealth{srv: {LastOk: time.Now()}}} + server.dnsMuxMap["example.com"] = handlerWrapper{domain: "example.com", handler: stub, priority: PriorityUpstream} + + server.mux.Lock() + server.updateNSGroupStates([]*nbdns.NameServerGroup{group}) + server.mux.Unlock() + server.refreshHealth() + + server.healthProjectMu.Lock() + p, ok := server.nsGroupProj[generateGroupKey(group)] + server.healthProjectMu.Unlock() + require.True(t, ok, "projection state should exist after tick") + require.True(t, p.everHealthy, "tick with success must set everHealthy") + + server.Stop() + + server.healthProjectMu.Lock() + cleared := server.nsGroupProj == nil + server.healthProjectMu.Unlock() + assert.True(t, cleared, "Stop must clear nsGroupProj") +} + +// TestProjection_OverlayRecoversDuringGrace covers the happy path of +// rule 3: startup failures while the peer is handshaking, then the peer +// comes up and a query succeeds before the grace window elapses. No +// warning should ever have fired, and no recovery either. +func TestProjection_OverlayRecoversDuringGrace(t *testing.T) { + fx := newProjTestFixture(t) + fx.server.warningDelayBase = 200 * time.Millisecond + fx.selected = overlayMapForTest + + fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"}) + fx.tick() + fx.expectNoEvent("fail within grace, warning suppressed") + + fx.active = overlayMapForTest + fx.setHealth(UpstreamHealth{LastOk: time.Now()}) + states := fx.tick() + require.Len(t, states, 1) + assert.True(t, states[0].Enabled) + fx.expectNoEvent("recovery without prior warning must not emit") +} + +// TestProjection_RecoveryOnlyAfterWarning enforces the invariant the +// whole design leans on: recovery events only appear when a warning +// event was actually emitted for the current streak. A Healthy verdict +// without a prior warning is silent, so the user never sees "recovered" +// out of thin air. +func TestProjection_RecoveryOnlyAfterWarning(t *testing.T) { + fx := newProjTestFixture(t) + + fx.setHealth(UpstreamHealth{LastOk: time.Now()}) + states := fx.tick() + require.Len(t, states, 1) + assert.True(t, states[0].Enabled) + fx.expectNoEvent("first healthy tick should not recover anything") + + fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"}) + fx.tick() + fx.expectEvent("unreachable", "public fail emits immediately") + + fx.setHealth(UpstreamHealth{LastOk: time.Now()}) + fx.tick() + fx.expectEvent("recovered", "recovery follows real warning") + + fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"}) + fx.tick() + fx.expectEvent("unreachable", "second cycle warning") + + fx.setHealth(UpstreamHealth{LastOk: time.Now()}) + fx.tick() + fx.expectEvent("recovered", "second cycle recovery") +} + +// TestProjection_EverHealthyOverridesDelay covers rule 4: once a group +// has ever been Healthy, subsequent failures skip the grace window even +// if classification says "routed + not connected". The system has +// proved it can work, so any new failure is real. +func TestProjection_EverHealthyOverridesDelay(t *testing.T) { + fx := newProjTestFixture(t) + // Large base so any emission must come from the everHealthy bypass, not elapsed time. + fx.server.warningDelayBase = time.Hour + fx.selected = overlayMapForTest + fx.active = overlayMapForTest + + // Establish "ever healthy". + fx.setHealth(UpstreamHealth{LastOk: time.Now()}) + fx.tick() + fx.expectNoEvent("first healthy tick") + + // Peer drops. Query fails. Routed + not connected โ†’ normally grace, + // but everHealthy flag bypasses it. + fx.active = nil + fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"}) + fx.tick() + fx.expectEvent("unreachable", "failure after ever-healthy must be immediate") +} + +// TestProjection_ReconnectBlipEmitsPair covers the explicit tradeoff +// from the design discussion: once a group has been healthy, a brief +// reconnect that produces a failing tick will fire warning + recovery. +// This is by design: user-visible blips are accurate signal, not noise. +func TestProjection_ReconnectBlipEmitsPair(t *testing.T) { + fx := newProjTestFixture(t) + fx.selected = overlayMapForTest + fx.active = overlayMapForTest + + fx.setHealth(UpstreamHealth{LastOk: time.Now()}) + fx.tick() + + fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"}) + fx.tick() + fx.expectEvent("unreachable", "blip warning") + + fx.setHealth(UpstreamHealth{LastOk: time.Now()}) + fx.tick() + fx.expectEvent("recovered", "blip recovery") +} + +// TestProjection_MixedGroupEmitsImmediately covers the multi-upstream +// rule: a group with at least one public upstream is in the "immediate" +// category regardless of the other upstreams' routing, because the +// public one has no peer-startup excuse. Prevents public-DNS failures +// from being hidden behind a routed sibling. +func TestProjection_MixedGroupEmitsImmediately(t *testing.T) { + recorder := peer.NewRecorder("mgm") + sub := recorder.SubscribeToEvents() + t.Cleanup(func() { recorder.UnsubscribeFromEvents(sub) }) + events := sub.Events() + + public := netip.MustParseAddrPort("8.8.8.8:53") + overlay := netip.MustParseAddrPort("100.64.0.1:53") + overlayMap := route.HAMap{"overlay": {{Network: netip.MustParsePrefix("100.64.0.0/16")}}} + + server := &DefaultServer{ + ctx: context.Background(), + statusRecorder: recorder, + dnsMuxMap: make(registeredHandlerMap), + selectedRoutes: func() route.HAMap { return overlayMap }, + activeRoutes: func() route.HAMap { return nil }, + warningDelayBase: time.Hour, + } + group := &nbdns.NameServerGroup{ + Domains: []string{"example.com"}, + NameServers: []nbdns.NameServer{ + {IP: public.Addr(), NSType: nbdns.UDPNameServerType, Port: int(public.Port())}, + {IP: overlay.Addr(), NSType: nbdns.UDPNameServerType, Port: int(overlay.Port())}, + }, + } + stub := &healthStubHandler{ + health: map[netip.AddrPort]UpstreamHealth{ + public: {LastFail: time.Now(), LastErr: "servfail"}, + overlay: {LastFail: time.Now(), LastErr: "timeout"}, + }, + } + server.dnsMuxMap["example.com"] = handlerWrapper{domain: "example.com", handler: stub, priority: PriorityUpstream} + + server.mux.Lock() + server.updateNSGroupStates([]*nbdns.NameServerGroup{group}) + server.mux.Unlock() + server.refreshHealth() + + select { + case evt := <-events: + assert.Contains(t, evt.Message, "unreachable") + case <-time.After(time.Second): + t.Fatal("expected immediate warning because group contains a public upstream") + } +} + func TestDNSLoopPrevention(t *testing.T) { wgInterface := &mocWGIface{} service := NewServiceViaMemory(wgInterface) @@ -2183,17 +2726,18 @@ func TestDNSLoopPrevention(t *testing.T) { if tt.expectedHandlers > 0 { handler := muxUpdates[0].handler.(*upstreamResolver) - assert.Len(t, handler.upstreamServers, len(tt.expectedServers)) + flat := handler.flatUpstreams() + assert.Len(t, flat, len(tt.expectedServers)) if tt.shouldFilterOwnIP { - for _, upstream := range handler.upstreamServers { + for _, upstream := range flat { assert.NotEqual(t, dnsServerIP, upstream.Addr()) } } for _, expected := range tt.expectedServers { found := false - for _, upstream := range handler.upstreamServers { + for _, upstream := range flat { if upstream.Addr() == expected { found = true break diff --git a/client/internal/dns/systemd_linux.go b/client/internal/dns/systemd_linux.go index 573dff540..bd301e177 100644 --- a/client/internal/dns/systemd_linux.go +++ b/client/internal/dns/systemd_linux.go @@ -8,6 +8,7 @@ import ( "fmt" "net" "net/netip" + "slices" "time" "github.com/godbus/dbus/v5" @@ -40,10 +41,17 @@ const ( ) type systemdDbusConfigurator struct { - dbusLinkObject dbus.ObjectPath - ifaceName string + dbusLinkObject dbus.ObjectPath + ifaceName string + wgIndex int + origNameservers []netip.Addr } +const ( + systemdDbusLinkDNSProperty = systemdDbusLinkInterface + ".DNS" + systemdDbusLinkDefaultRouteProperty = systemdDbusLinkInterface + ".DefaultRoute" +) + // the types below are based on dbus specification, each field is mapped to a dbus type // see https://dbus.freedesktop.org/doc/dbus-specification.html#basic-types for more details on dbus types // see https://www.freedesktop.org/software/systemd/man/org.freedesktop.resolve1.html on resolve1 input types @@ -79,10 +87,145 @@ func newSystemdDbusConfigurator(wgInterface string) (*systemdDbusConfigurator, e log.Debugf("got dbus Link interface: %s from net interface %s and index %d", s, iface.Name, iface.Index) - return &systemdDbusConfigurator{ + c := &systemdDbusConfigurator{ dbusLinkObject: dbus.ObjectPath(s), ifaceName: wgInterface, - }, nil + wgIndex: iface.Index, + } + + origNameservers, err := c.captureOriginalNameservers() + switch { + case err != nil: + log.Warnf("capture original nameservers from systemd-resolved: %v", err) + case len(origNameservers) == 0: + log.Warnf("no original nameservers captured from systemd-resolved default-route links; DNS fallback will be empty") + default: + log.Debugf("captured %d original nameservers from systemd-resolved default-route links: %v", len(origNameservers), origNameservers) + } + c.origNameservers = origNameservers + return c, nil +} + +// captureOriginalNameservers reads per-link DNS from systemd-resolved for +// every default-route link except our own WG link. Non-default-route links +// (VPNs, docker bridges) are skipped because their upstreams wouldn't +// actually serve host queries. +func (s *systemdDbusConfigurator) captureOriginalNameservers() ([]netip.Addr, error) { + ifaces, err := net.Interfaces() + if err != nil { + return nil, fmt.Errorf("list interfaces: %w", err) + } + + seen := make(map[netip.Addr]struct{}) + var out []netip.Addr + for _, iface := range ifaces { + if !s.isCandidateLink(iface) { + continue + } + linkPath, err := getSystemdLinkPath(iface.Index) + if err != nil || !isSystemdLinkDefaultRoute(linkPath) { + continue + } + for _, addr := range readSystemdLinkDNS(linkPath) { + addr = normalizeSystemdAddr(addr, iface.Name) + if !addr.IsValid() { + continue + } + if _, dup := seen[addr]; dup { + continue + } + seen[addr] = struct{}{} + out = append(out, addr) + } + } + return out, nil +} + +func (s *systemdDbusConfigurator) isCandidateLink(iface net.Interface) bool { + if iface.Index == s.wgIndex { + return false + } + if iface.Flags&net.FlagLoopback != 0 || iface.Flags&net.FlagUp == 0 { + return false + } + return true +} + +// normalizeSystemdAddr unmaps v4-mapped-v6, drops unspecified, and reattaches +// the link's iface name as zone for link-local v6 (Link.DNS strips it). +// Returns the zero Addr to signal "skip this entry". +func normalizeSystemdAddr(addr netip.Addr, ifaceName string) netip.Addr { + addr = addr.Unmap() + if !addr.IsValid() || addr.IsUnspecified() { + return netip.Addr{} + } + if addr.IsLinkLocalUnicast() { + return addr.WithZone(ifaceName) + } + return addr +} + +func getSystemdLinkPath(ifIndex int) (dbus.ObjectPath, error) { + obj, closeConn, err := getDbusObject(systemdResolvedDest, systemdDbusObjectNode) + if err != nil { + return "", fmt.Errorf("dbus resolve1: %w", err) + } + defer closeConn() + var p string + if err := obj.Call(systemdDbusGetLinkMethod, dbusDefaultFlag, int32(ifIndex)).Store(&p); err != nil { + return "", err + } + return dbus.ObjectPath(p), nil +} + +func isSystemdLinkDefaultRoute(linkPath dbus.ObjectPath) bool { + obj, closeConn, err := getDbusObject(systemdResolvedDest, linkPath) + if err != nil { + return false + } + defer closeConn() + v, err := obj.GetProperty(systemdDbusLinkDefaultRouteProperty) + if err != nil { + return false + } + b, ok := v.Value().(bool) + return ok && b +} + +func readSystemdLinkDNS(linkPath dbus.ObjectPath) []netip.Addr { + obj, closeConn, err := getDbusObject(systemdResolvedDest, linkPath) + if err != nil { + return nil + } + defer closeConn() + v, err := obj.GetProperty(systemdDbusLinkDNSProperty) + if err != nil { + return nil + } + entries, ok := v.Value().([][]any) + if !ok { + return nil + } + var out []netip.Addr + for _, entry := range entries { + if len(entry) < 2 { + continue + } + raw, ok := entry[1].([]byte) + if !ok { + continue + } + addr, ok := netip.AddrFromSlice(raw) + if !ok { + continue + } + out = append(out, addr) + } + return out +} + +func (s *systemdDbusConfigurator) getOriginalNameservers() []netip.Addr { + return slices.Clone(s.origNameservers) } func (s *systemdDbusConfigurator) supportCustomPort() bool { diff --git a/client/internal/dns/upstream.go b/client/internal/dns/upstream.go index 39064f26c..a4f713d68 100644 --- a/client/internal/dns/upstream.go +++ b/client/internal/dns/upstream.go @@ -1,3 +1,32 @@ +// Package dns implements the client-side DNS stack: listener/service on the +// peer's tunnel address, handler chain that routes questions by domain and +// priority, and upstream resolvers that forward what remains to configured +// nameservers. +// +// # Upstream resolution and the race model +// +// When two or more nameserver groups target the same domain, DefaultServer +// merges them into one upstream handler whose state is: +// +// upstreamResolverBase +// โ””โ”€โ”€ upstreamServers []upstreamRace // one entry per source NS group +// โ””โ”€โ”€ []netip.AddrPort // primary, fallback, ... +// +// Each source nameserver group contributes one upstreamRace. Within a race +// upstreams are tried in order: the next is used only on failure (timeout, +// SERVFAIL, REFUSED, no response). NXDOMAIN is a valid answer and stops +// the walk. When more than one race exists, ServeDNS fans out one +// goroutine per race and returns the first valid answer, cancelling the +// rest. A handler with a single race skips the fan-out. +// +// # Health projection +// +// Query outcomes are recorded per-upstream in UpstreamHealth. The server +// periodically merges these snapshots across handlers and projects them +// into peer.NSGroupState. There is no active probing: a group is marked +// unhealthy only when every seen upstream has a recent failure and none +// has a recent success. Healthyโ†’unhealthy fires a single +// SystemEvent_WARNING; steady-state refreshes do not duplicate it. package dns import ( @@ -11,11 +40,8 @@ import ( "slices" "strings" "sync" - "sync/atomic" "time" - "github.com/cenkalti/backoff/v4" - "github.com/hashicorp/go-multierror" "github.com/miekg/dns" log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/tun/netstack" @@ -25,7 +51,8 @@ import ( "github.com/netbirdio/netbird/client/internal/dns/resutil" "github.com/netbirdio/netbird/client/internal/dns/types" "github.com/netbirdio/netbird/client/internal/peer" - "github.com/netbirdio/netbird/client/proto" + "github.com/netbirdio/netbird/route" + "github.com/netbirdio/netbird/shared/management/domain" ) var currentMTU uint16 = iface.DefaultMTU @@ -67,15 +94,17 @@ const ( // Set longer than UpstreamTimeout to ensure context timeout takes precedence ClientTimeout = 5 * time.Second - reactivatePeriod = 30 * time.Second - probeTimeout = 2 * time.Second - // ipv6HeaderSize + udpHeaderSize, used to derive the maximum DNS UDP // payload from the tunnel MTU. ipUDPHeaderSize = 60 + 8 -) -const testRecord = "com." + // raceMaxTotalTimeout caps the combined time spent walking all upstreams + // within one race, so a slow primary can't eat the whole race budget. + raceMaxTotalTimeout = 5 * time.Second + // raceMinPerUpstreamTimeout is the floor applied when dividing + // raceMaxTotalTimeout across upstreams within a race. + raceMinPerUpstreamTimeout = 2 * time.Second +) const ( protoUDP = "udp" @@ -84,6 +113,69 @@ const ( type dnsProtocolKey struct{} +type upstreamProtocolKey struct{} + +// upstreamProtocolResult holds the protocol used for the upstream exchange. +// Stored as a pointer in context so the exchange function can set it. +type upstreamProtocolResult struct { + protocol string +} + +type upstreamClient interface { + exchange(ctx context.Context, upstream string, r *dns.Msg) (*dns.Msg, time.Duration, error) +} + +type UpstreamResolver interface { + serveDNS(r *dns.Msg) (*dns.Msg, time.Duration, error) + upstreamExchange(upstream string, r *dns.Msg) (*dns.Msg, time.Duration, error) +} + +// upstreamRace is an ordered list of upstreams derived from one configured +// nameserver group. Order matters: the first upstream is tried first, the +// second only on failure, and so on. Multiple upstreamRace values coexist +// inside one resolver when overlapping nameserver groups target the same +// domain; those races run in parallel and the first valid answer wins. +type upstreamRace []netip.AddrPort + +// UpstreamHealth is the last query-path outcome for a single upstream, +// consumed by nameserver-group status projection. +type UpstreamHealth struct { + LastOk time.Time + LastFail time.Time + LastErr string +} + +type upstreamResolverBase struct { + ctx context.Context + cancel context.CancelFunc + upstreamClient upstreamClient + upstreamServers []upstreamRace + domain domain.Domain + upstreamTimeout time.Duration + + healthMu sync.RWMutex + health map[netip.AddrPort]*UpstreamHealth + + statusRecorder *peer.Status + // selectedRoutes returns the current set of client routes the admin + // has enabled. Called lazily from the query hot path when an upstream + // might need a tunnel-bound client (iOS) and from health projection. + selectedRoutes func() route.HAMap +} + +type upstreamFailure struct { + upstream netip.AddrPort + reason string +} + +type raceResult struct { + msg *dns.Msg + upstream netip.AddrPort + protocol string + ede string + failures []upstreamFailure +} + // contextWithDNSProtocol stores the inbound DNS protocol ("udp" or "tcp") in context. func contextWithDNSProtocol(ctx context.Context, network string) context.Context { return context.WithValue(ctx, dnsProtocolKey{}, network) @@ -100,16 +192,8 @@ func dnsProtocolFromContext(ctx context.Context) string { return "" } -type upstreamProtocolKey struct{} - -// upstreamProtocolResult holds the protocol used for the upstream exchange. -// Stored as a pointer in context so the exchange function can set it. -type upstreamProtocolResult struct { - protocol string -} - -// contextWithupstreamProtocolResult stores a mutable result holder in the context. -func contextWithupstreamProtocolResult(ctx context.Context) (context.Context, *upstreamProtocolResult) { +// contextWithUpstreamProtocolResult stores a mutable result holder in the context. +func contextWithUpstreamProtocolResult(ctx context.Context) (context.Context, *upstreamProtocolResult) { r := &upstreamProtocolResult{} return context.WithValue(ctx, upstreamProtocolKey{}, r), r } @@ -124,67 +208,37 @@ func setUpstreamProtocol(ctx context.Context, protocol string) { } } -type upstreamClient interface { - exchange(ctx context.Context, upstream string, r *dns.Msg) (*dns.Msg, time.Duration, error) -} - -type UpstreamResolver interface { - serveDNS(r *dns.Msg) (*dns.Msg, time.Duration, error) - upstreamExchange(upstream string, r *dns.Msg) (*dns.Msg, time.Duration, error) -} - -type upstreamResolverBase struct { - ctx context.Context - cancel context.CancelFunc - upstreamClient upstreamClient - upstreamServers []netip.AddrPort - domain string - disabled bool - successCount atomic.Int32 - mutex sync.Mutex - reactivatePeriod time.Duration - upstreamTimeout time.Duration - wg sync.WaitGroup - - deactivate func(error) - reactivate func() - statusRecorder *peer.Status - routeMatch func(netip.Addr) bool -} - -type upstreamFailure struct { - upstream netip.AddrPort - reason string -} - -func newUpstreamResolverBase(ctx context.Context, statusRecorder *peer.Status, domain string) *upstreamResolverBase { +func newUpstreamResolverBase(ctx context.Context, statusRecorder *peer.Status, d domain.Domain) *upstreamResolverBase { ctx, cancel := context.WithCancel(ctx) return &upstreamResolverBase{ - ctx: ctx, - cancel: cancel, - domain: domain, - upstreamTimeout: UpstreamTimeout, - reactivatePeriod: reactivatePeriod, - statusRecorder: statusRecorder, + ctx: ctx, + cancel: cancel, + domain: d, + upstreamTimeout: UpstreamTimeout, + statusRecorder: statusRecorder, } } // String returns a string representation of the upstream resolver func (u *upstreamResolverBase) String() string { - return fmt.Sprintf("Upstream %s", u.upstreamServers) + return fmt.Sprintf("Upstream %s", u.flatUpstreams()) } -// ID returns the unique handler ID +// ID returns the unique handler ID. Race groupings and within-race +// ordering are both part of the identity: [[A,B]] and [[A],[B]] query +// the same servers but with different semantics (serial fallback vs +// parallel race), so their handlers must not collide. func (u *upstreamResolverBase) ID() types.HandlerID { - servers := slices.Clone(u.upstreamServers) - slices.SortFunc(servers, func(a, b netip.AddrPort) int { return a.Compare(b) }) - hash := sha256.New() - hash.Write([]byte(u.domain + ":")) - for _, s := range servers { - hash.Write([]byte(s.String())) - hash.Write([]byte("|")) + hash.Write([]byte(u.domain.PunycodeString() + ":")) + for _, race := range u.upstreamServers { + hash.Write([]byte("[")) + for _, s := range race { + hash.Write([]byte(s.String())) + hash.Write([]byte("|")) + } + hash.Write([]byte("]")) } return types.HandlerID("upstream-" + hex.EncodeToString(hash.Sum(nil)[:8])) } @@ -194,13 +248,31 @@ func (u *upstreamResolverBase) MatchSubdomains() bool { } func (u *upstreamResolverBase) Stop() { - log.Debugf("stopping serving DNS for upstreams %s", u.upstreamServers) + log.Debugf("stopping serving DNS for upstreams %s", u.flatUpstreams()) u.cancel() +} - u.mutex.Lock() - u.wg.Wait() - u.mutex.Unlock() +// flatUpstreams is for logging and ID hashing only, not for dispatch. +func (u *upstreamResolverBase) flatUpstreams() []netip.AddrPort { + var out []netip.AddrPort + for _, g := range u.upstreamServers { + out = append(out, g...) + } + return out +} +// setSelectedRoutes swaps the accessor used to classify overlay-routed +// upstreams. Called when route sources are wired after the handler was +// built (permanent / iOS constructors). +func (u *upstreamResolverBase) setSelectedRoutes(selected func() route.HAMap) { + u.selectedRoutes = selected +} + +func (u *upstreamResolverBase) addRace(servers []netip.AddrPort) { + if len(servers) == 0 { + return + } + u.upstreamServers = append(u.upstreamServers, slices.Clone(servers)) } // ServeDNS handles a DNS request @@ -242,82 +314,201 @@ func (u *upstreamResolverBase) prepareRequest(r *dns.Msg) { } func (u *upstreamResolverBase) tryUpstreamServers(ctx context.Context, w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) (bool, []upstreamFailure) { - timeout := u.upstreamTimeout - if len(u.upstreamServers) > 1 { - maxTotal := 5 * time.Second - minPerUpstream := 2 * time.Second - scaledTimeout := maxTotal / time.Duration(len(u.upstreamServers)) - if scaledTimeout > minPerUpstream { - timeout = scaledTimeout - } else { - timeout = minPerUpstream - } + groups := u.upstreamServers + switch len(groups) { + case 0: + return false, nil + case 1: + return u.tryOnlyRace(ctx, w, r, groups[0], logger) + default: + return u.raceAll(ctx, w, r, groups, logger) + } +} + +func (u *upstreamResolverBase) tryOnlyRace(ctx context.Context, w dns.ResponseWriter, r *dns.Msg, group upstreamRace, logger *log.Entry) (bool, []upstreamFailure) { + res := u.tryRace(ctx, r, group) + if res.msg == nil { + return false, res.failures + } + if res.ede != "" { + resutil.SetMeta(w, "ede", res.ede) + } + u.writeSuccessResponse(w, res.msg, res.upstream, r.Question[0].Name, res.protocol, logger) + return true, res.failures +} + +// raceAll runs one worker per group in parallel, taking the first valid +// answer and cancelling the rest. +func (u *upstreamResolverBase) raceAll(ctx context.Context, w dns.ResponseWriter, r *dns.Msg, groups []upstreamRace, logger *log.Entry) (bool, []upstreamFailure) { + raceCtx, cancel := context.WithCancel(ctx) + defer cancel() + + // Buffer sized to len(groups) so workers never block on send, even + // after the coordinator has returned. + results := make(chan raceResult, len(groups)) + for _, g := range groups { + // tryRace clones the request per attempt, so workers never share + // a *dns.Msg and concurrent EDNS0 mutations can't race. + go func(g upstreamRace) { + results <- u.tryRace(raceCtx, r, g) + }(g) } var failures []upstreamFailure - for _, upstream := range u.upstreamServers { - if failure := u.queryUpstream(ctx, w, r, upstream, timeout, logger); failure != nil { - failures = append(failures, *failure) - } else { - return true, failures + for range groups { + select { + case res := <-results: + failures = append(failures, res.failures...) + if res.msg != nil { + if res.ede != "" { + resutil.SetMeta(w, "ede", res.ede) + } + u.writeSuccessResponse(w, res.msg, res.upstream, r.Question[0].Name, res.protocol, logger) + return true, failures + } + case <-ctx.Done(): + return false, failures } } return false, failures } -// queryUpstream queries a single upstream server. Returns nil on success, or failure info to try next upstream. -func (u *upstreamResolverBase) queryUpstream(parentCtx context.Context, w dns.ResponseWriter, r *dns.Msg, upstream netip.AddrPort, timeout time.Duration, logger *log.Entry) *upstreamFailure { - var rm *dns.Msg - var t time.Duration - var err error +func (u *upstreamResolverBase) tryRace(ctx context.Context, r *dns.Msg, group upstreamRace) raceResult { + timeout := u.upstreamTimeout + if len(group) > 1 { + // Cap the whole walk at raceMaxTotalTimeout: per-upstream timeouts + // still honor raceMinPerUpstreamTimeout as a floor for correctness + // on slow links, but the outer context ensures the combined walk + // cannot exceed the cap regardless of group size. + timeout = max(raceMaxTotalTimeout/time.Duration(len(group)), raceMinPerUpstreamTimeout) + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, raceMaxTotalTimeout) + defer cancel() + } + + var failures []upstreamFailure + for _, upstream := range group { + if ctx.Err() != nil { + return raceResult{failures: failures} + } + // Clone the request per attempt: the exchange path mutates EDNS0 + // options in-place, so reusing the same *dns.Msg across sequential + // upstreams would carry those mutations (e.g. a reduced UDP size) + // into the next attempt. + res, failure := u.queryUpstream(ctx, r.Copy(), upstream, timeout) + if failure != nil { + failures = append(failures, *failure) + continue + } + res.failures = failures + return res + } + return raceResult{failures: failures} +} + +func (u *upstreamResolverBase) queryUpstream(parentCtx context.Context, r *dns.Msg, upstream netip.AddrPort, timeout time.Duration) (raceResult, *upstreamFailure) { + ctx, cancel := context.WithTimeout(parentCtx, timeout) + defer cancel() + ctx, upstreamProto := contextWithUpstreamProtocolResult(ctx) // Advertise EDNS0 so the upstream may include Extended DNS Errors // (RFC 8914) in failure responses; we use those to short-circuit // failover for definitive answers like DNSSEC validation failures. - // Operate on a copy so the inbound request is unchanged: a client that - // did not advertise EDNS0 must not see an OPT in the response. + // The caller already passed a per-attempt copy, so we can mutate r + // directly; hadEdns reflects the original client request's state and + // controls whether we strip the OPT from the response. hadEdns := r.IsEdns0() != nil - reqUp := r if !hadEdns { - reqUp = r.Copy() - reqUp.SetEdns0(upstreamUDPSize(), false) + r.SetEdns0(upstreamUDPSize(), false) } - var startTime time.Time - var upstreamProto *upstreamProtocolResult - func() { - ctx, cancel := context.WithTimeout(parentCtx, timeout) - defer cancel() - ctx, upstreamProto = contextWithupstreamProtocolResult(ctx) - startTime = time.Now() - rm, t, err = u.upstreamClient.exchange(ctx, upstream.String(), reqUp) - }() + startTime := time.Now() + rm, _, err := u.upstreamClient.exchange(ctx, upstream.String(), r) if err != nil { - return u.handleUpstreamError(err, upstream, startTime) + // A parent cancellation (e.g., another race won and the coordinator + // cancelled the losers) is not an upstream failure. Check both the + // error chain and the parent context: a transport may surface the + // cancellation as a read/deadline error rather than context.Canceled. + if errors.Is(err, context.Canceled) || errors.Is(parentCtx.Err(), context.Canceled) { + return raceResult{}, &upstreamFailure{upstream: upstream, reason: "canceled"} + } + failure := u.handleUpstreamError(err, upstream, startTime) + u.markUpstreamFail(upstream, failure.reason) + return raceResult{}, failure } if rm == nil || !rm.Response { - return &upstreamFailure{upstream: upstream, reason: "no response"} + u.markUpstreamFail(upstream, "no response") + return raceResult{}, &upstreamFailure{upstream: upstream, reason: "no response"} + } + + proto := "" + if upstreamProto != nil { + proto = upstreamProto.protocol } if rm.Rcode == dns.RcodeServerFailure || rm.Rcode == dns.RcodeRefused { if code, ok := nonRetryableEDE(rm); ok { - resutil.SetMeta(w, "ede", edeName(code)) if !hadEdns { stripOPT(rm) } - u.writeSuccessResponse(w, rm, upstream, r.Question[0].Name, t, upstreamProto, logger) - return nil + u.markUpstreamOk(upstream) + return raceResult{msg: rm, upstream: upstream, protocol: proto, ede: edeName(code)}, nil } - return &upstreamFailure{upstream: upstream, reason: dns.RcodeToString[rm.Rcode]} + reason := dns.RcodeToString[rm.Rcode] + u.markUpstreamFail(upstream, reason) + return raceResult{}, &upstreamFailure{upstream: upstream, reason: reason} } if !hadEdns { stripOPT(rm) } - u.writeSuccessResponse(w, rm, upstream, r.Question[0].Name, t, upstreamProto, logger) - return nil + + u.markUpstreamOk(upstream) + return raceResult{msg: rm, upstream: upstream, protocol: proto}, nil +} + +// healthEntry returns the mutable health record for addr, lazily creating +// the map and the entry. Caller must hold u.healthMu. +func (u *upstreamResolverBase) healthEntry(addr netip.AddrPort) *UpstreamHealth { + if u.health == nil { + u.health = make(map[netip.AddrPort]*UpstreamHealth) + } + h := u.health[addr] + if h == nil { + h = &UpstreamHealth{} + u.health[addr] = h + } + return h +} + +func (u *upstreamResolverBase) markUpstreamOk(addr netip.AddrPort) { + u.healthMu.Lock() + defer u.healthMu.Unlock() + h := u.healthEntry(addr) + h.LastOk = time.Now() + h.LastFail = time.Time{} + h.LastErr = "" +} + +func (u *upstreamResolverBase) markUpstreamFail(addr netip.AddrPort, reason string) { + u.healthMu.Lock() + defer u.healthMu.Unlock() + h := u.healthEntry(addr) + h.LastFail = time.Now() + h.LastErr = reason +} + +// UpstreamHealth returns a snapshot of per-upstream query outcomes. +func (u *upstreamResolverBase) UpstreamHealth() map[netip.AddrPort]UpstreamHealth { + u.healthMu.RLock() + defer u.healthMu.RUnlock() + out := make(map[netip.AddrPort]UpstreamHealth, len(u.health)) + for k, v := range u.health { + out[k] = *v + } + return out } // upstreamUDPSize returns the EDNS0 UDP buffer size we advertise to upstreams, @@ -358,12 +549,23 @@ func (u *upstreamResolverBase) handleUpstreamError(err error, upstream netip.Add return &upstreamFailure{upstream: upstream, reason: reason} } -func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dns.Msg, upstream netip.AddrPort, domain string, t time.Duration, upstreamProto *upstreamProtocolResult, logger *log.Entry) bool { - u.successCount.Add(1) +func (u *upstreamResolverBase) debugUpstreamTimeout(upstream netip.AddrPort) string { + if u.statusRecorder == nil { + return "" + } + peerInfo := findPeerForIP(upstream.Addr(), u.statusRecorder) + if peerInfo == nil { + return "" + } + + return fmt.Sprintf("(routes through NetBird peer %s)", FormatPeerStatus(peerInfo)) +} + +func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dns.Msg, upstream netip.AddrPort, domain string, proto string, logger *log.Entry) { resutil.SetMeta(w, "upstream", upstream.String()) - if upstreamProto != nil && upstreamProto.protocol != "" { - resutil.SetMeta(w, "upstream_protocol", upstreamProto.protocol) + if proto != "" { + resutil.SetMeta(w, "upstream_protocol", proto) } // Clear Zero bit from external responses to prevent upstream servers from @@ -372,14 +574,11 @@ func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dn if err := w.WriteMsg(rm); err != nil { logger.Errorf("failed to write DNS response for question domain=%s: %s", domain, err) - return true } - - return true } func (u *upstreamResolverBase) logUpstreamFailures(domain string, failures []upstreamFailure, succeeded bool, logger *log.Entry) { - totalUpstreams := len(u.upstreamServers) + totalUpstreams := len(u.flatUpstreams()) failedCount := len(failures) failureSummary := formatFailures(failures) @@ -434,119 +633,6 @@ func edeName(code uint16) string { return fmt.Sprintf("EDE %d", code) } -// ProbeAvailability tests all upstream servers simultaneously and -// disables the resolver if none work -func (u *upstreamResolverBase) ProbeAvailability(ctx context.Context) { - u.mutex.Lock() - defer u.mutex.Unlock() - - // avoid probe if upstreams could resolve at least one query - if u.successCount.Load() > 0 { - return - } - - var success bool - var mu sync.Mutex - var wg sync.WaitGroup - - var errs *multierror.Error - for _, upstream := range u.upstreamServers { - wg.Add(1) - go func(upstream netip.AddrPort) { - defer wg.Done() - err := u.testNameserver(u.ctx, ctx, upstream, 500*time.Millisecond) - if err != nil { - mu.Lock() - errs = multierror.Append(errs, err) - mu.Unlock() - log.Warnf("probing upstream nameserver %s: %s", upstream, err) - return - } - - mu.Lock() - success = true - mu.Unlock() - }(upstream) - } - - wg.Wait() - - select { - case <-ctx.Done(): - return - case <-u.ctx.Done(): - return - default: - } - - // didn't find a working upstream server, let's disable and try later - if !success { - u.disable(errs.ErrorOrNil()) - - if u.statusRecorder == nil { - return - } - - u.statusRecorder.PublishEvent( - proto.SystemEvent_WARNING, - proto.SystemEvent_DNS, - "All upstream servers failed (probe failed)", - "Unable to reach one or more DNS servers. This might affect your ability to connect to some services.", - map[string]string{"upstreams": u.upstreamServersString()}, - ) - } -} - -// waitUntilResponse retries, in an exponential interval, querying the upstream servers until it gets a positive response -func (u *upstreamResolverBase) waitUntilResponse() { - exponentialBackOff := &backoff.ExponentialBackOff{ - InitialInterval: 500 * time.Millisecond, - RandomizationFactor: 0.5, - Multiplier: 1.1, - MaxInterval: u.reactivatePeriod, - MaxElapsedTime: 0, - Stop: backoff.Stop, - Clock: backoff.SystemClock, - } - - operation := func() error { - select { - case <-u.ctx.Done(): - return backoff.Permanent(fmt.Errorf("exiting upstream retry loop for upstreams %s: parent context has been canceled", u.upstreamServersString())) - default: - } - - for _, upstream := range u.upstreamServers { - if err := u.testNameserver(u.ctx, nil, upstream, probeTimeout); err != nil { - log.Tracef("upstream check for %s: %s", upstream, err) - } else { - // at least one upstream server is available, stop probing - return nil - } - } - - log.Tracef("checking connectivity with upstreams %s failed. Retrying in %s", u.upstreamServersString(), exponentialBackOff.NextBackOff()) - return fmt.Errorf("upstream check call error") - } - - err := backoff.Retry(operation, backoff.WithContext(exponentialBackOff, u.ctx)) - if err != nil { - if errors.Is(err, context.Canceled) { - log.Debugf("upstream retry loop exited for upstreams %s", u.upstreamServersString()) - } else { - log.Warnf("upstream retry loop exited for upstreams %s: %v", u.upstreamServersString(), err) - } - return - } - - log.Infof("upstreams %s are responsive again. Adding them back to system", u.upstreamServersString()) - u.successCount.Add(1) - u.reactivate() - u.mutex.Lock() - u.disabled = false - u.mutex.Unlock() -} - // isTimeout returns true if the given error is a network timeout error. // // Copied from k8s.io/apimachinery/pkg/util/net.IsTimeout @@ -558,45 +644,6 @@ func isTimeout(err error) bool { return false } -func (u *upstreamResolverBase) disable(err error) { - if u.disabled { - return - } - - log.Warnf("Upstream resolving is Disabled for %v", reactivatePeriod) - u.successCount.Store(0) - u.deactivate(err) - u.disabled = true - u.wg.Add(1) - go func() { - defer u.wg.Done() - u.waitUntilResponse() - }() -} - -func (u *upstreamResolverBase) upstreamServersString() string { - var servers []string - for _, server := range u.upstreamServers { - servers = append(servers, server.String()) - } - return strings.Join(servers, ", ") -} - -func (u *upstreamResolverBase) testNameserver(baseCtx context.Context, externalCtx context.Context, server netip.AddrPort, timeout time.Duration) error { - mergedCtx, cancel := context.WithTimeout(baseCtx, timeout) - defer cancel() - - if externalCtx != nil { - stop2 := context.AfterFunc(externalCtx, cancel) - defer stop2() - } - - r := new(dns.Msg).SetQuestion(testRecord, dns.TypeSOA) - - _, _, err := u.upstreamClient.exchange(mergedCtx, server.String(), r) - return err -} - // clientUDPMaxSize returns the maximum UDP response size the client accepts. func clientUDPMaxSize(r *dns.Msg) int { if opt := r.IsEdns0(); opt != nil { @@ -608,13 +655,10 @@ func clientUDPMaxSize(r *dns.Msg) int { // ExchangeWithFallback exchanges a DNS message with the upstream server. // It first tries to use UDP, and if it is truncated, it falls back to TCP. // If the inbound request came over TCP (via context), it skips the UDP attempt. -// If the passed context is nil, this will use Exchange instead of ExchangeContext. func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, upstream string) (*dns.Msg, time.Duration, error) { // If the request came in over TCP, go straight to TCP upstream. if dnsProtocolFromContext(ctx) == protoTCP { - tcpClient := *client - tcpClient.Net = protoTCP - rm, t, err := tcpClient.ExchangeContext(ctx, r, upstream) + rm, t, err := toTCPClient(client).ExchangeContext(ctx, r, upstream) if err != nil { return nil, t, fmt.Errorf("with tcp: %w", err) } @@ -634,18 +678,7 @@ func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, u opt.SetUDPSize(maxUDPPayload) } - var ( - rm *dns.Msg - t time.Duration - err error - ) - - if ctx == nil { - rm, t, err = client.Exchange(r, upstream) - } else { - rm, t, err = client.ExchangeContext(ctx, r, upstream) - } - + rm, t, err := client.ExchangeContext(ctx, r, upstream) if err != nil { return nil, t, fmt.Errorf("with udp: %w", err) } @@ -659,15 +692,7 @@ func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, u // data than the client's buffer, we could truncate locally and skip // the TCP retry. - tcpClient := *client - tcpClient.Net = protoTCP - - if ctx == nil { - rm, t, err = tcpClient.Exchange(r, upstream) - } else { - rm, t, err = tcpClient.ExchangeContext(ctx, r, upstream) - } - + rm, t, err = toTCPClient(client).ExchangeContext(ctx, r, upstream) if err != nil { return nil, t, fmt.Errorf("with tcp: %w", err) } @@ -681,6 +706,25 @@ func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, u return rm, t, nil } +// toTCPClient returns a copy of c configured for TCP. If c's Dialer has a +// *net.UDPAddr bound as LocalAddr (iOS does this to keep the source IP on +// the tunnel interface), it is converted to the equivalent *net.TCPAddr +// so net.Dialer doesn't reject the TCP dial with "mismatched local +// address type". +func toTCPClient(c *dns.Client) *dns.Client { + tcp := *c + tcp.Net = protoTCP + if tcp.Dialer == nil { + return &tcp + } + d := *tcp.Dialer + if ua, ok := d.LocalAddr.(*net.UDPAddr); ok { + d.LocalAddr = &net.TCPAddr{IP: ua.IP, Port: ua.Port, Zone: ua.Zone} + } + tcp.Dialer = &d + return &tcp +} + // ExchangeWithNetstack performs a DNS exchange using netstack for dialing. // This is needed when netstack is enabled to reach peer IPs through the tunnel. func ExchangeWithNetstack(ctx context.Context, nsNet *netstack.Net, r *dns.Msg, upstream string) (*dns.Msg, error) { @@ -822,15 +866,36 @@ func findPeerForIP(ip netip.Addr, statusRecorder *peer.Status) *peer.State { return bestMatch } -func (u *upstreamResolverBase) debugUpstreamTimeout(upstream netip.AddrPort) string { - if u.statusRecorder == nil { - return "" +// haMapRouteCount returns the total number of routes across all HA +// groups in the map. route.HAMap is keyed by HAUniqueID with slices of +// routes per key, so len(hm) is the number of HA groups, not routes. +func haMapRouteCount(hm route.HAMap) int { + total := 0 + for _, routes := range hm { + total += len(routes) } - - peerInfo := findPeerForIP(upstream.Addr(), u.statusRecorder) - if peerInfo == nil { - return "" - } - - return fmt.Sprintf("(routes through NetBird peer %s)", FormatPeerStatus(peerInfo)) + return total +} + +// haMapContains checks whether ip is covered by any concrete prefix in +// the HA map. haveDynamic is reported separately: dynamic (domain-based) +// routes carry a placeholder Network that can't be prefix-checked, so we +// can't know at this point whether ip is reached through one. Callers +// decide how to interpret the unknown: health projection treats it as +// "possibly routed" to avoid emitting false-positive warnings during +// startup, while iOS dial selection requires a concrete match before +// binding to the tunnel. +func haMapContains(hm route.HAMap, ip netip.Addr) (matched, haveDynamic bool) { + for _, routes := range hm { + for _, r := range routes { + if r.IsDynamic() { + haveDynamic = true + continue + } + if r.Network.Contains(ip) { + return true, haveDynamic + } + } + } + return false, haveDynamic } diff --git a/client/internal/dns/upstream_android.go b/client/internal/dns/upstream_android.go index 988adb7d2..f7ab48b10 100644 --- a/client/internal/dns/upstream_android.go +++ b/client/internal/dns/upstream_android.go @@ -11,6 +11,7 @@ import ( "github.com/netbirdio/netbird/client/internal/peer" nbnet "github.com/netbirdio/netbird/client/net" + "github.com/netbirdio/netbird/shared/management/domain" ) type upstreamResolver struct { @@ -26,9 +27,9 @@ func newUpstreamResolver( _ WGIface, statusRecorder *peer.Status, hostsDNSHolder *hostsDNSHolder, - domain string, + d domain.Domain, ) (*upstreamResolver, error) { - upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, domain) + upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, d) c := &upstreamResolver{ upstreamResolverBase: upstreamResolverBase, hostsDNSHolder: hostsDNSHolder, diff --git a/client/internal/dns/upstream_general.go b/client/internal/dns/upstream_general.go index 910c3779e..dc841757b 100644 --- a/client/internal/dns/upstream_general.go +++ b/client/internal/dns/upstream_general.go @@ -12,6 +12,7 @@ import ( "golang.zx2c4.com/wireguard/tun/netstack" "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/shared/management/domain" ) type upstreamResolver struct { @@ -24,9 +25,9 @@ func newUpstreamResolver( wgIface WGIface, statusRecorder *peer.Status, _ *hostsDNSHolder, - domain string, + d domain.Domain, ) (*upstreamResolver, error) { - upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, domain) + upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, d) nonIOS := &upstreamResolver{ upstreamResolverBase: upstreamResolverBase, nsNet: wgIface.GetNet(), diff --git a/client/internal/dns/upstream_ios.go b/client/internal/dns/upstream_ios.go index 0e04742a0..b989bf0f9 100644 --- a/client/internal/dns/upstream_ios.go +++ b/client/internal/dns/upstream_ios.go @@ -15,6 +15,7 @@ import ( "golang.org/x/sys/unix" "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/shared/management/domain" ) type upstreamResolverIOS struct { @@ -27,9 +28,9 @@ func newUpstreamResolver( wgIface WGIface, statusRecorder *peer.Status, _ *hostsDNSHolder, - domain string, + d domain.Domain, ) (*upstreamResolverIOS, error) { - upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, domain) + upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, d) ios := &upstreamResolverIOS{ upstreamResolverBase: upstreamResolverBase, @@ -62,9 +63,16 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r * upstreamIP = upstreamIP.Unmap() } addr := u.wgIface.Address() + var routed bool + if u.selectedRoutes != nil { + // Only a concrete prefix match binds to the tunnel: dialing + // through a private client for an upstream we can't prove is + // routed would break public resolvers. + routed, _ = haMapContains(u.selectedRoutes(), upstreamIP) + } needsPrivate := addr.Network.Contains(upstreamIP) || addr.IPv6Net.Contains(upstreamIP) || - (u.routeMatch != nil && u.routeMatch(upstreamIP)) + routed if needsPrivate { log.Debugf("using private client to query %s via upstream %s", r.Question[0].Name, upstream) client, err = GetClientPrivate(u.wgIface, upstreamIP, timeout) @@ -73,8 +81,7 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r * } } - // Cannot use client.ExchangeContext because it overwrites our Dialer - return ExchangeWithFallback(nil, client, r, upstream) + return ExchangeWithFallback(ctx, client, r, upstream) } // GetClientPrivate returns a new DNS client bound to the local IP of the Netbird interface. diff --git a/client/internal/dns/upstream_test.go b/client/internal/dns/upstream_test.go index d6aec05ca..8b3c589f1 100644 --- a/client/internal/dns/upstream_test.go +++ b/client/internal/dns/upstream_test.go @@ -6,6 +6,7 @@ import ( "net" "net/netip" "strings" + "sync/atomic" "testing" "time" @@ -73,7 +74,7 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) { servers = append(servers, netip.AddrPortFrom(addrPort.Addr().Unmap(), addrPort.Port())) } } - resolver.upstreamServers = servers + resolver.addRace(servers) resolver.upstreamTimeout = testCase.timeout if testCase.cancelCTX { cancel() @@ -132,20 +133,10 @@ func (m *mockNetstackProvider) GetInterfaceGUIDString() (string, error) { return "", nil } -type mockUpstreamResolver struct { - r *dns.Msg - rtt time.Duration - err error -} - -// exchange mock implementation of exchange from upstreamResolver -func (c mockUpstreamResolver) exchange(_ context.Context, _ string, _ *dns.Msg) (*dns.Msg, time.Duration, error) { - return c.r, c.rtt, c.err -} - type mockUpstreamResponse struct { - msg *dns.Msg - err error + msg *dns.Msg + err error + delay time.Duration } type mockUpstreamResolverPerServer struct { @@ -153,63 +144,19 @@ type mockUpstreamResolverPerServer struct { rtt time.Duration } -func (c mockUpstreamResolverPerServer) exchange(_ context.Context, upstream string, _ *dns.Msg) (*dns.Msg, time.Duration, error) { - if r, ok := c.responses[upstream]; ok { - return r.msg, c.rtt, r.err +func (c mockUpstreamResolverPerServer) exchange(ctx context.Context, upstream string, _ *dns.Msg) (*dns.Msg, time.Duration, error) { + r, ok := c.responses[upstream] + if !ok { + return nil, c.rtt, fmt.Errorf("no mock response for %s", upstream) } - return nil, c.rtt, fmt.Errorf("no mock response for %s", upstream) -} - -func TestUpstreamResolver_DeactivationReactivation(t *testing.T) { - mockClient := &mockUpstreamResolver{ - err: dns.ErrTime, - r: new(dns.Msg), - rtt: time.Millisecond, - } - - resolver := &upstreamResolverBase{ - ctx: context.TODO(), - upstreamClient: mockClient, - upstreamTimeout: UpstreamTimeout, - reactivatePeriod: time.Microsecond * 100, - } - addrPort, _ := netip.ParseAddrPort("0.0.0.0:1") // Use valid port for parsing, test will still fail on connection - resolver.upstreamServers = []netip.AddrPort{netip.AddrPortFrom(addrPort.Addr().Unmap(), addrPort.Port())} - - failed := false - resolver.deactivate = func(error) { - failed = true - // After deactivation, make the mock client work again - mockClient.err = nil - } - - reactivated := false - resolver.reactivate = func() { - reactivated = true - } - - resolver.ProbeAvailability(context.TODO()) - - if !failed { - t.Errorf("expected that resolving was deactivated") - return - } - - if !resolver.disabled { - t.Errorf("resolver should be Disabled") - return - } - - time.Sleep(time.Millisecond * 200) - - if !reactivated { - t.Errorf("expected that resolving was reactivated") - return - } - - if resolver.disabled { - t.Errorf("should be enabled") + if r.delay > 0 { + select { + case <-time.After(r.delay): + case <-ctx.Done(): + return nil, c.rtt, ctx.Err() + } } + return r.msg, c.rtt, r.err } func TestUpstreamResolver_Failover(t *testing.T) { @@ -339,9 +286,9 @@ func TestUpstreamResolver_Failover(t *testing.T) { resolver := &upstreamResolverBase{ ctx: ctx, upstreamClient: trackingClient, - upstreamServers: []netip.AddrPort{upstream1, upstream2}, upstreamTimeout: UpstreamTimeout, } + resolver.addRace([]netip.AddrPort{upstream1, upstream2}) var responseMSG *dns.Msg responseWriter := &test.MockResponseWriter{ @@ -421,9 +368,9 @@ func TestUpstreamResolver_SingleUpstreamFailure(t *testing.T) { resolver := &upstreamResolverBase{ ctx: ctx, upstreamClient: mockClient, - upstreamServers: []netip.AddrPort{upstream}, upstreamTimeout: UpstreamTimeout, } + resolver.addRace([]netip.AddrPort{upstream}) var responseMSG *dns.Msg responseWriter := &test.MockResponseWriter{ @@ -440,6 +387,136 @@ func TestUpstreamResolver_SingleUpstreamFailure(t *testing.T) { assert.Equal(t, dns.RcodeServerFailure, responseMSG.Rcode, "single upstream SERVFAIL should return SERVFAIL") } +// TestUpstreamResolver_RaceAcrossGroups covers two nameserver groups +// configured for the same domain, with one broken group. The merge+race +// path should answer as fast as the working group and not pay the timeout +// of the broken one on every query. +func TestUpstreamResolver_RaceAcrossGroups(t *testing.T) { + broken := netip.MustParseAddrPort("192.0.2.1:53") + working := netip.MustParseAddrPort("192.0.2.2:53") + successAnswer := "192.0.2.100" + timeoutErr := &net.OpError{Op: "read", Err: fmt.Errorf("i/o timeout")} + + mockClient := &mockUpstreamResolverPerServer{ + responses: map[string]mockUpstreamResponse{ + // Force the broken upstream to only unblock via timeout / + // cancellation so the assertion below can't pass if races + // were run serially. + broken.String(): {err: timeoutErr, delay: 500 * time.Millisecond}, + working.String(): {msg: buildMockResponse(dns.RcodeSuccess, successAnswer)}, + }, + rtt: time.Millisecond, + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + resolver := &upstreamResolverBase{ + ctx: ctx, + upstreamClient: mockClient, + upstreamTimeout: 250 * time.Millisecond, + } + resolver.addRace([]netip.AddrPort{broken}) + resolver.addRace([]netip.AddrPort{working}) + + var responseMSG *dns.Msg + responseWriter := &test.MockResponseWriter{ + WriteMsgFunc: func(m *dns.Msg) error { + responseMSG = m + return nil + }, + } + + inputMSG := new(dns.Msg).SetQuestion("example.com.", dns.TypeA) + start := time.Now() + resolver.ServeDNS(responseWriter, inputMSG) + elapsed := time.Since(start) + + require.NotNil(t, responseMSG, "should write a response") + assert.Equal(t, dns.RcodeSuccess, responseMSG.Rcode) + require.NotEmpty(t, responseMSG.Answer) + assert.Contains(t, responseMSG.Answer[0].String(), successAnswer) + // Working group answers in a single RTT; the broken group's + // timeout (100ms) must not block the response. + assert.Less(t, elapsed, 100*time.Millisecond, "race must not wait for broken group's timeout") +} + +// TestUpstreamResolver_AllGroupsFail checks that when every group fails the +// resolver returns SERVFAIL rather than leaking a partial response. +func TestUpstreamResolver_AllGroupsFail(t *testing.T) { + a := netip.MustParseAddrPort("192.0.2.1:53") + b := netip.MustParseAddrPort("192.0.2.2:53") + + mockClient := &mockUpstreamResolverPerServer{ + responses: map[string]mockUpstreamResponse{ + a.String(): {msg: buildMockResponse(dns.RcodeServerFailure, "")}, + b.String(): {msg: buildMockResponse(dns.RcodeServerFailure, "")}, + }, + rtt: time.Millisecond, + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + resolver := &upstreamResolverBase{ + ctx: ctx, + upstreamClient: mockClient, + upstreamTimeout: UpstreamTimeout, + } + resolver.addRace([]netip.AddrPort{a}) + resolver.addRace([]netip.AddrPort{b}) + + var responseMSG *dns.Msg + responseWriter := &test.MockResponseWriter{ + WriteMsgFunc: func(m *dns.Msg) error { + responseMSG = m + return nil + }, + } + + resolver.ServeDNS(responseWriter, new(dns.Msg).SetQuestion("example.com.", dns.TypeA)) + require.NotNil(t, responseMSG) + assert.Equal(t, dns.RcodeServerFailure, responseMSG.Rcode) +} + +// TestUpstreamResolver_HealthTracking verifies that query-path results are +// recorded into per-upstream health, which is what projects back to +// NSGroupState for status reporting. +func TestUpstreamResolver_HealthTracking(t *testing.T) { + ok := netip.MustParseAddrPort("192.0.2.10:53") + bad := netip.MustParseAddrPort("192.0.2.11:53") + + mockClient := &mockUpstreamResolverPerServer{ + responses: map[string]mockUpstreamResponse{ + ok.String(): {msg: buildMockResponse(dns.RcodeSuccess, "192.0.2.100")}, + bad.String(): {msg: buildMockResponse(dns.RcodeServerFailure, "")}, + }, + rtt: time.Millisecond, + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + resolver := &upstreamResolverBase{ + ctx: ctx, + upstreamClient: mockClient, + upstreamTimeout: UpstreamTimeout, + } + resolver.addRace([]netip.AddrPort{ok, bad}) + + responseWriter := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { return nil }} + resolver.ServeDNS(responseWriter, new(dns.Msg).SetQuestion("example.com.", dns.TypeA)) + + health := resolver.UpstreamHealth() + require.Contains(t, health, ok) + assert.False(t, health[ok].LastOk.IsZero(), "ok upstream should have LastOk set") + assert.Empty(t, health[ok].LastErr) + + // bad upstream was never tried because ok answered first; its health + // should remain unset. + assert.NotContains(t, health, bad, "sibling upstream should not be queried when primary answers") +} + func TestFormatFailures(t *testing.T) { testCases := []struct { name string @@ -665,10 +742,10 @@ func TestExchangeWithFallback_EDNS0Capped(t *testing.T) { // Verify that a client EDNS0 larger than our MTU-derived limit gets // capped in the outgoing request so the upstream doesn't send a // response larger than our read buffer. - var receivedUDPSize uint16 + var receivedUDPSize atomic.Uint32 udpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) { if opt := r.IsEdns0(); opt != nil { - receivedUDPSize = opt.UDPSize() + receivedUDPSize.Store(uint32(opt.UDPSize())) } m := new(dns.Msg) m.SetReply(r) @@ -699,7 +776,7 @@ func TestExchangeWithFallback_EDNS0Capped(t *testing.T) { require.NotNil(t, rm) expectedMax := uint16(currentMTU - ipUDPHeaderSize) - assert.Equal(t, expectedMax, receivedUDPSize, + assert.Equal(t, expectedMax, uint16(receivedUDPSize.Load()), "upstream should see capped EDNS0, not the client's 4096") } @@ -874,7 +951,7 @@ func TestUpstreamResolver_NonRetryableEDEShortCircuits(t *testing.T) { resolver := &upstreamResolverBase{ ctx: ctx, upstreamClient: tracking, - upstreamServers: []netip.AddrPort{upstream1, upstream2}, + upstreamServers: []upstreamRace{{upstream1, upstream2}}, upstreamTimeout: UpstreamTimeout, } diff --git a/client/internal/engine.go b/client/internal/engine.go index 66fe6056b..3bd0d4621 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -512,16 +512,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL) e.routeManager.SetRouteChangeListener(e.mobileDep.NetworkChangeListener) - e.dnsServer.SetRouteChecker(func(ip netip.Addr) bool { - for _, routes := range e.routeManager.GetSelectedClientRoutes() { - for _, r := range routes { - if r.Network.Contains(ip) { - return true - } - } - } - return false - }) + e.dnsServer.SetRouteSources(e.routeManager.GetSelectedClientRoutes, e.routeManager.GetActiveClientRoutes) if err = e.wgInterfaceCreate(); err != nil { log.Errorf("failed creating tunnel interface %s: [%s]", e.config.WgIfaceName, err.Error()) @@ -1386,9 +1377,6 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error { e.networkSerial = serial - // Test received (upstream) servers for availability right away instead of upon usage. - // If no server of a server group responds this will disable the respective handler and retry later. - go e.dnsServer.ProbeAvailability() return nil } @@ -1932,7 +1920,7 @@ func (e *Engine) newDnsServer(dnsConfig *nbdns.Config) (dns.Server, error) { return dnsServer, nil case "ios": - dnsServer := dns.NewDefaultServerIos(e.ctx, e.wgInterface, e.mobileDep.DnsManager, e.mobileDep.HostDNSAddresses, e.statusRecorder, e.config.DisableDNS) + dnsServer := dns.NewDefaultServerIos(e.ctx, e.wgInterface, e.mobileDep.DnsManager, e.statusRecorder, e.config.DisableDNS) return dnsServer, nil default: diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index 907f1f592..839ec14c0 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -53,6 +53,7 @@ type Manager interface { GetRouteSelector() *routeselector.RouteSelector GetClientRoutes() route.HAMap GetSelectedClientRoutes() route.HAMap + GetActiveClientRoutes() route.HAMap GetClientRoutesWithNetID() map[route.NetID][]*route.Route SetRouteChangeListener(listener listener.NetworkChangeListener) InitialRouteRange() []string @@ -485,6 +486,39 @@ func (m *DefaultManager) GetSelectedClientRoutes() route.HAMap { return m.routeSelector.FilterSelectedExitNodes(maps.Clone(m.clientRoutes)) } +// GetActiveClientRoutes returns the subset of selected client routes +// that are currently reachable: the route's peer is Connected and is +// the one actively carrying the route (not just an HA sibling). +func (m *DefaultManager) GetActiveClientRoutes() route.HAMap { + m.mux.Lock() + selected := m.routeSelector.FilterSelectedExitNodes(maps.Clone(m.clientRoutes)) + recorder := m.statusRecorder + m.mux.Unlock() + + if recorder == nil { + return selected + } + + out := make(route.HAMap, len(selected)) + for id, routes := range selected { + for _, r := range routes { + st, err := recorder.GetPeer(r.Peer) + if err != nil { + continue + } + if st.ConnStatus != peer.StatusConnected { + continue + } + if _, hasRoute := st.GetRoutes()[r.Network.String()]; !hasRoute { + continue + } + out[id] = routes + break + } + } + return out +} + // GetClientRoutesWithNetID returns the current routes from the route map, but the keys consist of the network ID only func (m *DefaultManager) GetClientRoutesWithNetID() map[route.NetID][]*route.Route { m.mux.Lock() diff --git a/client/internal/routemanager/mock.go b/client/internal/routemanager/mock.go index 66b5e30dd..937314995 100644 --- a/client/internal/routemanager/mock.go +++ b/client/internal/routemanager/mock.go @@ -19,6 +19,7 @@ type MockManager struct { GetRouteSelectorFunc func() *routeselector.RouteSelector GetClientRoutesFunc func() route.HAMap GetSelectedClientRoutesFunc func() route.HAMap + GetActiveClientRoutesFunc func() route.HAMap GetClientRoutesWithNetIDFunc func() map[route.NetID][]*route.Route StopFunc func(manager *statemanager.Manager) } @@ -78,6 +79,14 @@ func (m *MockManager) GetSelectedClientRoutes() route.HAMap { return nil } +// GetActiveClientRoutes mock implementation of GetActiveClientRoutes from the Manager interface +func (m *MockManager) GetActiveClientRoutes() route.HAMap { + if m.GetActiveClientRoutesFunc != nil { + return m.GetActiveClientRoutesFunc() + } + return nil +} + // GetClientRoutesWithNetID mock implementation of GetClientRoutesWithNetID from Manager interface func (m *MockManager) GetClientRoutesWithNetID() map[route.NetID][]*route.Route { if m.GetClientRoutesWithNetIDFunc != nil { diff --git a/client/ios/NetBirdSDK/client.go b/client/ios/NetBirdSDK/client.go index 33f5ab1b0..bafbb0031 100644 --- a/client/ios/NetBirdSDK/client.go +++ b/client/ios/NetBirdSDK/client.go @@ -162,11 +162,7 @@ func (c *Client) Run(fd int32, interfaceName string, envList *EnvList) error { cfg.WgIface = interfaceName c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder) - hostDNS := []netip.AddrPort{ - netip.MustParseAddrPort("9.9.9.9:53"), - netip.MustParseAddrPort("149.112.112.112:53"), - } - return c.connectClient.RunOniOS(fd, c.networkChangeListener, c.dnsManager, hostDNS, c.stateFile) + return c.connectClient.RunOniOS(fd, c.networkChangeListener, c.dnsManager, c.stateFile) } // Stop the internal client and free the resources From e916f12cca508dfea584e7b72cf99a135acebc2b Mon Sep 17 00:00:00 2001 From: Vlad <4941176+crn4@users.noreply.github.com> Date: Fri, 15 May 2026 19:13:44 +0200 Subject: [PATCH 08/31] [proxy] auth token generation on mapping (#6157) * [management / proxy] auth token generation on mapping * fix tests --- management/internals/shared/grpc/proxy.go | 15 +++--- .../shared/grpc/proxy_snapshot_test.go | 53 +++++++++++++++++++ .../shared/grpc/validate_session_test.go | 14 +++-- 3 files changed, 73 insertions(+), 9 deletions(-) diff --git a/management/internals/shared/grpc/proxy.go b/management/internals/shared/grpc/proxy.go index 9e5027547..eada2d86a 100644 --- a/management/internals/shared/grpc/proxy.go +++ b/management/internals/shared/grpc/proxy.go @@ -394,6 +394,13 @@ func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnec if end > len(mappings) { end = len(mappings) } + for _, m := range mappings[i:end] { + token, err := s.tokenStore.GenerateToken(m.AccountId, m.Id, s.proxyTokenTTL()) + if err != nil { + return fmt.Errorf("generate auth token for service %s: %w", m.Id, err) + } + m.AuthToken = token + } if err := conn.stream.Send(&proto.GetMappingUpdateResponse{ Mapping: mappings[i:end], InitialSyncComplete: end == len(mappings), @@ -425,18 +432,14 @@ func (s *ProxyServiceServer) snapshotServiceMappings(ctx context.Context, conn * return nil, fmt.Errorf("get services from store: %w", err) } + oidcCfg := s.GetOIDCValidationConfig() var mappings []*proto.ProxyMapping for _, service := range services { if !service.Enabled || service.ProxyCluster == "" || service.ProxyCluster != conn.address { continue } - token, err := s.tokenStore.GenerateToken(service.AccountID, service.ID, s.proxyTokenTTL()) - if err != nil { - return nil, fmt.Errorf("generate auth token for service %s: %w", service.ID, err) - } - - m := service.ToProtoMapping(rpservice.Create, token, s.GetOIDCValidationConfig()) + m := service.ToProtoMapping(rpservice.Create, "", oidcCfg) if !proxyAcceptsMapping(conn, m) { continue } diff --git a/management/internals/shared/grpc/proxy_snapshot_test.go b/management/internals/shared/grpc/proxy_snapshot_test.go index e0c7425c5..68d2ecfd1 100644 --- a/management/internals/shared/grpc/proxy_snapshot_test.go +++ b/management/internals/shared/grpc/proxy_snapshot_test.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "testing" + "time" "github.com/golang/mock/gomock" "github.com/stretchr/testify/assert" @@ -172,3 +173,55 @@ func TestSendSnapshot_EmptySnapshot(t *testing.T) { assert.Empty(t, stream.messages[0].Mapping) assert.True(t, stream.messages[0].InitialSyncComplete) } + +type hookingStream struct { + grpc.ServerStream + onSend func(*proto.GetMappingUpdateResponse) +} + +func (s *hookingStream) Send(m *proto.GetMappingUpdateResponse) error { + if s.onSend != nil { + s.onSend(m) + } + return nil +} + +func (s *hookingStream) Context() context.Context { return context.Background() } +func (s *hookingStream) SetHeader(metadata.MD) error { return nil } +func (s *hookingStream) SendHeader(metadata.MD) error { return nil } +func (s *hookingStream) SetTrailer(metadata.MD) {} +func (s *hookingStream) SendMsg(any) error { return nil } +func (s *hookingStream) RecvMsg(any) error { return nil } + +func TestSendSnapshot_TokensRemainValidUnderSlowSend(t *testing.T) { + const cluster = "cluster.example.com" + const batchSize = 2 + const totalServices = 6 + const ttl = 100 * time.Millisecond + const sendDelay = 200 * time.Millisecond + + ctrl := gomock.NewController(t) + mgr := rpservice.NewMockManager(ctrl) + mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(makeServices(totalServices, cluster), nil) + + s := newSnapshotTestServer(t, batchSize) + s.serviceManager = mgr + s.tokenTTL = ttl + + var validateErrs []error + stream := &hookingStream{ + onSend: func(resp *proto.GetMappingUpdateResponse) { + for _, m := range resp.Mapping { + if err := s.tokenStore.ValidateAndConsume(m.AuthToken, m.AccountId, m.Id); err != nil { + validateErrs = append(validateErrs, fmt.Errorf("svc %s: %w", m.Id, err)) + } + } + time.Sleep(sendDelay) + }, + } + conn := &proxyConnection{proxyID: "proxy-a", address: cluster, stream: stream} + + require.NoError(t, s.sendSnapshot(context.Background(), conn)) + require.Empty(t, validateErrs, + "tokens must remain valid even when batches are sent slowly: lazy per-batch generation guarantees freshness") +} diff --git a/management/internals/shared/grpc/validate_session_test.go b/management/internals/shared/grpc/validate_session_test.go index 6cd95f988..7b7ffcfb2 100644 --- a/management/internals/shared/grpc/validate_session_test.go +++ b/management/internals/shared/grpc/validate_session_test.go @@ -326,17 +326,25 @@ func (m *testValidateSessionServiceManager) GetActiveClusters(_ context.Context, return nil, nil } +func (m *testValidateSessionServiceManager) DeleteAccountCluster(_ context.Context, _, _, _ string) error { + return nil +} + type testValidateSessionProxyManager struct{} -func (m *testValidateSessionProxyManager) Connect(_ context.Context, _, _, _ string, _ *string, _ *proxy.Capabilities) error { +func (m *testValidateSessionProxyManager) Connect(_ context.Context, _, _, _, _ string, _ *string, _ *proxy.Capabilities) (*proxy.Proxy, error) { + return nil, nil +} + +func (m *testValidateSessionProxyManager) Disconnect(_ context.Context, _, _ string) error { return nil } -func (m *testValidateSessionProxyManager) Disconnect(_ context.Context, _ string) error { +func (m *testValidateSessionProxyManager) Heartbeat(_ context.Context, _ *proxy.Proxy) error { return nil } -func (m *testValidateSessionProxyManager) Heartbeat(_ context.Context, _, _, _ string) error { +func (m *testValidateSessionProxyManager) DeleteAccountCluster(_ context.Context, _, _ string) error { return nil } From 22e2519d7113dffec718198e54474cc0a6d71c87 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Sat, 16 May 2026 22:51:48 +0900 Subject: [PATCH 09/31] [management] Avoid peer IP reallocation when account settings update preserves the network range (#6173) --- management/server/account.go | 37 +++++++++++-- management/server/account_test.go | 90 +++++++++++++++++++++++++++++++ 2 files changed, 124 insertions(+), 3 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index 77a46a069..e7b4acaac 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -291,10 +291,15 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco return nil, status.NewPermissionDeniedError() } + // Canonicalize the incoming range so a caller-supplied prefix with host bits + // (e.g. 100.64.1.1/16) compares equal to the masked form stored on network.Net. + newSettings.NetworkRange = newSettings.NetworkRange.Masked() + var oldSettings *types.Settings var updateAccountPeers bool var groupChangesAffectPeers bool var reloadReverseProxy bool + var effectiveOldNetworkRange netip.Prefix err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { var groupsUpdated bool @@ -308,6 +313,16 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco return err } + // No lock: the transaction already holds Settings(Update), and network.Net is + // only mutated by reallocateAccountPeerIPs, which is reachable only through + // this same code path. A Share lock here would extend an unnecessary row lock + // and complicate ordering against updatePeerIPv6InTransaction. + network, err := transaction.GetAccountNetwork(ctx, store.LockingStrengthNone, accountID) + if err != nil { + return fmt.Errorf("get account network: %w", err) + } + effectiveOldNetworkRange = prefixFromIPNet(network.Net) + if oldSettings.Extra != nil && newSettings.Extra != nil && oldSettings.Extra.PeerApprovalEnabled && !newSettings.Extra.PeerApprovalEnabled { approvedCount, err := transaction.ApproveAccountPeers(ctx, accountID) @@ -321,7 +336,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco } } - if oldSettings.NetworkRange != newSettings.NetworkRange { + if newSettings.NetworkRange.IsValid() && newSettings.NetworkRange != effectiveOldNetworkRange { if err = am.reallocateAccountPeerIPs(ctx, transaction, accountID, newSettings.NetworkRange); err != nil { return err } @@ -396,9 +411,9 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco } am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountDNSDomainUpdated, eventMeta) } - if oldSettings.NetworkRange != newSettings.NetworkRange { + if newSettings.NetworkRange.IsValid() && newSettings.NetworkRange != effectiveOldNetworkRange { eventMeta := map[string]any{ - "old_network_range": oldSettings.NetworkRange.String(), + "old_network_range": effectiveOldNetworkRange.String(), "new_network_range": newSettings.NetworkRange.String(), } am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountNetworkRangeUpdated, eventMeta) @@ -443,6 +458,22 @@ func ipv6SettingsChanged(old, updated *types.Settings) bool { return !slices.Equal(oldGroups, newGroups) } +// prefixFromIPNet returns the overlay prefix actually allocated on the account +// network, or an invalid prefix if none is set. Settings.NetworkRange is a +// user-facing override that is empty on legacy accounts, so the effective +// range must be read from network.Net to compare against an incoming update. +func prefixFromIPNet(ipNet net.IPNet) netip.Prefix { + if ipNet.IP == nil { + return netip.Prefix{} + } + addr, ok := netip.AddrFromSlice(ipNet.IP) + if !ok { + return netip.Prefix{} + } + ones, _ := ipNet.Mask.Size() + return netip.PrefixFrom(addr.Unmap(), ones) +} + func (am *DefaultAccountManager) validateSettingsUpdate(ctx context.Context, transaction store.Store, newSettings, oldSettings *types.Settings, userID, accountID string) error { halfYearLimit := 180 * 24 * time.Hour if newSettings.PeerLoginExpiration > halfYearLimit { diff --git a/management/server/account_test.go b/management/server/account_test.go index 65b27df49..60720faa6 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -3970,6 +3970,96 @@ func TestDefaultAccountManager_UpdateAccountSettings_NetworkRangeChange(t *testi } } +// TestDefaultAccountManager_UpdateAccountSettings_NetworkRangePreserved guards against +// peer IP reallocation when a settings update carries the network range that is already +// in use. Legacy accounts have Settings.NetworkRange unset in the DB while network.Net +// holds the actual allocated overlay; the dashboard backfills the GET response from +// network.Net and echoes the value back on PUT, so the diff must be against the +// effective range to avoid renumbering every peer on an unrelated settings change. +func TestDefaultAccountManager_UpdateAccountSettings_NetworkRangePreserved(t *testing.T) { + manager, _, account, peer1, peer2, peer3 := setupNetworkMapTest(t) + ctx := context.Background() + + settings, err := manager.Store.GetAccountSettings(ctx, store.LockingStrengthNone, account.Id) + require.NoError(t, err) + require.False(t, settings.NetworkRange.IsValid(), "precondition: new accounts leave Settings.NetworkRange unset") + + network, err := manager.Store.GetAccountNetwork(ctx, store.LockingStrengthNone, account.Id) + require.NoError(t, err) + require.NotNil(t, network.Net.IP, "precondition: network.Net should be allocated") + addr, ok := netip.AddrFromSlice(network.Net.IP) + require.True(t, ok) + ones, _ := network.Net.Mask.Size() + effective := netip.PrefixFrom(addr.Unmap(), ones) + require.True(t, effective.IsValid()) + + before := map[string]netip.Addr{peer1.ID: peer1.IP, peer2.ID: peer2.IP, peer3.ID: peer3.IP} + + // Round-trip the effective range as if the dashboard echoed back the GET-backfilled value. + _, err = manager.UpdateAccountSettings(ctx, account.Id, userID, &types.Settings{ + PeerLoginExpirationEnabled: true, + PeerLoginExpiration: types.DefaultPeerLoginExpiration, + NetworkRange: effective, + Extra: &types.ExtraSettings{}, + }) + require.NoError(t, err) + + peers, err := manager.Store.GetAccountPeers(ctx, store.LockingStrengthNone, account.Id, "", "") + require.NoError(t, err) + require.Len(t, peers, len(before)) + for _, p := range peers { + assert.Equal(t, before[p.ID], p.IP, "peer %s IP should not change when range matches effective", p.ID) + } + + // Carrying the same range with host bits set must also be a no-op once canonicalized. + hostBitsForm := netip.PrefixFrom(peer1.IP, ones) + require.NotEqual(t, effective, hostBitsForm, "precondition: host-bit form should differ before masking") + _, err = manager.UpdateAccountSettings(ctx, account.Id, userID, &types.Settings{ + PeerLoginExpirationEnabled: true, + PeerLoginExpiration: types.DefaultPeerLoginExpiration, + NetworkRange: hostBitsForm, + Extra: &types.ExtraSettings{}, + }) + require.NoError(t, err) + + peers, err = manager.Store.GetAccountPeers(ctx, store.LockingStrengthNone, account.Id, "", "") + require.NoError(t, err) + for _, p := range peers { + assert.Equal(t, before[p.ID], p.IP, "peer %s IP should not change for host-bit-set equivalent range", p.ID) + } + + // Omitting NetworkRange (invalid prefix) must also be a no-op. + _, err = manager.UpdateAccountSettings(ctx, account.Id, userID, &types.Settings{ + PeerLoginExpirationEnabled: true, + PeerLoginExpiration: types.DefaultPeerLoginExpiration, + Extra: &types.ExtraSettings{}, + }) + require.NoError(t, err) + + peers, err = manager.Store.GetAccountPeers(ctx, store.LockingStrengthNone, account.Id, "", "") + require.NoError(t, err) + for _, p := range peers { + assert.Equal(t, before[p.ID], p.IP, "peer %s IP should not change when NetworkRange omitted", p.ID) + } + + // Sanity: an actually different range still triggers reallocation. + newRange := netip.MustParsePrefix("100.99.0.0/16") + _, err = manager.UpdateAccountSettings(ctx, account.Id, userID, &types.Settings{ + PeerLoginExpirationEnabled: true, + PeerLoginExpiration: types.DefaultPeerLoginExpiration, + NetworkRange: newRange, + Extra: &types.ExtraSettings{}, + }) + require.NoError(t, err) + + peers, err = manager.Store.GetAccountPeers(ctx, store.LockingStrengthNone, account.Id, "", "") + require.NoError(t, err) + for _, p := range peers { + assert.True(t, newRange.Contains(p.IP), "peer %s should be in new range %s, got %s", p.ID, newRange, p.IP) + assert.NotEqual(t, before[p.ID], p.IP, "peer %s IP should change on real range update", p.ID) + } +} + func TestDefaultAccountManager_UpdateAccountSettings_IPv6EnabledGroups(t *testing.T) { manager, _, account, peer1, peer2, peer3 := setupNetworkMapTest(t) ctx := context.Background() From 347c5bf317794729a044ce9f866f29e357d386d9 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Sat, 16 May 2026 16:29:01 +0200 Subject: [PATCH 10/31] Avoid context cancellation in `cancelPeerRoutines` (#6175) When closing go routines and handling peer disconnect, we should avoid canceling the flow due to parent gRPC context cancellation. This change triggers disconnection handling with a context that is not bound to the parent gRPC cancellation. --- management/internals/shared/grpc/server.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/management/internals/shared/grpc/server.go b/management/internals/shared/grpc/server.go index 70024bac6..1d8234304 100644 --- a/management/internals/shared/grpc/server.go +++ b/management/internals/shared/grpc/server.go @@ -522,10 +522,11 @@ func (s *Server) sendJob(ctx context.Context, peerKey wgtypes.Key, job *job.Even } func (s *Server) cancelPeerRoutines(ctx context.Context, accountID string, peer *nbpeer.Peer, streamStartTime time.Time) { - unlock := s.acquirePeerLockByUID(ctx, peer.Key) + uncanceledCTX := context.WithoutCancel(ctx) + unlock := s.acquirePeerLockByUID(uncanceledCTX, peer.Key) defer unlock() - s.cancelPeerRoutinesWithoutLock(ctx, accountID, peer, streamStartTime) + s.cancelPeerRoutinesWithoutLock(uncanceledCTX, accountID, peer, streamStartTime) } func (s *Server) cancelPeerRoutinesWithoutLock(ctx context.Context, accountID string, peer *nbpeer.Peer, streamStartTime time.Time) { From 3f91f49277e1841bdfccda06ae7baa0430e6de2e Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Sat, 16 May 2026 23:52:57 +0900 Subject: [PATCH 11/31] Clean up legacy 32-bit and HKCU registry entries on Windows install (#6176) --- client/installer.nsis | 23 ++++++++++++++++++----- client/netbird.wxs | 25 +++++++++++++++++++++++++ 2 files changed, 43 insertions(+), 5 deletions(-) diff --git a/client/installer.nsis b/client/installer.nsis index 63bff1c5b..3e057df10 100644 --- a/client/installer.nsis +++ b/client/installer.nsis @@ -260,15 +260,23 @@ WriteRegStr ${REG_ROOT} "${UNINSTALL_PATH}" "Publisher" "${COMP_NAME}" WriteRegStr ${REG_ROOT} "${UI_REG_APP_PATH}" "" "$INSTDIR\${UI_APP_EXE}" -; Create autostart registry entry based on checkbox +; Drop Run, App Paths and Uninstall entries left in the 32-bit registry view +; or HKCU by legacy installers. +DetailPrint "Cleaning legacy 32-bit / HKCU entries..." +DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}" +SetRegView 32 +DeleteRegValue HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}" +DeleteRegKey HKLM "${REG_APP_PATH}" +DeleteRegKey HKLM "${UI_REG_APP_PATH}" +DeleteRegKey HKLM "${UNINSTALL_PATH}" +SetRegView 64 + DetailPrint "Autostart enabled: $AutostartEnabled" ${If} $AutostartEnabled == "1" WriteRegStr HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}" '"$INSTDIR\${UI_APP_EXE}.exe"' DetailPrint "Added autostart registry entry: $INSTDIR\${UI_APP_EXE}.exe" ${Else} DeleteRegValue HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}" - ; Legacy: pre-HKLM installs wrote to HKCU; clean that up too. - DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}" DetailPrint "Autostart not enabled by user" ${EndIf} @@ -299,11 +307,16 @@ ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service uninstall' DetailPrint "Terminating Netbird UI process..." ExecWait `taskkill /im ${UI_APP_EXE}.exe /f` -; Remove autostart registry entry +; Remove autostart entries from every view a previous installer may have used. DetailPrint "Removing autostart registry entry if exists..." DeleteRegValue HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}" -; Legacy: pre-HKLM installs wrote to HKCU; clean that up too. DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}" +SetRegView 32 +DeleteRegValue HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}" +DeleteRegKey HKLM "${REG_APP_PATH}" +DeleteRegKey HKLM "${UI_REG_APP_PATH}" +DeleteRegKey HKLM "${UNINSTALL_PATH}" +SetRegView 64 ; Handle data deletion based on checkbox DetailPrint "Checking if user requested data deletion..." diff --git a/client/netbird.wxs b/client/netbird.wxs index 6f18b63b5..96814ce52 100644 --- a/client/netbird.wxs +++ b/client/netbird.wxs @@ -64,6 +64,13 @@ + + + + + @@ -76,10 +83,28 @@ + + + + + + + + + + + From 705f87fc20d4410fd8e21986725b2063f72d864d Mon Sep 17 00:00:00 2001 From: Nicolas Frati Date: Mon, 18 May 2026 12:57:59 +0200 Subject: [PATCH 12/31] [management] fix: device redirect uri wasn't registered (#6191) * fix: device redirect uri wasn't registered * fix lint --- management/server/idp/embedded.go | 27 ++++++++++++++++++++----- management/server/idp/embedded_test.go | 28 ++++++++++++++++++++++++++ 2 files changed, 50 insertions(+), 5 deletions(-) diff --git a/management/server/idp/embedded.go b/management/server/idp/embedded.go index a1852a8bc..821e6ff55 100644 --- a/management/server/idp/embedded.go +++ b/management/server/idp/embedded.go @@ -6,7 +6,9 @@ import ( "errors" "fmt" "net/http" + "net/url" "os" + "path" "strings" "github.com/dexidp/dex/storage" @@ -138,10 +140,13 @@ func (c *EmbeddedIdPConfig) ToYAMLConfig() (*dex.YAMLConfig, error) { return nil, fmt.Errorf("invalid IdP storage config: %w", err) } - // Build CLI redirect URIs including the device callback (both relative and absolute) + // Build CLI redirect URIs including the device callback. Dex uses the issuer-relative + // path (for example, /oauth2/device/callback) when completing the device flow, so + // include it explicitly in addition to the legacy bare path and absolute URL. cliRedirectURIs := c.CLIRedirectURIs cliRedirectURIs = append(cliRedirectURIs, "/device/callback") - cliRedirectURIs = append(cliRedirectURIs, c.Issuer+"/device/callback") + cliRedirectURIs = append(cliRedirectURIs, issuerRelativeDeviceCallback(c.Issuer)) + cliRedirectURIs = append(cliRedirectURIs, strings.TrimSuffix(c.Issuer, "/")+"/device/callback") // Build dashboard redirect URIs including the OAuth callback for proxy authentication dashboardRedirectURIs := c.DashboardRedirectURIs @@ -154,6 +159,10 @@ func (c *EmbeddedIdPConfig) ToYAMLConfig() (*dex.YAMLConfig, error) { // MGMT api and the dashboard, adding baseURL means less configuration for the instance admin dashboardPostLogoutRedirectURIs = append(dashboardPostLogoutRedirectURIs, baseURL) + redirectURIs := make([]string, 0) + redirectURIs = append(redirectURIs, cliRedirectURIs...) + redirectURIs = append(redirectURIs, dashboardRedirectURIs...) + cfg := &dex.YAMLConfig{ Issuer: c.Issuer, Storage: dex.Storage{ @@ -179,14 +188,14 @@ func (c *EmbeddedIdPConfig) ToYAMLConfig() (*dex.YAMLConfig, error) { ID: staticClientDashboard, Name: "NetBird Dashboard", Public: true, - RedirectURIs: dashboardRedirectURIs, + RedirectURIs: redirectURIs, PostLogoutRedirectURIs: sanitizePostLogoutRedirectURIs(dashboardPostLogoutRedirectURIs), }, { ID: staticClientCLI, Name: "NetBird CLI", Public: true, - RedirectURIs: cliRedirectURIs, + RedirectURIs: redirectURIs, }, }, StaticConnectors: c.StaticConnectors, @@ -217,6 +226,14 @@ func (c *EmbeddedIdPConfig) ToYAMLConfig() (*dex.YAMLConfig, error) { return cfg, nil } +func issuerRelativeDeviceCallback(issuer string) string { + u, err := url.Parse(issuer) + if err != nil || u.Path == "" { + return "/device/callback" + } + return path.Join(u.Path, "/device/callback") +} + // Due to how the frontend generates the logout, sometimes it appends a trailing slash // and because Dex only allows exact matches, we need to make sure we always have both // versions of each provided uri @@ -299,7 +316,7 @@ func resolveSessionCookieEncryptionKey(configuredKey string) (string, error) { } } - return "", fmt.Errorf("invalid embedded IdP session cookie encryption key: %s (or sessionCookieEncryptionKey) must be 16, 24, or 32 bytes as a raw string or base64-encoded to one of those lengths; got %d raw bytes", sessionCookieEncryptionKeyEnv, len([]byte(key))) + return "", fmt.Errorf("invalid embedded IdP session cookie encryption key:%s (or sessionCookieEncryptionKey) must be 16, 24, or 32 bytes as a raw string or base64-encoded to one of those lengths; got %d raw bytes", sessionCookieEncryptionKeyEnv, len([]byte(key))) } func validSessionCookieEncryptionKeyLength(length int) bool { diff --git a/management/server/idp/embedded_test.go b/management/server/idp/embedded_test.go index 09dc67614..91cd27aee 100644 --- a/management/server/idp/embedded_test.go +++ b/management/server/idp/embedded_test.go @@ -314,6 +314,34 @@ func TestEmbeddedIdPManager_UpdateUserPassword(t *testing.T) { }) } +func TestEmbeddedIdPConfig_ToYAMLConfig_IncludesDeviceCallbackRedirectURI(t *testing.T) { + config := &EmbeddedIdPConfig{ + Enabled: true, + Issuer: "https://example.com/oauth2", + Storage: EmbeddedStorageConfig{ + Type: "sqlite3", + Config: EmbeddedStorageTypeConfig{ + File: filepath.Join(t.TempDir(), "dex.db"), + }, + }, + } + + yamlConfig, err := config.ToYAMLConfig() + require.NoError(t, err) + + var cliRedirectURIs []string + for _, client := range yamlConfig.StaticClients { + if client.ID == staticClientCLI { + cliRedirectURIs = client.RedirectURIs + break + } + } + require.NotEmpty(t, cliRedirectURIs) + assert.Contains(t, cliRedirectURIs, "/device/callback") + assert.Contains(t, cliRedirectURIs, "/oauth2/device/callback") + assert.Contains(t, cliRedirectURIs, "https://example.com/oauth2/device/callback") +} + func TestEmbeddedIdPConfig_ToYAMLConfig_SessionCookieEncryptionKey(t *testing.T) { t.Setenv(sessionCookieEncryptionKeyEnv, "") From 13d32d274f74b700557f8f6a615f56be2ab9c6a5 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Mon, 18 May 2026 20:25:12 +0200 Subject: [PATCH 13/31] [management] Fence peer status updates with a session token (#6193) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * [management] Fence peer status updates with a session token The connect/disconnect path used a best-effort LastSeen-after-streamStart comparison to decide whether a status update should land. Under contention โ€” a re-sync arriving while the previous stream's disconnect was still in flight, or two management replicas seeing the same peer at once โ€” the check was a read-then-decide-then-write window: any UPDATE in between caused the wrong row to be written. The Go-side time.Now() that fed the comparison also drifted under lock contention, since it was captured seconds before the write actually committed. Replace it with an integer-nanosecond fencing token stored alongside the status. Every gRPC sync stream uses its open time (UnixNano) as its token. Connects only land when the incoming token is strictly greater than the stored one; disconnects only land when the incoming token equals the stored one (i.e. we're the stream that owns the current session). Both are single optimistic-locked UPDATEs โ€” no read-then-write, no transaction wrapper. LastSeen is now written by the database itself (CURRENT_TIMESTAMP). The caller never supplies it, so the value always reflects the real moment of the UPDATE rather than the moment the caller queued the work โ€” which was already off by minutes under heavy lock contention. Side effects (geo lookup, peer-login-expiration scheduling, network-map fan-out) are explicitly documented as running after the fence UPDATE commits, never inside it. Geo also skips the update when realIP equals the stored ConnectionIP, dropping a redundant SavePeerLocation call on same-IP reconnects. Tests cover the three semantic cases (matched disconnect lands, stale disconnect dropped, stale connect dropped) plus a 16-goroutine race test that asserts the highest token always wins. * [management] Add SessionStartedAt to peer status updates Stored `SessionStartedAt` for fencing token propagation across goroutines and updated database queries/functions to handle the new field. Removed outdated geolocation handling logic and adjusted tests for concurrency safety. * Rename `peer_status_required_approval` to `peer_status_requires_approval` in SQL store fields --- management/server/account.go | 29 ++-- management/server/account/manager.go | 3 +- management/server/account/manager_mock.go | 22 ++- management/server/account_test.go | 115 ++++++++++++--- management/server/mock_server/account_mock.go | 24 +++- management/server/peer.go | 131 +++++++++--------- management/server/peer/peer.go | 19 ++- management/server/store/sql_store.go | 84 ++++++++++- management/server/store/store.go | 15 ++ management/server/store/store_mock.go | 30 ++++ 10 files changed, 354 insertions(+), 118 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index e7b4acaac..8e4e595f0 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -1868,35 +1868,32 @@ func domainIsUpToDate(domain string, domainCategory string, userAuth auth.UserAu return domainCategory == types.PrivateCategory || userAuth.DomainCategory != types.PrivateCategory || domain != userAuth.Domain } +// SyncAndMarkPeer is the per-Sync entry point: it refreshes the peer's +// network map and then marks the peer connected with a session token +// derived from syncTime (the moment the gRPC stream opened). Any +// concurrent stream that started earlier loses the optimistic-lock race +// in MarkPeerConnected and bails without writing. func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP, syncTime time.Time) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) { peer, netMap, postureChecks, dnsfwdPort, err := am.SyncPeer(ctx, types.PeerSync{WireGuardPubKey: peerPubKey, Meta: meta}, accountID) if err != nil { return nil, nil, nil, 0, fmt.Errorf("error syncing peer: %w", err) } - err = am.MarkPeerConnected(ctx, peerPubKey, true, realIP, accountID, syncTime) - if err != nil { + if err := am.MarkPeerConnected(ctx, peerPubKey, realIP, accountID, syncTime.UnixNano()); err != nil { log.WithContext(ctx).Warnf("failed marking peer as connected %s %v", peerPubKey, err) } return peer, netMap, postureChecks, dnsfwdPort, nil } +// OnPeerDisconnected is invoked when a sync stream ends. It marks the +// peer disconnected only when the stored SessionStartedAt matches the +// nanosecond token derived from streamStartTime โ€” i.e. only when this +// is the stream that currently owns the peer's session. A mismatch +// means a newer stream has already replaced us, so the disconnect is +// dropped. func (am *DefaultAccountManager) OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string, streamStartTime time.Time) error { - peer, err := am.Store.GetPeerByPeerPubKey(ctx, store.LockingStrengthNone, peerPubKey) - if err != nil { - log.WithContext(ctx).Warnf("failed to get peer %s for disconnect check: %v", peerPubKey, err) - return nil - } - - if peer.Status.LastSeen.After(streamStartTime) { - log.WithContext(ctx).Tracef("peer %s has newer activity (lastSeen=%s > streamStart=%s), skipping disconnect", - peerPubKey, peer.Status.LastSeen.Format(time.RFC3339), streamStartTime.Format(time.RFC3339)) - return nil - } - - err = am.MarkPeerConnected(ctx, peerPubKey, false, nil, accountID, time.Now().UTC()) - if err != nil { + if err := am.MarkPeerDisconnected(ctx, peerPubKey, accountID, streamStartTime.UnixNano()); err != nil { log.WithContext(ctx).Warnf("failed marking peer as disconnected %s %v", peerPubKey, err) } return nil diff --git a/management/server/account/manager.go b/management/server/account/manager.go index 71af0645c..ae3de8d79 100644 --- a/management/server/account/manager.go +++ b/management/server/account/manager.go @@ -61,7 +61,8 @@ type Manager interface { GetUserFromUserAuth(ctx context.Context, userAuth auth.UserAuth) (*types.User, error) ListUsers(ctx context.Context, accountID string) ([]*types.User, error) GetPeers(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error) - MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string, syncTime time.Time) error + MarkPeerConnected(ctx context.Context, peerKey string, realIP net.IP, accountID string, sessionStartedAt int64) error + MarkPeerDisconnected(ctx context.Context, peerKey string, accountID string, sessionStartedAt int64) error DeletePeer(ctx context.Context, accountID, peerID, userID string) error UpdatePeer(ctx context.Context, accountID, userID string, p *nbpeer.Peer) (*nbpeer.Peer, error) UpdatePeerIP(ctx context.Context, accountID, userID, peerID string, newIP netip.Addr) error diff --git a/management/server/account/manager_mock.go b/management/server/account/manager_mock.go index 7ffc41d73..0486e63ec 100644 --- a/management/server/account/manager_mock.go +++ b/management/server/account/manager_mock.go @@ -1305,17 +1305,31 @@ func (mr *MockManagerMockRecorder) LoginPeer(ctx, login interface{}) *gomock.Cal } // MarkPeerConnected mocks base method. -func (m *MockManager) MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string, syncTime time.Time) error { +func (m *MockManager) MarkPeerConnected(ctx context.Context, peerKey string, realIP net.IP, accountID string, sessionStartedAt int64) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "MarkPeerConnected", ctx, peerKey, connected, realIP, accountID, syncTime) + ret := m.ctrl.Call(m, "MarkPeerConnected", ctx, peerKey, realIP, accountID, sessionStartedAt) ret0, _ := ret[0].(error) return ret0 } // MarkPeerConnected indicates an expected call of MarkPeerConnected. -func (mr *MockManagerMockRecorder) MarkPeerConnected(ctx, peerKey, connected, realIP, accountID, syncTime interface{}) *gomock.Call { +func (mr *MockManagerMockRecorder) MarkPeerConnected(ctx, peerKey, realIP, accountID, sessionStartedAt interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MarkPeerConnected", reflect.TypeOf((*MockManager)(nil).MarkPeerConnected), ctx, peerKey, connected, realIP, accountID, syncTime) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MarkPeerConnected", reflect.TypeOf((*MockManager)(nil).MarkPeerConnected), ctx, peerKey, realIP, accountID, sessionStartedAt) +} + +// MarkPeerDisconnected mocks base method. +func (m *MockManager) MarkPeerDisconnected(ctx context.Context, peerKey string, accountID string, sessionStartedAt int64) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "MarkPeerDisconnected", ctx, peerKey, accountID, sessionStartedAt) + ret0, _ := ret[0].(error) + return ret0 +} + +// MarkPeerDisconnected indicates an expected call of MarkPeerDisconnected. +func (mr *MockManagerMockRecorder) MarkPeerDisconnected(ctx, peerKey, accountID, sessionStartedAt interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MarkPeerDisconnected", reflect.TypeOf((*MockManager)(nil).MarkPeerDisconnected), ctx, peerKey, accountID, sessionStartedAt) } // OnPeerDisconnected mocks base method. diff --git a/management/server/account_test.go b/management/server/account_test.go index 60720faa6..ba621030c 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -1813,7 +1813,7 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) { accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID}) require.NoError(t, err, "unable to get the account") - err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID, time.Now().UTC()) + err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), nil, accountID, time.Now().UTC().UnixNano()) require.NoError(t, err, "unable to mark peer connected") _, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{ @@ -1884,7 +1884,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing. require.NoError(t, err, "unable to get the account") // when we mark peer as connected, the peer login expiration routine should trigger - err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID, time.Now().UTC()) + err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), nil, accountID, time.Now().UTC().UnixNano()) require.NoError(t, err, "unable to mark peer connected") failed := waitTimeout(wg, time.Second) @@ -1910,15 +1910,16 @@ func TestDefaultAccountManager_OnPeerDisconnected_LastSeenCheck(t *testing.T) { }, false) require.NoError(t, err, "unable to add peer") - t.Run("disconnect peer when streamStartTime is after LastSeen", func(t *testing.T) { - err = manager.MarkPeerConnected(context.Background(), peerPubKey, true, nil, accountID, time.Now().UTC()) + t.Run("disconnect peer when session token matches", func(t *testing.T) { + streamStartTime := time.Now().UTC() + err = manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, streamStartTime.UnixNano()) require.NoError(t, err, "unable to mark peer connected") peer, err := manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey) require.NoError(t, err, "unable to get peer") require.True(t, peer.Status.Connected, "peer should be connected") - - streamStartTime := time.Now().UTC() + require.Equal(t, streamStartTime.UnixNano(), peer.Status.SessionStartedAt, + "SessionStartedAt should equal the token we passed in") err = manager.OnPeerDisconnected(context.Background(), accountID, peerPubKey, streamStartTime) require.NoError(t, err) @@ -1926,49 +1927,127 @@ func TestDefaultAccountManager_OnPeerDisconnected_LastSeenCheck(t *testing.T) { peer, err = manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey) require.NoError(t, err) require.False(t, peer.Status.Connected, "peer should be disconnected") + require.Equal(t, int64(0), peer.Status.SessionStartedAt, "SessionStartedAt should be reset to 0") }) - t.Run("skip disconnect when LastSeen is after streamStartTime (zombie stream protection)", func(t *testing.T) { - err = manager.MarkPeerConnected(context.Background(), peerPubKey, true, nil, accountID, time.Now().UTC()) + t.Run("skip disconnect when stored session is newer (zombie stream protection)", func(t *testing.T) { + // Newer stream wins on connect (sets SessionStartedAt = now ns). + streamStartTime := time.Now().UTC() + err = manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, streamStartTime.UnixNano()) require.NoError(t, err, "unable to mark peer connected") peer, err := manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey) require.NoError(t, err) require.True(t, peer.Status.Connected, "peer should be connected") - streamStartTime := peer.Status.LastSeen.Add(-1 * time.Hour) + // Older stream tries to mark disconnect with its own (older) session token โ€” + // fencing kicks in and the write is dropped. + staleStreamStartTime := streamStartTime.Add(-1 * time.Hour) - err = manager.OnPeerDisconnected(context.Background(), accountID, peerPubKey, streamStartTime) + err = manager.OnPeerDisconnected(context.Background(), accountID, peerPubKey, staleStreamStartTime) require.NoError(t, err) peer, err = manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey) require.NoError(t, err) require.True(t, peer.Status.Connected, - "peer should remain connected because LastSeen > streamStartTime (zombie stream protection)") + "peer should remain connected because the stored session is newer than the disconnect token") + require.Equal(t, streamStartTime.UnixNano(), peer.Status.SessionStartedAt, + "SessionStartedAt should still hold the winning stream's token") }) - t.Run("skip stale connect when peer already has newer LastSeen (blocked goroutine protection)", func(t *testing.T) { + t.Run("skip stale connect when stored session is newer (blocked goroutine protection)", func(t *testing.T) { node2SyncTime := time.Now().UTC() - err = manager.MarkPeerConnected(context.Background(), peerPubKey, true, nil, accountID, node2SyncTime) + err = manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, node2SyncTime.UnixNano()) require.NoError(t, err, "node 2 should connect peer") peer, err := manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey) require.NoError(t, err) require.True(t, peer.Status.Connected, "peer should be connected") - require.Equal(t, node2SyncTime.Unix(), peer.Status.LastSeen.Unix(), "LastSeen should be node2SyncTime") + require.Equal(t, node2SyncTime.UnixNano(), peer.Status.SessionStartedAt, + "SessionStartedAt should equal node2SyncTime token") node1StaleSyncTime := node2SyncTime.Add(-1 * time.Minute) - err = manager.MarkPeerConnected(context.Background(), peerPubKey, true, nil, accountID, node1StaleSyncTime) + err = manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, node1StaleSyncTime.UnixNano()) require.NoError(t, err, "stale connect should not return error") peer, err = manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey) require.NoError(t, err) require.True(t, peer.Status.Connected, "peer should still be connected") - require.Equal(t, node2SyncTime.Unix(), peer.Status.LastSeen.Unix(), - "LastSeen should NOT be overwritten by stale syncTime from blocked goroutine") + require.Equal(t, node2SyncTime.UnixNano(), peer.Status.SessionStartedAt, + "SessionStartedAt should NOT be overwritten by stale token from blocked goroutine") }) } +// TestDefaultAccountManager_MarkPeerConnected_ConcurrentRace exercises the +// fencing protocol under contention: many goroutines race to mark the +// same peer connected with distinct session tokens at the same time. +// The contract is that the highest token always wins and is what remains +// in the store, regardless of execution order. +func TestDefaultAccountManager_MarkPeerConnected_ConcurrentRace(t *testing.T) { + manager, _, err := createManager(t) + require.NoError(t, err, "unable to create account manager") + + accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID}) + require.NoError(t, err, "unable to get account") + + key, err := wgtypes.GenerateKey() + require.NoError(t, err, "unable to generate WireGuard key") + peerPubKey := key.PublicKey().String() + + _, _, _, err = manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{ + Key: peerPubKey, + Meta: nbpeer.PeerSystemMeta{Hostname: "race-peer"}, + }, false) + require.NoError(t, err, "unable to add peer") + + const workers = 16 + base := time.Now().UTC().UnixNano() + tokens := make([]int64, workers) + for i := range tokens { + // Spread tokens by 1ms so the comparison is unambiguous; the + // largest is index workers-1. + tokens[i] = base + int64(i)*int64(time.Millisecond) + } + expected := tokens[workers-1] + + var ready sync.WaitGroup + ready.Add(workers) + var start sync.WaitGroup + start.Add(1) + var done sync.WaitGroup + done.Add(workers) + + // require.* calls t.FailNow which is documented as unsafe from + // non-test goroutines (it calls runtime.Goexit on the wrong stack and + // races with the WaitGroup). Collect errors here and assert from the + // main goroutine after done.Wait(). + errs := make(chan error, workers) + + for i := 0; i < workers; i++ { + token := tokens[i] + go func() { + defer done.Done() + ready.Done() + start.Wait() + errs <- manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, token) + }() + } + + ready.Wait() + start.Done() + done.Wait() + close(errs) + for err := range errs { + require.NoError(t, err, "MarkPeerConnected must not error under contention") + } + + peer, err := manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey) + require.NoError(t, err) + require.True(t, peer.Status.Connected, "peer should be connected after the race") + require.Equal(t, expected, peer.Status.SessionStartedAt, + "the largest token must win regardless of execution order") +} + func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *testing.T) { manager, _, err := createManager(t) require.NoError(t, err, "unable to create account manager") @@ -1991,7 +2070,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test account, err := manager.Store.GetAccount(context.Background(), accountID) require.NoError(t, err, "unable to get the account") - err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID, time.Now().UTC()) + err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), nil, accountID, time.Now().UTC().UnixNano()) require.NoError(t, err, "unable to mark peer connected") wg := &sync.WaitGroup{} diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index 08091d4b7..aba408184 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -38,7 +38,8 @@ type MockAccountManager struct { GetUserFromUserAuthFunc func(ctx context.Context, userAuth auth.UserAuth) (*types.User, error) ListUsersFunc func(ctx context.Context, accountID string) ([]*types.User, error) GetPeersFunc func(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error) - MarkPeerConnectedFunc func(ctx context.Context, peerKey string, connected bool, realIP net.IP, syncTime time.Time) error + MarkPeerConnectedFunc func(ctx context.Context, peerKey string, realIP net.IP, accountID string, sessionStartedAt int64) error + MarkPeerDisconnectedFunc func(ctx context.Context, peerKey string, accountID string, sessionStartedAt int64) error SyncAndMarkPeerFunc func(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP, syncTime time.Time) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) DeletePeerFunc func(ctx context.Context, accountID, peerKey, userID string) error GetNetworkMapFunc func(ctx context.Context, peerKey string) (*types.NetworkMap, error) @@ -227,7 +228,14 @@ func (am *MockAccountManager) SyncAndMarkPeer(ctx context.Context, accountID str return nil, nil, nil, 0, status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented") } -func (am *MockAccountManager) OnPeerDisconnected(_ context.Context, accountID string, peerPubKey string, streamStartTime time.Time) error { +func (am *MockAccountManager) OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string, streamStartTime time.Time) error { + // Mirror DefaultAccountManager.OnPeerDisconnected: drive the fencing + // hook so tests that inject MarkPeerDisconnectedFunc actually observe + // disconnect events. Falls through to nil when no hook is set, which + // is the original behaviour. + if am.MarkPeerDisconnectedFunc != nil { + return am.MarkPeerDisconnectedFunc(ctx, peerPubKey, accountID, streamStartTime.UnixNano()) + } return nil } @@ -328,13 +336,21 @@ func (am *MockAccountManager) GetAccountIDByUserID(ctx context.Context, userAuth } // MarkPeerConnected mock implementation of MarkPeerConnected from server.AccountManager interface -func (am *MockAccountManager) MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string, syncTime time.Time) error { +func (am *MockAccountManager) MarkPeerConnected(ctx context.Context, peerKey string, realIP net.IP, accountID string, sessionStartedAt int64) error { if am.MarkPeerConnectedFunc != nil { - return am.MarkPeerConnectedFunc(ctx, peerKey, connected, realIP, syncTime) + return am.MarkPeerConnectedFunc(ctx, peerKey, realIP, accountID, sessionStartedAt) } return status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented") } +// MarkPeerDisconnected mock implementation of MarkPeerDisconnected from server.AccountManager interface +func (am *MockAccountManager) MarkPeerDisconnected(ctx context.Context, peerKey string, accountID string, sessionStartedAt int64) error { + if am.MarkPeerDisconnectedFunc != nil { + return am.MarkPeerDisconnectedFunc(ctx, peerKey, accountID, sessionStartedAt) + } + return status.Errorf(codes.Unimplemented, "method MarkPeerDisconnected is not implemented") +} + // DeleteAccount mock implementation of DeleteAccount from server.AccountManager interface func (am *MockAccountManager) DeleteAccount(ctx context.Context, accountID, userID string) error { if am.DeleteAccountFunc != nil { diff --git a/management/server/peer.go b/management/server/peer.go index c3b130ba2..4790a5aab 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -16,7 +16,6 @@ import ( "golang.org/x/exp/maps" nbdns "github.com/netbirdio/netbird/dns" - "github.com/netbirdio/netbird/management/server/geolocation" "github.com/netbirdio/netbird/management/server/idp" routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" "github.com/netbirdio/netbird/management/server/permissions/modules" @@ -63,56 +62,51 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID return am.Store.GetUserPeers(ctx, store.LockingStrengthNone, accountID, userID) } -// MarkPeerConnected marks peer as connected (true) or disconnected (false) -// syncTime is used as the LastSeen timestamp and for stale request detection -func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubKey string, connected bool, realIP net.IP, accountID string, syncTime time.Time) error { - var peer *nbpeer.Peer - var settings *types.Settings - var expired bool - var err error - var skipped bool - - err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { - peer, err = transaction.GetPeerByPeerPubKey(ctx, store.LockingStrengthUpdate, peerPubKey) - if err != nil { - return err - } - - if connected && !syncTime.After(peer.Status.LastSeen) { - log.WithContext(ctx).Tracef("peer %s has newer activity (lastSeen=%s >= syncTime=%s), skipping connect", - peer.ID, peer.Status.LastSeen.Format(time.RFC3339), syncTime.Format(time.RFC3339)) - skipped = true - return nil - } - - expired, err = updatePeerStatusAndLocation(ctx, am.geo, transaction, peer, connected, realIP, accountID, syncTime) - return err - }) - if skipped { - return nil - } +// MarkPeerConnected marks a peer as connected with optimistic-locked +// fencing on PeerStatus.SessionStartedAt. The sessionStartedAt argument +// is the start time of the gRPC sync stream that owns this update, +// expressed as Unix nanoseconds โ€” only the call whose token is greater +// than what's stored wins. LastSeen is written by the database itself; +// we never pass it down. +// +// Disconnects use MarkPeerDisconnected and require the session to match +// exactly; see PeerStatus.SessionStartedAt for the protocol. +func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubKey string, realIP net.IP, accountID string, sessionStartedAt int64) error { + peer, err := am.Store.GetPeerByPeerPubKey(ctx, store.LockingStrengthNone, peerPubKey) if err != nil { return err } + updated, err := am.Store.MarkPeerConnectedIfNewerSession(ctx, accountID, peer.ID, sessionStartedAt) + if err != nil { + return err + } + if !updated { + log.WithContext(ctx).Tracef("peer %s already has a newer session in store, skipping connect", peer.ID) + return nil + } + + if am.geo != nil && realIP != nil { + am.updatePeerLocationIfChanged(ctx, accountID, peer, realIP) + } + + expired := peer.Status != nil && peer.Status.LoginExpired + if peer.AddedWithSSOLogin() { - settings, err = am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID) + settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID) if err != nil { return err } - if peer.LoginExpirationEnabled && settings.PeerLoginExpirationEnabled { am.schedulePeerLoginExpiration(ctx, accountID) } - if peer.InactivityExpirationEnabled && settings.PeerInactivityExpirationEnabled { am.checkAndSchedulePeerInactivityExpiration(ctx, accountID) } } if expired { - err = am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peer.ID}) - if err != nil { + if err = am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peer.ID}); err != nil { return fmt.Errorf("notify network map controller of peer update: %w", err) } } @@ -120,41 +114,46 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK return nil } -func updatePeerStatusAndLocation(ctx context.Context, geo geolocation.Geolocation, transaction store.Store, peer *nbpeer.Peer, connected bool, realIP net.IP, accountID string, syncTime time.Time) (bool, error) { - oldStatus := peer.Status.Copy() - newStatus := oldStatus - newStatus.LastSeen = syncTime - newStatus.Connected = connected - // whenever peer got connected that means that it logged in successfully - if newStatus.Connected { - newStatus.LoginExpired = false - } - peer.Status = newStatus - - if geo != nil && realIP != nil { - location, err := geo.Lookup(realIP) - if err != nil { - log.WithContext(ctx).Warnf("failed to get location for peer %s realip: [%s]: %v", peer.ID, realIP.String(), err) - } else { - peer.Location.ConnectionIP = realIP - peer.Location.CountryCode = location.Country.ISOCode - peer.Location.CityName = location.City.Names.En - peer.Location.GeoNameID = location.City.GeonameID - err = transaction.SavePeerLocation(ctx, accountID, peer) - if err != nil { - log.WithContext(ctx).Warnf("could not store location for peer %s: %s", peer.ID, err) - } - } - } - - log.WithContext(ctx).Debugf("saving peer status for peer %s is connected: %t", peer.ID, connected) - - err := transaction.SavePeerStatus(ctx, accountID, peer.ID, *newStatus) +// MarkPeerDisconnected marks a peer as disconnected, but only when the +// stored session token matches the one passed in. A mismatch means a +// newer stream has already taken ownership of the peer โ€” disconnects from +// the older stream are ignored. LastSeen is written by the database. +func (am *DefaultAccountManager) MarkPeerDisconnected(ctx context.Context, peerPubKey string, accountID string, sessionStartedAt int64) error { + peer, err := am.Store.GetPeerByPeerPubKey(ctx, store.LockingStrengthNone, peerPubKey) if err != nil { - return false, err + return err } - return oldStatus.LoginExpired, nil + updated, err := am.Store.MarkPeerDisconnectedIfSameSession(ctx, accountID, peer.ID, sessionStartedAt) + if err != nil { + return err + } + if !updated { + log.WithContext(ctx).Tracef("peer %s session token mismatch on disconnect (token=%d), skipping", + peer.ID, sessionStartedAt) + } + return nil +} + +// updatePeerLocationIfChanged refreshes the geolocation on a separate +// row update, only when the connection IP actually changed. Geo lookups +// are expensive so we skip same-IP reconnects. +func (am *DefaultAccountManager) updatePeerLocationIfChanged(ctx context.Context, accountID string, peer *nbpeer.Peer, realIP net.IP) { + if peer.Location.ConnectionIP != nil && peer.Location.ConnectionIP.Equal(realIP) { + return + } + location, err := am.geo.Lookup(realIP) + if err != nil { + log.WithContext(ctx).Warnf("failed to get location for peer %s realip: [%s]: %v", peer.ID, realIP.String(), err) + return + } + peer.Location.ConnectionIP = realIP + peer.Location.CountryCode = location.Country.ISOCode + peer.Location.CityName = location.City.Names.En + peer.Location.GeoNameID = location.City.GeonameID + if err := am.Store.SavePeerLocation(ctx, accountID, peer); err != nil { + log.WithContext(ctx).Warnf("could not store location for peer %s: %s", peer.ID, err) + } } // UpdatePeer updates peer. Only Peer.Name, Peer.SSHEnabled, Peer.LoginExpirationEnabled and Peer.InactivityExpirationEnabled can be updated. diff --git a/management/server/peer/peer.go b/management/server/peer/peer.go index 17df761a1..2963dfcbd 100644 --- a/management/server/peer/peer.go +++ b/management/server/peer/peer.go @@ -74,8 +74,19 @@ type ProxyMeta struct { } type PeerStatus struct { //nolint:revive - // LastSeen is the last time peer was connected to the management service + // LastSeen is the last time the peer status was updated (i.e. the last + // time we observed the peer being alive on a sync stream). Written by + // the database (CURRENT_TIMESTAMP) โ€” callers do not supply it. LastSeen time.Time + // SessionStartedAt records when the currently-active sync stream began, + // stored as Unix nanoseconds. It acts as the optimistic-locking token + // for status updates: a stream is only allowed to mutate the peer's + // status when its own token strictly exceeds the stored token (when connecting) + // or matches it exactly (for disconnects). Zero means "no + // active session". Integer nanoseconds are used so equality is + // precision-safe across drivers, and so the predicates compose to a + // single bigint comparison. + SessionStartedAt int64 // Connected indicates whether peer is connected to the management service or not Connected bool // LoginExpired @@ -375,10 +386,14 @@ func (p *Peer) EventMeta(dnsDomain string) map[string]any { return meta } -// Copy PeerStatus +// Copy PeerStatus. SessionStartedAt must be propagated so clone-based +// callers (Peer.Copy, MarkLoginExpired, UpdateLastLogin) don't silently +// reset the fencing token to zero โ€” that would let any subsequent +// SavePeerStatus write reopen the optimistic-lock window. func (p *PeerStatus) Copy() *PeerStatus { return &PeerStatus{ LastSeen: p.LastSeen, + SessionStartedAt: p.SessionStartedAt, Connected: p.Connected, LoginExpired: p.LoginExpired, RequiresApproval: p.RequiresApproval, diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index 893ee2168..8cf37de56 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -498,8 +498,9 @@ func (s *SqlStore) SavePeerStatus(ctx context.Context, accountID, peerID string, peerCopy.Status = &peerStatus fieldsToUpdate := []string{ - "peer_status_last_seen", "peer_status_connected", - "peer_status_login_expired", "peer_status_required_approval", + "peer_status_last_seen", "peer_status_session_started_at", + "peer_status_connected", "peer_status_login_expired", + "peer_status_requires_approval", } result := s.db.Model(&nbpeer.Peer{}). Select(fieldsToUpdate). @@ -516,6 +517,69 @@ func (s *SqlStore) SavePeerStatus(ctx context.Context, accountID, peerID string, return nil } +// MarkPeerConnectedIfNewerSession is an atomic optimistic-locked update. +// The peer is marked connected with the given session token only when +// the stored SessionStartedAt is strictly smaller than the incoming +// one โ€” equivalently, when no newer stream has already taken ownership. +// The sentinel zero (set on peer creation or after a disconnect) counts +// as the smallest possible token. This is the write half of the +// fencing protocol described on PeerStatus.SessionStartedAt. +// +// The post-write side effects in the caller โ€” geo lookup, +// schedulePeerLoginExpiration, checkAndSchedulePeerInactivityExpiration, +// OnPeersUpdated โ€” all run AFTER this method returns and are deliberately +// outside the database write so they cannot extend the row-lock window. +// +// LastSeen is set to the database's clock (CURRENT_TIMESTAMP) at the +// moment the row is written. The caller never supplies LastSeen because +// the value would otherwise drift under lock contention โ€” a Go-side +// time.Now() taken before the write can land minutes later than the +// actual UPDATE under load, which previously caused real ordering bugs. +func (s *SqlStore) MarkPeerConnectedIfNewerSession(ctx context.Context, accountID, peerID string, newSessionStartedAt int64) (bool, error) { + result := s.db.WithContext(ctx). + Model(&nbpeer.Peer{}). + Where(accountAndIDQueryCondition, accountID, peerID). + Where("peer_status_session_started_at < ?", newSessionStartedAt). + Updates(map[string]any{ + "peer_status_connected": true, + "peer_status_last_seen": gorm.Expr("CURRENT_TIMESTAMP"), + "peer_status_session_started_at": newSessionStartedAt, + "peer_status_login_expired": false, + }) + if result.Error != nil { + return false, status.Errorf(status.Internal, "mark peer connected: %v", result.Error) + } + return result.RowsAffected > 0, nil +} + +// MarkPeerDisconnectedIfSameSession is an atomic optimistic-locked update. +// The peer is marked disconnected only when the stored SessionStartedAt +// matches the incoming token โ€” meaning the stream that owns the current +// session is the one ending. If a newer stream has already replaced the +// session, the update is skipped. LastSeen is set to CURRENT_TIMESTAMP at +// write time; see MarkPeerConnectedIfNewerSession for the rationale. +// +// A zero sessionStartedAt is rejected at the call site; the underlying +// WHERE on equality would otherwise match every never-connected peer. +func (s *SqlStore) MarkPeerDisconnectedIfSameSession(ctx context.Context, accountID, peerID string, sessionStartedAt int64) (bool, error) { + if sessionStartedAt == 0 { + return false, nil + } + result := s.db.WithContext(ctx). + Model(&nbpeer.Peer{}). + Where(accountAndIDQueryCondition, accountID, peerID). + Where("peer_status_session_started_at = ?", sessionStartedAt). + Updates(map[string]any{ + "peer_status_connected": false, + "peer_status_last_seen": gorm.Expr("CURRENT_TIMESTAMP"), + "peer_status_session_started_at": int64(0), + }) + if result.Error != nil { + return false, status.Errorf(status.Internal, "mark peer disconnected: %v", result.Error) + } + return result.RowsAffected > 0, nil +} + func (s *SqlStore) SavePeerLocation(ctx context.Context, accountID string, peerWithLocation *nbpeer.Peer) error { // To maintain data integrity, we create a copy of the peer's location to prevent unintended updates to other fields. var peerCopy nbpeer.Peer @@ -1723,9 +1787,10 @@ func (s *SqlStore) getPeers(ctx context.Context, accountID string) ([]nbpeer.Pee inactivity_expiration_enabled, last_login, created_at, ephemeral, extra_dns_labels, allow_extra_dns_labels, meta_hostname, meta_go_os, meta_kernel, meta_core, meta_platform, meta_os, meta_os_version, meta_wt_version, meta_ui_version, meta_kernel_version, meta_network_addresses, meta_system_serial_number, meta_system_product_name, meta_system_manufacturer, - meta_environment, meta_flags, meta_files, meta_capabilities, peer_status_last_seen, peer_status_connected, peer_status_login_expired, - peer_status_requires_approval, location_connection_ip, location_country_code, location_city_name, - location_geo_name_id, proxy_meta_embedded, proxy_meta_cluster, ipv6 FROM peers WHERE account_id = $1` + meta_environment, meta_flags, meta_files, meta_capabilities, peer_status_last_seen, peer_status_session_started_at, + peer_status_connected, peer_status_login_expired, peer_status_requires_approval, location_connection_ip, + location_country_code, location_city_name, location_geo_name_id, proxy_meta_embedded, proxy_meta_cluster, ipv6 + FROM peers WHERE account_id = $1` rows, err := s.pool.Query(ctx, query, accountID) if err != nil { return nil, err @@ -1738,6 +1803,7 @@ func (s *SqlStore) getPeers(ctx context.Context, accountID string) ([]nbpeer.Pee lastLogin, createdAt sql.NullTime sshEnabled, loginExpirationEnabled, inactivityExpirationEnabled, ephemeral, allowExtraDNSLabels sql.NullBool peerStatusLastSeen sql.NullTime + peerStatusSessionStartedAt sql.NullInt64 peerStatusConnected, peerStatusLoginExpired, peerStatusRequiresApproval, proxyEmbedded sql.NullBool ip, extraDNS, netAddr, env, flags, files, capabilities, connIP, ipv6 []byte metaHostname, metaGoOS, metaKernel, metaCore, metaPlatform sql.NullString @@ -1752,8 +1818,9 @@ func (s *SqlStore) getPeers(ctx context.Context, accountID string) ([]nbpeer.Pee &allowExtraDNSLabels, &metaHostname, &metaGoOS, &metaKernel, &metaCore, &metaPlatform, &metaOS, &metaOSVersion, &metaWtVersion, &metaUIVersion, &metaKernelVersion, &netAddr, &metaSystemSerialNumber, &metaSystemProductName, &metaSystemManufacturer, &env, &flags, &files, &capabilities, - &peerStatusLastSeen, &peerStatusConnected, &peerStatusLoginExpired, &peerStatusRequiresApproval, &connIP, - &locationCountryCode, &locationCityName, &locationGeoNameID, &proxyEmbedded, &proxyCluster, &ipv6) + &peerStatusLastSeen, &peerStatusSessionStartedAt, &peerStatusConnected, &peerStatusLoginExpired, + &peerStatusRequiresApproval, &connIP, &locationCountryCode, &locationCityName, &locationGeoNameID, + &proxyEmbedded, &proxyCluster, &ipv6) if err == nil { if lastLogin.Valid { @@ -1780,6 +1847,9 @@ func (s *SqlStore) getPeers(ctx context.Context, accountID string) ([]nbpeer.Pee if peerStatusLastSeen.Valid { p.Status.LastSeen = peerStatusLastSeen.Time } + if peerStatusSessionStartedAt.Valid { + p.Status.SessionStartedAt = peerStatusSessionStartedAt.Int64 + } if peerStatusConnected.Valid { p.Status.Connected = peerStatusConnected.Bool } diff --git a/management/server/store/store.go b/management/server/store/store.go index aa601c33f..a723c1fc3 100644 --- a/management/server/store/store.go +++ b/management/server/store/store.go @@ -167,6 +167,21 @@ type Store interface { GetAllEphemeralPeers(ctx context.Context, lockStrength LockingStrength) ([]*nbpeer.Peer, error) SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error SavePeerStatus(ctx context.Context, accountID, peerID string, status nbpeer.PeerStatus) error + // MarkPeerConnectedIfNewerSession sets the peer to connected with the + // given session token, but only when the stored SessionStartedAt is + // strictly less than newSessionStartedAt (the sentinel zero counts as + // "older"). LastSeen is recorded by the database at the moment the + // row is updated โ€” never by the caller โ€” so it always reflects the + // real write time even under lock contention. + // Returns true when the update happened, false when this stream lost + // the race against a newer session. + MarkPeerConnectedIfNewerSession(ctx context.Context, accountID, peerID string, newSessionStartedAt int64) (bool, error) + // MarkPeerDisconnectedIfSameSession sets the peer to disconnected and + // resets SessionStartedAt to zero, but only when the stored + // SessionStartedAt equals the given sessionStartedAt. LastSeen is + // recorded by the database. Returns true when the update happened, + // false when a newer session has taken over. + MarkPeerDisconnectedIfSameSession(ctx context.Context, accountID, peerID string, sessionStartedAt int64) (bool, error) SavePeerLocation(ctx context.Context, accountID string, peer *nbpeer.Peer) error ApproveAccountPeers(ctx context.Context, accountID string) (int, error) DeletePeer(ctx context.Context, accountID string, peerID string) error diff --git a/management/server/store/store_mock.go b/management/server/store/store_mock.go index 9780c521e..d51629606 100644 --- a/management/server/store/store_mock.go +++ b/management/server/store/store_mock.go @@ -2878,6 +2878,36 @@ func (mr *MockStoreMockRecorder) SavePeerStatus(ctx, accountID, peerID, status i return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SavePeerStatus", reflect.TypeOf((*MockStore)(nil).SavePeerStatus), ctx, accountID, peerID, status) } +// MarkPeerConnectedIfNewerSession mocks base method. +func (m *MockStore) MarkPeerConnectedIfNewerSession(ctx context.Context, accountID, peerID string, newSessionStartedAt int64) (bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "MarkPeerConnectedIfNewerSession", ctx, accountID, peerID, newSessionStartedAt) + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// MarkPeerConnectedIfNewerSession indicates an expected call of MarkPeerConnectedIfNewerSession. +func (mr *MockStoreMockRecorder) MarkPeerConnectedIfNewerSession(ctx, accountID, peerID, newSessionStartedAt interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MarkPeerConnectedIfNewerSession", reflect.TypeOf((*MockStore)(nil).MarkPeerConnectedIfNewerSession), ctx, accountID, peerID, newSessionStartedAt) +} + +// MarkPeerDisconnectedIfSameSession mocks base method. +func (m *MockStore) MarkPeerDisconnectedIfSameSession(ctx context.Context, accountID, peerID string, sessionStartedAt int64) (bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "MarkPeerDisconnectedIfSameSession", ctx, accountID, peerID, sessionStartedAt) + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// MarkPeerDisconnectedIfSameSession indicates an expected call of MarkPeerDisconnectedIfSameSession. +func (mr *MockStoreMockRecorder) MarkPeerDisconnectedIfSameSession(ctx, accountID, peerID, sessionStartedAt interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MarkPeerDisconnectedIfSameSession", reflect.TypeOf((*MockStore)(nil).MarkPeerDisconnectedIfSameSession), ctx, accountID, peerID, sessionStartedAt) +} + // SavePolicy mocks base method. func (m *MockStore) SavePolicy(ctx context.Context, policy *types2.Policy) error { m.ctrl.T.Helper() From af24fd779640538c05c5f261a1e9fdf20fe7773f Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Mon, 18 May 2026 22:55:19 +0200 Subject: [PATCH 14/31] [management] Add metrics for peer status updates and ephemeral cleanup (#6196) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * [management] Add metrics for peer status updates and ephemeral cleanup The session-fenced MarkPeerConnected / MarkPeerDisconnected path and the ephemeral peer cleanup loop both run silently today: when fencing rejects a stale stream, when a cleanup tick deletes peers, or when a batch delete fails, we have no operational signal beyond log lines. Add OpenTelemetry counters and a histogram so the same SLO-style dashboards that already exist for the network-map controller can cover peer connect/disconnect and ephemeral cleanup too. All new attributes are bounded enums: operation in {connect,disconnect} and outcome in {applied,stale,error,peer_not_found}. No account, peer, or user ID is ever written as a metric label โ€” total cardinality is fixed at compile time (8 counter series, 2 histogram series, 4 unlabeled ephemeral series). Metric methods are nil-receiver safe so test composition that doesn't wire telemetry (the bulk of the existing tests) works unchanged. The ephemeral manager exposes a SetMetrics setter rather than taking the collector through its constructor, keeping the constructor signature stable across all test call sites. * [management] Add OpenTelemetry metrics for ephemeral peer cleanup Introduce counters for tracking ephemeral peer cleanup, including peers pending deletion, cleanup runs, successful deletions, and failed batches. Metrics are nil-receiver safe to ensure compatibility with test setups without telemetry. --- .../peers/ephemeral/manager/ephemeral.go | 45 ++++++- management/internals/server/controllers.go | 6 +- management/server/peer.go | 28 +++++ .../telemetry/accountmanager_metrics.go | 65 ++++++++++ management/server/telemetry/app_metrics.go | 28 +++++ .../server/telemetry/ephemeral_metrics.go | 115 ++++++++++++++++++ 6 files changed, 281 insertions(+), 6 deletions(-) create mode 100644 management/server/telemetry/ephemeral_metrics.go diff --git a/management/internals/modules/peers/ephemeral/manager/ephemeral.go b/management/internals/modules/peers/ephemeral/manager/ephemeral.go index 758f643d0..0f902ea70 100644 --- a/management/internals/modules/peers/ephemeral/manager/ephemeral.go +++ b/management/internals/modules/peers/ephemeral/manager/ephemeral.go @@ -11,6 +11,7 @@ import ( "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral" "github.com/netbirdio/netbird/management/server/activity" nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/store" ) @@ -47,6 +48,11 @@ type EphemeralManager struct { lifeTime time.Duration cleanupWindow time.Duration + + // metrics is nil-safe; methods on telemetry.EphemeralPeersMetrics + // no-op when the receiver is nil so deployments without an app + // metrics provider work unchanged. + metrics *telemetry.EphemeralPeersMetrics } // NewEphemeralManager instantiate new EphemeralManager @@ -60,6 +66,15 @@ func NewEphemeralManager(store store.Store, peersManager peers.Manager) *Ephemer } } +// SetMetrics attaches a metrics collector. Safe to call once before +// LoadInitialPeers; later attachment is fine but earlier loads won't be +// reflected in the gauge. Pass nil to detach. +func (e *EphemeralManager) SetMetrics(m *telemetry.EphemeralPeersMetrics) { + e.peersLock.Lock() + e.metrics = m + e.peersLock.Unlock() +} + // LoadInitialPeers load from the database the ephemeral type of peers and schedule a cleanup procedure to the head // of the linked list (to the most deprecated peer). At the end of cleanup it schedules the next cleanup to the new // head. @@ -97,7 +112,9 @@ func (e *EphemeralManager) OnPeerConnected(ctx context.Context, peer *nbpeer.Pee e.peersLock.Lock() defer e.peersLock.Unlock() - e.removePeer(peer.ID) + if e.removePeer(peer.ID) { + e.metrics.DecPending(1) + } // stop the unnecessary timer if e.headPeer == nil && e.timer != nil { @@ -123,6 +140,7 @@ func (e *EphemeralManager) OnPeerDisconnected(ctx context.Context, peer *nbpeer. } e.addPeer(peer.AccountID, peer.ID, e.newDeadLine()) + e.metrics.IncPending() if e.timer == nil { delay := e.headPeer.deadline.Sub(timeNow()) + e.cleanupWindow if delay < 0 { @@ -145,6 +163,7 @@ func (e *EphemeralManager) loadEphemeralPeers(ctx context.Context) { for _, p := range peers { e.addPeer(p.AccountID, p.ID, t) } + e.metrics.AddPending(int64(len(peers))) log.WithContext(ctx).Debugf("loaded ephemeral peer(s): %d", len(peers)) } @@ -181,6 +200,15 @@ func (e *EphemeralManager) cleanup(ctx context.Context) { e.peersLock.Unlock() + // Drop the gauge by the number of entries we just took off the list, + // regardless of whether the subsequent DeletePeers call succeeds. The + // list invariant is what the gauge tracks; failed delete batches are + // counted separately via CountCleanupError so we can still see them. + if len(deletePeers) > 0 { + e.metrics.CountCleanupRun() + e.metrics.DecPending(int64(len(deletePeers))) + } + peerIDsPerAccount := make(map[string][]string) for id, p := range deletePeers { peerIDsPerAccount[p.accountID] = append(peerIDsPerAccount[p.accountID], id) @@ -191,7 +219,10 @@ func (e *EphemeralManager) cleanup(ctx context.Context) { err := e.peersManager.DeletePeers(ctx, accountID, peerIDs, activity.SystemInitiator, true) if err != nil { log.WithContext(ctx).Errorf("failed to delete ephemeral peers: %s", err) + e.metrics.CountCleanupError() + continue } + e.metrics.CountPeersCleaned(int64(len(peerIDs))) } } @@ -211,9 +242,12 @@ func (e *EphemeralManager) addPeer(accountID string, peerID string, deadline tim e.tailPeer = ep } -func (e *EphemeralManager) removePeer(id string) { +// removePeer drops the entry from the linked list. Returns true if a +// matching entry was found and removed so callers can keep the pending +// metric gauge in sync. +func (e *EphemeralManager) removePeer(id string) bool { if e.headPeer == nil { - return + return false } if e.headPeer.id == id { @@ -221,7 +255,7 @@ func (e *EphemeralManager) removePeer(id string) { if e.tailPeer.id == id { e.tailPeer = nil } - return + return true } for p := e.headPeer; p.next != nil; p = p.next { @@ -231,9 +265,10 @@ func (e *EphemeralManager) removePeer(id string) { e.tailPeer = p } p.next = p.next.next - return + return true } } + return false } func (e *EphemeralManager) isPeerOnList(id string) bool { diff --git a/management/internals/server/controllers.go b/management/internals/server/controllers.go index 89bdf0abe..794c3ebe0 100644 --- a/management/internals/server/controllers.go +++ b/management/internals/server/controllers.go @@ -112,7 +112,11 @@ func (s *BaseServer) AuthManager() auth.Manager { func (s *BaseServer) EphemeralManager() ephemeral.Manager { return Create(s, func() ephemeral.Manager { - return manager.NewEphemeralManager(s.Store(), s.PeersManager()) + em := manager.NewEphemeralManager(s.Store(), s.PeersManager()) + if metrics := s.Metrics(); metrics != nil { + em.SetMetrics(metrics.EphemeralPeersMetrics()) + } + return em }) } diff --git a/management/server/peer.go b/management/server/peer.go index 4790a5aab..34b681f51 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -28,6 +28,7 @@ import ( "github.com/netbirdio/netbird/management/server/activity" nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/shared/management/status" ) @@ -72,19 +73,32 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID // Disconnects use MarkPeerDisconnected and require the session to match // exactly; see PeerStatus.SessionStartedAt for the protocol. func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubKey string, realIP net.IP, accountID string, sessionStartedAt int64) error { + start := time.Now() + defer func() { + am.metrics.AccountManagerMetrics().RecordPeerStatusUpdateDuration(telemetry.PeerStatusConnect, time.Since(start)) + }() + peer, err := am.Store.GetPeerByPeerPubKey(ctx, store.LockingStrengthNone, peerPubKey) if err != nil { + outcome := telemetry.PeerStatusError + if s, ok := status.FromError(err); ok && s.Type() == status.NotFound { + outcome = telemetry.PeerStatusPeerNotFound + } + am.metrics.AccountManagerMetrics().CountPeerStatusUpdate(telemetry.PeerStatusConnect, outcome) return err } updated, err := am.Store.MarkPeerConnectedIfNewerSession(ctx, accountID, peer.ID, sessionStartedAt) if err != nil { + am.metrics.AccountManagerMetrics().CountPeerStatusUpdate(telemetry.PeerStatusConnect, telemetry.PeerStatusError) return err } if !updated { + am.metrics.AccountManagerMetrics().CountPeerStatusUpdate(telemetry.PeerStatusConnect, telemetry.PeerStatusStale) log.WithContext(ctx).Tracef("peer %s already has a newer session in store, skipping connect", peer.ID) return nil } + am.metrics.AccountManagerMetrics().CountPeerStatusUpdate(telemetry.PeerStatusConnect, telemetry.PeerStatusApplied) if am.geo != nil && realIP != nil { am.updatePeerLocationIfChanged(ctx, accountID, peer, realIP) @@ -119,19 +133,33 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK // newer stream has already taken ownership of the peer โ€” disconnects from // the older stream are ignored. LastSeen is written by the database. func (am *DefaultAccountManager) MarkPeerDisconnected(ctx context.Context, peerPubKey string, accountID string, sessionStartedAt int64) error { + start := time.Now() + defer func() { + am.metrics.AccountManagerMetrics().RecordPeerStatusUpdateDuration(telemetry.PeerStatusDisconnect, time.Since(start)) + }() + peer, err := am.Store.GetPeerByPeerPubKey(ctx, store.LockingStrengthNone, peerPubKey) if err != nil { + outcome := telemetry.PeerStatusError + if s, ok := status.FromError(err); ok && s.Type() == status.NotFound { + outcome = telemetry.PeerStatusPeerNotFound + } + am.metrics.AccountManagerMetrics().CountPeerStatusUpdate(telemetry.PeerStatusDisconnect, outcome) return err } updated, err := am.Store.MarkPeerDisconnectedIfSameSession(ctx, accountID, peer.ID, sessionStartedAt) if err != nil { + am.metrics.AccountManagerMetrics().CountPeerStatusUpdate(telemetry.PeerStatusDisconnect, telemetry.PeerStatusError) return err } if !updated { + am.metrics.AccountManagerMetrics().CountPeerStatusUpdate(telemetry.PeerStatusDisconnect, telemetry.PeerStatusStale) log.WithContext(ctx).Tracef("peer %s session token mismatch on disconnect (token=%d), skipping", peer.ID, sessionStartedAt) + return nil } + am.metrics.AccountManagerMetrics().CountPeerStatusUpdate(telemetry.PeerStatusDisconnect, telemetry.PeerStatusApplied) return nil } diff --git a/management/server/telemetry/accountmanager_metrics.go b/management/server/telemetry/accountmanager_metrics.go index 518aae7eb..bb6fb7e12 100644 --- a/management/server/telemetry/accountmanager_metrics.go +++ b/management/server/telemetry/accountmanager_metrics.go @@ -16,6 +16,8 @@ type AccountManagerMetrics struct { getPeerNetworkMapDurationMs metric.Float64Histogram networkMapObjectCount metric.Int64Histogram peerMetaUpdateCount metric.Int64Counter + peerStatusUpdateCounter metric.Int64Counter + peerStatusUpdateDurationMs metric.Float64Histogram } // NewAccountManagerMetrics creates an instance of AccountManagerMetrics @@ -64,6 +66,24 @@ func NewAccountManagerMetrics(ctx context.Context, meter metric.Meter) (*Account return nil, err } + // peerStatusUpdateCounter records every attempt to mark a peer as connected or disconnected + peerStatusUpdateCounter, err := meter.Int64Counter("management.account.peer.status.update.counter", + metric.WithUnit("1"), + metric.WithDescription("Number of peer status update attempts, labeled by operation (connect|disconnect) and outcome (applied|stale|error|peer_not_found)")) + if err != nil { + return nil, err + } + + peerStatusUpdateDurationMs, err := meter.Float64Histogram("management.account.peer.status.update.duration.ms", + metric.WithUnit("milliseconds"), + metric.WithExplicitBucketBoundaries( + 1, 5, 15, 25, 50, 100, 250, 500, 1000, 2000, 5000, + ), + metric.WithDescription("Duration of a peer status update (fence UPDATE + post-write side effects), labeled by operation")) + if err != nil { + return nil, err + } + return &AccountManagerMetrics{ ctx: ctx, getPeerNetworkMapDurationMs: getPeerNetworkMapDurationMs, @@ -71,10 +91,35 @@ func NewAccountManagerMetrics(ctx context.Context, meter metric.Meter) (*Account updateAccountPeersCounter: updateAccountPeersCounter, networkMapObjectCount: networkMapObjectCount, peerMetaUpdateCount: peerMetaUpdateCount, + peerStatusUpdateCounter: peerStatusUpdateCounter, + peerStatusUpdateDurationMs: peerStatusUpdateDurationMs, }, nil } +// PeerStatusOperation labels the kind of fence-locked peer status write. +type PeerStatusOperation string + +// PeerStatusOutcome labels how a fence-locked peer status write resolved. +type PeerStatusOutcome string + +const ( + PeerStatusConnect PeerStatusOperation = "connect" + PeerStatusDisconnect PeerStatusOperation = "disconnect" + + // PeerStatusApplied โ€” the fence WHERE matched and the UPDATE landed. + PeerStatusApplied PeerStatusOutcome = "applied" + // PeerStatusStale โ€” the fence WHERE rejected the write because a + // newer session has already taken ownership (connect: stored token + // >= incoming; disconnect: stored token != incoming). + PeerStatusStale PeerStatusOutcome = "stale" + // PeerStatusError โ€” the store returned a non-NotFound error. + PeerStatusError PeerStatusOutcome = "error" + // PeerStatusPeerNotFound โ€” the peer lookup failed (the peer was + // deleted between the gRPC sync handshake and the status write). + PeerStatusPeerNotFound PeerStatusOutcome = "peer_not_found" +) + // CountUpdateAccountPeersDuration counts the duration of updating account peers func (metrics *AccountManagerMetrics) CountUpdateAccountPeersDuration(duration time.Duration) { metrics.updateAccountPeersDurationMs.Record(metrics.ctx, float64(duration.Nanoseconds())/1e6) @@ -104,3 +149,23 @@ func (metrics *AccountManagerMetrics) CountUpdateAccountPeersTriggered(resource, func (metrics *AccountManagerMetrics) CountPeerMetUpdate() { metrics.peerMetaUpdateCount.Add(metrics.ctx, 1) } + +// CountPeerStatusUpdate increments the connect/disconnect counter, +// labeled by operation and outcome. Both labels are bounded enums. +func (metrics *AccountManagerMetrics) CountPeerStatusUpdate(op PeerStatusOperation, outcome PeerStatusOutcome) { + metrics.peerStatusUpdateCounter.Add(metrics.ctx, 1, + metric.WithAttributes( + attribute.String("operation", string(op)), + attribute.String("outcome", string(outcome)), + ), + ) +} + +// RecordPeerStatusUpdateDuration records the wall-clock time spent +// running a peer status update (including post-write side effects), +// labeled by operation. +func (metrics *AccountManagerMetrics) RecordPeerStatusUpdateDuration(op PeerStatusOperation, d time.Duration) { + metrics.peerStatusUpdateDurationMs.Record(metrics.ctx, float64(d.Nanoseconds())/1e6, + metric.WithAttributes(attribute.String("operation", string(op))), + ) +} diff --git a/management/server/telemetry/app_metrics.go b/management/server/telemetry/app_metrics.go index 1fd78bc3a..fd9087a96 100644 --- a/management/server/telemetry/app_metrics.go +++ b/management/server/telemetry/app_metrics.go @@ -29,6 +29,7 @@ type MockAppMetrics struct { StoreMetricsFunc func() *StoreMetrics UpdateChannelMetricsFunc func() *UpdateChannelMetrics AddAccountManagerMetricsFunc func() *AccountManagerMetrics + EphemeralPeersMetricsFunc func() *EphemeralPeersMetrics } // GetMeter mocks the GetMeter function of the AppMetrics interface @@ -103,6 +104,14 @@ func (mock *MockAppMetrics) AccountManagerMetrics() *AccountManagerMetrics { return nil } +// EphemeralPeersMetrics mocks the MockAppMetrics function of the EphemeralPeersMetrics interface +func (mock *MockAppMetrics) EphemeralPeersMetrics() *EphemeralPeersMetrics { + if mock.EphemeralPeersMetricsFunc != nil { + return mock.EphemeralPeersMetricsFunc() + } + return nil +} + // AppMetrics is metrics interface type AppMetrics interface { GetMeter() metric2.Meter @@ -114,6 +123,7 @@ type AppMetrics interface { StoreMetrics() *StoreMetrics UpdateChannelMetrics() *UpdateChannelMetrics AccountManagerMetrics() *AccountManagerMetrics + EphemeralPeersMetrics() *EphemeralPeersMetrics } // defaultAppMetrics are core application metrics based on OpenTelemetry https://opentelemetry.io/ @@ -129,6 +139,7 @@ type defaultAppMetrics struct { storeMetrics *StoreMetrics updateChannelMetrics *UpdateChannelMetrics accountManagerMetrics *AccountManagerMetrics + ephemeralMetrics *EphemeralPeersMetrics } // IDPMetrics returns metrics for the idp package @@ -161,6 +172,11 @@ func (appMetrics *defaultAppMetrics) AccountManagerMetrics() *AccountManagerMetr return appMetrics.accountManagerMetrics } +// EphemeralPeersMetrics returns metrics for the ephemeral peer cleanup loop +func (appMetrics *defaultAppMetrics) EphemeralPeersMetrics() *EphemeralPeersMetrics { + return appMetrics.ephemeralMetrics +} + // Close stop application metrics HTTP handler and closes listener. func (appMetrics *defaultAppMetrics) Close() error { if appMetrics.listener == nil { @@ -245,6 +261,11 @@ func NewDefaultAppMetrics(ctx context.Context) (AppMetrics, error) { return nil, fmt.Errorf("failed to initialize account manager metrics: %w", err) } + ephemeralMetrics, err := NewEphemeralPeersMetrics(ctx, meter) + if err != nil { + return nil, fmt.Errorf("failed to initialize ephemeral peers metrics: %w", err) + } + return &defaultAppMetrics{ Meter: meter, ctx: ctx, @@ -254,6 +275,7 @@ func NewDefaultAppMetrics(ctx context.Context) (AppMetrics, error) { storeMetrics: storeMetrics, updateChannelMetrics: updateChannelMetrics, accountManagerMetrics: accountManagerMetrics, + ephemeralMetrics: ephemeralMetrics, }, nil } @@ -290,6 +312,11 @@ func NewAppMetricsWithMeter(ctx context.Context, meter metric2.Meter) (AppMetric return nil, fmt.Errorf("failed to initialize account manager metrics: %w", err) } + ephemeralMetrics, err := NewEphemeralPeersMetrics(ctx, meter) + if err != nil { + return nil, fmt.Errorf("failed to initialize ephemeral peers metrics: %w", err) + } + return &defaultAppMetrics{ Meter: meter, ctx: ctx, @@ -300,5 +327,6 @@ func NewAppMetricsWithMeter(ctx context.Context, meter metric2.Meter) (AppMetric storeMetrics: storeMetrics, updateChannelMetrics: updateChannelMetrics, accountManagerMetrics: accountManagerMetrics, + ephemeralMetrics: ephemeralMetrics, }, nil } diff --git a/management/server/telemetry/ephemeral_metrics.go b/management/server/telemetry/ephemeral_metrics.go new file mode 100644 index 000000000..a7fb432f8 --- /dev/null +++ b/management/server/telemetry/ephemeral_metrics.go @@ -0,0 +1,115 @@ +package telemetry + +import ( + "context" + + "go.opentelemetry.io/otel/metric" +) + +// EphemeralPeersMetrics tracks the ephemeral peer cleanup pipeline: how +// many peers are currently scheduled for deletion, how many tick runs +// the cleaner has performed, how many peers it has removed, and how +// many delete batches failed. +type EphemeralPeersMetrics struct { + ctx context.Context + + pending metric.Int64UpDownCounter + cleanupRuns metric.Int64Counter + peersCleaned metric.Int64Counter + errors metric.Int64Counter +} + +// NewEphemeralPeersMetrics constructs the ephemeral cleanup counters. +func NewEphemeralPeersMetrics(ctx context.Context, meter metric.Meter) (*EphemeralPeersMetrics, error) { + pending, err := meter.Int64UpDownCounter("management.ephemeral.peers.pending", + metric.WithUnit("1"), + metric.WithDescription("Number of ephemeral peers currently waiting to be cleaned up")) + if err != nil { + return nil, err + } + + cleanupRuns, err := meter.Int64Counter("management.ephemeral.cleanup.runs.counter", + metric.WithUnit("1"), + metric.WithDescription("Number of ephemeral cleanup ticks that processed at least one peer")) + if err != nil { + return nil, err + } + + peersCleaned, err := meter.Int64Counter("management.ephemeral.peers.cleaned.counter", + metric.WithUnit("1"), + metric.WithDescription("Total number of ephemeral peers deleted by the cleanup loop")) + if err != nil { + return nil, err + } + + errors, err := meter.Int64Counter("management.ephemeral.cleanup.errors.counter", + metric.WithUnit("1"), + metric.WithDescription("Number of ephemeral cleanup batches (per account) that failed to delete")) + if err != nil { + return nil, err + } + + return &EphemeralPeersMetrics{ + ctx: ctx, + pending: pending, + cleanupRuns: cleanupRuns, + peersCleaned: peersCleaned, + errors: errors, + }, nil +} + +// All methods are nil-receiver safe so callers that haven't wired metrics +// (tests, self-hosted with metrics off) can invoke them unconditionally. + +// IncPending bumps the pending gauge when a peer is added to the cleanup list. +func (m *EphemeralPeersMetrics) IncPending() { + if m == nil { + return + } + m.pending.Add(m.ctx, 1) +} + +// AddPending bumps the pending gauge by n โ€” used at startup when the +// initial set of ephemeral peers is loaded from the store. +func (m *EphemeralPeersMetrics) AddPending(n int64) { + if m == nil || n <= 0 { + return + } + m.pending.Add(m.ctx, n) +} + +// DecPending decreases the pending gauge โ€” used both when a peer reconnects +// before its deadline (removed from the list) and when a cleanup tick +// actually deletes it. +func (m *EphemeralPeersMetrics) DecPending(n int64) { + if m == nil || n <= 0 { + return + } + m.pending.Add(m.ctx, -n) +} + +// CountCleanupRun records one cleanup pass that processed >0 peers. Idle +// ticks (nothing to do) deliberately don't increment so the rate +// reflects useful work. +func (m *EphemeralPeersMetrics) CountCleanupRun() { + if m == nil { + return + } + m.cleanupRuns.Add(m.ctx, 1) +} + +// CountPeersCleaned records the number of peers a single tick deleted. +func (m *EphemeralPeersMetrics) CountPeersCleaned(n int64) { + if m == nil || n <= 0 { + return + } + m.peersCleaned.Add(m.ctx, n) +} + +// CountCleanupError records a failed delete batch. +func (m *EphemeralPeersMetrics) CountCleanupError() { + if m == nil { + return + } + m.errors.Add(m.ctx, 1) +} From 80966ab1b09bd86b7a526d9402b6a47438bc0943 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Wed, 20 May 2026 08:25:30 +0200 Subject: [PATCH 15/31] [management] Ensure SessionStartedAt has a default value (#6211) * [management] Ensure SessionStartedAt has a default value Avoid null values for the new column * [management] Add PeerStatus with LastSeen in peer_test * [management] Add migration for PeerStatusSessionStartedAt default value * [management] Add PeerStatus with LastSeen in migration tests --- management/server/migration/migration_test.go | 6 +++++- management/server/peer/peer.go | 2 +- management/server/peer_test.go | 3 +++ management/server/store/store.go | 3 +++ 4 files changed, 12 insertions(+), 2 deletions(-) diff --git a/management/server/migration/migration_test.go b/management/server/migration/migration_test.go index 5e00976c2..cc97c2dff 100644 --- a/management/server/migration/migration_test.go +++ b/management/server/migration/migration_test.go @@ -198,7 +198,11 @@ func TestMigrateNetIPFieldFromBlobToJSON_WithJSONData(t *testing.T) { require.NoError(t, err, "Failed to insert account") account.PeersG = []nbpeer.Peer{ - {AccountID: "1234", Location: nbpeer.Location{ConnectionIP: net.IP{10, 0, 0, 1}}}, + { + AccountID: "1234", + Location: nbpeer.Location{ConnectionIP: net.IP{10, 0, 0, 1}}, + Status: &nbpeer.PeerStatus{LastSeen: time.Now()}, + }, } err = db.Save(account).Error diff --git a/management/server/peer/peer.go b/management/server/peer/peer.go index 2963dfcbd..6294d1c0a 100644 --- a/management/server/peer/peer.go +++ b/management/server/peer/peer.go @@ -86,7 +86,7 @@ type PeerStatus struct { //nolint:revive // active session". Integer nanoseconds are used so equality is // precision-safe across drivers, and so the predicates compose to a // single bigint comparison. - SessionStartedAt int64 + SessionStartedAt int64 `gorm:"not null;default:0"` // Connected indicates whether peer is connected to the management service or not Connected bool // LoginExpired diff --git a/management/server/peer_test.go b/management/server/peer_test.go index 07acf865f..9d6856740 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -2218,6 +2218,9 @@ func Test_IsUniqueConstraintError(t *testing.T) { ID: "test-peer-id", AccountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b", DNSLabel: "test-peer-dns-label", + Status: &nbpeer.PeerStatus{ + LastSeen: time.Now(), + }, } for _, tt := range tests { diff --git a/management/server/store/store.go b/management/server/store/store.go index a723c1fc3..045f1576a 100644 --- a/management/server/store/store.go +++ b/management/server/store/store.go @@ -471,6 +471,9 @@ func getMigrationsPreAuto(ctx context.Context) []migrationFunc { func(db *gorm.DB) error { return migration.MigrateNewField[types.User](ctx, db, "email", "") }, + func(db *gorm.DB) error { + return migration.MigrateNewField[nbpeer.Peer](ctx, db, "peer_status_session_started_at", int64(0)) + }, func(db *gorm.DB) error { return migration.RemoveDuplicatePeerKeys(ctx, db) }, From d250f92c435bac83fd55f00fad3ee2c292eee910 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Wed, 20 May 2026 10:08:34 +0200 Subject: [PATCH 16/31] feat(reverse-proxy): clusters API surfaces type, online status, and capability flags (#6148) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The cluster listing now answers three questions in one round-trip instead of forcing the dashboard to cross-reference the domains API: which clusters can this account see, are they currently up, and what do they support. The ProxyCluster wire type drops the boolean self_hosted in favour of a `type` enum (`account` / `shared`) plus explicit `online`, `supports_custom_ports`, `require_subdomain`, and `supports_crowdsec` fields. Store query reworked so offline clusters still appear (no last_seen WHERE), with online and connected_proxies both derived from the existing 2-min active window via portable CASE expressions; the 1-hour heartbeat reaper still removes long-stale rows. Service manager enriches each cluster with the capability flags via the existing per-cluster lookups (CapabilityProvider now also exposes ClusterSupportsCrowdSec). GetActiveClusterAddresses* keep their tight 2-min filter so service routing and domain enumeration aren't pulled into the wider window. The hard cut removes self_hosted from the response โ€” the dashboard is the only consumer and is updated in the matching PR; no transitional field is shipped. Adds a cross-engine regression test asserting offline clusters surface, connected_proxies counts only fresh proxies, and account-scoped BYOP clusters never leak across accounts. --- .../reverseproxy/proxy/manager/manager.go | 2 +- .../proxy/manager/manager_test.go | 2 +- .../modules/reverseproxy/proxy/proxy.go | 27 ++++- .../modules/reverseproxy/service/interface.go | 2 +- .../reverseproxy/service/interface_mock.go | 72 ++++++------ .../reverseproxy/service/manager/api.go | 14 ++- .../reverseproxy/service/manager/manager.go | 22 +++- .../shared/grpc/proxy_group_access_test.go | 2 +- .../shared/grpc/validate_session_test.go | 2 +- .../proxy/auth_callback_integration_test.go | 2 +- management/server/store/sql_store.go | 64 ++++++++-- .../store/sql_store_proxy_clusters_test.go | 109 ++++++++++++++++++ management/server/store/store.go | 2 +- management/server/store/store_mock.go | 86 +++++++------- proxy/management_integration_test.go | 2 +- shared/management/http/api/openapi.yml | 32 ++++- shared/management/http/api/types.gen.go | 73 +++++++++--- 17 files changed, 393 insertions(+), 122 deletions(-) create mode 100644 management/server/store/sql_store_proxy_clusters_test.go diff --git a/management/internals/modules/reverseproxy/proxy/manager/manager.go b/management/internals/modules/reverseproxy/proxy/manager/manager.go index b72a6ebe5..510500e0c 100644 --- a/management/internals/modules/reverseproxy/proxy/manager/manager.go +++ b/management/internals/modules/reverseproxy/proxy/manager/manager.go @@ -17,7 +17,7 @@ type store interface { UpdateProxyHeartbeat(ctx context.Context, p *proxy.Proxy) error GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error) GetActiveProxyClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error) - GetActiveProxyClusters(ctx context.Context, accountID string) ([]proxy.Cluster, error) + GetProxyClusters(ctx context.Context, accountID string) ([]proxy.Cluster, error) GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool GetClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool GetClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool diff --git a/management/internals/modules/reverseproxy/proxy/manager/manager_test.go b/management/internals/modules/reverseproxy/proxy/manager/manager_test.go index 3c53fe684..3436216b4 100644 --- a/management/internals/modules/reverseproxy/proxy/manager/manager_test.go +++ b/management/internals/modules/reverseproxy/proxy/manager/manager_test.go @@ -57,7 +57,7 @@ func (m *mockStore) GetActiveProxyClusterAddressesForAccount(ctx context.Context } return nil, nil } -func (m *mockStore) GetActiveProxyClusters(_ context.Context, _ string) ([]proxy.Cluster, error) { +func (m *mockStore) GetProxyClusters(_ context.Context, _ string) ([]proxy.Cluster, error) { return nil, nil } func (m *mockStore) CleanupStaleProxies(ctx context.Context, d time.Duration) error { diff --git a/management/internals/modules/reverseproxy/proxy/proxy.go b/management/internals/modules/reverseproxy/proxy/proxy.go index 64394799e..9da7910df 100644 --- a/management/internals/modules/reverseproxy/proxy/proxy.go +++ b/management/internals/modules/reverseproxy/proxy/proxy.go @@ -42,10 +42,35 @@ func (Proxy) TableName() string { return "proxies" } +// ClusterType is the source of a proxy cluster. +type ClusterType string + +const ( + // ClusterTypeAccount is a cluster operated by the account itself (BYOP) โ€” + // at least one proxy row in the cluster carries a non-NULL account_id. + ClusterTypeAccount ClusterType = "account" + // ClusterTypeShared is a cluster operated by NetBird and shared across + // accounts โ€” all proxy rows in the cluster have account_id IS NULL. + ClusterTypeShared ClusterType = "shared" +) + // Cluster represents a group of proxy nodes serving the same address. +// +// Online and ConnectedProxies derive from the same 2-min active window +// the rest of the module uses, but Cluster rows are not gated on it โ€” +// the cluster listing surfaces offline clusters too so operators can +// see and clean them up. The 1-hour heartbeat reaper still bounds the +// table eventually. type Cluster struct { ID string Address string + Type ClusterType + Online bool ConnectedProxies int - SelfHosted bool + // Capability flags. *bool because nil means "no proxy reported a + // capability for this cluster" โ€” the dashboard renders these as + // unknown rather than false. + SupportsCustomPorts *bool + RequireSubdomain *bool + SupportsCrowdSec *bool } diff --git a/management/internals/modules/reverseproxy/service/interface.go b/management/internals/modules/reverseproxy/service/interface.go index 6a94aa32b..dddf6ae8a 100644 --- a/management/internals/modules/reverseproxy/service/interface.go +++ b/management/internals/modules/reverseproxy/service/interface.go @@ -9,7 +9,7 @@ import ( ) type Manager interface { - GetActiveClusters(ctx context.Context, accountID, userID string) ([]proxy.Cluster, error) + GetClusters(ctx context.Context, accountID, userID string) ([]proxy.Cluster, error) DeleteAccountCluster(ctx context.Context, accountID, userID, clusterAddress string) error GetAllServices(ctx context.Context, accountID, userID string) ([]*Service, error) GetService(ctx context.Context, accountID, userID, serviceID string) (*Service, error) diff --git a/management/internals/modules/reverseproxy/service/interface_mock.go b/management/internals/modules/reverseproxy/service/interface_mock.go index 83b2162ed..24963fe30 100644 --- a/management/internals/modules/reverseproxy/service/interface_mock.go +++ b/management/internals/modules/reverseproxy/service/interface_mock.go @@ -65,20 +65,6 @@ func (mr *MockManagerMockRecorder) CreateServiceFromPeer(ctx, accountID, peerID, return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateServiceFromPeer", reflect.TypeOf((*MockManager)(nil).CreateServiceFromPeer), ctx, accountID, peerID, req) } -// DeleteAllServices mocks base method. -func (m *MockManager) DeleteAllServices(ctx context.Context, accountID, userID string) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "DeleteAllServices", ctx, accountID, userID) - ret0, _ := ret[0].(error) - return ret0 -} - -// DeleteAllServices indicates an expected call of DeleteAllServices. -func (mr *MockManagerMockRecorder) DeleteAllServices(ctx, accountID, userID interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAllServices", reflect.TypeOf((*MockManager)(nil).DeleteAllServices), ctx, accountID, userID) -} - // DeleteAccountCluster mocks base method. func (m *MockManager) DeleteAccountCluster(ctx context.Context, accountID, userID, clusterAddress string) error { m.ctrl.T.Helper() @@ -93,6 +79,20 @@ func (mr *MockManagerMockRecorder) DeleteAccountCluster(ctx, accountID, userID, return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAccountCluster", reflect.TypeOf((*MockManager)(nil).DeleteAccountCluster), ctx, accountID, userID, clusterAddress) } +// DeleteAllServices mocks base method. +func (m *MockManager) DeleteAllServices(ctx context.Context, accountID, userID string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteAllServices", ctx, accountID, userID) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeleteAllServices indicates an expected call of DeleteAllServices. +func (mr *MockManagerMockRecorder) DeleteAllServices(ctx, accountID, userID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAllServices", reflect.TypeOf((*MockManager)(nil).DeleteAllServices), ctx, accountID, userID) +} + // DeleteService mocks base method. func (m *MockManager) DeleteService(ctx context.Context, accountID, userID, serviceID string) error { m.ctrl.T.Helper() @@ -122,21 +122,6 @@ func (mr *MockManagerMockRecorder) GetAccountServices(ctx, accountID interface{} return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAccountServices", reflect.TypeOf((*MockManager)(nil).GetAccountServices), ctx, accountID) } -// GetActiveClusters mocks base method. -func (m *MockManager) GetActiveClusters(ctx context.Context, accountID, userID string) ([]proxy.Cluster, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetActiveClusters", ctx, accountID, userID) - ret0, _ := ret[0].([]proxy.Cluster) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetActiveClusters indicates an expected call of GetActiveClusters. -func (mr *MockManagerMockRecorder) GetActiveClusters(ctx, accountID, userID interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveClusters", reflect.TypeOf((*MockManager)(nil).GetActiveClusters), ctx, accountID, userID) -} - // GetAllServices mocks base method. func (m *MockManager) GetAllServices(ctx context.Context, accountID, userID string) ([]*Service, error) { m.ctrl.T.Helper() @@ -152,19 +137,19 @@ func (mr *MockManagerMockRecorder) GetAllServices(ctx, accountID, userID interfa return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAllServices", reflect.TypeOf((*MockManager)(nil).GetAllServices), ctx, accountID, userID) } -// GetServiceByDomain mocks base method. -func (m *MockManager) GetServiceByDomain(ctx context.Context, domain string) (*Service, error) { +// GetClusters mocks base method. +func (m *MockManager) GetClusters(ctx context.Context, accountID, userID string) ([]proxy.Cluster, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetServiceByDomain", ctx, domain) - ret0, _ := ret[0].(*Service) + ret := m.ctrl.Call(m, "GetClusters", ctx, accountID, userID) + ret0, _ := ret[0].([]proxy.Cluster) ret1, _ := ret[1].(error) return ret0, ret1 } -// GetServiceByDomain indicates an expected call of GetServiceByDomain. -func (mr *MockManagerMockRecorder) GetServiceByDomain(ctx, domain interface{}) *gomock.Call { +// GetClusters indicates an expected call of GetClusters. +func (mr *MockManagerMockRecorder) GetClusters(ctx, accountID, userID interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServiceByDomain", reflect.TypeOf((*MockManager)(nil).GetServiceByDomain), ctx, domain) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClusters", reflect.TypeOf((*MockManager)(nil).GetClusters), ctx, accountID, userID) } // GetGlobalServices mocks base method. @@ -197,6 +182,21 @@ func (mr *MockManagerMockRecorder) GetService(ctx, accountID, userID, serviceID return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetService", reflect.TypeOf((*MockManager)(nil).GetService), ctx, accountID, userID, serviceID) } +// GetServiceByDomain mocks base method. +func (m *MockManager) GetServiceByDomain(ctx context.Context, domain string) (*Service, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetServiceByDomain", ctx, domain) + ret0, _ := ret[0].(*Service) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetServiceByDomain indicates an expected call of GetServiceByDomain. +func (mr *MockManagerMockRecorder) GetServiceByDomain(ctx, domain interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServiceByDomain", reflect.TypeOf((*MockManager)(nil).GetServiceByDomain), ctx, domain) +} + // GetServiceByID mocks base method. func (m *MockManager) GetServiceByID(ctx context.Context, accountID, serviceID string) (*Service, error) { m.ctrl.T.Helper() diff --git a/management/internals/modules/reverseproxy/service/manager/api.go b/management/internals/modules/reverseproxy/service/manager/api.go index 08272077c..9d93d52ee 100644 --- a/management/internals/modules/reverseproxy/service/manager/api.go +++ b/management/internals/modules/reverseproxy/service/manager/api.go @@ -187,7 +187,7 @@ func (h *handler) getClusters(w http.ResponseWriter, r *http.Request) { return } - clusters, err := h.manager.GetActiveClusters(r.Context(), userAuth.AccountId, userAuth.UserId) + clusters, err := h.manager.GetClusters(r.Context(), userAuth.AccountId, userAuth.UserId) if err != nil { util.WriteError(r.Context(), err, w) return @@ -196,10 +196,14 @@ func (h *handler) getClusters(w http.ResponseWriter, r *http.Request) { apiClusters := make([]api.ProxyCluster, 0, len(clusters)) for _, c := range clusters { apiClusters = append(apiClusters, api.ProxyCluster{ - Id: c.ID, - Address: c.Address, - ConnectedProxies: c.ConnectedProxies, - SelfHosted: c.SelfHosted, + Id: c.ID, + Address: c.Address, + Type: api.ProxyClusterType(c.Type), + Online: c.Online, + ConnectedProxies: c.ConnectedProxies, + SupportsCustomPorts: c.SupportsCustomPorts, + RequireSubdomain: c.RequireSubdomain, + SupportsCrowdsec: c.SupportsCrowdSec, }) } diff --git a/management/internals/modules/reverseproxy/service/manager/manager.go b/management/internals/modules/reverseproxy/service/manager/manager.go index 4a8598afb..ca0c5540f 100644 --- a/management/internals/modules/reverseproxy/service/manager/manager.go +++ b/management/internals/modules/reverseproxy/service/manager/manager.go @@ -81,6 +81,7 @@ type ClusterDeriver interface { type CapabilityProvider interface { ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool ClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool + ClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool } type Manager struct { @@ -112,8 +113,12 @@ func (m *Manager) StartExposeReaper(ctx context.Context) { m.exposeReaper.StartExposeReaper(ctx) } -// GetActiveClusters returns all active proxy clusters with their connected proxy count. -func (m *Manager) GetActiveClusters(ctx context.Context, accountID, userID string) ([]proxy.Cluster, error) { +// GetClusters returns every proxy cluster visible to the account +// (shared + its own BYOP), regardless of whether any proxy in the +// cluster is currently heartbeating. Each cluster is enriched with the +// capability flags reported by its active proxies so the dashboard can +// render feature support without a second round-trip. +func (m *Manager) GetClusters(ctx context.Context, accountID, userID string) ([]proxy.Cluster, error) { ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read) if err != nil { return nil, status.NewPermissionValidationError(err) @@ -122,7 +127,18 @@ func (m *Manager) GetActiveClusters(ctx context.Context, accountID, userID strin return nil, status.NewPermissionDeniedError() } - return m.store.GetActiveProxyClusters(ctx, accountID) + clusters, err := m.store.GetProxyClusters(ctx, accountID) + if err != nil { + return nil, err + } + + for i := range clusters { + clusters[i].SupportsCustomPorts = m.capabilities.ClusterSupportsCustomPorts(ctx, clusters[i].Address) + clusters[i].RequireSubdomain = m.capabilities.ClusterRequireSubdomain(ctx, clusters[i].Address) + clusters[i].SupportsCrowdSec = m.capabilities.ClusterSupportsCrowdSec(ctx, clusters[i].Address) + } + + return clusters, nil } // DeleteAccountCluster removes all proxy registrations for the given cluster address diff --git a/management/internals/shared/grpc/proxy_group_access_test.go b/management/internals/shared/grpc/proxy_group_access_test.go index 46dad5b56..5980f8a30 100644 --- a/management/internals/shared/grpc/proxy_group_access_test.go +++ b/management/internals/shared/grpc/proxy_group_access_test.go @@ -109,7 +109,7 @@ func (m *mockReverseProxyManager) GetServiceByDomain(_ context.Context, domain s return nil, errors.New("service not found for domain: " + domain) } -func (m *mockReverseProxyManager) GetActiveClusters(_ context.Context, _, _ string) ([]proxy.Cluster, error) { +func (m *mockReverseProxyManager) GetClusters(_ context.Context, _, _ string) ([]proxy.Cluster, error) { return nil, nil } diff --git a/management/internals/shared/grpc/validate_session_test.go b/management/internals/shared/grpc/validate_session_test.go index 7b7ffcfb2..774c5d1d3 100644 --- a/management/internals/shared/grpc/validate_session_test.go +++ b/management/internals/shared/grpc/validate_session_test.go @@ -322,7 +322,7 @@ func (m *testValidateSessionServiceManager) GetServiceByDomain(ctx context.Conte return m.store.GetServiceByDomain(ctx, domain) } -func (m *testValidateSessionServiceManager) GetActiveClusters(_ context.Context, _, _ string) ([]proxy.Cluster, error) { +func (m *testValidateSessionServiceManager) GetClusters(_ context.Context, _, _ string) ([]proxy.Cluster, error) { return nil, nil } diff --git a/management/server/http/handlers/proxy/auth_callback_integration_test.go b/management/server/http/handlers/proxy/auth_callback_integration_test.go index 30d8aa0e7..f08d5daf1 100644 --- a/management/server/http/handlers/proxy/auth_callback_integration_test.go +++ b/management/server/http/handlers/proxy/auth_callback_integration_test.go @@ -444,7 +444,7 @@ func (m *testServiceManager) GetServiceByDomain(ctx context.Context, domain stri return m.store.GetServiceByDomain(ctx, domain) } -func (m *testServiceManager) GetActiveClusters(_ context.Context, _, _ string) ([]nbproxy.Cluster, error) { +func (m *testServiceManager) GetClusters(_ context.Context, _, _ string) ([]nbproxy.Cluster, error) { return nil, nil } diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index 8cf37de56..f3c6b741b 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -5736,19 +5736,67 @@ func (s *SqlStore) DeleteAccountCluster(ctx context.Context, clusterAddress, acc return nil } -func (s *SqlStore) GetActiveProxyClusters(ctx context.Context, accountID string) ([]proxy.Cluster, error) { - var clusters []proxy.Cluster +// GetProxyClusters returns every cluster the account can see (shared +// plus its own BYOP), regardless of whether any proxy in the cluster +// is currently heartbeating. Online and ConnectedProxies are derived +// from the 2-min active window so the dashboard can render offline +// clusters distinctly; the 1-hour heartbeat reaper still removes rows +// that go quiet for too long. +// +// AccountOwned is determined by whether any proxy row in the group +// carries a non-NULL account_id; the caller maps that to Cluster.Type. +// Capability flags are NOT filled here โ€” the handler enriches them via +// the per-cluster capability lookups. +func (s *SqlStore) GetProxyClusters(ctx context.Context, accountID string) ([]proxy.Cluster, error) { + activeCutoff := time.Now().Add(-proxyActiveThreshold) + type clusterRow struct { + ID string + Address string + ConnectedProxies int + Online bool + AccountOwned bool + } + + var rows []clusterRow result := s.db.Model(&proxy.Proxy{}). - Select("MIN(id) as id, cluster_address as address, COUNT(*) as connected_proxies, COUNT(account_id) > 0 as self_hosted"). - Where("status = ? AND last_seen > ? AND (account_id IS NULL OR account_id = ?)", - proxy.StatusConnected, time.Now().Add(-proxyActiveThreshold), accountID). + Select( + "MIN(id) AS id, "+ + "cluster_address AS address, "+ + // COUNT(CASE WHEN ... THEN 1 END) counts only non-NULL โ€” i.e. only + // rows that satisfy the predicate โ€” so it works portably across + // sqlite/postgres/mysql without dialect-specific FILTER syntax. + "COUNT(CASE WHEN status = ? AND last_seen > ? THEN 1 END) AS connected_proxies, "+ + // MAX(CASE โ€ฆ) > 0 expresses BOOL_OR in a way Postgres tolerates + // (Postgres can't MAX a boolean column). + "MAX(CASE WHEN status = ? AND last_seen > ? THEN 1 ELSE 0 END) > 0 AS online, "+ + "MAX(CASE WHEN account_id IS NOT NULL THEN 1 ELSE 0 END) > 0 AS account_owned", + proxy.StatusConnected, activeCutoff, + proxy.StatusConnected, activeCutoff, + ). + Where("account_id IS NULL OR account_id = ?", accountID). Group("cluster_address"). - Scan(&clusters) + Scan(&rows) if result.Error != nil { - log.WithContext(ctx).Errorf("failed to get active proxy clusters: %v", result.Error) - return nil, status.Errorf(status.Internal, "get active proxy clusters") + log.WithContext(ctx).Errorf("failed to get proxy clusters: %v", result.Error) + return nil, status.Errorf(status.Internal, "get proxy clusters") + } + + clusters := make([]proxy.Cluster, 0, len(rows)) + for _, r := range rows { + c := proxy.Cluster{ + ID: r.ID, + Address: r.Address, + Online: r.Online, + ConnectedProxies: r.ConnectedProxies, + } + if r.AccountOwned { + c.Type = proxy.ClusterTypeAccount + } else { + c.Type = proxy.ClusterTypeShared + } + clusters = append(clusters, c) } return clusters, nil diff --git a/management/server/store/sql_store_proxy_clusters_test.go b/management/server/store/sql_store_proxy_clusters_test.go new file mode 100644 index 000000000..cdacfedae --- /dev/null +++ b/management/server/store/sql_store_proxy_clusters_test.go @@ -0,0 +1,109 @@ +package store + +import ( + "context" + "os" + "runtime" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + rpproxy "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy" +) + +// TestSqlStore_GetProxyClusters_DerivesOnlineAndType guards the +// account-visible cluster list against silent regressions in two +// dimensions: +// +// 1. Online derivation: a cluster with one stale and one fresh proxy +// is online and counts only the fresh proxy; a cluster whose +// proxies all heartbeated outside the 2-min window appears offline +// with connected_proxies = 0 (rather than disappearing, which is +// what the old query did). +// 2. Type derivation: a cluster scoped to the calling account is +// reported as `account`; a cluster with account_id IS NULL is +// reported as `shared`. Clusters scoped to other accounts must not +// leak into the result. +// +// Capability flags are intentionally not asserted here โ€” they're filled +// by the manager (handler) layer from the per-cluster capability +// lookups, not by the store query. +func TestSqlStore_GetProxyClusters_DerivesOnlineAndType(t *testing.T) { + if (os.Getenv("CI") == "true" && runtime.GOOS == "darwin") || runtime.GOOS == "windows" { + t.Skip("skip CI tests on darwin and windows") + } + + runTestForAllEngines(t, "", func(t *testing.T, store Store) { + ctx := context.Background() + accountID := "acct-clusters" + require.NoError(t, store.SaveAccount(ctx, newAccountWithId(ctx, accountID, "user-1", ""))) + + otherAccountID := "acct-other" + require.NoError(t, store.SaveAccount(ctx, newAccountWithId(ctx, otherAccountID, "user-2", ""))) + + acctID := accountID + otherID := otherAccountID + + fresh := time.Now().Add(-30 * time.Second) + stale := time.Now().Add(-30 * time.Minute) + + mustSave := func(id, cluster string, accID *string, status string, lastSeen time.Time) { + require.NoError(t, store.SaveProxy(ctx, &rpproxy.Proxy{ + ID: id, + SessionID: id + "-sess", + ClusterAddress: cluster, + IPAddress: "10.0.0.1", + AccountID: accID, + LastSeen: lastSeen, + Status: status, + })) + } + + // shared-mixed: one fresh + one stale proxy โ†’ online, connected=1 + mustSave("p-shared-fresh", "shared-mixed.netbird.io", nil, rpproxy.StatusConnected, fresh) + mustSave("p-shared-stale", "shared-mixed.netbird.io", nil, rpproxy.StatusConnected, stale) + + // shared-offline: only stale proxies โ†’ offline, connected=0, + // but row must still appear (this is the new semantic โ€” old + // query would have dropped it entirely). + mustSave("p-shared-off", "shared-offline.netbird.io", nil, rpproxy.StatusConnected, stale) + + // account-online: BYOP cluster owned by acctID, fresh + mustSave("p-acct-fresh", "byop.acct.example", &acctID, rpproxy.StatusConnected, fresh) + + // other-account: must not surface for acctID + mustSave("p-other", "byop.other.example", &otherID, rpproxy.StatusConnected, fresh) + + clusters, err := store.GetProxyClusters(ctx, accountID) + require.NoError(t, err) + + byAddr := map[string]rpproxy.Cluster{} + for _, c := range clusters { + byAddr[c.Address] = c + } + + assert.NotContains(t, byAddr, "byop.other.example", + "another account's BYOP cluster must not leak into this account's listing") + + require.Contains(t, byAddr, "shared-mixed.netbird.io") + mixed := byAddr["shared-mixed.netbird.io"] + assert.Equal(t, rpproxy.ClusterTypeShared, mixed.Type, "shared cluster (account_id IS NULL) must be reported as Type=shared") + assert.True(t, mixed.Online, "cluster with a fresh proxy must be online") + assert.Equal(t, 1, mixed.ConnectedProxies, "connected_proxies must count only fresh proxies; the stale one should not bump the count") + + require.Contains(t, byAddr, "shared-offline.netbird.io", + "offline clusters must still appear so the dashboard can render them โ€” the old GetActiveProxyClusters would have dropped this row, which is the regression this test guards against") + offline := byAddr["shared-offline.netbird.io"] + assert.Equal(t, rpproxy.ClusterTypeShared, offline.Type) + assert.False(t, offline.Online, "no fresh heartbeat โ†’ offline") + assert.Equal(t, 0, offline.ConnectedProxies, "no fresh proxies โ†’ connected_proxies=0") + + require.Contains(t, byAddr, "byop.acct.example") + acct := byAddr["byop.acct.example"] + assert.Equal(t, rpproxy.ClusterTypeAccount, acct.Type, "BYOP cluster owned by the account must be reported as Type=account") + assert.True(t, acct.Online) + assert.Equal(t, 1, acct.ConnectedProxies) + }) +} diff --git a/management/server/store/store.go b/management/server/store/store.go index 045f1576a..42cdcf36d 100644 --- a/management/server/store/store.go +++ b/management/server/store/store.go @@ -307,7 +307,7 @@ type Store interface { UpdateProxyHeartbeat(ctx context.Context, p *proxy.Proxy) error GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error) GetActiveProxyClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error) - GetActiveProxyClusters(ctx context.Context, accountID string) ([]proxy.Cluster, error) + GetProxyClusters(ctx context.Context, accountID string) ([]proxy.Cluster, error) GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool GetClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool GetClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool diff --git a/management/server/store/store_mock.go b/management/server/store/store_mock.go index d51629606..4f9d875d2 100644 --- a/management/server/store/store_mock.go +++ b/management/server/store/store_mock.go @@ -380,6 +380,20 @@ func (mr *MockStoreMockRecorder) DeleteAccount(ctx, account interface{}) *gomock return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAccount", reflect.TypeOf((*MockStore)(nil).DeleteAccount), ctx, account) } +// DeleteAccountCluster mocks base method. +func (m *MockStore) DeleteAccountCluster(ctx context.Context, clusterAddress, accountID string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteAccountCluster", ctx, clusterAddress, accountID) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeleteAccountCluster indicates an expected call of DeleteAccountCluster. +func (mr *MockStoreMockRecorder) DeleteAccountCluster(ctx, clusterAddress, accountID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAccountCluster", reflect.TypeOf((*MockStore)(nil).DeleteAccountCluster), ctx, clusterAddress, accountID) +} + // DeleteCustomDomain mocks base method. func (m *MockStore) DeleteCustomDomain(ctx context.Context, accountID, domainID string) error { m.ctrl.T.Helper() @@ -577,20 +591,6 @@ func (mr *MockStoreMockRecorder) DeletePostureChecks(ctx, accountID, postureChec return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeletePostureChecks", reflect.TypeOf((*MockStore)(nil).DeletePostureChecks), ctx, accountID, postureChecksID) } -// DeleteAccountCluster mocks base method. -func (m *MockStore) DeleteAccountCluster(ctx context.Context, clusterAddress, accountID string) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "DeleteAccountCluster", ctx, clusterAddress, accountID) - ret0, _ := ret[0].(error) - return ret0 -} - -// DeleteAccountCluster indicates an expected call of DeleteAccountCluster. -func (mr *MockStoreMockRecorder) DeleteAccountCluster(ctx, clusterAddress, accountID interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAccountCluster", reflect.TypeOf((*MockStore)(nil).DeleteAccountCluster), ctx, clusterAddress, accountID) -} - // DeleteRoute mocks base method. func (m *MockStore) DeleteRoute(ctx context.Context, accountID, routeID string) error { m.ctrl.T.Helper() @@ -731,6 +731,20 @@ func (mr *MockStoreMockRecorder) DeleteZoneDNSRecords(ctx, accountID, zoneID int return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteZoneDNSRecords", reflect.TypeOf((*MockStore)(nil).DeleteZoneDNSRecords), ctx, accountID, zoneID) } +// DisconnectProxy mocks base method. +func (m *MockStore) DisconnectProxy(ctx context.Context, proxyID, sessionID string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DisconnectProxy", ctx, proxyID, sessionID) + ret0, _ := ret[0].(error) + return ret0 +} + +// DisconnectProxy indicates an expected call of DisconnectProxy. +func (mr *MockStoreMockRecorder) DisconnectProxy(ctx, proxyID, sessionID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DisconnectProxy", reflect.TypeOf((*MockStore)(nil).DisconnectProxy), ctx, proxyID, sessionID) +} + // EphemeralServiceExists mocks base method. func (m *MockStore) EphemeralServiceExists(ctx context.Context, lockStrength LockingStrength, accountID, peerID, domain string) (bool, error) { m.ctrl.T.Helper() @@ -1332,21 +1346,6 @@ func (mr *MockStoreMockRecorder) GetActiveProxyClusterAddressesForAccount(ctx, a return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveProxyClusterAddressesForAccount", reflect.TypeOf((*MockStore)(nil).GetActiveProxyClusterAddressesForAccount), ctx, accountID) } -// GetActiveProxyClusters mocks base method. -func (m *MockStore) GetActiveProxyClusters(ctx context.Context, accountID string) ([]proxy.Cluster, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetActiveProxyClusters", ctx, accountID) - ret0, _ := ret[0].([]proxy.Cluster) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetActiveProxyClusters indicates an expected call of GetActiveProxyClusters. -func (mr *MockStoreMockRecorder) GetActiveProxyClusters(ctx, accountID interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveProxyClusters", reflect.TypeOf((*MockStore)(nil).GetActiveProxyClusters), ctx, accountID) -} - // GetAllAccounts mocks base method. func (m *MockStore) GetAllAccounts(ctx context.Context) []*types2.Account { m.ctrl.T.Helper() @@ -2048,6 +2047,21 @@ func (mr *MockStoreMockRecorder) GetProxyByAccountID(ctx, accountID interface{}) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProxyByAccountID", reflect.TypeOf((*MockStore)(nil).GetProxyByAccountID), ctx, accountID) } +// GetProxyClusters mocks base method. +func (m *MockStore) GetProxyClusters(ctx context.Context, accountID string) ([]proxy.Cluster, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetProxyClusters", ctx, accountID) + ret0, _ := ret[0].([]proxy.Cluster) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetProxyClusters indicates an expected call of GetProxyClusters. +func (mr *MockStoreMockRecorder) GetProxyClusters(ctx, accountID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProxyClusters", reflect.TypeOf((*MockStore)(nil).GetProxyClusters), ctx, accountID) +} + // GetResourceGroups mocks base method. func (m *MockStore) GetResourceGroups(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) ([]*types2.Group, error) { m.ctrl.T.Helper() @@ -2950,20 +2964,6 @@ func (mr *MockStoreMockRecorder) SaveProxy(ctx, proxy interface{}) *gomock.Call return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveProxy", reflect.TypeOf((*MockStore)(nil).SaveProxy), ctx, proxy) } -// DisconnectProxy mocks base method. -func (m *MockStore) DisconnectProxy(ctx context.Context, proxyID, sessionID string) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "DisconnectProxy", ctx, proxyID, sessionID) - ret0, _ := ret[0].(error) - return ret0 -} - -// DisconnectProxy indicates an expected call of DisconnectProxy. -func (mr *MockStoreMockRecorder) DisconnectProxy(ctx, proxyID, sessionID interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DisconnectProxy", reflect.TypeOf((*MockStore)(nil).DisconnectProxy), ctx, proxyID, sessionID) -} - // SaveProxyAccessToken mocks base method. func (m *MockStore) SaveProxyAccessToken(ctx context.Context, token *types2.ProxyAccessToken) error { m.ctrl.T.Helper() diff --git a/proxy/management_integration_test.go b/proxy/management_integration_test.go index 9fd3d2ce9..d7e891801 100644 --- a/proxy/management_integration_test.go +++ b/proxy/management_integration_test.go @@ -366,7 +366,7 @@ func (m *storeBackedServiceManager) GetServiceByDomain(ctx context.Context, doma return m.store.GetServiceByDomain(ctx, domain) } -func (m *storeBackedServiceManager) GetActiveClusters(_ context.Context, _, _ string) ([]nbproxy.Cluster, error) { +func (m *storeBackedServiceManager) GetClusters(_ context.Context, _, _ string) ([]nbproxy.Cluster, error) { return nil, nil } diff --git a/shared/management/http/api/openapi.yml b/shared/management/http/api/openapi.yml index 942f3aa45..353aff72d 100644 --- a/shared/management/http/api/openapi.yml +++ b/shared/management/http/api/openapi.yml @@ -3417,19 +3417,43 @@ components: type: string description: Cluster address used for CNAME targets example: "eu.proxy.netbird.io" + type: + $ref: '#/components/schemas/ProxyClusterType' + online: + type: boolean + description: Whether at least one proxy in the cluster has heartbeated within the active window + example: true connected_proxies: type: integer - description: Number of proxy nodes connected in this cluster + description: Number of proxy nodes currently connected (heartbeat within the active window) example: 3 - self_hosted: + supports_custom_ports: type: boolean - description: Whether this cluster is a self-hosted (BYOP) proxy managed by the account owner + description: Whether the cluster supports binding arbitrary TCP/UDP ports + example: true + require_subdomain: + type: boolean + description: Whether services on this cluster must include a subdomain label + example: false + supports_crowdsec: + type: boolean + description: Whether all active proxies in the cluster have CrowdSec configured example: false required: - id - address + - type + - online - connected_proxies - - self_hosted + ProxyClusterType: + type: string + description: | + Source of the proxy cluster. `account` clusters are owned and operated by the account (BYOP); + `shared` clusters are operated by NetBird and shared across accounts. + enum: + - account + - shared + example: shared ReverseProxyDomainType: type: string description: Type of Reverse Proxy Domain diff --git a/shared/management/http/api/types.gen.go b/shared/management/http/api/types.gen.go index b3bb475a9..16e765f8c 100644 --- a/shared/management/http/api/types.gen.go +++ b/shared/management/http/api/types.gen.go @@ -1,6 +1,6 @@ // Package api provides primitives to interact with the openapi HTTP API. // -// Code generated by github.com/oapi-codegen/oapi-codegen/v2 version v2.6.0 DO NOT EDIT. +// Code generated by github.com/oapi-codegen/oapi-codegen/v2 version v2.7.0 DO NOT EDIT. package api import ( @@ -13,8 +13,8 @@ import ( ) const ( - BearerAuthScopes = "BearerAuth.Scopes" - TokenAuthScopes = "TokenAuth.Scopes" + BearerAuthScopes bearerAuthContextKey = "BearerAuth.Scopes" + TokenAuthScopes tokenAuthContextKey = "TokenAuth.Scopes" ) // Defines values for AccessRestrictionsCrowdsecMode. @@ -511,6 +511,7 @@ func (e GroupMinimumIssued) Valid() bool { // Defines values for IdentityProviderType. const ( + IdentityProviderTypeAdfs IdentityProviderType = "adfs" IdentityProviderTypeEntra IdentityProviderType = "entra" IdentityProviderTypeGoogle IdentityProviderType = "google" IdentityProviderTypeMicrosoft IdentityProviderType = "microsoft" @@ -518,12 +519,13 @@ const ( IdentityProviderTypeOkta IdentityProviderType = "okta" IdentityProviderTypePocketid IdentityProviderType = "pocketid" IdentityProviderTypeZitadel IdentityProviderType = "zitadel" - IdentityProviderTypeAdfs IdentityProviderType = "adfs" ) // Valid indicates whether the value is a known member of the IdentityProviderType enum. func (e IdentityProviderType) Valid() bool { switch e { + case IdentityProviderTypeAdfs: + return true case IdentityProviderTypeEntra: return true case IdentityProviderTypeGoogle: @@ -538,8 +540,6 @@ func (e IdentityProviderType) Valid() bool { return true case IdentityProviderTypeZitadel: return true - case IdentityProviderTypeAdfs: - return true default: return false } @@ -878,6 +878,24 @@ func (e PolicyRuleUpdateProtocol) Valid() bool { } } +// Defines values for ProxyClusterType. +const ( + ProxyClusterTypeAccount ProxyClusterType = "account" + ProxyClusterTypeShared ProxyClusterType = "shared" +) + +// Valid indicates whether the value is a known member of the ProxyClusterType enum. +func (e ProxyClusterType) Valid() bool { + switch e { + case ProxyClusterTypeAccount: + return true + case ProxyClusterTypeShared: + return true + default: + return false + } +} + // Defines values for ResourceType. const ( ResourceTypeDomain ResourceType = "domain" @@ -1638,7 +1656,9 @@ type Checks struct { // OsVersionCheck Posture check for the version of operating system OsVersionCheck *OSVersionCheck `json:"os_version_check,omitempty"` - // PeerNetworkRangeCheck Posture check for allow or deny access based on the peer's IP addresses. A range matches when it contains any of the peer's local network interface IPs or its public connection (NAT egress) IP, so ranges may target private subnets, public CIDRs, or single hosts via a /32 or /128. + // PeerNetworkRangeCheck Posture check for allow or deny access based on the peer's IP addresses. A range matches when it + // contains any of the peer's local network interface IPs or its public connection (NAT egress) IP, + // so ranges may target private subnets, public CIDRs, or single hosts via a /32 or /128. PeerNetworkRangeCheck *PeerNetworkRangeCheck `json:"peer_network_range_check,omitempty"` // ProcessCheck Posture Check for binaries exist and are running in the peerโ€™s system @@ -3330,7 +3350,9 @@ type PeerMinimum struct { Name string `json:"name"` } -// PeerNetworkRangeCheck Posture check for allow or deny access based on the peer's IP addresses. A range matches when it contains any of the peer's local network interface IPs or its public connection (NAT egress) IP, so ranges may target private subnets, public CIDRs, or single hosts via a /32 or /128. +// PeerNetworkRangeCheck Posture check for allow or deny access based on the peer's IP addresses. A range matches when it +// contains any of the peer's local network interface IPs or its public connection (NAT egress) IP, +// so ranges may target private subnets, public CIDRs, or single hosts via a /32 or /128. type PeerNetworkRangeCheck struct { // Action Action to take upon policy match Action PeerNetworkRangeCheckAction `json:"action"` @@ -3785,19 +3807,36 @@ type ProxyAccessLogsResponse struct { // ProxyCluster A proxy cluster represents a group of proxy nodes serving the same address type ProxyCluster struct { - // Id Unique identifier of a proxy in this cluster - Id string `json:"id"` - // Address Cluster address used for CNAME targets Address string `json:"address"` - // ConnectedProxies Number of proxy nodes connected in this cluster + // ConnectedProxies Number of proxy nodes currently connected (heartbeat within the active window) ConnectedProxies int `json:"connected_proxies"` - // SelfHosted Whether this cluster is a self-hosted (BYOP) proxy managed by the account owner - SelfHosted bool `json:"self_hosted"` + // Id Unique identifier of a proxy in this cluster + Id string `json:"id"` + + // Online Whether at least one proxy in the cluster has heartbeated within the active window + Online bool `json:"online"` + + // RequireSubdomain Whether services on this cluster must include a subdomain label + RequireSubdomain *bool `json:"require_subdomain,omitempty"` + + // SupportsCrowdsec Whether all active proxies in the cluster have CrowdSec configured + SupportsCrowdsec *bool `json:"supports_crowdsec,omitempty"` + + // SupportsCustomPorts Whether the cluster supports binding arbitrary TCP/UDP ports + SupportsCustomPorts *bool `json:"supports_custom_ports,omitempty"` + + // Type Source of the proxy cluster. `account` clusters are owned and operated by the account (BYOP); + // `shared` clusters are operated by NetBird and shared across accounts. + Type ProxyClusterType `json:"type"` } +// ProxyClusterType Source of the proxy cluster. `account` clusters are owned and operated by the account (BYOP); +// `shared` clusters are operated by NetBird and shared across accounts. +type ProxyClusterType string + // ProxyToken defines model for ProxyToken. type ProxyToken struct { CreatedAt time.Time `json:"created_at"` @@ -4820,6 +4859,12 @@ type ZoneRequest struct { // Conflict Standard error response. Note: The exact structure of this error response is inferred from `util.WriteErrorResponse` and `util.WriteError` usage in the provided Go code, as a specific Go struct for errors was not provided. type Conflict = ErrorResponse +// bearerAuthContextKey is the context key for BearerAuth security scheme +type bearerAuthContextKey string + +// tokenAuthContextKey is the context key for TokenAuth security scheme +type tokenAuthContextKey string + // GetApiEventsNetworkTrafficParams defines parameters for GetApiEventsNetworkTraffic. type GetApiEventsNetworkTrafficParams struct { // Page Page number From c784b0255063b9cbfde830c78670de2400e46c1c Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Wed, 20 May 2026 12:21:03 +0200 Subject: [PATCH 17/31] [misc] Update contribution guidelines (#6219) Update contribution guidelines and PR template to require discussing impactful changes with the team --- .github/pull_request_template.md | 1 + CONTRIBUTING.md | 9 +++++++++ 2 files changed, 10 insertions(+) diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md index 9d6bc96eb..8e68054bd 100644 --- a/.github/pull_request_template.md +++ b/.github/pull_request_template.md @@ -12,6 +12,7 @@ - [ ] Is a feature enhancement - [ ] It is a refactor - [ ] Created tests that fail without the change (if possible) +- [ ] This change does **not** modify the public API, gRPC protocols, functionality behavior, CLI / service flags, or introduce a new feature โ€” **OR** I have discussed it with the NetBird team beforehand (link the issue / Slack thread in the description). See [CONTRIBUTING.md](https://github.com/netbirdio/netbird/blob/main/CONTRIBUTING.md#discuss-changes-with-the-netbird-team-first). > By submitting this pull request, you confirm that you have read and agree to the terms of the [Contributor License Agreement](https://github.com/netbirdio/netbird/blob/main/CONTRIBUTOR_LICENSE_AGREEMENT.md). diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 960cd30e9..cd1c087bb 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -15,6 +15,7 @@ If you haven't already, join our slack workspace [here](https://docs.netbird.io/ - [Contributing to NetBird](#contributing-to-netbird) - [Contents](#contents) - [Code of conduct](#code-of-conduct) + - [Discuss changes with the NetBird team first](#discuss-changes-with-the-netbird-team-first) - [Directory structure](#directory-structure) - [Development setup](#development-setup) - [Requirements](#requirements) @@ -33,6 +34,14 @@ Conduct which can be found in the file [CODE_OF_CONDUCT.md](CODE_OF_CONDUCT.md). By participating, you are expected to uphold this code. Please report unacceptable behavior to community@netbird.io. +## Discuss changes with the NetBird team first + +Changes to the **public API**, **gRPC protocols**, **functionality behavior**, **CLI / service flags**, or **new features** should be discussed with the NetBird team before you start the work. These surfaces are part of NetBird's contract with operators, self-hosters, and downstream integrators, and changes to them have compatibility, security, and release-planning implications that benefit from an early conversation. + +Open an issue or reach out on [Slack](https://docs.netbird.io/slack-url) to talk through what you have in mind. We'll help shape the change, flag any constraints we know about, and confirm the direction so the PR review can focus on implementation rather than design. + +Typical bug fixes, internal refactors, documentation updates, and tests do not need pre-discussion โ€” open the PR directly. + ## Directory structure The NetBird project monorepo is organized to maintain most of its individual dependencies code within their directories, except for a few auxiliary or shared packages. From 9192b4f029f8f0eeaad77fff8625c34ba9849668 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Wed, 20 May 2026 20:09:22 +0900 Subject: [PATCH 18/31] [client] Bump macOS sleep callback timeout to 20s (#6220) --- client/internal/sleep/detector_darwin.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/client/internal/sleep/detector_darwin.go b/client/internal/sleep/detector_darwin.go index ef495bded..fc4713b21 100644 --- a/client/internal/sleep/detector_darwin.go +++ b/client/internal/sleep/detector_darwin.go @@ -188,7 +188,9 @@ func (d *Detector) triggerCallback(event EventType, cb func(event EventType), do } doneChan := make(chan struct{}) - timeout := time.NewTimer(500 * time.Millisecond) + // macOS forces sleep ~30s after kIOMessageSystemWillSleep, so block long + // enough for teardown to finish while staying under that deadline. + timeout := time.NewTimer(20 * time.Second) defer timeout.Stop() go func() { From 4955c345d53f63394266305744841c4e1bff8123 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Wed, 20 May 2026 23:25:56 +0900 Subject: [PATCH 19/31] Clean up README header, key features table, and self-hosted quickstart (#6178) --- README.md | 153 +++++++++++++++++++++++++----------------------------- 1 file changed, 70 insertions(+), 83 deletions(-) diff --git a/README.md b/README.md index dc84af2fd..cc27e2d28 100644 --- a/README.md +++ b/README.md @@ -1,147 +1,134 @@
-
-
-

- -

-

- - - - - - -
+

+ NetBird logo +

+

+ + SonarCloud alert status + + + BSD-3 License + - - + NetBird Slack + - - -
+ Community forum + - - + Gurubase: Ask NetBird Guru +

-

- - Start using NetBird at netbird.io + + Start using NetBird at netbird.io +
+ See Documentation +
+ Join our Slack channel or our Community forum +

- See Documentation
- Join our Slack channel or our Community forum -
- -
-
- - ๐Ÿš€ We are hiring! Join us at careers.netbird.io - -
-
- - New: NetBird terraform provider - + + ๐Ÿš€ We are hiring! Join us at careers.netbird.io +

-
- **NetBird combines a configuration-free peer-to-peer private network and a centralized access control system in a single platform, making it easy to create secure private networks for your organization or home.** **Connect.** NetBird creates a WireGuard-based overlay network that automatically connects your machines over an encrypted tunnel, leaving behind the hassle of opening ports, complex firewall rules, VPN gateways, and so forth. **Secure.** NetBird enables secure remote access by applying granular access policies while allowing you to manage them intuitively from a single place. Works universally on any infrastructure. -### Open Source Network Security in a Single Platform - https://github.com/user-attachments/assets/10cec749-bb56-4ab3-97af-4e38850108d2 -### Self-Host NetBird (Video) +### Self-host NetBird (video) + [![Watch the video](https://img.youtube.com/vi/bZAgpT6nzaQ/0.jpg)](https://youtu.be/bZAgpT6nzaQ) ### Key features -| Connectivity | Management | Security | Automation| Platforms | -|----|----|----|----|----| -|
  • - \[x] Kernel WireGuard
|
  • - \[x] [Admin Web UI](https://github.com/netbirdio/dashboard)
|
  • - \[x] [SSO & MFA support](https://docs.netbird.io/how-to/installation#running-net-bird-with-sso-login)
|
  • - \[x] [Public API](https://docs.netbird.io/api)
|
  • - \[x] Linux
| -|
  • - \[x] Peer-to-peer connections
|
  • - \[x] Auto peer discovery and configuration
  • |
    • - \[x] [Access control - groups & rules](https://docs.netbird.io/how-to/manage-network-access)
    • |
      • - \[x] [Setup keys for bulk network provisioning](https://docs.netbird.io/how-to/register-machines-using-setup-keys)
      • |
        • - \[x] Mac
        • | -|
          • - \[x] Connection relay fallback
          • |
            • - \[x] [IdP integrations](https://docs.netbird.io/selfhosted/identity-providers)
            • |
              • - \[x] [Activity logging](https://docs.netbird.io/how-to/audit-events-logging)
              • |
                • - \[x] [Self-hosting quickstart script](https://docs.netbird.io/selfhosted/selfhosted-quickstart)
                • |
                  • - \[x] Windows
                  • | -|
                    • - \[x] [Routes to external networks](https://docs.netbird.io/how-to/routing-traffic-to-private-networks)
                    • |
                      • - \[x] [Private DNS](https://docs.netbird.io/how-to/manage-dns-in-your-network)
                      • |
                        • - \[x] [Device posture checks](https://docs.netbird.io/how-to/manage-posture-checks)
                        • |
                          • - \[x] IdP groups sync with JWT
                          • |
                            • - \[x] Android
                            • | -|
                              • - \[x] NAT traversal with BPF
                              • |
                                • - \[x] [Multiuser support](https://docs.netbird.io/how-to/add-users-to-your-network)
                                • |
                                  • - \[x] Peer-to-peer encryption
                                  • ||
                                    • - \[x] iOS
                                    • | -|||
                                      • - \[x] [Quantum-resistance with Rosenpass](https://netbird.io/knowledge-hub/the-first-quantum-resistant-mesh-vpn)
                                      • ||
                                        • - \[x] OpenWRT
                                        • | -|||
                                          • - \[x] [Periodic re-authentication](https://docs.netbird.io/how-to/enforce-periodic-user-authentication)
                                          • ||
                                            • - \[x] [Serverless](https://docs.netbird.io/how-to/netbird-on-faas)
                                            • | -|||||
                                              • - \[x] Docker
                                              • | +| Connectivity | Management | Security | Automation | Platforms | +|---|---|---|---|---| +| โœ“ [Kernel WireGuard](https://docs.netbird.io/about-netbird/why-wireguard-with-netbird) | โœ“ [Admin Web UI](https://github.com/netbirdio/dashboard) | โœ“ [SSO & MFA support](https://docs.netbird.io/how-to/installation#running-net-bird-with-sso-login) | โœ“ [Public API](https://docs.netbird.io/api) | โœ“ [Linux](https://docs.netbird.io/get-started/install/linux) | +| โœ“ [Peer-to-peer connections](https://docs.netbird.io/about-netbird/how-netbird-works) | โœ“ Auto peer discovery and configuration | โœ“ [Access control: groups & rules](https://docs.netbird.io/how-to/manage-network-access) | โœ“ [Setup keys for bulk provisioning](https://docs.netbird.io/how-to/register-machines-using-setup-keys) | โœ“ [macOS](https://docs.netbird.io/get-started/install/macos) | +| โœ“ Connection relay fallback | โœ“ [IdP integrations](https://docs.netbird.io/selfhosted/identity-providers) | โœ“ [Activity logging](https://docs.netbird.io/how-to/audit-events-logging) | โœ“ [Self-hosting quickstart script](https://docs.netbird.io/selfhosted/selfhosted-quickstart) | โœ“ [Windows](https://docs.netbird.io/get-started/install/windows) | +| โœ“ [Routes to external networks](https://docs.netbird.io/how-to/routing-traffic-to-private-networks) | โœ“ [Private DNS](https://docs.netbird.io/how-to/manage-dns-in-your-network) | โœ“ [Traffic events](https://docs.netbird.io/manage/activity/traffic-events-logging) | โœ“ [IdP groups sync with JWT](https://docs.netbird.io/manage/team/idp-sync) | โœ“ [Android](https://docs.netbird.io/get-started/install/android) | +| โœ“ [Domain-based DNS routes](https://docs.netbird.io/manage/dns/dns-aliases-for-routed-networks) | โœ“ [Custom DNS zones](https://docs.netbird.io/manage/dns/custom-zones) | โœ“ [Device posture checks](https://docs.netbird.io/how-to/manage-posture-checks) | โœ“ [Terraform provider](https://registry.terraform.io/providers/netbirdio/netbird/latest) | โœ“ [Android TV](https://docs.netbird.io/get-started/install/android-tv) | +| โœ“ [Exit nodes](https://docs.netbird.io/manage/network-routes/use-cases/exit-nodes) | โœ“ [Multiuser support](https://docs.netbird.io/how-to/add-users-to-your-network) | โœ“ Peer-to-peer encryption | โœ“ [Ansible collection](https://github.com/netbirdio/ansible-netbird) | โœ“ [iOS](https://docs.netbird.io/get-started/install/ios) | +| โœ“ [IPv6 dual-stack overlay](https://docs.netbird.io/manage/settings/ipv6) | โœ“ [Multi-account profile switching](https://docs.netbird.io/client/profiles) | โœ“ [SSH with central access policies](https://docs.netbird.io/manage/peers/ssh) | | โœ“ [Apple TV](https://docs.netbird.io/get-started/install/tvos) | +| โœ“ [Browser SSH & RDP](https://docs.netbird.io/manage/peers/browser-client) | | โœ“ [Quantum-resistance with Rosenpass](https://netbird.io/knowledge-hub/the-first-quantum-resistant-mesh-vpn) | | โœ“ FreeBSD | +| โœ“ [Reverse proxy with auto-TLS](https://docs.netbird.io/manage/reverse-proxy) | | โœ“ [Periodic re-authentication](https://docs.netbird.io/how-to/enforce-periodic-user-authentication) | | โœ“ [pfSense](https://docs.netbird.io/get-started/install/pfsense) | +| | | | | โœ“ [OPNsense](https://docs.netbird.io/get-started/install/opnsense) | +| | | | | โœ“ [MikroTik RouterOS](https://docs.netbird.io/use-cases/homelab/client-on-mikrotik-router) | +| | | | | โœ“ OpenWRT | +| | | | | โœ“ [Synology](https://docs.netbird.io/get-started/install/synology) | +| | | | | โœ“ [TrueNAS](https://docs.netbird.io/get-started/install/truenas) | +| | | | | โœ“ [Proxmox](https://docs.netbird.io/get-started/install/proxmox-ve) | +| | | | | โœ“ [Raspberry Pi](https://docs.netbird.io/get-started/install/raspberrypi) | +| | | | | โœ“ [Serverless](https://docs.netbird.io/how-to/netbird-on-faas) | +| | | | | โœ“ [Container](https://docs.netbird.io/get-started/install/docker) | ### Quickstart with NetBird Cloud -- Download and install NetBird at [https://app.netbird.io/install](https://app.netbird.io/install) -- Follow the steps to sign-up with Google, Microsoft, GitHub or your email address. -- Check NetBird [admin UI](https://app.netbird.io/). -- Add more machines. +- Download and install NetBird at [https://app.netbird.io/install](https://app.netbird.io/install). +- Follow the steps to sign up with Google, Microsoft, GitHub or your email address. +- Check the NetBird [admin UI](https://app.netbird.io/). ### Quickstart with self-hosted NetBird -> This is the quickest way to try self-hosted NetBird. It should take around 5 minutes to get started if you already have a public domain and a VM. -Follow the [Advanced guide with a custom identity provider](https://docs.netbird.io/selfhosted/selfhosted-guide#advanced-guide-with-a-custom-identity-provider) for installations with different IDPs. +This is the quickest way to try self-hosted NetBird. It should take around 5 minutes to get started if you already have a public domain and a VM. Follow the [Advanced guide with a custom identity provider](https://docs.netbird.io/selfhosted/selfhosted-guide#advanced-guide-with-a-custom-identity-provider) for installations with different IdPs. **Infrastructure requirements:** -- A Linux VM with at least **1CPU** and **2GB** of memory. -- The VM should be publicly accessible on TCP ports **80** and **443** and UDP port: **3478**. -- **Public domain** name pointing to the VM. +- A Linux VM with at least **1 CPU** and **2 GB** of memory. +- The VM should be publicly accessible on TCP ports **80** and **443** and UDP port **3478**. +- A **public domain** name pointing to the VM. **Software requirements:** -- Docker installed on the VM with the docker-compose plugin ([Docker installation guide](https://docs.docker.com/engine/install/)) or docker with docker-compose in version 2 or higher. -- [jq](https://jqlang.github.io/jq/) installed. In most distributions - Usually available in the official repositories and can be installed with `sudo apt install jq` or `sudo yum install jq` -- [curl](https://curl.se/) installed. - Usually available in the official repositories and can be installed with `sudo apt install curl` or `sudo yum install curl` +- Docker with the Compose plugin (Compose v2 or higher). See the [Docker installation guide](https://docs.docker.com/engine/install/). **Steps** - Download and run the installation script: ```bash export NETBIRD_DOMAIN=netbird.example.com; curl -fsSL https://github.com/netbirdio/netbird/releases/latest/download/getting-started.sh | bash ``` -- Once finished, you can manage the resources via `docker-compose` ### A bit on NetBird internals -- Every machine in the network runs [NetBird Agent (or Client)](client/) that manages WireGuard. -- Every agent connects to [Management Service](management/) that holds network state, manages peer IPs, and distributes network updates to agents (peers). -- NetBird agent uses WebRTC ICE implemented in [pion/ice library](https://github.com/pion/ice) to discover connection candidates when establishing a peer-to-peer connection between machines. -- Connection candidates are discovered with the help of [STUN](https://en.wikipedia.org/wiki/STUN) servers. -- Agents negotiate a connection through [Signal Service](signal/) passing p2p encrypted messages with candidates. -- Sometimes the NAT traversal is unsuccessful due to strict NATs (e.g. mobile carrier-grade NAT) and a p2p connection isn't possible. When this occurs the system falls back to a relay server called [TURN](https://en.wikipedia.org/wiki/Traversal_Using_Relays_around_NAT), and a secure WireGuard tunnel is established via the TURN server. - -[Coturn](https://github.com/coturn/coturn) is the one that has been successfully used for STUN and TURN in NetBird setups. +- Every machine in the network runs the [NetBird agent](client/), which manages WireGuard. +- Every agent connects to the [Management Service](management/), which holds network state, manages peer IPs, and distributes updates to agents. +- Agents use ICE (via [pion/ice](https://github.com/pion/ice)) to discover connection candidates for peer-to-peer connections. +- Candidates are discovered with the help of [STUN](https://en.wikipedia.org/wiki/STUN) servers. +- Agents negotiate a connection through the [Signal Service](signal/), exchanging end-to-end encrypted messages with candidates. +- When NAT traversal fails (e.g. mobile carrier-grade NAT) and a direct p2p connection isn't possible, the system falls back to a [Relay Service](relay/) and a secure WireGuard tunnel is established through it.

                                                - + NetBird high-level architecture diagram

                                                See a complete [architecture overview](https://docs.netbird.io/about-netbird/how-netbird-works#architecture) for details. ### Community projects -- [NetBird installer script](https://github.com/physk/netbird-installer) -- [NetBird ansible collection by Dominion Solutions](https://galaxy.ansible.com/ui/repo/published/dominion_solutions/netbird/) -- [netbird-tui](https://github.com/n0pashkov/netbird-tui) โ€” terminal UI for managing NetBird peers, routes, and settings +- [NetBird installer script](https://github.com/physk/netbird-installer) +- [netbird-tui](https://github.com/n0pashkov/netbird-tui) - terminal UI for managing NetBird peers, routes, and settings +- [caddy-netbird](https://github.com/lixmal/caddy-netbird) - Caddy plugin that embeds a NetBird client for proxying HTTP and TCP/UDP traffic through NetBird networks **Note**: The `main` branch may be in an *unstable or even broken state* during development. For stable versions, see [releases](https://github.com/netbirdio/netbird/releases). ### Support acknowledgement -In November 2022, NetBird joined the [StartUpSecure program](https://www.forschung-it-sicherheit-kommunikationssysteme.de/foerderung/bekanntmachungen/startup-secure) sponsored by The Federal Ministry of Education and Research of The Federal Republic of Germany. Together with [CISPA Helmholtz Center for Information Security](https://cispa.de/en) NetBird brings the security best practices and simplicity to private networking. +In November 2022, NetBird joined the [StartUpSecure program](https://www.forschung-it-sicherheit-kommunikationssysteme.de/foerderung/bekanntmachungen/startup-secure) sponsored by the Federal Ministry of Education and Research of the Federal Republic of Germany. Together with the [CISPA Helmholtz Center for Information Security](https://cispa.de/en), NetBird brings security best practices and simplicity to private networking. ![CISPA_Logo_BLACK_EN_RZ_RGB (1)](https://user-images.githubusercontent.com/700848/203091324-c6d311a0-22b5-4b05-a288-91cbc6cdcc46.png) -### Testimonials -We use open-source technologies like [WireGuardยฎ](https://www.wireguard.com/), [Pion ICE (WebRTC)](https://github.com/pion/ice), [Coturn](https://github.com/coturn/coturn), and [Rosenpass](https://rosenpass.eu). We very much appreciate the work these guys are doing and we'd greatly appreciate if you could support them in any way (e.g., by giving a star or a contribution). +### Acknowledgements +We build on open-source technologies like [WireGuardยฎ](https://www.wireguard.com/), [Pion ICE](https://github.com/pion/ice), and [Rosenpass](https://rosenpass.eu). We greatly appreciate the work these projects are doing, and we'd love it if you could support them too (e.g., by starring or contributing). ### Legal -This repository is licensed under BSD-3-Clause license that applies to all parts of the repository except for the directories management/, signal/ and relay/. +This repository is licensed under the BSD-3-Clause license, which applies to all parts of the repository except for the directories management/, signal/ and relay/. Those directories are licensed under the GNU Affero General Public License version 3.0 (AGPLv3). See the respective LICENSE files inside each directory. _WireGuard_ and the _WireGuard_ logo are [registered trademarks](https://www.wireguard.com/trademark-policy/) of Jason A. Donenfeld. From 6137a1fcc53ad0ea0f5048ac97369168f1116a69 Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Wed, 20 May 2026 18:21:22 +0200 Subject: [PATCH 20/31] [proxy] concurrent proxy snapshot apply (#6207) --- management/internals/shared/grpc/proxy.go | 399 +++++++++---- .../shared/grpc/sync_mappings_test.go | 411 ++++++++++++++ proxy/handle_mapping_stream_test.go | 7 +- proxy/internal/metrics/metrics.go | 61 ++ proxy/internal/roundtrip/netbird.go | 62 ++- proxy/process_mappings_bench_test.go | 300 ++++++++++ proxy/server.go | 278 ++++++++-- proxy/snapshot_reconcile_test.go | 9 +- proxy/sync_mappings_test.go | 525 ++++++++++++++++++ shared/management/proto/proxy_service.pb.go | 487 +++++++++++++--- shared/management/proto/proxy_service.proto | 41 ++ .../management/proto/proxy_service_grpc.pb.go | 82 +++ 12 files changed, 2446 insertions(+), 216 deletions(-) create mode 100644 management/internals/shared/grpc/sync_mappings_test.go create mode 100644 proxy/process_mappings_bench_test.go create mode 100644 proxy/sync_mappings_test.go diff --git a/management/internals/shared/grpc/proxy.go b/management/internals/shared/grpc/proxy.go index eada2d86a..4abeb8e7c 100644 --- a/management/internals/shared/grpc/proxy.go +++ b/management/internals/shared/grpc/proxy.go @@ -9,6 +9,7 @@ import ( "encoding/hex" "errors" "fmt" + "io" "net" "net/http" "net/url" @@ -136,9 +137,12 @@ type proxyConnection struct { tokenID string capabilities *proto.ProxyCapabilities stream proto.ProxyService_GetMappingUpdateServer - sendChan chan *proto.GetMappingUpdateResponse - ctx context.Context - cancel context.CancelFunc + // syncStream is set when the proxy connected via SyncMappings. + // When non-nil, the sender goroutine uses this instead of stream. + syncStream proto.ProxyService_SyncMappingsServer + sendChan chan *proto.GetMappingUpdateResponse + ctx context.Context + cancel context.CancelFunc } func enforceAccountScope(ctx context.Context, requestAccountID string) error { @@ -206,145 +210,322 @@ func (s *ProxyServiceServer) SetProxyController(proxyController proxy.Controller s.proxyController = proxyController } +// proxyConnectParams holds the validated parameters extracted from either +// a GetMappingUpdateRequest or a SyncMappingsInit message. +type proxyConnectParams struct { + proxyID string + address string + capabilities *proto.ProxyCapabilities +} + // GetMappingUpdate handles the control stream with proxy clients func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest, stream proto.ProxyService_GetMappingUpdateServer) error { - ctx := stream.Context() + params, err := s.validateProxyConnect(req.GetProxyId(), req.GetAddress(), stream.Context()) + if err != nil { + return err + } + params.capabilities = req.GetCapabilities() - peerInfo := PeerIPFromContext(ctx) - log.Infof("New proxy connection from %s", peerInfo) + conn, proxyRecord, err := s.registerProxyConnection(stream.Context(), params, &proxyConnection{ + stream: stream, + }) + if err != nil { + return err + } - proxyID := req.GetProxyId() + if err := s.sendSnapshot(stream.Context(), conn); err != nil { + s.cleanupFailedSnapshot(stream.Context(), conn) + return fmt.Errorf("send snapshot to proxy %s: %w", params.proxyID, err) + } + + errChan := make(chan error, 2) + go s.sender(conn, errChan) + + return s.serveProxyConnection(conn, proxyRecord, errChan, false) +} + +// SyncMappings implements the bidirectional SyncMappings RPC. +// It mirrors GetMappingUpdate but provides application-level back-pressure: +// management waits for an ack from the proxy before sending the next batch. +func (s *ProxyServiceServer) SyncMappings(stream proto.ProxyService_SyncMappingsServer) error { + init, err := recvSyncInit(stream) + if err != nil { + return err + } + + params, err := s.validateProxyConnect(init.GetProxyId(), init.GetAddress(), stream.Context()) + if err != nil { + return err + } + params.capabilities = init.GetCapabilities() + + conn, proxyRecord, err := s.registerProxyConnection(stream.Context(), params, &proxyConnection{ + syncStream: stream, + }) + if err != nil { + return err + } + + if err := s.sendSnapshotSync(stream.Context(), conn, stream); err != nil { + s.cleanupFailedSnapshot(stream.Context(), conn) + return fmt.Errorf("send snapshot to proxy %s: %w", params.proxyID, err) + } + + errChan := make(chan error, 2) + go s.sender(conn, errChan) + go s.drainRecv(stream, errChan) + + return s.serveProxyConnection(conn, proxyRecord, errChan, true) +} + +// recvSyncInit receives and validates the first message on a SyncMappings stream. +func recvSyncInit(stream proto.ProxyService_SyncMappingsServer) (*proto.SyncMappingsInit, error) { + firstMsg, err := stream.Recv() + if err != nil { + return nil, status.Errorf(codes.Internal, "receive init: %v", err) + } + init := firstMsg.GetInit() + if init == nil { + return nil, status.Errorf(codes.InvalidArgument, "first message must be init") + } + return init, nil +} + +// validateProxyConnect validates the proxy ID and address, and checks cluster +// address availability for account-scoped tokens. +func (s *ProxyServiceServer) validateProxyConnect(proxyID, address string, ctx context.Context) (proxyConnectParams, error) { if proxyID == "" { - return status.Errorf(codes.InvalidArgument, "proxy_id is required") + return proxyConnectParams{}, status.Errorf(codes.InvalidArgument, "proxy_id is required") + } + if !isProxyAddressValid(address) { + return proxyConnectParams{}, status.Errorf(codes.InvalidArgument, "proxy address is invalid") } - proxyAddress := req.GetAddress() - if !isProxyAddressValid(proxyAddress) { - return status.Errorf(codes.InvalidArgument, "proxy address is invalid") - } - - var accountID *string token := GetProxyTokenFromContext(ctx) if token != nil && token.AccountID != nil { - accountID = token.AccountID - - available, err := s.proxyManager.IsClusterAddressAvailable(ctx, proxyAddress, *accountID) + available, err := s.proxyManager.IsClusterAddressAvailable(ctx, address, *token.AccountID) if err != nil { - return status.Errorf(codes.Internal, "check cluster address: %v", err) + return proxyConnectParams{}, status.Errorf(codes.Internal, "check cluster address: %v", err) } if !available { - return status.Errorf(codes.AlreadyExists, "cluster address %s is already in use", proxyAddress) + return proxyConnectParams{}, status.Errorf(codes.AlreadyExists, "cluster address %s is already in use", address) } } + return proxyConnectParams{proxyID: proxyID, address: address}, nil +} + +// registerProxyConnection creates a proxyConnection, registers it with the +// proxy manager and cluster, and stores it in connectedProxies. The caller +// provides a partially initialised connSeed with stream-specific fields set; +// the remaining fields are filled in here. +func (s *ProxyServiceServer) registerProxyConnection(ctx context.Context, params proxyConnectParams, connSeed *proxyConnection) (*proxyConnection, *proxy.Proxy, error) { + peerInfo := PeerIPFromContext(ctx) + + var accountID *string var tokenID string - if token != nil { + if token := GetProxyTokenFromContext(ctx); token != nil { + if token.AccountID != nil { + accountID = token.AccountID + } tokenID = token.ID } sessionID := uuid.NewString() - - if old, loaded := s.connectedProxies.Load(proxyID); loaded { - oldConn := old.(*proxyConnection) - log.WithFields(log.Fields{ - "proxy_id": proxyID, - "old_session_id": oldConn.sessionID, - "new_session_id": sessionID, - }).Info("Superseding existing proxy connection") - oldConn.cancel() - } + s.supersedePriorConnection(params.proxyID, sessionID) connCtx, cancel := context.WithCancel(ctx) - conn := &proxyConnection{ - proxyID: proxyID, - sessionID: sessionID, - address: proxyAddress, - accountID: accountID, - tokenID: tokenID, - capabilities: req.GetCapabilities(), - stream: stream, - sendChan: make(chan *proto.GetMappingUpdateResponse, 100), - ctx: connCtx, - cancel: cancel, - } + connSeed.proxyID = params.proxyID + connSeed.sessionID = sessionID + connSeed.address = params.address + connSeed.accountID = accountID + connSeed.tokenID = tokenID + connSeed.capabilities = params.capabilities + connSeed.sendChan = make(chan *proto.GetMappingUpdateResponse, 100) + connSeed.ctx = connCtx + connSeed.cancel = cancel var caps *proxy.Capabilities - if c := req.GetCapabilities(); c != nil { + if c := params.capabilities; c != nil { caps = &proxy.Capabilities{ SupportsCustomPorts: c.SupportsCustomPorts, RequireSubdomain: c.RequireSubdomain, SupportsCrowdsec: c.SupportsCrowdsec, } } - proxyRecord, err := s.proxyManager.Connect(ctx, proxyID, sessionID, proxyAddress, peerInfo, accountID, caps) + + proxyRecord, err := s.proxyManager.Connect(ctx, params.proxyID, sessionID, params.address, peerInfo, accountID, caps) if err != nil { cancel() if accountID != nil { - return status.Errorf(codes.Internal, "failed to register BYOP proxy: %v", err) + return nil, nil, status.Errorf(codes.Internal, "failed to register BYOP proxy: %v", err) } - log.WithContext(ctx).Warnf("failed to register proxy %s in database: %v", proxyID, err) - return status.Errorf(codes.Internal, "register proxy in database: %v", err) + log.WithContext(ctx).Warnf("failed to register proxy %s in database: %v", params.proxyID, err) + return nil, nil, status.Errorf(codes.Internal, "register proxy in database: %v", err) } - s.connectedProxies.Store(proxyID, conn) - if err := s.proxyController.RegisterProxyToCluster(ctx, conn.address, proxyID); err != nil { - log.WithContext(ctx).Warnf("Failed to register proxy %s in cluster: %v", proxyID, err) + s.connectedProxies.Store(params.proxyID, connSeed) + if err := s.proxyController.RegisterProxyToCluster(ctx, params.address, params.proxyID); err != nil { + log.WithContext(ctx).Warnf("Failed to register proxy %s in cluster: %v", params.proxyID, err) } - if err := s.sendSnapshot(ctx, conn); err != nil { - if s.connectedProxies.CompareAndDelete(proxyID, conn) { - if unregErr := s.proxyController.UnregisterProxyFromCluster(context.Background(), conn.address, proxyID); unregErr != nil { - log.WithContext(ctx).Debugf("cleanup after snapshot failure for proxy %s: %v", proxyID, unregErr) - } - } - cancel() - if disconnErr := s.proxyManager.Disconnect(context.Background(), proxyID, sessionID); disconnErr != nil { - log.WithContext(ctx).Debugf("cleanup after snapshot failure for proxy %s: %v", proxyID, disconnErr) - } - return fmt.Errorf("send snapshot to proxy %s: %w", proxyID, err) + return connSeed, proxyRecord, nil +} + +// supersedePriorConnection cancels any existing connection for the given proxy. +func (s *ProxyServiceServer) supersedePriorConnection(proxyID, newSessionID string) { + if old, loaded := s.connectedProxies.Load(proxyID); loaded { + oldConn := old.(*proxyConnection) + log.WithFields(log.Fields{ + "proxy_id": proxyID, + "old_session_id": oldConn.sessionID, + "new_session_id": newSessionID, + }).Info("Superseding existing proxy connection") + oldConn.cancel() } +} - errChan := make(chan error, 2) - go s.sender(conn, errChan) +// cleanupFailedSnapshot removes the connection from the cluster and store +// after a snapshot send failure. +func (s *ProxyServiceServer) cleanupFailedSnapshot(ctx context.Context, conn *proxyConnection) { + if s.connectedProxies.CompareAndDelete(conn.proxyID, conn) { + if err := s.proxyController.UnregisterProxyFromCluster(context.Background(), conn.address, conn.proxyID); err != nil { + log.WithContext(ctx).Debugf("cleanup after snapshot failure for proxy %s: %v", conn.proxyID, err) + } + } + conn.cancel() + if err := s.proxyManager.Disconnect(context.Background(), conn.proxyID, conn.sessionID); err != nil { + log.WithContext(ctx).Debugf("cleanup after snapshot failure for proxy %s: %v", conn.proxyID, err) + } +} - log.WithFields(log.Fields{ - "proxy_id": proxyID, - "session_id": sessionID, - "address": proxyAddress, - "cluster_addr": proxyAddress, - "account_id": accountID, - "total_proxies": len(s.GetConnectedProxies()), - }).Info("Proxy registered in cluster") - defer func() { - if !s.connectedProxies.CompareAndDelete(proxyID, conn) { - log.Infof("Proxy %s session %s: skipping cleanup, superseded by new connection", proxyID, sessionID) - cancel() +// drainRecv consumes and discards messages from a bidirectional stream. +// The proxy sends an ack for every incremental update; we don't need them +// after the snapshot phase. Recv errors are forwarded to errChan. +func (s *ProxyServiceServer) drainRecv(stream proto.ProxyService_SyncMappingsServer, errChan chan<- error) { + for { + if _, err := stream.Recv(); err != nil { + errChan <- err return } + } +} - if err := s.proxyController.UnregisterProxyFromCluster(context.Background(), conn.address, proxyID); err != nil { - log.Warnf("Failed to unregister proxy %s from cluster: %v", proxyID, err) - } - if err := s.proxyManager.Disconnect(context.Background(), proxyID, sessionID); err != nil { - log.Warnf("Failed to mark proxy %s as disconnected: %v", proxyID, err) - } +// serveProxyConnection runs the post-snapshot lifecycle: heartbeat, sender, +// and wait for termination. When bidi is true, normal stream closure (EOF, +// canceled) is treated as a clean disconnect rather than an error. +func (s *ProxyServiceServer) serveProxyConnection(conn *proxyConnection, proxyRecord *proxy.Proxy, errChan <-chan error, bidi bool) error { + log.WithFields(log.Fields{ + "proxy_id": conn.proxyID, + "session_id": conn.sessionID, + "address": conn.address, + "cluster_addr": conn.address, + "account_id": conn.accountID, + "total_proxies": len(s.GetConnectedProxies()), + }).Info("Proxy registered in cluster") - cancel() - log.Infof("Proxy %s session %s disconnected", proxyID, sessionID) - }() - - go s.heartbeat(connCtx, conn, proxyRecord) + defer s.disconnectProxy(conn) + go s.heartbeat(conn.ctx, conn, proxyRecord) select { case err := <-errChan: - log.WithContext(ctx).Warnf("Failed to send update: %v", err) - return fmt.Errorf("send update to proxy %s: %w", proxyID, err) - case <-connCtx.Done(): - log.WithContext(ctx).Infof("Proxy %s context canceled", proxyID) - return connCtx.Err() + if bidi && isStreamClosed(err) { + log.Infof("Proxy %s stream closed", conn.proxyID) + return nil + } + log.Warnf("Failed to send update: %v", err) + return fmt.Errorf("send update to proxy %s: %w", conn.proxyID, err) + case <-conn.ctx.Done(): + log.Infof("Proxy %s context canceled", conn.proxyID) + return conn.ctx.Err() } } +// disconnectProxy removes the connection from cluster and store, unless it +// has already been superseded by a newer connection. +func (s *ProxyServiceServer) disconnectProxy(conn *proxyConnection) { + if !s.connectedProxies.CompareAndDelete(conn.proxyID, conn) { + log.Infof("Proxy %s session %s: skipping cleanup, superseded by new connection", conn.proxyID, conn.sessionID) + conn.cancel() + return + } + + if err := s.proxyController.UnregisterProxyFromCluster(context.Background(), conn.address, conn.proxyID); err != nil { + log.Warnf("Failed to unregister proxy %s from cluster: %v", conn.proxyID, err) + } + if err := s.proxyManager.Disconnect(context.Background(), conn.proxyID, conn.sessionID); err != nil { + log.Warnf("Failed to mark proxy %s as disconnected: %v", conn.proxyID, err) + } + + conn.cancel() + log.Infof("Proxy %s session %s disconnected", conn.proxyID, conn.sessionID) +} + +// sendSnapshotSync sends the initial snapshot with back-pressure: it sends +// one batch, then waits for the proxy to ack before sending the next. +func (s *ProxyServiceServer) sendSnapshotSync(ctx context.Context, conn *proxyConnection, stream proto.ProxyService_SyncMappingsServer) error { + if !isProxyAddressValid(conn.address) { + return fmt.Errorf("proxy address is invalid") + } + if s.snapshotBatchSize <= 0 { + return fmt.Errorf("invalid snapshot batch size: %d", s.snapshotBatchSize) + } + + mappings, err := s.snapshotServiceMappings(ctx, conn) + if err != nil { + return err + } + + for i := 0; i < len(mappings); i += s.snapshotBatchSize { + end := i + s.snapshotBatchSize + if end > len(mappings) { + end = len(mappings) + } + for _, m := range mappings[i:end] { + token, err := s.tokenStore.GenerateToken(m.AccountId, m.Id, s.proxyTokenTTL()) + if err != nil { + return fmt.Errorf("generate auth token for service %s: %w", m.Id, err) + } + m.AuthToken = token + } + if err := stream.Send(&proto.SyncMappingsResponse{ + Mapping: mappings[i:end], + InitialSyncComplete: end == len(mappings), + }); err != nil { + return fmt.Errorf("send snapshot batch: %w", err) + } + + if err := waitForAck(stream); err != nil { + return err + } + } + + if len(mappings) == 0 { + if err := stream.Send(&proto.SyncMappingsResponse{ + InitialSyncComplete: true, + }); err != nil { + return fmt.Errorf("send snapshot completion: %w", err) + } + + if err := waitForAck(stream); err != nil { + return err + } + } + + return nil +} + +func waitForAck(stream proto.ProxyService_SyncMappingsServer) error { + msg, err := stream.Recv() + if err != nil { + return fmt.Errorf("receive ack: %w", err) + } + if msg.GetAck() == nil { + return fmt.Errorf("expected ack, got %T", msg.GetMsg()) + } + return nil +} + // heartbeat updates the proxy's last_seen timestamp every minute and // disconnects the proxy if its access token has been revoked. func (s *ProxyServiceServer) heartbeat(ctx context.Context, conn *proxyConnection, p *proxy.Proxy) { @@ -381,6 +562,9 @@ func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnec if !isProxyAddressValid(conn.address) { return fmt.Errorf("proxy address is invalid") } + if s.snapshotBatchSize <= 0 { + return fmt.Errorf("invalid snapshot batch size: %d", s.snapshotBatchSize) + } mappings, err := s.snapshotServiceMappings(ctx, conn) if err != nil { @@ -460,12 +644,26 @@ func isProxyAddressValid(addr string) bool { return err == nil } -// sender handles sending messages to proxy +// isStreamClosed returns true for errors that indicate normal stream +// termination: io.EOF, context cancellation, or gRPC Canceled. +func isStreamClosed(err error) bool { + if err == nil { + return false + } + if errors.Is(err, io.EOF) || errors.Is(err, context.Canceled) { + return true + } + return status.Code(err) == codes.Canceled +} + +// sender handles sending messages to proxy. +// When conn.syncStream is set the message is sent as SyncMappingsResponse; +// otherwise the legacy GetMappingUpdateResponse stream is used. func (s *ProxyServiceServer) sender(conn *proxyConnection, errChan chan<- error) { for { select { case resp := <-conn.sendChan: - if err := conn.stream.Send(resp); err != nil { + if err := conn.sendResponse(resp); err != nil { errChan <- err return } @@ -475,6 +673,17 @@ func (s *ProxyServiceServer) sender(conn *proxyConnection, errChan chan<- error) } } +// sendResponse sends a mapping update on whichever stream the proxy connected with. +func (conn *proxyConnection) sendResponse(resp *proto.GetMappingUpdateResponse) error { + if conn.syncStream != nil { + return conn.syncStream.Send(&proto.SyncMappingsResponse{ + Mapping: resp.Mapping, + InitialSyncComplete: resp.InitialSyncComplete, + }) + } + return conn.stream.Send(resp) +} + // SendAccessLog processes access log from proxy func (s *ProxyServiceServer) SendAccessLog(ctx context.Context, req *proto.SendAccessLogRequest) (*proto.SendAccessLogResponse, error) { accessLog := req.GetLog() @@ -541,8 +750,8 @@ func (s *ProxyServiceServer) SendServiceUpdate(update *proto.GetMappingUpdateRes return true } connUpdate = &proto.GetMappingUpdateResponse{ - Mapping: filtered, - InitialSyncComplete: update.InitialSyncComplete, + Mapping: filtered, + InitialSyncComplete: update.InitialSyncComplete, } } resp := s.perProxyMessage(connUpdate, conn.proxyID) diff --git a/management/internals/shared/grpc/sync_mappings_test.go b/management/internals/shared/grpc/sync_mappings_test.go new file mode 100644 index 000000000..97f6183bb --- /dev/null +++ b/management/internals/shared/grpc/sync_mappings_test.go @@ -0,0 +1,411 @@ +package grpc + +import ( + "context" + "fmt" + "sync" + "testing" + "time" + + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/grpc" + "google.golang.org/grpc/metadata" + + rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" + "github.com/netbirdio/netbird/shared/management/proto" +) + +// syncRecordingStream is a mock ProxyService_SyncMappingsServer that records +// sent messages and returns pre-loaded ack responses from Recv. +type syncRecordingStream struct { + grpc.ServerStream + + mu sync.Mutex + sent []*proto.SyncMappingsResponse + recvMsgs []*proto.SyncMappingsRequest + recvIdx int +} + +func (s *syncRecordingStream) Send(m *proto.SyncMappingsResponse) error { + s.mu.Lock() + defer s.mu.Unlock() + s.sent = append(s.sent, m) + return nil +} + +func (s *syncRecordingStream) Recv() (*proto.SyncMappingsRequest, error) { + s.mu.Lock() + defer s.mu.Unlock() + if s.recvIdx >= len(s.recvMsgs) { + return nil, fmt.Errorf("no more recv messages") + } + msg := s.recvMsgs[s.recvIdx] + s.recvIdx++ + return msg, nil +} + +func (s *syncRecordingStream) Context() context.Context { return context.Background() } +func (s *syncRecordingStream) SetHeader(metadata.MD) error { return nil } +func (s *syncRecordingStream) SendHeader(metadata.MD) error { return nil } +func (s *syncRecordingStream) SetTrailer(metadata.MD) {} +func (s *syncRecordingStream) SendMsg(any) error { return nil } +func (s *syncRecordingStream) RecvMsg(any) error { return nil } + +func ackMsg() *proto.SyncMappingsRequest { + return &proto.SyncMappingsRequest{ + Msg: &proto.SyncMappingsRequest_Ack{Ack: &proto.SyncMappingsAck{}}, + } +} + +func TestSendSnapshotSync_BatchesWithAcks(t *testing.T) { + const cluster = "cluster.example.com" + const batchSize = 3 + const totalServices = 7 // 3 + 3 + 1 โ†’ 3 batches, 3 acks (one per batch, including final) + + ctrl := gomock.NewController(t) + mgr := rpservice.NewMockManager(ctrl) + mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(makeServices(totalServices, cluster), nil) + + s := newSnapshotTestServer(t, batchSize) + s.serviceManager = mgr + + stream := &syncRecordingStream{ + recvMsgs: []*proto.SyncMappingsRequest{ackMsg(), ackMsg(), ackMsg()}, + } + conn := &proxyConnection{ + proxyID: "proxy-a", + address: cluster, + syncStream: stream, + } + + err := s.sendSnapshotSync(context.Background(), conn, stream) + require.NoError(t, err) + + require.Len(t, stream.sent, 3, "should send ceil(7/3) = 3 batches") + + assert.Len(t, stream.sent[0].Mapping, 3) + assert.False(t, stream.sent[0].InitialSyncComplete) + + assert.Len(t, stream.sent[1].Mapping, 3) + assert.False(t, stream.sent[1].InitialSyncComplete) + + assert.Len(t, stream.sent[2].Mapping, 1) + assert.True(t, stream.sent[2].InitialSyncComplete) + + // All 3 acks consumed โ€” including the final batch. + assert.Equal(t, 3, stream.recvIdx) +} + +func TestSendSnapshotSync_SingleBatchWaitsForAck(t *testing.T) { + const cluster = "cluster.example.com" + const batchSize = 100 + const totalServices = 5 + + ctrl := gomock.NewController(t) + mgr := rpservice.NewMockManager(ctrl) + mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(makeServices(totalServices, cluster), nil) + + s := newSnapshotTestServer(t, batchSize) + s.serviceManager = mgr + + stream := &syncRecordingStream{ + recvMsgs: []*proto.SyncMappingsRequest{ackMsg()}, + } + conn := &proxyConnection{ + proxyID: "proxy-a", + address: cluster, + syncStream: stream, + } + + err := s.sendSnapshotSync(context.Background(), conn, stream) + require.NoError(t, err) + + require.Len(t, stream.sent, 1) + assert.Len(t, stream.sent[0].Mapping, totalServices) + assert.True(t, stream.sent[0].InitialSyncComplete) + assert.Equal(t, 1, stream.recvIdx, "final batch ack must be consumed") +} + +func TestSendSnapshotSync_EmptySnapshot(t *testing.T) { + const cluster = "cluster.example.com" + + ctrl := gomock.NewController(t) + mgr := rpservice.NewMockManager(ctrl) + mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(nil, nil) + + s := newSnapshotTestServer(t, 500) + s.serviceManager = mgr + + stream := &syncRecordingStream{ + recvMsgs: []*proto.SyncMappingsRequest{ackMsg()}, + } + conn := &proxyConnection{ + proxyID: "proxy-a", + address: cluster, + syncStream: stream, + } + + err := s.sendSnapshotSync(context.Background(), conn, stream) + require.NoError(t, err) + + require.Len(t, stream.sent, 1, "empty snapshot must still send sync-complete") + assert.Empty(t, stream.sent[0].Mapping) + assert.True(t, stream.sent[0].InitialSyncComplete) + assert.Equal(t, 1, stream.recvIdx, "empty snapshot ack must be consumed") +} + +func TestSendSnapshotSync_MissingAckReturnsError(t *testing.T) { + const cluster = "cluster.example.com" + const batchSize = 2 + const totalServices = 4 // 2 batches โ†’ 1 ack needed, but we provide none + + ctrl := gomock.NewController(t) + mgr := rpservice.NewMockManager(ctrl) + mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(makeServices(totalServices, cluster), nil) + + s := newSnapshotTestServer(t, batchSize) + s.serviceManager = mgr + + // No acks available โ€” Recv will return error. + stream := &syncRecordingStream{} + conn := &proxyConnection{ + proxyID: "proxy-a", + address: cluster, + syncStream: stream, + } + + err := s.sendSnapshotSync(context.Background(), conn, stream) + require.Error(t, err) + assert.Contains(t, err.Error(), "receive ack") + // First batch should have been sent before the error. + require.Len(t, stream.sent, 1) +} + +func TestSendSnapshotSync_WrongMessageInsteadOfAck(t *testing.T) { + const cluster = "cluster.example.com" + const batchSize = 2 + const totalServices = 4 + + ctrl := gomock.NewController(t) + mgr := rpservice.NewMockManager(ctrl) + mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(makeServices(totalServices, cluster), nil) + + s := newSnapshotTestServer(t, batchSize) + s.serviceManager = mgr + + // Send an init message instead of an ack. + stream := &syncRecordingStream{ + recvMsgs: []*proto.SyncMappingsRequest{ + {Msg: &proto.SyncMappingsRequest_Init{Init: &proto.SyncMappingsInit{ProxyId: "bad"}}}, + }, + } + conn := &proxyConnection{ + proxyID: "proxy-a", + address: cluster, + syncStream: stream, + } + + err := s.sendSnapshotSync(context.Background(), conn, stream) + require.Error(t, err) + assert.Contains(t, err.Error(), "expected ack") +} + +func TestSendSnapshotSync_BackPressureOrdering(t *testing.T) { + // Verify batches are sent strictly sequentially โ€” batch N+1 is not sent + // until the ack for batch N is received, including the final batch. + const cluster = "cluster.example.com" + const batchSize = 2 + const totalServices = 6 // 3 batches, 3 acks + + ctrl := gomock.NewController(t) + mgr := rpservice.NewMockManager(ctrl) + mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(makeServices(totalServices, cluster), nil) + + s := newSnapshotTestServer(t, batchSize) + s.serviceManager = mgr + + var mu sync.Mutex + var events []string + + // Build a stream that logs send/recv events so we can verify ordering. + ackCh := make(chan struct{}, 3) + stream := &orderTrackingStream{ + mu: &mu, + events: &events, + ackCh: ackCh, + } + conn := &proxyConnection{ + proxyID: "proxy-a", + address: cluster, + syncStream: stream, + } + + // Feed acks asynchronously after a short delay to simulate real proxy. + go func() { + for range 3 { + time.Sleep(10 * time.Millisecond) + ackCh <- struct{}{} + } + }() + + err := s.sendSnapshotSync(context.Background(), conn, stream) + require.NoError(t, err) + + mu.Lock() + defer mu.Unlock() + + // Expected: send, recv-ack, send, recv-ack, send, recv-ack. + require.Len(t, events, 6) + assert.Equal(t, "send", events[0]) + assert.Equal(t, "recv", events[1]) + assert.Equal(t, "send", events[2]) + assert.Equal(t, "recv", events[3]) + assert.Equal(t, "send", events[4]) + assert.Equal(t, "recv", events[5]) +} + +// orderTrackingStream logs "send" and "recv" events and blocks Recv until +// an ack is signaled via ackCh. +type orderTrackingStream struct { + grpc.ServerStream + mu *sync.Mutex + events *[]string + ackCh chan struct{} +} + +func (s *orderTrackingStream) Send(_ *proto.SyncMappingsResponse) error { + s.mu.Lock() + *s.events = append(*s.events, "send") + s.mu.Unlock() + return nil +} + +func (s *orderTrackingStream) Recv() (*proto.SyncMappingsRequest, error) { + <-s.ackCh + s.mu.Lock() + *s.events = append(*s.events, "recv") + s.mu.Unlock() + return ackMsg(), nil +} + +func (s *orderTrackingStream) Context() context.Context { return context.Background() } +func (s *orderTrackingStream) SetHeader(metadata.MD) error { return nil } +func (s *orderTrackingStream) SendHeader(metadata.MD) error { return nil } +func (s *orderTrackingStream) SetTrailer(metadata.MD) {} +func (s *orderTrackingStream) SendMsg(any) error { return nil } +func (s *orderTrackingStream) RecvMsg(any) error { return nil } + +func TestSendSnapshotSync_TokensGeneratedPerBatch(t *testing.T) { + const cluster = "cluster.example.com" + const batchSize = 2 + const totalServices = 4 + const ttl = 100 * time.Millisecond + const ackDelay = 200 * time.Millisecond + + ctrl := gomock.NewController(t) + mgr := rpservice.NewMockManager(ctrl) + mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(makeServices(totalServices, cluster), nil) + + s := newSnapshotTestServer(t, batchSize) + s.serviceManager = mgr + s.tokenTTL = ttl + + // Build a stream that validates tokens immediately on Send, then + // delays the ack to ensure the next batch's tokens are generated fresh. + var validateErrs []error + ackCh := make(chan struct{}, 2) + stream := &tokenValidatingSyncStream{ + tokenStore: s.tokenStore, + validateErrs: &validateErrs, + ackCh: ackCh, + } + conn := &proxyConnection{ + proxyID: "proxy-a", + address: cluster, + syncStream: stream, + } + + go func() { + // Delay first ack so that if tokens were all generated upfront they'd expire. + time.Sleep(ackDelay) + ackCh <- struct{}{} + // Final batch ack โ€” immediate. + ackCh <- struct{}{} + }() + + err := s.sendSnapshotSync(context.Background(), conn, stream) + require.NoError(t, err) + require.Empty(t, validateErrs, + "tokens must remain valid: per-batch generation guarantees freshness") +} + +type tokenValidatingSyncStream struct { + grpc.ServerStream + tokenStore *OneTimeTokenStore + validateErrs *[]error + ackCh chan struct{} +} + +func (s *tokenValidatingSyncStream) Send(m *proto.SyncMappingsResponse) error { + for _, mapping := range m.Mapping { + if err := s.tokenStore.ValidateAndConsume(mapping.AuthToken, mapping.AccountId, mapping.Id); err != nil { + *s.validateErrs = append(*s.validateErrs, fmt.Errorf("svc %s: %w", mapping.Id, err)) + } + } + return nil +} + +func (s *tokenValidatingSyncStream) Recv() (*proto.SyncMappingsRequest, error) { + <-s.ackCh + return ackMsg(), nil +} + +func (s *tokenValidatingSyncStream) Context() context.Context { return context.Background() } +func (s *tokenValidatingSyncStream) SetHeader(metadata.MD) error { return nil } +func (s *tokenValidatingSyncStream) SendHeader(metadata.MD) error { return nil } +func (s *tokenValidatingSyncStream) SetTrailer(metadata.MD) {} +func (s *tokenValidatingSyncStream) SendMsg(any) error { return nil } +func (s *tokenValidatingSyncStream) RecvMsg(any) error { return nil } + +func TestConnectionSendResponse_RoutesToSyncStream(t *testing.T) { + stream := &syncRecordingStream{} + conn := &proxyConnection{ + syncStream: stream, + } + + resp := &proto.GetMappingUpdateResponse{ + Mapping: []*proto.ProxyMapping{ + {Id: "svc-1", AccountId: "acct-1", Domain: "example.com"}, + }, + InitialSyncComplete: true, + } + + err := conn.sendResponse(resp) + require.NoError(t, err) + + require.Len(t, stream.sent, 1) + assert.Len(t, stream.sent[0].Mapping, 1) + assert.Equal(t, "svc-1", stream.sent[0].Mapping[0].Id) + assert.True(t, stream.sent[0].InitialSyncComplete) +} + +func TestConnectionSendResponse_RoutesToLegacyStream(t *testing.T) { + stream := &recordingStream{} + conn := &proxyConnection{ + stream: stream, + } + + resp := &proto.GetMappingUpdateResponse{ + Mapping: []*proto.ProxyMapping{ + {Id: "svc-2", AccountId: "acct-2"}, + }, + } + + err := conn.sendResponse(resp) + require.NoError(t, err) + + require.Len(t, stream.messages, 1) + assert.Equal(t, "svc-2", stream.messages[0].Mapping[0].Id) +} diff --git a/proxy/handle_mapping_stream_test.go b/proxy/handle_mapping_stream_test.go index cb16c0814..67c399e44 100644 --- a/proxy/handle_mapping_stream_test.go +++ b/proxy/handle_mapping_stream_test.go @@ -4,6 +4,7 @@ import ( "context" "io" "testing" + "time" log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" @@ -59,7 +60,7 @@ func TestHandleMappingStream_SyncCompleteFlag(t *testing.T) { } syncDone := false - err := s.handleMappingStream(context.Background(), stream, &syncDone) + err := s.handleMappingStream(context.Background(), stream, &syncDone, time.Time{}) assert.NoError(t, err) assert.True(t, syncDone, "initial sync should be marked done when flag is set") } @@ -79,7 +80,7 @@ func TestHandleMappingStream_NoSyncFlagDoesNotMarkDone(t *testing.T) { } syncDone := false - err := s.handleMappingStream(context.Background(), stream, &syncDone) + err := s.handleMappingStream(context.Background(), stream, &syncDone, time.Time{}) assert.NoError(t, err) assert.False(t, syncDone, "initial sync should not be marked done without flag") } @@ -97,7 +98,7 @@ func TestHandleMappingStream_NilHealthChecker(t *testing.T) { } syncDone := false - err := s.handleMappingStream(context.Background(), stream, &syncDone) + err := s.handleMappingStream(context.Background(), stream, &syncDone, time.Time{}) assert.NoError(t, err) assert.True(t, syncDone, "sync done flag should be set even without health checker") } diff --git a/proxy/internal/metrics/metrics.go b/proxy/internal/metrics/metrics.go index 573485625..41a6b0dd4 100644 --- a/proxy/internal/metrics/metrics.go +++ b/proxy/internal/metrics/metrics.go @@ -25,6 +25,11 @@ type Metrics struct { backendDuration metric.Int64Histogram certificateIssueDuration metric.Int64Histogram + // Management sync metrics. + snapshotSyncDuration metric.Int64Histogram + snapshotBatchDuration metric.Int64Histogram + addPeerDuration metric.Int64Histogram + // L4 service-level metrics. l4Services metric.Int64UpDownCounter @@ -54,6 +59,9 @@ func New(ctx context.Context, meter metric.Meter) (*Metrics, error) { if err := m.initHTTPMetrics(meter); err != nil { return nil, err } + if err := m.initSyncMetrics(meter); err != nil { + return nil, err + } if err := m.initL4Metrics(meter); err != nil { return nil, err } @@ -126,6 +134,59 @@ func (m *Metrics) initHTTPMetrics(meter metric.Meter) error { return err } +func (m *Metrics) initSyncMetrics(meter metric.Meter) error { + var err error + + m.snapshotSyncDuration, err = meter.Int64Histogram( + "proxy.sync.snapshot.duration.ms", + metric.WithUnit("milliseconds"), + metric.WithDescription("Duration from management connect until the initial snapshot sync is complete"), + metric.WithExplicitBucketBoundaries(100, 250, 500, 1000, 2500, 5000, 10000, 30000, 60000, 120000, 300000), + ) + if err != nil { + return err + } + + m.snapshotBatchDuration, err = meter.Int64Histogram( + "proxy.sync.batch.duration.ms", + metric.WithUnit("milliseconds"), + metric.WithDescription("Duration to process a single mapping batch during initial snapshot sync"), + metric.WithExplicitBucketBoundaries(100, 250, 500, 1000, 2500, 5000, 10000, 30000, 60000, 120000, 300000), + ) + if err != nil { + return err + } + + m.addPeerDuration, err = meter.Int64Histogram( + "proxy.peer.add.duration.ms", + metric.WithUnit("milliseconds"), + metric.WithDescription("Duration to add a peer for an account (keygen + gRPC CreateProxyPeer + embed.New)"), + metric.WithExplicitBucketBoundaries(10, 25, 50, 100, 250, 500, 1000, 2500, 5000, 10000), + ) + return err +} + +// RecordSnapshotSyncDuration records the total time from connect to sync-complete. +func (m *Metrics) RecordSnapshotSyncDuration(d time.Duration) { + m.snapshotSyncDuration.Record(m.ctx, d.Milliseconds()) +} + +// RecordSnapshotBatchDuration records the time to process one mapping batch during initial sync. +func (m *Metrics) RecordSnapshotBatchDuration(d time.Duration) { + m.snapshotBatchDuration.Record(m.ctx, d.Milliseconds()) +} + +// RecordAddPeerDuration records the time to create a new peer for an account. +func (m *Metrics) RecordAddPeerDuration(d time.Duration, err error) { + result := "success" + if err != nil { + result = "error" + } + m.addPeerDuration.Record(m.ctx, d.Milliseconds(), metric.WithAttributes( + attribute.String("result", result), + )) +} + func (m *Metrics) initL4Metrics(meter metric.Meter) error { var err error diff --git a/proxy/internal/roundtrip/netbird.go b/proxy/internal/roundtrip/netbird.go index e38e3dc4e..28eba7810 100644 --- a/proxy/internal/roundtrip/netbird.go +++ b/proxy/internal/roundtrip/netbird.go @@ -76,6 +76,11 @@ type clientEntry struct { services map[ServiceKey]serviceInfo createdAt time.Time started bool + // ready is closed once the client has been fully initialized. + // Callers that find a pending entry wait on this channel before + // accessing the client. A nil initErr means success. + ready chan struct{} + initErr error // Per-backend in-flight limiting keyed by target host:port. // TODO: clean up stale entries when backend targets change. inflightMu sync.Mutex @@ -137,6 +142,11 @@ type NetBird struct { clients map[types.AccountID]*clientEntry initLogOnce sync.Once statusNotifier statusNotifier + + // OnAddPeer, when set, is called after AddPeer completes for a new account + // (i.e. when a new client was actually created, not when an existing one + // was reused). The duration covers keygen + gRPC CreateProxyPeer + embed.New. + OnAddPeer func(d time.Duration, err error) } // ClientDebugInfo contains debug information about a client. @@ -157,6 +167,9 @@ type skipTLSVerifyContextKey struct{} // AddPeer registers a service for an account. If the account doesn't have a client yet, // one is created by authenticating with the management server using the provided token. // Multiple services can share the same client. +// +// Client creation (WG keygen, gRPC, embed.New) runs without holding clientsMux +// so that concurrent AddPeer calls for different accounts execute in parallel. func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, key ServiceKey, authToken string, serviceID types.ServiceID) error { si := serviceInfo{serviceID: serviceID} @@ -164,10 +177,23 @@ func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, key Se entry, exists := n.clients[accountID] if exists { + ready := entry.ready entry.services[key] = si started := entry.started n.clientsMux.Unlock() + // If the entry is still being initialized by another goroutine, wait. + if ready != nil { + select { + case <-ready: + case <-ctx.Done(): + return ctx.Err() + } + if entry.initErr != nil { + return fmt.Errorf("peer initialization failed: %w", entry.initErr) + } + } + n.logger.WithFields(log.Fields{ "account_id": accountID, "service_key": key, @@ -184,15 +210,43 @@ func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, key Se return nil } - entry, err := n.createClientEntry(ctx, accountID, key, authToken, si) + // Insert a placeholder so other goroutines calling AddPeer for the same + // account will wait on the ready channel instead of starting a second + // client creation. + entry = &clientEntry{ + services: map[ServiceKey]serviceInfo{key: si}, + ready: make(chan struct{}), + } + n.clients[accountID] = entry + n.clientsMux.Unlock() + + createStart := time.Now() + created, err := n.createClientEntry(ctx, accountID, key, authToken, si) + if n.OnAddPeer != nil { + n.OnAddPeer(time.Since(createStart), err) + } if err != nil { + entry.initErr = err + close(entry.ready) + + n.clientsMux.Lock() + delete(n.clients, accountID) n.clientsMux.Unlock() return err } - n.clients[accountID] = entry + // Transfer any services that were registered by concurrent AddPeer calls + // while we were creating the client. + n.clientsMux.Lock() + for k, v := range entry.services { + created.services[k] = v + } + created.ready = nil + n.clients[accountID] = created n.clientsMux.Unlock() + close(entry.ready) + n.logger.WithFields(log.Fields{ "account_id": accountID, "service_key": key, @@ -200,13 +254,13 @@ func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, key Se // Attempt to start the client in the background; if this fails we will // retry on the first request via RoundTrip. - go n.runClientStartup(ctx, accountID, entry.client) + go n.runClientStartup(ctx, accountID, created.client) return nil } // createClientEntry generates a WireGuard keypair, authenticates with management, -// and creates an embedded NetBird client. Must be called with clientsMux held. +// and creates an embedded NetBird client. func (n *NetBird) createClientEntry(ctx context.Context, accountID types.AccountID, key ServiceKey, authToken string, si serviceInfo) (*clientEntry, error) { serviceID := si.serviceID n.logger.WithFields(log.Fields{ diff --git a/proxy/process_mappings_bench_test.go b/proxy/process_mappings_bench_test.go new file mode 100644 index 000000000..ca0792590 --- /dev/null +++ b/proxy/process_mappings_bench_test.go @@ -0,0 +1,300 @@ +package proxy + +import ( + "context" + "fmt" + "net" + "testing" + "time" + + log "github.com/sirupsen/logrus" + "google.golang.org/grpc" + + "github.com/netbirdio/netbird/proxy/internal/auth" + "github.com/netbirdio/netbird/proxy/internal/conntrack" + "github.com/netbirdio/netbird/proxy/internal/crowdsec" + proxymetrics "github.com/netbirdio/netbird/proxy/internal/metrics" + "github.com/netbirdio/netbird/proxy/internal/proxy" + "github.com/netbirdio/netbird/proxy/internal/roundtrip" + nbtcp "github.com/netbirdio/netbird/proxy/internal/tcp" + "github.com/netbirdio/netbird/proxy/internal/types" + udprelay "github.com/netbirdio/netbird/proxy/internal/udp" + "github.com/netbirdio/netbird/shared/management/proto" + + "go.opentelemetry.io/otel/metric/noop" +) + +// latencyMockClient simulates realistic gRPC latency for management calls. +type latencyMockClient struct { + proto.ProxyServiceClient + createPeerDelay time.Duration + statusUpdateDelay time.Duration +} + +func (m *latencyMockClient) SendStatusUpdate(ctx context.Context, _ *proto.SendStatusUpdateRequest, _ ...grpc.CallOption) (*proto.SendStatusUpdateResponse, error) { + if m.statusUpdateDelay > 0 { + select { + case <-time.After(m.statusUpdateDelay): + case <-ctx.Done(): + return nil, ctx.Err() + } + } + return &proto.SendStatusUpdateResponse{}, nil +} + +func (m *latencyMockClient) CreateProxyPeer(ctx context.Context, _ *proto.CreateProxyPeerRequest, _ ...grpc.CallOption) (*proto.CreateProxyPeerResponse, error) { + if m.createPeerDelay > 0 { + select { + case <-time.After(m.createPeerDelay): + case <-ctx.Done(): + return nil, ctx.Err() + } + } + return &proto.CreateProxyPeerResponse{Success: true}, nil +} + +type discardWriter struct{} + +func (discardWriter) Write(p []byte) (int, error) { return len(p), nil } + +func benchServerWithLatency(b *testing.B, createPeerDelay, statusDelay time.Duration) *Server { + b.Helper() + logger := log.New() + logger.SetLevel(log.FatalLevel) + logger.SetOutput(&discardWriter{}) + + meter, err := proxymetrics.New(context.Background(), noop.Meter{}) + if err != nil { + b.Fatal(err) + } + + mgmtClient := &latencyMockClient{ + createPeerDelay: createPeerDelay, + statusUpdateDelay: statusDelay, + } + + nb := roundtrip.NewNetBird("bench-proxy", "bench.test", + roundtrip.ClientConfig{MgmtAddr: "http://bench.test:9999"}, + logger, nil, mgmtClient) + + mainRouter := nbtcp.NewRouter(logger, func(accountID types.AccountID) (types.DialContextFunc, error) { + return (&net.Dialer{}).DialContext, nil + }, &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 443}) + + return &Server{ + Logger: logger, + mgmtClient: mgmtClient, + netbird: nb, + proxy: proxy.NewReverseProxy(nil, "auto", nil, logger), + auth: auth.NewMiddleware(logger, nil, nil), + mainRouter: mainRouter, + mainPort: 443, + meter: meter, + hijackTracker: conntrack.HijackTracker{}, + crowdsecRegistry: crowdsec.NewRegistry("", "", log.NewEntry(logger)), + crowdsecServices: make(map[types.ServiceID]bool), + lastMappings: make(map[types.ServiceID]*proto.ProxyMapping), + portRouters: make(map[uint16]*portRouter), + svcPorts: make(map[types.ServiceID][]uint16), + udpRelays: make(map[types.ServiceID]*udprelay.Relay), + } +} + +// generateHTTPMappings creates N HTTP-mode mappings with the given update type. +// All belong to a single account to share the embedded client. +func generateHTTPMappings(n int, updateType proto.ProxyMappingUpdateType) []*proto.ProxyMapping { + mappings := make([]*proto.ProxyMapping, n) + for i := range n { + mappings[i] = &proto.ProxyMapping{ + Type: updateType, + Id: fmt.Sprintf("svc-%d", i), + AccountId: "account-1", + Domain: fmt.Sprintf("svc-%d.bench.example.com", i), + Mode: "http", + Path: []*proto.PathMapping{ + { + Path: "/", + Target: fmt.Sprintf("http://10.0.%d.%d:8080", (i/256)%256, i%256), + }, + }, + Auth: &proto.Authentication{}, + } + } + return mappings +} + +// generateMultiAccountHTTPMappings creates N HTTP-mode CREATED mappings spread +// across the given number of accounts. This stresses the AddPeer new-account +// path which calls CreateProxyPeer + embed.New per unique account. +func generateMultiAccountHTTPMappings(n, accounts int) []*proto.ProxyMapping { + mappings := make([]*proto.ProxyMapping, n) + for i := range n { + mappings[i] = &proto.ProxyMapping{ + Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED, + Id: fmt.Sprintf("svc-%d", i), + AccountId: fmt.Sprintf("account-%d", i%accounts), + Domain: fmt.Sprintf("svc-%d.bench.example.com", i), + Mode: "http", + Path: []*proto.PathMapping{ + { + Path: "/", + Target: fmt.Sprintf("http://10.0.%d.%d:8080", (i/256)%256, i%256), + }, + }, + Auth: &proto.Authentication{}, + } + } + return mappings +} + +// generateMixedMappings creates mappings with a realistic distribution: +// 70% HTTP create, 15% modify existing, 10% TLS on main port, 5% remove. +// All use a single account to avoid embed.New dialing. +func generateMixedMappings(n int) []*proto.ProxyMapping { + mappings := make([]*proto.ProxyMapping, n) + for i := range n { + var m *proto.ProxyMapping + switch { + case i%20 < 14: // 70% HTTP create + m = &proto.ProxyMapping{ + Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED, + Id: fmt.Sprintf("svc-http-%d", i), + AccountId: "account-1", + Domain: fmt.Sprintf("svc-%d.bench.example.com", i), + Mode: "http", + Path: []*proto.PathMapping{ + {Path: "/", Target: fmt.Sprintf("http://10.0.%d.%d:8080", (i/256)%256, i%256)}, + {Path: "/api", Target: fmt.Sprintf("http://10.0.%d.%d:8081", (i/256)%256, i%256)}, + }, + Auth: &proto.Authentication{}, + } + case i%20 < 17: // 15% modify + m = &proto.ProxyMapping{ + Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED, + Id: fmt.Sprintf("svc-http-%d", i%100), + AccountId: "account-1", + Domain: fmt.Sprintf("svc-%d.bench.example.com", i%100), + Mode: "http", + Path: []*proto.PathMapping{ + {Path: "/", Target: fmt.Sprintf("http://10.1.%d.%d:8080", (i/256)%256, i%256)}, + }, + Auth: &proto.Authentication{}, + } + case i%20 < 19: // 10% TLS passthrough on main port + m = &proto.ProxyMapping{ + Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED, + Id: fmt.Sprintf("svc-tls-%d", i), + AccountId: "account-1", + Domain: fmt.Sprintf("tls-%d.bench.example.com", i), + Mode: "tls", + ListenPort: 443, + Path: []*proto.PathMapping{ + {Path: "/", Target: fmt.Sprintf("10.2.%d.%d:443", (i/256)%256, i%256)}, + }, + } + default: // 5% remove + m = &proto.ProxyMapping{ + Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED, + Id: fmt.Sprintf("svc-http-%d", i%50), + AccountId: "account-1", + Domain: fmt.Sprintf("svc-%d.bench.example.com", i%50), + Mode: "http", + } + } + mappings[i] = m + } + return mappings +} + +const ( + createPeerLatency = 100 * time.Millisecond + statusUpdateLatency = 50 * time.Millisecond +) + +// BenchmarkProcessMappings_HTTPCreate_SingleAccount benchmarks the initial sync +// scenario: N HTTP mappings all on a single account. Only the first mapping +// triggers CreateProxyPeer (100ms gRPC). The rest just register with the +// existing client. This is the "best case" production path. +func BenchmarkProcessMappings_HTTPCreate_SingleAccount(b *testing.B) { + for _, n := range []int{100, 1000, 5000} { + b.Run(fmt.Sprintf("n=%d", n), func(b *testing.B) { + mappings := generateHTTPMappings(n, proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED) + for range b.N { + s := benchServerWithLatency(b, createPeerLatency, statusUpdateLatency) + s.processMappings(b.Context(), mappings) + } + }) + } +} + +// BenchmarkProcessMappings_HTTPCreate_MultiAccount benchmarks the worst-case +// initial sync: every mapping belongs to a different account, so each one +// triggers a full CreateProxyPeer gRPC round-trip (100ms) + embed.New. +// With 500 accounts this serializes to ~50s of blocking I/O. +func BenchmarkProcessMappings_HTTPCreate_MultiAccount(b *testing.B) { + for _, tc := range []struct { + mappings int + accounts int + }{ + {100, 10}, + {100, 50}, + {1000, 50}, + {1000, 200}, + {3000, 500}, + } { + b.Run(fmt.Sprintf("mappings=%d/accounts=%d", tc.mappings, tc.accounts), func(b *testing.B) { + mappings := generateMultiAccountHTTPMappings(tc.mappings, tc.accounts) + for range b.N { + s := benchServerWithLatency(b, createPeerLatency, statusUpdateLatency) + s.processMappings(b.Context(), mappings) + } + }) + } +} + +// BenchmarkProcessMappings_Mixed benchmarks a realistic mixed workload +// of creates, modifies, TLS, and removes with production-like latency. +// TLS mappings call SendStatusUpdate (50ms each), serialized. +func BenchmarkProcessMappings_Mixed(b *testing.B) { + for _, n := range []int{100, 1000, 5000} { + b.Run(fmt.Sprintf("n=%d", n), func(b *testing.B) { + mappings := generateMixedMappings(n) + for range b.N { + s := benchServerWithLatency(b, createPeerLatency, statusUpdateLatency) + creates := generateHTTPMappings(100, proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED) + s.processMappings(b.Context(), creates) + s.processMappings(b.Context(), mappings) + } + }) + } +} + +// BenchmarkProcessMappings_ModifyOnly benchmarks bulk modification of +// already-registered mappings (no new peers needed, no gRPC). +func BenchmarkProcessMappings_ModifyOnly(b *testing.B) { + for _, n := range []int{100, 1000, 5000} { + b.Run(fmt.Sprintf("n=%d", n), func(b *testing.B) { + creates := generateHTTPMappings(n, proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED) + modifies := generateHTTPMappings(n, proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED) + for range b.N { + s := benchServerWithLatency(b, createPeerLatency, statusUpdateLatency) + s.processMappings(b.Context(), creates) + s.processMappings(b.Context(), modifies) + } + }) + } +} + +// BenchmarkProcessMappings_NoLatency measures pure CPU/allocation overhead +// with zero I/O latency for profiling purposes. +func BenchmarkProcessMappings_NoLatency(b *testing.B) { + for _, n := range []int{1000, 5000} { + b.Run(fmt.Sprintf("n=%d", n), func(b *testing.B) { + mappings := generateHTTPMappings(n, proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED) + for range b.N { + s := benchServerWithLatency(b, 0, 0) + s.processMappings(b.Context(), mappings) + } + }) + } +} diff --git a/proxy/server.go b/proxy/server.go index 6980e1df1..6ccfa3e9a 100644 --- a/proxy/server.go +++ b/proxy/server.go @@ -32,9 +32,11 @@ import ( "go.opentelemetry.io/otel/sdk/metric" "golang.org/x/exp/maps" "google.golang.org/grpc" + "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/keepalive" + grpcstatus "google.golang.org/grpc/status" "google.golang.org/protobuf/types/known/timestamppb" "github.com/netbirdio/netbird/proxy/internal/accesslog" @@ -282,6 +284,7 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) { WGPort: s.WireguardPort, PreSharedKey: s.PreSharedKey, }, s.Logger, s, s.mgmtClient) + s.netbird.OnAddPeer = s.meter.RecordAddPeerDuration // Create health checker before the mapping worker so it can track // management connectivity from the first stream connection. @@ -938,6 +941,9 @@ func (s *Server) newManagementMappingWorker(ctx context.Context, client proto.Pr Clock: backoff.SystemClock, } + // syncSupported tracks whether management supports SyncMappings. + // Starts true; set to false on first Unimplemented error. + syncSupported := true initialSyncDone := false operation := func() error { @@ -949,36 +955,25 @@ func (s *Server) newManagementMappingWorker(ctx context.Context, client proto.Pr s.healthChecker.SetManagementConnected(false) } - supportsCrowdSec := s.crowdsecRegistry.Available() - mappingClient, err := client.GetMappingUpdate(ctx, &proto.GetMappingUpdateRequest{ - ProxyId: s.ID, - Version: s.Version, - StartedAt: timestamppb.New(s.startTime), - Address: s.ProxyURL, - Capabilities: &proto.ProxyCapabilities{ - SupportsCustomPorts: &s.SupportsCustomPorts, - RequireSubdomain: &s.RequireSubdomain, - SupportsCrowdsec: &supportsCrowdSec, - }, - }) - if err != nil { - return fmt.Errorf("create mapping stream: %w", err) + var streamErr error + if syncSupported { + streamErr = s.trySyncMappings(ctx, client, &initialSyncDone) + if isSyncUnimplemented(streamErr) { + syncSupported = false + s.Logger.Info("management does not support SyncMappings, falling back to GetMappingUpdate") + streamErr = s.tryGetMappingUpdate(ctx, client, &initialSyncDone) + } + } else { + streamErr = s.tryGetMappingUpdate(ctx, client, &initialSyncDone) } - if s.healthChecker != nil { - s.healthChecker.SetManagementConnected(true) - } - s.Logger.Debug("management mapping stream established") - - // Stream established โ€” reset backoff so the next failure retries quickly. - bo.Reset() - - streamErr := s.handleMappingStream(ctx, mappingClient, &initialSyncDone) - if s.healthChecker != nil { s.healthChecker.SetManagementConnected(false) } + // Stream established โ€” reset backoff so the next failure retries quickly. + bo.Reset() + if streamErr == nil { return fmt.Errorf("stream closed by server") } @@ -995,56 +990,187 @@ func (s *Server) newManagementMappingWorker(ctx context.Context, client proto.Pr } } -func (s *Server) handleMappingStream(ctx context.Context, mappingClient proto.ProxyService_GetMappingUpdateClient, initialSyncDone *bool) error { +func (s *Server) proxyCapabilities() *proto.ProxyCapabilities { + supportsCrowdSec := s.crowdsecRegistry.Available() + return &proto.ProxyCapabilities{ + SupportsCustomPorts: &s.SupportsCustomPorts, + RequireSubdomain: &s.RequireSubdomain, + SupportsCrowdsec: &supportsCrowdSec, + } +} + +func (s *Server) tryGetMappingUpdate(ctx context.Context, client proto.ProxyServiceClient, initialSyncDone *bool) error { + connectTime := time.Now() + mappingClient, err := client.GetMappingUpdate(ctx, &proto.GetMappingUpdateRequest{ + ProxyId: s.ID, + Version: s.Version, + StartedAt: timestamppb.New(s.startTime), + Address: s.ProxyURL, + Capabilities: s.proxyCapabilities(), + }) + if err != nil { + return fmt.Errorf("create mapping stream: %w", err) + } + + if s.healthChecker != nil { + s.healthChecker.SetManagementConnected(true) + } + s.Logger.Debug("management mapping stream established (GetMappingUpdate)") + + return s.handleMappingStream(ctx, mappingClient, initialSyncDone, connectTime) +} + +func (s *Server) trySyncMappings(ctx context.Context, client proto.ProxyServiceClient, initialSyncDone *bool) error { + connectTime := time.Now() + stream, err := client.SyncMappings(ctx) + if err != nil { + return fmt.Errorf("create sync stream: %w", err) + } + + // Send init message. + if err := stream.Send(&proto.SyncMappingsRequest{ + Msg: &proto.SyncMappingsRequest_Init{ + Init: &proto.SyncMappingsInit{ + ProxyId: s.ID, + Version: s.Version, + StartedAt: timestamppb.New(s.startTime), + Address: s.ProxyURL, + Capabilities: s.proxyCapabilities(), + }, + }, + }); err != nil { + return fmt.Errorf("send sync init: %w", err) + } + + if s.healthChecker != nil { + s.healthChecker.SetManagementConnected(true) + } + s.Logger.Debug("management mapping stream established (SyncMappings)") + + return s.handleSyncMappingsStream(ctx, stream, initialSyncDone, connectTime) +} + +func isSyncUnimplemented(err error) bool { + if err == nil { + return false + } + st, ok := grpcstatus.FromError(err) + return ok && st.Code() == codes.Unimplemented +} + +func (s *Server) handleSyncMappingsStream(ctx context.Context, stream proto.ProxyService_SyncMappingsClient, initialSyncDone *bool, connectTime time.Time) error { select { case <-s.routerReady: case <-ctx.Done(): return ctx.Err() } - var snapshotIDs map[types.ServiceID]struct{} - if !*initialSyncDone { - snapshotIDs = make(map[types.ServiceID]struct{}) - } + tracker := s.newSnapshotTracker(initialSyncDone, connectTime) + + for { + select { + case <-ctx.Done(): + return ctx.Err() + default: + msg, err := stream.Recv() + switch { + case errors.Is(err, io.EOF): + return nil + case err != nil: + return fmt.Errorf("receive msg: %w", err) + } + + batchStart := time.Now() + s.Logger.Debug("Received mapping update, starting processing") + s.processMappings(ctx, msg.GetMapping()) + s.Logger.Debug("Processing mapping update completed") + tracker.recordBatch(ctx, s, msg.GetMapping(), msg.GetInitialSyncComplete(), batchStart) + + if err := stream.Send(&proto.SyncMappingsRequest{ + Msg: &proto.SyncMappingsRequest_Ack{Ack: &proto.SyncMappingsAck{}}, + }); err != nil { + return fmt.Errorf("send ack: %w", err) + } + } + } +} + +func (s *Server) handleMappingStream(ctx context.Context, mappingClient proto.ProxyService_GetMappingUpdateClient, initialSyncDone *bool, connectTime time.Time) error { + select { + case <-s.routerReady: + case <-ctx.Done(): + return ctx.Err() + } + + tracker := s.newSnapshotTracker(initialSyncDone, connectTime) for { - // Check for context completion to gracefully shutdown. select { case <-ctx.Done(): - // Shutting down. return ctx.Err() default: msg, err := mappingClient.Recv() switch { case errors.Is(err, io.EOF): - // Mapping connection gracefully terminated by server. return nil case err != nil: - // Something has gone horribly wrong, return and hope the parent retries the connection. return fmt.Errorf("receive msg: %w", err) } + + batchStart := time.Now() s.Logger.Debug("Received mapping update, starting processing") s.processMappings(ctx, msg.GetMapping()) s.Logger.Debug("Processing mapping update completed") - - if !*initialSyncDone { - for _, m := range msg.GetMapping() { - snapshotIDs[types.ServiceID(m.GetId())] = struct{}{} - } - if msg.GetInitialSyncComplete() { - s.reconcileSnapshot(ctx, snapshotIDs) - snapshotIDs = nil - if s.healthChecker != nil { - s.healthChecker.SetInitialSyncComplete() - } - *initialSyncDone = true - s.Logger.Info("Initial mapping sync complete") - } - } + tracker.recordBatch(ctx, s, msg.GetMapping(), msg.GetInitialSyncComplete(), batchStart) } } } +// snapshotTracker accumulates service IDs during the initial snapshot and +// finalises sync state when the complete flag arrives. +type snapshotTracker struct { + done *bool + connectTime time.Time + snapshotIDs map[types.ServiceID]struct{} +} + +func (s *Server) newSnapshotTracker(done *bool, connectTime time.Time) *snapshotTracker { + var ids map[types.ServiceID]struct{} + if !*done { + ids = make(map[types.ServiceID]struct{}) + } + return &snapshotTracker{done: done, connectTime: connectTime, snapshotIDs: ids} +} + +func (t *snapshotTracker) recordBatch(ctx context.Context, s *Server, mappings []*proto.ProxyMapping, syncComplete bool, batchStart time.Time) { + if *t.done { + return + } + + if s.meter != nil { + s.meter.RecordSnapshotBatchDuration(time.Since(batchStart)) + } + + for _, m := range mappings { + t.snapshotIDs[types.ServiceID(m.GetId())] = struct{}{} + } + + if !syncComplete { + return + } + + s.reconcileSnapshot(ctx, t.snapshotIDs) + t.snapshotIDs = nil + if s.healthChecker != nil { + s.healthChecker.SetInitialSyncComplete() + } + *t.done = true + if s.meter != nil { + s.meter.RecordSnapshotSyncDuration(time.Since(t.connectTime)) + } + s.Logger.Info("Initial mapping sync complete") +} + // reconcileSnapshot removes local mappings that are absent from the snapshot. // This ensures services deleted while the proxy was disconnected get cleaned up. func (s *Server) reconcileSnapshot(ctx context.Context, snapshotIDs map[types.ServiceID]struct{}) { @@ -1067,6 +1193,8 @@ func (s *Server) reconcileSnapshot(ctx context.Context, snapshotIDs map[types.Se } func (s *Server) processMappings(ctx context.Context, mappings []*proto.ProxyMapping) { + s.ensurePeers(ctx, mappings) + for _, mapping := range mappings { s.Logger.WithFields(log.Fields{ "type": mapping.GetType(), @@ -1100,6 +1228,60 @@ func (s *Server) processMappings(ctx context.Context, mappings []*proto.ProxyMap } } +// ensurePeers pre-creates NetBird peers for all unique accounts referenced by +// CREATED mappings. Peers for different accounts are created concurrently, +// which avoids serializing Nร—100ms gRPC round-trips during large initial syncs. +func (s *Server) ensurePeers(ctx context.Context, mappings []*proto.ProxyMapping) { + // Collect one representative mapping per account that needs a new peer. + type peerReq struct { + accountID types.AccountID + svcKey roundtrip.ServiceKey + authToken string + svcID types.ServiceID + } + seen := make(map[types.AccountID]struct{}) + var reqs []peerReq + for _, m := range mappings { + if m.GetType() != proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED { + continue + } + accountID := types.AccountID(m.GetAccountId()) + if _, ok := seen[accountID]; ok { + continue + } + seen[accountID] = struct{}{} + if s.netbird.HasClient(accountID) { + continue + } + reqs = append(reqs, peerReq{ + accountID: accountID, + svcKey: s.serviceKeyForMapping(m), + authToken: m.GetAuthToken(), + svcID: types.ServiceID(m.GetId()), + }) + } + + if len(reqs) <= 1 { + return + } + + var wg sync.WaitGroup + wg.Add(len(reqs)) + for _, r := range reqs { + go func() { + defer wg.Done() + if err := s.netbird.AddPeer(ctx, r.accountID, r.svcKey, r.authToken, r.svcID); err != nil { + s.Logger.WithFields(log.Fields{ + "account_id": r.accountID, + "service_id": r.svcID, + "error": err, + }).Warn("failed to pre-create peer for account") + } + }() + } + wg.Wait() +} + // addMapping registers a service mapping and starts the appropriate relay or routes. func (s *Server) addMapping(ctx context.Context, mapping *proto.ProxyMapping) error { accountID := types.AccountID(mapping.GetAccountId()) diff --git a/proxy/snapshot_reconcile_test.go b/proxy/snapshot_reconcile_test.go index 042d8df77..2e6c80cfc 100644 --- a/proxy/snapshot_reconcile_test.go +++ b/proxy/snapshot_reconcile_test.go @@ -4,6 +4,7 @@ import ( "context" "io" "testing" + "time" log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" @@ -139,7 +140,7 @@ func TestHandleMappingStream_BatchedSnapshotSyncComplete(t *testing.T) { } syncDone := false - err := s.handleMappingStream(context.Background(), stream, &syncDone) + err := s.handleMappingStream(context.Background(), stream, &syncDone, time.Time{}) assert.NoError(t, err) assert.True(t, syncDone, "sync should be marked done after final batch") } @@ -164,7 +165,7 @@ func TestHandleMappingStream_PostSyncDoesNotReconcile(t *testing.T) { } syncDone := true // sync already completed in a previous stream - err := s.handleMappingStream(context.Background(), stream, &syncDone) + err := s.handleMappingStream(context.Background(), stream, &syncDone, time.Time{}) require.NoError(t, err) assert.Len(t, s.lastMappings, 2, @@ -185,7 +186,7 @@ func TestHandleMappingStream_ImmediateEOF_NoReconciliation(t *testing.T) { stream := &mockMappingStream{} // no messages โ†’ immediate EOF syncDone := false - err := s.handleMappingStream(context.Background(), stream, &syncDone) + err := s.handleMappingStream(context.Background(), stream, &syncDone, time.Time{}) assert.NoError(t, err) assert.False(t, syncDone, "sync should not be marked done on immediate EOF") @@ -218,7 +219,7 @@ func TestHandleMappingStream_ErrorMidSync_NoReconciliation(t *testing.T) { s.lastMappings["svc-stale"] = &proto.ProxyMapping{Id: "svc-stale", AccountId: "acct-1"} syncDone := false - err := s.handleMappingStream(context.Background(), &mockErrRecvStream{}, &syncDone) + err := s.handleMappingStream(context.Background(), &mockErrRecvStream{}, &syncDone, time.Time{}) assert.Error(t, err) assert.False(t, syncDone) diff --git a/proxy/sync_mappings_test.go b/proxy/sync_mappings_test.go new file mode 100644 index 000000000..801587e4c --- /dev/null +++ b/proxy/sync_mappings_test.go @@ -0,0 +1,525 @@ +package proxy + +import ( + "context" + "errors" + "fmt" + "net" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials/insecure" + grpcstatus "google.golang.org/grpc/status" + + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" + "github.com/netbirdio/netbird/shared/management/proto" +) + +func TestIntegration_SyncMappings_HappyPath(t *testing.T) { + setup := setupIntegrationTest(t) + defer setup.cleanup() + + conn, err := grpc.NewClient(setup.grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials())) + require.NoError(t, err) + defer conn.Close() + + client := proto.NewProxyServiceClient(conn) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + stream, err := client.SyncMappings(ctx) + require.NoError(t, err) + + // Send init. + err = stream.Send(&proto.SyncMappingsRequest{ + Msg: &proto.SyncMappingsRequest_Init{ + Init: &proto.SyncMappingsInit{ + ProxyId: "sync-proxy-1", + Version: "test-v1", + Address: "test.proxy.io", + }, + }, + }) + require.NoError(t, err) + + mappingsByID := make(map[string]*proto.ProxyMapping) + for { + msg, err := stream.Recv() + require.NoError(t, err) + for _, m := range msg.GetMapping() { + mappingsByID[m.GetId()] = m + } + + // Ack every batch. + err = stream.Send(&proto.SyncMappingsRequest{ + Msg: &proto.SyncMappingsRequest_Ack{Ack: &proto.SyncMappingsAck{}}, + }) + require.NoError(t, err) + + if msg.GetInitialSyncComplete() { + break + } + } + + assert.Len(t, mappingsByID, 2, "Should receive 2 mappings") + + rp1 := mappingsByID["rp-1"] + require.NotNil(t, rp1) + assert.Equal(t, "app1.test.proxy.io", rp1.GetDomain()) + assert.Equal(t, "test-account-1", rp1.GetAccountId()) + assert.Equal(t, proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED, rp1.GetType()) + assert.NotEmpty(t, rp1.GetAuthToken(), "Should have auth token") + + rp2 := mappingsByID["rp-2"] + require.NotNil(t, rp2) + assert.Equal(t, "app2.test.proxy.io", rp2.GetDomain()) +} + +func TestIntegration_SyncMappings_BackPressure(t *testing.T) { + setup := setupIntegrationTest(t) + defer setup.cleanup() + + // Add enough services to guarantee multiple batches (default batch size 500). + addServicesToStore(t, setup, 600, "test.proxy.io") + + conn, err := grpc.NewClient(setup.grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials())) + require.NoError(t, err) + defer conn.Close() + + client := proto.NewProxyServiceClient(conn) + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + stream, err := client.SyncMappings(ctx) + require.NoError(t, err) + + err = stream.Send(&proto.SyncMappingsRequest{ + Msg: &proto.SyncMappingsRequest_Init{ + Init: &proto.SyncMappingsInit{ + ProxyId: "sync-proxy-backpressure", + Version: "test-v1", + Address: "test.proxy.io", + }, + }, + }) + require.NoError(t, err) + + // Strategy: receive batch 1, then hold for a significant delay before + // acking. If back-pressure works, batch 2 cannot arrive until after + // the ack is sent โ€” so its receive timestamp must be >= the ack + // timestamp. If management were fire-and-forget, all batches would + // already be buffered in the gRPC transport and batch 2 would arrive + // well before the ack time. + const ackDelay = 300 * time.Millisecond + + type batchEvent struct { + recvAt time.Time + ackAt time.Time + count int + } + var batches []batchEvent + var totalMappings int + + for { + msg, err := stream.Recv() + require.NoError(t, err) + + recvAt := time.Now() + totalMappings += len(msg.GetMapping()) + + // Delay the ack on non-final batches to create a measurable gap. + if !msg.GetInitialSyncComplete() { + time.Sleep(ackDelay) + } + + ackAt := time.Now() + batches = append(batches, batchEvent{ + recvAt: recvAt, + ackAt: ackAt, + count: len(msg.GetMapping()), + }) + + err = stream.Send(&proto.SyncMappingsRequest{ + Msg: &proto.SyncMappingsRequest_Ack{Ack: &proto.SyncMappingsAck{}}, + }) + require.NoError(t, err) + + if msg.GetInitialSyncComplete() { + break + } + } + + // 2 original + 600 added = 602 services total. + assert.Equal(t, 602, totalMappings, "should receive all 602 mappings") + require.GreaterOrEqual(t, len(batches), 2, "need at least 2 batches to verify back-pressure") + + // For every batch after the first, its receive time must be after the + // previous batch's ack time. This proves management waited for the ack + // before sending the next batch. + for i := 1; i < len(batches); i++ { + prevAckAt := batches[i-1].ackAt + thisRecvAt := batches[i].recvAt + assert.True(t, !thisRecvAt.Before(prevAckAt), + "batch %d received at %v, but batch %d was acked at %v โ€” "+ + "management sent the next batch before receiving the ack", + i, thisRecvAt, i-1, prevAckAt) + } +} + +func TestIntegration_SyncMappings_IncrementalUpdate(t *testing.T) { + setup := setupIntegrationTest(t) + defer setup.cleanup() + + conn, err := grpc.NewClient(setup.grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials())) + require.NoError(t, err) + defer conn.Close() + + client := proto.NewProxyServiceClient(conn) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + stream, err := client.SyncMappings(ctx) + require.NoError(t, err) + + err = stream.Send(&proto.SyncMappingsRequest{ + Msg: &proto.SyncMappingsRequest_Init{ + Init: &proto.SyncMappingsInit{ + ProxyId: "sync-proxy-incremental", + Version: "test-v1", + Address: "test.proxy.io", + }, + }, + }) + require.NoError(t, err) + + // Drain initial snapshot. + for { + msg, err := stream.Recv() + require.NoError(t, err) + + err = stream.Send(&proto.SyncMappingsRequest{ + Msg: &proto.SyncMappingsRequest_Ack{Ack: &proto.SyncMappingsAck{}}, + }) + require.NoError(t, err) + + if msg.GetInitialSyncComplete() { + break + } + } + + // Now send an incremental update via the management server. + setup.proxyService.SendServiceUpdate(&proto.GetMappingUpdateResponse{ + Mapping: []*proto.ProxyMapping{{ + Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED, + Id: "rp-1", + AccountId: "test-account-1", + Domain: "app1.test.proxy.io", + }}, + }) + + // Receive the incremental update on the sync stream. + msg, err := stream.Recv() + require.NoError(t, err) + require.NotEmpty(t, msg.GetMapping()) + assert.Equal(t, "rp-1", msg.GetMapping()[0].GetId()) + assert.Equal(t, proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED, msg.GetMapping()[0].GetType()) +} + +func TestIntegration_SyncMappings_MixedProxyVersions(t *testing.T) { + setup := setupIntegrationTest(t) + defer setup.cleanup() + + conn, err := grpc.NewClient(setup.grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials())) + require.NoError(t, err) + defer conn.Close() + + client := proto.NewProxyServiceClient(conn) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + // Old proxy uses GetMappingUpdate. + legacyStream, err := client.GetMappingUpdate(ctx, &proto.GetMappingUpdateRequest{ + ProxyId: "legacy-proxy", + Version: "old-v1", + Address: "test.proxy.io", + }) + require.NoError(t, err) + + var legacyMappings []*proto.ProxyMapping + for { + msg, err := legacyStream.Recv() + require.NoError(t, err) + legacyMappings = append(legacyMappings, msg.GetMapping()...) + if msg.GetInitialSyncComplete() { + break + } + } + + // New proxy uses SyncMappings. + syncStream, err := client.SyncMappings(ctx) + require.NoError(t, err) + + err = syncStream.Send(&proto.SyncMappingsRequest{ + Msg: &proto.SyncMappingsRequest_Init{ + Init: &proto.SyncMappingsInit{ + ProxyId: "new-proxy", + Version: "new-v2", + Address: "test.proxy.io", + }, + }, + }) + require.NoError(t, err) + + var syncMappings []*proto.ProxyMapping + for { + msg, err := syncStream.Recv() + require.NoError(t, err) + syncMappings = append(syncMappings, msg.GetMapping()...) + + err = syncStream.Send(&proto.SyncMappingsRequest{ + Msg: &proto.SyncMappingsRequest_Ack{Ack: &proto.SyncMappingsAck{}}, + }) + require.NoError(t, err) + + if msg.GetInitialSyncComplete() { + break + } + } + + // Both should receive the same set of mappings. + assert.Equal(t, len(legacyMappings), len(syncMappings), + "legacy and sync proxies should receive the same number of mappings") + + legacyIDs := make(map[string]bool) + for _, m := range legacyMappings { + legacyIDs[m.GetId()] = true + } + for _, m := range syncMappings { + assert.True(t, legacyIDs[m.GetId()], + "mapping %s should be present in both streams", m.GetId()) + } + + // Both proxies should be connected. + proxies := setup.proxyService.GetConnectedProxies() + assert.Contains(t, proxies, "legacy-proxy") + assert.Contains(t, proxies, "new-proxy") + + // Both should receive incremental updates. + setup.proxyService.SendServiceUpdate(&proto.GetMappingUpdateResponse{ + Mapping: []*proto.ProxyMapping{{ + Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED, + Id: "rp-1", + AccountId: "test-account-1", + Domain: "app1.test.proxy.io", + }}, + }) + + // Legacy proxy receives via GetMappingUpdateResponse. + legacyMsg, err := legacyStream.Recv() + require.NoError(t, err) + assert.Equal(t, "rp-1", legacyMsg.GetMapping()[0].GetId()) + + // Sync proxy receives via SyncMappingsResponse. + syncMsg, err := syncStream.Recv() + require.NoError(t, err) + assert.Equal(t, "rp-1", syncMsg.GetMapping()[0].GetId()) +} + +func TestIntegration_SyncMappings_Reconnect(t *testing.T) { + setup := setupIntegrationTest(t) + defer setup.cleanup() + + conn, err := grpc.NewClient(setup.grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials())) + require.NoError(t, err) + defer conn.Close() + + client := proto.NewProxyServiceClient(conn) + proxyID := "sync-proxy-reconnect" + + receiveMappings := func() []*proto.ProxyMapping { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + stream, err := client.SyncMappings(ctx) + require.NoError(t, err) + + err = stream.Send(&proto.SyncMappingsRequest{ + Msg: &proto.SyncMappingsRequest_Init{ + Init: &proto.SyncMappingsInit{ + ProxyId: proxyID, + Version: "test-v1", + Address: "test.proxy.io", + }, + }, + }) + require.NoError(t, err) + + var mappings []*proto.ProxyMapping + for { + msg, err := stream.Recv() + require.NoError(t, err) + mappings = append(mappings, msg.GetMapping()...) + + err = stream.Send(&proto.SyncMappingsRequest{ + Msg: &proto.SyncMappingsRequest_Ack{Ack: &proto.SyncMappingsAck{}}, + }) + require.NoError(t, err) + + if msg.GetInitialSyncComplete() { + break + } + } + return mappings + } + + first := receiveMappings() + time.Sleep(100 * time.Millisecond) + second := receiveMappings() + + assert.Equal(t, len(first), len(second), + "should receive same mappings on reconnect") + + firstIDs := make(map[string]bool) + for _, m := range first { + firstIDs[m.GetId()] = true + } + for _, m := range second { + assert.True(t, firstIDs[m.GetId()], + "mapping %s should be present in both connections", m.GetId()) + } +} + +// --- Fallback tests: old management returns Unimplemented --- + +// unimplementedProxyServer embeds UnimplementedProxyServiceServer so +// SyncMappings returns codes.Unimplemented while GetMappingUpdate works. +type unimplementedSyncServer struct { + proto.UnimplementedProxyServiceServer + getMappingCalls atomic.Int32 +} + +func (s *unimplementedSyncServer) GetMappingUpdate(_ *proto.GetMappingUpdateRequest, stream proto.ProxyService_GetMappingUpdateServer) error { + s.getMappingCalls.Add(1) + return stream.Send(&proto.GetMappingUpdateResponse{ + Mapping: []*proto.ProxyMapping{{Id: "svc-1", AccountId: "acct-1", Domain: "example.com"}}, + InitialSyncComplete: true, + }) +} + +func TestIntegration_FallbackToGetMappingUpdate(t *testing.T) { + // Start a gRPC server that does NOT implement SyncMappings. + lis, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + + srv := &unimplementedSyncServer{} + grpcServer := grpc.NewServer() + proto.RegisterProxyServiceServer(grpcServer, srv) + go func() { _ = grpcServer.Serve(lis) }() + defer grpcServer.GracefulStop() + + conn, err := grpc.NewClient(lis.Addr().String(), grpc.WithTransportCredentials(insecure.NewCredentials())) + require.NoError(t, err) + defer conn.Close() + + client := proto.NewProxyServiceClient(conn) + + // Try SyncMappings โ€” should get Unimplemented. + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + stream, err := client.SyncMappings(ctx) + require.NoError(t, err) + + err = stream.Send(&proto.SyncMappingsRequest{ + Msg: &proto.SyncMappingsRequest_Init{ + Init: &proto.SyncMappingsInit{ + ProxyId: "test-proxy", + Address: "test.example.com", + }, + }, + }) + require.NoError(t, err) + + _, err = stream.Recv() + require.Error(t, err) + st, ok := grpcstatus.FromError(err) + require.True(t, ok) + assert.Equal(t, codes.Unimplemented, st.Code(), + "unimplemented SyncMappings should return Unimplemented code") + + // isSyncUnimplemented should detect this. + assert.True(t, isSyncUnimplemented(err)) + + // The actual fallback: GetMappingUpdate should work. + legacyStream, err := client.GetMappingUpdate(ctx, &proto.GetMappingUpdateRequest{ + ProxyId: "test-proxy", + Address: "test.example.com", + }) + require.NoError(t, err) + + msg, err := legacyStream.Recv() + require.NoError(t, err) + assert.True(t, msg.GetInitialSyncComplete()) + assert.Len(t, msg.GetMapping(), 1) + assert.Equal(t, int32(1), srv.getMappingCalls.Load()) +} + +func TestIsSyncUnimplemented(t *testing.T) { + tests := []struct { + name string + err error + want bool + }{ + {"nil error", nil, false}, + {"non-grpc error", errors.New("random"), false}, + {"grpc internal", grpcstatus.Error(codes.Internal, "fail"), false}, + {"grpc unavailable", grpcstatus.Error(codes.Unavailable, "fail"), false}, + {"grpc unimplemented", grpcstatus.Error(codes.Unimplemented, "method not found"), true}, + { + "wrapped unimplemented", + fmt.Errorf("create sync stream: %w", grpcstatus.Error(codes.Unimplemented, "nope")), + // grpc/status.FromError unwraps in recent versions of grpc-go. + true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.want, isSyncUnimplemented(tt.err)) + }) + } +} + +// addServicesToStore adds n extra services to the test store for the given cluster. +func addServicesToStore(t *testing.T, setup *integrationTestSetup, n int, cluster string) { + t.Helper() + ctx := context.Background() + for i := 0; i < n; i++ { + svc := &service.Service{ + ID: fmt.Sprintf("extra-svc-%d", i), + AccountID: "test-account-1", + Name: fmt.Sprintf("Extra Service %d", i), + Domain: fmt.Sprintf("extra-%d.test.proxy.io", i), + ProxyCluster: cluster, + Enabled: true, + Targets: []*service.Target{{ + Path: strPtr("/"), + Host: fmt.Sprintf("10.0.1.%d", i%256), + Port: 8080, + Protocol: "http", + TargetId: fmt.Sprintf("peer-extra-%d", i), + TargetType: "peer", + Enabled: true, + }}, + } + require.NoError(t, setup.store.CreateService(ctx, svc)) + } +} diff --git a/shared/management/proto/proxy_service.pb.go b/shared/management/proto/proxy_service.pb.go index 1095b6411..a3a5e4588 100644 --- a/shared/management/proto/proxy_service.pb.go +++ b/shared/management/proto/proxy_service.pb.go @@ -1970,6 +1970,269 @@ func (x *ValidateSessionResponse) GetDeniedReason() string { return "" } +// SyncMappingsRequest is sent by the proxy on the bidirectional SyncMappings +// stream. The first message MUST be an init; all subsequent messages MUST be +// acks. +type SyncMappingsRequest struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + // Types that are assignable to Msg: + // + // *SyncMappingsRequest_Init + // *SyncMappingsRequest_Ack + Msg isSyncMappingsRequest_Msg `protobuf_oneof:"msg"` +} + +func (x *SyncMappingsRequest) Reset() { + *x = SyncMappingsRequest{} + if protoimpl.UnsafeEnabled { + mi := &file_proxy_service_proto_msgTypes[25] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *SyncMappingsRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*SyncMappingsRequest) ProtoMessage() {} + +func (x *SyncMappingsRequest) ProtoReflect() protoreflect.Message { + mi := &file_proxy_service_proto_msgTypes[25] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use SyncMappingsRequest.ProtoReflect.Descriptor instead. +func (*SyncMappingsRequest) Descriptor() ([]byte, []int) { + return file_proxy_service_proto_rawDescGZIP(), []int{25} +} + +func (m *SyncMappingsRequest) GetMsg() isSyncMappingsRequest_Msg { + if m != nil { + return m.Msg + } + return nil +} + +func (x *SyncMappingsRequest) GetInit() *SyncMappingsInit { + if x, ok := x.GetMsg().(*SyncMappingsRequest_Init); ok { + return x.Init + } + return nil +} + +func (x *SyncMappingsRequest) GetAck() *SyncMappingsAck { + if x, ok := x.GetMsg().(*SyncMappingsRequest_Ack); ok { + return x.Ack + } + return nil +} + +type isSyncMappingsRequest_Msg interface { + isSyncMappingsRequest_Msg() +} + +type SyncMappingsRequest_Init struct { + Init *SyncMappingsInit `protobuf:"bytes,1,opt,name=init,proto3,oneof"` +} + +type SyncMappingsRequest_Ack struct { + Ack *SyncMappingsAck `protobuf:"bytes,2,opt,name=ack,proto3,oneof"` +} + +func (*SyncMappingsRequest_Init) isSyncMappingsRequest_Msg() {} + +func (*SyncMappingsRequest_Ack) isSyncMappingsRequest_Msg() {} + +// SyncMappingsInit is the first message on the stream, carrying the same +// identification fields as GetMappingUpdateRequest. +type SyncMappingsInit struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + ProxyId string `protobuf:"bytes,1,opt,name=proxy_id,json=proxyId,proto3" json:"proxy_id,omitempty"` + Version string `protobuf:"bytes,2,opt,name=version,proto3" json:"version,omitempty"` + StartedAt *timestamppb.Timestamp `protobuf:"bytes,3,opt,name=started_at,json=startedAt,proto3" json:"started_at,omitempty"` + Address string `protobuf:"bytes,4,opt,name=address,proto3" json:"address,omitempty"` + Capabilities *ProxyCapabilities `protobuf:"bytes,5,opt,name=capabilities,proto3" json:"capabilities,omitempty"` +} + +func (x *SyncMappingsInit) Reset() { + *x = SyncMappingsInit{} + if protoimpl.UnsafeEnabled { + mi := &file_proxy_service_proto_msgTypes[26] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *SyncMappingsInit) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*SyncMappingsInit) ProtoMessage() {} + +func (x *SyncMappingsInit) ProtoReflect() protoreflect.Message { + mi := &file_proxy_service_proto_msgTypes[26] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use SyncMappingsInit.ProtoReflect.Descriptor instead. +func (*SyncMappingsInit) Descriptor() ([]byte, []int) { + return file_proxy_service_proto_rawDescGZIP(), []int{26} +} + +func (x *SyncMappingsInit) GetProxyId() string { + if x != nil { + return x.ProxyId + } + return "" +} + +func (x *SyncMappingsInit) GetVersion() string { + if x != nil { + return x.Version + } + return "" +} + +func (x *SyncMappingsInit) GetStartedAt() *timestamppb.Timestamp { + if x != nil { + return x.StartedAt + } + return nil +} + +func (x *SyncMappingsInit) GetAddress() string { + if x != nil { + return x.Address + } + return "" +} + +func (x *SyncMappingsInit) GetCapabilities() *ProxyCapabilities { + if x != nil { + return x.Capabilities + } + return nil +} + +// SyncMappingsAck is sent by the proxy after it has fully processed a batch. +// Management waits for this before sending the next batch. +type SyncMappingsAck struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields +} + +func (x *SyncMappingsAck) Reset() { + *x = SyncMappingsAck{} + if protoimpl.UnsafeEnabled { + mi := &file_proxy_service_proto_msgTypes[27] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *SyncMappingsAck) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*SyncMappingsAck) ProtoMessage() {} + +func (x *SyncMappingsAck) ProtoReflect() protoreflect.Message { + mi := &file_proxy_service_proto_msgTypes[27] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use SyncMappingsAck.ProtoReflect.Descriptor instead. +func (*SyncMappingsAck) Descriptor() ([]byte, []int) { + return file_proxy_service_proto_rawDescGZIP(), []int{27} +} + +// SyncMappingsResponse is a batch of mappings sent by management. +// Identical semantics to GetMappingUpdateResponse. +type SyncMappingsResponse struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Mapping []*ProxyMapping `protobuf:"bytes,1,rep,name=mapping,proto3" json:"mapping,omitempty"` + // initial_sync_complete is set on the last message of the initial snapshot. + InitialSyncComplete bool `protobuf:"varint,2,opt,name=initial_sync_complete,json=initialSyncComplete,proto3" json:"initial_sync_complete,omitempty"` +} + +func (x *SyncMappingsResponse) Reset() { + *x = SyncMappingsResponse{} + if protoimpl.UnsafeEnabled { + mi := &file_proxy_service_proto_msgTypes[28] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *SyncMappingsResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*SyncMappingsResponse) ProtoMessage() {} + +func (x *SyncMappingsResponse) ProtoReflect() protoreflect.Message { + mi := &file_proxy_service_proto_msgTypes[28] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use SyncMappingsResponse.ProtoReflect.Descriptor instead. +func (*SyncMappingsResponse) Descriptor() ([]byte, []int) { + return file_proxy_service_proto_rawDescGZIP(), []int{28} +} + +func (x *SyncMappingsResponse) GetMapping() []*ProxyMapping { + if x != nil { + return x.Mapping + } + return nil +} + +func (x *SyncMappingsResponse) GetInitialSyncComplete() bool { + if x != nil { + return x.InitialSyncComplete + } + return false +} + var File_proxy_service_proto protoreflect.FileDescriptor var file_proxy_service_proto_rawDesc = []byte{ @@ -2254,37 +2517,74 @@ var file_proxy_service_proto_rawDesc = []byte{ 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x75, 0x73, 0x65, 0x72, 0x45, 0x6d, 0x61, 0x69, 0x6c, 0x12, 0x23, 0x0a, 0x0d, 0x64, 0x65, 0x6e, 0x69, 0x65, 0x64, 0x5f, 0x72, 0x65, 0x61, 0x73, 0x6f, 0x6e, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x64, 0x65, 0x6e, 0x69, 0x65, 0x64, - 0x52, 0x65, 0x61, 0x73, 0x6f, 0x6e, 0x2a, 0x64, 0x0a, 0x16, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x4d, - 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x54, 0x79, 0x70, 0x65, - 0x12, 0x17, 0x0a, 0x13, 0x55, 0x50, 0x44, 0x41, 0x54, 0x45, 0x5f, 0x54, 0x59, 0x50, 0x45, 0x5f, - 0x43, 0x52, 0x45, 0x41, 0x54, 0x45, 0x44, 0x10, 0x00, 0x12, 0x18, 0x0a, 0x14, 0x55, 0x50, 0x44, - 0x41, 0x54, 0x45, 0x5f, 0x54, 0x59, 0x50, 0x45, 0x5f, 0x4d, 0x4f, 0x44, 0x49, 0x46, 0x49, 0x45, - 0x44, 0x10, 0x01, 0x12, 0x17, 0x0a, 0x13, 0x55, 0x50, 0x44, 0x41, 0x54, 0x45, 0x5f, 0x54, 0x59, - 0x50, 0x45, 0x5f, 0x52, 0x45, 0x4d, 0x4f, 0x56, 0x45, 0x44, 0x10, 0x02, 0x2a, 0x46, 0x0a, 0x0f, - 0x50, 0x61, 0x74, 0x68, 0x52, 0x65, 0x77, 0x72, 0x69, 0x74, 0x65, 0x4d, 0x6f, 0x64, 0x65, 0x12, - 0x18, 0x0a, 0x14, 0x50, 0x41, 0x54, 0x48, 0x5f, 0x52, 0x45, 0x57, 0x52, 0x49, 0x54, 0x45, 0x5f, - 0x44, 0x45, 0x46, 0x41, 0x55, 0x4c, 0x54, 0x10, 0x00, 0x12, 0x19, 0x0a, 0x15, 0x50, 0x41, 0x54, - 0x48, 0x5f, 0x52, 0x45, 0x57, 0x52, 0x49, 0x54, 0x45, 0x5f, 0x50, 0x52, 0x45, 0x53, 0x45, 0x52, - 0x56, 0x45, 0x10, 0x01, 0x2a, 0xc8, 0x01, 0x0a, 0x0b, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x53, 0x74, - 0x61, 0x74, 0x75, 0x73, 0x12, 0x18, 0x0a, 0x14, 0x50, 0x52, 0x4f, 0x58, 0x59, 0x5f, 0x53, 0x54, - 0x41, 0x54, 0x55, 0x53, 0x5f, 0x50, 0x45, 0x4e, 0x44, 0x49, 0x4e, 0x47, 0x10, 0x00, 0x12, 0x17, - 0x0a, 0x13, 0x50, 0x52, 0x4f, 0x58, 0x59, 0x5f, 0x53, 0x54, 0x41, 0x54, 0x55, 0x53, 0x5f, 0x41, - 0x43, 0x54, 0x49, 0x56, 0x45, 0x10, 0x01, 0x12, 0x23, 0x0a, 0x1f, 0x50, 0x52, 0x4f, 0x58, 0x59, - 0x5f, 0x53, 0x54, 0x41, 0x54, 0x55, 0x53, 0x5f, 0x54, 0x55, 0x4e, 0x4e, 0x45, 0x4c, 0x5f, 0x4e, - 0x4f, 0x54, 0x5f, 0x43, 0x52, 0x45, 0x41, 0x54, 0x45, 0x44, 0x10, 0x02, 0x12, 0x24, 0x0a, 0x20, - 0x50, 0x52, 0x4f, 0x58, 0x59, 0x5f, 0x53, 0x54, 0x41, 0x54, 0x55, 0x53, 0x5f, 0x43, 0x45, 0x52, - 0x54, 0x49, 0x46, 0x49, 0x43, 0x41, 0x54, 0x45, 0x5f, 0x50, 0x45, 0x4e, 0x44, 0x49, 0x4e, 0x47, - 0x10, 0x03, 0x12, 0x23, 0x0a, 0x1f, 0x50, 0x52, 0x4f, 0x58, 0x59, 0x5f, 0x53, 0x54, 0x41, 0x54, - 0x55, 0x53, 0x5f, 0x43, 0x45, 0x52, 0x54, 0x49, 0x46, 0x49, 0x43, 0x41, 0x54, 0x45, 0x5f, 0x46, - 0x41, 0x49, 0x4c, 0x45, 0x44, 0x10, 0x04, 0x12, 0x16, 0x0a, 0x12, 0x50, 0x52, 0x4f, 0x58, 0x59, - 0x5f, 0x53, 0x54, 0x41, 0x54, 0x55, 0x53, 0x5f, 0x45, 0x52, 0x52, 0x4f, 0x52, 0x10, 0x05, 0x32, - 0xfc, 0x04, 0x0a, 0x0c, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, - 0x12, 0x5f, 0x0a, 0x10, 0x47, 0x65, 0x74, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x55, 0x70, - 0x64, 0x61, 0x74, 0x65, 0x12, 0x23, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, - 0x74, 0x2e, 0x47, 0x65, 0x74, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x55, 0x70, 0x64, 0x61, - 0x74, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x24, 0x2e, 0x6d, 0x61, 0x6e, 0x61, - 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x47, 0x65, 0x74, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, - 0x67, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x30, + 0x52, 0x65, 0x61, 0x73, 0x6f, 0x6e, 0x22, 0x81, 0x01, 0x0a, 0x13, 0x53, 0x79, 0x6e, 0x63, 0x4d, + 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x32, + 0x0a, 0x04, 0x69, 0x6e, 0x69, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1c, 0x2e, 0x6d, + 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x79, 0x6e, 0x63, 0x4d, 0x61, + 0x70, 0x70, 0x69, 0x6e, 0x67, 0x73, 0x49, 0x6e, 0x69, 0x74, 0x48, 0x00, 0x52, 0x04, 0x69, 0x6e, + 0x69, 0x74, 0x12, 0x2f, 0x0a, 0x03, 0x61, 0x63, 0x6b, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, + 0x1b, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x79, 0x6e, + 0x63, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x73, 0x41, 0x63, 0x6b, 0x48, 0x00, 0x52, 0x03, + 0x61, 0x63, 0x6b, 0x42, 0x05, 0x0a, 0x03, 0x6d, 0x73, 0x67, 0x22, 0xdf, 0x01, 0x0a, 0x10, 0x53, + 0x79, 0x6e, 0x63, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x73, 0x49, 0x6e, 0x69, 0x74, 0x12, + 0x19, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x78, 0x79, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, + 0x09, 0x52, 0x07, 0x70, 0x72, 0x6f, 0x78, 0x79, 0x49, 0x64, 0x12, 0x18, 0x0a, 0x07, 0x76, 0x65, + 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x76, 0x65, 0x72, + 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x39, 0x0a, 0x0a, 0x73, 0x74, 0x61, 0x72, 0x74, 0x65, 0x64, 0x5f, + 0x61, 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, + 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73, + 0x74, 0x61, 0x6d, 0x70, 0x52, 0x09, 0x73, 0x74, 0x61, 0x72, 0x74, 0x65, 0x64, 0x41, 0x74, 0x12, + 0x18, 0x0a, 0x07, 0x61, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, + 0x52, 0x07, 0x61, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x12, 0x41, 0x0a, 0x0c, 0x63, 0x61, 0x70, + 0x61, 0x62, 0x69, 0x6c, 0x69, 0x74, 0x69, 0x65, 0x73, 0x18, 0x05, 0x20, 0x01, 0x28, 0x0b, 0x32, + 0x1d, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x72, 0x6f, + 0x78, 0x79, 0x43, 0x61, 0x70, 0x61, 0x62, 0x69, 0x6c, 0x69, 0x74, 0x69, 0x65, 0x73, 0x52, 0x0c, + 0x63, 0x61, 0x70, 0x61, 0x62, 0x69, 0x6c, 0x69, 0x74, 0x69, 0x65, 0x73, 0x22, 0x11, 0x0a, 0x0f, + 0x53, 0x79, 0x6e, 0x63, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x73, 0x41, 0x63, 0x6b, 0x22, + 0x7e, 0x0a, 0x14, 0x53, 0x79, 0x6e, 0x63, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x73, 0x52, + 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x32, 0x0a, 0x07, 0x6d, 0x61, 0x70, 0x70, 0x69, + 0x6e, 0x67, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, + 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x4d, 0x61, 0x70, 0x70, 0x69, + 0x6e, 0x67, 0x52, 0x07, 0x6d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x12, 0x32, 0x0a, 0x15, 0x69, + 0x6e, 0x69, 0x74, 0x69, 0x61, 0x6c, 0x5f, 0x73, 0x79, 0x6e, 0x63, 0x5f, 0x63, 0x6f, 0x6d, 0x70, + 0x6c, 0x65, 0x74, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x13, 0x69, 0x6e, 0x69, 0x74, + 0x69, 0x61, 0x6c, 0x53, 0x79, 0x6e, 0x63, 0x43, 0x6f, 0x6d, 0x70, 0x6c, 0x65, 0x74, 0x65, 0x2a, + 0x64, 0x0a, 0x16, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x55, + 0x70, 0x64, 0x61, 0x74, 0x65, 0x54, 0x79, 0x70, 0x65, 0x12, 0x17, 0x0a, 0x13, 0x55, 0x50, 0x44, + 0x41, 0x54, 0x45, 0x5f, 0x54, 0x59, 0x50, 0x45, 0x5f, 0x43, 0x52, 0x45, 0x41, 0x54, 0x45, 0x44, + 0x10, 0x00, 0x12, 0x18, 0x0a, 0x14, 0x55, 0x50, 0x44, 0x41, 0x54, 0x45, 0x5f, 0x54, 0x59, 0x50, + 0x45, 0x5f, 0x4d, 0x4f, 0x44, 0x49, 0x46, 0x49, 0x45, 0x44, 0x10, 0x01, 0x12, 0x17, 0x0a, 0x13, + 0x55, 0x50, 0x44, 0x41, 0x54, 0x45, 0x5f, 0x54, 0x59, 0x50, 0x45, 0x5f, 0x52, 0x45, 0x4d, 0x4f, + 0x56, 0x45, 0x44, 0x10, 0x02, 0x2a, 0x46, 0x0a, 0x0f, 0x50, 0x61, 0x74, 0x68, 0x52, 0x65, 0x77, + 0x72, 0x69, 0x74, 0x65, 0x4d, 0x6f, 0x64, 0x65, 0x12, 0x18, 0x0a, 0x14, 0x50, 0x41, 0x54, 0x48, + 0x5f, 0x52, 0x45, 0x57, 0x52, 0x49, 0x54, 0x45, 0x5f, 0x44, 0x45, 0x46, 0x41, 0x55, 0x4c, 0x54, + 0x10, 0x00, 0x12, 0x19, 0x0a, 0x15, 0x50, 0x41, 0x54, 0x48, 0x5f, 0x52, 0x45, 0x57, 0x52, 0x49, + 0x54, 0x45, 0x5f, 0x50, 0x52, 0x45, 0x53, 0x45, 0x52, 0x56, 0x45, 0x10, 0x01, 0x2a, 0xc8, 0x01, + 0x0a, 0x0b, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x18, 0x0a, + 0x14, 0x50, 0x52, 0x4f, 0x58, 0x59, 0x5f, 0x53, 0x54, 0x41, 0x54, 0x55, 0x53, 0x5f, 0x50, 0x45, + 0x4e, 0x44, 0x49, 0x4e, 0x47, 0x10, 0x00, 0x12, 0x17, 0x0a, 0x13, 0x50, 0x52, 0x4f, 0x58, 0x59, + 0x5f, 0x53, 0x54, 0x41, 0x54, 0x55, 0x53, 0x5f, 0x41, 0x43, 0x54, 0x49, 0x56, 0x45, 0x10, 0x01, + 0x12, 0x23, 0x0a, 0x1f, 0x50, 0x52, 0x4f, 0x58, 0x59, 0x5f, 0x53, 0x54, 0x41, 0x54, 0x55, 0x53, + 0x5f, 0x54, 0x55, 0x4e, 0x4e, 0x45, 0x4c, 0x5f, 0x4e, 0x4f, 0x54, 0x5f, 0x43, 0x52, 0x45, 0x41, + 0x54, 0x45, 0x44, 0x10, 0x02, 0x12, 0x24, 0x0a, 0x20, 0x50, 0x52, 0x4f, 0x58, 0x59, 0x5f, 0x53, + 0x54, 0x41, 0x54, 0x55, 0x53, 0x5f, 0x43, 0x45, 0x52, 0x54, 0x49, 0x46, 0x49, 0x43, 0x41, 0x54, + 0x45, 0x5f, 0x50, 0x45, 0x4e, 0x44, 0x49, 0x4e, 0x47, 0x10, 0x03, 0x12, 0x23, 0x0a, 0x1f, 0x50, + 0x52, 0x4f, 0x58, 0x59, 0x5f, 0x53, 0x54, 0x41, 0x54, 0x55, 0x53, 0x5f, 0x43, 0x45, 0x52, 0x54, + 0x49, 0x46, 0x49, 0x43, 0x41, 0x54, 0x45, 0x5f, 0x46, 0x41, 0x49, 0x4c, 0x45, 0x44, 0x10, 0x04, + 0x12, 0x16, 0x0a, 0x12, 0x50, 0x52, 0x4f, 0x58, 0x59, 0x5f, 0x53, 0x54, 0x41, 0x54, 0x55, 0x53, + 0x5f, 0x45, 0x52, 0x52, 0x4f, 0x52, 0x10, 0x05, 0x32, 0xd3, 0x05, 0x0a, 0x0c, 0x50, 0x72, 0x6f, + 0x78, 0x79, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x5f, 0x0a, 0x10, 0x47, 0x65, 0x74, + 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x12, 0x23, 0x2e, + 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x47, 0x65, 0x74, 0x4d, 0x61, + 0x70, 0x70, 0x69, 0x6e, 0x67, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, + 0x73, 0x74, 0x1a, 0x24, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, + 0x47, 0x65, 0x74, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, + 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x30, 0x01, 0x12, 0x55, 0x0a, 0x0c, 0x53, 0x79, + 0x6e, 0x63, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x73, 0x12, 0x1f, 0x2e, 0x6d, 0x61, 0x6e, + 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x79, 0x6e, 0x63, 0x4d, 0x61, 0x70, 0x70, + 0x69, 0x6e, 0x67, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x20, 0x2e, 0x6d, 0x61, + 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x79, 0x6e, 0x63, 0x4d, 0x61, 0x70, + 0x70, 0x69, 0x6e, 0x67, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x28, 0x01, 0x30, 0x01, 0x12, 0x54, 0x0a, 0x0d, 0x53, 0x65, 0x6e, 0x64, 0x41, 0x63, 0x63, 0x65, 0x73, 0x73, 0x4c, 0x6f, 0x67, 0x12, 0x20, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x65, 0x6e, 0x64, 0x41, 0x63, 0x63, 0x65, 0x73, 0x73, 0x4c, 0x6f, 0x67, 0x52, 0x65, 0x71, @@ -2334,7 +2634,7 @@ func file_proxy_service_proto_rawDescGZIP() []byte { } var file_proxy_service_proto_enumTypes = make([]protoimpl.EnumInfo, 3) -var file_proxy_service_proto_msgTypes = make([]protoimpl.MessageInfo, 27) +var file_proxy_service_proto_msgTypes = make([]protoimpl.MessageInfo, 31) var file_proxy_service_proto_goTypes = []interface{}{ (ProxyMappingUpdateType)(0), // 0: management.ProxyMappingUpdateType (PathRewriteMode)(0), // 1: management.PathRewriteMode @@ -2364,19 +2664,23 @@ var file_proxy_service_proto_goTypes = []interface{}{ (*GetOIDCURLResponse)(nil), // 25: management.GetOIDCURLResponse (*ValidateSessionRequest)(nil), // 26: management.ValidateSessionRequest (*ValidateSessionResponse)(nil), // 27: management.ValidateSessionResponse - nil, // 28: management.PathTargetOptions.CustomHeadersEntry - nil, // 29: management.AccessLog.MetadataEntry - (*timestamppb.Timestamp)(nil), // 30: google.protobuf.Timestamp - (*durationpb.Duration)(nil), // 31: google.protobuf.Duration + (*SyncMappingsRequest)(nil), // 28: management.SyncMappingsRequest + (*SyncMappingsInit)(nil), // 29: management.SyncMappingsInit + (*SyncMappingsAck)(nil), // 30: management.SyncMappingsAck + (*SyncMappingsResponse)(nil), // 31: management.SyncMappingsResponse + nil, // 32: management.PathTargetOptions.CustomHeadersEntry + nil, // 33: management.AccessLog.MetadataEntry + (*timestamppb.Timestamp)(nil), // 34: google.protobuf.Timestamp + (*durationpb.Duration)(nil), // 35: google.protobuf.Duration } var file_proxy_service_proto_depIdxs = []int32{ - 30, // 0: management.GetMappingUpdateRequest.started_at:type_name -> google.protobuf.Timestamp + 34, // 0: management.GetMappingUpdateRequest.started_at:type_name -> google.protobuf.Timestamp 3, // 1: management.GetMappingUpdateRequest.capabilities:type_name -> management.ProxyCapabilities 11, // 2: management.GetMappingUpdateResponse.mapping:type_name -> management.ProxyMapping - 31, // 3: management.PathTargetOptions.request_timeout:type_name -> google.protobuf.Duration + 35, // 3: management.PathTargetOptions.request_timeout:type_name -> google.protobuf.Duration 1, // 4: management.PathTargetOptions.path_rewrite:type_name -> management.PathRewriteMode - 28, // 5: management.PathTargetOptions.custom_headers:type_name -> management.PathTargetOptions.CustomHeadersEntry - 31, // 6: management.PathTargetOptions.session_idle_timeout:type_name -> google.protobuf.Duration + 32, // 5: management.PathTargetOptions.custom_headers:type_name -> management.PathTargetOptions.CustomHeadersEntry + 35, // 6: management.PathTargetOptions.session_idle_timeout:type_name -> google.protobuf.Duration 6, // 7: management.PathMapping.options:type_name -> management.PathTargetOptions 8, // 8: management.Authentication.header_auths:type_name -> management.HeaderAuth 0, // 9: management.ProxyMapping.type:type_name -> management.ProxyMappingUpdateType @@ -2384,31 +2688,38 @@ var file_proxy_service_proto_depIdxs = []int32{ 9, // 11: management.ProxyMapping.auth:type_name -> management.Authentication 10, // 12: management.ProxyMapping.access_restrictions:type_name -> management.AccessRestrictions 14, // 13: management.SendAccessLogRequest.log:type_name -> management.AccessLog - 30, // 14: management.AccessLog.timestamp:type_name -> google.protobuf.Timestamp - 29, // 15: management.AccessLog.metadata:type_name -> management.AccessLog.MetadataEntry + 34, // 14: management.AccessLog.timestamp:type_name -> google.protobuf.Timestamp + 33, // 15: management.AccessLog.metadata:type_name -> management.AccessLog.MetadataEntry 17, // 16: management.AuthenticateRequest.password:type_name -> management.PasswordRequest 18, // 17: management.AuthenticateRequest.pin:type_name -> management.PinRequest 16, // 18: management.AuthenticateRequest.header_auth:type_name -> management.HeaderAuthRequest 2, // 19: management.SendStatusUpdateRequest.status:type_name -> management.ProxyStatus - 4, // 20: management.ProxyService.GetMappingUpdate:input_type -> management.GetMappingUpdateRequest - 12, // 21: management.ProxyService.SendAccessLog:input_type -> management.SendAccessLogRequest - 15, // 22: management.ProxyService.Authenticate:input_type -> management.AuthenticateRequest - 20, // 23: management.ProxyService.SendStatusUpdate:input_type -> management.SendStatusUpdateRequest - 22, // 24: management.ProxyService.CreateProxyPeer:input_type -> management.CreateProxyPeerRequest - 24, // 25: management.ProxyService.GetOIDCURL:input_type -> management.GetOIDCURLRequest - 26, // 26: management.ProxyService.ValidateSession:input_type -> management.ValidateSessionRequest - 5, // 27: management.ProxyService.GetMappingUpdate:output_type -> management.GetMappingUpdateResponse - 13, // 28: management.ProxyService.SendAccessLog:output_type -> management.SendAccessLogResponse - 19, // 29: management.ProxyService.Authenticate:output_type -> management.AuthenticateResponse - 21, // 30: management.ProxyService.SendStatusUpdate:output_type -> management.SendStatusUpdateResponse - 23, // 31: management.ProxyService.CreateProxyPeer:output_type -> management.CreateProxyPeerResponse - 25, // 32: management.ProxyService.GetOIDCURL:output_type -> management.GetOIDCURLResponse - 27, // 33: management.ProxyService.ValidateSession:output_type -> management.ValidateSessionResponse - 27, // [27:34] is the sub-list for method output_type - 20, // [20:27] is the sub-list for method input_type - 20, // [20:20] is the sub-list for extension type_name - 20, // [20:20] is the sub-list for extension extendee - 0, // [0:20] is the sub-list for field type_name + 29, // 20: management.SyncMappingsRequest.init:type_name -> management.SyncMappingsInit + 30, // 21: management.SyncMappingsRequest.ack:type_name -> management.SyncMappingsAck + 34, // 22: management.SyncMappingsInit.started_at:type_name -> google.protobuf.Timestamp + 3, // 23: management.SyncMappingsInit.capabilities:type_name -> management.ProxyCapabilities + 11, // 24: management.SyncMappingsResponse.mapping:type_name -> management.ProxyMapping + 4, // 25: management.ProxyService.GetMappingUpdate:input_type -> management.GetMappingUpdateRequest + 28, // 26: management.ProxyService.SyncMappings:input_type -> management.SyncMappingsRequest + 12, // 27: management.ProxyService.SendAccessLog:input_type -> management.SendAccessLogRequest + 15, // 28: management.ProxyService.Authenticate:input_type -> management.AuthenticateRequest + 20, // 29: management.ProxyService.SendStatusUpdate:input_type -> management.SendStatusUpdateRequest + 22, // 30: management.ProxyService.CreateProxyPeer:input_type -> management.CreateProxyPeerRequest + 24, // 31: management.ProxyService.GetOIDCURL:input_type -> management.GetOIDCURLRequest + 26, // 32: management.ProxyService.ValidateSession:input_type -> management.ValidateSessionRequest + 5, // 33: management.ProxyService.GetMappingUpdate:output_type -> management.GetMappingUpdateResponse + 31, // 34: management.ProxyService.SyncMappings:output_type -> management.SyncMappingsResponse + 13, // 35: management.ProxyService.SendAccessLog:output_type -> management.SendAccessLogResponse + 19, // 36: management.ProxyService.Authenticate:output_type -> management.AuthenticateResponse + 21, // 37: management.ProxyService.SendStatusUpdate:output_type -> management.SendStatusUpdateResponse + 23, // 38: management.ProxyService.CreateProxyPeer:output_type -> management.CreateProxyPeerResponse + 25, // 39: management.ProxyService.GetOIDCURL:output_type -> management.GetOIDCURLResponse + 27, // 40: management.ProxyService.ValidateSession:output_type -> management.ValidateSessionResponse + 33, // [33:41] is the sub-list for method output_type + 25, // [25:33] is the sub-list for method input_type + 25, // [25:25] is the sub-list for extension type_name + 25, // [25:25] is the sub-list for extension extendee + 0, // [0:25] is the sub-list for field type_name } func init() { file_proxy_service_proto_init() } @@ -2717,6 +3028,54 @@ func file_proxy_service_proto_init() { return nil } } + file_proxy_service_proto_msgTypes[25].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*SyncMappingsRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_proxy_service_proto_msgTypes[26].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*SyncMappingsInit); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_proxy_service_proto_msgTypes[27].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*SyncMappingsAck); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_proxy_service_proto_msgTypes[28].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*SyncMappingsResponse); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } } file_proxy_service_proto_msgTypes[0].OneofWrappers = []interface{}{} file_proxy_service_proto_msgTypes[12].OneofWrappers = []interface{}{ @@ -2726,13 +3085,17 @@ func file_proxy_service_proto_init() { } file_proxy_service_proto_msgTypes[17].OneofWrappers = []interface{}{} file_proxy_service_proto_msgTypes[20].OneofWrappers = []interface{}{} + file_proxy_service_proto_msgTypes[25].OneofWrappers = []interface{}{ + (*SyncMappingsRequest_Init)(nil), + (*SyncMappingsRequest_Ack)(nil), + } type x struct{} out := protoimpl.TypeBuilder{ File: protoimpl.DescBuilder{ GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: file_proxy_service_proto_rawDesc, NumEnums: 3, - NumMessages: 27, + NumMessages: 31, NumExtensions: 0, NumServices: 1, }, diff --git a/shared/management/proto/proxy_service.proto b/shared/management/proto/proxy_service.proto index e359f0cbd..d1171b27e 100644 --- a/shared/management/proto/proxy_service.proto +++ b/shared/management/proto/proxy_service.proto @@ -12,6 +12,15 @@ import "google/protobuf/timestamp.proto"; service ProxyService { rpc GetMappingUpdate(GetMappingUpdateRequest) returns (stream GetMappingUpdateResponse); + // SyncMappings is a bidirectional stream that replaces GetMappingUpdate for + // new proxies. The proxy sends an initial SyncMappingsRequest to start the + // stream and then sends an ack after each batch is fully processed. + // Management waits for the ack before sending the next batch, providing + // application-level back-pressure during large initial syncs. + // Old proxies continue using GetMappingUpdate; old management servers + // return Unimplemented for this RPC and proxies fall back. + rpc SyncMappings(stream SyncMappingsRequest) returns (stream SyncMappingsResponse); + rpc SendAccessLog(SendAccessLogRequest) returns (SendAccessLogResponse); rpc Authenticate(AuthenticateRequest) returns (AuthenticateResponse); @@ -246,3 +255,35 @@ message ValidateSessionResponse { string user_email = 3; string denied_reason = 4; } + +// SyncMappingsRequest is sent by the proxy on the bidirectional SyncMappings +// stream. The first message MUST be an init; all subsequent messages MUST be +// acks. +message SyncMappingsRequest { + oneof msg { + SyncMappingsInit init = 1; + SyncMappingsAck ack = 2; + } +} + +// SyncMappingsInit is the first message on the stream, carrying the same +// identification fields as GetMappingUpdateRequest. +message SyncMappingsInit { + string proxy_id = 1; + string version = 2; + google.protobuf.Timestamp started_at = 3; + string address = 4; + ProxyCapabilities capabilities = 5; +} + +// SyncMappingsAck is sent by the proxy after it has fully processed a batch. +// Management waits for this before sending the next batch. +message SyncMappingsAck {} + +// SyncMappingsResponse is a batch of mappings sent by management. +// Identical semantics to GetMappingUpdateResponse. +message SyncMappingsResponse { + repeated ProxyMapping mapping = 1; + // initial_sync_complete is set on the last message of the initial snapshot. + bool initial_sync_complete = 2; +} diff --git a/shared/management/proto/proxy_service_grpc.pb.go b/shared/management/proto/proxy_service_grpc.pb.go index 627b217d8..fdc031ed7 100644 --- a/shared/management/proto/proxy_service_grpc.pb.go +++ b/shared/management/proto/proxy_service_grpc.pb.go @@ -19,6 +19,14 @@ const _ = grpc.SupportPackageIsVersion7 // For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. type ProxyServiceClient interface { GetMappingUpdate(ctx context.Context, in *GetMappingUpdateRequest, opts ...grpc.CallOption) (ProxyService_GetMappingUpdateClient, error) + // SyncMappings is a bidirectional stream that replaces GetMappingUpdate for + // new proxies. The proxy sends an initial SyncMappingsRequest to start the + // stream and then sends an ack after each batch is fully processed. + // Management waits for the ack before sending the next batch, providing + // application-level back-pressure during large initial syncs. + // Old proxies continue using GetMappingUpdate; old management servers + // return Unimplemented for this RPC and proxies fall back. + SyncMappings(ctx context.Context, opts ...grpc.CallOption) (ProxyService_SyncMappingsClient, error) SendAccessLog(ctx context.Context, in *SendAccessLogRequest, opts ...grpc.CallOption) (*SendAccessLogResponse, error) Authenticate(ctx context.Context, in *AuthenticateRequest, opts ...grpc.CallOption) (*AuthenticateResponse, error) SendStatusUpdate(ctx context.Context, in *SendStatusUpdateRequest, opts ...grpc.CallOption) (*SendStatusUpdateResponse, error) @@ -69,6 +77,37 @@ func (x *proxyServiceGetMappingUpdateClient) Recv() (*GetMappingUpdateResponse, return m, nil } +func (c *proxyServiceClient) SyncMappings(ctx context.Context, opts ...grpc.CallOption) (ProxyService_SyncMappingsClient, error) { + stream, err := c.cc.NewStream(ctx, &ProxyService_ServiceDesc.Streams[1], "/management.ProxyService/SyncMappings", opts...) + if err != nil { + return nil, err + } + x := &proxyServiceSyncMappingsClient{stream} + return x, nil +} + +type ProxyService_SyncMappingsClient interface { + Send(*SyncMappingsRequest) error + Recv() (*SyncMappingsResponse, error) + grpc.ClientStream +} + +type proxyServiceSyncMappingsClient struct { + grpc.ClientStream +} + +func (x *proxyServiceSyncMappingsClient) Send(m *SyncMappingsRequest) error { + return x.ClientStream.SendMsg(m) +} + +func (x *proxyServiceSyncMappingsClient) Recv() (*SyncMappingsResponse, error) { + m := new(SyncMappingsResponse) + if err := x.ClientStream.RecvMsg(m); err != nil { + return nil, err + } + return m, nil +} + func (c *proxyServiceClient) SendAccessLog(ctx context.Context, in *SendAccessLogRequest, opts ...grpc.CallOption) (*SendAccessLogResponse, error) { out := new(SendAccessLogResponse) err := c.cc.Invoke(ctx, "/management.ProxyService/SendAccessLog", in, out, opts...) @@ -128,6 +167,14 @@ func (c *proxyServiceClient) ValidateSession(ctx context.Context, in *ValidateSe // for forward compatibility type ProxyServiceServer interface { GetMappingUpdate(*GetMappingUpdateRequest, ProxyService_GetMappingUpdateServer) error + // SyncMappings is a bidirectional stream that replaces GetMappingUpdate for + // new proxies. The proxy sends an initial SyncMappingsRequest to start the + // stream and then sends an ack after each batch is fully processed. + // Management waits for the ack before sending the next batch, providing + // application-level back-pressure during large initial syncs. + // Old proxies continue using GetMappingUpdate; old management servers + // return Unimplemented for this RPC and proxies fall back. + SyncMappings(ProxyService_SyncMappingsServer) error SendAccessLog(context.Context, *SendAccessLogRequest) (*SendAccessLogResponse, error) Authenticate(context.Context, *AuthenticateRequest) (*AuthenticateResponse, error) SendStatusUpdate(context.Context, *SendStatusUpdateRequest) (*SendStatusUpdateResponse, error) @@ -146,6 +193,9 @@ type UnimplementedProxyServiceServer struct { func (UnimplementedProxyServiceServer) GetMappingUpdate(*GetMappingUpdateRequest, ProxyService_GetMappingUpdateServer) error { return status.Errorf(codes.Unimplemented, "method GetMappingUpdate not implemented") } +func (UnimplementedProxyServiceServer) SyncMappings(ProxyService_SyncMappingsServer) error { + return status.Errorf(codes.Unimplemented, "method SyncMappings not implemented") +} func (UnimplementedProxyServiceServer) SendAccessLog(context.Context, *SendAccessLogRequest) (*SendAccessLogResponse, error) { return nil, status.Errorf(codes.Unimplemented, "method SendAccessLog not implemented") } @@ -198,6 +248,32 @@ func (x *proxyServiceGetMappingUpdateServer) Send(m *GetMappingUpdateResponse) e return x.ServerStream.SendMsg(m) } +func _ProxyService_SyncMappings_Handler(srv interface{}, stream grpc.ServerStream) error { + return srv.(ProxyServiceServer).SyncMappings(&proxyServiceSyncMappingsServer{stream}) +} + +type ProxyService_SyncMappingsServer interface { + Send(*SyncMappingsResponse) error + Recv() (*SyncMappingsRequest, error) + grpc.ServerStream +} + +type proxyServiceSyncMappingsServer struct { + grpc.ServerStream +} + +func (x *proxyServiceSyncMappingsServer) Send(m *SyncMappingsResponse) error { + return x.ServerStream.SendMsg(m) +} + +func (x *proxyServiceSyncMappingsServer) Recv() (*SyncMappingsRequest, error) { + m := new(SyncMappingsRequest) + if err := x.ServerStream.RecvMsg(m); err != nil { + return nil, err + } + return m, nil +} + func _ProxyService_SendAccessLog_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { in := new(SendAccessLogRequest) if err := dec(in); err != nil { @@ -344,6 +420,12 @@ var ProxyService_ServiceDesc = grpc.ServiceDesc{ Handler: _ProxyService_GetMappingUpdate_Handler, ServerStreams: true, }, + { + StreamName: "SyncMappings", + Handler: _ProxyService_SyncMappings_Handler, + ServerStreams: true, + ClientStreams: true, + }, }, Metadata: "proxy_service.proto", } From 454ff66518feaef2d9ceb10fb1960da28f14fd87 Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Wed, 20 May 2026 18:24:00 +0200 Subject: [PATCH 21/31] [management] scope network router update call (#6222) --- .../networks_handler_integration_test.go | 8 +- management/server/networks/manager_test.go | 7 +- management/server/networks/routers/manager.go | 15 ++- .../server/networks/routers/manager_test.go | 97 +++++++++++++++++ management/server/store/sql_store.go | 24 ++++- management/server/store/sql_store_test.go | 37 ++++++- management/server/store/store.go | 3 +- management/server/store/store_mock.go | 102 ++++++++++-------- management/server/testdata/networks.sql | 4 + 9 files changed, 235 insertions(+), 62 deletions(-) diff --git a/management/server/http/testing/integration/networks_handler_integration_test.go b/management/server/http/testing/integration/networks_handler_integration_test.go index 54f204a8f..a0a49a9ec 100644 --- a/management/server/http/testing/integration/networks_handler_integration_test.go +++ b/management/server/http/testing/integration/networks_handler_integration_test.go @@ -1319,7 +1319,7 @@ func Test_NetworkRouters_Update(t *testing.T) { }, }, { - name: "Update non-existing router creates it", + name: "Update non-existing router returns not found", networkId: "testNetworkId", routerId: "nonExistingRouterId", requestBody: &api.NetworkRouterRequest{ @@ -1328,11 +1328,7 @@ func Test_NetworkRouters_Update(t *testing.T) { Metric: 100, Enabled: true, }, - expectedStatus: http.StatusOK, - verifyResponse: func(t *testing.T, router *api.NetworkRouter) { - t.Helper() - assert.Equal(t, "nonExistingRouterId", router.Id) - }, + expectedStatus: http.StatusNotFound, }, { name: "Update router with both peer and peer_groups", diff --git a/management/server/networks/manager_test.go b/management/server/networks/manager_test.go index 6fb19d157..24d5f49b7 100644 --- a/management/server/networks/manager_test.go +++ b/management/server/networks/manager_test.go @@ -34,8 +34,11 @@ func Test_GetAllNetworksReturnsNetworks(t *testing.T) { networks, err := manager.GetAllNetworks(ctx, accountID, userID) require.NoError(t, err) - require.Len(t, networks, 1) - require.Equal(t, "testNetworkId", networks[0].ID) + ids := make([]string, 0, len(networks)) + for _, n := range networks { + ids = append(ids, n.ID) + } + require.ElementsMatch(t, []string{"testNetworkId", "secondNetworkId"}, ids) } func Test_GetAllNetworksReturnsPermissionDenied(t *testing.T) { diff --git a/management/server/networks/routers/manager.go b/management/server/networks/routers/manager.go index c7c3f2ff4..ed5b0e558 100644 --- a/management/server/networks/routers/manager.go +++ b/management/server/networks/routers/manager.go @@ -102,7 +102,7 @@ func (m *managerImpl) CreateRouter(ctx context.Context, userID string, router *t router.ID = xid.New().String() - err = transaction.SaveNetworkRouter(ctx, router) + err = transaction.CreateNetworkRouter(ctx, router) if err != nil { return fmt.Errorf("failed to create network router: %w", err) } @@ -162,11 +162,20 @@ func (m *managerImpl) UpdateRouter(ctx context.Context, userID string, router *t return fmt.Errorf("failed to get network: %w", err) } - if network.ID != router.NetworkID { + existing, err := transaction.GetNetworkRouterByID(ctx, store.LockingStrengthUpdate, router.AccountID, router.ID) + if err != nil { + return fmt.Errorf("failed to get network router: %w", err) + } + + if existing.AccountID != router.AccountID { + return status.NewNetworkRouterNotFoundError(router.ID) + } + + if existing.NetworkID != router.NetworkID { return status.NewRouterNotPartOfNetworkError(router.ID, router.NetworkID) } - err = transaction.SaveNetworkRouter(ctx, router) + err = transaction.UpdateNetworkRouter(ctx, router) if err != nil { return fmt.Errorf("failed to update network router: %w", err) } diff --git a/management/server/networks/routers/manager_test.go b/management/server/networks/routers/manager_test.go index 6be90baa7..7b6d5f14f 100644 --- a/management/server/networks/routers/manager_test.go +++ b/management/server/networks/routers/manager_test.go @@ -195,6 +195,7 @@ func Test_UpdateRouterSuccessfully(t *testing.T) { if err != nil { require.NoError(t, err) } + router.ID = "testRouterId" s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) if err != nil { @@ -210,6 +211,102 @@ func Test_UpdateRouterSuccessfully(t *testing.T) { require.Equal(t, router.Metric, updatedRouter.Metric) } +func Test_UpdateRouterRejectsCrossAccountID(t *testing.T) { + ctx := context.Background() + userID := "testAdminId" + + // Admin of testAccountId tries to update a router that belongs to otherAccountId + // by passing the other account's router ID through the URL. + router, err := types.NewNetworkRouter("testAccountId", "testNetworkId", "testPeerId", []string{}, false, 1, true) + if err != nil { + require.NoError(t, err) + } + router.ID = "otherRouterId" + + s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(cleanUp) + permissionsManager := permissions.NewManager(s) + am := mock_server.MockAccountManager{} + manager := NewManager(s, permissionsManager, &am) + + updatedRouter, err := manager.UpdateRouter(ctx, userID, router) + require.Error(t, err) + require.Nil(t, updatedRouter) + + // The other account's router must be untouched. + stored, err := s.GetNetworkRouterByID(ctx, store.LockingStrengthNone, "otherAccountId", "otherRouterId") + require.NoError(t, err) + require.Equal(t, "otherAccountId", stored.AccountID) + require.Equal(t, "otherNetworkId", stored.NetworkID) + require.Equal(t, "otherPeer", stored.Peer) + require.Equal(t, 1, stored.Metric) +} + +func Test_CreateRouterRejectsCrossAccountID(t *testing.T) { + ctx := context.Background() + userID := "testAdminId" + + // Admin of testAccountId tries to create a router in otherAccountId's network. + // The permission check is on router.AccountID (their own), but the network + // lookup must fail because (testAccountId, otherNetworkId) does not exist. + router, err := types.NewNetworkRouter("testAccountId", "otherNetworkId", "testPeerId", []string{}, false, 1, true) + if err != nil { + require.NoError(t, err) + } + + s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(cleanUp) + permissionsManager := permissions.NewManager(s) + am := mock_server.MockAccountManager{} + manager := NewManager(s, permissionsManager, &am) + + createdRouter, err := manager.CreateRouter(ctx, userID, router) + require.Error(t, err) + require.Nil(t, createdRouter) + + // No router should have been created in either account's scope under otherNetworkId. + routersInOther, err := s.GetNetworkRoutersByNetID(ctx, store.LockingStrengthNone, "otherAccountId", "otherNetworkId") + require.NoError(t, err) + require.Len(t, routersInOther, 1) + require.Equal(t, "otherRouterId", routersInOther[0].ID) +} + +func Test_UpdateRouterRejectsNetworkMismatch(t *testing.T) { + ctx := context.Background() + userID := "testAdminId" + + // The router exists in testNetworkId, but the caller submits secondNetworkId + // (a different network in the same account). The update must be refused. + router, err := types.NewNetworkRouter("testAccountId", "secondNetworkId", "testPeerId", []string{}, false, 1, true) + if err != nil { + require.NoError(t, err) + } + router.ID = "testRouterId" + + s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(cleanUp) + permissionsManager := permissions.NewManager(s) + am := mock_server.MockAccountManager{} + manager := NewManager(s, permissionsManager, &am) + + updatedRouter, err := manager.UpdateRouter(ctx, userID, router) + require.Error(t, err) + require.Nil(t, updatedRouter) + + stored, err := s.GetNetworkRouterByID(ctx, store.LockingStrengthNone, "testAccountId", "testRouterId") + require.NoError(t, err) + require.Equal(t, "testNetworkId", stored.NetworkID) +} + func Test_UpdateRouterFailsWithPermissionDenied(t *testing.T) { ctx := context.Background() userID := "testUserId" diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index f3c6b741b..279c0e21f 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -4315,11 +4315,27 @@ func (s *SqlStore) GetNetworkRouterByID(ctx context.Context, lockStrength Lockin return netRouter, nil } -func (s *SqlStore) SaveNetworkRouter(ctx context.Context, router *routerTypes.NetworkRouter) error { - result := s.db.Save(router) +func (s *SqlStore) CreateNetworkRouter(ctx context.Context, router *routerTypes.NetworkRouter) error { + if err := s.db.Create(router).Error; err != nil { + log.WithContext(ctx).Errorf("failed to create network router in store: %v", err) + return status.Errorf(status.Internal, "failed to create network router in store") + } + + return nil +} + +func (s *SqlStore) UpdateNetworkRouter(ctx context.Context, router *routerTypes.NetworkRouter) error { + result := s.db. + Select("*"). + Where(accountAndIDQueryCondition, router.AccountID, router.ID). + Updates(router) if result.Error != nil { - log.WithContext(ctx).Errorf("failed to save network router to store: %v", result.Error) - return status.Errorf(status.Internal, "failed to save network router to store") + log.WithContext(ctx).Errorf("failed to update network router in store: %v", result.Error) + return status.Errorf(status.Internal, "failed to update network router in store") + } + + if result.RowsAffected == 0 { + return status.NewNetworkRouterNotFoundError(router.ID) } return nil diff --git a/management/server/store/sql_store_test.go b/management/server/store/sql_store_test.go index 7515add62..41e3290b6 100644 --- a/management/server/store/sql_store_test.go +++ b/management/server/store/sql_store_test.go @@ -2399,7 +2399,7 @@ func TestSqlStore_GetNetworkRouterByID(t *testing.T) { } } -func TestSqlStore_SaveNetworkRouter(t *testing.T) { +func TestSqlStore_CreateNetworkRouter(t *testing.T) { store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) @@ -2410,7 +2410,7 @@ func TestSqlStore_SaveNetworkRouter(t *testing.T) { netRouter, err := routerTypes.NewNetworkRouter(accountID, networkID, "", []string{"net-router-grp"}, true, 0, true) require.NoError(t, err) - err = store.SaveNetworkRouter(context.Background(), netRouter) + err = store.CreateNetworkRouter(context.Background(), netRouter) require.NoError(t, err) savedNetRouter, err := store.GetNetworkRouterByID(context.Background(), LockingStrengthNone, accountID, netRouter.ID) @@ -2418,6 +2418,39 @@ func TestSqlStore_SaveNetworkRouter(t *testing.T) { require.Equal(t, netRouter, savedNetRouter) } +func TestSqlStore_UpdateNetworkRouter(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + networkID := "ct286bi7qv930dsrrug0" + routerID := "ctc20ji7qv9ck2sebc80" + + netRouter := &routerTypes.NetworkRouter{ + ID: routerID, + AccountID: accountID, + NetworkID: networkID, + Peer: "", + PeerGroups: []string{"net-router-grp"}, + Masquerade: true, + Metric: 42, + Enabled: true, + } + + err = store.UpdateNetworkRouter(context.Background(), netRouter) + require.NoError(t, err) + + savedNetRouter, err := store.GetNetworkRouterByID(context.Background(), LockingStrengthNone, accountID, routerID) + require.NoError(t, err) + require.Equal(t, netRouter, savedNetRouter) + + // Updating a router under a different account must not match any row. + netRouter.AccountID = "non-existent-account" + err = store.UpdateNetworkRouter(context.Background(), netRouter) + require.Error(t, err) +} + func TestSqlStore_DeleteNetworkRouter(t *testing.T) { store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) t.Cleanup(cleanup) diff --git a/management/server/store/store.go b/management/server/store/store.go index 42cdcf36d..39b1c0ed3 100644 --- a/management/server/store/store.go +++ b/management/server/store/store.go @@ -228,7 +228,8 @@ type Store interface { GetNetworkRoutersByNetID(ctx context.Context, lockStrength LockingStrength, accountID, netID string) ([]*routerTypes.NetworkRouter, error) GetNetworkRoutersByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*routerTypes.NetworkRouter, error) GetNetworkRouterByID(ctx context.Context, lockStrength LockingStrength, accountID, routerID string) (*routerTypes.NetworkRouter, error) - SaveNetworkRouter(ctx context.Context, router *routerTypes.NetworkRouter) error + CreateNetworkRouter(ctx context.Context, router *routerTypes.NetworkRouter) error + UpdateNetworkRouter(ctx context.Context, router *routerTypes.NetworkRouter) error DeleteNetworkRouter(ctx context.Context, accountID, routerID string) error GetNetworkResourcesByNetID(ctx context.Context, lockStrength LockingStrength, accountID, netID string) ([]*resourceTypes.NetworkResource, error) diff --git a/management/server/store/store_mock.go b/management/server/store/store_mock.go index 4f9d875d2..c7e86c2db 100644 --- a/management/server/store/store_mock.go +++ b/management/server/store/store_mock.go @@ -310,6 +310,20 @@ func (mr *MockStoreMockRecorder) CreateGroups(ctx, accountID, groups interface{} return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateGroups", reflect.TypeOf((*MockStore)(nil).CreateGroups), ctx, accountID, groups) } +// CreateNetworkRouter mocks base method. +func (m *MockStore) CreateNetworkRouter(ctx context.Context, router *types0.NetworkRouter) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CreateNetworkRouter", ctx, router) + ret0, _ := ret[0].(error) + return ret0 +} + +// CreateNetworkRouter indicates an expected call of CreateNetworkRouter. +func (mr *MockStoreMockRecorder) CreateNetworkRouter(ctx, router interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateNetworkRouter", reflect.TypeOf((*MockStore)(nil).CreateNetworkRouter), ctx, router) +} + // CreatePeerJob mocks base method. func (m *MockStore) CreatePeerJob(ctx context.Context, job *types2.Job) error { m.ctrl.T.Helper() @@ -2612,6 +2626,36 @@ func (mr *MockStoreMockRecorder) MarkPATUsed(ctx, patID interface{}) *gomock.Cal return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MarkPATUsed", reflect.TypeOf((*MockStore)(nil).MarkPATUsed), ctx, patID) } +// MarkPeerConnectedIfNewerSession mocks base method. +func (m *MockStore) MarkPeerConnectedIfNewerSession(ctx context.Context, accountID, peerID string, newSessionStartedAt int64) (bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "MarkPeerConnectedIfNewerSession", ctx, accountID, peerID, newSessionStartedAt) + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// MarkPeerConnectedIfNewerSession indicates an expected call of MarkPeerConnectedIfNewerSession. +func (mr *MockStoreMockRecorder) MarkPeerConnectedIfNewerSession(ctx, accountID, peerID, newSessionStartedAt interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MarkPeerConnectedIfNewerSession", reflect.TypeOf((*MockStore)(nil).MarkPeerConnectedIfNewerSession), ctx, accountID, peerID, newSessionStartedAt) +} + +// MarkPeerDisconnectedIfSameSession mocks base method. +func (m *MockStore) MarkPeerDisconnectedIfSameSession(ctx context.Context, accountID, peerID string, sessionStartedAt int64) (bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "MarkPeerDisconnectedIfSameSession", ctx, accountID, peerID, sessionStartedAt) + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// MarkPeerDisconnectedIfSameSession indicates an expected call of MarkPeerDisconnectedIfSameSession. +func (mr *MockStoreMockRecorder) MarkPeerDisconnectedIfSameSession(ctx, accountID, peerID, sessionStartedAt interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MarkPeerDisconnectedIfSameSession", reflect.TypeOf((*MockStore)(nil).MarkPeerDisconnectedIfSameSession), ctx, accountID, peerID, sessionStartedAt) +} + // MarkPendingJobsAsFailed mocks base method. func (m *MockStore) MarkPendingJobsAsFailed(ctx context.Context, accountID, peerID, jobID, reason string) error { m.ctrl.T.Helper() @@ -2822,20 +2866,6 @@ func (mr *MockStoreMockRecorder) SaveNetworkResource(ctx, resource interface{}) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveNetworkResource", reflect.TypeOf((*MockStore)(nil).SaveNetworkResource), ctx, resource) } -// SaveNetworkRouter mocks base method. -func (m *MockStore) SaveNetworkRouter(ctx context.Context, router *types0.NetworkRouter) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SaveNetworkRouter", ctx, router) - ret0, _ := ret[0].(error) - return ret0 -} - -// SaveNetworkRouter indicates an expected call of SaveNetworkRouter. -func (mr *MockStoreMockRecorder) SaveNetworkRouter(ctx, router interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveNetworkRouter", reflect.TypeOf((*MockStore)(nil).SaveNetworkRouter), ctx, router) -} - // SavePAT mocks base method. func (m *MockStore) SavePAT(ctx context.Context, pat *types2.PersonalAccessToken) error { m.ctrl.T.Helper() @@ -2892,36 +2922,6 @@ func (mr *MockStoreMockRecorder) SavePeerStatus(ctx, accountID, peerID, status i return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SavePeerStatus", reflect.TypeOf((*MockStore)(nil).SavePeerStatus), ctx, accountID, peerID, status) } -// MarkPeerConnectedIfNewerSession mocks base method. -func (m *MockStore) MarkPeerConnectedIfNewerSession(ctx context.Context, accountID, peerID string, newSessionStartedAt int64) (bool, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "MarkPeerConnectedIfNewerSession", ctx, accountID, peerID, newSessionStartedAt) - ret0, _ := ret[0].(bool) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// MarkPeerConnectedIfNewerSession indicates an expected call of MarkPeerConnectedIfNewerSession. -func (mr *MockStoreMockRecorder) MarkPeerConnectedIfNewerSession(ctx, accountID, peerID, newSessionStartedAt interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MarkPeerConnectedIfNewerSession", reflect.TypeOf((*MockStore)(nil).MarkPeerConnectedIfNewerSession), ctx, accountID, peerID, newSessionStartedAt) -} - -// MarkPeerDisconnectedIfSameSession mocks base method. -func (m *MockStore) MarkPeerDisconnectedIfSameSession(ctx context.Context, accountID, peerID string, sessionStartedAt int64) (bool, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "MarkPeerDisconnectedIfSameSession", ctx, accountID, peerID, sessionStartedAt) - ret0, _ := ret[0].(bool) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// MarkPeerDisconnectedIfSameSession indicates an expected call of MarkPeerDisconnectedIfSameSession. -func (mr *MockStoreMockRecorder) MarkPeerDisconnectedIfSameSession(ctx, accountID, peerID, sessionStartedAt interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MarkPeerDisconnectedIfSameSession", reflect.TypeOf((*MockStore)(nil).MarkPeerDisconnectedIfSameSession), ctx, accountID, peerID, sessionStartedAt) -} - // SavePolicy mocks base method. func (m *MockStore) SavePolicy(ctx context.Context, policy *types2.Policy) error { m.ctrl.T.Helper() @@ -3173,6 +3173,20 @@ func (mr *MockStoreMockRecorder) UpdateGroups(ctx, accountID, groups interface{} return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateGroups", reflect.TypeOf((*MockStore)(nil).UpdateGroups), ctx, accountID, groups) } +// UpdateNetworkRouter mocks base method. +func (m *MockStore) UpdateNetworkRouter(ctx context.Context, router *types0.NetworkRouter) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpdateNetworkRouter", ctx, router) + ret0, _ := ret[0].(error) + return ret0 +} + +// UpdateNetworkRouter indicates an expected call of UpdateNetworkRouter. +func (mr *MockStoreMockRecorder) UpdateNetworkRouter(ctx, router interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateNetworkRouter", reflect.TypeOf((*MockStore)(nil).UpdateNetworkRouter), ctx, router) +} + // UpdateProxyHeartbeat mocks base method. func (m *MockStore) UpdateProxyHeartbeat(ctx context.Context, p *proxy.Proxy) error { m.ctrl.T.Helper() diff --git a/management/server/testdata/networks.sql b/management/server/testdata/networks.sql index bcb202084..911b3bb27 100644 --- a/management/server/testdata/networks.sql +++ b/management/server/testdata/networks.sql @@ -9,9 +9,13 @@ INSERT INTO peers VALUES('testPeerId','testAccountId','5rvhvriKJZ3S9oxYToVj5TzDM CREATE TABLE `networks` (`id` text,`account_id` text,`name` text,`description` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_networks` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); INSERT INTO networks VALUES('testNetworkId','testAccountId','some-name','some-description'); +INSERT INTO networks VALUES('secondNetworkId','testAccountId','second-name','second-description'); CREATE TABLE `network_routers` (`id` text,`network_id` text,`account_id` text,`peer` text,`peer_groups` text,`masquerade` numeric,`metric` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_network_routers` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); INSERT INTO network_routers VALUES('testRouterId','testNetworkId','testAccountId','','["csquuo4jcko732k1ag00"]',0,9999); +INSERT INTO accounts VALUES('otherAccountId','','2024-10-02 16:01:38.000000000+00:00','other.com','private',1,'otherNetworkIdentifier','{"IP":"100.65.0.0","Mask":"//8AAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL); +INSERT INTO networks VALUES('otherNetworkId','otherAccountId','other-net','other-description'); +INSERT INTO network_routers VALUES('otherRouterId','otherNetworkId','otherAccountId','otherPeer',NULL,0,1); CREATE TABLE `network_resources` (`id` text,`network_id` text,`account_id` text,`name` text,`description` text,`type` text,`address` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_network_resources` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); INSERT INTO network_resources VALUES('testResourceId','testNetworkId','testAccountId','some-name','some-description','host','3.3.3.3/32'); From 37052fd5bc050f8eaf80b986d511f3c278233cca Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Thu, 21 May 2026 01:46:51 +0900 Subject: [PATCH 22/31] [client] Fix nil channel panic in external chain monitor stop (#6224) --- client/firewall/nftables/external_chain_monitor_linux.go | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/client/firewall/nftables/external_chain_monitor_linux.go b/client/firewall/nftables/external_chain_monitor_linux.go index 2a2e04c09..9c91c95cf 100644 --- a/client/firewall/nftables/external_chain_monitor_linux.go +++ b/client/firewall/nftables/external_chain_monitor_linux.go @@ -52,9 +52,10 @@ func (m *externalChainMonitor) start() { ctx, cancel := context.WithCancel(context.Background()) m.cancel = cancel - m.done = make(chan struct{}) + done := make(chan struct{}) + m.done = done - go m.run(ctx) + go m.run(ctx, done) } func (m *externalChainMonitor) stop() { @@ -72,8 +73,8 @@ func (m *externalChainMonitor) stop() { <-done } -func (m *externalChainMonitor) run(ctx context.Context) { - defer close(m.done) +func (m *externalChainMonitor) run(ctx context.Context, done chan struct{}) { + defer close(done) bo := &backoff.ExponentialBackOff{ InitialInterval: externalMonitorInitInterval, From 0358be23136da50e829ba99a83e54ef555071a7f Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Thu, 21 May 2026 23:27:12 +0900 Subject: [PATCH 23/31] [client] Revert "Clean up legacy 32-bit and HKCU registry entries on Windows install (#6176)" (#6232) This reverts commit d927ef468a73a15c734987b9cd5478f2a5b12738. --- client/installer.nsis | 23 +++++------------------ client/netbird.wxs | 25 ------------------------- 2 files changed, 5 insertions(+), 43 deletions(-) diff --git a/client/installer.nsis b/client/installer.nsis index 3e057df10..63bff1c5b 100644 --- a/client/installer.nsis +++ b/client/installer.nsis @@ -260,23 +260,15 @@ WriteRegStr ${REG_ROOT} "${UNINSTALL_PATH}" "Publisher" "${COMP_NAME}" WriteRegStr ${REG_ROOT} "${UI_REG_APP_PATH}" "" "$INSTDIR\${UI_APP_EXE}" -; Drop Run, App Paths and Uninstall entries left in the 32-bit registry view -; or HKCU by legacy installers. -DetailPrint "Cleaning legacy 32-bit / HKCU entries..." -DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}" -SetRegView 32 -DeleteRegValue HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}" -DeleteRegKey HKLM "${REG_APP_PATH}" -DeleteRegKey HKLM "${UI_REG_APP_PATH}" -DeleteRegKey HKLM "${UNINSTALL_PATH}" -SetRegView 64 - +; Create autostart registry entry based on checkbox DetailPrint "Autostart enabled: $AutostartEnabled" ${If} $AutostartEnabled == "1" WriteRegStr HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}" '"$INSTDIR\${UI_APP_EXE}.exe"' DetailPrint "Added autostart registry entry: $INSTDIR\${UI_APP_EXE}.exe" ${Else} DeleteRegValue HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}" + ; Legacy: pre-HKLM installs wrote to HKCU; clean that up too. + DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}" DetailPrint "Autostart not enabled by user" ${EndIf} @@ -307,16 +299,11 @@ ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service uninstall' DetailPrint "Terminating Netbird UI process..." ExecWait `taskkill /im ${UI_APP_EXE}.exe /f` -; Remove autostart entries from every view a previous installer may have used. +; Remove autostart registry entry DetailPrint "Removing autostart registry entry if exists..." DeleteRegValue HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}" +; Legacy: pre-HKLM installs wrote to HKCU; clean that up too. DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}" -SetRegView 32 -DeleteRegValue HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}" -DeleteRegKey HKLM "${REG_APP_PATH}" -DeleteRegKey HKLM "${UI_REG_APP_PATH}" -DeleteRegKey HKLM "${UNINSTALL_PATH}" -SetRegView 64 ; Handle data deletion based on checkbox DetailPrint "Checking if user requested data deletion..." diff --git a/client/netbird.wxs b/client/netbird.wxs index 96814ce52..6f18b63b5 100644 --- a/client/netbird.wxs +++ b/client/netbird.wxs @@ -64,13 +64,6 @@ - - - - - @@ -83,28 +76,10 @@ - - - - - - - - - - - From 7aebdd69dd3acd40f827ae87b14d3a4b3e620dbf Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Mon, 25 May 2026 17:41:50 +0200 Subject: [PATCH 24/31] [management, client, proxy] add expose NetBird-only services over tunnel peers (#6226) Adds a new "private" service mode for the reverse proxy: services reachable exclusively over the embedded WireGuard tunnel, gated by per-peer group membership instead of operator auth schemes. Wire contract - ProxyMapping.private (field 13): the proxy MUST call ValidateTunnelPeer and fail closed; operator schemes are bypassed. - ProxyCapabilities.private (4) + supports_private_service (5): capability gate. Management never streams private mappings to proxies that don't claim the capability; the broadcast path applies the same filter via filterMappingsForProxy. - ValidateTunnelPeer RPC: resolves an inbound tunnel IP to a peer, checks the peer's groups against service.AccessGroups, and mints a session JWT on success. checkPeerGroupAccess fails closed when a private service has empty AccessGroups. - ValidateSession/ValidateTunnelPeer responses now carry peer_group_ids + peer_group_names so the proxy can authorise policy-aware middlewares without an extra management round-trip. - ProxyInboundListener + SendStatusUpdate.inbound_listener: per-account inbound listener state surfaced to dashboards. - PathTargetOptions.direct_upstream (11): bypass the embedded NetBird client and dial the target via the proxy host's network stack for upstreams reachable without WireGuard. Data model - Service.Private (bool) + Service.AccessGroups ([]string, JSON- serialised). Validate() rejects bearer auth on private services. Copy() deep-copies AccessGroups. pgx getServices loads the columns. - DomainConfig.Private threaded into the proxy auth middleware. Request handler routes private services through forwardWithTunnelPeer and returns 403 on validation failure. - Account-level SynthesizePrivateServiceZones (synthetic DNS) and injectPrivateServicePolicies (synthetic ACL) gate on len(svc.AccessGroups) > 0. Proxy - /netbird proxy --private (embedded mode) flag; Config.Private in proxy/lifecycle.go. - Per-account inbound listener (proxy/inbound.go) binding HTTP/HTTPS on the embedded NetBird client's WireGuard tunnel netstack. - proxy/internal/auth/tunnel_cache: ValidateTunnelPeer response cache with single-flight de-duplication and per-account eviction. - Local peerstore short-circuit: when the inbound IP isn't in the account roster, deny fast without an RPC. - proxy/server.go reports SupportsPrivateService=true and redacts the full ProxyMapping JSON from info logs (auth_token + header-auth hashed values now only at debug level). Identity forwarding - ValidateSessionJWT returns user_id, email, method, groups, group_names. sessionkey.Claims carries Email + Groups + GroupNames so the proxy can stamp identity onto upstream requests without an extra management round-trip on every cookie-bearing request. - CapturedData carries userEmail / userGroups / userGroupNames; the proxy stamps X-NetBird-User and X-NetBird-Groups on r.Out from the authenticated identity (strips client-supplied values first to prevent spoofing). - AccessLog.UserGroups: access-log enrichment captures the user's group memberships at write time so the dashboard can render group context without reverse-resolving stale memberships. OpenAPI/dashboard surface - ReverseProxyService gains private + access_groups; ReverseProxyCluster gains private + supports_private. ReverseProxyTarget target_type enum gains "cluster". ServiceTargetOptions gains direct_upstream. ProxyAccessLog gains user_groups. --- .github/workflows/proto-version-check.yml | 66 +- client/embed/embed.go | 22 + client/internal/dns/local/local.go | 100 ++ client/internal/dns/local/local_test.go | 126 ++ client/internal/dns/server.go | 27 + client/internal/peer/status.go | 82 +- client/internal/peer/status_test.go | 27 + management/internals/modules/peers/manager.go | 29 + .../internals/modules/peers/manager_mock.go | 61 +- .../modules/reverseproxy/domain/domain.go | 2 + .../reverseproxy/domain/manager/api.go | 1 + .../reverseproxy/domain/manager/manager.go | 3 + .../domain/manager/manager_test.go | 7 +- .../modules/reverseproxy/proxy/manager.go | 1 + .../reverseproxy/proxy/manager/manager.go | 7 +- .../proxy/manager/manager_test.go | 23 +- .../reverseproxy/proxy/manager_mock.go | 14 + .../modules/reverseproxy/proxy/proxy.go | 8 +- .../reverseproxy/service/manager/api.go | 1 + .../reverseproxy/service/manager/manager.go | 20 +- .../service/manager/manager_test.go | 63 + .../modules/reverseproxy/service/service.go | 144 +- .../reverseproxy/service/service_test.go | 189 +++ .../reverseproxy/sessionkey/sessionkey.go | 27 +- management/internals/shared/grpc/proxy.go | 272 +++- .../shared/grpc/proxy_group_access_test.go | 51 + .../shared/grpc/validate_session_test.go | 4 +- management/server/metrics/selfhosted.go | 39 + management/server/metrics/selfhosted_test.go | 71 +- management/server/peer.go | 23 + management/server/store/file_store.go | 6 + management/server/store/sql_store.go | 57 +- .../server/store/sql_store_service_test.go | 46 + management/server/store/store.go | 30 + management/server/store/store_mock.go | 29 + management/server/types/account.go | 162 ++- management/server/types/account_components.go | 1 + .../types/account_private_netmap_test.go | 85 ++ .../types/account_private_zones_test.go | 256 ++++ management/server/types/account_test.go | 208 ++- management/server/users/manager.go | 34 + proxy/auth/auth.go | 44 +- proxy/cmd/proxy/cmd/root.go | 11 +- proxy/inbound.go | 547 ++++++++ proxy/inbound_test.go | 502 +++++++ proxy/internal/auth/identity.go | 47 + proxy/internal/auth/middleware.go | 216 ++- proxy/internal/auth/middleware_test.go | 260 +++- proxy/internal/auth/tunnel_cache.go | 171 +++ proxy/internal/auth/tunnel_cache_test.go | 171 +++ proxy/internal/auth/tunnel_lookup_test.go | 325 +++++ proxy/internal/debug/client.go | 53 +- proxy/internal/debug/handler.go | 65 +- proxy/internal/proxy/context.go | 86 +- proxy/internal/proxy/reverseproxy.go | 72 + proxy/internal/proxy/reverseproxy_test.go | 242 ++++ proxy/internal/proxy/servicemapping.go | 4 + proxy/internal/restrict/restrict.go | 12 + proxy/internal/restrict/restrict_test.go | 28 + proxy/internal/roundtrip/multi.go | 112 ++ proxy/internal/roundtrip/multi_test.go | 134 ++ proxy/internal/roundtrip/netbird.go | 170 ++- proxy/internal/roundtrip/netbird_test.go | 31 + proxy/internal/tcp/bench_test.go | 4 +- proxy/internal/tcp/router.go | 130 +- proxy/internal/tcp/router_test.go | 94 ++ proxy/internal/tcp/snipeek.go | 16 +- proxy/internal/tcp/snipeek_test.go | 16 +- proxy/internal/types/types.go | 20 + proxy/lifecycle.go | 160 +++ proxy/management_integration_test.go | 5 + proxy/server.go | 630 ++++++--- proxy/server_test.go | 156 +++ shared/management/client/rest/client.go | 5 + .../client/rest/reverse_proxy_clusters.go | 21 +- .../rest/reverse_proxy_clusters_test.go | 90 ++ .../rest/reverse_proxy_services_test.go | 25 +- .../client/rest/reverse_proxy_tokens.go | 72 + .../client/rest/reverse_proxy_tokens_test.go | 131 ++ shared/management/http/api/openapi.yml | 41 +- shared/management/http/api/types.gen.go | 35 +- shared/management/proto/proxy_service.pb.go | 1226 +++++++++++------ shared/management/proto/proxy_service.proto | 89 ++ .../management/proto/proxy_service_grpc.pb.go | 50 + 84 files changed, 7810 insertions(+), 933 deletions(-) create mode 100644 management/server/store/sql_store_service_test.go create mode 100644 management/server/types/account_private_netmap_test.go create mode 100644 management/server/types/account_private_zones_test.go create mode 100644 proxy/inbound.go create mode 100644 proxy/inbound_test.go create mode 100644 proxy/internal/auth/identity.go create mode 100644 proxy/internal/auth/tunnel_cache.go create mode 100644 proxy/internal/auth/tunnel_cache_test.go create mode 100644 proxy/internal/auth/tunnel_lookup_test.go create mode 100644 proxy/internal/roundtrip/multi.go create mode 100644 proxy/internal/roundtrip/multi_test.go create mode 100644 proxy/lifecycle.go create mode 100644 shared/management/client/rest/reverse_proxy_clusters_test.go create mode 100644 shared/management/client/rest/reverse_proxy_tokens.go create mode 100644 shared/management/client/rest/reverse_proxy_tokens_test.go diff --git a/.github/workflows/proto-version-check.yml b/.github/workflows/proto-version-check.yml index ea300419d..bec503b36 100644 --- a/.github/workflows/proto-version-check.yml +++ b/.github/workflows/proto-version-check.yml @@ -20,34 +20,66 @@ jobs: per_page: 100, }); - const pbFiles = files.filter(f => f.filename.endsWith('.pb.go')); - const missingPatch = pbFiles.filter(f => !f.patch).map(f => f.filename); - if (missingPatch.length > 0) { - core.setFailed( - `Cannot inspect patch data for:\n` + - missingPatch.map(f => `- ${f}`).join('\n') + - `\nThis can happen with very large PRs. Verify proto versions manually.` - ); + const modifiedPbFiles = files.filter( + f => f.filename.endsWith('.pb.go') && f.status === 'modified' + ); + if (modifiedPbFiles.length === 0) { + console.log('No modified .pb.go files to check'); return; } - const versionPattern = /^[+-]\s*\/\/\s+protoc(?:-gen-go)?\s+v[\d.]+/; - const violations = []; - for (const file of pbFiles) { - const changed = file.patch - .split('\n') - .filter(line => versionPattern.test(line)); - if (changed.length > 0) { + const versionPattern = /^\s*\/\/\s+protoc(?:-gen-go)?\s+v[\d.]+/; + const baseSha = context.payload.pull_request.base.sha; + const headSha = context.payload.pull_request.head.sha; + + async function getVersionHeader(path, ref) { + try { + const res = await github.rest.repos.getContent({ + owner: context.repo.owner, + repo: context.repo.repo, + path, + ref, + }); + if (!res.data.content) { + return { ok: false, reason: 'no inline content (file too large)' }; + } + const content = Buffer.from(res.data.content, 'base64').toString('utf8'); + const lines = content + .split('\n') + .slice(0, 20) + .filter(line => versionPattern.test(line)); + return { ok: true, lines }; + } catch (e) { + return { ok: false, reason: e.message }; + } + } + + const violations = []; + for (const file of modifiedPbFiles) { + const [base, head] = await Promise.all([ + getVersionHeader(file.filename, baseSha), + getVersionHeader(file.filename, headSha), + ]); + if (!base.ok || !head.ok) { + core.warning( + `Skipping ${file.filename}: base=${base.ok ? 'ok' : base.reason}, head=${head.ok ? 'ok' : head.reason}` + ); + continue; + } + if (base.lines.join('\n') !== head.lines.join('\n')) { violations.push({ file: file.filename, - lines: changed, + base: base.lines, + head: head.lines, }); } } if (violations.length > 0) { const details = violations.map(v => - `${v.file}:\n${v.lines.map(l => ' ' + l).join('\n')}` + `${v.file}:\n` + + ` base:\n${v.base.map(l => ' ' + l).join('\n') || ' (none)'}\n` + + ` head:\n${v.head.map(l => ' ' + l).join('\n') || ' (none)'}` ).join('\n\n'); core.setFailed( diff --git a/client/embed/embed.go b/client/embed/embed.go index 8b669e547..7e7f6c337 100644 --- a/client/embed/embed.go +++ b/client/embed/embed.go @@ -84,6 +84,12 @@ type Options struct { DisableIPv6 bool // BlockInbound blocks all inbound connections from peers BlockInbound bool + // BlockLANAccess blocks the embedded peer from reaching the host's + // LAN (RFC 1918, link-local, loopback) when it's used as a routing + // peer. Mirrors profilemanager.ConfigInput.BlockLANAccess. Useful + // when the embedded client must never act as a stepping stone into + // the host's local network (e.g. the proxy's overlay peer). + BlockLANAccess bool // WireguardPort is the port for the tunnel interface. Use 0 for a random port. WireguardPort *int // MTU is the MTU for the tunnel interface. @@ -175,6 +181,7 @@ func New(opts Options) (*Client, error) { DisableClientRoutes: &opts.DisableClientRoutes, DisableIPv6: &opts.DisableIPv6, BlockInbound: &opts.BlockInbound, + BlockLANAccess: &opts.BlockLANAccess, WireguardPort: opts.WireguardPort, MTU: opts.MTU, DNSLabels: parsedLabels, @@ -405,6 +412,21 @@ func (c *Client) Expose(ctx context.Context, req ExposeRequest) (*ExposeSession, }, nil } +// IdentityForIP looks up a remote peer by its tunnel IP using the +// embedded client's status recorder. Returns the peer's WireGuard public +// key and FQDN. ok=false means the IP isn't in this client's peer +// roster โ€” callers should treat that as "unknown peer". +func (c *Client) IdentityForIP(ip netip.Addr) (pubKey, fqdn string, ok bool) { + if !ip.IsValid() || c.recorder == nil { + return "", "", false + } + state, found := c.recorder.PeerStateByIP(ip.String()) + if !found { + return "", "", false + } + return state.PubKey, state.FQDN, true +} + // Status returns the current status of the client. func (c *Client) Status() (peer.FullStatus, error) { c.mu.Lock() diff --git a/client/internal/dns/local/local.go b/client/internal/dns/local/local.go index 4a75a76b6..d13aa672e 100644 --- a/client/internal/dns/local/local.go +++ b/client/internal/dns/local/local.go @@ -26,6 +26,19 @@ type resolver interface { LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error) } +// PeerConnectivity reports whether a tunnel IP belongs to a peer the +// client knows about and whether that peer is currently connected. The +// local resolver uses this to suppress A/AAAA answers whose RDATA points +// at a disconnected peer (typical case: a synthesized private-service +// record pointing at an embedded proxy peer that just went offline). +// +// known=false means the IP isn't in the local peerstore at all โ€” the +// record is left alone (it points at something outside our mesh, e.g. +// a non-peer upstream). +type PeerConnectivity interface { + IsConnectedByIP(ip string) (known, connected bool) +} + type Resolver struct { mu sync.RWMutex records map[dns.Question][]dns.RR @@ -33,6 +46,11 @@ type Resolver struct { // zones maps zone domain -> NonAuthoritative (true = non-authoritative, user-created zone) zones map[domain.Domain]bool resolver resolver + // peerConn, when non-nil, is consulted on every A/AAAA answer to + // drop records pointing at disconnected peers. nil disables the + // filter and preserves the legacy "return whatever is registered" + // behaviour for callers that never wire a status source. + peerConn PeerConnectivity ctx context.Context cancel context.CancelFunc @@ -49,6 +67,15 @@ func NewResolver() *Resolver { } } +// SetPeerConnectivity wires the per-IP connectivity check used to filter +// out A/AAAA answers pointing at disconnected peers. Pass nil to disable. +// Safe to call multiple times; the latest value wins. +func (d *Resolver) SetPeerConnectivity(p PeerConnectivity) { + d.mu.Lock() + defer d.mu.Unlock() + d.peerConn = p +} + func (d *Resolver) MatchSubdomains() bool { return true } @@ -95,6 +122,7 @@ func (d *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { replyMessage.RecursionAvailable = true result := d.lookupRecords(logger, question) + result.records = d.filterDisconnectedPeerAnswers(logger, question, result.records) replyMessage.Authoritative = !result.hasExternalData replyMessage.Answer = result.records replyMessage.Rcode = d.determineRcode(question, result) @@ -436,6 +464,78 @@ func (d *Resolver) logDNSError(logger *log.Entry, hostname string, qtype uint16, } } +// filterDisconnectedPeerAnswers drops A/AAAA records whose RDATA matches +// a known but disconnected peer. The synthesized private-service zones +// emit one A record per connected proxy peer in a cluster; when a peer +// goes offline, the server-side refresh removes the record from the +// next netmap, but the client may still hold the previous netmap for a +// short window. This filter is the local belt to that braces โ€” even on +// the stale netmap, the resolver hides the offline target. +// +// Records pointing at unknown IPs (outside the local peerstore, e.g. +// non-mesh upstreams) are never dropped. Non-A/AAAA records pass +// through untouched. +// +// Escape hatch: if filtering would leave the answer empty AND at least +// one record was filtered, the original list is returned. Better to +// hand the client a record that may not respond than NXDOMAIN it +// completely when every proxy peer is offline (the upstream may still +// be reachable some other way, or the peerstore may be stale). +func (d *Resolver) filterDisconnectedPeerAnswers(logger *log.Entry, question dns.Question, records []dns.RR) []dns.RR { + if len(records) == 0 { + return records + } + d.mu.RLock() + checker := d.peerConn + d.mu.RUnlock() + if checker == nil { + return records + } + + kept := make([]dns.RR, 0, len(records)) + var dropped int + for _, rr := range records { + ip := extractRecordIP(rr) + if ip == "" { + kept = append(kept, rr) + continue + } + known, connected := checker.IsConnectedByIP(ip) + if known && !connected { + dropped++ + continue + } + kept = append(kept, rr) + } + if dropped == 0 { + return records + } + if len(kept) == 0 { + logger.Debugf("all %d answers for %s point at disconnected peers; returning the original list", dropped, question.Name) + return records + } + logger.Tracef("dropped %d disconnected-peer answer(s) for %s, returning %d", dropped, question.Name, len(kept)) + return kept +} + +// extractRecordIP returns the dotted-decimal / colon-hex IP carried by +// an A or AAAA record, or "" for any other record type. +func extractRecordIP(rr dns.RR) string { + switch r := rr.(type) { + case *dns.A: + if r.A == nil { + return "" + } + return r.A.String() + case *dns.AAAA: + if r.AAAA == nil { + return "" + } + return r.AAAA.String() + } + return "" +} + // Update replaces all zones and their records func (d *Resolver) Update(customZones []nbdns.CustomZone) { d.mu.Lock() diff --git a/client/internal/dns/local/local_test.go b/client/internal/dns/local/local_test.go index 2c6b7dbc3..fdf7f2659 100644 --- a/client/internal/dns/local/local_test.go +++ b/client/internal/dns/local/local_test.go @@ -30,6 +30,21 @@ func (m *mockResolver) LookupNetIP(ctx context.Context, network, host string) ([ return nil, nil } +// mockPeerConnectivity returns canned (known, connected) results per IP. +// Used by the disconnected-peer filter tests below. IPs not in the map +// are reported as unknown so the filter leaves them alone. +type mockPeerConnectivity struct { + byIP map[string]struct{ known, connected bool } +} + +func (m mockPeerConnectivity) IsConnectedByIP(ip string) (known, connected bool) { + v, ok := m.byIP[ip] + if !ok { + return false, false + } + return v.known, v.connected +} + func TestLocalResolver_ServeDNS(t *testing.T) { recordA := nbdns.SimpleRecord{ Name: "peera.netbird.cloud.", @@ -2652,3 +2667,114 @@ func BenchmarkIsInManagedZone_ManyZones(b *testing.B) { resolver.isInManagedZone(qname) } } + +// TestLocalResolver_FilterDisconnectedPeerAnswers verifies the +// connectivity-aware filtering layered on top of lookupRecords: +// when an A record's IP belongs to a known peer that's disconnected, +// the record is dropped from the answer. Records for unknown IPs pass +// through. If filtering would empty the answer entirely and at least +// one record was dropped, the original list is restored (escape hatch +// for the "all proxies offline" case). +func TestLocalResolver_FilterDisconnectedPeerAnswers(t *testing.T) { + zone := "svc.cluster.netbird." + connectedRec := nbdns.SimpleRecord{ + Name: zone, + Type: int(dns.TypeA), + Class: nbdns.DefaultClass, + TTL: 5, + RData: "100.64.0.10", + } + disconnectedRec := nbdns.SimpleRecord{ + Name: zone, + Type: int(dns.TypeA), + Class: nbdns.DefaultClass, + TTL: 5, + RData: "100.64.0.11", + } + unknownRec := nbdns.SimpleRecord{ + Name: zone, + Type: int(dns.TypeA), + Class: nbdns.DefaultClass, + TTL: 5, + RData: "203.0.113.5", + } + + type ipState struct{ known, connected bool } + tests := []struct { + name string + records []nbdns.SimpleRecord + connByIP map[string]ipState + wantInOrder []string + }{ + { + name: "drops disconnected peer, keeps connected", + records: []nbdns.SimpleRecord{connectedRec, disconnectedRec}, + connByIP: map[string]ipState{ + "100.64.0.10": {known: true, connected: true}, + "100.64.0.11": {known: true, connected: false}, + }, + wantInOrder: []string{"100.64.0.10"}, + }, + { + name: "unknown IPs pass through untouched", + records: []nbdns.SimpleRecord{unknownRec, disconnectedRec}, + connByIP: map[string]ipState{ + "100.64.0.11": {known: true, connected: false}, + }, + wantInOrder: []string{"203.0.113.5"}, + }, + { + name: "all disconnected falls back to original list", + records: []nbdns.SimpleRecord{disconnectedRec, connectedRec}, + connByIP: map[string]ipState{ + "100.64.0.10": {known: true, connected: false}, + "100.64.0.11": {known: true, connected: false}, + }, + wantInOrder: []string{"100.64.0.11", "100.64.0.10"}, + }, + { + name: "no checker wired returns all records", + records: []nbdns.SimpleRecord{connectedRec, disconnectedRec}, + connByIP: nil, + wantInOrder: []string{"100.64.0.10", "100.64.0.11"}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + resolver := NewResolver() + if tc.connByIP != nil { + cm := mockPeerConnectivity{byIP: make(map[string]struct{ known, connected bool }, len(tc.connByIP))} + for ip, st := range tc.connByIP { + cm.byIP[ip] = struct{ known, connected bool }{st.known, st.connected} + } + resolver.SetPeerConnectivity(cm) + } + resolver.Update([]nbdns.CustomZone{{ + Domain: strings.TrimSuffix(zone, "."), + Records: tc.records, + NonAuthoritative: true, + }}) + + var got *dns.Msg + writer := &test.MockResponseWriter{ + WriteMsgFunc: func(m *dns.Msg) error { + got = m + return nil + }, + } + req := new(dns.Msg).SetQuestion(zone, dns.TypeA) + resolver.ServeDNS(writer, req) + + require.NotNil(t, got, "resolver must produce a response") + require.Len(t, got.Answer, len(tc.wantInOrder), + "answer count must match expected: %v", tc.wantInOrder) + for i, want := range tc.wantInOrder { + a, ok := got.Answer[i].(*dns.A) + require.True(t, ok, "answer[%d] must be an A record", i) + assert.Equal(t, want, a.A.String(), + "answer[%d] expected %s got %s", i, want, a.A.String()) + } + }) + } +} diff --git a/client/internal/dns/server.go b/client/internal/dns/server.go index e689f3586..7a35e56d8 100644 --- a/client/internal/dns/server.go +++ b/client/internal/dns/server.go @@ -301,6 +301,11 @@ func newDefaultServer( warningDelayBase: defaultWarningDelayBase, healthRefresh: make(chan struct{}, 1), } + // Wire the local resolver against the peer status recorder so it can + // suppress A/AAAA answers that point at disconnected peers (typical + // case: synthesised private-service records pointing at an embedded + // proxy peer that just went offline). + defaultServer.localResolver.SetPeerConnectivity(localPeerConnectivity{statusRecorder}) // register with root zone, handler chain takes care of the routing dnsService.RegisterMux(".", handlerChain) @@ -1386,3 +1391,25 @@ func (s *DefaultServer) PopulateManagementDomain(mgmtURL *url.URL) error { } return nil } + +// localPeerConnectivity adapts *peer.Status to local.PeerConnectivity so +// the local resolver can ask "is this IP a known peer and is it +// connected?" without taking on the peer package as a dependency. +// A nil status recorder always reports known=false so the resolver +// short-circuits to the legacy "return everything" path. +type localPeerConnectivity struct { + status *peer.Status +} + +// IsConnectedByIP looks the IP up in the peerstore and surfaces both +// the known and connected bits. Used by Resolver.filterDisconnectedPeerAnswers. +func (l localPeerConnectivity) IsConnectedByIP(ip string) (known, connected bool) { + if l.status == nil { + return false, false + } + state, ok := l.status.PeerStateByIP(ip) + if !ok { + return false, false + } + return true, state.ConnStatus == peer.StatusConnected +} diff --git a/client/internal/peer/status.go b/client/internal/peer/status.go index df746fa13..f9eb9adf5 100644 --- a/client/internal/peer/status.go +++ b/client/internal/peer/status.go @@ -185,9 +185,12 @@ func (s *StatusChangeSubscription) Events() chan map[string]RouterState { return s.eventsChan } -// Status holds a state of peers, signal, management connections and relays +// Status holds a state of peers, signal, management connections and relays. +// mux is an RWMutex so hot read paths (notably PeerStateByIP, called for +// every private-service request) don't contend against each other. +// Pure read methods take RLock; anything that mutates state takes Lock. type Status struct { - mux sync.Mutex + mux sync.RWMutex peers map[string]State changeNotify map[string]map[string]*StatusChangeSubscription // map[peerID]map[subscriptionID]*StatusChangeSubscription signalState bool @@ -283,8 +286,8 @@ func (d *Status) AddPeer(peerPubKey string, fqdn string, ip string, ipv6 string) // GetPeer adds peer to Daemon status map func (d *Status) GetPeer(peerPubKey string) (State, error) { - d.mux.Lock() - defer d.mux.Unlock() + d.mux.RLock() + defer d.mux.RUnlock() state, ok := d.peers[peerPubKey] if !ok { @@ -294,8 +297,8 @@ func (d *Status) GetPeer(peerPubKey string) (State, error) { } func (d *Status) PeerByIP(ip string) (string, bool) { - d.mux.Lock() - defer d.mux.Unlock() + d.mux.RLock() + defer d.mux.RUnlock() for _, state := range d.peers { if state.IP == ip { @@ -305,6 +308,25 @@ func (d *Status) PeerByIP(ip string) (string, bool) { return "", false } +// PeerStateByIP returns the full peer State for the given tunnel IP. +// Matches against either the IPv4 (State.IP) or IPv6 (State.IPv6) tunnel +// address so dual-stack peers are reachable on either family. Returns the +// zero State and false when no peer matches or the input is empty. +func (d *Status) PeerStateByIP(ip string) (State, bool) { + if ip == "" { + return State{}, false + } + d.mux.RLock() + defer d.mux.RUnlock() + + for _, state := range d.peers { + if (state.IP != "" && state.IP == ip) || (state.IPv6 != "" && state.IPv6 == ip) { + return state, true + } + } + return State{}, false +} + // RemovePeer removes peer from Daemon status map func (d *Status) RemovePeer(peerPubKey string) error { d.mux.Lock() @@ -702,8 +724,8 @@ func (d *Status) UnsubscribePeerStateChanges(subscription *StatusChangeSubscript // GetLocalPeerState returns the local peer state func (d *Status) GetLocalPeerState() LocalPeerState { - d.mux.Lock() - defer d.mux.Unlock() + d.mux.RLock() + defer d.mux.RUnlock() return d.localPeer.Clone() } @@ -909,8 +931,8 @@ func (d *Status) DeleteResolvedDomainsStates(domain domain.Domain) { } func (d *Status) GetRosenpassState() RosenpassState { - d.mux.Lock() - defer d.mux.Unlock() + d.mux.RLock() + defer d.mux.RUnlock() return RosenpassState{ d.rosenpassEnabled, d.rosenpassPermissive, @@ -918,14 +940,14 @@ func (d *Status) GetRosenpassState() RosenpassState { } func (d *Status) GetLazyConnection() bool { - d.mux.Lock() - defer d.mux.Unlock() + d.mux.RLock() + defer d.mux.RUnlock() return d.lazyConnectionEnabled } func (d *Status) GetManagementState() ManagementState { - d.mux.Lock() - defer d.mux.Unlock() + d.mux.RLock() + defer d.mux.RUnlock() return ManagementState{ d.mgmAddress, d.managementState, @@ -951,8 +973,8 @@ func (d *Status) UpdateLatency(pubKey string, latency time.Duration) error { // IsLoginRequired determines if a peer's login has expired. func (d *Status) IsLoginRequired() bool { - d.mux.Lock() - defer d.mux.Unlock() + d.mux.RLock() + defer d.mux.RUnlock() // if peer is connected to the management then login is not expired if d.managementState { @@ -967,8 +989,8 @@ func (d *Status) IsLoginRequired() bool { } func (d *Status) GetSignalState() SignalState { - d.mux.Lock() - defer d.mux.Unlock() + d.mux.RLock() + defer d.mux.RUnlock() return SignalState{ d.signalAddress, d.signalState, @@ -978,8 +1000,8 @@ func (d *Status) GetSignalState() SignalState { // GetRelayStates returns the stun/turn/permanent relay states func (d *Status) GetRelayStates() []relay.ProbeResult { - d.mux.Lock() - defer d.mux.Unlock() + d.mux.RLock() + defer d.mux.RUnlock() if d.relayMgr == nil { return d.relayStates } @@ -1008,8 +1030,8 @@ func (d *Status) GetRelayStates() []relay.ProbeResult { } func (d *Status) ForwardingRules() []firewall.ForwardRule { - d.mux.Lock() - defer d.mux.Unlock() + d.mux.RLock() + defer d.mux.RUnlock() if d.ingressGwMgr == nil { return nil } @@ -1018,16 +1040,16 @@ func (d *Status) ForwardingRules() []firewall.ForwardRule { } func (d *Status) GetDNSStates() []NSGroupState { - d.mux.Lock() - defer d.mux.Unlock() + d.mux.RLock() + defer d.mux.RUnlock() // shallow copy is good enough, as slices fields are currently not updated return slices.Clone(d.nsGroupStates) } func (d *Status) GetResolvedDomainsStates() map[domain.Domain]ResolvedDomainInfo { - d.mux.Lock() - defer d.mux.Unlock() + d.mux.RLock() + defer d.mux.RUnlock() return maps.Clone(d.resolvedDomainsStates) } @@ -1043,8 +1065,8 @@ func (d *Status) GetFullStatus() FullStatus { LazyConnectionEnabled: d.GetLazyConnection(), } - d.mux.Lock() - defer d.mux.Unlock() + d.mux.RLock() + defer d.mux.RUnlock() fullStatus.LocalPeerState = d.localPeer @@ -1219,8 +1241,8 @@ func (d *Status) SetWgIface(wgInterface WGIfaceStatus) { } func (d *Status) PeersStatus() (*configurer.Stats, error) { - d.mux.Lock() - defer d.mux.Unlock() + d.mux.RLock() + defer d.mux.RUnlock() if d.wgIface == nil { return nil, fmt.Errorf("wgInterface is nil, cannot retrieve peers status") } diff --git a/client/internal/peer/status_test.go b/client/internal/peer/status_test.go index 9bafca55a..8d889b0ae 100644 --- a/client/internal/peer/status_test.go +++ b/client/internal/peer/status_test.go @@ -63,6 +63,33 @@ func TestUpdatePeerState(t *testing.T) { assert.Equal(t, ip, state.IP, "ip should be equal") } +func TestStatus_PeerStateByIP(t *testing.T) { + status := NewRecorder("https://mgm") + req := require.New(t) + + req.NoError(status.AddPeer("pk-1", "peer-1.netbird", "100.64.0.10", "")) + req.NoError(status.AddPeer("pk-2", "peer-2.netbird", "100.64.0.11", "")) + + state, ok := status.PeerStateByIP("100.64.0.10") + req.True(ok, "known tunnel IP should resolve to a peer state") + req.Equal("pk-1", state.PubKey, "matching state must carry the right pub key") + req.Equal("peer-1.netbird", state.FQDN, "matching state must carry the right FQDN") + + _, ok = status.PeerStateByIP("100.64.0.99") + req.False(ok, "unknown IP must report ok=false") +} + +func TestStatus_PeerStateByIP_MatchesIPv6(t *testing.T) { + status := NewRecorder("https://mgm") + req := require.New(t) + + req.NoError(status.AddPeer("pk-1", "peer-1.netbird", "100.64.0.10", "fd00::1")) + + state, ok := status.PeerStateByIP("fd00::1") + req.True(ok, "IPv6-only match must resolve to the peer state") + req.Equal("pk-1", state.PubKey, "matching state must carry the right pub key") +} + func TestStatus_UpdatePeerFQDN(t *testing.T) { key := "abc" fqdn := "peer-a.netbird.local" diff --git a/management/internals/modules/peers/manager.go b/management/internals/modules/peers/manager.go index c913efb92..75ae8de91 100644 --- a/management/internals/modules/peers/manager.go +++ b/management/internals/modules/peers/manager.go @@ -5,6 +5,7 @@ package peers import ( "context" "fmt" + "net" "time" "github.com/rs/xid" @@ -35,6 +36,14 @@ type Manager interface { SetAccountManager(accountManager account.Manager) GetPeerID(ctx context.Context, peerKey string) (string, error) CreateProxyPeer(ctx context.Context, accountID string, peerKey string, cluster string) error + // GetPeerByTunnelIP looks up a peer in accountID by its WireGuard tunnel IP. + // Returns nil with an error when no match exists. No permission check; + // callers (the proxy's ValidateTunnelPeer RPC) are trusted server components. + GetPeerByTunnelIP(ctx context.Context, accountID string, ip net.IP) (*peer.Peer, error) + // GetPeerWithGroups returns the peer and the list of *types.Group it belongs + // to. Used by the proxy's auth path to authorise a request by the calling + // peer's group memberships. + GetPeerWithGroups(ctx context.Context, accountID, peerID string) (*peer.Peer, []*types.Group, error) } type managerImpl struct { @@ -99,6 +108,26 @@ func (m *managerImpl) GetPeersByGroupIDs(ctx context.Context, accountID string, return m.store.GetPeersByGroupIDs(ctx, accountID, groupsIDs) } +// GetPeerByTunnelIP delegates to the store's indexed lookup. +func (m *managerImpl) GetPeerByTunnelIP(ctx context.Context, accountID string, ip net.IP) (*peer.Peer, error) { + return m.store.GetPeerByIP(ctx, store.LockingStrengthNone, accountID, ip) +} + +// GetPeerWithGroups returns the peer plus its group memberships. Any store +// error returns (nil, nil, err) so callers never receive a valid peer +// alongside a non-nil error. +func (m *managerImpl) GetPeerWithGroups(ctx context.Context, accountID, peerID string) (*peer.Peer, []*types.Group, error) { + p, err := m.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID) + if err != nil { + return nil, nil, err + } + groups, err := m.store.GetPeerGroups(ctx, store.LockingStrengthNone, accountID, peerID) + if err != nil { + return nil, nil, err + } + return p, groups, nil +} + func (m *managerImpl) DeletePeers(ctx context.Context, accountID string, peerIDs []string, userID string, checkConnected bool) error { settings, err := m.store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID) if err != nil { diff --git a/management/internals/modules/peers/manager_mock.go b/management/internals/modules/peers/manager_mock.go index d6c9ebacc..3836ac909 100644 --- a/management/internals/modules/peers/manager_mock.go +++ b/management/internals/modules/peers/manager_mock.go @@ -6,6 +6,7 @@ package peers import ( context "context" + net "net" reflect "reflect" gomock "github.com/golang/mock/gomock" @@ -13,6 +14,7 @@ import ( account "github.com/netbirdio/netbird/management/server/account" integrated_validator "github.com/netbirdio/netbird/management/server/integrations/integrated_validator" peer "github.com/netbirdio/netbird/management/server/peer" + types "github.com/netbirdio/netbird/management/server/types" ) // MockManager is a mock of Manager interface. @@ -38,6 +40,20 @@ func (m *MockManager) EXPECT() *MockManagerMockRecorder { return m.recorder } +// CreateProxyPeer mocks base method. +func (m *MockManager) CreateProxyPeer(ctx context.Context, accountID, peerKey, cluster string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CreateProxyPeer", ctx, accountID, peerKey, cluster) + ret0, _ := ret[0].(error) + return ret0 +} + +// CreateProxyPeer indicates an expected call of CreateProxyPeer. +func (mr *MockManagerMockRecorder) CreateProxyPeer(ctx, accountID, peerKey, cluster interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateProxyPeer", reflect.TypeOf((*MockManager)(nil).CreateProxyPeer), ctx, accountID, peerKey, cluster) +} + // DeletePeers mocks base method. func (m *MockManager) DeletePeers(ctx context.Context, accountID string, peerIDs []string, userID string, checkConnected bool) error { m.ctrl.T.Helper() @@ -97,6 +113,21 @@ func (mr *MockManagerMockRecorder) GetPeerAccountID(ctx, peerID interface{}) *go return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeerAccountID", reflect.TypeOf((*MockManager)(nil).GetPeerAccountID), ctx, peerID) } +// GetPeerByTunnelIP mocks base method. +func (m *MockManager) GetPeerByTunnelIP(ctx context.Context, accountID string, ip net.IP) (*peer.Peer, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetPeerByTunnelIP", ctx, accountID, ip) + ret0, _ := ret[0].(*peer.Peer) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetPeerByTunnelIP indicates an expected call of GetPeerByTunnelIP. +func (mr *MockManagerMockRecorder) GetPeerByTunnelIP(ctx, accountID, ip interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeerByTunnelIP", reflect.TypeOf((*MockManager)(nil).GetPeerByTunnelIP), ctx, accountID, ip) +} + // GetPeerID mocks base method. func (m *MockManager) GetPeerID(ctx context.Context, peerKey string) (string, error) { m.ctrl.T.Helper() @@ -112,6 +143,22 @@ func (mr *MockManagerMockRecorder) GetPeerID(ctx, peerKey interface{}) *gomock.C return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeerID", reflect.TypeOf((*MockManager)(nil).GetPeerID), ctx, peerKey) } +// GetPeerWithGroups mocks base method. +func (m *MockManager) GetPeerWithGroups(ctx context.Context, accountID, peerID string) (*peer.Peer, []*types.Group, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetPeerWithGroups", ctx, accountID, peerID) + ret0, _ := ret[0].(*peer.Peer) + ret1, _ := ret[1].([]*types.Group) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 +} + +// GetPeerWithGroups indicates an expected call of GetPeerWithGroups. +func (mr *MockManagerMockRecorder) GetPeerWithGroups(ctx, accountID, peerID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeerWithGroups", reflect.TypeOf((*MockManager)(nil).GetPeerWithGroups), ctx, accountID, peerID) +} + // GetPeersByGroupIDs mocks base method. func (m *MockManager) GetPeersByGroupIDs(ctx context.Context, accountID string, groupsIDs []string) ([]*peer.Peer, error) { m.ctrl.T.Helper() @@ -162,17 +209,3 @@ func (mr *MockManagerMockRecorder) SetNetworkMapController(networkMapController mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetNetworkMapController", reflect.TypeOf((*MockManager)(nil).SetNetworkMapController), networkMapController) } - -// CreateProxyPeer mocks base method. -func (m *MockManager) CreateProxyPeer(ctx context.Context, accountID string, peerKey string, cluster string) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CreateProxyPeer", ctx, accountID, peerKey, cluster) - ret0, _ := ret[0].(error) - return ret0 -} - -// CreateProxyPeer indicates an expected call of CreateProxyPeer. -func (mr *MockManagerMockRecorder) CreateProxyPeer(ctx, accountID, peerKey, cluster interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateProxyPeer", reflect.TypeOf((*MockManager)(nil).CreateProxyPeer), ctx, accountID, peerKey, cluster) -} diff --git a/management/internals/modules/reverseproxy/domain/domain.go b/management/internals/modules/reverseproxy/domain/domain.go index f65e31a07..08d7ad19b 100644 --- a/management/internals/modules/reverseproxy/domain/domain.go +++ b/management/internals/modules/reverseproxy/domain/domain.go @@ -23,6 +23,8 @@ type Domain struct { // SupportsCrowdSec is populated at query time from proxy cluster capabilities. // Not persisted. SupportsCrowdSec *bool `gorm:"-"` + // SupportsPrivate is populated at query time from proxy cluster capabilities. Not persisted. + SupportsPrivate *bool `gorm:"-"` } // EventMeta returns activity event metadata for a domain diff --git a/management/internals/modules/reverseproxy/domain/manager/api.go b/management/internals/modules/reverseproxy/domain/manager/api.go index 4493ef0ad..f01329010 100644 --- a/management/internals/modules/reverseproxy/domain/manager/api.go +++ b/management/internals/modules/reverseproxy/domain/manager/api.go @@ -49,6 +49,7 @@ func domainToApi(d *domain.Domain) api.ReverseProxyDomain { SupportsCustomPorts: d.SupportsCustomPorts, RequireSubdomain: d.RequireSubdomain, SupportsCrowdsec: d.SupportsCrowdSec, + SupportsPrivate: d.SupportsPrivate, } if d.TargetCluster != "" { resp.TargetCluster = &d.TargetCluster diff --git a/management/internals/modules/reverseproxy/domain/manager/manager.go b/management/internals/modules/reverseproxy/domain/manager/manager.go index 2790b5f20..2a026c7fa 100644 --- a/management/internals/modules/reverseproxy/domain/manager/manager.go +++ b/management/internals/modules/reverseproxy/domain/manager/manager.go @@ -35,6 +35,7 @@ type proxyManager interface { ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool ClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool ClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool + ClusterSupportsPrivate(ctx context.Context, clusterAddr string) *bool } type Manager struct { @@ -93,6 +94,7 @@ func (m Manager) GetDomains(ctx context.Context, accountID, userID string) ([]*d d.SupportsCustomPorts = m.proxyManager.ClusterSupportsCustomPorts(ctx, cluster) d.RequireSubdomain = m.proxyManager.ClusterRequireSubdomain(ctx, cluster) d.SupportsCrowdSec = m.proxyManager.ClusterSupportsCrowdSec(ctx, cluster) + d.SupportsPrivate = m.proxyManager.ClusterSupportsPrivate(ctx, cluster) ret = append(ret, d) } @@ -109,6 +111,7 @@ func (m Manager) GetDomains(ctx context.Context, accountID, userID string) ([]*d if d.TargetCluster != "" { cd.SupportsCustomPorts = m.proxyManager.ClusterSupportsCustomPorts(ctx, d.TargetCluster) cd.SupportsCrowdSec = m.proxyManager.ClusterSupportsCrowdSec(ctx, d.TargetCluster) + cd.SupportsPrivate = m.proxyManager.ClusterSupportsPrivate(ctx, d.TargetCluster) } // Custom domains never require a subdomain by default since // the account owns them and should be able to use the bare domain. diff --git a/management/internals/modules/reverseproxy/domain/manager/manager_test.go b/management/internals/modules/reverseproxy/domain/manager/manager_test.go index 5e7bbfc36..53a8dedae 100644 --- a/management/internals/modules/reverseproxy/domain/manager/manager_test.go +++ b/management/internals/modules/reverseproxy/domain/manager/manager_test.go @@ -10,7 +10,7 @@ import ( ) type mockProxyManager struct { - getActiveClusterAddressesFunc func(ctx context.Context) ([]string, error) + getActiveClusterAddressesFunc func(ctx context.Context) ([]string, error) getActiveClusterAddressesForAccountFunc func(ctx context.Context, accountID string) ([]string, error) } @@ -40,6 +40,10 @@ func (m *mockProxyManager) ClusterSupportsCrowdSec(_ context.Context, _ string) return nil } +func (m *mockProxyManager) ClusterSupportsPrivate(_ context.Context, _ string) *bool { + return nil +} + func TestGetClusterAllowList_BYOPMergedWithPublic(t *testing.T) { pm := &mockProxyManager{ getActiveClusterAddressesForAccountFunc: func(_ context.Context, accID string) ([]string, error) { @@ -151,4 +155,3 @@ func TestGetClusterAllowList_PublicEmpty_BYOPOnly(t *testing.T) { require.NoError(t, err) assert.Equal(t, []string{"byop.example.com"}, result) } - diff --git a/management/internals/modules/reverseproxy/proxy/manager.go b/management/internals/modules/reverseproxy/proxy/manager.go index 07ea6f0ab..22f1007ec 100644 --- a/management/internals/modules/reverseproxy/proxy/manager.go +++ b/management/internals/modules/reverseproxy/proxy/manager.go @@ -19,6 +19,7 @@ type Manager interface { ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool ClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool ClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool + ClusterSupportsPrivate(ctx context.Context, clusterAddr string) *bool CleanupStale(ctx context.Context, inactivityDuration time.Duration) error GetAccountProxy(ctx context.Context, accountID string) (*Proxy, error) CountAccountProxies(ctx context.Context, accountID string) (int64, error) diff --git a/management/internals/modules/reverseproxy/proxy/manager/manager.go b/management/internals/modules/reverseproxy/proxy/manager/manager.go index 510500e0c..943766004 100644 --- a/management/internals/modules/reverseproxy/proxy/manager/manager.go +++ b/management/internals/modules/reverseproxy/proxy/manager/manager.go @@ -21,6 +21,7 @@ type store interface { GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool GetClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool GetClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool + GetClusterSupportsPrivate(ctx context.Context, clusterAddr string) *bool CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error GetProxyByAccountID(ctx context.Context, accountID string) (*proxy.Proxy, error) CountProxiesByAccountID(ctx context.Context, accountID string) (int64, error) @@ -137,6 +138,11 @@ func (m Manager) ClusterSupportsCrowdSec(ctx context.Context, clusterAddr string return m.store.GetClusterSupportsCrowdSec(ctx, clusterAddr) } +// ClusterSupportsPrivate reports whether any active proxy claims the private capability (nil = unreported). +func (m Manager) ClusterSupportsPrivate(ctx context.Context, clusterAddr string) *bool { + return m.store.GetClusterSupportsPrivate(ctx, clusterAddr) +} + // CleanupStale removes proxies that haven't sent heartbeat in the specified duration func (m *Manager) CleanupStale(ctx context.Context, inactivityDuration time.Duration) error { if err := m.store.CleanupStaleProxies(ctx, inactivityDuration); err != nil { @@ -178,4 +184,3 @@ func (m *Manager) DeleteAccountCluster(ctx context.Context, clusterAddress, acco } return nil } - diff --git a/management/internals/modules/reverseproxy/proxy/manager/manager_test.go b/management/internals/modules/reverseproxy/proxy/manager/manager_test.go index 3436216b4..5c44470a3 100644 --- a/management/internals/modules/reverseproxy/proxy/manager/manager_test.go +++ b/management/internals/modules/reverseproxy/proxy/manager/manager_test.go @@ -15,16 +15,16 @@ import ( ) type mockStore struct { - saveProxyFunc func(ctx context.Context, p *proxy.Proxy) error - disconnectProxyFunc func(ctx context.Context, proxyID, sessionID string) error - updateProxyHeartbeatFunc func(ctx context.Context, p *proxy.Proxy) error - getActiveProxyClusterAddressesFunc func(ctx context.Context) ([]string, error) - getActiveProxyClusterAddressesForAccFunc func(ctx context.Context, accountID string) ([]string, error) - cleanupStaleProxiesFunc func(ctx context.Context, d time.Duration) error - getProxyByAccountIDFunc func(ctx context.Context, accountID string) (*proxy.Proxy, error) - countProxiesByAccountIDFunc func(ctx context.Context, accountID string) (int64, error) - isClusterAddressConflictingFunc func(ctx context.Context, clusterAddress, accountID string) (bool, error) - deleteAccountClusterFunc func(ctx context.Context, clusterAddress, accountID string) error + saveProxyFunc func(ctx context.Context, p *proxy.Proxy) error + disconnectProxyFunc func(ctx context.Context, proxyID, sessionID string) error + updateProxyHeartbeatFunc func(ctx context.Context, p *proxy.Proxy) error + getActiveProxyClusterAddressesFunc func(ctx context.Context) ([]string, error) + getActiveProxyClusterAddressesForAccFunc func(ctx context.Context, accountID string) ([]string, error) + cleanupStaleProxiesFunc func(ctx context.Context, d time.Duration) error + getProxyByAccountIDFunc func(ctx context.Context, accountID string) (*proxy.Proxy, error) + countProxiesByAccountIDFunc func(ctx context.Context, accountID string) (int64, error) + isClusterAddressConflictingFunc func(ctx context.Context, clusterAddress, accountID string) (bool, error) + deleteAccountClusterFunc func(ctx context.Context, clusterAddress, accountID string) error } func (m *mockStore) SaveProxy(ctx context.Context, p *proxy.Proxy) error { @@ -99,6 +99,9 @@ func (m *mockStore) GetClusterRequireSubdomain(_ context.Context, _ string) *boo func (m *mockStore) GetClusterSupportsCrowdSec(_ context.Context, _ string) *bool { return nil } +func (m *mockStore) GetClusterSupportsPrivate(_ context.Context, _ string) *bool { + return nil +} func newTestManager(s store) *Manager { meter := noop.NewMeterProvider().Meter("test") diff --git a/management/internals/modules/reverseproxy/proxy/manager_mock.go b/management/internals/modules/reverseproxy/proxy/manager_mock.go index a0e360a1b..d2be46c9f 100644 --- a/management/internals/modules/reverseproxy/proxy/manager_mock.go +++ b/management/internals/modules/reverseproxy/proxy/manager_mock.go @@ -92,6 +92,20 @@ func (mr *MockManagerMockRecorder) ClusterSupportsCrowdSec(ctx, clusterAddr inte return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClusterSupportsCrowdSec", reflect.TypeOf((*MockManager)(nil).ClusterSupportsCrowdSec), ctx, clusterAddr) } +// ClusterSupportsPrivate mocks base method. +func (m *MockManager) ClusterSupportsPrivate(ctx context.Context, clusterAddr string) *bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ClusterSupportsPrivate", ctx, clusterAddr) + ret0, _ := ret[0].(*bool) + return ret0 +} + +// ClusterSupportsPrivate indicates an expected call of ClusterSupportsPrivate. +func (mr *MockManagerMockRecorder) ClusterSupportsPrivate(ctx, clusterAddr interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClusterSupportsPrivate", reflect.TypeOf((*MockManager)(nil).ClusterSupportsPrivate), ctx, clusterAddr) +} + // Connect mocks base method. func (m *MockManager) Connect(ctx context.Context, proxyID, sessionID, clusterAddress, ipAddress string, accountID *string, capabilities *Capabilities) (*Proxy, error) { m.ctrl.T.Helper() diff --git a/management/internals/modules/reverseproxy/proxy/proxy.go b/management/internals/modules/reverseproxy/proxy/proxy.go index 9da7910df..4404b0d24 100644 --- a/management/internals/modules/reverseproxy/proxy/proxy.go +++ b/management/internals/modules/reverseproxy/proxy/proxy.go @@ -20,6 +20,9 @@ type Capabilities struct { RequireSubdomain *bool // SupportsCrowdsec indicates whether this proxy has CrowdSec configured. SupportsCrowdsec *bool + // Private indicates whether this proxy supports inbound access via Wireguard + // tunnel and netbird-only authentication policies + Private *bool } // Proxy represents a reverse proxy instance @@ -67,10 +70,9 @@ type Cluster struct { Type ClusterType Online bool ConnectedProxies int - // Capability flags. *bool because nil means "no proxy reported a - // capability for this cluster" โ€” the dashboard renders these as - // unknown rather than false. + // *bool: nil = no proxy reported the capability; the dashboard renders that as unknown. SupportsCustomPorts *bool RequireSubdomain *bool SupportsCrowdSec *bool + Private *bool } diff --git a/management/internals/modules/reverseproxy/service/manager/api.go b/management/internals/modules/reverseproxy/service/manager/api.go index 9d93d52ee..7298b4261 100644 --- a/management/internals/modules/reverseproxy/service/manager/api.go +++ b/management/internals/modules/reverseproxy/service/manager/api.go @@ -204,6 +204,7 @@ func (h *handler) getClusters(w http.ResponseWriter, r *http.Request) { SupportsCustomPorts: c.SupportsCustomPorts, RequireSubdomain: c.RequireSubdomain, SupportsCrowdsec: c.SupportsCrowdSec, + Private: c.Private, }) } diff --git a/management/internals/modules/reverseproxy/service/manager/manager.go b/management/internals/modules/reverseproxy/service/manager/manager.go index ca0c5540f..f0ac68ed0 100644 --- a/management/internals/modules/reverseproxy/service/manager/manager.go +++ b/management/internals/modules/reverseproxy/service/manager/manager.go @@ -82,6 +82,7 @@ type CapabilityProvider interface { ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool ClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool ClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool + ClusterSupportsPrivate(ctx context.Context, clusterAddr string) *bool } type Manager struct { @@ -136,6 +137,7 @@ func (m *Manager) GetClusters(ctx context.Context, accountID, userID string) ([] clusters[i].SupportsCustomPorts = m.capabilities.ClusterSupportsCustomPorts(ctx, clusters[i].Address) clusters[i].RequireSubdomain = m.capabilities.ClusterRequireSubdomain(ctx, clusters[i].Address) clusters[i].SupportsCrowdSec = m.capabilities.ClusterSupportsCrowdSec(ctx, clusters[i].Address) + clusters[i].Private = m.capabilities.ClusterSupportsPrivate(ctx, clusters[i].Address) } return clusters, nil @@ -208,6 +210,9 @@ func (m *Manager) replaceHostByLookup(ctx context.Context, accountID string, s * target.Host = resource.Domain case service.TargetTypeSubnet: // For subnets we do not do any lookups on the resource + case service.TargetTypeCluster: + // Cluster targets carry the upstream address on target_id; the + // proxy resolves the destination at request time. default: return fmt.Errorf("unknown target type: %s", target.TargetType) } @@ -779,6 +784,10 @@ func validateTargetReferences(ctx context.Context, transaction store.Store, acco if err := validateResourceTarget(ctx, transaction, accountID, target); err != nil { return err } + case service.TargetTypeCluster: + if err := validateClusterTarget(target); err != nil { + return err + } default: return status.Errorf(status.InvalidArgument, "unknown target type %q for target %q", target.TargetType, target.TargetId) } @@ -786,6 +795,13 @@ func validateTargetReferences(ctx context.Context, transaction store.Store, acco return nil } +func validateClusterTarget(target *service.Target) error { + if !target.Options.DirectUpstream { + return status.Errorf(status.InvalidArgument, "cluster target %s has direct upstream disabled", target.Host) + } + return nil +} + func validatePeerTarget(ctx context.Context, transaction store.Store, accountID string, target *service.Target) error { if _, err := transaction.GetPeerByID(ctx, store.LockingStrengthShare, accountID, target.TargetId); err != nil { if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound { @@ -962,12 +978,14 @@ func (m *Manager) ReloadAllServicesForAccount(ctx context.Context, accountID str return fmt.Errorf("failed to get services: %w", err) } + oidcCfg := m.proxyController.GetOIDCValidationConfig() + for _, s := range services { err = m.replaceHostByLookup(ctx, accountID, s) if err != nil { return fmt.Errorf("failed to replace host by lookup for service %s: %w", s.ID, err) } - m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Update, "", m.proxyController.GetOIDCValidationConfig()), s.ProxyCluster) + m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Update, "", oidcCfg), s.ProxyCluster) } return nil diff --git a/management/internals/modules/reverseproxy/service/manager/manager_test.go b/management/internals/modules/reverseproxy/service/manager/manager_test.go index 47b8b3865..f3ab89a25 100644 --- a/management/internals/modules/reverseproxy/service/manager/manager_test.go +++ b/management/internals/modules/reverseproxy/service/manager/manager_test.go @@ -1344,3 +1344,66 @@ func TestValidateSubdomainRequirement(t *testing.T) { }) } } + +func TestValidateTargetReferences_ClusterTargetSkipsLookup(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t) + mockStore := store.NewMockStore(ctrl) + accountID := "test-account" + + // No peer or resource lookups must be issued for cluster targets. + targets := []*rpservice.Target{ + { + TargetId: "eu.proxy.netbird.io", + TargetType: rpservice.TargetTypeCluster, + Options: rpservice.TargetOptions{DirectUpstream: true}, + }, + } + require.NoError(t, validateTargetReferences(ctx, mockStore, accountID, targets), "cluster target must validate without store lookups") +} + +// TestValidateTargetReferences_ClusterTargetRequiresDirectUpstream pins the +// store-side check that cluster targets must opt into the host-stack dial +// path. Without DirectUpstream the proxy would route this target through +// the embedded NetBird client and fail on every request. +func TestValidateTargetReferences_ClusterTargetRequiresDirectUpstream(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t) + mockStore := store.NewMockStore(ctrl) + accountID := "test-account" + + targets := []*rpservice.Target{ + { + TargetId: "eu.proxy.netbird.io", + TargetType: rpservice.TargetTypeCluster, + Host: "backend.lan", + }, + } + err := validateTargetReferences(ctx, mockStore, accountID, targets) + require.Error(t, err, "cluster target without direct_upstream must be rejected") + assert.ErrorContains(t, err, "direct upstream disabled") +} + +func TestReplaceHostByLookup_SkipsClusterTarget(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t) + mockStore := store.NewMockStore(ctrl) + accountID := "test-account" + + mgr := &Manager{store: mockStore} + + svc := &rpservice.Service{ + ID: "svc-1", + AccountID: accountID, + Targets: []*rpservice.Target{ + { + TargetId: "eu.proxy.netbird.io", + TargetType: rpservice.TargetTypeCluster, + Host: "127.0.0.1", + }, + }, + } + + require.NoError(t, mgr.replaceHostByLookup(ctx, accountID, svc), "cluster target must not trigger peer/resource lookup") + assert.Equal(t, "127.0.0.1", svc.Targets[0].Host, "operator-supplied host must be preserved for cluster target") +} diff --git a/management/internals/modules/reverseproxy/service/service.go b/management/internals/modules/reverseproxy/service/service.go index 166a66a5f..27f6d914d 100644 --- a/management/internals/modules/reverseproxy/service/service.go +++ b/management/internals/modules/reverseproxy/service/service.go @@ -45,10 +45,11 @@ const ( StatusCertificateFailed Status = "certificate_failed" StatusError Status = "error" - TargetTypePeer TargetType = "peer" - TargetTypeHost TargetType = "host" - TargetTypeDomain TargetType = "domain" - TargetTypeSubnet TargetType = "subnet" + TargetTypePeer TargetType = "peer" + TargetTypeHost TargetType = "host" + TargetTypeDomain TargetType = "domain" + TargetTypeSubnet TargetType = "subnet" + TargetTypeCluster TargetType = "cluster" SourcePermanent = "permanent" SourceEphemeral = "ephemeral" @@ -60,6 +61,11 @@ type TargetOptions struct { SessionIdleTimeout time.Duration `json:"session_idle_timeout,omitempty"` PathRewrite PathRewriteMode `json:"path_rewrite,omitempty"` CustomHeaders map[string]string `gorm:"serializer:json" json:"custom_headers,omitempty"` + // DirectUpstream bypasses the proxy's embedded NetBird client and dials + // the target via the proxy host's network stack. Useful for upstreams + // reachable without WireGuard (public APIs, LAN services, localhost + // sidecars). Default false. + DirectUpstream bool `json:"direct_upstream,omitempty"` } type Target struct { @@ -67,7 +73,7 @@ type Target struct { AccountID string `gorm:"index:idx_target_account;not null" json:"-"` ServiceID string `gorm:"index:idx_service_targets;not null" json:"-"` Path *string `json:"path,omitempty"` - Host string `json:"host"` // the Host field is only used for subnet targets, otherwise ignored + Host string `json:"host"` Port uint16 `gorm:"index:idx_target_port" json:"port"` Protocol string `gorm:"index:idx_target_protocol" json:"protocol"` TargetId string `gorm:"index:idx_target_id" json:"target_id"` @@ -200,6 +206,10 @@ type Service struct { Mode string `gorm:"default:'http'"` ListenPort uint16 PortAutoAssigned bool + // Private marks the service as NetBird-only: auth via ValidateTunnelPeer against AccessGroups instead of SSO. HTTP-only. + Private bool + // AccessGroups is the group ID allowlist for inbound peers on private services. Mutually exclusive with bearer SSO. + AccessGroups []string `json:"access_groups,omitempty" gorm:"serializer:json"` } // InitNewRecord generates a new unique ID and resets metadata for a newly created @@ -299,6 +309,12 @@ func (s *Service) ToAPIResponse() *api.Service { Mode: &mode, ListenPort: &listenPort, PortAutoAssigned: &s.PortAutoAssigned, + Private: &s.Private, + } + + if len(s.AccessGroups) > 0 { + groups := append([]string(nil), s.AccessGroups...) + resp.AccessGroups = &groups } if s.ProxyCluster != "" { @@ -308,6 +324,7 @@ func (s *Service) ToAPIResponse() *api.Service { return resp } +// ToProtoMapping converts the service into the wire format the proxy consumes. func (s *Service) ToProtoMapping(operation Operation, authToken string, oidcConfig proxy.OIDCValidationConfig) *proto.ProxyMapping { pathMappings := s.buildPathMappings() @@ -349,6 +366,7 @@ func (s *Service) ToProtoMapping(operation Operation, authToken string, oidcConf RewriteRedirects: s.RewriteRedirects, Mode: s.Mode, ListenPort: int32(s.ListenPort), //nolint:gosec + Private: s.Private, } if r := restrictionsToProto(s.Restrictions); r != nil { @@ -455,7 +473,8 @@ func pathRewriteToProto(mode PathRewriteMode) proto.PathRewriteMode { } func targetOptionsToAPI(opts TargetOptions) *api.ServiceTargetOptions { - if !opts.SkipTLSVerify && opts.RequestTimeout == 0 && opts.SessionIdleTimeout == 0 && opts.PathRewrite == "" && len(opts.CustomHeaders) == 0 { + if !opts.SkipTLSVerify && opts.RequestTimeout == 0 && opts.SessionIdleTimeout == 0 && + opts.PathRewrite == "" && len(opts.CustomHeaders) == 0 && !opts.DirectUpstream { return nil } apiOpts := &api.ServiceTargetOptions{} @@ -477,17 +496,22 @@ func targetOptionsToAPI(opts TargetOptions) *api.ServiceTargetOptions { if len(opts.CustomHeaders) > 0 { apiOpts.CustomHeaders = &opts.CustomHeaders } + if opts.DirectUpstream { + apiOpts.DirectUpstream = &opts.DirectUpstream + } return apiOpts } func targetOptionsToProto(opts TargetOptions) *proto.PathTargetOptions { - if !opts.SkipTLSVerify && opts.PathRewrite == "" && opts.RequestTimeout == 0 && len(opts.CustomHeaders) == 0 { + if !opts.SkipTLSVerify && opts.PathRewrite == "" && opts.RequestTimeout == 0 && + len(opts.CustomHeaders) == 0 && !opts.DirectUpstream { return nil } popts := &proto.PathTargetOptions{ - SkipTlsVerify: opts.SkipTLSVerify, - PathRewrite: pathRewriteToProto(opts.PathRewrite), - CustomHeaders: opts.CustomHeaders, + SkipTlsVerify: opts.SkipTLSVerify, + PathRewrite: pathRewriteToProto(opts.PathRewrite), + CustomHeaders: opts.CustomHeaders, + DirectUpstream: opts.DirectUpstream, } if opts.RequestTimeout != 0 { popts.RequestTimeout = durationpb.New(opts.RequestTimeout) @@ -537,6 +561,9 @@ func targetOptionsFromAPI(idx int, o *api.ServiceTargetOptions) (TargetOptions, if o.CustomHeaders != nil { opts.CustomHeaders = *o.CustomHeaders } + if o.DirectUpstream != nil { + opts.DirectUpstream = *o.DirectUpstream + } return opts, nil } @@ -551,6 +578,14 @@ func (s *Service) FromAPIRequest(req *api.ServiceRequest, accountID string) erro if req.ListenPort != nil { s.ListenPort = uint16(*req.ListenPort) //nolint:gosec } + if req.Private != nil { + s.Private = *req.Private + } + if req.AccessGroups != nil { + s.AccessGroups = append([]string(nil), *req.AccessGroups...) + } else { + s.AccessGroups = nil + } targets, err := targetsFromAPI(accountID, req.Targets) if err != nil { @@ -740,6 +775,9 @@ func (s *Service) Validate() error { if err := validateAccessRestrictions(&s.Restrictions); err != nil { return err } + if err := s.validatePrivateRequirements(); err != nil { + return err + } switch s.Mode { case ModeHTTP: @@ -753,6 +791,23 @@ func (s *Service) Validate() error { } } +// validatePrivateRequirements enforces the private-service contract: HTTP mode, โ‰ฅ1 access group, no bearer auth. +func (s *Service) validatePrivateRequirements() error { + if !s.Private { + return nil + } + if s.Mode != "" && s.Mode != ModeHTTP { + return fmt.Errorf("private services only support HTTP mode, got %q", s.Mode) + } + if len(s.AccessGroups) == 0 { + return errors.New("private services require at least one access group") + } + if s.Auth.BearerAuth != nil && s.Auth.BearerAuth.Enabled { + return errors.New("private services cannot enable bearer auth (SSO): NetBird-only access and SSO are mutually exclusive") + } + return nil +} + func (s *Service) validateHTTPMode() error { if s.Domain == "" { return errors.New("service domain is required") @@ -799,11 +854,21 @@ func (s *Service) validateHTTPTargets() error { for i, target := range s.Targets { switch target.TargetType { case TargetTypePeer, TargetTypeHost, TargetTypeDomain: - // host field will be ignored + // Host is normally overwritten by replaceHostByLookup with the + // resolved peer IP / resource address; operator-supplied values + // are honored only when DirectUpstream is set. Validate the + // override here so misconfigured hosts fail fast at API time. + if err := validateDirectUpstreamHost(i, target); err != nil { + return err + } case TargetTypeSubnet: if target.Host == "" { return fmt.Errorf("target %d has empty host but target_type is %q", i, target.TargetType) } + case TargetTypeCluster: + if err := validateClusterTarget(i, target); err != nil { + return err + } default: return fmt.Errorf("target %d has invalid target_type %q", i, target.TargetType) } @@ -821,25 +886,67 @@ func (s *Service) validateHTTPTargets() error { return nil } +// validateClusterTarget cluster targets should not have empty hosts and should have direct upstream enabled. +func validateClusterTarget(idx int, target *Target) error { + host := strings.TrimSpace(target.Host) + if host == "" { + return fmt.Errorf("target %d: has empty host", idx) + } + if !target.Options.DirectUpstream { + return fmt.Errorf("target %d: %s has direct upstream disabled", idx, target.Host) + } + return validateDirectUpstreamHost(idx, target) +} + +// validateDirectUpstreamHost validates the operator-supplied Host on a +// peer/host/domain target when DirectUpstream is set. Empty Host is +// allowed โ€” the lookup fills in the default peer IP / resource address. +// Without DirectUpstream the Host value is silently overwritten by +// replaceHostByLookup, so we don't validate it (preserves the historical +// behaviour where APIs accepted any value and dropped it). Non-empty +// Host with DirectUpstream must look like a hostname or IP and must +// not carry a port (port lives on Target.Port). +func validateDirectUpstreamHost(idx int, target *Target) error { + if !target.Options.DirectUpstream { + return nil + } + host := strings.TrimSpace(target.Host) + if host == "" { + return nil + } + if strings.ContainsAny(host, " \t/") { + return fmt.Errorf("target %d: host %q contains invalid characters", idx, host) + } + if _, _, err := net.SplitHostPort(host); err == nil { + return fmt.Errorf("target %d: host %q must not include a port (set target.port instead)", idx, host) + } + return nil +} + func (s *Service) validateL4Target(target *Target) error { // L4 services have a single target; per-target disable is meaningless // (use the service-level Enabled flag instead). Force it on so that // buildPathMappings always includes the target in the proto. target.Enabled = true - if target.Port == 0 { - return errors.New("target port is required for L4 services") - } if target.TargetId == "" { return errors.New("target_id is required for L4 services") } + if target.TargetType != TargetTypeCluster && target.Port == 0 { + return errors.New("target port is required for L4 services") + } switch target.TargetType { case TargetTypePeer, TargetTypeHost, TargetTypeDomain: - // OK + if err := validateDirectUpstreamHost(0, target); err != nil { + return err + } case TargetTypeSubnet: if target.Host == "" { return errors.New("target host is required for subnet targets") } + case TargetTypeCluster: + // target_id carries the cluster address; the proxy resolves + // the upstream at request time. default: return fmt.Errorf("invalid target_type %q for L4 service", target.TargetType) } @@ -1174,6 +1281,11 @@ func (s *Service) Copy() *Service { } } + var accessGroups []string + if len(s.AccessGroups) > 0 { + accessGroups = append([]string(nil), s.AccessGroups...) + } + return &Service{ ID: s.ID, AccountID: s.AccountID, @@ -1195,6 +1307,8 @@ func (s *Service) Copy() *Service { Mode: s.Mode, ListenPort: s.ListenPort, PortAutoAssigned: s.PortAutoAssigned, + Private: s.Private, + AccessGroups: accessGroups, } } diff --git a/management/internals/modules/reverseproxy/service/service_test.go b/management/internals/modules/reverseproxy/service/service_test.go index f1349ff65..ba63d76ed 100644 --- a/management/internals/modules/reverseproxy/service/service_test.go +++ b/management/internals/modules/reverseproxy/service/service_test.go @@ -12,6 +12,7 @@ import ( "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy" "github.com/netbirdio/netbird/shared/hash/argon2id" + "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/proto" ) @@ -1116,3 +1117,191 @@ func TestValidate_HeaderAuths(t *testing.T) { assert.Contains(t, err.Error(), "exceeds maximum length") }) } + +func TestValidate_HTTPClusterTarget(t *testing.T) { + rp := validProxy() + rp.Targets = []*Target{{ + TargetId: "eu.proxy.netbird.io", + TargetType: TargetTypeCluster, + Protocol: "http", + Host: "backend.lan", + Options: TargetOptions{DirectUpstream: true}, + Enabled: true, + }} + require.NoError(t, rp.Validate(), "HTTP cluster target with target_id, host, and direct_upstream must validate") +} + +func TestValidate_HTTPClusterTarget_RequiresTargetId(t *testing.T) { + rp := validProxy() + rp.Targets = []*Target{{ + TargetType: TargetTypeCluster, + Protocol: "http", + Host: "backend.lan", + Options: TargetOptions{DirectUpstream: true}, + Enabled: true, + }} + assert.ErrorContains(t, rp.Validate(), "empty target_id", "cluster target must reject empty target_id") +} + +// TestValidate_HTTPClusterTarget_RequiresHost pins the new cluster-target +// rule that operator-supplied Host is mandatory: cluster targets dial the +// upstream via the host network stack (direct_upstream is implied), so an +// empty Host leaves the proxy with nothing to dial. +func TestValidate_HTTPClusterTarget_RequiresHost(t *testing.T) { + rp := validProxy() + rp.Targets = []*Target{{ + TargetId: "eu.proxy.netbird.io", + TargetType: TargetTypeCluster, + Protocol: "http", + Options: TargetOptions{DirectUpstream: true}, + Enabled: true, + }} + assert.ErrorContains(t, rp.Validate(), "empty host", "cluster target must reject empty host") +} + +// TestValidate_HTTPClusterTarget_RequiresDirectUpstream pins the second +// half of the cluster-target rule: DirectUpstream must be true so the +// stdlib transport branch in MultiTransport is taken. Without it the +// embedded NetBird client would try to dial the cluster address through +// the WG tunnel, which is the wrong network for a cluster upstream. +func TestValidate_HTTPClusterTarget_RequiresDirectUpstream(t *testing.T) { + rp := validProxy() + rp.Targets = []*Target{{ + TargetId: "eu.proxy.netbird.io", + TargetType: TargetTypeCluster, + Protocol: "http", + Host: "backend.lan", + Enabled: true, + }} + assert.ErrorContains(t, rp.Validate(), "direct upstream disabled", "cluster target must reject direct_upstream=false") +} + +func TestValidate_L4ClusterTarget(t *testing.T) { + rp := validProxy() + rp.Mode = ModeTCP + rp.ListenPort = 9000 + rp.Targets = []*Target{{ + TargetId: "eu.proxy.netbird.io", + TargetType: TargetTypeCluster, + Protocol: "tcp", + Enabled: true, + }} + require.NoError(t, rp.Validate(), "L4 cluster target must validate without an explicit port") +} + +func TestService_Copy_RoundtripsPrivate(t *testing.T) { + svc := validProxy() + svc.Private = true + svc.AccessGroups = []string{"grp-admins", "grp-ops"} + cp := svc.Copy() + require.NotNil(t, cp) + assert.True(t, cp.Private) + assert.Equal(t, []string{"grp-admins", "grp-ops"}, cp.AccessGroups) + + cp.Private = false + assert.True(t, svc.Private) + + cp.AccessGroups[0] = "grp-other" + assert.Equal(t, []string{"grp-admins", "grp-ops"}, svc.AccessGroups) +} + +func TestService_APIRoundtrip_Private(t *testing.T) { + enabled := true + private := true + accessGroups := []string{"grp-admins"} + targets := []api.ServiceTarget{{ + TargetId: "eu.proxy.netbird.io", + TargetType: api.ServiceTargetTargetType("cluster"), + Protocol: "http", + Port: 80, + Enabled: true, + }} + req := &api.ServiceRequest{ + Name: "svc-private", + Domain: "myapp.eu.proxy.netbird.io", + Enabled: enabled, + Private: &private, + AccessGroups: &accessGroups, + Targets: &targets, + } + + svc := &Service{} + require.NoError(t, svc.FromAPIRequest(req, "acc-1")) + assert.True(t, svc.Private) + assert.Equal(t, []string{"grp-admins"}, svc.AccessGroups) + + resp := svc.ToAPIResponse() + require.NotNil(t, resp.Private) + assert.True(t, *resp.Private) + require.NotNil(t, resp.AccessGroups) + assert.Equal(t, []string{"grp-admins"}, *resp.AccessGroups) +} + +func TestValidate_Private_RequiresAccessGroups(t *testing.T) { + rp := validProxy() + rp.Private = true + rp.Targets = []*Target{{ + TargetId: "eu.proxy.netbird.io", + TargetType: TargetTypeCluster, + Protocol: "http", + Host: "backend.lan", + Options: TargetOptions{DirectUpstream: true}, + Enabled: true, + }} + assert.ErrorContains(t, rp.Validate(), "access group") +} + +func TestValidate_Private_RejectsBearerAuth(t *testing.T) { + rp := validProxy() + rp.Private = true + rp.AccessGroups = []string{"grp-admins"} + rp.Auth.BearerAuth = &BearerAuthConfig{ + Enabled: true, + DistributionGroups: []string{"grp-sso"}, + } + rp.Targets = []*Target{{ + TargetId: "eu.proxy.netbird.io", + TargetType: TargetTypeCluster, + Protocol: "http", + Host: "backend.lan", + Options: TargetOptions{DirectUpstream: true}, + Enabled: true, + }} + assert.ErrorContains(t, rp.Validate(), "mutually exclusive") +} + +func TestValidate_Private_AcceptsNonClusterTargets(t *testing.T) { + rp := validProxy() + rp.Private = true + rp.AccessGroups = []string{"grp-admins"} + require.NoError(t, rp.Validate()) +} + +func TestValidate_Private_AcceptsClusterTargetWithAccessGroups(t *testing.T) { + rp := validProxy() + rp.Private = true + rp.AccessGroups = []string{"grp-admins"} + rp.Targets = []*Target{{ + TargetId: "eu.proxy.netbird.io", + TargetType: TargetTypeCluster, + Protocol: "http", + Host: "backend.lan", + Options: TargetOptions{DirectUpstream: true}, + Enabled: true, + }} + require.NoError(t, rp.Validate()) +} + +func TestValidate_Private_RejectsNonHTTPMode(t *testing.T) { + rp := validProxy() + rp.Private = true + rp.AccessGroups = []string{"grp-admins"} + rp.Mode = ModeTCP + rp.Targets = []*Target{{ + TargetId: "eu.proxy.netbird.io", + TargetType: TargetTypeCluster, + Protocol: "tcp", + Enabled: true, + }} + assert.ErrorContains(t, rp.Validate(), "HTTP") +} diff --git a/management/internals/modules/reverseproxy/sessionkey/sessionkey.go b/management/internals/modules/reverseproxy/sessionkey/sessionkey.go index aacbe5dca..1fb6a323d 100644 --- a/management/internals/modules/reverseproxy/sessionkey/sessionkey.go +++ b/management/internals/modules/reverseproxy/sessionkey/sessionkey.go @@ -20,6 +20,20 @@ type KeyPair struct { type Claims struct { jwt.RegisteredClaims Method auth.Method `json:"method"` + // Email is the calling user's email address. Carried so the + // proxy can stamp identity on upstream requests (e.g. + // x-litellm-end-user-id) without an extra management + // round-trip on every cookie-bearing request. + Email string `json:"email,omitempty"` + // Groups carries the user's group IDs so the proxy can stamp them + // onto upstream requests (X-NetBird-Groups) from the cookie path + // without an extra management round-trip. + Groups []string `json:"groups,omitempty"` + // GroupNames carries the human-readable display names for the ids + // in Groups, ordered identically (positional pairing). Slice may be + // shorter than Groups for tokens minted before names were + // resolvable; the consumer falls back to ids for missing positions. + GroupNames []string `json:"group_names,omitempty"` } func GenerateKeyPair() (*KeyPair, error) { @@ -34,7 +48,13 @@ func GenerateKeyPair() (*KeyPair, error) { }, nil } -func SignToken(privKeyB64, userID, domain string, method auth.Method, expiration time.Duration) (string, error) { +// SignToken mints a session JWT for the given user and domain. email, +// groups, and groupNames, when non-empty, are embedded so the proxy can +// authorise and stamp identity for policy-aware middlewares without a +// management round-trip on every cookie-bearing request. groupNames +// pairs positionally with groups; pass nil when names couldn't be +// resolved. +func SignToken(privKeyB64, userID, email, domain string, method auth.Method, groups, groupNames []string, expiration time.Duration) (string, error) { privKeyBytes, err := base64.StdEncoding.DecodeString(privKeyB64) if err != nil { return "", fmt.Errorf("decode private key: %w", err) @@ -56,7 +76,10 @@ func SignToken(privKeyB64, userID, domain string, method auth.Method, expiration IssuedAt: jwt.NewNumericDate(now), NotBefore: jwt.NewNumericDate(now), }, - Method: method, + Method: method, + Email: email, + Groups: append([]string(nil), groups...), + GroupNames: append([]string(nil), groupNames...), } token := jwt.NewWithClaims(jwt.SigningMethodEdDSA, claims) diff --git a/management/internals/shared/grpc/proxy.go b/management/internals/shared/grpc/proxy.go index 4abeb8e7c..e7155ae09 100644 --- a/management/internals/shared/grpc/proxy.go +++ b/management/internals/shared/grpc/proxy.go @@ -351,6 +351,7 @@ func (s *ProxyServiceServer) registerProxyConnection(ctx context.Context, params SupportsCustomPorts: c.SupportsCustomPorts, RequireSubdomain: c.RequireSubdomain, SupportsCrowdsec: c.SupportsCrowdsec, + Private: c.Private, } } @@ -754,6 +755,11 @@ func (s *ProxyServiceServer) SendServiceUpdate(update *proto.GetMappingUpdateRes InitialSyncComplete: update.InitialSyncComplete, } } + // Drop mappings the proxy lacks capability for (e.g. private without SupportsPrivateService). + connUpdate = filterMappingsForProxy(conn, connUpdate) + if connUpdate == nil || len(connUpdate.Mapping) == 0 { + return true + } resp := s.perProxyMessage(connUpdate, conn.proxyID) if resp == nil { log.Warnf("Token generation failed for proxy %s, disconnecting to force resync", conn.proxyID) @@ -882,16 +888,20 @@ func (s *ProxyServiceServer) SendServiceUpdateToCluster(ctx context.Context, upd } } -// proxyAcceptsMapping returns whether the proxy should receive this mapping. -// Old proxies that never reported capabilities are skipped for non-TLS L4 -// mappings with a custom listen port, since they don't understand the -// protocol. Proxies that report capabilities (even SupportsCustomPorts=false) -// are new enough to handle the mapping. TLS uses SNI routing and works on -// any proxy. Delete operations are always sent so proxies can clean up. +// proxyAcceptsMapping returns whether the proxy can receive this mapping. +// Private mappings require SupportsPrivateService; custom-port L4 mappings +// require SupportsCustomPorts. Remove operations always pass so proxies can +// clean up. func proxyAcceptsMapping(conn *proxyConnection, mapping *proto.ProxyMapping) bool { if mapping.Type == proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED { return true } + if mapping.GetPrivate() { + caps := conn.capabilities + if caps == nil || caps.SupportsPrivateService == nil || !*caps.SupportsPrivateService { + return false + } + } if mapping.ListenPort == 0 || mapping.Mode == "tls" { return true } @@ -900,6 +910,29 @@ func proxyAcceptsMapping(conn *proxyConnection, mapping *proto.ProxyMapping) boo return conn.capabilities != nil && conn.capabilities.SupportsCustomPorts != nil } +// filterMappingsForProxy drops mappings the proxy cannot safely receive +// (e.g. private mappings to a proxy without SupportsPrivateService). +// Returns the input unchanged when no filtering is needed. +func filterMappingsForProxy(conn *proxyConnection, update *proto.GetMappingUpdateResponse) *proto.GetMappingUpdateResponse { + if update == nil || len(update.Mapping) == 0 { + return update + } + kept := make([]*proto.ProxyMapping, 0, len(update.Mapping)) + for _, m := range update.Mapping { + if !proxyAcceptsMapping(conn, m) { + continue + } + kept = append(kept, m) + } + if len(kept) == len(update.Mapping) { + return update + } + return &proto.GetMappingUpdateResponse{ + Mapping: kept, + InitialSyncComplete: update.InitialSyncComplete, + } +} + // perProxyMessage returns a copy of update with a fresh one-time token for // create/update operations. For delete operations the original mapping is // used unchanged because proxies do not need to authenticate for removal. @@ -961,7 +994,10 @@ func (s *ProxyServiceServer) Authenticate(ctx context.Context, req *proto.Authen authenticated, userId, method := s.authenticateRequest(ctx, req, service) - token, err := s.generateSessionToken(ctx, authenticated, service, userId, method) + // Non-OIDC schemes (PIN/Password/Header) authenticate against per-service + // secrets and have no user-level group context, so groups stay nil. Email + // is also empty โ€” these schemes don't resolve a user record at sign time. + token, err := s.generateSessionToken(ctx, authenticated, service, userId, "", method, nil, nil) if err != nil { return nil, err } @@ -1050,7 +1086,7 @@ func (s *ProxyServiceServer) logAuthenticationError(ctx context.Context, err err } } -func (s *ProxyServiceServer) generateSessionToken(ctx context.Context, authenticated bool, service *rpservice.Service, userId string, method proxyauth.Method) (string, error) { +func (s *ProxyServiceServer) generateSessionToken(ctx context.Context, authenticated bool, service *rpservice.Service, userId, userEmail string, method proxyauth.Method, groupIDs, groupNames []string) (string, error) { if !authenticated || service.SessionPrivateKey == "" { return "", nil } @@ -1058,8 +1094,11 @@ func (s *ProxyServiceServer) generateSessionToken(ctx context.Context, authentic token, err := sessionkey.SignToken( service.SessionPrivateKey, userId, + userEmail, service.Domain, method, + groupIDs, + groupNames, proxyauth.DefaultSessionExpiry, ) if err != nil { @@ -1070,6 +1109,26 @@ func (s *ProxyServiceServer) generateSessionToken(ctx context.Context, authentic return token, nil } +// pairGroupIDsAndNames splits a slice of resolved *types.Group records +// into parallel id and name slices. ids[i] and names[i] always pair to +// the same group. nil entries (orphan ids the manager couldn't resolve) +// are skipped so the consumer can rely on positional pairing. +func pairGroupIDsAndNames(groups []*types.Group) (ids, names []string) { + if len(groups) == 0 { + return nil, nil + } + ids = make([]string, 0, len(groups)) + names = make([]string, 0, len(groups)) + for _, g := range groups { + if g == nil { + continue + } + ids = append(ids, g.ID) + names = append(names, g.Name) + } + return ids, names +} + // SendStatusUpdate handles status updates from proxy clients. func (s *ProxyServiceServer) SendStatusUpdate(ctx context.Context, req *proto.SendStatusUpdateRequest) (*proto.SendStatusUpdateResponse, error) { if err := enforceAccountScope(ctx, req.GetAccountId()); err != nil { @@ -1334,7 +1393,9 @@ func (s *ProxyServiceServer) ValidateState(state string) (verifier, redirectURL return verifier, redirectURL, nil } -// GenerateSessionToken creates a signed session JWT for the given domain and user. +// GenerateSessionToken creates a signed session JWT for the given domain and +// user. The user's group memberships are embedded in the token so policy-aware +// middlewares on the proxy can authorise without an extra management round-trip. func (s *ProxyServiceServer) GenerateSessionToken(ctx context.Context, domain, userID string, method proxyauth.Method) (string, error) { service, err := s.getServiceByDomain(ctx, domain) if err != nil { @@ -1345,11 +1406,29 @@ func (s *ProxyServiceServer) GenerateSessionToken(ctx context.Context, domain, u return "", fmt.Errorf("no session key configured for domain: %s", domain) } + var ( + email string + groupIDs []string + groupNames []string + ) + if s.usersManager != nil { + user, userGroups, uerr := s.usersManager.GetUserWithGroups(ctx, userID) + if uerr != nil { + log.WithContext(ctx).Debugf("session token mint: lookup user %s: %v", userID, uerr) + } else if user != nil { + email = user.Email + groupIDs, groupNames = pairGroupIDsAndNames(userGroups) + } + } + return sessionkey.SignToken( service.SessionPrivateKey, userID, + email, domain, method, + groupIDs, + groupNames, proxyauth.DefaultSessionExpiry, ) } @@ -1453,7 +1532,7 @@ func (s *ProxyServiceServer) ValidateSession(ctx context.Context, req *proto.Val }, nil } - userID, _, err := proxyauth.ValidateSessionJWT(sessionToken, domain, pubKeyBytes) + userID, _, _, _, _, err := proxyauth.ValidateSessionJWT(sessionToken, domain, pubKeyBytes) if err != nil { log.WithFields(log.Fields{ "domain": domain, @@ -1466,7 +1545,7 @@ func (s *ProxyServiceServer) ValidateSession(ctx context.Context, req *proto.Val }, nil } - user, err := s.usersManager.GetUser(ctx, userID) + user, userGroups, err := s.usersManager.GetUserWithGroups(ctx, userID) if err != nil { log.WithFields(log.Fields{ "domain": domain, @@ -1500,12 +1579,15 @@ func (s *ProxyServiceServer) ValidateSession(ctx context.Context, req *proto.Val "user_id": userID, "error": err.Error(), }).Debug("ValidateSession: access denied") + groupIDs, groupNames := pairGroupIDsAndNames(userGroups) //nolint:nilerr return &proto.ValidateSessionResponse{ - Valid: false, - UserId: user.Id, - UserEmail: user.Email, - DeniedReason: "not_in_group", + Valid: false, + UserId: user.Id, + UserEmail: user.Email, + DeniedReason: "not_in_group", + PeerGroupIds: groupIDs, + PeerGroupNames: groupNames, }, nil } @@ -1515,10 +1597,13 @@ func (s *ProxyServiceServer) ValidateSession(ctx context.Context, req *proto.Val "email": user.Email, }).Debug("ValidateSession: access granted") + groupIDs, groupNames := pairGroupIDsAndNames(userGroups) return &proto.ValidateSessionResponse{ - Valid: true, - UserId: user.Id, - UserEmail: user.Email, + Valid: true, + UserId: user.Id, + UserEmail: user.Email, + PeerGroupIds: groupIDs, + PeerGroupNames: groupNames, }, nil } @@ -1551,3 +1636,154 @@ func (s *ProxyServiceServer) checkGroupAccess(service *rpservice.Service, user * } func ptr[T any](v T) *T { return &v } + +// ValidateTunnelPeer resolves an inbound peer by its WireGuard tunnel IP and +// checks the peer's group membership against the service's access groups. +// Peers without a user (machine agents, automation workloads) are first-class +// callers; authorisation runs off peer-group memberships rather than the +// optional owning user's auto-groups. On success a session JWT is minted so +// the proxy can install a cookie and skip subsequent management round-trips. +func (s *ProxyServiceServer) ValidateTunnelPeer(ctx context.Context, req *proto.ValidateTunnelPeerRequest) (*proto.ValidateTunnelPeerResponse, error) { + domain := req.GetDomain() + tunnelIPStr := req.GetTunnelIp() + + if domain == "" || tunnelIPStr == "" { + return &proto.ValidateTunnelPeerResponse{ + Valid: false, + DeniedReason: "missing domain or tunnel_ip", + }, nil + } + + tunnelIP := net.ParseIP(tunnelIPStr) + if tunnelIP == nil { + return &proto.ValidateTunnelPeerResponse{ + Valid: false, + DeniedReason: "invalid_tunnel_ip", + }, nil + } + + service, err := s.getServiceByDomain(ctx, domain) + if err != nil { + log.WithFields(log.Fields{"domain": domain, "error": err.Error()}).Debug("ValidateTunnelPeer: service not found") + //nolint:nilerr + return &proto.ValidateTunnelPeerResponse{ + Valid: false, + DeniedReason: "service_not_found", + }, nil + } + + // Mirror ValidateSession: account-scoped (BYOP) proxy tokens may only + // validate and mint session cookies for their own account's domains. + if err := enforceAccountScope(ctx, service.AccountID); err != nil { + return nil, err + } + + peer, err := s.peersManager.GetPeerByTunnelIP(ctx, service.AccountID, tunnelIP) + if err != nil || peer == nil { + log.WithFields(log.Fields{"domain": domain, "tunnel_ip": tunnelIPStr}).Debug("ValidateTunnelPeer: peer not found") + //nolint:nilerr + return &proto.ValidateTunnelPeerResponse{ + Valid: false, + DeniedReason: "peer_not_found", + }, nil + } + + _, peerGroups, err := s.peersManager.GetPeerWithGroups(ctx, service.AccountID, peer.ID) + if err != nil { + log.WithFields(log.Fields{"domain": domain, "peer_id": peer.ID, "error": err.Error()}).Debug("ValidateTunnelPeer: peer groups lookup failed") + //nolint:nilerr + return &proto.ValidateTunnelPeerResponse{ + Valid: false, + DeniedReason: "peer_not_found", + }, nil + } + + groupIDs, groupNames := pairGroupIDsAndNames(peerGroups) + + // Resolve the principal: when the peer is linked to a user, the human + // is the principal so multiple peers owned by the same user share a + // single identity. Unlinked peers (machine agents) are their own + // principal keyed on peer.ID. displayIdentity is what upstream gateways + // tag spend with โ€” user.Email when linked, peer.Name when not. + principalID := peer.ID + displayIdentity := peer.Name + if peer.UserID != "" { + if user, uerr := s.usersManager.GetUser(ctx, peer.UserID); uerr == nil && user != nil { + principalID = user.Id + if user.Email != "" { + displayIdentity = user.Email + } + } + } + + if err := checkPeerGroupAccess(service, groupIDs); err != nil { + log.WithFields(log.Fields{"domain": domain, "peer_id": peer.ID, "error": err.Error()}).Debug("ValidateTunnelPeer: access denied") + //nolint:nilerr + return &proto.ValidateTunnelPeerResponse{ + Valid: false, + UserId: principalID, + UserEmail: displayIdentity, + DeniedReason: "not_in_group", + PeerGroupIds: groupIDs, + PeerGroupNames: groupNames, + }, nil + } + + token, err := s.generateSessionToken(ctx, true, service, principalID, displayIdentity, proxyauth.MethodOIDC, groupIDs, groupNames) + if err != nil { + return nil, err + } + + log.WithFields(log.Fields{ + "domain": domain, + "tunnel_ip": tunnelIPStr, + "peer_id": peer.ID, + "principal_id": principalID, + }).Debug("ValidateTunnelPeer: access granted") + + return &proto.ValidateTunnelPeerResponse{ + Valid: true, + UserId: principalID, + UserEmail: displayIdentity, + SessionToken: token, + PeerGroupIds: groupIDs, + PeerGroupNames: groupNames, + }, nil +} + +// checkPeerGroupAccess gates ValidateTunnelPeer by the service's required +// groups. Private services authorise against AccessGroups (empty list fails +// closed โ€” Validate() rejects that at save time but the RPC is the security +// boundary and must not trust upstream state). Bearer-auth services authorise +// against DistributionGroups when populated. Non-private non-bearer services +// are open. +func checkPeerGroupAccess(service *rpservice.Service, peerGroupIDs []string) error { + if service.Private { + if len(service.AccessGroups) == 0 { + return fmt.Errorf("private service has no access groups") + } + return matchAnyGroup(service.AccessGroups, peerGroupIDs) + } + if service.Auth.BearerAuth != nil && service.Auth.BearerAuth.Enabled && len(service.Auth.BearerAuth.DistributionGroups) > 0 { + return matchAnyGroup(service.Auth.BearerAuth.DistributionGroups, peerGroupIDs) + } + return nil +} + +// matchAnyGroup returns nil when peerGroupIDs intersects allowedGroups, +// else a non-nil error. +func matchAnyGroup(allowedGroups, peerGroupIDs []string) error { + if len(allowedGroups) == 0 { + return fmt.Errorf("no allowed groups configured") + } + allowed := make(map[string]struct{}, len(allowedGroups)) + for _, g := range allowedGroups { + allowed[g] = struct{}{} + } + for _, g := range peerGroupIDs { + if _, ok := allowed[g]; ok { + return nil + } + } + return fmt.Errorf("peer not in allowed groups") +} diff --git a/management/internals/shared/grpc/proxy_group_access_test.go b/management/internals/shared/grpc/proxy_group_access_test.go index 5980f8a30..76da7ddbc 100644 --- a/management/internals/shared/grpc/proxy_group_access_test.go +++ b/management/internals/shared/grpc/proxy_group_access_test.go @@ -129,6 +129,14 @@ func (m *mockUsersManager) GetUser(ctx context.Context, userID string) (*types.U return user, nil } +func (m *mockUsersManager) GetUserWithGroups(ctx context.Context, userID string) (*types.User, []*types.Group, error) { + user, err := m.GetUser(ctx, userID) + if err != nil { + return nil, nil, err + } + return user, nil, nil +} + func TestValidateUserGroupAccess(t *testing.T) { tests := []struct { name string @@ -420,3 +428,46 @@ func TestGetAccountProxyByDomain(t *testing.T) { }) } } + +func TestCheckPeerGroupAccess(t *testing.T) { + t.Run("private with empty AccessGroups denies", func(t *testing.T) { + svc := &service.Service{Private: true, AccessGroups: nil} + err := checkPeerGroupAccess(svc, []string{"grp-admins"}) + require.Error(t, err) + assert.Contains(t, err.Error(), "no access groups") + }) + + t.Run("private with peer in AccessGroups allows", func(t *testing.T) { + svc := &service.Service{Private: true, AccessGroups: []string{"grp-admins", "grp-ops"}} + assert.NoError(t, checkPeerGroupAccess(svc, []string{"grp-other", "grp-ops"})) + }) + + t.Run("private with peer outside AccessGroups denies", func(t *testing.T) { + svc := &service.Service{Private: true, AccessGroups: []string{"grp-admins"}} + assert.Error(t, checkPeerGroupAccess(svc, []string{"grp-other"})) + }) + + t.Run("bearer enabled with empty DistributionGroups allows", func(t *testing.T) { + svc := &service.Service{ + Auth: service.AuthConfig{BearerAuth: &service.BearerAuthConfig{Enabled: true}}, + } + assert.NoError(t, checkPeerGroupAccess(svc, []string{"grp-anyone"})) + }) + + t.Run("bearer enabled gates on DistributionGroups", func(t *testing.T) { + svc := &service.Service{ + Auth: service.AuthConfig{ + BearerAuth: &service.BearerAuthConfig{ + Enabled: true, + DistributionGroups: []string{"grp-allowed"}, + }, + }, + } + assert.NoError(t, checkPeerGroupAccess(svc, []string{"grp-allowed"})) + assert.Error(t, checkPeerGroupAccess(svc, []string{"grp-other"})) + }) + + t.Run("non-private non-bearer is open", func(t *testing.T) { + assert.NoError(t, checkPeerGroupAccess(&service.Service{}, nil)) + }) +} diff --git a/management/internals/shared/grpc/validate_session_test.go b/management/internals/shared/grpc/validate_session_test.go index 774c5d1d3..1dc2dac28 100644 --- a/management/internals/shared/grpc/validate_session_test.go +++ b/management/internals/shared/grpc/validate_session_test.go @@ -102,7 +102,7 @@ func generateSessionKeyPair(t *testing.T) (string, string) { func createSessionToken(t *testing.T, privKeyB64, userID, domain string) string { t.Helper() - token, err := sessionkey.SignToken(privKeyB64, userID, domain, auth.MethodOIDC, time.Hour) + token, err := sessionkey.SignToken(privKeyB64, userID, domain, auth.MethodOIDC, nil, time.Hour) require.NoError(t, err) return token } @@ -125,6 +125,7 @@ func TestValidateSession_UserAllowed(t *testing.T) { assert.True(t, resp.Valid, "User should be allowed access") assert.Equal(t, "allowedUserId", resp.UserId) assert.Empty(t, resp.DeniedReason) + assert.Equal(t, []string{"allowedGroupId"}, resp.GetPeerGroupIds(), "PeerGroupIds must mirror the resolved user's group memberships") } func TestValidateSession_UserNotInAllowedGroup(t *testing.T) { @@ -145,6 +146,7 @@ func TestValidateSession_UserNotInAllowedGroup(t *testing.T) { assert.False(t, resp.Valid, "User not in group should be denied") assert.Equal(t, "not_in_group", resp.DeniedReason) assert.Equal(t, "nonGroupUserId", resp.UserId) + assert.Empty(t, resp.GetPeerGroupIds(), "PeerGroupIds must mirror the resolved user's actual (empty) memberships on denial") } func TestValidateSession_UserInDifferentAccount(t *testing.T) { diff --git a/management/server/metrics/selfhosted.go b/management/server/metrics/selfhosted.go index 8732cf89f..efe50c88f 100644 --- a/management/server/metrics/selfhosted.go +++ b/management/server/metrics/selfhosted.go @@ -17,6 +17,7 @@ import ( rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/types" nbversion "github.com/netbirdio/netbird/version" ) @@ -53,6 +54,7 @@ type DataSource interface { GetAllAccounts(ctx context.Context) []*types.Account GetStoreEngine() types.Engine GetCustomDomainsCounts(ctx context.Context) (total int64, validated int64, err error) + GetProxyMetrics(ctx context.Context) (store.ProxyMetrics, error) } // ConnManager peer connection manager that holds state for current active connections @@ -223,6 +225,12 @@ func (w *Worker) generateProperties(ctx context.Context) properties { servicesAuthPassword int servicesAuthPin int servicesAuthOIDC int + // Private-service signals โ€” track adoption of NetBird-only mode + // (services backed by an embedded proxy peer + access groups). + servicesPrivate int + servicesPrivateWithGroups int + servicesPrivateAccessGroupsSum int + servicesWithDirectUpstream int ) start := time.Now() metricsProperties := make(properties) @@ -380,9 +388,31 @@ func (w *Worker) generateProperties(ctx context.Context) properties { if service.Auth.BearerAuth != nil && service.Auth.BearerAuth.Enabled { servicesAuthOIDC++ } + + if service.Private { + servicesPrivate++ + if len(service.AccessGroups) > 0 { + servicesPrivateWithGroups++ + } + servicesPrivateAccessGroupsSum += len(service.AccessGroups) + } + + for _, target := range service.Targets { + if target.Options.DirectUpstream { + servicesWithDirectUpstream++ + break + } + } } } + // Proxy / BYOP cluster signals come from the proxies table aggregated + // across all accounts in a single store query; nil on FileStore. + proxyMetrics, err := w.dataSource.GetProxyMetrics(ctx) + if err != nil { + log.WithContext(ctx).Debugf("collect proxy metrics: %v", err) + } + minActivePeerVersion, maxActivePeerVersion := getMinMaxVersion(peerActiveVersions) metricsProperties["uptime"] = uptime metricsProperties["accounts"] = accounts @@ -430,6 +460,15 @@ func (w *Worker) generateProperties(ctx context.Context) properties { metricsProperties["services_auth_password"] = servicesAuthPassword metricsProperties["services_auth_pin"] = servicesAuthPin metricsProperties["services_auth_oidc"] = servicesAuthOIDC + metricsProperties["services_private"] = servicesPrivate + metricsProperties["services_private_with_access_groups"] = servicesPrivateWithGroups + metricsProperties["services_private_access_groups_sum"] = servicesPrivateAccessGroupsSum + metricsProperties["services_with_direct_upstream"] = servicesWithDirectUpstream + metricsProperties["proxy_clusters"] = proxyMetrics.Clusters + metricsProperties["proxy_clusters_byop"] = proxyMetrics.ClustersBYOP + metricsProperties["proxy_clusters_private"] = proxyMetrics.ClustersPrivate + metricsProperties["proxies"] = proxyMetrics.Proxies + metricsProperties["proxies_connected"] = proxyMetrics.ProxiesConnected metricsProperties["custom_domains"] = customDomains metricsProperties["custom_domains_validated"] = customDomainsValidated diff --git a/management/server/metrics/selfhosted_test.go b/management/server/metrics/selfhosted_test.go index 78f5c53be..ca9e10262 100644 --- a/management/server/metrics/selfhosted_test.go +++ b/management/server/metrics/selfhosted_test.go @@ -12,6 +12,7 @@ import ( networkTypes "github.com/netbirdio/netbird/management/server/networks/types" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/posture" + "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/route" ) @@ -123,7 +124,7 @@ func (mockDatasource) GetAllAccounts(_ context.Context) []*types.Account { Enabled: true, Targets: []*rpservice.Target{ {TargetType: "peer"}, - {TargetType: "host"}, + {TargetType: "host", Options: rpservice.TargetOptions{DirectUpstream: true}}, }, Auth: rpservice.AuthConfig{ PasswordAuth: &rpservice.PasswordAuthConfig{Enabled: true}, @@ -141,6 +142,16 @@ func (mockDatasource) GetAllAccounts(_ context.Context) []*types.Account { }, Meta: rpservice.Meta{Status: string(rpservice.StatusPending)}, }, + { + ID: "svc3-private", + Enabled: true, + Private: true, + AccessGroups: []string{"grp-eng", "grp-ops"}, + Targets: []*rpservice.Target{ + {TargetType: "cluster", Options: rpservice.TargetOptions{DirectUpstream: true}}, + }, + Meta: rpservice.Meta{Status: string(rpservice.StatusActive)}, + }, }, }, { @@ -254,6 +265,18 @@ func (mockDatasource) GetCustomDomainsCounts(_ context.Context) (int64, int64, e return 3, 2, nil } +// GetProxyMetrics returns canned proxy/cluster counts so the +// generateProperties test can assert the BYOP signals end-to-end. +func (mockDatasource) GetProxyMetrics(_ context.Context) (store.ProxyMetrics, error) { + return store.ProxyMetrics{ + Clusters: 3, + ClustersBYOP: 1, + ClustersPrivate: 1, + Proxies: 4, + ProxiesConnected: 2, + }, nil +} + // TestGenerateProperties tests and validate the properties generation by using the mockDatasource for the Worker.generateProperties func TestGenerateProperties(t *testing.T) { ds := mockDatasource{} @@ -393,17 +416,17 @@ func TestGenerateProperties(t *testing.T) { t.Errorf("expected 3 embedded_idp_count, got %v", properties["embedded_idp_count"]) } - if properties["services"] != 2 { - t.Errorf("expected 2 services, got %v", properties["services"]) + if properties["services"] != 3 { + t.Errorf("expected 3 services, got %v", properties["services"]) } - if properties["services_enabled"] != 1 { - t.Errorf("expected 1 services_enabled, got %v", properties["services_enabled"]) + if properties["services_enabled"] != 2 { + t.Errorf("expected 2 services_enabled, got %v", properties["services_enabled"]) } - if properties["services_targets"] != 3 { - t.Errorf("expected 3 services_targets, got %v", properties["services_targets"]) + if properties["services_targets"] != 4 { + t.Errorf("expected 4 services_targets, got %v", properties["services_targets"]) } - if properties["services_status_active"] != 1 { - t.Errorf("expected 1 services_status_active, got %v", properties["services_status_active"]) + if properties["services_status_active"] != 2 { + t.Errorf("expected 2 services_status_active, got %v", properties["services_status_active"]) } if properties["services_status_pending"] != 1 { t.Errorf("expected 1 services_status_pending, got %v", properties["services_status_pending"]) @@ -420,6 +443,9 @@ func TestGenerateProperties(t *testing.T) { if properties["services_target_type_domain"] != 1 { t.Errorf("expected 1 services_target_type_domain, got %v", properties["services_target_type_domain"]) } + if properties["services_target_type_cluster"] != 1 { + t.Errorf("expected 1 services_target_type_cluster, got %v", properties["services_target_type_cluster"]) + } if properties["services_auth_password"] != 1 { t.Errorf("expected 1 services_auth_password, got %v", properties["services_auth_password"]) } @@ -429,6 +455,33 @@ func TestGenerateProperties(t *testing.T) { if properties["services_auth_pin"] != 0 { t.Errorf("expected 0 services_auth_pin, got %v", properties["services_auth_pin"]) } + if properties["services_private"] != 1 { + t.Errorf("expected 1 services_private, got %v", properties["services_private"]) + } + if properties["services_private_with_access_groups"] != 1 { + t.Errorf("expected 1 services_private_with_access_groups, got %v", properties["services_private_with_access_groups"]) + } + if properties["services_private_access_groups_sum"] != 2 { + t.Errorf("expected 2 services_private_access_groups_sum, got %v", properties["services_private_access_groups_sum"]) + } + if properties["services_with_direct_upstream"] != 2 { + t.Errorf("expected 2 services_with_direct_upstream, got %v", properties["services_with_direct_upstream"]) + } + if properties["proxy_clusters"] != int64(3) { + t.Errorf("expected 3 proxy_clusters, got %v", properties["proxy_clusters"]) + } + if properties["proxy_clusters_byop"] != int64(1) { + t.Errorf("expected 1 proxy_clusters_byop, got %v", properties["proxy_clusters_byop"]) + } + if properties["proxy_clusters_private"] != int64(1) { + t.Errorf("expected 1 proxy_clusters_private, got %v", properties["proxy_clusters_private"]) + } + if properties["proxies"] != int64(4) { + t.Errorf("expected 4 proxies, got %v", properties["proxies"]) + } + if properties["proxies_connected"] != int64(2) { + t.Errorf("expected 2 proxies_connected, got %v", properties["proxies_connected"]) + } if properties["custom_domains"] != int64(3) { t.Errorf("expected 3 custom_domains, got %v", properties["custom_domains"]) } diff --git a/management/server/peer.go b/management/server/peer.go index 34b681f51..37cacee41 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -125,6 +125,18 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK } } + // An embedded proxy peer flipping to connected is the trigger for + // SynthesizePrivateServiceZones to emit DNS A records pointing at its + // tunnel IP. Without an account-wide netmap recompute, user peers keep + // the stale synth (or no synth at all on first connect) until some + // other change pokes the controller. Fire OnPeersUpdated so the + // buffered recompute fans the new state out to every peer. + if peer.ProxyMeta.Embedded { + if err := am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peer.ID}); err != nil { + log.WithContext(ctx).Warnf("notify network map controller of embedded proxy %s connect: %v", peer.ID, err) + } + } + return nil } @@ -160,6 +172,17 @@ func (am *DefaultAccountManager) MarkPeerDisconnected(ctx context.Context, peerP return nil } am.metrics.AccountManagerMetrics().CountPeerStatusUpdate(telemetry.PeerStatusDisconnect, telemetry.PeerStatusApplied) + + // Symmetric with MarkPeerConnected: when an embedded proxy peer goes + // offline, drive an account-wide netmap recompute so the synthesized + // DNS records that pointed at it are pulled. Without this the records + // linger client-side at TTL until something else triggers a refresh. + if peer.ProxyMeta.Embedded { + if err := am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peer.ID}); err != nil { + log.WithContext(ctx).Warnf("notify network map controller of embedded proxy %s disconnect: %v", peer.ID, err) + } + } + return nil } diff --git a/management/server/store/file_store.go b/management/server/store/file_store.go index 81185b020..bcf563cd0 100644 --- a/management/server/store/file_store.go +++ b/management/server/store/file_store.go @@ -274,3 +274,9 @@ func (s *FileStore) SetFieldEncrypt(_ *crypt.FieldEncrypt) { func (s *FileStore) GetCustomDomainsCounts(_ context.Context) (int64, int64, error) { return 0, 0, nil } + +// GetProxyMetrics is a no-op for FileStore โ€” proxy/cluster state isn't +// persisted in the JSON file format. +func (s *FileStore) GetProxyMetrics(_ context.Context) (ProxyMetrics, error) { + return ProxyMetrics{}, nil +} diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index 279c0e21f..d8c27fb5c 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -1090,6 +1090,38 @@ func (s *SqlStore) GetCustomDomainsCounts(ctx context.Context) (int64, int64, er return total, validated, nil } +// GetProxyMetrics aggregates per-cluster + per-proxy counts for the +// self-hosted telemetry payload. Single round-trip via conditional +// aggregations so a large proxies table doesn't fan out into multiple +// queries. +func (s *SqlStore) GetProxyMetrics(ctx context.Context) (ProxyMetrics, error) { + var m ProxyMetrics + activeCutoff := time.Now().Add(-proxyActiveThreshold) + + // COUNT(DISTINCT ... CASE WHEN ...) is portable across sqlite/postgres + // (MySQL too) and keeps the round-trip to one. proxy.StatusConnected + // is the same string the cluster-capability queries use; the active + // window matches the cluster-capability semantics (only proxies + // heartbeating within ~2 * heartbeat interval count as connected). + row := s.db.WithContext(ctx). + Model(&proxy.Proxy{}). + Select( + "COUNT(DISTINCT cluster_address) AS clusters, "+ + "COUNT(DISTINCT CASE WHEN account_id IS NOT NULL THEN cluster_address END) AS clusters_byop, "+ + "COUNT(DISTINCT CASE WHEN private = ? THEN cluster_address END) AS clusters_private, "+ + "COUNT(*) AS proxies, "+ + "COUNT(CASE WHEN status = ? AND last_seen > ? THEN 1 END) AS proxies_connected", + true, + proxy.StatusConnected, + activeCutoff, + ). + Row() + if err := row.Scan(&m.Clusters, &m.ClustersBYOP, &m.ClustersPrivate, &m.Proxies, &m.ProxiesConnected); err != nil { + return ProxyMetrics{}, fmt.Errorf("scan proxy metrics: %w", err) + } + return m, nil +} + func (s *SqlStore) GetAllAccounts(ctx context.Context) (all []*types.Account) { var accounts []types.Account result := s.db.Find(&accounts) @@ -2178,7 +2210,8 @@ func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*rpserv const serviceQuery = `SELECT id, account_id, name, domain, enabled, auth, meta_created_at, meta_certificate_issued_at, meta_status, proxy_cluster, pass_host_header, rewrite_redirects, session_private_key, session_public_key, - mode, listen_port, port_auto_assigned, source, source_peer, terminated + mode, listen_port, port_auto_assigned, source, source_peer, terminated, + private, access_groups FROM services WHERE account_id = $1` const targetsQuery = `SELECT id, account_id, service_id, path, host, port, protocol, @@ -2193,10 +2226,11 @@ func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*rpserv services, err := pgx.CollectRows(serviceRows, func(row pgx.CollectableRow) (*rpservice.Service, error) { var s rpservice.Service var auth []byte + var accessGroups []byte var createdAt, certIssuedAt sql.NullTime var status, proxyCluster, sessionPrivateKey, sessionPublicKey sql.NullString var mode, source, sourcePeer sql.NullString - var terminated, portAutoAssigned sql.NullBool + var terminated, portAutoAssigned, private sql.NullBool var listenPort sql.NullInt64 err := row.Scan( &s.ID, @@ -2219,6 +2253,8 @@ func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*rpserv &source, &sourcePeer, &terminated, + &private, + &accessGroups, ) if err != nil { return nil, err @@ -2230,6 +2266,16 @@ func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*rpserv } } + if len(accessGroups) > 0 { + if err := json.Unmarshal(accessGroups, &s.AccessGroups); err != nil { + return nil, fmt.Errorf("unmarshal access_groups: %w", err) + } + } + + if private.Valid { + s.Private = private.Bool + } + s.Meta = rpservice.Meta{} if createdAt.Valid { s.Meta.CreatedAt = createdAt.Time @@ -5826,6 +5872,7 @@ var validCapabilityColumns = map[string]struct{}{ "supports_custom_ports": {}, "require_subdomain": {}, "supports_crowdsec": {}, + "private": {}, } // GetClusterSupportsCustomPorts returns whether any active proxy in the cluster @@ -5840,6 +5887,12 @@ func (s *SqlStore) GetClusterRequireSubdomain(ctx context.Context, clusterAddr s return s.getClusterCapability(ctx, clusterAddr, "require_subdomain") } +// GetClusterSupportsPrivate reports whether any active proxy in the cluster +// has the private capability (nil = unreported). +func (s *SqlStore) GetClusterSupportsPrivate(ctx context.Context, clusterAddr string) *bool { + return s.getClusterCapability(ctx, clusterAddr, "private") +} + // GetClusterSupportsCrowdSec returns whether all active proxies in the cluster // have CrowdSec configured. Returns nil when no proxy reported the capability. // Unlike other capabilities that use ANY-true (for rolling upgrades), CrowdSec diff --git a/management/server/store/sql_store_service_test.go b/management/server/store/sql_store_service_test.go new file mode 100644 index 000000000..0978440c6 --- /dev/null +++ b/management/server/store/sql_store_service_test.go @@ -0,0 +1,46 @@ +package store + +import ( + "context" + "os" + "runtime" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" +) + +func TestSqlStore_GetAccount_PrivateServiceRoundtrip(t *testing.T) { + if (os.Getenv("CI") == "true" && runtime.GOOS == "darwin") || runtime.GOOS == "windows" { + t.Skip("skip CI tests on darwin and windows") + } + + runTestForAllEngines(t, "", func(t *testing.T, store Store) { + ctx := context.Background() + account := newAccountWithId(ctx, "account_private_svc", "testuser", "") + require.NoError(t, store.SaveAccount(ctx, account)) + + svc := &rpservice.Service{ + ID: "svc-private", + AccountID: account.Id, + Name: "private-svc", + Domain: "private.example", + ProxyCluster: "cluster.example", + Enabled: true, + Mode: rpservice.ModeHTTP, + Private: true, + AccessGroups: []string{"grp-admins", "grp-ops"}, + } + require.NoError(t, store.CreateService(ctx, svc)) + + loaded, err := store.GetAccount(ctx, account.Id) + require.NoError(t, err) + require.Len(t, loaded.Services, 1) + + got := loaded.Services[0] + assert.True(t, got.Private) + assert.Equal(t, []string{"grp-admins", "grp-ops"}, got.AccessGroups) + }) +} diff --git a/management/server/store/store.go b/management/server/store/store.go index 39b1c0ed3..746207f27 100644 --- a/management/server/store/store.go +++ b/management/server/store/store.go @@ -312,6 +312,7 @@ type Store interface { GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool GetClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool GetClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool + GetClusterSupportsPrivate(ctx context.Context, clusterAddr string) *bool CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error GetProxyByAccountID(ctx context.Context, accountID string) (*proxy.Proxy, error) CountProxiesByAccountID(ctx context.Context, accountID string) (int64, error) @@ -320,9 +321,38 @@ type Store interface { GetCustomDomainsCounts(ctx context.Context) (total int64, validated int64, err error) + // GetProxyMetrics returns aggregated proxy / cluster counts for the + // self-hosted metrics worker. Self-hosted only โ€” file-based stores + // return a zero-valued struct. + GetProxyMetrics(ctx context.Context) (ProxyMetrics, error) + GetRoutingPeerNetworks(ctx context.Context, accountID, peerID string) ([]string, error) } +// ProxyMetrics aggregates self-hosted proxy + cluster usage signals +// surfaced to the telemetry payload. Each field is best-effort: when a +// store cannot answer (e.g. FileStore) all fields are zero. +type ProxyMetrics struct { + // Clusters counts distinct cluster_address values across the proxies + // table โ€” every cluster the management server has heard from, online or not. + Clusters int64 + // ClustersBYOP counts distinct cluster_address values that are owned + // by an account (account_id IS NOT NULL). These are bring-your-own-proxy + // installations as opposed to NetBird-operated shared clusters. + ClustersBYOP int64 + // ClustersPrivate counts distinct cluster_address values where at + // least one proxy reported the private capability (embedded + // `netbird proxy` running inside a client). + ClustersPrivate int64 + // Proxies is the total number of proxy rows currently persisted. + Proxies int64 + // ProxiesConnected is the subset of proxies whose status is + // "connected" AND last_seen falls within the active heartbeat window + // (~2 * heartbeat interval). Proxies the controller hasn't pruned + // yet but that are visibly stale don't count. + ProxiesConnected int64 +} + const ( postgresDsnEnv = "NB_STORE_ENGINE_POSTGRES_DSN" postgresDsnEnvLegacy = "NETBIRD_STORE_ENGINE_POSTGRES_DSN" diff --git a/management/server/store/store_mock.go b/management/server/store/store_mock.go index c7e86c2db..dfd5af78d 100644 --- a/management/server/store/store_mock.go +++ b/management/server/store/store_mock.go @@ -1461,6 +1461,20 @@ func (mr *MockStoreMockRecorder) GetClusterSupportsCustomPorts(ctx, clusterAddr return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClusterSupportsCustomPorts", reflect.TypeOf((*MockStore)(nil).GetClusterSupportsCustomPorts), ctx, clusterAddr) } +// GetClusterSupportsPrivate mocks base method. +func (m *MockStore) GetClusterSupportsPrivate(ctx context.Context, clusterAddr string) *bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetClusterSupportsPrivate", ctx, clusterAddr) + ret0, _ := ret[0].(*bool) + return ret0 +} + +// GetClusterSupportsPrivate indicates an expected call of GetClusterSupportsPrivate. +func (mr *MockStoreMockRecorder) GetClusterSupportsPrivate(ctx, clusterAddr interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClusterSupportsPrivate", reflect.TypeOf((*MockStore)(nil).GetClusterSupportsPrivate), ctx, clusterAddr) +} + // GetCustomDomain mocks base method. func (m *MockStore) GetCustomDomain(ctx context.Context, accountID, domainID string) (*domain.Domain, error) { m.ctrl.T.Helper() @@ -2076,6 +2090,21 @@ func (mr *MockStoreMockRecorder) GetProxyClusters(ctx, accountID interface{}) *g return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProxyClusters", reflect.TypeOf((*MockStore)(nil).GetProxyClusters), ctx, accountID) } +// GetProxyMetrics mocks base method. +func (m *MockStore) GetProxyMetrics(ctx context.Context) (ProxyMetrics, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetProxyMetrics", ctx) + ret0, _ := ret[0].(ProxyMetrics) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetProxyMetrics indicates an expected call of GetProxyMetrics. +func (mr *MockStoreMockRecorder) GetProxyMetrics(ctx interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProxyMetrics", reflect.TypeOf((*MockStore)(nil).GetProxyMetrics), ctx) +} + // GetResourceGroups mocks base method. func (m *MockStore) GetResourceGroups(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) ([]*types2.Group, error) { m.ctrl.T.Helper() diff --git a/management/server/types/account.go b/management/server/types/account.go index 870333a60..dc0c5a685 100644 --- a/management/server/types/account.go +++ b/management/server/types/account.go @@ -32,7 +32,9 @@ import ( ) const ( - defaultTTL = 300 + defaultTTL = 300 + // privateServiceDNSRecordTTL is short so proxy-peer changes propagate quickly to clients. + privateServiceDNSRecordTTL = 5 DefaultPeerLoginExpiration = 24 * time.Hour DefaultPeerInactivityExpiration = 10 * time.Minute @@ -254,6 +256,117 @@ func getUniqueHostLabel(name string, peerLabels LookupMap) string { return "" } +// SynthesizePrivateServiceZones returns in-memory CustomZones with A records pointing each enabled private service the peer can reach at the cluster's proxy-peer IPs. One zone per cluster (multiple services share); records gated by AccessGroups. +func (a *Account) SynthesizePrivateServiceZones(peerID string) []nbdns.CustomZone { + peer, ok := a.Peers[peerID] + if !ok || peer == nil { + return nil + } + if len(a.Services) == 0 { + return nil + } + + proxyPeersByCluster := a.GetProxyPeers() + if len(proxyPeersByCluster) == 0 { + return nil + } + + peerGroups := a.GetPeerGroups(peerID) + zonesByCluster := map[string]*nbdns.CustomZone{} + + for _, svc := range a.Services { + if svc == nil || !svc.Enabled || !svc.Private { + continue + } + if len(svc.AccessGroups) == 0 { + continue + } + if !peerInDistributionGroups(peerGroups, svc.AccessGroups) { + continue + } + proxyPeers := proxyPeersByCluster[svc.ProxyCluster] + if len(proxyPeers) == 0 { + continue + } + + zone, exists := zonesByCluster[svc.ProxyCluster] + if !exists { + // NonAuthoritative makes this a match-only zone: queries for + // names without an explicit record fall through to the + // upstream resolver instead of returning NXDOMAIN. Without + // it, adding a single private service would black-hole every + // other name under the cluster apex. + zone = &nbdns.CustomZone{ + Domain: dns.Fqdn(svc.ProxyCluster), + Records: []nbdns.SimpleRecord{}, + NonAuthoritative: true, + } + zonesByCluster[svc.ProxyCluster] = zone + } + + emitted := 0 + skippedDisconnected := 0 + for _, p := range proxyPeers { + if p == nil || !p.IP.IsValid() { + continue + } + // Only emit a record when the proxy peer is actually + // connected. A disconnected proxy peer's tunnel IP won't + // answer; pointing DNS at it would produce a black hole + // for as long as the record is cached client-side. + if p.Status == nil || !p.Status.Connected { + skippedDisconnected++ + continue + } + zone.Records = append(zone.Records, nbdns.SimpleRecord{ + Name: dns.Fqdn(svc.Domain), + Type: int(dns.TypeA), + Class: nbdns.DefaultClass, + TTL: privateServiceDNSRecordTTL, + RData: p.IP.String(), + }) + emitted++ + } + // Disagreement with the firewall path is the typical + // "domain doesn't reach client but firewall rules do" + // symptom: the synth service is otherwise fine, only the + // proxy peer's persisted Connected flag is wrong (most + // likely the connection reaper marked it disconnected even + // though the gRPC stream is alive). + if emitted == 0 && skippedDisconnected > 0 { + log.Debugf("private-zone synth: svc %s domain=%s cluster=%s emitted_zero proxy_peers=%d all_disconnected=%d (firewall would still fire)", + svc.ID, svc.Domain, svc.ProxyCluster, len(proxyPeers), skippedDisconnected) + } + } + + out := make([]nbdns.CustomZone, 0, len(zonesByCluster)) + for _, zone := range zonesByCluster { + if len(zone.Records) == 0 { + continue + } + out = append(out, *zone) + } + if len(out) == 0 && len(a.Services) > 0 { + // Targeted diagnostic for the "firewall yes, DNS no" divergence โ€” + // fires only when services exist but synth returns zero zones, + // so accounts without private services produce no noise. + log.Debugf("private-zone synth: peer %s account %s returned 0 zones from %d candidate service(s)", + peerID, a.Id, len(a.Services)) + } + return out +} + +// peerInDistributionGroups reports whether any of the peer's groups +// matches the service's bearer-auth distribution_groups. +func peerInDistributionGroups(peerGroups LookupMap, distributionGroups []string) bool { + for _, gid := range distributionGroups { + if _, ok := peerGroups[gid]; ok { + return true + } + } + return false +} + func (a *Account) GetPeersCustomZone(ctx context.Context, dnsDomain string) nbdns.CustomZone { var merr *multierror.Error @@ -1498,6 +1611,53 @@ func (a *Account) injectServiceProxyPolicies(ctx context.Context, service *servi a.injectTargetProxyPolicies(ctx, service, target, proxyPeers) } + a.injectPrivateServicePolicies(service, proxyPeers) +} + +// injectPrivateServicePolicies synthesises an in-memory ACL: AccessGroups โ†’ cluster proxy peers on TCP 80/443. +func (a *Account) injectPrivateServicePolicies(svc *service.Service, proxyPeers []*nbpeer.Peer) { + if !svc.Private { + return + } + if len(svc.AccessGroups) == 0 { + return + } + if len(proxyPeers) == 0 { + return + } + for _, proxyPeer := range proxyPeers { + a.Policies = append(a.Policies, a.createPrivateServicePolicy(svc, proxyPeer)) + } +} + +func (a *Account) createPrivateServicePolicy(svc *service.Service, proxyPeer *nbpeer.Peer) *Policy { + policyID := fmt.Sprintf("private-access-%s-%s", svc.ID, proxyPeer.ID) + sources := append([]string(nil), svc.AccessGroups...) + return &Policy{ + ID: policyID, + Name: fmt.Sprintf("Private Access to %s", svc.Name), + Enabled: true, + Rules: []*PolicyRule{ + { + ID: policyID, + PolicyID: policyID, + Name: fmt.Sprintf("Allow access groups to reach %s", svc.Name), + Enabled: true, + Sources: sources, + DestinationResource: Resource{ + ID: proxyPeer.ID, + Type: ResourceTypePeer, + }, + Bidirectional: false, + Protocol: PolicyRuleProtocolTCP, + Action: PolicyTrafficActionAccept, + PortRanges: []RulePortRange{ + {Start: 80, End: 80}, + {Start: 443, End: 443}, + }, + }, + }, + } } func (a *Account) injectTargetProxyPolicies(ctx context.Context, service *service.Service, target *service.Target, proxyPeers []*nbpeer.Peer) { diff --git a/management/server/types/account_components.go b/management/server/types/account_components.go index 2b4f7e051..a42028351 100644 --- a/management/server/types/account_components.go +++ b/management/server/types/account_components.go @@ -119,6 +119,7 @@ func (a *Account) GetPeerNetworkMapComponents( peerGroups := a.GetPeerGroups(peerID) components.AccountZones = filterPeerAppliedZones(ctx, accountZones, peerGroups) + components.AccountZones = append(components.AccountZones, a.SynthesizePrivateServiceZones(peerID)...) for _, nsGroup := range a.NameServerGroups { if nsGroup.Enabled { diff --git a/management/server/types/account_private_netmap_test.go b/management/server/types/account_private_netmap_test.go new file mode 100644 index 000000000..dc097ce26 --- /dev/null +++ b/management/server/types/account_private_netmap_test.go @@ -0,0 +1,85 @@ +package types + +import ( + "context" + "testing" + + "github.com/miekg/dns" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + nbdns "github.com/netbirdio/netbird/dns" + nbpeer "github.com/netbirdio/netbird/management/server/peer" +) + +func TestPrivateService_NetworkMap_UserPeer_AndProxyPeer(t *testing.T) { + account := privateZoneTestAccount(t) + account.Peers["user-peer"].Meta.WtVersion = "0.50.0" + account.Peers["proxy-peer"].Meta.WtVersion = "0.50.0" + + ctx := context.Background() + account.InjectProxyPolicies(ctx) + + validated := map[string]struct{}{ + "user-peer": {}, + "proxy-peer": {}, + } + + t.Run("user-peer update", func(t *testing.T) { + nm := account.GetPeerNetworkMapFromComponents(ctx, "user-peer", nbdns.CustomZone{}, nil, validated, nil, nil, nil, nil) + require.NotNil(t, nm) + + zone, ok := findCustomZone(nm.DNSConfig.CustomZones, "eu.proxy.netbird.io") + require.True(t, ok) + require.Len(t, zone.Records, 1) + assert.Equal(t, "myapp.eu.proxy.netbird.io.", zone.Records[0].Name) + assert.Equal(t, int(dns.TypeA), zone.Records[0].Type) + assert.Equal(t, "100.64.0.99", zone.Records[0].RData) + + assert.Contains(t, netmapPeerIDs(nm.Peers), "proxy-peer") + assertPrivateServiceFirewallRules(t, nm.FirewallRules, "100.64.0.99", FirewallRuleDirectionOUT) + }) + + t.Run("proxy-peer update", func(t *testing.T) { + nm := account.GetPeerNetworkMapFromComponents(ctx, "proxy-peer", nbdns.CustomZone{}, nil, validated, nil, nil, nil, nil) + require.NotNil(t, nm) + + assert.Contains(t, netmapPeerIDs(nm.Peers), "user-peer") + assertPrivateServiceFirewallRules(t, nm.FirewallRules, "100.64.0.10", FirewallRuleDirectionIN) + }) +} + +func netmapPeerIDs(peers []*nbpeer.Peer) []string { + ids := make([]string, 0, len(peers)) + for _, p := range peers { + if p == nil { + continue + } + ids = append(ids, p.ID) + } + return ids +} + +func assertPrivateServiceFirewallRules(t *testing.T, rules []*FirewallRule, peerIP string, direction int) { + t.Helper() + wantPorts := map[uint16]bool{80: false, 443: false} + for _, r := range rules { + if r == nil || r.PeerIP != peerIP || r.Direction != direction { + continue + } + if r.Protocol != string(PolicyRuleProtocolTCP) || r.Action != string(PolicyTrafficActionAccept) { + continue + } + switch { + case r.PortRange.Start == r.PortRange.End && r.PortRange.Start != 0: + wantPorts[r.PortRange.Start] = true + case r.Port == "80": + wantPorts[80] = true + case r.Port == "443": + wantPorts[443] = true + } + } + for port, found := range wantPorts { + assert.Truef(t, found, "missing TCP accept rule on port %d for peer %s direction %d", port, peerIP, direction) + } +} diff --git a/management/server/types/account_private_zones_test.go b/management/server/types/account_private_zones_test.go new file mode 100644 index 000000000..1d4f720b7 --- /dev/null +++ b/management/server/types/account_private_zones_test.go @@ -0,0 +1,256 @@ +package types + +import ( + "context" + "net" + "net/netip" + "testing" + + "github.com/miekg/dns" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" + nbpeer "github.com/netbirdio/netbird/management/server/peer" +) + +func privateZoneTestAccount(t *testing.T) *Account { + t.Helper() + return &Account{ + Id: "acct-1", + Settings: &Settings{}, + Network: &Network{ + Identifier: "net-1", + Net: net.IPNet{IP: net.ParseIP("100.64.0.0"), Mask: net.CIDRMask(10, 32)}, + }, + Peers: map[string]*nbpeer.Peer{ + "user-peer": { + ID: "user-peer", + AccountID: "acct-1", + Key: "user-peer-key", + IP: netip.MustParseAddr("100.64.0.10"), + Status: &nbpeer.PeerStatus{Connected: true}, + }, + "proxy-peer": { + ID: "proxy-peer", + AccountID: "acct-1", + Key: "proxy-peer-key", + IP: netip.MustParseAddr("100.64.0.99"), + Status: &nbpeer.PeerStatus{Connected: true}, + ProxyMeta: nbpeer.ProxyMeta{ + Embedded: true, + Cluster: "eu.proxy.netbird.io", + }, + }, + }, + Groups: map[string]*Group{ + "grp-admins": { + ID: "grp-admins", + Name: "admins", + Peers: []string{"user-peer"}, + }, + }, + Services: []*service.Service{ + { + ID: "svc-1", + AccountID: "acct-1", + Name: "myapp", + Domain: "myapp.eu.proxy.netbird.io", + ProxyCluster: "eu.proxy.netbird.io", + Enabled: true, + Private: true, + Mode: service.ModeHTTP, + AccessGroups: []string{"grp-admins"}, + }, + }, + } +} + +func TestSynthesizePrivateServiceZones_PeerInGroup_GetsRecord(t *testing.T) { + account := privateZoneTestAccount(t) + + zones := account.SynthesizePrivateServiceZones("user-peer") + require.Len(t, zones, 1, "one cluster should produce one zone") + zone := zones[0] + assert.Equal(t, "eu.proxy.netbird.io.", zone.Domain, "zone apex must be the cluster FQDN") + assert.True(t, zone.NonAuthoritative, "synth zone must be match-only so unrelated sibling names fall through to the upstream resolver") + require.Len(t, zone.Records, 1, "one private service yields one A record") + rec := zone.Records[0] + assert.Equal(t, "myapp.eu.proxy.netbird.io.", rec.Name, "record name is the service FQDN") + assert.Equal(t, int(dns.TypeA), rec.Type, "record type must be A") + assert.Equal(t, "100.64.0.99", rec.RData, "record points at the embedded proxy peer's tunnel IP") + assert.Equal(t, privateServiceDNSRecordTTL, rec.TTL, "TTL must match the synth-records constant") + assert.Equal(t, nbdns.DefaultClass, rec.Class, "record class must be the package default") +} + +func TestSynthesizePrivateServiceZones_PeerNotInGroup_NoRecord(t *testing.T) { + account := privateZoneTestAccount(t) + account.Groups["grp-admins"].Peers = nil + + zones := account.SynthesizePrivateServiceZones("user-peer") + assert.Empty(t, zones, "peer outside distribution_groups must not see private-service records") +} + +func TestSynthesizePrivateServiceZones_NotPrivate_NoRecord(t *testing.T) { + account := privateZoneTestAccount(t) + account.Services[0].Private = false + + zones := account.SynthesizePrivateServiceZones("user-peer") + assert.Empty(t, zones, "non-private service must not produce DNS records") +} + +func TestSynthesizePrivateServiceZones_NoAccessGroups_NoRecord(t *testing.T) { + account := privateZoneTestAccount(t) + account.Services[0].AccessGroups = nil + + zones := account.SynthesizePrivateServiceZones("user-peer") + assert.Empty(t, zones, "private service without bearer auth must not produce DNS records") +} + +func TestSynthesizePrivateServiceZones_NoProxyPeers_NoRecord(t *testing.T) { + account := privateZoneTestAccount(t) + delete(account.Peers, "proxy-peer") + + zones := account.SynthesizePrivateServiceZones("user-peer") + assert.Empty(t, zones, "no embedded proxy peer in cluster means no record to emit") +} + +func TestSynthesizePrivateServiceZones_DisabledService_NoRecord(t *testing.T) { + account := privateZoneTestAccount(t) + account.Services[0].Enabled = false + + zones := account.SynthesizePrivateServiceZones("user-peer") + assert.Empty(t, zones, "disabled service must not produce DNS records") +} + +func TestSynthesizePrivateServiceZones_DisconnectedProxyPeer_NoRecord(t *testing.T) { + account := privateZoneTestAccount(t) + account.Peers["proxy-peer"].Status = &nbpeer.PeerStatus{Connected: false} + + zones := account.SynthesizePrivateServiceZones("user-peer") + assert.Empty(t, zones, "disconnected proxy peer must not produce a DNS record (would be a black hole)") +} + +func TestSynthesizePrivateServiceZones_PartiallyDisconnectedProxyPeers_OnlyConnectedSurface(t *testing.T) { + account := privateZoneTestAccount(t) + account.Peers["proxy-peer-2"] = &nbpeer.Peer{ + ID: "proxy-peer-2", + AccountID: "acct-1", + Key: "proxy-peer-2-key", + IP: netip.MustParseAddr("100.64.0.100"), + Status: &nbpeer.PeerStatus{Connected: false}, + ProxyMeta: nbpeer.ProxyMeta{Embedded: true, Cluster: "eu.proxy.netbird.io"}, + } + + zones := account.SynthesizePrivateServiceZones("user-peer") + require.Len(t, zones, 1) + require.Len(t, zones[0].Records, 1, "only the connected proxy peer must surface") + assert.Equal(t, "100.64.0.99", zones[0].Records[0].RData) +} + +func TestSynthesizePrivateServiceZones_MultipleProxyPeers_RoundRobin(t *testing.T) { + account := privateZoneTestAccount(t) + account.Peers["proxy-peer-2"] = &nbpeer.Peer{ + ID: "proxy-peer-2", + AccountID: "acct-1", + Key: "proxy-peer-2-key", + IP: netip.MustParseAddr("100.64.0.100"), + Status: &nbpeer.PeerStatus{Connected: true}, + ProxyMeta: nbpeer.ProxyMeta{Embedded: true, Cluster: "eu.proxy.netbird.io"}, + } + + zones := account.SynthesizePrivateServiceZones("user-peer") + require.Len(t, zones, 1, "still one cluster yields one zone") + require.Len(t, zones[0].Records, 2, "two proxy peers must produce two A records on the same name") + rdata := []string{zones[0].Records[0].RData, zones[0].Records[1].RData} + assert.ElementsMatch(t, []string{"100.64.0.99", "100.64.0.100"}, rdata, "both proxy peer IPs must surface") +} + +// findCustomZone returns the CustomZone whose Domain equals the FQDN +// of want, or a zero value when not found. Tests use it to assert +// that the synth zone reaches dnsUpdate.CustomZones end-to-end. +func findCustomZone(zones []nbdns.CustomZone, want string) (nbdns.CustomZone, bool) { + wantFqdn := dns.Fqdn(want) + for _, z := range zones { + if z.Domain == wantFqdn { + return z, true + } + } + return nbdns.CustomZone{}, false +} + +// TestPrivateZone_GetPeerNetworkMapFromComponents_ShipsSynthZone +// covers the components-based builder path. The components builder +// appends SynthesizePrivateServiceZones to AccountZones; the +// CalculateNetworkMapFromComponents step then merges AccountZones +// into dnsUpdate.CustomZones. +func TestPrivateZone_GetPeerNetworkMapFromComponents_ShipsSynthZone(t *testing.T) { + account := privateZoneTestAccount(t) + ctx := context.Background() + validated := map[string]struct{}{ + "user-peer": {}, + "proxy-peer": {}, + } + + nm := account.GetPeerNetworkMapFromComponents(ctx, "user-peer", nbdns.CustomZone{}, nil, validated, nil, nil, nil, nil) + require.NotNil(t, nm, "network map must be produced for an in-account peer") + + zone, ok := findCustomZone(nm.DNSConfig.CustomZones, "eu.proxy.netbird.io") + require.True(t, ok, "shipped CustomZones must include the synth zone for the cluster") + require.Len(t, zone.Records, 1, "exactly one record per private service per connected proxy peer") + rec := zone.Records[0] + assert.Equal(t, "myapp.eu.proxy.netbird.io.", rec.Name, "record name is the service FQDN") + assert.Equal(t, "100.64.0.99", rec.RData, "record points at the embedded proxy peer's tunnel IP") +} + +// TestPrivateZone_GetPeerNetworkMap_PeerOutsideGroups_OmitsSynthZone +// confirms the negative case the user encountered: a peer whose +// groups don't overlap the policy's distribution_groups gets a +// network map with no synth zone (and the wildcard / peer zones still +// flow through). This is the test mirror of the runtime confusion +// where the user looked at a non-distribution-group peer and assumed +// the synth path was broken. +func TestPrivateZone_GetPeerNetworkMap_PeerOutsideGroups_OmitsSynthZone(t *testing.T) { + account := privateZoneTestAccount(t) + account.Peers["outsider"] = &nbpeer.Peer{ + ID: "outsider", + AccountID: "acct-1", + Key: "outsider-key", + IP: netip.MustParseAddr("100.64.0.20"), + Status: &nbpeer.PeerStatus{Connected: true}, + } + ctx := context.Background() + validated := map[string]struct{}{ + "user-peer": {}, + "proxy-peer": {}, + "outsider": {}, + } + + nm := account.GetPeerNetworkMapFromComponents(ctx, "outsider", nbdns.CustomZone{}, nil, validated, nil, nil, nil, nil) + require.NotNil(t, nm) + + _, ok := findCustomZone(nm.DNSConfig.CustomZones, "eu.proxy.netbird.io") + assert.False(t, ok, "peer outside the distribution_groups must not see the synth zone") +} + +func TestSynthesizePrivateServiceZones_TwoServicesSameCluster_OneZone(t *testing.T) { + account := privateZoneTestAccount(t) + account.Services = append(account.Services, &service.Service{ + ID: "svc-2", + AccountID: "acct-1", + Name: "anotherapp", + Domain: "anotherapp.eu.proxy.netbird.io", + ProxyCluster: "eu.proxy.netbird.io", + Enabled: true, + Private: true, + Mode: service.ModeHTTP, + AccessGroups: []string{"grp-admins"}, + }) + + zones := account.SynthesizePrivateServiceZones("user-peer") + require.Len(t, zones, 1, "two services on the same cluster must collapse into one zone") + require.Len(t, zones[0].Records, 2, "two services yield two A records") + names := []string{zones[0].Records[0].Name, zones[0].Records[1].Name} + assert.ElementsMatch(t, []string{"myapp.eu.proxy.netbird.io.", "anotherapp.eu.proxy.netbird.io."}, names, "both service domains must surface") +} diff --git a/management/server/types/account_test.go b/management/server/types/account_test.go index a1a616882..b55b41638 100644 --- a/management/server/types/account_test.go +++ b/management/server/types/account_test.go @@ -3,6 +3,7 @@ package types import ( "context" "fmt" + "net" "net/netip" "testing" @@ -11,6 +12,7 @@ import ( "github.com/stretchr/testify/require" nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" "github.com/netbirdio/netbird/management/internals/modules/zones" "github.com/netbirdio/netbird/management/internals/modules/zones/records" resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" @@ -82,9 +84,9 @@ func setupTestAccount() *Account { }, Groups: map[string]*Group{ "groupAll": { - ID: "groupAll", - Name: "All", - Peers: []string{"peer1", "peer2", "peer3", "peer11", "peer12", "peer21", "peer31", "peer32", "peer41", "peer51", "peer61"}, + ID: "groupAll", + Name: "All", + Peers: []string{"peer1", "peer2", "peer3", "peer11", "peer12", "peer21", "peer31", "peer32", "peer41", "peer51", "peer61"}, Issued: GroupIssuedAPI, }, "group1": { @@ -1583,3 +1585,203 @@ func Test_filterPeerAppliedZones(t *testing.T) { }) } } + +func TestInjectPrivateServicePolicies_ProxyPeerGetsInboundRule(t *testing.T) { + ctx := context.Background() + + userPeerIP := netip.MustParseAddr("100.64.0.10") + proxyPeerIP := netip.MustParseAddr("100.64.0.99") + + account := &Account{ + Id: "acct-1", + Network: &Network{ + Identifier: "net-1", + Net: net.IPNet{IP: net.ParseIP("100.64.0.0"), Mask: net.CIDRMask(10, 32)}, + }, + Peers: map[string]*nbpeer.Peer{ + "user-peer": { + ID: "user-peer", + AccountID: "acct-1", + Key: "user-peer-key", + IP: userPeerIP, + }, + "proxy-peer": { + ID: "proxy-peer", + AccountID: "acct-1", + Key: "proxy-peer-key", + IP: proxyPeerIP, + ProxyMeta: nbpeer.ProxyMeta{ + Embedded: true, + Cluster: "eu.proxy.netbird.io", + }, + }, + }, + Groups: map[string]*Group{ + "grp-admins": { + ID: "grp-admins", + Name: "admins", + Peers: []string{"user-peer"}, + }, + }, + Services: []*service.Service{ + { + ID: "svc-1", + AccountID: "acct-1", + Name: "myapp", + Domain: "myapp.eu.proxy.netbird.io", + ProxyCluster: "eu.proxy.netbird.io", + Enabled: true, + Private: true, + Mode: service.ModeHTTP, + AccessGroups: []string{"grp-admins"}, + Targets: []*service.Target{ + { + TargetId: "eu.proxy.netbird.io", + TargetType: service.TargetTypeCluster, + Protocol: "http", + Host: "127.0.0.1", + Port: 8080, + Enabled: true, + }, + }, + }, + }, + } + + account.InjectProxyPolicies(ctx) + + var found *Policy + for _, p := range account.Policies { + if p != nil && p.ID == "private-access-svc-1-proxy-peer" { + found = p + break + } + } + require.NotNil(t, found, "expected synthesised private-access policy in account.Policies") + require.Len(t, found.Rules, 1, "policy should have exactly one rule") + rule := found.Rules[0] + assert.Equal(t, []string{"grp-admins"}, rule.Sources, "sources should be group IDs verbatim") + assert.Equal(t, "proxy-peer", rule.DestinationResource.ID, "destination resource should be the proxy peer ID") + assert.Equal(t, ResourceTypePeer, rule.DestinationResource.Type, "destination resource type should be peer") + + validatedPeersMap := map[string]struct{}{ + "user-peer": {}, + "proxy-peer": {}, + } + + proxyPeer := account.Peers["proxy-peer"] + aclPeers, firewallRules, _, _ := account.GetPeerConnectionResources(ctx, proxyPeer, validatedPeersMap, nil) + + var sawUserAsAclPeer bool + for _, p := range aclPeers { + if p.ID == "user-peer" { + sawUserAsAclPeer = true + break + } + } + assert.True(t, sawUserAsAclPeer, "proxy peer should see the user peer as an ACL peer") + + var inboundRules []*FirewallRule + for _, r := range firewallRules { + if r.Direction == FirewallRuleDirectionIN && r.PeerIP == userPeerIP.String() { + inboundRules = append(inboundRules, r) + } + } + assert.NotEmpty(t, inboundRules, "proxy peer should have inbound firewall rules from the user peer") +} + +func TestInjectPrivateServicePolicies_NotPrivate_NoPolicy(t *testing.T) { + ctx := context.Background() + account := privateServiceTestAccount(t) + account.Services[0].Private = false + + account.InjectProxyPolicies(ctx) + assert.False(t, hasPrivateAccessPolicy(account, "svc-1"), "non-private service must not synthesise an access policy") +} + +func TestInjectPrivateServicePolicies_EmptyAccessGroups_NoPolicy(t *testing.T) { + ctx := context.Background() + account := privateServiceTestAccount(t) + account.Services[0].AccessGroups = nil + + account.InjectProxyPolicies(ctx) + assert.False(t, hasPrivateAccessPolicy(account, "svc-1"), "private service with no access groups must not synthesise a policy") +} + +func TestInjectPrivateServicePolicies_NoProxyPeers_NoPolicy(t *testing.T) { + ctx := context.Background() + account := privateServiceTestAccount(t) + delete(account.Peers, "proxy-peer") + + account.InjectProxyPolicies(ctx) + assert.False(t, hasPrivateAccessPolicy(account, "svc-1"), "policy must not synthesise when the cluster has no proxy peers") +} + +func privateServiceTestAccount(t *testing.T) *Account { + t.Helper() + return &Account{ + Id: "acct-1", + Network: &Network{ + Identifier: "net-1", + Net: net.IPNet{IP: net.ParseIP("100.64.0.0"), Mask: net.CIDRMask(10, 32)}, + }, + Peers: map[string]*nbpeer.Peer{ + "user-peer": { + ID: "user-peer", + AccountID: "acct-1", + Key: "user-peer-key", + IP: netip.MustParseAddr("100.64.0.10"), + }, + "proxy-peer": { + ID: "proxy-peer", + AccountID: "acct-1", + Key: "proxy-peer-key", + IP: netip.MustParseAddr("100.64.0.99"), + ProxyMeta: nbpeer.ProxyMeta{ + Embedded: true, + Cluster: "eu.proxy.netbird.io", + }, + }, + }, + Groups: map[string]*Group{ + "grp-admins": { + ID: "grp-admins", + Name: "admins", + Peers: []string{"user-peer"}, + }, + }, + Services: []*service.Service{ + { + ID: "svc-1", + AccountID: "acct-1", + Name: "myapp", + Domain: "myapp.eu.proxy.netbird.io", + ProxyCluster: "eu.proxy.netbird.io", + Enabled: true, + Private: true, + Mode: service.ModeHTTP, + AccessGroups: []string{"grp-admins"}, + Targets: []*service.Target{ + { + TargetId: "eu.proxy.netbird.io", + TargetType: service.TargetTypeCluster, + Protocol: "http", + Host: "127.0.0.1", + Port: 8080, + Enabled: true, + }, + }, + }, + }, + } +} + +func hasPrivateAccessPolicy(account *Account, serviceID string) bool { + prefix := "private-access-" + serviceID + "-" + for _, p := range account.Policies { + if p != nil && len(p.ID) > len(prefix) && p.ID[:len(prefix)] == prefix { + return true + } + } + return false +} diff --git a/management/server/users/manager.go b/management/server/users/manager.go index e07f28706..1a05b1a7c 100644 --- a/management/server/users/manager.go +++ b/management/server/users/manager.go @@ -10,6 +10,7 @@ import ( type Manager interface { GetUser(ctx context.Context, userID string) (*types.User, error) + GetUserWithGroups(ctx context.Context, userID string) (*types.User, []*types.Group, error) } type managerImpl struct { @@ -29,6 +30,31 @@ func (m *managerImpl) GetUser(ctx context.Context, userID string) (*types.User, return m.store.GetUserByUserID(ctx, store.LockingStrengthNone, userID) } +// GetUserWithGroups returns the user and the *types.Group records for the user's AutoGroups, in the same order as +// AutoGroups. Group ids that don't resolve to a stored group are skipped from the returned slice (the parallel id list is +// derivable from the returned User). Wraps two store calls today; can be optimised to a single JOIN later if needed. +// Any store error returns (nil, nil, err) so callers never receive a valid user alongside a non-nil error. +func (m *managerImpl) GetUserWithGroups(ctx context.Context, userID string) (*types.User, []*types.Group, error) { + user, err := m.store.GetUserByUserID(ctx, store.LockingStrengthNone, userID) + if err != nil { + return nil, nil, err + } + if len(user.AutoGroups) == 0 { + return user, nil, nil + } + groupsMap, err := m.store.GetGroupsByIDs(ctx, store.LockingStrengthNone, user.AccountID, user.AutoGroups) + if err != nil { + return nil, nil, err + } + groups := make([]*types.Group, 0, len(user.AutoGroups)) + for _, id := range user.AutoGroups { + if g, ok := groupsMap[id]; ok && g != nil { + groups = append(groups, g) + } + } + return user, groups, nil +} + func NewManagerMock() Manager { return &managerMock{} } @@ -47,3 +73,11 @@ func (m *managerMock) GetUser(ctx context.Context, userID string) (*types.User, return nil, errors.New("user not found") } } + +func (m *managerMock) GetUserWithGroups(ctx context.Context, userID string) (*types.User, []*types.Group, error) { + user, err := m.GetUser(ctx, userID) + if err != nil { + return nil, nil, err + } + return user, nil, nil +} diff --git a/proxy/auth/auth.go b/proxy/auth/auth.go index ca9c260b7..78f0097d5 100644 --- a/proxy/auth/auth.go +++ b/proxy/auth/auth.go @@ -45,10 +45,14 @@ func ResolveProto(forwardedProto string, conn *tls.ConnectionState) string { } } -// ValidateSessionJWT validates a session JWT and returns the user ID and method. -func ValidateSessionJWT(tokenString, domain string, publicKey ed25519.PublicKey) (userID, method string, err error) { +// ValidateSessionJWT validates a session JWT and returns the user ID, the +// user's email (when carried), the authentication method, any embedded +// group memberships, and the parallel group display names. email, +// groups, and groupNames may be empty for tokens minted before those +// claims were introduced. groupNames pairs positionally with groups. +func ValidateSessionJWT(tokenString, domain string, publicKey ed25519.PublicKey) (userID, email, method string, groups, groupNames []string, err error) { if publicKey == nil { - return "", "", fmt.Errorf("no public key configured for domain") + return "", "", "", nil, nil, fmt.Errorf("no public key configured for domain") } token, err := jwt.Parse(tokenString, func(t *jwt.Token) (interface{}, error) { @@ -58,20 +62,46 @@ func ValidateSessionJWT(tokenString, domain string, publicKey ed25519.PublicKey) return publicKey, nil }, jwt.WithAudience(domain), jwt.WithIssuer(SessionJWTIssuer)) if err != nil { - return "", "", fmt.Errorf("parse token: %w", err) + return "", "", "", nil, nil, fmt.Errorf("parse token: %w", err) } claims, ok := token.Claims.(jwt.MapClaims) if !ok || !token.Valid { - return "", "", fmt.Errorf("invalid token claims") + return "", "", "", nil, nil, fmt.Errorf("invalid token claims") } sub, _ := claims.GetSubject() if sub == "" { - return "", "", fmt.Errorf("missing subject claim") + return "", "", "", nil, nil, fmt.Errorf("missing subject claim") } methodClaim, _ := claims["method"].(string) + emailClaim, _ := claims["email"].(string) + groups = extractGroupsClaim(claims["groups"]) + groupNames = extractGroupsClaim(claims["group_names"]) - return sub, methodClaim, nil + return sub, emailClaim, methodClaim, groups, groupNames, nil +} + +// extractGroupsClaim decodes the "groups" claim into a string slice. The JWT +// library decodes JSON arrays as []interface{}, so we coerce element-wise +// and skip non-string entries silently. +func extractGroupsClaim(claim interface{}) []string { + raw, ok := claim.([]interface{}) + if !ok { + return nil + } + if len(raw) == 0 { + return nil + } + groups := make([]string, 0, len(raw)) + for _, v := range raw { + if s, ok := v.(string); ok && s != "" { + groups = append(groups, s) + } + } + if len(groups) == 0 { + return nil + } + return groups } diff --git a/proxy/cmd/proxy/cmd/root.go b/proxy/cmd/proxy/cmd/root.go index ec8980ad9..5970886da 100644 --- a/proxy/cmd/proxy/cmd/root.go +++ b/proxy/cmd/proxy/cmd/root.go @@ -63,6 +63,7 @@ var ( preSharedKey string supportsCustomPorts bool requireSubdomain bool + private bool geoDataDir string crowdsecAPIURL string crowdsecAPIKey string @@ -105,6 +106,8 @@ func init() { rootCmd.Flags().StringVar(&preSharedKey, "preshared-key", envStringOrDefault("NB_PROXY_PRESHARED_KEY", ""), "Define a pre-shared key for the tunnel between proxy and peers") rootCmd.Flags().BoolVar(&supportsCustomPorts, "supports-custom-ports", envBoolOrDefault("NB_PROXY_SUPPORTS_CUSTOM_PORTS", true), "Whether the proxy can bind arbitrary ports for UDP/TCP passthrough") rootCmd.Flags().BoolVar(&requireSubdomain, "require-subdomain", envBoolOrDefault("NB_PROXY_REQUIRE_SUBDOMAIN", false), "Require a subdomain label in front of the cluster domain") + rootCmd.Flags().BoolVar(&private, "private", envBoolOrDefault("NB_PROXY_PRIVATE", false), "Enable private services accessible with NetBird-Only authentication mode.") + _ = rootCmd.Flags().MarkHidden("private") rootCmd.Flags().DurationVar(&maxDialTimeout, "max-dial-timeout", envDurationOrDefault("NB_PROXY_MAX_DIAL_TIMEOUT", 0), "Cap per-service backend dial timeout (0 = no cap)") rootCmd.Flags().DurationVar(&maxSessionIdleTimeout, "max-session-idle-timeout", envDurationOrDefault("NB_PROXY_MAX_SESSION_IDLE_TIMEOUT", 0), "Cap per-service session idle timeout (0 = no cap)") rootCmd.Flags().StringVar(&geoDataDir, "geo-data-dir", envStringOrDefault("NB_PROXY_GEO_DATA_DIR", "/var/lib/netbird/geolocation"), "Directory for the GeoLite2 MMDB file (auto-downloaded if missing)") @@ -161,7 +164,8 @@ func runServer(cmd *cobra.Command, args []string) error { return fmt.Errorf("invalid --trusted-proxies: %w", err) } - srv := proxy.Server{ + srv := proxy.New(proxy.Config{ + ListenAddr: addr, Logger: logger, Version: Version, ManagementAddress: mgmtAddr, @@ -178,7 +182,7 @@ func runServer(cmd *cobra.Command, args []string) error { ACMEChallengeType: acmeChallengeType, DebugEndpointEnabled: debugEndpoint, DebugEndpointAddress: debugEndpointAddr, - HealthAddress: healthAddr, + HealthAddr: healthAddr, ForwardedProto: forwardedProto, TrustedProxies: parsedTrustedProxies, CertLockMethod: nbacme.CertLockMethod(certLockMethod), @@ -188,12 +192,13 @@ func runServer(cmd *cobra.Command, args []string) error { PreSharedKey: preSharedKey, SupportsCustomPorts: supportsCustomPorts, RequireSubdomain: requireSubdomain, + Private: private, MaxDialTimeout: maxDialTimeout, MaxSessionIdleTimeout: maxSessionIdleTimeout, GeoDataDir: geoDataDir, CrowdSecAPIURL: crowdsecAPIURL, CrowdSecAPIKey: crowdsecAPIKey, - } + }) ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGTERM, syscall.SIGINT) defer stop() diff --git a/proxy/inbound.go b/proxy/inbound.go new file mode 100644 index 000000000..8165b331f --- /dev/null +++ b/proxy/inbound.go @@ -0,0 +1,547 @@ +package proxy + +import ( + "context" + "crypto/tls" + "errors" + "fmt" + stdlog "log" + "net" + "net/http" + "net/netip" + "strconv" + "sync" + "time" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/embed" + "github.com/netbirdio/netbird/proxy/internal/auth" + "github.com/netbirdio/netbird/proxy/internal/debug" + nbtcp "github.com/netbirdio/netbird/proxy/internal/tcp" + "github.com/netbirdio/netbird/proxy/internal/types" +) + +// httpInboundReadHeaderTimeout matches the host-listener read header timeout +// so per-account http.Servers don't leak idle connections. +const httpInboundReadHeaderTimeout = 30 * time.Second + +// httpInboundIdleTimeout caps idle keep-alive on per-account inbound HTTP +// servers; matches the host listener. +const httpInboundIdleTimeout = 90 * time.Second + +// inboundShutdownTimeout caps how long a per-account http.Server gets to +// drain in-flight requests during teardown. +const inboundShutdownTimeout = 5 * time.Second + +// privateInboundPortHTTPS is the WG-side TLS port. Each account's +// embedded netstack binds independently, so a fixed port is fine. +const privateInboundPortHTTPS = 443 + +// privateInboundPortHTTP is the WG-side plain-HTTP port. +const privateInboundPortHTTP = 80 + +// inboundManager wires per-account inbound listeners into the proxy +// pipeline when --private-inbound is enabled. When disabled the manager +// is nil and every method on *Server that touches it short-circuits. +type inboundManager struct { + logger *log.Logger + handler http.Handler + tlsConfig *tls.Config + // muxLock guards entries and pendingRoutes. + muxLock sync.Mutex + entries map[types.AccountID]*inboundEntry + pendingRoutes map[types.AccountID][]pendingInboundRoute +} + +// inboundEntry owns the listeners, router and HTTP servers for a single +// account's embedded netstack. +type inboundEntry struct { + router *nbtcp.Router + tlsListener net.Listener + plainListener net.Listener + httpsServer *http.Server + httpServer *http.Server + cancel context.CancelFunc + wg sync.WaitGroup +} + +// pendingInboundRoute holds a route that arrived before the account's +// listener finished starting. +type pendingInboundRoute struct { + host nbtcp.SNIHost + route nbtcp.Route +} + +// newInboundManager constructs a manager bound to the proxy's HTTP +// handler chain and TLS config. +func newInboundManager(logger *log.Logger, handler http.Handler, tlsConfig *tls.Config) *inboundManager { + return &inboundManager{ + logger: logger, + handler: handler, + tlsConfig: tlsConfig, + entries: make(map[types.AccountID]*inboundEntry), + pendingRoutes: make(map[types.AccountID][]pendingInboundRoute), + } +} + +// onClientReady is registered with NetBird.SetClientLifecycle so the +// listener pair comes up exactly when the embedded client reports ready. +// The returned value is opaque to the roundtrip package; it is handed +// back verbatim to onClientStop on teardown. +func (m *inboundManager) onClientReady(ctx context.Context, accountID types.AccountID, client *embed.Client) any { + if m == nil { + return nil + } + entry, err := m.bringUp(ctx, accountID, client) + if err != nil { + m.logger.WithField("account_id", accountID).WithError(err).Warn("failed to start per-account inbound listener; continuing without inbound") + return nil + } + + m.flushPending(accountID, entry) + + m.logger.WithFields(log.Fields{ + "account_id": accountID, + "https": entry.tlsListener.Addr().String(), + "http": entry.plainListener.Addr().String(), + }).Info("per-account inbound listeners up") + return entry +} + +// onClientStop tears down a per-account listener bundle. State is the +// opaque value previously returned by onClientReady. +func (m *inboundManager) onClientStop(accountID types.AccountID, state any) { + if m == nil { + return + } + entry, ok := state.(*inboundEntry) + if !ok || entry == nil { + return + } + m.tearDown(accountID, entry) +} + +// bringUp opens both listeners on the account's netstack, builds the +// router, and starts the parallel HTTP servers. +func (m *inboundManager) bringUp(ctx context.Context, accountID types.AccountID, client *embed.Client) (*inboundEntry, error) { + tlsListener, err := client.ListenTCP(fmt.Sprintf(":%d", privateInboundPortHTTPS)) + if err != nil { + return nil, fmt.Errorf("listen tls on netstack: %w", err) + } + plainListener, err := client.ListenTCP(fmt.Sprintf(":%d", privateInboundPortHTTP)) + if err != nil { + _ = tlsListener.Close() + return nil, fmt.Errorf("listen plain on netstack: %w", err) + } + + router := nbtcp.NewRouter(m.logger, accountDialResolver(accountID, client), tlsListener.Addr(), nbtcp.WithPlainHTTP(plainListener.Addr())) + + scopedHandler := withTunnelLookup(m.handler, accountTunnelLookup(client)) + + // markOverlayOrigin stamps every connection accepted by an inbound + // listener with a context value middlewares can read to skip + // geo/CrowdSec checks (the source address is always inside the + // NetBird CGNAT range and won't match either dataset). + markOverlayOrigin := func(ctx context.Context, _ net.Conn) context.Context { + return types.WithOverlayOrigin(ctx) + } + + httpsServer := &http.Server{ + Handler: scopedHandler, + TLSConfig: m.tlsConfig, + ReadHeaderTimeout: httpInboundReadHeaderTimeout, + IdleTimeout: httpInboundIdleTimeout, + ErrorLog: newInboundErrorLog(m.logger, "https", accountID), + ConnContext: markOverlayOrigin, + } + httpServer := &http.Server{ + Handler: scopedHandler, + ReadHeaderTimeout: httpInboundReadHeaderTimeout, + IdleTimeout: httpInboundIdleTimeout, + ErrorLog: newInboundErrorLog(m.logger, "http", accountID), + ConnContext: markOverlayOrigin, + } + + runCtx, cancel := context.WithCancel(ctx) + entry := &inboundEntry{ + router: router, + tlsListener: tlsListener, + plainListener: plainListener, + httpsServer: httpsServer, + httpServer: httpServer, + cancel: cancel, + } + + entry.wg.Add(1) + go func() { + defer entry.wg.Done() + if err := router.Serve(runCtx, tlsListener); err != nil { + m.logger.WithField("account_id", accountID).Debugf("per-account router stopped: %v", err) + } + }() + + entry.wg.Add(1) + go func() { + defer entry.wg.Done() + if err := httpsServer.ServeTLS(router.HTTPListener(), "", ""); err != nil && !errors.Is(err, http.ErrServerClosed) { + m.logger.WithField("account_id", accountID).Debugf("per-account https server stopped: %v", err) + } + }() + + entry.wg.Add(1) + go func() { + defer entry.wg.Done() + if err := httpServer.Serve(router.HTTPListenerPlain()); err != nil && !errors.Is(err, http.ErrServerClosed) { + m.logger.WithField("account_id", accountID).Debugf("per-account http server stopped: %v", err) + } + }() + + entry.wg.Add(1) + go func() { + defer entry.wg.Done() + feedRouterFromListener(runCtx, plainListener, router, m.logger, accountID) + }() + + m.muxLock.Lock() + m.entries[accountID] = entry + m.muxLock.Unlock() + + return entry, nil +} + +// tearDown shuts every goroutine down and closes the netstack listeners. +func (m *inboundManager) tearDown(accountID types.AccountID, entry *inboundEntry) { + m.muxLock.Lock() + if m.entries[accountID] == entry { + delete(m.entries, accountID) + delete(m.pendingRoutes, accountID) + } + m.muxLock.Unlock() + + entry.cancel() + + shutdownCtx, cancel := context.WithTimeout(context.Background(), inboundShutdownTimeout) + defer cancel() + + if err := entry.httpsServer.Shutdown(shutdownCtx); err != nil { + m.logger.Debugf("per-account https shutdown: %v", err) + } + if err := entry.httpServer.Shutdown(shutdownCtx); err != nil { + m.logger.Debugf("per-account http shutdown: %v", err) + } + if err := entry.tlsListener.Close(); err != nil { + m.logger.Debugf("close per-account tls listener: %v", err) + } + if err := entry.plainListener.Close(); err != nil { + m.logger.Debugf("close per-account plain listener: %v", err) + } + entry.wg.Wait() +} + +// AddRoute records an SNI/host route on the account's per-account router. +// Routes registered before the listener is up are queued and replayed +// once startup completes. +func (m *inboundManager) AddRoute(accountID types.AccountID, host nbtcp.SNIHost, route nbtcp.Route) { + if m == nil { + return + } + m.muxLock.Lock() + entry, ok := m.entries[accountID] + if !ok { + m.queuePendingLocked(accountID, host, route) + m.muxLock.Unlock() + return + } + router := entry.router + m.muxLock.Unlock() + + router.AddRoute(host, route) +} + +// RemoveRoute drops a previously registered route. Safe to call when the +// listener is not yet up; queued copies are pruned in that case. +func (m *inboundManager) RemoveRoute(accountID types.AccountID, host nbtcp.SNIHost, svcID types.ServiceID) { + if m == nil { + return + } + m.muxLock.Lock() + m.dropPendingLocked(accountID, host, svcID) + entry, ok := m.entries[accountID] + if !ok { + m.muxLock.Unlock() + return + } + router := entry.router + m.muxLock.Unlock() + + router.RemoveRoute(host, svcID) +} + +// queuePendingLocked stores or upserts a pending route. Caller holds muxLock. +func (m *inboundManager) queuePendingLocked(accountID types.AccountID, host nbtcp.SNIHost, route nbtcp.Route) { + queued := m.pendingRoutes[accountID] + for i, pr := range queued { + if pr.host == host && pr.route.ServiceID == route.ServiceID { + queued[i] = pendingInboundRoute{host: host, route: route} + m.pendingRoutes[accountID] = queued + return + } + } + m.pendingRoutes[accountID] = append(queued, pendingInboundRoute{host: host, route: route}) +} + +// dropPendingLocked removes any queued route matching host/svcID. +// Caller holds muxLock. +func (m *inboundManager) dropPendingLocked(accountID types.AccountID, host nbtcp.SNIHost, svcID types.ServiceID) { + queued, ok := m.pendingRoutes[accountID] + if !ok { + return + } + filtered := queued[:0] + for _, pr := range queued { + if pr.host == host && pr.route.ServiceID == svcID { + continue + } + filtered = append(filtered, pr) + } + if len(filtered) == 0 { + delete(m.pendingRoutes, accountID) + return + } + m.pendingRoutes[accountID] = filtered +} + +// flushPending applies all queued routes to a freshly-up router. +func (m *inboundManager) flushPending(accountID types.AccountID, entry *inboundEntry) { + m.muxLock.Lock() + queued := m.pendingRoutes[accountID] + delete(m.pendingRoutes, accountID) + m.muxLock.Unlock() + + for _, pr := range queued { + entry.router.AddRoute(pr.host, pr.route) + } +} + +// HasInbound reports whether the manager has a live listener for the account. +// Used by tests. +func (m *inboundManager) HasInbound(accountID types.AccountID) bool { + if m == nil { + return false + } + m.muxLock.Lock() + defer m.muxLock.Unlock() + _, ok := m.entries[accountID] + return ok +} + +// PendingRouteCount reports the number of queued routes for the account. +// Used by tests. +func (m *inboundManager) PendingRouteCount(accountID types.AccountID) int { + if m == nil { + return 0 + } + m.muxLock.Lock() + defer m.muxLock.Unlock() + return len(m.pendingRoutes[accountID]) +} + +// InboundListenerInfo describes the bound addresses of a single +// per-account inbound listener. Both addresses live on the embedded +// netstack of the account's WireGuard client and share the same tunnel IP. +type InboundListenerInfo struct { + TunnelIP string + HTTPSPort uint16 + HTTPPort uint16 +} + +// ListenerInfo returns the inbound listener addresses for the given +// account, or ok=false when the account has no live listener. Used by +// the status-update RPC and the debug HTTP handler to surface inbound +// reachability to operators. +func (m *inboundManager) ListenerInfo(accountID types.AccountID) (InboundListenerInfo, bool) { + if m == nil { + return InboundListenerInfo{}, false + } + m.muxLock.Lock() + defer m.muxLock.Unlock() + entry, ok := m.entries[accountID] + if !ok || entry == nil { + return InboundListenerInfo{}, false + } + return listenerInfoFromEntry(entry), true +} + +// Snapshot returns the inbound listener state for every account that has +// a live listener at call time. Empty when --private-inbound is off or +// no accounts have come up yet. +func (m *inboundManager) Snapshot() map[types.AccountID]InboundListenerInfo { + if m == nil { + return nil + } + m.muxLock.Lock() + defer m.muxLock.Unlock() + if len(m.entries) == 0 { + return nil + } + out := make(map[types.AccountID]InboundListenerInfo, len(m.entries)) + for id, entry := range m.entries { + if entry == nil { + continue + } + out[id] = listenerInfoFromEntry(entry) + } + return out +} + +// listenerInfoFromEntry extracts the tunnel IP and ports from a live +// per-account entry. Both listeners are bound on the same netstack so +// their host components match; we still pull the TLS host as the +// authoritative source. +func listenerInfoFromEntry(entry *inboundEntry) InboundListenerInfo { + info := InboundListenerInfo{HTTPSPort: privateInboundPortHTTPS, HTTPPort: privateInboundPortHTTP} + if entry.tlsListener != nil { + host, port := splitHostPort(entry.tlsListener.Addr()) + info.TunnelIP = host + if port != 0 { + info.HTTPSPort = port + } + } + if entry.plainListener != nil { + host, port := splitHostPort(entry.plainListener.Addr()) + if info.TunnelIP == "" { + info.TunnelIP = host + } + if port != 0 { + info.HTTPPort = port + } + } + return info +} + +// splitHostPort extracts host and port from a net.Addr, returning the +// zero values when the address is missing or malformed. +func splitHostPort(addr net.Addr) (string, uint16) { + if addr == nil { + return "", 0 + } + host, portStr, err := net.SplitHostPort(addr.String()) + if err != nil { + return "", 0 + } + if portStr == "" { + return host, 0 + } + port, err := strconv.ParseUint(portStr, 10, 16) + if err != nil { + return host, 0 + } + return host, uint16(port) +} + +// feedRouterFromListener accepts on the plain-HTTP netstack listener and +// hands every connection to the account's router. The router peeks the +// first byte and dispatches to the plain-HTTP channel for non-TLS +// streams or the TLS channel for ClientHellos that arrive on :80. +func feedRouterFromListener(ctx context.Context, ln net.Listener, router *nbtcp.Router, logger *log.Logger, accountID types.AccountID) { + go func() { + <-ctx.Done() + _ = ln.Close() + }() + + for { + conn, err := ln.Accept() + if err != nil { + if ctx.Err() != nil || errors.Is(err, net.ErrClosed) { + return + } + logger.WithField("account_id", accountID).Debugf("plain inbound accept: %v", err) + continue + } + router.HandleConn(ctx, conn) + } +} + +// accountDialResolver returns a DialResolver bound to a single account's +// embedded client. The router only ever serves traffic for that account +// so the supplied accountID is ignored at dial time. +func accountDialResolver(_ types.AccountID, client *embed.Client) nbtcp.DialResolver { + return func(_ types.AccountID) (types.DialContextFunc, error) { + return client.DialContext, nil + } +} + +// accountTunnelLookup returns a TunnelLookupFunc backed by the embedded +// client's peerstore for a single account. Phase 3 uses the result to +// short-circuit ValidateTunnelPeer when the source IP is not in the +// account's roster and to seed the cached identity for known peers. +func accountTunnelLookup(client *embed.Client) auth.TunnelLookupFunc { + if client == nil { + return nil + } + return func(ip netip.Addr) (auth.PeerIdentity, bool) { + pubKey, fqdn, ok := client.IdentityForIP(ip) + if !ok { + return auth.PeerIdentity{}, false + } + return auth.PeerIdentity{ + PubKey: pubKey, + TunnelIP: ip, + FQDN: fqdn, + }, true + } +} + +// withTunnelLookup returns an http.Handler that attaches the per-account +// peerstore lookup to every request's context before delegating to next. +// Calling on the host-level listener is a no-op because that path never +// installs this wrapper, so the existing behaviour stays byte-for-byte +// identical when --private-inbound is off or the request didn't arrive +// on a per-account listener. +func withTunnelLookup(next http.Handler, lookup auth.TunnelLookupFunc) http.Handler { + if lookup == nil { + return next + } + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := auth.WithTunnelLookup(r.Context(), lookup) + next.ServeHTTP(w, r.WithContext(ctx)) + }) +} + +// inboundDebugAdapter adapts *inboundManager to the debug.InboundProvider +// interface so the debug HTTP handler can render per-account inbound +// listener state without importing the proxy package. +type inboundDebugAdapter struct { + mgr *inboundManager +} + +// InboundListeners returns a snapshot of the live per-account inbound +// listeners formatted for the debug surface. +func (a inboundDebugAdapter) InboundListeners() map[types.AccountID]debug.InboundListenerInfo { + if a.mgr == nil { + return nil + } + snap := a.mgr.Snapshot() + if len(snap) == 0 { + return nil + } + out := make(map[types.AccountID]debug.InboundListenerInfo, len(snap)) + for id, info := range snap { + out[id] = debug.InboundListenerInfo{ + TunnelIP: info.TunnelIP, + HTTPSPort: info.HTTPSPort, + HTTPPort: info.HTTPPort, + } + } + return out +} + +// newInboundErrorLog routes a per-account http.Server's stdlib error +// stream through logrus at warn level. +func newInboundErrorLog(logger *log.Logger, scheme string, accountID types.AccountID) *stdlog.Logger { + return stdlog.New(logger.WithFields(log.Fields{ + "inbound-http": scheme, + "account_id": accountID, + }).WriterLevel(log.WarnLevel), "", 0) +} diff --git a/proxy/inbound_test.go b/proxy/inbound_test.go new file mode 100644 index 000000000..a868f1c12 --- /dev/null +++ b/proxy/inbound_test.go @@ -0,0 +1,502 @@ +package proxy + +import ( + "bufio" + "context" + "crypto/tls" + "net" + "net/http" + "net/http/httptest" + "net/netip" + "sync" + "sync/atomic" + "testing" + "time" + + log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/grpc" + + "github.com/netbirdio/netbird/proxy/internal/auth" + "github.com/netbirdio/netbird/proxy/internal/roundtrip" + nbtcp "github.com/netbirdio/netbird/proxy/internal/tcp" + "github.com/netbirdio/netbird/proxy/internal/types" + "github.com/netbirdio/netbird/shared/management/proto" +) + +// bufioReader wraps the connection in a buffered reader so http.ReadResponse +// can parse the response line + headers off the wire. +func bufioReader(conn net.Conn) *bufio.Reader { + return bufio.NewReader(conn) +} + +// quietLogger returns a logger that emits nothing โ€” keeps test output tidy. +func quietLogger() *log.Logger { + logger := log.New() + logger.SetLevel(log.PanicLevel) + return logger +} + +func TestInboundManager_RouteScopedToAccount(t *testing.T) { + mgr := newInboundManager(quietLogger(), http.NotFoundHandler(), nil) + + accountA := types.AccountID("acct-a") + accountB := types.AccountID("acct-b") + + mgr.AddRoute(accountA, "shared.example", nbtcp.Route{Type: nbtcp.RouteHTTP, AccountID: accountA, ServiceID: "svc-a", Domain: "shared.example"}) + mgr.AddRoute(accountB, "other.example", nbtcp.Route{Type: nbtcp.RouteHTTP, AccountID: accountB, ServiceID: "svc-b", Domain: "other.example"}) + + require.Equal(t, 1, mgr.PendingRouteCount(accountA), "account A should have one queued route") + require.Equal(t, 1, mgr.PendingRouteCount(accountB), "account B should have one queued route") + + mgr.RemoveRoute(accountA, "shared.example", "svc-a") + mgr.RemoveRoute(accountB, "other.example", "svc-b") + + assert.Equal(t, 0, mgr.PendingRouteCount(accountA), "queue should drain on remove") + assert.Equal(t, 0, mgr.PendingRouteCount(accountB), "queue should drain on remove") +} + +func TestInboundManager_PendingThenFlush(t *testing.T) { + mgr := newInboundManager(quietLogger(), http.NotFoundHandler(), nil) + + accountID := types.AccountID("acct-1") + host := nbtcp.SNIHost("example.test") + route := nbtcp.Route{Type: nbtcp.RouteHTTP, AccountID: accountID, ServiceID: "svc-1", Domain: "example.test"} + + mgr.AddRoute(accountID, host, route) + require.Equal(t, 1, mgr.PendingRouteCount(accountID), "pending count before listener is up") + + // Simulate listener up by registering a fake entry, then flushing. + router := nbtcp.NewRouter(quietLogger(), nil, &fakeAddr{addr: "127.0.0.1:0"}) + entry := &inboundEntry{router: router} + mgr.muxLock.Lock() + mgr.entries[accountID] = entry + mgr.muxLock.Unlock() + + mgr.flushPending(accountID, entry) + assert.Equal(t, 0, mgr.PendingRouteCount(accountID), "queue should be empty after flush") +} + +// fakeAddr is a stub net.Addr for tests that don't actually bind sockets. +type fakeAddr struct { + addr string +} + +func (a *fakeAddr) Network() string { return "tcp" } +func (a *fakeAddr) String() string { return a.addr } + +// fakeMgmtClient implements roundtrip.managementClient for tests. +type fakeMgmtClient struct{} + +func (fakeMgmtClient) CreateProxyPeer(_ context.Context, _ *proto.CreateProxyPeerRequest, _ ...grpc.CallOption) (*proto.CreateProxyPeerResponse, error) { + return &proto.CreateProxyPeerResponse{Success: true}, nil +} + +// TestServer_PrivateInbound_NotEnabled_NoManager confirms that with +// --private off the inbound manager is nil and the standalone proxy +// keeps its zero-overhead default path. +func TestServer_PrivateInbound_NotEnabled_NoManager(t *testing.T) { + s := &Server{Logger: quietLogger(), Private: false} + s.initPrivateInbound(http.NotFoundHandler(), nil) + assert.Nil(t, s.inbound, "manager should remain nil when --private is off") +} + +// TestServer_PrivateInbound_Enabled_WiresLifecycle confirms that +// --private alone wires the manager into the NetBird transport, so +// AddPeer / RemovePeer drive the lifecycle. +func TestServer_PrivateInbound_Enabled_WiresLifecycle(t *testing.T) { + s := &Server{Logger: quietLogger(), Private: true} + // Construct a NetBird transport. We can't actually start the embedded + // client here (that needs a real management server), but we can + // confirm that the lifecycle callbacks are registered. + s.netbird = roundtrip.NewNetBird("test", "test", roundtrip.ClientConfig{ + MgmtAddr: "http://invalid.test", + }, quietLogger(), nil, fakeMgmtClient{}) + + s.initPrivateInbound(http.NotFoundHandler(), &tls.Config{}) //nolint:gosec + require.NotNil(t, s.inbound, "manager should be set when --private is on") + assert.NotNil(t, s.inbound.handler, "handler should be set on manager") + assert.NotNil(t, s.inbound.tlsConfig, "tls config should be set on manager") +} + +// TestInboundManager_AddRouteAfterReady_RegistersDirectly verifies that +// when the listener is already up, AddRoute writes straight to the +// router without queueing. +func TestInboundManager_AddRouteAfterReady_RegistersDirectly(t *testing.T) { + mgr := newInboundManager(quietLogger(), http.NotFoundHandler(), nil) + accountID := types.AccountID("acct-1") + router := nbtcp.NewRouter(quietLogger(), nil, &fakeAddr{addr: "127.0.0.1:0"}) + + mgr.muxLock.Lock() + mgr.entries[accountID] = &inboundEntry{router: router} + mgr.muxLock.Unlock() + + host := nbtcp.SNIHost("ready.example") + mgr.AddRoute(accountID, host, nbtcp.Route{Type: nbtcp.RouteHTTP, AccountID: accountID, ServiceID: "svc-ready", Domain: string(host)}) + assert.Equal(t, 0, mgr.PendingRouteCount(accountID), "no pending entries when listener is up") +} + +// TestPrivateCapability_DerivedFromPrivateOnly tests that the capability +// bit reported upstream tracks --private exclusively. The previous +// --private-inbound flag has been folded into --private. +func TestPrivateCapability_DerivedFromPrivateOnly(t *testing.T) { + tests := []struct { + name string + private bool + expected bool + }{ + {"off", false, false}, + {"on", true, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &Server{Private: tt.private} + assert.Equal(t, tt.expected, s.Private, "private capability bit should match --private") + }) + } +} + +// TestInboundManager_RouteScopedToAccountB_DoesNotMatchA verifies that a +// service registered for account B is invisible to a router serving +// account A. We exercise the path through real per-account routers. +func TestInboundManager_RouteScopedToAccountB_DoesNotMatchA(t *testing.T) { + mgr := newInboundManager(quietLogger(), http.NotFoundHandler(), nil) + + accountA := types.AccountID("acct-a") + accountB := types.AccountID("acct-b") + routerA := nbtcp.NewRouter(quietLogger(), nil, &fakeAddr{addr: "127.0.0.1:0"}) + routerB := nbtcp.NewRouter(quietLogger(), nil, &fakeAddr{addr: "127.0.0.1:0"}) + + mgr.muxLock.Lock() + mgr.entries[accountA] = &inboundEntry{router: routerA} + mgr.entries[accountB] = &inboundEntry{router: routerB} + mgr.muxLock.Unlock() + + host := nbtcp.SNIHost("shared.example") + mgr.AddRoute(accountB, host, nbtcp.Route{Type: nbtcp.RouteHTTP, AccountID: accountB, ServiceID: "svc-b", Domain: string(host)}) + + // Account A's router should have no routes; account B's should have one. + // We check via IsEmpty โ€” true means no routes and no fallback. + assert.True(t, routerA.IsEmpty(), "account A router must not see account B's mappings") + assert.False(t, routerB.IsEmpty(), "account B router should hold its own mapping") +} + +// TestInboundEntry_ShutdownIdempotent ensures that tearDown can run twice +// without panicking โ€” callers may invoke it from RemovePeer + StopAll. +func TestInboundEntry_ShutdownIdempotent(t *testing.T) { + t.Skip("teardown requires real netstack listeners; covered by integration tests") +} + +// TestRouter_PlainHTTP_ForwardedProtoIsHTTP exercises the full per-account +// router pipeline against a loopback listener (proxy of a netstack +// listener for test purposes): a plain HTTP request lands on the plain +// http.Server and the inner handler observes a nil r.TLS, which is what +// auth.ResolveProto translates to "http" in the real pipeline. +func TestRouter_PlainHTTP_ForwardedProtoIsHTTP(t *testing.T) { + logger := quietLogger() + + var captured atomic.Value + captured.Store("") + + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.TLS == nil { + captured.Store("http") + } else { + captured.Store("https") + } + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("ok")) + }) + + hostListener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err, "loopback listener bind must succeed") + defer hostListener.Close() + + router := nbtcp.NewRouter(logger, nil, hostListener.Addr(), nbtcp.WithPlainHTTP(hostListener.Addr())) + httpServer := &http.Server{Handler: handler, ReadHeaderTimeout: time.Second} + defer func() { _ = httpServer.Close() }() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { _ = httpServer.Serve(router.HTTPListenerPlain()) }() + go func() { _ = router.Serve(ctx, hostListener) }() + + conn, err := net.DialTimeout("tcp", hostListener.Addr().String(), 2*time.Second) + require.NoError(t, err, "plain HTTP dial must succeed") + defer conn.Close() + + _, err = conn.Write([]byte("GET / HTTP/1.1\r\nHost: example.com\r\nConnection: close\r\n\r\n")) + require.NoError(t, err, "write must succeed") + + resp, err := http.ReadResponse(bufioReader(conn), nil) + require.NoError(t, err, "must read response") + defer resp.Body.Close() + + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.Equal(t, "http", captured.Load(), "ForwardedProto must be http on plain path") +} + +// TestWithTunnelLookup_AttachesLookupToContext verifies that requests +// flowing through the per-account handler wrapper carry the peerstore +// lookup function. Phase 3's local-first deny path depends on this. +func TestWithTunnelLookup_AttachesLookupToContext(t *testing.T) { + expected := auth.PeerIdentity{TunnelIP: netip.MustParseAddr("100.64.0.10"), FQDN: "peer.netbird"} + lookup := auth.TunnelLookupFunc(func(_ netip.Addr) (auth.PeerIdentity, bool) { + return expected, true + }) + + var observed auth.TunnelLookupFunc + inner := http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { + observed = auth.TunnelLookupFromContext(r.Context()) + }) + + handler := withTunnelLookup(inner, lookup) + r := httptest.NewRequest(http.MethodGet, "https://svc.example/", nil) + handler.ServeHTTP(httptest.NewRecorder(), r) + + require.NotNil(t, observed, "wrapper must inject the lookup into the request context") + got, ok := observed(netip.MustParseAddr("100.64.0.10")) + assert.True(t, ok, "lookup must round-trip through context") + assert.Equal(t, expected.FQDN, got.FQDN, "lookup must return the same identity it was constructed with") +} + +// TestWithTunnelLookup_NilLookupIsNoop confirms the wrapper is a pure +// pass-through when no lookup is provided. Required for the host-level +// listener path to keep its byte-for-byte previous behaviour. +func TestWithTunnelLookup_NilLookupIsNoop(t *testing.T) { + var called bool + inner := http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { + called = true + assert.Nil(t, auth.TunnelLookupFromContext(r.Context()), "host-level path must not see a lookup function") + }) + + handler := withTunnelLookup(inner, nil) + r := httptest.NewRequest(http.MethodGet, "https://svc.example/", nil) + handler.ServeHTTP(httptest.NewRecorder(), r) + assert.True(t, called, "wrapper without lookup must still invoke next") +} + +// fakeListener satisfies net.Listener for snapshot tests without binding +// a real socket on the netstack. +type fakeListener struct { + addr net.Addr +} + +func (f *fakeListener) Accept() (net.Conn, error) { return nil, net.ErrClosed } +func (f *fakeListener) Close() error { return nil } +func (f *fakeListener) Addr() net.Addr { return f.addr } + +// TestInboundManager_ListenerInfo confirms ListenerInfo and Snapshot +// surface the bound tunnel-IP and ports for live entries. +func TestInboundManager_ListenerInfo(t *testing.T) { + mgr := newInboundManager(quietLogger(), http.NotFoundHandler(), nil) + accountID := types.AccountID("acct-info") + + tlsAddr := &net.TCPAddr{IP: net.ParseIP("100.64.0.5"), Port: privateInboundPortHTTPS} + plainAddr := &net.TCPAddr{IP: net.ParseIP("100.64.0.5"), Port: privateInboundPortHTTP} + mgr.muxLock.Lock() + mgr.entries[accountID] = &inboundEntry{ + tlsListener: &fakeListener{addr: tlsAddr}, + plainListener: &fakeListener{addr: plainAddr}, + } + mgr.muxLock.Unlock() + + info, ok := mgr.ListenerInfo(accountID) + require.True(t, ok, "ListenerInfo must report ok for live entry") + assert.Equal(t, "100.64.0.5", info.TunnelIP, "tunnel IP must come from listener address") + assert.Equal(t, uint16(privateInboundPortHTTPS), info.HTTPSPort, "TLS port must match bound port") + assert.Equal(t, uint16(privateInboundPortHTTP), info.HTTPPort, "HTTP port must match bound port") + + snap := mgr.Snapshot() + require.Len(t, snap, 1, "snapshot must contain exactly one entry") + assert.Equal(t, info, snap[accountID], "snapshot entry must equal direct lookup") + + _, ok = mgr.ListenerInfo(types.AccountID("missing")) + assert.False(t, ok, "ListenerInfo must report ok=false for unknown accounts") +} + +// TestInboundManager_NilManagerSafe ensures the observability accessors +// are safe to call when --private-inbound is off (nil manager). +func TestInboundManager_NilManagerSafe(t *testing.T) { + var mgr *inboundManager + _, ok := mgr.ListenerInfo("anything") + assert.False(t, ok, "nil manager must return ok=false") + assert.Nil(t, mgr.Snapshot(), "nil manager must return nil snapshot") +} + +// TestInboundManager_ConcurrentAddRemove pounds AddRoute / RemoveRoute +// from multiple goroutines to expose any locking gaps. +func TestInboundManager_ConcurrentAddRemove(t *testing.T) { + mgr := newInboundManager(quietLogger(), http.NotFoundHandler(), nil) + accountID := types.AccountID("acct-1") + const workers = 32 + const iterations = 50 + + var wg sync.WaitGroup + wg.Add(workers) + for i := 0; i < workers; i++ { + go func(idx int) { + defer wg.Done() + host := nbtcp.SNIHost("example.test") + svc := types.ServiceID("svc") + route := nbtcp.Route{Type: nbtcp.RouteHTTP, AccountID: accountID, ServiceID: svc, Domain: "example.test"} + for j := 0; j < iterations; j++ { + mgr.AddRoute(accountID, host, route) + mgr.RemoveRoute(accountID, host, svc) + } + }(i) + } + done := make(chan struct{}) + go func() { + wg.Wait() + close(done) + }() + + select { + case <-done: + case <-time.After(10 * time.Second): + t.Fatal("concurrent add/remove timed out") + } +} + +// TestFeedRouterFromListener_DeliversConnectionToHandler validates the +// per-account inbound chain end-to-end with a loopback listener +// substituted for the embedded netstack: a TCP connection arriving at +// the plain listener flows through feedRouterFromListener, the router's +// peek-and-dispatch, the wrapped HTTP server, and reaches the user +// handler. If the embedded netstack is delivering connections at all, +// this is the path they take. Failures localise to wiring bugs in the +// proxy, not the netstack. +func TestFeedRouterFromListener_DeliversConnectionToHandler(t *testing.T) { + logger := quietLogger() + + hits := make(chan string, 1) + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + hits <- r.Host + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("served")) + }) + + plainLn, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err, "plain loopback bind must succeed") + t.Cleanup(func() { _ = plainLn.Close() }) + + router := nbtcp.NewRouter(logger, nil, &fakeAddr{addr: "127.0.0.1:0"}, nbtcp.WithPlainHTTP(plainLn.Addr())) + + httpServer := &http.Server{Handler: handler, ReadHeaderTimeout: time.Second} + t.Cleanup(func() { _ = httpServer.Close() }) + + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + + go func() { _ = httpServer.Serve(router.HTTPListenerPlain()) }() + go feedRouterFromListener(ctx, plainLn, router, logger, types.AccountID("acct-1")) + + conn, err := net.DialTimeout("tcp", plainLn.Addr().String(), 2*time.Second) + require.NoError(t, err, "must connect to the plain listener") + t.Cleanup(func() { _ = conn.Close() }) + + _, err = conn.Write([]byte("GET / HTTP/1.1\r\nHost: app.example\r\nConnection: close\r\n\r\n")) + require.NoError(t, err, "request write must succeed") + + resp, err := http.ReadResponse(bufioReader(conn), nil) + require.NoError(t, err, "must read response from server") + t.Cleanup(func() { _ = resp.Body.Close() }) + + assert.Equal(t, http.StatusOK, resp.StatusCode, "handler must be reached") + + select { + case host := <-hits: + assert.Equal(t, "app.example", host, "handler must observe the request Host") + case <-time.After(2 * time.Second): + t.Fatal("handler was not invoked โ€” connection did not flow through router โ†’ http server") + } +} + +// TestFeedRouterFromListener_DispatchesTLSToTLSChannel verifies that a +// TLS ClientHello arriving on the plain listener is detected by the +// router peek and re-dispatched to the TLS channel โ€” the cross-channel +// fallback the inbound stack relies on for HTTPS-on-:80 testing. +func TestFeedRouterFromListener_DispatchesTLSToTLSChannel(t *testing.T) { + logger := quietLogger() + + hits := make(chan string, 1) + tlsHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + hits <- r.Host + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("served-tls")) + }) + + plainLn, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err, "plain loopback bind must succeed") + t.Cleanup(func() { _ = plainLn.Close() }) + + tlsLn, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err, "tls loopback bind must succeed") + t.Cleanup(func() { _ = tlsLn.Close() }) + + router := nbtcp.NewRouter(logger, nil, tlsLn.Addr(), nbtcp.WithPlainHTTP(plainLn.Addr())) + + tlsConfig := selfSignedTLSConfig(t) + httpsServer := &http.Server{ + Handler: tlsHandler, + TLSConfig: tlsConfig, + ReadHeaderTimeout: time.Second, + } + t.Cleanup(func() { _ = httpsServer.Close() }) + + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + + go func() { _ = httpsServer.ServeTLS(router.HTTPListener(), "", "") }() + go feedRouterFromListener(ctx, plainLn, router, logger, types.AccountID("acct-tls")) + + tlsConn, err := tls.Dial("tcp", plainLn.Addr().String(), &tls.Config{InsecureSkipVerify: true}) //nolint:gosec + require.NoError(t, err, "TLS dial against the plain listener must succeed (cross-channel)") + t.Cleanup(func() { _ = tlsConn.Close() }) + + req, err := http.NewRequest(http.MethodGet, "https://app.example/", nil) + require.NoError(t, err) + require.NoError(t, req.Write(tlsConn), "TLS request write must succeed") + + resp, err := http.ReadResponse(bufioReader(tlsConn), req) + require.NoError(t, err, "must read TLS response") + t.Cleanup(func() { _ = resp.Body.Close() }) + + assert.Equal(t, http.StatusOK, resp.StatusCode, "TLS handler must be reached") + + select { + case host := <-hits: + assert.Equal(t, "app.example", host, "TLS handler must observe the request Host") + case <-time.After(2 * time.Second): + t.Fatal("TLS handler was not invoked โ€” peek/dispatch path is broken") + } +} + +func selfSignedTLSConfig(t *testing.T) *tls.Config { + t.Helper() + cert, err := tls.X509KeyPair(testCertPEM, testKeyPEM) + require.NoError(t, err, "load static self-signed cert") + return &tls.Config{Certificates: []tls.Certificate{cert}, MinVersion: tls.VersionTLS12} //nolint:gosec +} + +// testCertPEM / testKeyPEM are a minimal RSA self-signed cert for +// 127.0.0.1 โ€” only used by tests that need a working TLS handshake. +var testCertPEM = []byte(`-----BEGIN CERTIFICATE----- +MIIBhTCCASugAwIBAgIQIRi6zePL6mKjOipn+dNuaTAKBggqhkjOPQQDAjASMRAw +DgYDVQQKEwdBY21lIENvMB4XDTE3MTAyMDE5NDMwNloXDTE4MTAyMDE5NDMwNlow +EjEQMA4GA1UEChMHQWNtZSBDbzBZMBMGByqGSM49AgEGCCqGSM49AwEHA0IABD0d +7VNhbWvZLWPuj/RtHFjvtJBEwOkhbN/BnnE8rnZR8+sbwnc/KhCk3FhnpHZnQz7B +5aETbbIgmuvewdjvSBSjYzBhMA4GA1UdDwEB/wQEAwICpDATBgNVHSUEDDAKBggr +BgEFBQcDATAPBgNVHRMBAf8EBTADAQH/MCkGA1UdEQQiMCCCDmxvY2FsaG9zdDo1 +NDUzgg4xMjcuMC4wLjE6NTQ1MzAKBggqhkjOPQQDAgNIADBFAiEA2zpJEPQyz6/l +Wf86aX6PepsntZv2GYlA5UpabfT2EZICICpJ5h/iI+i341gBmLiAFQOyTDT+/wQc +6MF9+Yw1Yy0t +-----END CERTIFICATE-----`) +var testKeyPEM = []byte(`-----BEGIN EC PRIVATE KEY----- +MHcCAQEEIIrYSSNQFaA2Hwf1duRSxKtLYX5CB04fSeQ6tF1aY/PuoAoGCCqGSM49 +AwEHoUQDQgAEPR3tU2Fta9ktY+6P9G0cWO+0kETA6SFs38GecTyudlHz6xvCdz8q +EKTcWGekdmdDPsHloRNtsiCa697B2O9IFA== +-----END EC PRIVATE KEY-----`) diff --git a/proxy/internal/auth/identity.go b/proxy/internal/auth/identity.go new file mode 100644 index 000000000..c49e0caa9 --- /dev/null +++ b/proxy/internal/auth/identity.go @@ -0,0 +1,47 @@ +package auth + +import ( + "context" + "net/netip" +) + +// PeerIdentity describes the locally-known facts about a peer reachable on +// the proxy's per-account WireGuard listener. Phase 3 fills PubKey, TunnelIP +// and FQDN from the embedded client's peerstore. UserID, Email and Groups +// stay zero in V1 โ€” full identity still travels through ValidateTunnelPeer. +// Phase V2 will populate them once RemotePeerConfig carries user identity. +type PeerIdentity struct { + PubKey string + TunnelIP netip.Addr + FQDN string + + // V2 fields (zero in V1). + UserID string + Email string + Groups []string +} + +// TunnelLookupFunc resolves a tunnel IP to a peer identity using locally +// available peerstore data. ok=false means the IP is not in the calling +// account's roster. +type TunnelLookupFunc func(ip netip.Addr) (PeerIdentity, bool) + +type tunnelLookupContextKey struct{} + +// WithTunnelLookup attaches a per-account peerstore lookup function to +// the request context. The auth middleware calls this lookup before +// hitting management's ValidateTunnelPeer to short-circuit unknown IPs +// and to skip the RPC for already-cached identities. +func WithTunnelLookup(ctx context.Context, lookup TunnelLookupFunc) context.Context { + if lookup == nil { + return ctx + } + return context.WithValue(ctx, tunnelLookupContextKey{}, lookup) +} + +// TunnelLookupFromContext returns the peerstore lookup attached to ctx, +// or nil when the request did not arrive on a per-account listener. +func TunnelLookupFromContext(ctx context.Context) TunnelLookupFunc { + v, _ := ctx.Value(tunnelLookupContextKey{}).(TunnelLookupFunc) + return v +} diff --git a/proxy/internal/auth/middleware.go b/proxy/internal/auth/middleware.go index 3b383f8b4..a76427ca0 100644 --- a/proxy/internal/auth/middleware.go +++ b/proxy/internal/auth/middleware.go @@ -36,6 +36,7 @@ type authenticator interface { // SessionValidator validates session tokens and checks user access permissions. type SessionValidator interface { ValidateSession(ctx context.Context, in *proto.ValidateSessionRequest, opts ...grpc.CallOption) (*proto.ValidateSessionResponse, error) + ValidateTunnelPeer(ctx context.Context, in *proto.ValidateTunnelPeerRequest, opts ...grpc.CallOption) (*proto.ValidateTunnelPeerResponse, error) } // Scheme defines an authentication mechanism for a domain. @@ -56,12 +57,21 @@ type DomainConfig struct { AccountID types.AccountID ServiceID types.ServiceID IPRestrictions *restrict.Filter + // Private routes the domain through ValidateTunnelPeer; failure โ†’ 403. + Private bool } type validationResult struct { UserID string + UserEmail string Valid bool DeniedReason string + Groups []string + // GroupNames carries the human-readable display names for Groups, + // ordered identically (positional pairing). May be shorter than + // Groups for tokens minted before names were embedded; the consumer + // falls back to ids for missing positions. + GroupNames []string } // Middleware applies per-domain authentication and IP restriction checks. @@ -71,6 +81,7 @@ type Middleware struct { logger *log.Logger sessionValidator SessionValidator geo restrict.GeoResolver + tunnelCache *tunnelValidationCache } // NewMiddleware creates a new authentication middleware. The sessionValidator is @@ -84,6 +95,7 @@ func NewMiddleware(logger *log.Logger, sessionValidator SessionValidator, geo re logger: logger, sessionValidator: sessionValidator, geo: geo, + tunnelCache: newTunnelValidationCache(), } } @@ -111,6 +123,15 @@ func (mw *Middleware) Protect(next http.Handler) http.Handler { return } + // Private services bypass operator schemes and gate on tunnel peer. + if config.Private { + if mw.forwardWithTunnelPeer(w, r, host, config, next) { + return + } + http.Error(w, "Forbidden", http.StatusForbidden) + return + } + // Domains with no authentication schemes pass through after IP checks. if len(config.Schemes) == 0 { next.ServeHTTP(w, r) @@ -129,10 +150,54 @@ func (mw *Middleware) Protect(next http.Handler) http.Handler { return } + if mw.forwardWithTunnelPeer(w, r, host, config, next) { + return + } + + if mw.blockOIDCOnPlainHTTP(w, r, config) { + return + } + mw.authenticateWithSchemes(w, r, host, config) }) } +// requestIsPlainHTTP reports whether the request arrived without TLS. +// Used to gate cookie-on-plain warnings and the OIDC plain-HTTP block. +func requestIsPlainHTTP(r *http.Request) bool { + return r.TLS == nil +} + +// hasOIDCScheme reports whether any of the configured schemes requires +// TLS to round-trip safely with an external IdP. +func hasOIDCScheme(schemes []Scheme) bool { + for _, s := range schemes { + if s.Type() == auth.MethodOIDC { + return true + } + } + return false +} + +// blockOIDCOnPlainHTTP fails fast when an OIDC-configured domain is hit +// over plain HTTP. Most IdPs reject http:// redirect URIs, so surfacing +// the misconfiguration here yields a clearer error than the IdP's +// "invalid redirect_uri" round-trip. +func (mw *Middleware) blockOIDCOnPlainHTTP(w http.ResponseWriter, r *http.Request, config DomainConfig) bool { + if !requestIsPlainHTTP(r) { + return false + } + if !hasOIDCScheme(config.Schemes) { + return false + } + mw.logger.WithFields(log.Fields{ + "host": r.Host, + "remote": r.RemoteAddr, + }).Warn("OIDC scheme reached on plain HTTP path; rejecting with 400 โ€” use port 443") + http.Error(w, "OIDC requires TLS โ€” use port 443", http.StatusBadRequest) + return true +} + func (mw *Middleware) getDomainConfig(host string) (DomainConfig, bool) { mw.domainsMux.RLock() defer mw.domainsMux.RUnlock() @@ -162,7 +227,17 @@ func (mw *Middleware) checkIPRestrictions(w http.ResponseWriter, r *http.Request return false } - verdict := config.IPRestrictions.Check(clientIP, mw.geo) + var verdict restrict.Verdict + if types.IsOverlayOrigin(r.Context()) { + // Geo/CrowdSec checks don't apply over the WireGuard overlay: + // the source address is always inside the NetBird CGNAT range, + // which is never in a GeoIP database or a CrowdSec decision + // list. Enforcing them here would either no-op (best case) or + // fail-closed when the geo database is missing. + verdict = config.IPRestrictions.CheckCIDR(clientIP) + } else { + verdict = config.IPRestrictions.Check(clientIP, mw.geo) + } if verdict == restrict.Allow { return true } @@ -246,18 +321,111 @@ func (mw *Middleware) forwardWithSessionCookie(w http.ResponseWriter, r *http.Re if err != nil { return false } - userID, method, err := auth.ValidateSessionJWT(cookie.Value, host, config.SessionPublicKey) + userID, email, method, groups, groupNames, err := auth.ValidateSessionJWT(cookie.Value, host, config.SessionPublicKey) if err != nil { return false } if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil { cd.SetUserID(userID) + cd.SetUserEmail(email) + cd.SetUserGroups(groups) + cd.SetUserGroupNames(groupNames) cd.SetAuthMethod(method) } next.ServeHTTP(w, r) return true } +// forwardWithTunnelPeer is the OIDC fast-path for requests originating on the +// netbird mesh. When the source IP belongs to a private/CGNAT range the proxy +// asks management to resolve it to a peer/user and to gate by the service's +// distribution_groups. On success the proxy installs the freshly minted JWT +// as a session cookie, sets UserID + Method=oidc on the captured data, and +// forwards directly โ€” operators see the same access-log shape as if the user +// had completed an OIDC redirect. Any failure (private-range mismatch, +// management unreachable, peer unknown, user not in group) returns false so +// the caller falls back to the existing OIDC scheme dispatch. +// +// Phase 3 adds a local-first short-circuit: when the request arrived on a +// per-account inbound listener the context carries a peerstore lookup +// (TunnelLookupFromContext). If the lookup says the IP isn't in the account's +// roster the proxy denies fast without calling management. If the lookup +// confirms a known peer the RPC still runs for the user-identity tail +// (UserID + group access), but its result is cached for tunnelCacheTTL so +// repeat requests skip management entirely. +func (mw *Middleware) forwardWithTunnelPeer(w http.ResponseWriter, r *http.Request, host string, config DomainConfig, next http.Handler) bool { + if mw.sessionValidator == nil { + return false + } + clientIP := mw.resolveClientIP(r) + if !clientIP.IsValid() { + return false + } + if !isTunnelSourceIP(clientIP) { + return false + } + + if lookup := TunnelLookupFromContext(r.Context()); lookup != nil { + if _, ok := lookup(clientIP); !ok { + mw.logger.WithFields(log.Fields{ + "host": host, + "remote": clientIP, + }).Debug("local peerstore: tunnel IP not in account roster; denying without RPC") + return false + } + } + + resp, _, err := mw.tunnelCache.fetch(r.Context(), tunnelCacheKey{ + accountID: config.AccountID, + tunnelIP: clientIP, + domain: host, + }, mw.validateTunnelPeer) + if err != nil { + mw.logger.WithError(err).Debug("ValidateTunnelPeer failed; falling back to OIDC") + return false + } + if !resp.GetValid() || resp.GetSessionToken() == "" { + return false + } + + setSessionCookie(w, resp.GetSessionToken(), config.SessionExpiration) + if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil { + cd.SetOrigin(proxy.OriginAuth) + cd.SetUserID(resp.GetUserId()) + cd.SetUserEmail(resp.GetUserEmail()) + cd.SetUserGroups(resp.GetPeerGroupIds()) + cd.SetUserGroupNames(resp.GetPeerGroupNames()) + cd.SetAuthMethod(auth.MethodOIDC.String()) + } + next.ServeHTTP(w, r) + return true +} + +// validateTunnelPeer adapts the SessionValidator interface to the cache's +// validateTunnelPeerFn signature. +func (mw *Middleware) validateTunnelPeer(ctx context.Context, req *proto.ValidateTunnelPeerRequest) (*proto.ValidateTunnelPeerResponse, error) { + return mw.sessionValidator.ValidateTunnelPeer(ctx, req) +} + +// cgnatPrefix covers RFC 6598 100.64.0.0/10, the CGNAT block NetBird +// allocates tunnel addresses from by default. IsPrivate() doesn't include +// it, so we check it explicitly. +var cgnatPrefix = netip.MustParsePrefix("100.64.0.0/10") + +// isTunnelSourceIP reports whether ip falls within an address range typical +// of NetBird tunnels: RFC1918 private space, IPv6 ULA, or CGNAT 100.64/10 +// (NetBird's default range). Loopback and link-local are excluded โ€” the +// fast-path is meant for peer-to-peer mesh traffic, not localhost. +func isTunnelSourceIP(ip netip.Addr) bool { + if !ip.IsValid() || ip.IsLoopback() || ip.IsLinkLocalUnicast() { + return false + } + if ip.IsPrivate() { + return true + } + return cgnatPrefix.Contains(ip) +} + // forwardWithHeaderAuth checks for a Header auth scheme. If the header validates, // the request is forwarded directly (no redirect), which is important for API clients. func (mw *Middleware) forwardWithHeaderAuth(w http.ResponseWriter, r *http.Request, host string, config DomainConfig, next http.Handler) bool { @@ -286,7 +454,7 @@ func (mw *Middleware) tryHeaderScheme(w http.ResponseWriter, r *http.Request, ho result, err := mw.validateSessionToken(r.Context(), host, token, config.SessionPublicKey, auth.MethodHeader) if err != nil { - setHeaderCapturedData(r.Context(), "") + setHeaderCapturedData(r.Context(), "", "", nil, nil) status := http.StatusBadRequest msg := "invalid session token" if errors.Is(err, errValidationUnavailable) { @@ -298,7 +466,7 @@ func (mw *Middleware) tryHeaderScheme(w http.ResponseWriter, r *http.Request, ho } if !result.Valid { - setHeaderCapturedData(r.Context(), result.UserID) + setHeaderCapturedData(r.Context(), result.UserID, result.UserEmail, result.Groups, result.GroupNames) http.Error(w, "Unauthorized", http.StatusUnauthorized) return true } @@ -306,6 +474,9 @@ func (mw *Middleware) tryHeaderScheme(w http.ResponseWriter, r *http.Request, ho setSessionCookie(w, token, config.SessionExpiration) if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil { cd.SetUserID(result.UserID) + cd.SetUserEmail(result.UserEmail) + cd.SetUserGroups(result.Groups) + cd.SetUserGroupNames(result.GroupNames) cd.SetAuthMethod(auth.MethodHeader.String()) } @@ -315,7 +486,7 @@ func (mw *Middleware) tryHeaderScheme(w http.ResponseWriter, r *http.Request, ho func (mw *Middleware) handleHeaderAuthError(w http.ResponseWriter, r *http.Request, err error) bool { if errors.Is(err, ErrHeaderAuthFailed) { - setHeaderCapturedData(r.Context(), "") + setHeaderCapturedData(r.Context(), "", "", nil, nil) http.Error(w, "Unauthorized", http.StatusUnauthorized) return true } @@ -327,7 +498,7 @@ func (mw *Middleware) handleHeaderAuthError(w http.ResponseWriter, r *http.Reque return true } -func setHeaderCapturedData(ctx context.Context, userID string) { +func setHeaderCapturedData(ctx context.Context, userID, userEmail string, groups, groupNames []string) { cd := proxy.CapturedDataFromContext(ctx) if cd == nil { return @@ -335,6 +506,9 @@ func setHeaderCapturedData(ctx context.Context, userID string) { cd.SetOrigin(proxy.OriginAuth) cd.SetAuthMethod(auth.MethodHeader.String()) cd.SetUserID(userID) + cd.SetUserEmail(userEmail) + cd.SetUserGroups(groups) + cd.SetUserGroupNames(groupNames) } // authenticateWithSchemes tries each configured auth scheme in order. @@ -405,6 +579,9 @@ func (mw *Middleware) handleAuthenticatedToken(w http.ResponseWriter, r *http.Re if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil { cd.SetOrigin(proxy.OriginAuth) cd.SetUserID(result.UserID) + cd.SetUserEmail(result.UserEmail) + cd.SetUserGroups(result.Groups) + cd.SetUserGroupNames(result.GroupNames) cd.SetAuthMethod(scheme.Type().String()) requestID = cd.GetRequestID() } @@ -419,6 +596,9 @@ func (mw *Middleware) handleAuthenticatedToken(w http.ResponseWriter, r *http.Re if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil { cd.SetOrigin(proxy.OriginAuth) cd.SetUserID(result.UserID) + cd.SetUserEmail(result.UserEmail) + cd.SetUserGroups(result.Groups) + cd.SetUserGroupNames(result.GroupNames) cd.SetAuthMethod(scheme.Type().String()) } redirectURL := stripSessionTokenParam(r.URL) @@ -454,12 +634,9 @@ func wasCredentialSubmitted(r *http.Request, method auth.Method) bool { return false } -// AddDomain registers authentication schemes for the given domain. -// If schemes are provided, a valid session public key is required to sign/verify -// session JWTs. Returns an error if the key is missing or invalid. -// Callers must not serve the domain if this returns an error, to avoid -// exposing an unauthenticated service. -func (mw *Middleware) AddDomain(domain string, schemes []Scheme, publicKeyB64 string, expiration time.Duration, accountID types.AccountID, serviceID types.ServiceID, ipRestrictions *restrict.Filter) error { +// AddDomain registers authentication schemes for the given domain. With schemes a valid session public key is required. +// private=true forces ValidateTunnelPeer enforcement (403 on failure) regardless of the schemes list. +func (mw *Middleware) AddDomain(domain string, schemes []Scheme, publicKeyB64 string, expiration time.Duration, accountID types.AccountID, serviceID types.ServiceID, ipRestrictions *restrict.Filter, private bool) error { if len(schemes) == 0 { mw.domainsMux.Lock() defer mw.domainsMux.Unlock() @@ -467,6 +644,7 @@ func (mw *Middleware) AddDomain(domain string, schemes []Scheme, publicKeyB64 st AccountID: accountID, ServiceID: serviceID, IPRestrictions: ipRestrictions, + Private: private, } return nil } @@ -488,6 +666,7 @@ func (mw *Middleware) AddDomain(domain string, schemes []Scheme, publicKeyB64 st AccountID: accountID, ServiceID: serviceID, IPRestrictions: ipRestrictions, + Private: private, } return nil } @@ -518,18 +697,25 @@ func (mw *Middleware) validateSessionToken(ctx context.Context, host, token stri }).Debug("Session validation denied") return &validationResult{ UserID: resp.UserId, + UserEmail: resp.GetUserEmail(), Valid: false, DeniedReason: resp.DeniedReason, }, nil } - return &validationResult{UserID: resp.UserId, Valid: true}, nil + return &validationResult{ + UserID: resp.UserId, + UserEmail: resp.GetUserEmail(), + Valid: true, + Groups: resp.GetPeerGroupIds(), + GroupNames: resp.GetPeerGroupNames(), + }, nil } - userID, _, err := auth.ValidateSessionJWT(token, host, publicKey) + userID, email, _, groups, groupNames, err := auth.ValidateSessionJWT(token, host, publicKey) if err != nil { return nil, err } - return &validationResult{UserID: userID, Valid: true}, nil + return &validationResult{UserID: userID, UserEmail: email, Valid: true, Groups: groups, GroupNames: groupNames}, nil } // stripSessionTokenParam returns the request URI with the session_token query diff --git a/proxy/internal/auth/middleware_test.go b/proxy/internal/auth/middleware_test.go index 2c93d7912..84c319446 100644 --- a/proxy/internal/auth/middleware_test.go +++ b/proxy/internal/auth/middleware_test.go @@ -4,6 +4,7 @@ import ( "context" "crypto/ed25519" "crypto/rand" + "crypto/tls" "encoding/base64" "errors" "net/http" @@ -23,6 +24,7 @@ import ( "github.com/netbirdio/netbird/proxy/auth" "github.com/netbirdio/netbird/proxy/internal/proxy" "github.com/netbirdio/netbird/proxy/internal/restrict" + "github.com/netbirdio/netbird/proxy/internal/types" "github.com/netbirdio/netbird/shared/management/proto" ) @@ -62,7 +64,7 @@ func TestAddDomain_ValidKey(t *testing.T) { kp := generateTestKeyPair(t) scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"} - err := mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil) + err := mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false) require.NoError(t, err) mw.domainsMux.RLock() @@ -79,7 +81,7 @@ func TestAddDomain_EmptyKey(t *testing.T) { mw := NewMiddleware(log.StandardLogger(), nil, nil) scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"} - err := mw.AddDomain("example.com", []Scheme{scheme}, "", time.Hour, "", "", nil) + err := mw.AddDomain("example.com", []Scheme{scheme}, "", time.Hour, "", "", nil, false) require.Error(t, err) assert.Contains(t, err.Error(), "invalid session public key size") @@ -93,7 +95,7 @@ func TestAddDomain_InvalidBase64(t *testing.T) { mw := NewMiddleware(log.StandardLogger(), nil, nil) scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"} - err := mw.AddDomain("example.com", []Scheme{scheme}, "not-valid-base64!!!", time.Hour, "", "", nil) + err := mw.AddDomain("example.com", []Scheme{scheme}, "not-valid-base64!!!", time.Hour, "", "", nil, false) require.Error(t, err) assert.Contains(t, err.Error(), "decode session public key") @@ -108,7 +110,7 @@ func TestAddDomain_WrongKeySize(t *testing.T) { shortKey := base64.StdEncoding.EncodeToString([]byte("tooshort")) scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"} - err := mw.AddDomain("example.com", []Scheme{scheme}, shortKey, time.Hour, "", "", nil) + err := mw.AddDomain("example.com", []Scheme{scheme}, shortKey, time.Hour, "", "", nil, false) require.Error(t, err) assert.Contains(t, err.Error(), "invalid session public key size") @@ -121,7 +123,7 @@ func TestAddDomain_WrongKeySize(t *testing.T) { func TestAddDomain_NoSchemes_NoKeyRequired(t *testing.T) { mw := NewMiddleware(log.StandardLogger(), nil, nil) - err := mw.AddDomain("example.com", nil, "", time.Hour, "", "", nil) + err := mw.AddDomain("example.com", nil, "", time.Hour, "", "", nil, false) require.NoError(t, err, "domains with no auth schemes should not require a key") mw.domainsMux.RLock() @@ -137,8 +139,8 @@ func TestAddDomain_OverwritesPreviousConfig(t *testing.T) { scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"} - require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp1.PublicKey, time.Hour, "", "", nil)) - require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp2.PublicKey, 2*time.Hour, "", "", nil)) + require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp1.PublicKey, time.Hour, "", "", nil, false)) + require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp2.PublicKey, 2*time.Hour, "", "", nil, false)) mw.domainsMux.RLock() config := mw.domains["example.com"] @@ -154,7 +156,7 @@ func TestRemoveDomain(t *testing.T) { kp := generateTestKeyPair(t) scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"} - require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil)) + require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false)) mw.RemoveDomain("example.com") @@ -178,7 +180,7 @@ func TestProtect_UnknownDomainPassesThrough(t *testing.T) { func TestProtect_DomainWithNoSchemesPassesThrough(t *testing.T) { mw := NewMiddleware(log.StandardLogger(), nil, nil) - require.NoError(t, mw.AddDomain("example.com", nil, "", time.Hour, "", "", nil)) + require.NoError(t, mw.AddDomain("example.com", nil, "", time.Hour, "", "", nil, false)) handler := mw.Protect(newPassthroughHandler()) @@ -195,7 +197,7 @@ func TestProtect_UnauthenticatedRequestIsBlocked(t *testing.T) { kp := generateTestKeyPair(t) scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"} - require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil)) + require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false)) var backendCalled bool backend := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { @@ -216,7 +218,7 @@ func TestProtect_HostWithPortIsMatched(t *testing.T) { kp := generateTestKeyPair(t) scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"} - require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil)) + require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false)) var backendCalled bool backend := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { @@ -237,9 +239,9 @@ func TestProtect_ValidSessionCookiePassesThrough(t *testing.T) { kp := generateTestKeyPair(t) scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"} - require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil)) + require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false)) - token, err := sessionkey.SignToken(kp.PrivateKey, "test-user", "example.com", auth.MethodPIN, time.Hour) + token, err := sessionkey.SignToken(kp.PrivateKey, "test-user", "", "example.com", auth.MethodPIN, nil, nil, time.Hour) require.NoError(t, err) capturedData := proxy.NewCapturedData("") @@ -262,15 +264,48 @@ func TestProtect_ValidSessionCookiePassesThrough(t *testing.T) { assert.Equal(t, "authenticated", rec.Body.String()) } +// TestProtect_SessionCookieGroupsPropagate verifies the cookie path lifts the +// JWT's groups claim into CapturedData so policy-aware middlewares can +// authorise without an extra management round-trip. +func TestProtect_SessionCookieGroupsPropagate(t *testing.T) { + mw := NewMiddleware(log.StandardLogger(), nil, nil) + kp := generateTestKeyPair(t) + + scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"} + require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false)) + + groups := []string{"engineering", "sre"} + token, err := sessionkey.SignToken(kp.PrivateKey, "test-user", "", "example.com", auth.MethodPIN, groups, nil, time.Hour) + require.NoError(t, err) + + capturedData := proxy.NewCapturedData("") + handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + cd := proxy.CapturedDataFromContext(r.Context()) + require.NotNil(t, cd, "captured data must be present in request context") + assert.Equal(t, "test-user", cd.GetUserID()) + assert.Equal(t, groups, cd.GetUserGroups(), "JWT groups claim must propagate to CapturedData") + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil) + req = req.WithContext(proxy.WithCapturedData(req.Context(), capturedData)) + req.AddCookie(&http.Cookie{Name: auth.SessionCookieName, Value: token}) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code, "request with valid groups-bearing cookie must succeed") + assert.Equal(t, groups, capturedData.GetUserGroups(), "CapturedData groups must be retained after handler completes") +} + func TestProtect_ExpiredSessionCookieIsRejected(t *testing.T) { mw := NewMiddleware(log.StandardLogger(), nil, nil) kp := generateTestKeyPair(t) scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"} - require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil)) + require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false)) // Sign a token that expired 1 second ago. - token, err := sessionkey.SignToken(kp.PrivateKey, "test-user", "example.com", auth.MethodPIN, -time.Second) + token, err := sessionkey.SignToken(kp.PrivateKey, "test-user", "", "example.com", auth.MethodPIN, nil, nil, -time.Second) require.NoError(t, err) var backendCalled bool @@ -293,10 +328,10 @@ func TestProtect_WrongDomainCookieIsRejected(t *testing.T) { kp := generateTestKeyPair(t) scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"} - require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil)) + require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false)) // Token signed for a different domain audience. - token, err := sessionkey.SignToken(kp.PrivateKey, "test-user", "other.com", auth.MethodPIN, time.Hour) + token, err := sessionkey.SignToken(kp.PrivateKey, "test-user", "", "other.com", auth.MethodPIN, nil, nil, time.Hour) require.NoError(t, err) var backendCalled bool @@ -320,10 +355,10 @@ func TestProtect_WrongKeyCookieIsRejected(t *testing.T) { kp2 := generateTestKeyPair(t) scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"} - require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp1.PublicKey, time.Hour, "", "", nil)) + require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp1.PublicKey, time.Hour, "", "", nil, false)) // Token signed with a different private key. - token, err := sessionkey.SignToken(kp2.PrivateKey, "test-user", "example.com", auth.MethodPIN, time.Hour) + token, err := sessionkey.SignToken(kp2.PrivateKey, "test-user", "", "example.com", auth.MethodPIN, nil, nil, time.Hour) require.NoError(t, err) var backendCalled bool @@ -345,7 +380,7 @@ func TestProtect_SchemeAuthRedirectsWithCookie(t *testing.T) { mw := NewMiddleware(log.StandardLogger(), nil, nil) kp := generateTestKeyPair(t) - token, err := sessionkey.SignToken(kp.PrivateKey, "pin-user", "example.com", auth.MethodPIN, time.Hour) + token, err := sessionkey.SignToken(kp.PrivateKey, "pin-user", "", "example.com", auth.MethodPIN, nil, nil, time.Hour) require.NoError(t, err) scheme := &stubScheme{ @@ -357,7 +392,7 @@ func TestProtect_SchemeAuthRedirectsWithCookie(t *testing.T) { return "", "pin", nil }, } - require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil)) + require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false)) var backendCalled bool backend := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { @@ -410,7 +445,7 @@ func TestProtect_FailedAuthDoesNotSetCookie(t *testing.T) { return "", "pin", nil }, } - require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil)) + require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false)) handler := mw.Protect(newPassthroughHandler()) @@ -427,7 +462,7 @@ func TestProtect_MultipleSchemes(t *testing.T) { mw := NewMiddleware(log.StandardLogger(), nil, nil) kp := generateTestKeyPair(t) - token, err := sessionkey.SignToken(kp.PrivateKey, "password-user", "example.com", auth.MethodPassword, time.Hour) + token, err := sessionkey.SignToken(kp.PrivateKey, "password-user", "", "example.com", auth.MethodPassword, nil, nil, time.Hour) require.NoError(t, err) // First scheme (PIN) always fails, second scheme (password) succeeds. @@ -446,7 +481,7 @@ func TestProtect_MultipleSchemes(t *testing.T) { return "", "password", nil }, } - require.NoError(t, mw.AddDomain("example.com", []Scheme{pinScheme, passwordScheme}, kp.PublicKey, time.Hour, "", "", nil)) + require.NoError(t, mw.AddDomain("example.com", []Scheme{pinScheme, passwordScheme}, kp.PublicKey, time.Hour, "", "", nil, false)) var backendCalled bool backend := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { @@ -476,7 +511,7 @@ func TestProtect_InvalidTokenFromSchemeReturns400(t *testing.T) { return "invalid-jwt-token", "", nil }, } - require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil)) + require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false)) handler := mw.Protect(newPassthroughHandler()) @@ -500,7 +535,7 @@ func TestAddDomain_RandomBytes32NotEd25519(t *testing.T) { key := base64.StdEncoding.EncodeToString(randomBytes) scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"} - err = mw.AddDomain("example.com", []Scheme{scheme}, key, time.Hour, "", "", nil) + err = mw.AddDomain("example.com", []Scheme{scheme}, key, time.Hour, "", "", nil, false) require.NoError(t, err, "any 32-byte key should be accepted at registration time") } @@ -509,10 +544,10 @@ func TestAddDomain_InvalidKeyDoesNotCorruptExistingConfig(t *testing.T) { kp := generateTestKeyPair(t) scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"} - require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil)) + require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false)) // Attempt to overwrite with an invalid key. - err := mw.AddDomain("example.com", []Scheme{scheme}, "bad", time.Hour, "", "", nil) + err := mw.AddDomain("example.com", []Scheme{scheme}, "bad", time.Hour, "", "", nil, false) require.Error(t, err) // The original valid config should still be intact. @@ -536,7 +571,7 @@ func TestProtect_FailedPinAuthCapturesAuthMethod(t *testing.T) { return "", "pin", nil }, } - require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil)) + require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false)) capturedData := proxy.NewCapturedData("") handler := mw.Protect(newPassthroughHandler()) @@ -563,7 +598,7 @@ func TestProtect_FailedPasswordAuthCapturesAuthMethod(t *testing.T) { return "", "password", nil }, } - require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil)) + require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false)) capturedData := proxy.NewCapturedData("") handler := mw.Protect(newPassthroughHandler()) @@ -590,7 +625,7 @@ func TestProtect_NoCredentialsDoesNotCaptureAuthMethod(t *testing.T) { return "", "pin", nil }, } - require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil)) + require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false)) capturedData := proxy.NewCapturedData("") handler := mw.Protect(newPassthroughHandler()) @@ -678,7 +713,7 @@ func TestCheckIPRestrictions_UnparseableAddress(t *testing.T) { mw := NewMiddleware(log.StandardLogger(), nil, nil) err := mw.AddDomain("example.com", nil, "", 0, "acc1", "svc1", - restrict.ParseFilter(restrict.FilterConfig{AllowedCIDRs: []string{"10.0.0.0/8"}})) + restrict.ParseFilter(restrict.FilterConfig{AllowedCIDRs: []string{"10.0.0.0/8"}}), false) require.NoError(t, err) handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -714,7 +749,7 @@ func TestCheckIPRestrictions_UsesCapturedDataClientIP(t *testing.T) { mw := NewMiddleware(log.StandardLogger(), nil, nil) err := mw.AddDomain("example.com", nil, "", 0, "acc1", "svc1", - restrict.ParseFilter(restrict.FilterConfig{AllowedCIDRs: []string{"203.0.113.0/24"}})) + restrict.ParseFilter(restrict.FilterConfig{AllowedCIDRs: []string{"203.0.113.0/24"}}), false) require.NoError(t, err) handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -755,7 +790,7 @@ func TestCheckIPRestrictions_NilGeoWithCountryRules(t *testing.T) { mw := NewMiddleware(log.StandardLogger(), nil, nil) err := mw.AddDomain("example.com", nil, "", 0, "acc1", "svc1", - restrict.ParseFilter(restrict.FilterConfig{AllowedCountries: []string{"US"}})) + restrict.ParseFilter(restrict.FilterConfig{AllowedCountries: []string{"US"}}), false) require.NoError(t, err) handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -770,6 +805,69 @@ func TestCheckIPRestrictions_NilGeoWithCountryRules(t *testing.T) { assert.Equal(t, http.StatusForbidden, rr.Code, "country restrictions with nil geo must deny") } +// TestCheckIPRestrictions_OverlayOriginSkipsCountryRules covers the +// inbound (WG) listener path: requests stamped with WithOverlayOrigin +// must skip country lookups, even when no geo database is configured. +// Without this short-circuit the inbound flow would fail-closed for +// every overlay request whenever country rules are configured. +func TestCheckIPRestrictions_OverlayOriginSkipsCountryRules(t *testing.T) { + mw := NewMiddleware(log.StandardLogger(), nil, nil) + + err := mw.AddDomain("example.com", nil, "", 0, "acc1", "svc1", + restrict.ParseFilter(restrict.FilterConfig{ + AllowedCIDRs: []string{"100.64.0.0/10"}, + AllowedCountries: []string{"US"}, + }), false) + require.NoError(t, err) + + handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil) + req.RemoteAddr = "100.64.5.6:5000" + req.Host = "example.com" + req = req.WithContext(types.WithOverlayOrigin(req.Context())) + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + assert.Equal(t, http.StatusOK, rr.Code, + "overlay-origin requests must not be denied by country rules they would fail without geo data") + + // Sanity check: the same filter without the overlay flag denies (no geo, + // country allowlist active โ†’ DenyGeoUnavailable). + req2 := httptest.NewRequest(http.MethodGet, "http://example.com/", nil) + req2.RemoteAddr = "100.64.5.6:5000" + req2.Host = "example.com" + rr2 := httptest.NewRecorder() + handler.ServeHTTP(rr2, req2) + assert.Equal(t, http.StatusForbidden, rr2.Code, + "WAN-origin requests must still hit the full Check path and be denied without geo data") +} + +// TestCheckIPRestrictions_OverlayOriginRespectsCIDR confirms CIDR +// rules still apply on the overlay path so operators retain a way to +// scope private services to specific peer subnets. +func TestCheckIPRestrictions_OverlayOriginRespectsCIDR(t *testing.T) { + mw := NewMiddleware(log.StandardLogger(), nil, nil) + + err := mw.AddDomain("example.com", nil, "", 0, "acc1", "svc1", + restrict.ParseFilter(restrict.FilterConfig{AllowedCIDRs: []string{"100.64.0.0/16"}}), false) + require.NoError(t, err) + + handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil) + req.RemoteAddr = "100.65.5.6:5000" // outside 100.64.0.0/16 + req.Host = "example.com" + req = req.WithContext(types.WithOverlayOrigin(req.Context())) + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + assert.Equal(t, http.StatusForbidden, rr.Code, + "CIDR rules must still apply on the overlay path") +} + func TestProtect_OIDCOnlyRedirectsDirectly(t *testing.T) { mw := NewMiddleware(log.StandardLogger(), nil, nil) kp := generateTestKeyPair(t) @@ -781,11 +879,12 @@ func TestProtect_OIDCOnlyRedirectsDirectly(t *testing.T) { return "", oidcURL, nil }, } - require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil)) + require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false)) handler := mw.Protect(newPassthroughHandler()) - req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil) + req := httptest.NewRequest(http.MethodGet, "https://example.com/", nil) + req.TLS = &tls.ConnectionState{} rec := httptest.NewRecorder() handler.ServeHTTP(rec, req) @@ -809,11 +908,12 @@ func TestProtect_OIDCWithOtherMethodShowsLoginPage(t *testing.T) { return "", "pin", nil }, } - require.NoError(t, mw.AddDomain("example.com", []Scheme{oidcScheme, pinScheme}, kp.PublicKey, time.Hour, "", "", nil)) + require.NoError(t, mw.AddDomain("example.com", []Scheme{oidcScheme, pinScheme}, kp.PublicKey, time.Hour, "", "", nil, false)) handler := mw.Protect(newPassthroughHandler()) - req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil) + req := httptest.NewRequest(http.MethodGet, "https://example.com/", nil) + req.TLS = &tls.ConnectionState{} rec := httptest.NewRecorder() handler.ServeHTTP(rec, req) @@ -834,7 +934,7 @@ func (m *mockAuthenticator) Authenticate(ctx context.Context, in *proto.Authenti // returns a signed session token when the expected header value is provided. func newHeaderSchemeWithToken(t *testing.T, kp *sessionkey.KeyPair, headerName, expectedValue string) Header { t.Helper() - token, err := sessionkey.SignToken(kp.PrivateKey, "header-user", "example.com", auth.MethodHeader, time.Hour) + token, err := sessionkey.SignToken(kp.PrivateKey, "header-user", "", "example.com", auth.MethodHeader, nil, nil, time.Hour) require.NoError(t, err) mock := &mockAuthenticator{fn: func(_ context.Context, req *proto.AuthenticateRequest) (*proto.AuthenticateResponse, error) { @@ -852,7 +952,7 @@ func TestProtect_HeaderAuth_ForwardsOnSuccess(t *testing.T) { kp := generateTestKeyPair(t) hdr := newHeaderSchemeWithToken(t, kp, "X-API-Key", "secret-key") - require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr}, kp.PublicKey, time.Hour, "acc1", "svc1", nil)) + require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr}, kp.PublicKey, time.Hour, "acc1", "svc1", nil, false)) var backendCalled bool capturedData := proxy.NewCapturedData("") @@ -895,7 +995,7 @@ func TestProtect_HeaderAuth_MissingHeaderFallsThrough(t *testing.T) { hdr := newHeaderSchemeWithToken(t, kp, "X-API-Key", "secret-key") // Also add a PIN scheme so we can verify fallthrough behavior. pinScheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"} - require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr, pinScheme}, kp.PublicKey, time.Hour, "acc1", "svc1", nil)) + require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr, pinScheme}, kp.PublicKey, time.Hour, "acc1", "svc1", nil, false)) handler := mw.Protect(newPassthroughHandler()) @@ -915,7 +1015,7 @@ func TestProtect_HeaderAuth_WrongValueReturns401(t *testing.T) { return &proto.AuthenticateResponse{Success: false}, nil }} hdr := NewHeader(mock, "svc1", "acc1", "X-API-Key") - require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr}, kp.PublicKey, time.Hour, "acc1", "svc1", nil)) + require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr}, kp.PublicKey, time.Hour, "acc1", "svc1", nil, false)) capturedData := proxy.NewCapturedData("") handler := mw.Protect(newPassthroughHandler()) @@ -938,7 +1038,7 @@ func TestProtect_HeaderAuth_InfraErrorReturns502(t *testing.T) { return nil, errors.New("gRPC unavailable") }} hdr := NewHeader(mock, "svc1", "acc1", "X-API-Key") - require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr}, kp.PublicKey, time.Hour, "acc1", "svc1", nil)) + require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr}, kp.PublicKey, time.Hour, "acc1", "svc1", nil, false)) handler := mw.Protect(newPassthroughHandler()) @@ -955,7 +1055,7 @@ func TestProtect_HeaderAuth_SubsequentRequestUsesSessionCookie(t *testing.T) { kp := generateTestKeyPair(t) hdr := newHeaderSchemeWithToken(t, kp, "X-API-Key", "secret-key") - require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr}, kp.PublicKey, time.Hour, "acc1", "svc1", nil)) + require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr}, kp.PublicKey, time.Hour, "acc1", "svc1", nil, false)) handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusOK) @@ -1006,7 +1106,7 @@ func TestProtect_HeaderAuth_MultipleValuesSameHeader(t *testing.T) { mock := &mockAuthenticator{fn: func(_ context.Context, req *proto.AuthenticateRequest) (*proto.AuthenticateResponse, error) { ha := req.GetHeaderAuth() if ha != nil && accepted[ha.GetHeaderValue()] { - token, err := sessionkey.SignToken(kp.PrivateKey, "header-user", "example.com", auth.MethodHeader, time.Hour) + token, err := sessionkey.SignToken(kp.PrivateKey, "header-user", "", "example.com", auth.MethodHeader, nil, nil, time.Hour) require.NoError(t, err) return &proto.AuthenticateResponse{Success: true, SessionToken: token}, nil } @@ -1015,7 +1115,7 @@ func TestProtect_HeaderAuth_MultipleValuesSameHeader(t *testing.T) { // Single Header scheme (as if one entry existed), but the mock checks both values. hdr := NewHeader(mock, "svc1", "acc1", "Authorization") - require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr}, kp.PublicKey, time.Hour, "acc1", "svc1", nil)) + require.NoError(t, mw.AddDomain("example.com", []Scheme{hdr}, kp.PublicKey, time.Hour, "acc1", "svc1", nil, false)) var backendCalled bool handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { @@ -1059,3 +1159,71 @@ func TestProtect_HeaderAuth_MultipleValuesSameHeader(t *testing.T) { assert.False(t, backendCalled, "unknown token should be rejected") }) } + +// TestProtect_OIDCOnPlainHTTP_BlockedWith400 verifies that when an OIDC +// scheme is configured and the request arrived without TLS, the middleware +// short-circuits with a 400 instead of dispatching to the IdP redirect. +func TestProtect_OIDCOnPlainHTTP_BlockedWith400(t *testing.T) { + mw := NewMiddleware(log.StandardLogger(), nil, nil) + kp := generateTestKeyPair(t) + + scheme := &stubScheme{ + method: auth.MethodOIDC, + authFn: func(_ *http.Request) (string, string, error) { + return "", "https://idp.example.com/authorize", nil + }, + } + require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false)) + + handler := mw.Protect(newPassthroughHandler()) + + req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusBadRequest, rec.Code, "OIDC over plain HTTP should be rejected") + assert.Contains(t, rec.Body.String(), "OIDC requires TLS", "response body should explain the rejection") +} + +// TestProtect_OIDCOverTLS_NotBlocked confirms the same configuration works +// over TLS โ€” the block only fires on plain HTTP. +func TestProtect_OIDCOverTLS_NotBlocked(t *testing.T) { + mw := NewMiddleware(log.StandardLogger(), nil, nil) + kp := generateTestKeyPair(t) + + scheme := &stubScheme{ + method: auth.MethodOIDC, + authFn: func(_ *http.Request) (string, string, error) { + return "", "https://idp.example.com/authorize", nil + }, + } + require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false)) + + handler := mw.Protect(newPassthroughHandler()) + + req := httptest.NewRequest(http.MethodGet, "https://example.com/", nil) + req.TLS = &tls.ConnectionState{} + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusFound, rec.Code, "OIDC over TLS should redirect to IdP") +} + +// TestProtect_NonOIDCSchemes_PlainHTTP_NotBlocked confirms that the OIDC +// block only fires when an OIDC scheme is configured. PIN-only domains +// pass through normally on plain HTTP. +func TestProtect_NonOIDCSchemes_PlainHTTP_NotBlocked(t *testing.T) { + mw := NewMiddleware(log.StandardLogger(), nil, nil) + kp := generateTestKeyPair(t) + + scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"} + require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false)) + + handler := mw.Protect(newPassthroughHandler()) + + req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusUnauthorized, rec.Code, "PIN-only domain should serve the login page on plain HTTP") +} diff --git a/proxy/internal/auth/tunnel_cache.go b/proxy/internal/auth/tunnel_cache.go new file mode 100644 index 000000000..10b671d82 --- /dev/null +++ b/proxy/internal/auth/tunnel_cache.go @@ -0,0 +1,171 @@ +package auth + +import ( + "context" + "net/netip" + "sync" + "time" + + "golang.org/x/sync/singleflight" + + "github.com/netbirdio/netbird/proxy/internal/types" + "github.com/netbirdio/netbird/shared/management/proto" +) + +// tunnelCacheTTL caps how long a positive ValidateTunnelPeer result is +// reused before re-fetching from management. 5 minutes balances freshness +// against management load on busy mesh networks. +const tunnelCacheTTL = 300 * time.Second + +// tunnelCachePerAccount caps the number of cached identities per account. +// Bounded eviction avoids memory growth in pathological cases (huge peer +// roster, brief request bursts) while staying generous for normal use. +const tunnelCachePerAccount = 1024 + +// tunnelCacheKey identifies a cached entry by tunnel IP and originating +// account. Domain is part of the value, not the key, because the +// management response is per (account, IP) โ€” domain only gates whether a +// re-fetch is needed if the operator is accessing a different service. +type tunnelCacheKey struct { + accountID types.AccountID + tunnelIP netip.Addr + domain string +} + +// tunnelCacheEntry stores a positive validation response with the time it +// was minted. Entries past tunnelCacheTTL are treated as misses. +type tunnelCacheEntry struct { + resp *proto.ValidateTunnelPeerResponse + cachedAt time.Time +} + +// tunnelValidationCache memoizes ValidateTunnelPeer responses keyed by +// (accountID, tunnelIP, domain). Only successful, valid responses are +// cached โ€” denials skip the cache so policy changes apply immediately. +// Single-flight de-duplicates concurrent fetches for the same key so a +// burst of cold requests collapses into a single RPC. +type tunnelValidationCache struct { + mu sync.Mutex + entries map[types.AccountID]*accountBucket + flight singleflight.Group + ttl time.Duration + maxSize int + now func() time.Time +} + +// accountBucket holds the cached entries for a single account, with a +// FIFO eviction queue used when the bucket exceeds maxSize. +type accountBucket struct { + items map[tunnelCacheKey]tunnelCacheEntry + order []tunnelCacheKey +} + +// newTunnelValidationCache constructs a cache with default TTL and bounds. +func newTunnelValidationCache() *tunnelValidationCache { + return &tunnelValidationCache{ + entries: make(map[types.AccountID]*accountBucket), + ttl: tunnelCacheTTL, + maxSize: tunnelCachePerAccount, + now: time.Now, + } +} + +// get returns a cached response for the key, or nil when missing or +// expired. Expired entries are evicted lazily on read. +func (c *tunnelValidationCache) get(key tunnelCacheKey) *proto.ValidateTunnelPeerResponse { + c.mu.Lock() + defer c.mu.Unlock() + + bucket, ok := c.entries[key.accountID] + if !ok { + return nil + } + entry, ok := bucket.items[key] + if !ok { + return nil + } + if c.now().Sub(entry.cachedAt) > c.ttl { + delete(bucket.items, key) + bucket.order = removeKey(bucket.order, key) + return nil + } + return entry.resp +} + +// put records a positive response under the key. Evicts the oldest entry +// in the account's bucket when the bound is exceeded. +func (c *tunnelValidationCache) put(key tunnelCacheKey, resp *proto.ValidateTunnelPeerResponse) { + c.mu.Lock() + defer c.mu.Unlock() + + bucket, ok := c.entries[key.accountID] + if !ok { + bucket = &accountBucket{items: make(map[tunnelCacheKey]tunnelCacheEntry)} + c.entries[key.accountID] = bucket + } + if _, exists := bucket.items[key]; !exists { + bucket.order = append(bucket.order, key) + } + bucket.items[key] = tunnelCacheEntry{resp: resp, cachedAt: c.now()} + + for len(bucket.order) > c.maxSize { + oldest := bucket.order[0] + bucket.order = bucket.order[1:] + delete(bucket.items, oldest) + } +} + +// removeKey drops the first occurrence of needle from order. The cache +// uses small slices so a linear scan is cheaper than a map+slice combo. +func removeKey(order []tunnelCacheKey, needle tunnelCacheKey) []tunnelCacheKey { + for i, k := range order { + if k == needle { + return append(order[:i], order[i+1:]...) + } + } + return order +} + +// flightKey turns a cache key into a single-flight string. AccountID and +// IP isolation by themselves are insufficient because different domains +// for the same peer/account may have different group access. +func flightKey(key tunnelCacheKey) string { + return string(key.accountID) + "|" + key.tunnelIP.String() + "|" + key.domain +} + +// validateTunnelPeerFn is the RPC entry point the cache wraps. It matches +// the SessionValidator.ValidateTunnelPeer signature without exposing the +// gRPC option variadic, since callers don't need it on the cache hot path. +type validateTunnelPeerFn func(ctx context.Context, req *proto.ValidateTunnelPeerRequest) (*proto.ValidateTunnelPeerResponse, error) + +// fetch returns a cached response when present, otherwise calls validate +// under single-flight and caches the result. Denied responses pass +// through but are not cached so policy changes apply immediately. +func (c *tunnelValidationCache) fetch(ctx context.Context, key tunnelCacheKey, validate validateTunnelPeerFn) (*proto.ValidateTunnelPeerResponse, bool, error) { + if resp := c.get(key); resp != nil { + return resp, true, nil + } + + flight := flightKey(key) + res, err, _ := c.flight.Do(flight, func() (any, error) { + if cached := c.get(key); cached != nil { + return cached, nil + } + resp, err := validate(ctx, &proto.ValidateTunnelPeerRequest{ + TunnelIp: key.tunnelIP.String(), + Domain: key.domain, + }) + if err != nil { + return nil, err + } + if resp.GetValid() && resp.GetSessionToken() != "" { + c.put(key, resp) + } + return resp, nil + }) + if err != nil { + return nil, false, err + } + resp, _ := res.(*proto.ValidateTunnelPeerResponse) + return resp, false, nil +} diff --git a/proxy/internal/auth/tunnel_cache_test.go b/proxy/internal/auth/tunnel_cache_test.go new file mode 100644 index 000000000..1a63dc107 --- /dev/null +++ b/proxy/internal/auth/tunnel_cache_test.go @@ -0,0 +1,171 @@ +package auth + +import ( + "context" + "net/netip" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/proxy/internal/types" + "github.com/netbirdio/netbird/shared/management/proto" +) + +func newTestKey(account types.AccountID, ip string, domain string) tunnelCacheKey { + return tunnelCacheKey{ + accountID: account, + tunnelIP: netip.MustParseAddr(ip), + domain: domain, + } +} + +func TestTunnelCache_HitSkipsRPC(t *testing.T) { + cache := newTunnelValidationCache() + key := newTestKey("acct-1", "100.64.0.10", "svc.example") + + var calls int32 + validate := func(_ context.Context, req *proto.ValidateTunnelPeerRequest) (*proto.ValidateTunnelPeerResponse, error) { + atomic.AddInt32(&calls, 1) + return &proto.ValidateTunnelPeerResponse{Valid: true, SessionToken: "tok", UserId: "user-1"}, nil + } + + resp, fromCache, err := cache.fetch(context.Background(), key, validate) + require.NoError(t, err) + require.NotNil(t, resp, "first fetch returns RPC response") + assert.False(t, fromCache, "first fetch must not be cached") + + resp2, fromCache2, err := cache.fetch(context.Background(), key, validate) + require.NoError(t, err) + require.NotNil(t, resp2, "second fetch returns cached response") + assert.True(t, fromCache2, "second fetch must be served from cache") + assert.Equal(t, "user-1", resp2.GetUserId(), "cached response should preserve user identity") + assert.Equal(t, int32(1), atomic.LoadInt32(&calls), "validate should run exactly once with one cache hit") +} + +func TestTunnelCache_ExpiredEntryRefetches(t *testing.T) { + cache := newTunnelValidationCache() + clock := time.Now() + cache.now = func() time.Time { return clock } + + key := newTestKey("acct-1", "100.64.0.10", "svc.example") + var calls int32 + validate := func(_ context.Context, _ *proto.ValidateTunnelPeerRequest) (*proto.ValidateTunnelPeerResponse, error) { + atomic.AddInt32(&calls, 1) + return &proto.ValidateTunnelPeerResponse{Valid: true, SessionToken: "tok"}, nil + } + + _, _, err := cache.fetch(context.Background(), key, validate) + require.NoError(t, err) + assert.Equal(t, int32(1), atomic.LoadInt32(&calls), "first fetch issues one RPC") + + clock = clock.Add(tunnelCacheTTL + time.Second) + + _, fromCache, err := cache.fetch(context.Background(), key, validate) + require.NoError(t, err) + assert.False(t, fromCache, "expired entry must miss the cache") + assert.Equal(t, int32(2), atomic.LoadInt32(&calls), "expired entry forces a re-fetch") +} + +func TestTunnelCache_DeniedResponseNotCached(t *testing.T) { + cache := newTunnelValidationCache() + key := newTestKey("acct-1", "100.64.0.10", "svc.example") + + var calls int32 + validate := func(_ context.Context, _ *proto.ValidateTunnelPeerRequest) (*proto.ValidateTunnelPeerResponse, error) { + atomic.AddInt32(&calls, 1) + return &proto.ValidateTunnelPeerResponse{Valid: false, DeniedReason: "not_in_group"}, nil + } + + for i := 0; i < 3; i++ { + _, _, err := cache.fetch(context.Background(), key, validate) + require.NoError(t, err, "fetch must not error on denied response") + } + assert.Equal(t, int32(3), atomic.LoadInt32(&calls), "denied responses bypass the cache so policy changes apply immediately") +} + +func TestTunnelCache_ConcurrentColdHitsCoalesce(t *testing.T) { + cache := newTunnelValidationCache() + key := newTestKey("acct-1", "100.64.0.10", "svc.example") + + gate := make(chan struct{}) + var calls int32 + validate := func(_ context.Context, _ *proto.ValidateTunnelPeerRequest) (*proto.ValidateTunnelPeerResponse, error) { + atomic.AddInt32(&calls, 1) + <-gate + return &proto.ValidateTunnelPeerResponse{Valid: true, SessionToken: "tok"}, nil + } + + const workers = 16 + var wg sync.WaitGroup + wg.Add(workers) + results := make([]bool, workers) + for i := 0; i < workers; i++ { + go func(idx int) { + defer wg.Done() + resp, _, err := cache.fetch(context.Background(), key, validate) + results[idx] = err == nil && resp.GetValid() + }(i) + } + + time.Sleep(20 * time.Millisecond) + close(gate) + wg.Wait() + + for i, ok := range results { + assert.Truef(t, ok, "worker %d should observe a successful response", i) + } + assert.Equal(t, int32(1), atomic.LoadInt32(&calls), "single-flight must collapse concurrent cold fetches into one RPC") +} + +func TestTunnelCache_PerAccountIsolation(t *testing.T) { + cache := newTunnelValidationCache() + keyA := newTestKey("acct-a", "100.64.0.10", "svc.example") + keyB := newTestKey("acct-b", "100.64.0.10", "svc.example") + + var callsA, callsB int32 + validateA := func(_ context.Context, _ *proto.ValidateTunnelPeerRequest) (*proto.ValidateTunnelPeerResponse, error) { + atomic.AddInt32(&callsA, 1) + return &proto.ValidateTunnelPeerResponse{Valid: true, SessionToken: "tok-a", UserId: "user-a"}, nil + } + validateB := func(_ context.Context, _ *proto.ValidateTunnelPeerRequest) (*proto.ValidateTunnelPeerResponse, error) { + atomic.AddInt32(&callsB, 1) + return &proto.ValidateTunnelPeerResponse{Valid: true, SessionToken: "tok-b", UserId: "user-b"}, nil + } + + respA, _, err := cache.fetch(context.Background(), keyA, validateA) + require.NoError(t, err) + respB, _, err := cache.fetch(context.Background(), keyB, validateB) + require.NoError(t, err) + + assert.Equal(t, "user-a", respA.GetUserId(), "account A response should belong to user-a") + assert.Equal(t, "user-b", respB.GetUserId(), "account B response must not be served from account A's cache") + assert.Equal(t, int32(1), atomic.LoadInt32(&callsA), "validateA called exactly once") + assert.Equal(t, int32(1), atomic.LoadInt32(&callsB), "validateB called exactly once") +} + +func TestTunnelCache_BoundedSizeEvictsOldest(t *testing.T) { + cache := newTunnelValidationCache() + cache.maxSize = 2 + + validate := func(_ context.Context, req *proto.ValidateTunnelPeerRequest) (*proto.ValidateTunnelPeerResponse, error) { + return &proto.ValidateTunnelPeerResponse{Valid: true, SessionToken: "tok-" + req.GetTunnelIp()}, nil + } + + keys := []tunnelCacheKey{ + newTestKey("acct-1", "100.64.0.10", "svc"), + newTestKey("acct-1", "100.64.0.11", "svc"), + newTestKey("acct-1", "100.64.0.12", "svc"), + } + for _, k := range keys { + _, _, err := cache.fetch(context.Background(), k, validate) + require.NoError(t, err) + } + + assert.Nil(t, cache.get(keys[0]), "oldest key should be evicted past maxSize") + assert.NotNil(t, cache.get(keys[1]), "second-newest must remain cached") + assert.NotNil(t, cache.get(keys[2]), "newest must remain cached") +} diff --git a/proxy/internal/auth/tunnel_lookup_test.go b/proxy/internal/auth/tunnel_lookup_test.go new file mode 100644 index 000000000..cc8081af2 --- /dev/null +++ b/proxy/internal/auth/tunnel_lookup_test.go @@ -0,0 +1,325 @@ +package auth + +import ( + "context" + "errors" + "net/http" + "net/http/httptest" + "net/netip" + "sync/atomic" + "testing" + + log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/grpc" + + "github.com/netbirdio/netbird/proxy/internal/proxy" + "github.com/netbirdio/netbird/shared/management/proto" +) + +// stubSessionValidator records ValidateTunnelPeer calls and returns the +// pre-canned response. Counts let tests assert RPC traffic. +type stubSessionValidator struct { + respFn func(req *proto.ValidateTunnelPeerRequest) *proto.ValidateTunnelPeerResponse + respErr error + tunnelCalls atomic.Int32 +} + +func (s *stubSessionValidator) ValidateSession(_ context.Context, _ *proto.ValidateSessionRequest, _ ...grpc.CallOption) (*proto.ValidateSessionResponse, error) { + return &proto.ValidateSessionResponse{Valid: false}, nil +} + +func (s *stubSessionValidator) ValidateTunnelPeer(_ context.Context, in *proto.ValidateTunnelPeerRequest, _ ...grpc.CallOption) (*proto.ValidateTunnelPeerResponse, error) { + s.tunnelCalls.Add(1) + if s.respErr != nil { + return nil, s.respErr + } + if s.respFn != nil { + return s.respFn(in), nil + } + return &proto.ValidateTunnelPeerResponse{Valid: false}, nil +} + +func newTunnelMiddleware(t *testing.T, validator SessionValidator) *Middleware { + t.Helper() + mw := NewMiddleware(log.New(), validator, nil) + require.NoError(t, mw.AddDomain("svc.example", nil, "", 0, "acct-1", "svc-1", nil, false)) + return mw +} + +func newTunnelRequest(remoteAddr string) (*httptest.ResponseRecorder, *http.Request) { + w := httptest.NewRecorder() + r := httptest.NewRequest(http.MethodGet, "https://svc.example/", nil) + r.Host = "svc.example" + r.RemoteAddr = remoteAddr + return w, r +} + +// TestForwardWithTunnelPeer_LocalLookupUnknownIPDeniesFast verifies the +// short-circuit: a tunnel IP not in the account's roster never reaches +// management's ValidateTunnelPeer. +func TestForwardWithTunnelPeer_LocalLookupUnknownIPDeniesFast(t *testing.T) { + validator := &stubSessionValidator{} + mw := newTunnelMiddleware(t, validator) + + lookup := TunnelLookupFunc(func(_ netip.Addr) (PeerIdentity, bool) { + return PeerIdentity{}, false + }) + + w, r := newTunnelRequest("100.64.0.99:55555") + r = r.WithContext(WithTunnelLookup(r.Context(), lookup)) + + called := false + next := http.HandlerFunc(func(http.ResponseWriter, *http.Request) { called = true }) + + config, _ := mw.getDomainConfig("svc.example") + handled := mw.forwardWithTunnelPeer(w, r, "svc.example", config, next) + + assert.False(t, handled, "unknown peer must fall through, not forward") + assert.False(t, called, "next handler must not run for unknown peer") + assert.Equal(t, int32(0), validator.tunnelCalls.Load(), "ValidateTunnelPeer must be skipped on local-lookup miss") +} + +// TestForwardWithTunnelPeer_GroupsPropagateToCapturedData verifies the proxy +// surfaces the calling peer's group memberships from ValidateTunnelPeerResponse +// onto CapturedData so policy-aware middlewares can authorise without an +// extra management round-trip. +func TestForwardWithTunnelPeer_GroupsPropagateToCapturedData(t *testing.T) { + groups := []string{"engineering", "sre"} + validator := &stubSessionValidator{ + respFn: func(_ *proto.ValidateTunnelPeerRequest) *proto.ValidateTunnelPeerResponse { + return &proto.ValidateTunnelPeerResponse{ + Valid: true, + SessionToken: "tok", + UserId: "user-1", + PeerGroupIds: groups, + } + }, + } + mw := newTunnelMiddleware(t, validator) + + w, r := newTunnelRequest("100.64.0.10:55555") + cd := proxy.NewCapturedData("") + r = r.WithContext(proxy.WithCapturedData(r.Context(), cd)) + + called := false + next := http.HandlerFunc(func(http.ResponseWriter, *http.Request) { called = true }) + + config, _ := mw.getDomainConfig("svc.example") + handled := mw.forwardWithTunnelPeer(w, r, "svc.example", config, next) + + require.True(t, handled, "valid tunnel-peer response must forward") + require.True(t, called, "next handler must run") + assert.Equal(t, "user-1", cd.GetUserID(), "user id must propagate from tunnel-peer response") + assert.Equal(t, groups, cd.GetUserGroups(), "peer group IDs must propagate from tunnel-peer response") +} + +// TestForwardWithTunnelPeer_LocalLookupKnownPeerStillRPCs verifies that a +// known tunnel IP still triggers ValidateTunnelPeer for the user-identity +// tail (UserID + group access). Phase 3 only short-circuits the deny path. +func TestForwardWithTunnelPeer_LocalLookupKnownPeerStillRPCs(t *testing.T) { + validator := &stubSessionValidator{ + respFn: func(_ *proto.ValidateTunnelPeerRequest) *proto.ValidateTunnelPeerResponse { + return &proto.ValidateTunnelPeerResponse{Valid: true, SessionToken: "tok", UserId: "user-1"} + }, + } + mw := newTunnelMiddleware(t, validator) + + knownIP := netip.MustParseAddr("100.64.0.10") + lookup := TunnelLookupFunc(func(ip netip.Addr) (PeerIdentity, bool) { + if ip == knownIP { + return PeerIdentity{PubKey: "pk", TunnelIP: ip, FQDN: "peer.netbird.cloud"}, true + } + return PeerIdentity{}, false + }) + + w, r := newTunnelRequest(knownIP.String() + ":55555") + r = r.WithContext(WithTunnelLookup(r.Context(), lookup)) + + called := false + next := http.HandlerFunc(func(http.ResponseWriter, *http.Request) { called = true }) + + config, _ := mw.getDomainConfig("svc.example") + handled := mw.forwardWithTunnelPeer(w, r, "svc.example", config, next) + + assert.True(t, handled, "known peer with valid RPC response must forward") + assert.True(t, called, "next handler must run on success") + assert.Equal(t, int32(1), validator.tunnelCalls.Load(), "RPC must run for the user-identity tail when local lookup confirms the peer") +} + +// TestForwardWithTunnelPeer_NoLookupKeepsLegacyPath ensures the existing +// behaviour stays intact on the host-level listener (no lookup attached). +func TestForwardWithTunnelPeer_NoLookupKeepsLegacyPath(t *testing.T) { + validator := &stubSessionValidator{ + respFn: func(_ *proto.ValidateTunnelPeerRequest) *proto.ValidateTunnelPeerResponse { + return &proto.ValidateTunnelPeerResponse{Valid: true, SessionToken: "tok", UserId: "user-1"} + }, + } + mw := newTunnelMiddleware(t, validator) + + w, r := newTunnelRequest("100.64.0.10:55555") + called := false + next := http.HandlerFunc(func(http.ResponseWriter, *http.Request) { called = true }) + + config, _ := mw.getDomainConfig("svc.example") + handled := mw.forwardWithTunnelPeer(w, r, "svc.example", config, next) + + assert.True(t, handled, "host-level path forwards on positive RPC result") + assert.True(t, called, "next handler runs on host-level success") + assert.Equal(t, int32(1), validator.tunnelCalls.Load(), "host-level path always RPCs (Phase 3 unchanged)") +} + +// TestForwardWithTunnelPeer_RPCErrorFallsThrough validates that an RPC +// failure still falls through to the next scheme (no false positive). +func TestForwardWithTunnelPeer_RPCErrorFallsThrough(t *testing.T) { + validator := &stubSessionValidator{respErr: errors.New("management down")} + mw := newTunnelMiddleware(t, validator) + + knownIP := netip.MustParseAddr("100.64.0.10") + lookup := TunnelLookupFunc(func(ip netip.Addr) (PeerIdentity, bool) { + return PeerIdentity{TunnelIP: ip}, true + }) + + w, r := newTunnelRequest(knownIP.String() + ":55555") + r = r.WithContext(WithTunnelLookup(r.Context(), lookup)) + + config, _ := mw.getDomainConfig("svc.example") + handled := mw.forwardWithTunnelPeer(w, r, "svc.example", config, http.HandlerFunc(func(http.ResponseWriter, *http.Request) {})) + + assert.False(t, handled, "RPC error must let the caller try other schemes") + assert.Equal(t, int32(1), validator.tunnelCalls.Load(), "RPC was attempted exactly once") +} + +// TestForwardWithTunnelPeer_CacheReusesPositiveResponse confirms the +// (account, IP, domain) cache prevents repeated RPCs for the same peer. +func TestForwardWithTunnelPeer_CacheReusesPositiveResponse(t *testing.T) { + validator := &stubSessionValidator{ + respFn: func(_ *proto.ValidateTunnelPeerRequest) *proto.ValidateTunnelPeerResponse { + return &proto.ValidateTunnelPeerResponse{Valid: true, SessionToken: "tok", UserId: "user-1"} + }, + } + mw := newTunnelMiddleware(t, validator) + + for i := 0; i < 4; i++ { + w, r := newTunnelRequest("100.64.0.10:55555") + next := http.HandlerFunc(func(http.ResponseWriter, *http.Request) {}) + config, _ := mw.getDomainConfig("svc.example") + handled := mw.forwardWithTunnelPeer(w, r, "svc.example", config, next) + require.True(t, handled, "iteration %d should forward", i) + } + + assert.Equal(t, int32(1), validator.tunnelCalls.Load(), "subsequent forwards must hit the cache, not management") +} + +// TestForwardWithTunnelPeer_RoutesAccountIDIntoCacheKey ensures cache keys +// honour account scoping โ€” same tunnel IP on different accounts must not +// collide. +func TestForwardWithTunnelPeer_RoutesAccountIDIntoCacheKey(t *testing.T) { + validator := &stubSessionValidator{ + respFn: func(req *proto.ValidateTunnelPeerRequest) *proto.ValidateTunnelPeerResponse { + return &proto.ValidateTunnelPeerResponse{Valid: true, SessionToken: "tok", UserId: "user"} + }, + } + mw := NewMiddleware(log.New(), validator, nil) + + require.NoError(t, mw.AddDomain("svc-a.example", nil, "", 0, "acct-a", "svc-a", nil, false)) + require.NoError(t, mw.AddDomain("svc-b.example", nil, "", 0, "acct-b", "svc-b", nil, false)) + + for _, host := range []string{"svc-a.example", "svc-b.example"} { + w := httptest.NewRecorder() + r := httptest.NewRequest(http.MethodGet, "https://"+host+"/", nil) + r.Host = host + r.RemoteAddr = "100.64.0.10:55555" + config, _ := mw.getDomainConfig(host) + handled := mw.forwardWithTunnelPeer(w, r, host, config, http.HandlerFunc(func(http.ResponseWriter, *http.Request) {})) + require.True(t, handled, "host %s should forward", host) + } + + assert.Equal(t, int32(2), validator.tunnelCalls.Load(), "cache must not collide across accounts even when tunnel IPs match") +} + +// TestForwardWithTunnelPeer_LocalLookupShortCircuitDoesNotPopulateCache +// guarantees that the deny-fast path leaves the cache untouched, so a +// subsequent request from the same IP after the peerstore catches up +// goes through the normal RPC flow. +func TestForwardWithTunnelPeer_LocalLookupShortCircuitDoesNotPopulateCache(t *testing.T) { + validator := &stubSessionValidator{ + respFn: func(_ *proto.ValidateTunnelPeerRequest) *proto.ValidateTunnelPeerResponse { + return &proto.ValidateTunnelPeerResponse{Valid: true, SessionToken: "tok"} + }, + } + mw := newTunnelMiddleware(t, validator) + + knownIP := netip.MustParseAddr("100.64.0.10") + known := false + lookup := TunnelLookupFunc(func(ip netip.Addr) (PeerIdentity, bool) { + if known && ip == knownIP { + return PeerIdentity{TunnelIP: ip}, true + } + return PeerIdentity{}, false + }) + + doRequest := func() bool { + w, r := newTunnelRequest(knownIP.String() + ":55555") + r = r.WithContext(WithTunnelLookup(r.Context(), lookup)) + config, _ := mw.getDomainConfig("svc.example") + return mw.forwardWithTunnelPeer(w, r, "svc.example", config, http.HandlerFunc(func(http.ResponseWriter, *http.Request) {})) + } + + require.False(t, doRequest(), "first request must short-circuit") + require.Equal(t, int32(0), validator.tunnelCalls.Load(), "short-circuit must not populate the cache") + + known = true + require.True(t, doRequest(), "second request with peer in roster must forward via RPC") + assert.Equal(t, int32(1), validator.tunnelCalls.Load(), "RPC runs once after peerstore catches up") +} + +func TestPrivateService_FailsClosedOnTunnelPeerFailure(t *testing.T) { + mw := NewMiddleware(log.New(), nil, nil) + require.NoError(t, mw.AddDomain("private.svc", nil, "", 0, "acct-1", "svc-1", nil, true)) + + called := false + handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + called = true + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest(http.MethodGet, "https://private.svc/", nil) + req.Host = "private.svc" + req.RemoteAddr = "100.64.0.10:55555" + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + + assert.Equal(t, http.StatusForbidden, w.Code) + assert.False(t, called) +} + +func TestPrivateService_ForwardsOnTunnelPeerSuccess(t *testing.T) { + validator := &stubSessionValidator{ + respFn: func(_ *proto.ValidateTunnelPeerRequest) *proto.ValidateTunnelPeerResponse { + return &proto.ValidateTunnelPeerResponse{ + Valid: true, + SessionToken: "tok", + UserId: "user-1", + } + }, + } + mw := NewMiddleware(log.New(), validator, nil) + require.NoError(t, mw.AddDomain("private.svc", nil, "", 0, "acct-1", "svc-1", nil, true)) + + called := false + handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + called = true + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest(http.MethodGet, "https://private.svc/", nil) + req.Host = "private.svc" + req.RemoteAddr = "100.64.0.10:55555" + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + assert.True(t, called) +} diff --git a/proxy/internal/debug/client.go b/proxy/internal/debug/client.go index 09c25afb2..736781652 100644 --- a/proxy/internal/debug/client.go +++ b/proxy/internal/debug/client.go @@ -11,7 +11,6 @@ import ( "net/url" "strings" "time" - ) // StatusFilters contains filter options for status queries. @@ -160,6 +159,49 @@ func (c *Client) printClients(data map[string]any) { for _, item := range clients { c.printClientRow(item) } + + c.printInboundListeners(clients) +} + +func (c *Client) printInboundListeners(clients []any) { + type row struct { + accountID string + tunnelIP string + httpsPort int + httpPort int + } + var rows []row + for _, item := range clients { + client, ok := item.(map[string]any) + if !ok { + continue + } + inbound, ok := client["inbound_listener"].(map[string]any) + if !ok { + continue + } + tunnelIP, _ := inbound["tunnel_ip"].(string) + httpsPort, _ := inbound["https_port"].(float64) + httpPort, _ := inbound["http_port"].(float64) + accountID, _ := client["account_id"].(string) + rows = append(rows, row{ + accountID: accountID, + tunnelIP: tunnelIP, + httpsPort: int(httpsPort), + httpPort: int(httpPort), + }) + } + if len(rows) == 0 { + return + } + + _, _ = fmt.Fprintln(c.out) + _, _ = fmt.Fprintln(c.out, "Inbound listeners (per-account):") + _, _ = fmt.Fprintf(c.out, " %-38s %-20s %-7s %s\n", "ACCOUNT ID", "TUNNEL IP", "HTTPS", "HTTP") + _, _ = fmt.Fprintln(c.out, " "+strings.Repeat("-", 78)) + for _, r := range rows { + _, _ = fmt.Fprintf(c.out, " %-38s %-20s %-7d %d\n", r.accountID, r.tunnelIP, r.httpsPort, r.httpPort) + } } func (c *Client) printClientRow(item any) { @@ -219,7 +261,14 @@ func (c *Client) ClientStatus(ctx context.Context, accountID string, filters Sta } func (c *Client) printClientStatus(data map[string]any) { - _, _ = fmt.Fprintf(c.out, "Account: %v\n\n", data["account_id"]) + _, _ = fmt.Fprintf(c.out, "Account: %v\n", data["account_id"]) + if inbound, ok := data["inbound_listener"].(map[string]any); ok { + tunnelIP, _ := inbound["tunnel_ip"].(string) + httpsPort, _ := inbound["https_port"].(float64) + httpPort, _ := inbound["http_port"].(float64) + _, _ = fmt.Fprintf(c.out, "Inbound listener: %s (https=%d, http=%d)\n", tunnelIP, int(httpsPort), int(httpPort)) + } + _, _ = fmt.Fprintln(c.out) if status, ok := data["status"].(string); ok { _, _ = fmt.Fprint(c.out, status) } diff --git a/proxy/internal/debug/handler.go b/proxy/internal/debug/handler.go index 23ca4adbb..1dbfe1522 100644 --- a/proxy/internal/debug/handler.go +++ b/proxy/internal/debug/handler.go @@ -61,6 +61,23 @@ type clientProvider interface { ListClientsForDebug() map[types.AccountID]roundtrip.ClientDebugInfo } +// InboundListenerInfo describes a per-account inbound listener as +// surfaced through the debug HTTP handler. Mirrors the proto sub-message +// emitted with SendStatusUpdate so dashboards and CLI tooling see the +// same shape. +type InboundListenerInfo struct { + TunnelIP string `json:"tunnel_ip"` + HTTPSPort uint16 `json:"https_port"` + HTTPPort uint16 `json:"http_port"` +} + +// InboundProvider exposes per-account inbound listener state. Optional; +// when nil the debug endpoint omits the inbound section entirely so the +// existing JSON shape stays additive. +type InboundProvider interface { + InboundListeners() map[types.AccountID]InboundListenerInfo +} + // healthChecker provides health probe state. type healthChecker interface { ReadinessProbe() bool @@ -80,6 +97,7 @@ type Handler struct { provider clientProvider health healthChecker certStatus certStatus + inbound InboundProvider logger *log.Logger startTime time.Time templates *template.Template @@ -108,6 +126,13 @@ func (h *Handler) SetCertStatus(cs certStatus) { h.certStatus = cs } +// SetInboundProvider wires per-account inbound listener observability. +// Pass nil (or skip the call) to keep the inbound section out of debug +// responses on proxies that don't run --private-inbound. +func (h *Handler) SetInboundProvider(p InboundProvider) { + h.inbound = p +} + func (h *Handler) loadTemplates() error { tmpl, err := template.ParseFS(templateFS, "templates/*.html") if err != nil { @@ -323,23 +348,35 @@ func (h *Handler) handleListClients(w http.ResponseWriter, _ *http.Request, want sortedIDs := sortedAccountIDs(clients) if wantJSON { + var inboundAll map[types.AccountID]InboundListenerInfo + if h.inbound != nil { + inboundAll = h.inbound.InboundListeners() + } clientsJSON := make([]map[string]interface{}, 0, len(clients)) for _, id := range sortedIDs { info := clients[id] - clientsJSON = append(clientsJSON, map[string]interface{}{ + row := map[string]interface{}{ "account_id": info.AccountID, "service_count": info.ServiceCount, "service_keys": info.ServiceKeys, "has_client": info.HasClient, "created_at": info.CreatedAt, "age": time.Since(info.CreatedAt).Round(time.Second).String(), - }) + } + if inb, ok := inboundAll[id]; ok { + row["inbound_listener"] = inb + } + clientsJSON = append(clientsJSON, row) } - h.writeJSON(w, map[string]interface{}{ + resp := map[string]interface{}{ "uptime": time.Since(h.startTime).Round(time.Second).String(), "client_count": len(clients), "clients": clientsJSON, - }) + } + if len(inboundAll) > 0 { + resp["inbound_listener_count"] = len(inboundAll) + } + h.writeJSON(w, resp) return } @@ -421,10 +458,14 @@ func (h *Handler) handleClientStatus(w http.ResponseWriter, r *http.Request, acc }) if wantJSON { - h.writeJSON(w, map[string]interface{}{ + resp := map[string]interface{}{ "account_id": accountID, "status": overview.FullDetailSummary(), - }) + } + if info, ok := h.inboundInfoFor(accountID); ok { + resp["inbound_listener"] = info + } + h.writeJSON(w, resp) return } @@ -437,6 +478,18 @@ func (h *Handler) handleClientStatus(w http.ResponseWriter, r *http.Request, acc h.renderTemplate(w, "clientDetail", data) } +// inboundInfoFor returns the inbound listener info for an account, or +// ok=false when no inbound provider is wired or the account has no live +// listener. +func (h *Handler) inboundInfoFor(accountID types.AccountID) (InboundListenerInfo, bool) { + if h.inbound == nil { + return InboundListenerInfo{}, false + } + all := h.inbound.InboundListeners() + info, ok := all[accountID] + return info, ok +} + func (h *Handler) handleClientSyncResponse(w http.ResponseWriter, _ *http.Request, accountID types.AccountID, wantJSON bool) { client, ok := h.provider.GetClient(accountID) if !ok { diff --git a/proxy/internal/proxy/context.go b/proxy/internal/proxy/context.go index a888ad9ed..e05ec78aa 100644 --- a/proxy/internal/proxy/context.go +++ b/proxy/internal/proxy/context.go @@ -52,8 +52,15 @@ type CapturedData struct { origin ResponseOrigin clientIP netip.Addr userID string - authMethod string - metadata map[string]string + userEmail string + userGroups []string + // userGroupNames pairs positionally with userGroups; populated from + // the JWT's group_names claim or from ValidateSession/Tunnel + // responses. Slice may be shorter than userGroups for tokens minted + // before names were resolvable. + userGroupNames []string + authMethod string + metadata map[string]string } // NewCapturedData creates a CapturedData with the given request ID. @@ -138,6 +145,81 @@ func (c *CapturedData) GetUserID() string { return c.userID } +// SetUserEmail records the authenticated user's email address. Used by +// policy-aware middlewares to stamp identity onto upstream requests +// (e.g. x-litellm-end-user-id) without a management round-trip. +func (c *CapturedData) SetUserEmail(email string) { + c.mu.Lock() + defer c.mu.Unlock() + c.userEmail = email +} + +// GetUserEmail returns the authenticated user's email address. Returns +// the empty string when the auth path didn't carry an email (e.g. +// non-OIDC schemes or legacy JWTs minted before the email claim). +func (c *CapturedData) GetUserEmail() string { + c.mu.RLock() + defer c.mu.RUnlock() + return c.userEmail +} + +// SetUserGroups records the authenticated user's group memberships so +// downstream policy-aware middlewares can authorise the request without +// an additional management round-trip. The auth middleware populates this +// from ValidateSessionResponse / ValidateTunnelPeerResponse and from the +// session JWT's groups claim on cookie-bearing requests. +func (c *CapturedData) SetUserGroups(groups []string) { + c.mu.Lock() + defer c.mu.Unlock() + if len(groups) == 0 { + c.userGroups = nil + return + } + c.userGroups = append(c.userGroups[:0], groups...) +} + +// GetUserGroups returns a copy of the authenticated user's group +// memberships. +func (c *CapturedData) GetUserGroups() []string { + c.mu.RLock() + defer c.mu.RUnlock() + if len(c.userGroups) == 0 { + return nil + } + out := make([]string, len(c.userGroups)) + copy(out, c.userGroups) + return out +} + +// SetUserGroupNames records the human-readable display names for the +// user's groups, ordered identically to UserGroups (positional +// pairing). Stamped onto upstream requests as X-NetBird-Groups so +// downstream services can read names rather than opaque ids. +func (c *CapturedData) SetUserGroupNames(names []string) { + c.mu.Lock() + defer c.mu.Unlock() + if len(names) == 0 { + c.userGroupNames = nil + return + } + c.userGroupNames = append(c.userGroupNames[:0], names...) +} + +// GetUserGroupNames returns a copy of the authenticated user's group +// display names. Position i pairs with UserGroups[i]. May be shorter +// than UserGroups for tokens minted before names were resolvable; the +// consumer should fall back to ids for missing positions. +func (c *CapturedData) GetUserGroupNames() []string { + c.mu.RLock() + defer c.mu.RUnlock() + if len(c.userGroupNames) == 0 { + return nil + } + out := make([]string, len(c.userGroupNames)) + copy(out, c.userGroupNames) + return out +} + // SetAuthMethod sets the authentication method used. func (c *CapturedData) SetAuthMethod(method string) { c.mu.Lock() diff --git a/proxy/internal/proxy/reverseproxy.go b/proxy/internal/proxy/reverseproxy.go index 246851d24..e437e78a7 100644 --- a/proxy/internal/proxy/reverseproxy.go +++ b/proxy/internal/proxy/reverseproxy.go @@ -86,6 +86,9 @@ func (p *ReverseProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { if pt.RequestTimeout > 0 { ctx = types.WithDialTimeout(ctx, pt.RequestTimeout) } + if pt.DirectUpstream { + ctx = roundtrip.WithDirectUpstream(ctx) + } rewriteMatchedPath := result.matchedPath if pt.PathRewrite == PathRewritePreserve { @@ -142,6 +145,8 @@ func (p *ReverseProxy) rewriteFunc(target *url.URL, matchedPath string, passHost r.Out.Header.Set(k, v) } + stampNetBirdIdentity(r) + clientIP := extractHostIP(r.In.RemoteAddr) if isTrustedAddr(clientIP, p.trustedProxies) { @@ -426,3 +431,70 @@ func opErrorContains(err error, substr string) bool { } return false } + +const ( + // headerNetBirdUser carries the authenticated user's display identity + // (email when the peer is attached to a user, else peer name) onto + // upstream requests. Stripped from inbound requests before stamping + // so a client can't spoof identity by setting the header themselves. + headerNetBirdUser = "X-NetBird-User" + // headerNetBirdGroups carries the user's group display names as a + // comma-separated list. Falls back to group IDs at positions where a + // name wasn't available at session-mint time. Labels containing a + // comma or any non-printable byte are dropped at stamp time so the + // list is unambiguously splittable by consumers. + headerNetBirdGroups = "X-NetBird-Groups" +) + +// isHeaderValueSafe reports whether v is a valid RFC 7230 field-value: +// VCHAR (0x21-0x7E), SP (0x20), or HTAB (0x09). Empty values are +// rejected; the caller decides whether to omit the header entirely. +func isHeaderValueSafe(v string) bool { + if v == "" { + return false + } + for i := 0; i < len(v); i++ { + c := v[i] + if c == '\t' || (c >= 0x20 && c <= 0x7E) { + continue + } + return false + } + return true +} + +// stampNetBirdIdentity injects authenticated identity onto outbound +// requests as X-NetBird-User and X-NetBird-Groups. Always strips any +// client-sent values first (anti-spoof). Skips when the request didn't +// carry CapturedData (early-path errors, internal endpoints). +func stampNetBirdIdentity(r *httputil.ProxyRequest) { + r.Out.Header.Del(headerNetBirdUser) + r.Out.Header.Del(headerNetBirdGroups) + + cd := CapturedDataFromContext(r.In.Context()) + if cd == nil { + return + } + if email := cd.GetUserEmail(); isHeaderValueSafe(email) { + r.Out.Header.Set(headerNetBirdUser, email) + } + groupIDs := cd.GetUserGroups() + if len(groupIDs) == 0 { + return + } + groupNames := cd.GetUserGroupNames() + labels := make([]string, 0, len(groupIDs)) + for i, id := range groupIDs { + label := id + if i < len(groupNames) && groupNames[i] != "" { + label = groupNames[i] + } + if !isHeaderValueSafe(label) || strings.ContainsRune(label, ',') { + continue + } + labels = append(labels, label) + } + if len(labels) > 0 { + r.Out.Header.Set(headerNetBirdGroups, strings.Join(labels, ",")) + } +} diff --git a/proxy/internal/proxy/reverseproxy_test.go b/proxy/internal/proxy/reverseproxy_test.go index c53307837..d5158a6cc 100644 --- a/proxy/internal/proxy/reverseproxy_test.go +++ b/proxy/internal/proxy/reverseproxy_test.go @@ -1067,3 +1067,245 @@ func TestClassifyProxyError(t *testing.T) { }) } } + +func TestStampNetBirdIdentity_NoCapturedData_StripsOnly(t *testing.T) { + target, _ := url.Parse("http://backend.internal:8080") + p := &ReverseProxy{forwardedProto: "auto"} + rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil) + + pr := newProxyRequest(t, "http://example.com/", "203.0.113.50:9999") + pr.In.Header.Set(headerNetBirdUser, "spoofed@evil.io") + pr.In.Header.Set(headerNetBirdGroups, "admin") + pr.Out.Header = pr.In.Header.Clone() + + rewrite(pr) + + assert.Empty(t, pr.Out.Header.Get(headerNetBirdUser), + "client-supplied X-NetBird-User must be stripped when no captured identity is present") + assert.Empty(t, pr.Out.Header.Get(headerNetBirdGroups), + "client-supplied X-NetBird-Groups must be stripped when no captured identity is present") +} + +func TestStampNetBirdIdentity_StampsFromCapturedData(t *testing.T) { + target, _ := url.Parse("http://backend.internal:8080") + p := &ReverseProxy{forwardedProto: "auto"} + rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil) + + pr := newProxyRequest(t, "http://example.com/", "203.0.113.50:9999") + pr.In.Header.Set(headerNetBirdUser, "spoofed@evil.io") + pr.Out.Header = pr.In.Header.Clone() + + cd := NewCapturedData("req-1") + cd.SetUserEmail("alice@netbird.io") + cd.SetUserGroups([]string{"grp-eng", "grp-ops"}) + cd.SetUserGroupNames([]string{"engineering", "operations"}) + + pr.In = pr.In.WithContext(WithCapturedData(pr.In.Context(), cd)) + + rewrite(pr) + + assert.Equal(t, "alice@netbird.io", pr.Out.Header.Get(headerNetBirdUser), + "captured email must overwrite any spoofed value") + assert.Equal(t, "engineering,operations", pr.Out.Header.Get(headerNetBirdGroups), + "group display names must be CSV-joined in positional order") +} + +// TestStampNetBirdIdentity_GroupsOnlyWhenEmailEmpty covers the +// tunnel-peer-without-user case (machine agents, unattached proxy peers). +// The proxy must still stamp the peer's groups so downstream services can +// authorise, but X-NetBird-User stays unset โ€” only its inbound stripping +// must happen. +func TestStampNetBirdIdentity_GroupsOnlyWhenEmailEmpty(t *testing.T) { + target, _ := url.Parse("http://backend.internal:8080") + p := &ReverseProxy{forwardedProto: "auto"} + rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil) + + pr := newProxyRequest(t, "http://example.com/", "203.0.113.50:9999") + pr.In.Header.Set(headerNetBirdUser, "spoofed@evil.io") + pr.Out.Header = pr.In.Header.Clone() + + cd := NewCapturedData("req-1") + cd.SetUserGroups([]string{"grp-machines"}) + cd.SetUserGroupNames([]string{"machines"}) + + pr.In = pr.In.WithContext(WithCapturedData(pr.In.Context(), cd)) + + rewrite(pr) + + assert.Empty(t, pr.Out.Header.Get(headerNetBirdUser), + "X-NetBird-User must remain unset when CapturedData carries no email") + assert.Equal(t, "machines", pr.Out.Header.Get(headerNetBirdGroups), + "groups must still be stamped for peers without a user identity") +} + +// TestStampNetBirdIdentity_EmailOnlyWhenGroupsEmpty covers the symmetric +// case: identity-resolved user without resolved group memberships. +func TestStampNetBirdIdentity_EmailOnlyWhenGroupsEmpty(t *testing.T) { + target, _ := url.Parse("http://backend.internal:8080") + p := &ReverseProxy{forwardedProto: "auto"} + rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil) + + pr := newProxyRequest(t, "http://example.com/", "203.0.113.50:9999") + pr.In.Header.Set(headerNetBirdGroups, "spoofed-admin") + pr.Out.Header = pr.In.Header.Clone() + + cd := NewCapturedData("req-1") + cd.SetUserEmail("carol@netbird.io") + + pr.In = pr.In.WithContext(WithCapturedData(pr.In.Context(), cd)) + + rewrite(pr) + + assert.Equal(t, "carol@netbird.io", pr.Out.Header.Get(headerNetBirdUser), + "email must be stamped even when no groups are captured") + assert.Empty(t, pr.Out.Header.Get(headerNetBirdGroups), + "X-NetBird-Groups must remain unset when CapturedData carries no groups") +} + +func TestStampNetBirdIdentity_FallsBackToGroupIDsWhenNameMissing(t *testing.T) { + target, _ := url.Parse("http://backend.internal:8080") + p := &ReverseProxy{forwardedProto: "auto"} + rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil) + + pr := newProxyRequest(t, "http://example.com/", "203.0.113.50:9999") + + cd := NewCapturedData("req-1") + cd.SetUserEmail("bob@netbird.io") + cd.SetUserGroups([]string{"grp-a", "grp-b", "grp-c"}) + // "grp-b" gets an explicit empty-string display name (not just a + // shorter slice). Both gap shapes must fall back to the id. + cd.SetUserGroupNames([]string{"alpha", "", ""}) + + pr.In = pr.In.WithContext(WithCapturedData(pr.In.Context(), cd)) + + rewrite(pr) + + assert.Equal(t, "alpha,grp-b,grp-c", pr.Out.Header.Get(headerNetBirdGroups), + "empty-string and out-of-range name slots must both fall back to the group id") +} + +// TestStampNetBirdIdentity_DropsLabelsWithComma covers the +// comma-separator constraint: a group display name that itself contains +// a comma is dropped from the header (rather than corrupting the list), +// and the remaining labels are stamped. +func TestStampNetBirdIdentity_DropsLabelsWithComma(t *testing.T) { + target, _ := url.Parse("http://backend.internal:8080") + p := &ReverseProxy{forwardedProto: "auto"} + rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil) + + pr := newProxyRequest(t, "http://example.com/", "203.0.113.50:9999") + + cd := NewCapturedData("req-1") + cd.SetUserEmail("alice@netbird.io") + cd.SetUserGroups([]string{"grp-a", "grp-b", "grp-c"}) + cd.SetUserGroupNames([]string{"engineering", "EU, EMEA", "operations"}) + + pr.In = pr.In.WithContext(WithCapturedData(pr.In.Context(), cd)) + + rewrite(pr) + + assert.Equal(t, "engineering,operations", pr.Out.Header.Get(headerNetBirdGroups), + "group label with embedded comma must be dropped, remaining labels stamped") +} + +// TestStampNetBirdIdentity_RejectsControlCharsInEmail covers the +// header-injection defence: an email value containing CR/LF/control +// chars is omitted entirely (not partially stamped) so the upstream +// request stays well-formed and no header injection is possible. +func TestStampNetBirdIdentity_RejectsControlCharsInEmail(t *testing.T) { + target, _ := url.Parse("http://backend.internal:8080") + p := &ReverseProxy{forwardedProto: "auto"} + rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil) + + pr := newProxyRequest(t, "http://example.com/", "203.0.113.50:9999") + pr.In.Header.Set(headerNetBirdUser, "spoofed@evil.io") + pr.Out.Header = pr.In.Header.Clone() + + cd := NewCapturedData("req-1") + cd.SetUserEmail("alice@netbird.io\r\nX-Admin: yes") + cd.SetUserGroups([]string{"grp-a"}) + cd.SetUserGroupNames([]string{"engineering"}) + + pr.In = pr.In.WithContext(WithCapturedData(pr.In.Context(), cd)) + + rewrite(pr) + + assert.Empty(t, pr.Out.Header.Get(headerNetBirdUser), + "email with CR/LF must be dropped, not partially stamped") + assert.Equal(t, "engineering", pr.Out.Header.Get(headerNetBirdGroups), + "groups remain stampable even when email is invalid") +} + +// TestStampNetBirdIdentity_RejectsControlCharsInGroup covers the +// per-label defence: a group name with a control char is silently +// dropped, the rest are stamped. +func TestStampNetBirdIdentity_RejectsControlCharsInGroup(t *testing.T) { + target, _ := url.Parse("http://backend.internal:8080") + p := &ReverseProxy{forwardedProto: "auto"} + rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil) + + pr := newProxyRequest(t, "http://example.com/", "203.0.113.50:9999") + + cd := NewCapturedData("req-1") + cd.SetUserEmail("alice@netbird.io") + cd.SetUserGroups([]string{"grp-a", "grp-b"}) + cd.SetUserGroupNames([]string{"engineering\r\nsneaky", "operations"}) + + pr.In = pr.In.WithContext(WithCapturedData(pr.In.Context(), cd)) + + rewrite(pr) + + assert.Equal(t, "operations", pr.Out.Header.Get(headerNetBirdGroups), + "group label with control char must be dropped, valid ones kept") +} + +// TestStampNetBirdIdentity_OmitsGroupsHeaderWhenAllInvalid covers the +// edge case where every group label is rejected: the header must not be +// set at all (rather than set to an empty string). +func TestStampNetBirdIdentity_OmitsGroupsHeaderWhenAllInvalid(t *testing.T) { + target, _ := url.Parse("http://backend.internal:8080") + p := &ReverseProxy{forwardedProto: "auto"} + rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil) + + pr := newProxyRequest(t, "http://example.com/", "203.0.113.50:9999") + pr.In.Header.Set(headerNetBirdGroups, "spoofed-admin") + pr.Out.Header = pr.In.Header.Clone() + + cd := NewCapturedData("req-1") + cd.SetUserEmail("alice@netbird.io") + cd.SetUserGroups([]string{"grp-a", "grp-b"}) + cd.SetUserGroupNames([]string{"with,comma", "with\nbreak"}) + + pr.In = pr.In.WithContext(WithCapturedData(pr.In.Context(), cd)) + + rewrite(pr) + + _, present := pr.Out.Header[http.CanonicalHeaderKey(headerNetBirdGroups)] + assert.False(t, present, + "X-NetBird-Groups must not be set when every group label is rejected") +} + +// TestStampNetBirdIdentity_CapturedDataPresentButEmpty covers requests +// that carry CapturedData with no identity fields populated (e.g. the +// auth middleware ran but the request didn't authenticate). Both +// headers must be cleared and neither stamped. +func TestStampNetBirdIdentity_CapturedDataPresentButEmpty(t *testing.T) { + target, _ := url.Parse("http://backend.internal:8080") + p := &ReverseProxy{forwardedProto: "auto"} + rewrite := p.rewriteFunc(target, "", false, PathRewriteDefault, nil, nil) + + pr := newProxyRequest(t, "http://example.com/", "203.0.113.50:9999") + pr.In.Header.Set(headerNetBirdUser, "spoofed@evil.io") + pr.In.Header.Set(headerNetBirdGroups, "spoofed-admin") + pr.Out.Header = pr.In.Header.Clone() + + cd := NewCapturedData("req-1") + pr.In = pr.In.WithContext(WithCapturedData(pr.In.Context(), cd)) + + rewrite(pr) + + assert.Empty(t, pr.Out.Header.Get(headerNetBirdUser), + "X-NetBird-User must be stripped when CapturedData has no email") + assert.Empty(t, pr.Out.Header.Get(headerNetBirdGroups), + "X-NetBird-Groups must be stripped when CapturedData has no groups") +} diff --git a/proxy/internal/proxy/servicemapping.go b/proxy/internal/proxy/servicemapping.go index fe470cf01..46b4d2e8d 100644 --- a/proxy/internal/proxy/servicemapping.go +++ b/proxy/internal/proxy/servicemapping.go @@ -28,6 +28,10 @@ type PathTarget struct { RequestTimeout time.Duration PathRewrite PathRewriteMode CustomHeaders map[string]string + // DirectUpstream selects the stdlib HTTP transport (host network stack) + // over the embedded NetBird WireGuard client when forwarding requests + // to this target. Default false โ†’ embedded client (existing behaviour). + DirectUpstream bool } // Mapping describes how a domain is routed by the HTTP reverse proxy. diff --git a/proxy/internal/restrict/restrict.go b/proxy/internal/restrict/restrict.go index f3e0fa695..67756b88b 100644 --- a/proxy/internal/restrict/restrict.go +++ b/proxy/internal/restrict/restrict.go @@ -191,6 +191,18 @@ func (f *Filter) IsObserveOnly(v Verdict) bool { return v.IsCrowdSec() && f.CrowdSecMode == CrowdSecObserve } +// CheckCIDR runs only the CIDR allow/block evaluation. Use this when +// country and CrowdSec checks don't apply โ€” e.g. requests arriving +// from the WireGuard overlay, whose source addresses live in the +// CGNAT range and have no meaningful geolocation or IP-reputation +// data. +func (f *Filter) CheckCIDR(addr netip.Addr) Verdict { + if f == nil { + return Allow + } + return f.checkCIDR(addr.Unmap()) +} + // Check evaluates whether addr is permitted. CIDR rules are evaluated // first because they are O(n) prefix comparisons. Country rules run // only when CIDR checks pass and require a geo lookup. CrowdSec checks diff --git a/proxy/internal/restrict/restrict_test.go b/proxy/internal/restrict/restrict_test.go index abaa1afdc..16aa1e139 100644 --- a/proxy/internal/restrict/restrict_test.go +++ b/proxy/internal/restrict/restrict_test.go @@ -514,6 +514,34 @@ func TestFilter_CrowdSec_Observe_NilChecker(t *testing.T) { assert.Equal(t, Allow, f.Check(netip.MustParseAddr("1.2.3.4"), nil)) } +func TestFilter_CheckCIDR_AllowsWithoutCountryOrCrowdSec(t *testing.T) { + cs := &mockCrowdSec{ready: true, decisions: map[string]*CrowdSecDecision{ + "100.64.5.6": {Type: DecisionBan}, + }} + f := ParseFilter(FilterConfig{ + AllowedCIDRs: []string{"100.64.0.0/10"}, + AllowedCountries: []string{"US"}, + CrowdSec: cs, + CrowdSecMode: CrowdSecEnforce, + }) + + // CheckCIDR skips country + CrowdSec evaluation: an address inside + // the allowed CIDR passes even when it would be denied by CrowdSec + // or by the country allowlist (CGNAT addresses have no geo data). + assert.Equal(t, Allow, f.CheckCIDR(netip.MustParseAddr("100.64.5.6")), + "CheckCIDR must not run CrowdSec lookups on overlay traffic") + + // CIDR denials still fire. + assert.Equal(t, DenyCIDR, f.CheckCIDR(netip.MustParseAddr("198.51.100.1")), + "CheckCIDR must still reject addresses outside the allow list") +} + +func TestFilter_CheckCIDR_NilFilter(t *testing.T) { + var f *Filter + assert.Equal(t, Allow, f.CheckCIDR(netip.MustParseAddr("100.64.5.6")), + "CheckCIDR on a nil filter must allow") +} + func TestFilter_HasRestrictions_CrowdSec(t *testing.T) { cs := &mockCrowdSec{ready: true} f := ParseFilter(FilterConfig{CrowdSec: cs, CrowdSecMode: CrowdSecEnforce}) diff --git a/proxy/internal/roundtrip/multi.go b/proxy/internal/roundtrip/multi.go new file mode 100644 index 000000000..567249437 --- /dev/null +++ b/proxy/internal/roundtrip/multi.go @@ -0,0 +1,112 @@ +package roundtrip + +import ( + "crypto/tls" + "errors" + "net" + "net/http" + "time" + + log "github.com/sirupsen/logrus" +) + +// MultiTransport dispatches each request to either the embedded NetBird +// http.RoundTripper or a stdlib http.Transport based on a per-request +// context flag set by the reverse-proxy rewrite step. When the flag is +// absent (the default for every existing target), requests follow the +// embedded NetBird path โ€” current behaviour, preserved. +// +// The stdlib branch is used when a target was configured with +// direct_upstream=true. It dials via the host's network stack, which is +// what private (`netbird proxy`) deployments and centralised proxies +// fronting host-reachable upstreams (public APIs, LAN services, +// localhost sidecars) want. +// +// An embedded roundtripper is required. To run direct-only (no WG +// branch at all), construct the MultiTransport via NewDirectOnly. +type MultiTransport struct { + embedded http.RoundTripper + direct *http.Transport + insecure *http.Transport +} + +// errNoEmbeddedTransport is returned when a request reaches the +// embedded branch on a MultiTransport that wasn't given one. Surfaces +// the misconfiguration to the caller instead of silently routing to +// the direct branch, which would bypass the WG tunnel. +var errNoEmbeddedTransport = errors.New("multitransport: embedded roundtripper not configured") + +// NewMultiTransport wires both branches. embedded is the existing NetBird +// roundtripper and must not be nil โ€” pass to NewDirectOnly for a +// MultiTransport that only ever uses the direct branch. The direct +// branches honour the same NB_PROXY_* tuning env vars as the embedded +// transport (see loadTransportConfig) plus a dial-timeout wrapper that +// respects types.WithDialTimeout. +func NewMultiTransport(embedded http.RoundTripper, logger *log.Logger) *MultiTransport { + if logger == nil { + logger = log.StandardLogger() + } + cfg := loadTransportConfig(logger) + dialer := &net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + } + direct := &http.Transport{ + DialContext: dialWithTimeout(dialer.DialContext), + ForceAttemptHTTP2: true, + MaxIdleConns: cfg.maxIdleConns, + MaxIdleConnsPerHost: cfg.maxIdleConnsPerHost, + MaxConnsPerHost: cfg.maxConnsPerHost, + IdleConnTimeout: cfg.idleConnTimeout, + TLSHandshakeTimeout: cfg.tlsHandshakeTimeout, + ExpectContinueTimeout: cfg.expectContinueTimeout, + ResponseHeaderTimeout: cfg.responseHeaderTimeout, + WriteBufferSize: cfg.writeBufferSize, + ReadBufferSize: cfg.readBufferSize, + DisableCompression: cfg.disableCompression, + } + insecure := direct.Clone() + insecure.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} //nolint:gosec // matches the embedded NetBird transport's per-target opt-in + + return &MultiTransport{ + embedded: embedded, + direct: direct, + insecure: insecure, + } +} + +// NewDirectOnly returns a MultiTransport with no embedded branch. +// Every request goes through the direct branch regardless of the +// per-request flag, so the embedded path can never be reached +// silently โ€” wiring code that needs WG must use NewMultiTransport. +func NewDirectOnly(logger *log.Logger) *MultiTransport { + return NewMultiTransport(noEmbeddedRoundTripper{}, logger) +} + +// noEmbeddedRoundTripper is the sentinel embedded transport for +// direct-only MultiTransports. RoundTrip is never called in practice +// because the direct branch matches every request, but if anything +// ever did reach this path it would fail loudly instead of falling +// back to direct. +type noEmbeddedRoundTripper struct{} + +func (noEmbeddedRoundTripper) RoundTrip(*http.Request) (*http.Response, error) { + return nil, errNoEmbeddedTransport +} + +// RoundTrip dispatches by reading the direct-upstream flag from the request +// context. When set, the request is forwarded via the stdlib transport, +// honouring the existing per-request skip-TLS-verify flag. Otherwise it +// goes through the embedded NetBird roundtripper. +func (m *MultiTransport) RoundTrip(req *http.Request) (*http.Response, error) { + if DirectUpstreamFromContext(req.Context()) { + if skipTLSVerifyFromContext(req.Context()) { + return m.insecure.RoundTrip(req) + } + return m.direct.RoundTrip(req) + } + if m.embedded == nil { + return nil, errNoEmbeddedTransport + } + return m.embedded.RoundTrip(req) +} diff --git a/proxy/internal/roundtrip/multi_test.go b/proxy/internal/roundtrip/multi_test.go new file mode 100644 index 000000000..5c6cf1c97 --- /dev/null +++ b/proxy/internal/roundtrip/multi_test.go @@ -0,0 +1,134 @@ +package roundtrip + +import ( + "context" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// stubRoundTripper records whether RoundTrip was called and returns a +// canned response so tests can assert the dispatch decision without +// running a real network. +type stubRoundTripper struct { + called bool + body string +} + +func (s *stubRoundTripper) RoundTrip(_ *http.Request) (*http.Response, error) { + s.called = true + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(s.body)), + Header: http.Header{}, + }, nil +} + +func TestMultiTransport_DispatchesByContextFlag(t *testing.T) { + embedded := &stubRoundTripper{body: "embedded"} + mt := NewMultiTransport(embedded, nil) + + t.Run("default routes to embedded", func(t *testing.T) { + embedded.called = false + req := httptest.NewRequest(http.MethodGet, "http://example.invalid", nil) + resp, err := mt.RoundTrip(req) + require.NoError(t, err, "embedded path must not error on stubbed transport") + require.NotNil(t, resp) + _ = resp.Body.Close() + assert.True(t, embedded.called, "request without WithDirectUpstream must hit the embedded transport") + }) + + t.Run("WithDirectUpstream skips embedded", func(t *testing.T) { + embedded.called = false + // Hit a server we control to verify the stdlib transport is used. + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + _, _ = io.WriteString(w, "direct") + })) + defer srv.Close() + + req, err := http.NewRequestWithContext(WithDirectUpstream(context.Background()), http.MethodGet, srv.URL, nil) + require.NoError(t, err) + resp, err := mt.RoundTrip(req) + require.NoError(t, err, "direct path must dial via stdlib transport") + body, err := io.ReadAll(resp.Body) + _ = resp.Body.Close() + require.NoError(t, err) + assert.Equal(t, "direct", string(body), "stdlib transport must reach the test server") + assert.False(t, embedded.called, "WithDirectUpstream must bypass the embedded transport") + }) +} + +// TestMultiTransport_AppliesEnvOverridesToDirect verifies that the +// NB_PROXY_* env vars consumed by loadTransportConfig flow into the +// direct branches (previously they only applied to the embedded +// roundtripper, so direct-upstream traffic ignored operator tuning). +func TestMultiTransport_AppliesEnvOverridesToDirect(t *testing.T) { + t.Setenv(EnvMaxIdleConns, "42") + t.Setenv(EnvIdleConnTimeout, "11s") + t.Setenv(EnvTLSHandshakeTimeout, "7s") + + mt := NewMultiTransport(&stubRoundTripper{body: "embedded"}, nil) + + assert.Equal(t, 42, mt.direct.MaxIdleConns, + "NB_PROXY_MAX_IDLE_CONNS must propagate to the direct transport") + assert.Equal(t, 11*time.Second, mt.direct.IdleConnTimeout, + "NB_PROXY_IDLE_CONN_TIMEOUT must propagate to the direct transport") + assert.Equal(t, 7*time.Second, mt.direct.TLSHandshakeTimeout, + "NB_PROXY_TLS_HANDSHAKE_TIMEOUT must propagate to the direct transport") + assert.Equal(t, 42, mt.insecure.MaxIdleConns, + "env tuning must also apply to the insecure-skip-verify direct transport") +} + +// TestMultiTransport_NilEmbeddedErrorsWhenWGPathRequested guards +// against the previous silent fallback: a MultiTransport constructed +// without an embedded transport must reject requests that don't +// explicitly opt into the direct branch, rather than routing them +// over the host stack and bypassing WireGuard. +func TestMultiTransport_NilEmbeddedErrorsWhenWGPathRequested(t *testing.T) { + mt := NewMultiTransport(nil, nil) + + req := httptest.NewRequest(http.MethodGet, "http://example.invalid", nil) + resp, err := mt.RoundTrip(req) + if resp != nil { + _ = resp.Body.Close() + } + require.Error(t, err, "nil embedded must surface as an explicit error, not a silent direct dispatch") + assert.Nil(t, resp) + assert.ErrorIs(t, err, errNoEmbeddedTransport, + "the error must be the sentinel so callers can distinguish misconfiguration from network failures") +} + +// TestMultiTransport_DirectOnlyServesDirectBranch verifies NewDirectOnly +// constructs a MultiTransport whose direct branch handles requests with +// the direct-upstream flag set, and surfaces the explicit sentinel +// when the embedded path is reached. +func TestMultiTransport_DirectOnlyServesDirectBranch(t *testing.T) { + mt := NewDirectOnly(nil) + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + _, _ = io.WriteString(w, "ok") + })) + defer srv.Close() + + req, err := http.NewRequestWithContext(WithDirectUpstream(context.Background()), http.MethodGet, srv.URL, nil) + require.NoError(t, err) + resp, err := mt.RoundTrip(req) + require.NoError(t, err, "direct-only must serve requests that opt into the direct branch") + _ = resp.Body.Close() + assert.Equal(t, http.StatusOK, resp.StatusCode) + + wgReq := httptest.NewRequest(http.MethodGet, "http://example.invalid", nil) + resp, err = mt.RoundTrip(wgReq) + if resp != nil { + _ = resp.Body.Close() + } + require.Error(t, err, "direct-only must refuse requests that didn't opt into the direct branch") + assert.Nil(t, resp) + assert.ErrorIs(t, err, errNoEmbeddedTransport) +} diff --git a/proxy/internal/roundtrip/netbird.go b/proxy/internal/roundtrip/netbird.go index 28eba7810..133e86f05 100644 --- a/proxy/internal/roundtrip/netbird.go +++ b/proxy/internal/roundtrip/netbird.go @@ -7,6 +7,7 @@ import ( "fmt" "net" "net/http" + "net/netip" "sync" "time" @@ -76,11 +77,11 @@ type clientEntry struct { services map[ServiceKey]serviceInfo createdAt time.Time started bool - // ready is closed once the client has been fully initialized. - // Callers that find a pending entry wait on this channel before - // accessing the client. A nil initErr means success. - ready chan struct{} - initErr error + // inbound is opaque per-account state owned by the NetBird parent's + // ReadyHandler. The roundtrip package never inspects this value; it + // only stores it so RemovePeer / StopAll can hand it back to the + // matching StopHandler. Nil when no inbound integration is active. + inbound any // Per-backend in-flight limiting keyed by target host:port. // TODO: clean up stale entries when backend targets change. inflightMu sync.Mutex @@ -88,6 +89,19 @@ type clientEntry struct { maxInflight int } +// IdentityForIP resolves a tunnel IP to the peer identity locally known by +// this account's embedded client. Returns (pubKey, fqdn) on success. +// ok=false means the IP is not in the account's roster โ€” callers can use +// that as a fast deny without round-tripping management. The returned +// strings carry only what the embedded peerstore exposes; user identity +// (UserID / Email / Groups) still flows through ValidateTunnelPeer. +func (e *clientEntry) IdentityForIP(ip netip.Addr) (pubKey, fqdn string, ok bool) { + if e == nil || e.client == nil || !ip.IsValid() { + return "", "", false + } + return e.client.IdentityForIP(ip) +} + // acquireInflight attempts to acquire an in-flight slot for the given backend. // It returns a release function that must always be called, and true on success. func (e *clientEntry) acquireInflight(backend backendKey) (release func(), ok bool) { @@ -117,6 +131,12 @@ type ClientConfig struct { MgmtAddr string WGPort uint16 PreSharedKey string + // BlockInbound mirrors embed.Options.BlockInbound. Set to true on the + // standalone proxy where the embedded client never accepts inbound; + // set to false on the private/embedded proxy so the engine creates + // the ACL manager and applies management's per-policy firewall rules + // (which is what gates per-account inbound listeners on the netstack). + BlockInbound bool } type statusNotifier interface { @@ -142,6 +162,14 @@ type NetBird struct { clients map[types.AccountID]*clientEntry initLogOnce sync.Once statusNotifier statusNotifier + // readyHandler runs after the embedded client for an account reports + // Ready. The opaque return value is stored on clientEntry and handed + // back to stopHandler when the entry is torn down. Nil disables the + // hook entirely (default for the standalone proxy). + readyHandler func(ctx context.Context, accountID types.AccountID, client *embed.Client) any + // stopHandler runs when an account's last service is removed (or the + // transport is shutting down). Receives whatever readyHandler returned. + stopHandler func(accountID types.AccountID, state any) // OnAddPeer, when set, is called after AddPeer completes for a new account // (i.e. when a new client was actually created, not when an existing one @@ -167,9 +195,6 @@ type skipTLSVerifyContextKey struct{} // AddPeer registers a service for an account. If the account doesn't have a client yet, // one is created by authenticating with the management server using the provided token. // Multiple services can share the same client. -// -// Client creation (WG keygen, gRPC, embed.New) runs without holding clientsMux -// so that concurrent AddPeer calls for different accounts execute in parallel. func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, key ServiceKey, authToken string, serviceID types.ServiceID) error { si := serviceInfo{serviceID: serviceID} @@ -177,23 +202,10 @@ func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, key Se entry, exists := n.clients[accountID] if exists { - ready := entry.ready entry.services[key] = si started := entry.started n.clientsMux.Unlock() - // If the entry is still being initialized by another goroutine, wait. - if ready != nil { - select { - case <-ready: - case <-ctx.Done(): - return ctx.Err() - } - if entry.initErr != nil { - return fmt.Errorf("peer initialization failed: %w", entry.initErr) - } - } - n.logger.WithFields(log.Fields{ "account_id": accountID, "service_key": key, @@ -210,43 +222,19 @@ func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, key Se return nil } - // Insert a placeholder so other goroutines calling AddPeer for the same - // account will wait on the ready channel instead of starting a second - // client creation. - entry = &clientEntry{ - services: map[ServiceKey]serviceInfo{key: si}, - ready: make(chan struct{}), - } - n.clients[accountID] = entry - n.clientsMux.Unlock() - createStart := time.Now() - created, err := n.createClientEntry(ctx, accountID, key, authToken, si) + entry, err := n.createClientEntry(ctx, accountID, key, authToken, si) if n.OnAddPeer != nil { n.OnAddPeer(time.Since(createStart), err) } if err != nil { - entry.initErr = err - close(entry.ready) - - n.clientsMux.Lock() - delete(n.clients, accountID) n.clientsMux.Unlock() return err } - // Transfer any services that were registered by concurrent AddPeer calls - // while we were creating the client. - n.clientsMux.Lock() - for k, v := range entry.services { - created.services[k] = v - } - created.ready = nil - n.clients[accountID] = created + n.clients[accountID] = entry n.clientsMux.Unlock() - close(entry.ready) - n.logger.WithFields(log.Fields{ "account_id": accountID, "service_key": key, @@ -254,13 +242,13 @@ func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, key Se // Attempt to start the client in the background; if this fails we will // retry on the first request via RoundTrip. - go n.runClientStartup(ctx, accountID, created.client) + go n.runClientStartup(ctx, accountID, entry.client) return nil } // createClientEntry generates a WireGuard keypair, authenticates with management, -// and creates an embedded NetBird client. +// and creates an embedded NetBird client. Must be called with clientsMux held. func (n *NetBird) createClientEntry(ctx context.Context, accountID types.AccountID, key ServiceKey, authToken string, si serviceInfo) (*clientEntry, error) { serviceID := si.serviceID n.logger.WithFields(log.Fields{ @@ -318,9 +306,15 @@ func (n *NetBird) createClientEntry(ctx context.Context, accountID types.Account ManagementURL: n.clientCfg.MgmtAddr, PrivateKey: privateKey.String(), LogLevel: log.WarnLevel.String(), - BlockInbound: true, - WireguardPort: &wgPort, - PreSharedKey: n.clientCfg.PreSharedKey, + BlockInbound: n.clientCfg.BlockInbound, + // The embedded proxy peer must never be a stepping stone into + // the proxy host's LAN: it only exists to reach NetBird mesh + // targets or, when direct_upstream is set, the host network + // stack via the MultiTransport's direct branch (which bypasses + // the engine routing entirely). + BlockLANAccess: true, + WireguardPort: &wgPort, + PreSharedKey: n.clientCfg.PreSharedKey, }) if err != nil { return nil, fmt.Errorf("create netbird client: %w", err) @@ -385,8 +379,25 @@ func (n *NetBird) runClientStartup(ctx context.Context, accountID types.AccountI toNotify = append(toNotify, serviceNotification{key: key, serviceID: info.serviceID}) } } + readyHandler := n.readyHandler n.clientsMux.Unlock() + if readyHandler != nil { + state := readyHandler(ctx, accountID, client) + n.clientsMux.Lock() + if e, ok := n.clients[accountID]; ok { + e.inbound = state + } else if state != nil && n.stopHandler != nil { + // Account was removed while readyHandler ran; tear down the + // resources it just brought up. + stop := n.stopHandler + n.clientsMux.Unlock() + stop(accountID, state) + n.clientsMux.Lock() + } + n.clientsMux.Unlock() + } + if n.statusNotifier == nil { return } @@ -432,11 +443,15 @@ func (n *NetBird) RemovePeer(ctx context.Context, accountID types.AccountID, key stopClient := len(entry.services) == 0 var client *embed.Client var transport, insecureTransport *http.Transport + var inbound any + var stopHandler func(types.AccountID, any) if stopClient { n.logger.WithField("account_id", accountID).Info("stopping client, no more services") client = entry.client transport = entry.transport insecureTransport = entry.insecureTransport + inbound = entry.inbound + stopHandler = n.stopHandler delete(n.clients, accountID) } else { n.logger.WithFields(log.Fields{ @@ -450,6 +465,9 @@ func (n *NetBird) RemovePeer(ctx context.Context, accountID types.AccountID, key n.notifyDisconnect(ctx, accountID, key, si.serviceID) if stopClient { + if inbound != nil && stopHandler != nil { + stopHandler(accountID, inbound) + } transport.CloseIdleConnections() insecureTransport.CloseIdleConnections() if err := client.Stop(ctx); err != nil { @@ -536,8 +554,12 @@ func (n *NetBird) StopAll(ctx context.Context) error { n.clientsMux.Lock() defer n.clientsMux.Unlock() + stopHandler := n.stopHandler var merr *multierror.Error for accountID, entry := range n.clients { + if entry.inbound != nil && stopHandler != nil { + stopHandler(accountID, entry.inbound) + } entry.transport.CloseIdleConnections() entry.insecureTransport.CloseIdleConnections() if err := entry.client.Stop(ctx); err != nil { @@ -590,6 +612,19 @@ func (n *NetBird) GetClient(accountID types.AccountID) (*embed.Client, bool) { return entry.client, true } +// IdentityForIP resolves a tunnel IP to a peer identity local to the given +// account. Delegates to clientEntry.IdentityForIP. Returns ok=false when +// the account has no client or the IP is not in its peerstore. +func (n *NetBird) IdentityForIP(accountID types.AccountID, ip netip.Addr) (pubKey, fqdn string, ok bool) { + n.clientsMux.RLock() + entry, exists := n.clients[accountID] + n.clientsMux.RUnlock() + if !exists { + return "", "", false + } + return entry.IdentityForIP(ip) +} + // ListClientsForDebug returns information about all clients for debug purposes. func (n *NetBird) ListClientsForDebug() map[types.AccountID]ClientDebugInfo { n.clientsMux.RLock() @@ -645,6 +680,18 @@ func NewNetBird(proxyID, proxyAddr string, clientCfg ClientConfig, logger *log.L } } +// SetClientLifecycle registers callbacks that run when an embedded +// client becomes ready and when its entry is torn down. The opaque value +// returned by ready is stored on the entry and handed back to stop on +// cleanup. Must be called before AddPeer. A nil pair leaves the +// outbound-only behaviour intact. +func (n *NetBird) SetClientLifecycle(ready func(ctx context.Context, accountID types.AccountID, client *embed.Client) any, stop func(accountID types.AccountID, state any)) { + n.clientsMux.Lock() + defer n.clientsMux.Unlock() + n.readyHandler = ready + n.stopHandler = stop +} + // dialWithTimeout wraps a DialContext function so that any dial timeout // stored in the context (via types.WithDialTimeout) is applied only to // the connection establishment phase, not the full request lifetime. @@ -687,3 +734,22 @@ func skipTLSVerifyFromContext(ctx context.Context) bool { v, _ := ctx.Value(skipTLSVerifyContextKey{}).(bool) return v } + +// directUpstreamContextKey signals that the request should bypass the embedded +// NetBird WireGuard client and dial via the host's network stack instead. +// Set by the reverse-proxy rewrite step when the matched target carries +// PathTarget.DirectUpstream; consumed by MultiTransport. +type directUpstreamContextKey struct{} + +// WithDirectUpstream marks the context so MultiTransport routes the request +// through its stdlib transport instead of the embedded NetBird roundtripper. +func WithDirectUpstream(ctx context.Context) context.Context { + return context.WithValue(ctx, directUpstreamContextKey{}, true) +} + +// DirectUpstreamFromContext reports whether the context has been marked to +// bypass the embedded NetBird client. +func DirectUpstreamFromContext(ctx context.Context) bool { + v, _ := ctx.Value(directUpstreamContextKey{}).(bool) + return v +} diff --git a/proxy/internal/roundtrip/netbird_test.go b/proxy/internal/roundtrip/netbird_test.go index 5444f6c11..3f3e4138a 100644 --- a/proxy/internal/roundtrip/netbird_test.go +++ b/proxy/internal/roundtrip/netbird_test.go @@ -3,6 +3,7 @@ package roundtrip import ( "context" "net/http" + "net/netip" "sync" "testing" @@ -305,6 +306,36 @@ func TestNetBird_AddPeer_ExistingStartedClient_NotifiesStatus(t *testing.T) { assert.True(t, calls[0].connected) } +// TestNetBird_IdentityForIP_UnknownAccountReturnsFalse confirms that the +// public lookup short-circuits when no client has been registered for +// the queried account. The auth middleware uses ok=false as a fast deny. +func TestNetBird_IdentityForIP_UnknownAccountReturnsFalse(t *testing.T) { + nb := mockNetBird() + _, _, ok := nb.IdentityForIP("acct-missing", netip.MustParseAddr("100.64.0.10")) + assert.False(t, ok, "unknown account must yield ok=false") +} + +// TestClientEntry_IdentityForIP_NilClientGuard ensures the receiver +// methods stay safe when called on partially-initialized state, which +// can happen briefly during AddPeer setup or test fixtures. +func TestClientEntry_IdentityForIP_NilClientGuard(t *testing.T) { + var e *clientEntry + _, _, ok := e.IdentityForIP(netip.MustParseAddr("100.64.0.10")) + assert.False(t, ok, "nil clientEntry must yield ok=false") + + e = &clientEntry{} + _, _, ok = e.IdentityForIP(netip.MustParseAddr("100.64.0.10")) + assert.False(t, ok, "clientEntry with nil embed.Client must yield ok=false") +} + +// TestClientEntry_IdentityForIP_InvalidIPReturnsFalse covers the input +// guard so callers don't have to repeat the check. +func TestClientEntry_IdentityForIP_InvalidIPReturnsFalse(t *testing.T) { + e := &clientEntry{} + _, _, ok := e.IdentityForIP(netip.Addr{}) + assert.False(t, ok, "invalid IP must yield ok=false") +} + func TestNetBird_RemovePeer_NotifiesDisconnection(t *testing.T) { notifier := &mockStatusNotifier{} nb := NewNetBird("test-proxy", "invalid.test", ClientConfig{ diff --git a/proxy/internal/tcp/bench_test.go b/proxy/internal/tcp/bench_test.go index 049f8395d..6c0b1eea7 100644 --- a/proxy/internal/tcp/bench_test.go +++ b/proxy/internal/tcp/bench_test.go @@ -36,7 +36,7 @@ func BenchmarkPeekClientHello_TLS(b *testing.B) { for b.Loop() { r := bytes.NewReader(hello) conn := &readerConn{Reader: r} - sni, wrapped, err := PeekClientHello(conn) + sni, wrapped, _, err := PeekClientHello(conn) if err != nil { b.Fatal(err) } @@ -59,7 +59,7 @@ func BenchmarkPeekClientHello_NonTLS(b *testing.B) { for b.Loop() { r := bytes.NewReader(httpReq) conn := &readerConn{Reader: r} - _, wrapped, err := PeekClientHello(conn) + _, wrapped, _, err := PeekClientHello(conn) if err != nil { b.Fatal(err) } diff --git a/proxy/internal/tcp/router.go b/proxy/internal/tcp/router.go index 05beb658b..15c5022b0 100644 --- a/proxy/internal/tcp/router.go +++ b/proxy/internal/tcp/router.go @@ -100,28 +100,50 @@ type Router struct { // httpCh is immutable after construction: set only in NewRouter, nil in NewPortRouter. httpCh chan net.Conn httpListener *chanListener - mu sync.RWMutex - routes map[SNIHost][]Route - fallback *Route - draining bool - dialResolve DialResolver - activeConns sync.WaitGroup - activeRelays sync.WaitGroup - relaySem chan struct{} - drainDone chan struct{} - observer RelayObserver - accessLog l4Logger - geo restrict.GeoResolver + // httpPlainCh feeds non-TLS HTTP connections to a parallel http.Server. + // Set only when NewRouter is called with WithPlainHTTP option (used by + // per-account inbound listeners that accept both :80 and :443 traffic). + // Nil for the host SNI router and for port routers. + httpPlainCh chan net.Conn + httpPlainListener *chanListener + mu sync.RWMutex + routes map[SNIHost][]Route + fallback *Route + draining bool + dialResolve DialResolver + activeConns sync.WaitGroup + activeRelays sync.WaitGroup + relaySem chan struct{} + drainDone chan struct{} + observer RelayObserver + accessLog l4Logger + geo restrict.GeoResolver // svcCtxs tracks a context per service ID. All relay goroutines for a // service derive from its context; canceling it kills them immediately. svcCtxs map[types.ServiceID]context.Context svcCancels map[types.ServiceID]context.CancelFunc } +// RouterOption customises Router construction. +type RouterOption func(*Router) + +// WithPlainHTTP enables a parallel plain-HTTP channel on the router. When +// set, connections whose first byte is not a TLS handshake are forwarded +// to the plain channel returned by HTTPListenerPlain instead of the TLS +// channel. Used by per-account inbound listeners that share both :80 and +// :443 traffic on the same router. +func WithPlainHTTP(addr net.Addr) RouterOption { + return func(r *Router) { + ch := make(chan net.Conn, httpChannelBuffer) + r.httpPlainCh = ch + r.httpPlainListener = newChanListener(ch, addr) + } +} + // NewRouter creates a new SNI-based connection router. -func NewRouter(logger *log.Logger, dialResolve DialResolver, addr net.Addr) *Router { +func NewRouter(logger *log.Logger, dialResolve DialResolver, addr net.Addr, opts ...RouterOption) *Router { httpCh := make(chan net.Conn, httpChannelBuffer) - return &Router{ + r := &Router{ logger: logger, httpCh: httpCh, httpListener: newChanListener(httpCh, addr), @@ -131,6 +153,10 @@ func NewRouter(logger *log.Logger, dialResolve DialResolver, addr net.Addr) *Rou svcCtxs: make(map[types.ServiceID]context.Context), svcCancels: make(map[types.ServiceID]context.CancelFunc), } + for _, opt := range opts { + opt(r) + } + return r } // NewPortRouter creates a Router for a dedicated port without an HTTP @@ -153,6 +179,16 @@ func (r *Router) HTTPListener() net.Listener { return r.httpListener } +// HTTPListenerPlain returns a net.Listener yielding non-TLS connections +// for use with a parallel plain http.Server. Returns nil when the router +// was not constructed with WithPlainHTTP. +func (r *Router) HTTPListenerPlain() net.Listener { + if r.httpPlainListener == nil { + return nil + } + return r.httpPlainListener +} + // AddRoute registers an SNI route. Multiple routes for the same host are // stored and resolved by priority at lookup time (HTTP > TCP). // Empty host is ignored to prevent conflicts with ECH/ESNI fallback. @@ -254,6 +290,9 @@ func (r *Router) Serve(ctx context.Context, ln net.Listener) error { if r.httpListener != nil { r.httpListener.Close() } + if r.httpPlainListener != nil { + r.httpPlainListener.Close() + } case <-done: } }() @@ -270,6 +309,7 @@ func (r *Router) Serve(ctx context.Context, ln net.Listener) error { r.logger.Debugf("SNI router accept: %v", err) continue } + r.logger.Debugf("SNI router accepted conn from %s on %s", conn.RemoteAddr(), conn.LocalAddr()) r.activeConns.Add(1) go func() { defer r.activeConns.Done() @@ -278,13 +318,24 @@ func (r *Router) Serve(ctx context.Context, ln net.Listener) error { } } +// HandleConn lets external accept loops feed a connection through the +// router's peek-and-dispatch logic. Use this when the same router serves +// a secondary listener (for example, a per-account inbound :80 socket +// alongside its :443 socket). +func (r *Router) HandleConn(ctx context.Context, conn net.Conn) { + r.activeConns.Add(1) + defer r.activeConns.Done() + r.handleConn(ctx, conn) +} + // handleConn peeks at the TLS ClientHello and routes the connection. func (r *Router) handleConn(ctx context.Context, conn net.Conn) { // Fast path: when no SNI routes and no HTTP channel exist (pure TCP // fallback port), skip the TLS peek entirely to avoid read errors on // non-TLS connections and reduce latency. if r.isFallbackOnly() { - r.handleUnmatched(ctx, conn) + r.logger.Debugf("SNI router fallback-only mode for conn from %s; skipping ClientHello peek", conn.RemoteAddr()) + r.handleUnmatched(ctx, conn, false) return } @@ -294,11 +345,11 @@ func (r *Router) handleConn(ctx context.Context, conn net.Conn) { return } - sni, wrapped, err := PeekClientHello(conn) + sni, wrapped, isTLS, err := PeekClientHello(conn) if err != nil { - r.logger.Debugf("SNI peek: %v", err) + r.logger.Debugf("SNI peek failed for conn from %s: %v", conn.RemoteAddr(), err) if wrapped != nil { - r.handleUnmatched(ctx, wrapped) + r.handleUnmatched(ctx, wrapped, isTLS) } else { _ = conn.Close() } @@ -313,13 +364,20 @@ func (r *Router) handleConn(ctx context.Context, conn net.Conn) { host := SNIHost(strings.ToLower(sni)) route, ok := r.lookupRoute(host) + r.logger.WithFields(log.Fields{ + "remote": wrapped.RemoteAddr().String(), + "sni": string(host), + "match": ok, + "tls": isTLS, + }).Debug("SNI route lookup") if !ok { - r.handleUnmatched(ctx, wrapped) + r.handleUnmatched(ctx, wrapped, isTLS) return } if route.Type == RouteHTTP { - r.sendToHTTP(wrapped) + r.logger.Debugf("SNI %q routed to HTTP handler (service_id=%s)", host, route.ServiceID) + r.sendToHTTP(wrapped, isTLS) return } @@ -344,15 +402,17 @@ func (r *Router) isFallbackOnly() bool { } // handleUnmatched routes a connection that didn't match any SNI route. -// This includes ECH/ESNI connections where the cleartext SNI is empty. +// This includes ECH/ESNI connections where the cleartext SNI is empty, +// and plain (non-TLS) HTTP connections when isTLS is false. // It tries the fallback relay first, then the HTTP channel, and closes // the connection if neither is available. -func (r *Router) handleUnmatched(ctx context.Context, conn net.Conn) { +func (r *Router) handleUnmatched(ctx context.Context, conn net.Conn, isTLS bool) { r.mu.RLock() fb := r.fallback r.mu.RUnlock() if fb != nil { + r.logger.Debugf("unmatched conn from %s relayed to TCP fallback (service_id=%s, target=%s)", conn.RemoteAddr(), fb.ServiceID, fb.Target) if err := r.relayTCP(ctx, conn, SNIHost("fallback"), *fb); err != nil { if !errors.Is(err, errAccessRestricted) { r.logger.WithFields(log.Fields{ @@ -364,7 +424,8 @@ func (r *Router) handleUnmatched(ctx context.Context, conn net.Conn) { } return } - r.sendToHTTP(conn) + r.logger.Debugf("unmatched conn from %s sent to HTTP channel (no TCP fallback configured)", conn.RemoteAddr()) + r.sendToHTTP(conn, isTLS) } // lookupRoute returns the highest-priority route for the given SNI host. @@ -386,10 +447,20 @@ func (r *Router) lookupRoute(host SNIHost) (Route, bool) { } // sendToHTTP feeds the connection to the HTTP handler via the channel. -// If no HTTP channel is configured (port router), the router is -// draining, or the channel is full, the connection is closed. -func (r *Router) sendToHTTP(conn net.Conn) { - if r.httpCh == nil { +// When isTLS is false and a plain channel is configured the connection +// is forwarded to the plain channel; otherwise it lands on the TLS +// channel. If no usable channel exists, the router is draining, or the +// channel is full, the connection is closed. +func (r *Router) sendToHTTP(conn net.Conn, isTLS bool) { + ch := r.httpCh + chanName := "HTTP" + if !isTLS && r.httpPlainCh != nil { + ch = r.httpPlainCh + chanName = "HTTP-plain" + } + + if ch == nil { + r.logger.Debugf("%s channel nil; dropping conn from %s", chanName, conn.RemoteAddr()) _ = conn.Close() return } @@ -399,14 +470,15 @@ func (r *Router) sendToHTTP(conn net.Conn) { r.mu.RUnlock() if draining { + r.logger.Debugf("router draining; dropping conn from %s", conn.RemoteAddr()) _ = conn.Close() return } select { - case r.httpCh <- conn: + case ch <- conn: default: - r.logger.Warnf("HTTP channel full, dropping connection from %s", conn.RemoteAddr()) + r.logger.Warnf("%s channel full, dropping connection from %s", chanName, conn.RemoteAddr()) _ = conn.Close() } } diff --git a/proxy/internal/tcp/router_test.go b/proxy/internal/tcp/router_test.go index 93b6560f4..2f96d142c 100644 --- a/proxy/internal/tcp/router_test.go +++ b/proxy/internal/tcp/router_test.go @@ -1739,3 +1739,97 @@ func TestCheckRestrictions_IPv4MappedIPv6(t *testing.T) { connOutside := &fakeConn{remote: fakeAddr("[::ffff:192.168.1.1]:5678")} assert.NotEqual(t, restrict.Allow, router.checkRestrictions(connOutside, route), "::ffff:192.168.1.1 not in v4 CIDR") } + +// TestRouter_PlainHTTP_RoutesToPlainChannel verifies that a plain (non-TLS) +// connection lands on the plain HTTP channel when the router was built +// with WithPlainHTTP, leaving the TLS channel untouched. +func TestRouter_PlainHTTP_RoutesToPlainChannel(t *testing.T) { + logger := log.StandardLogger() + addr := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 443} + + router := NewRouter(logger, nil, addr, WithPlainHTTP(addr)) + router.AddRoute("example.com", Route{Type: RouteHTTP}) + + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err, "test listener bind must succeed") + defer ln.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { + _ = router.Serve(ctx, ln) + }() + + // Plain HTTP request (no TLS handshake byte). + go func() { + conn, err := net.DialTimeout("tcp", ln.Addr().String(), 2*time.Second) + if err != nil { + return + } + _, _ = conn.Write([]byte("GET / HTTP/1.1\r\nHost: example.com\r\n\r\n")) + }() + + plainListener := router.HTTPListenerPlain() + require.NotNil(t, plainListener, "plain listener must be exposed when WithPlainHTTP is set") + + acceptDone := make(chan net.Conn, 1) + go func() { + conn, err := plainListener.Accept() + if err == nil { + acceptDone <- conn + } + }() + + select { + case conn := <-acceptDone: + require.NotNil(t, conn) + _ = conn.Close() + case <-router.HTTPListener().(*chanListener).ch: + t.Fatal("plain HTTP request leaked into TLS channel") + case <-time.After(3 * time.Second): + t.Fatal("plain HTTP connection never reached plain channel") + } +} + +// TestRouter_TLS_StaysOnTLSChannel_WhenPlainEnabled verifies that the +// presence of a plain channel does not divert TLS traffic โ€” TLS still +// goes to the TLS channel as before. +func TestRouter_TLS_StaysOnTLSChannel_WhenPlainEnabled(t *testing.T) { + logger := log.StandardLogger() + addr := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 443} + + router := NewRouter(logger, nil, addr, WithPlainHTTP(addr)) + router.AddRoute("example.com", Route{Type: RouteHTTP}) + + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err, "test listener bind must succeed") + defer ln.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { _ = router.Serve(ctx, ln) }() + + // Send a TLS ClientHello. + go func() { + conn, err := net.DialTimeout("tcp", ln.Addr().String(), 2*time.Second) + if err != nil { + return + } + tlsConn := tls.Client(conn, &tls.Config{ + ServerName: "example.com", + InsecureSkipVerify: true, //nolint:gosec + }) + _ = tlsConn.Handshake() + _ = tlsConn.Close() + }() + + select { + case conn := <-router.httpCh: + require.NotNil(t, conn, "TLS conn should land on the TLS channel") + _ = conn.Close() + case <-time.After(5 * time.Second): + t.Fatal("TLS conn never reached the TLS channel") + } +} diff --git a/proxy/internal/tcp/snipeek.go b/proxy/internal/tcp/snipeek.go index 25ab8e5ef..dc3c96498 100644 --- a/proxy/internal/tcp/snipeek.go +++ b/proxy/internal/tcp/snipeek.go @@ -30,26 +30,30 @@ const ( // bytes transparently. If the data is not a valid TLS ClientHello or // contains no SNI extension, sni is empty and err is nil. // +// isTLS reports whether the first byte indicated a TLS handshake record. +// Callers can use this to distinguish plain (non-TLS) traffic from a TLS +// stream that simply lacked an SNI extension or used ECH. +// // ECH/ESNI: When the client uses Encrypted Client Hello (TLS 1.3), the // real server name is encrypted inside the encrypted_client_hello // extension. This parser only reads the cleartext server_name extension // (type 0x0000), so ECH connections return sni="" and are routed through // the fallback path (or HTTP channel), which is the correct behavior // for a transparent proxy that does not terminate TLS. -func PeekClientHello(conn net.Conn) (sni string, wrapped net.Conn, err error) { +func PeekClientHello(conn net.Conn) (sni string, wrapped net.Conn, isTLS bool, err error) { // Read the 5-byte TLS record header into a small stack-friendly buffer. var header [tlsRecordHeaderLen]byte if _, err := io.ReadFull(conn, header[:]); err != nil { - return "", nil, fmt.Errorf("read TLS record header: %w", err) + return "", nil, false, fmt.Errorf("read TLS record header: %w", err) } if header[0] != contentTypeHandshake { - return "", newPeekedConn(conn, header[:]), nil + return "", newPeekedConn(conn, header[:]), false, nil } recordLen := int(binary.BigEndian.Uint16(header[3:5])) if recordLen == 0 || recordLen > maxClientHelloLen { - return "", newPeekedConn(conn, header[:]), nil + return "", newPeekedConn(conn, header[:]), true, nil } // Single allocation for header + payload. The peekedConn takes @@ -59,11 +63,11 @@ func PeekClientHello(conn net.Conn) (sni string, wrapped net.Conn, err error) { n, err := io.ReadFull(conn, buf[tlsRecordHeaderLen:]) if err != nil { - return "", newPeekedConn(conn, buf[:tlsRecordHeaderLen+n]), fmt.Errorf("read TLS handshake payload: %w", err) + return "", newPeekedConn(conn, buf[:tlsRecordHeaderLen+n]), true, fmt.Errorf("read TLS handshake payload: %w", err) } sni = extractSNI(buf[tlsRecordHeaderLen:]) - return sni, newPeekedConn(conn, buf), nil + return sni, newPeekedConn(conn, buf), true, nil } // extractSNI parses a TLS handshake payload to find the SNI extension. diff --git a/proxy/internal/tcp/snipeek_test.go b/proxy/internal/tcp/snipeek_test.go index 9afe6261d..85dee15c1 100644 --- a/proxy/internal/tcp/snipeek_test.go +++ b/proxy/internal/tcp/snipeek_test.go @@ -29,10 +29,11 @@ func TestPeekClientHello_ValidSNI(t *testing.T) { _ = tlsConn.Handshake() }() - sni, wrapped, err := PeekClientHello(serverConn) + sni, wrapped, isTLS, err := PeekClientHello(serverConn) require.NoError(t, err) assert.Equal(t, expectedSNI, sni, "should extract SNI from ClientHello") assert.NotNil(t, wrapped, "wrapped connection should not be nil") + assert.True(t, isTLS, "TLS ClientHello should be flagged as TLS") // Verify the wrapped connection replays the peeked bytes. // Read the first 5 bytes (TLS record header) to confirm replay. @@ -83,10 +84,11 @@ func TestPeekClientHello_MultipleSNIs(t *testing.T) { _ = tlsConn.Handshake() }() - sni, wrapped, err := PeekClientHello(serverConn) + sni, wrapped, isTLS, err := PeekClientHello(serverConn) require.NoError(t, err) assert.Equal(t, tt.expectedSNI, sni) assert.NotNil(t, wrapped) + assert.True(t, isTLS, "TLS handshake should be flagged as TLS") }) } } @@ -102,10 +104,11 @@ func TestPeekClientHello_NonTLSData(t *testing.T) { _, _ = clientConn.Write(httpData) }() - sni, wrapped, err := PeekClientHello(serverConn) + sni, wrapped, isTLS, err := PeekClientHello(serverConn) require.NoError(t, err) assert.Empty(t, sni, "should return empty SNI for non-TLS data") assert.NotNil(t, wrapped) + assert.False(t, isTLS, "plain HTTP data should not be flagged as TLS") // Verify the wrapped connection still provides the original data. buf := make([]byte, len(httpData)) @@ -124,7 +127,7 @@ func TestPeekClientHello_TruncatedHeader(t *testing.T) { clientConn.Close() }() - _, _, err := PeekClientHello(serverConn) + _, _, _, err := PeekClientHello(serverConn) assert.Error(t, err, "should error on truncated header") } @@ -140,7 +143,7 @@ func TestPeekClientHello_TruncatedPayload(t *testing.T) { clientConn.Close() }() - _, _, err := PeekClientHello(serverConn) + _, _, _, err := PeekClientHello(serverConn) assert.Error(t, err, "should error on truncated payload") } @@ -154,10 +157,11 @@ func TestPeekClientHello_ZeroLengthRecord(t *testing.T) { _, _ = clientConn.Write([]byte{0x16, 0x03, 0x01, 0x00, 0x00}) }() - sni, wrapped, err := PeekClientHello(serverConn) + sni, wrapped, isTLS, err := PeekClientHello(serverConn) require.NoError(t, err) assert.Empty(t, sni) assert.NotNil(t, wrapped) + assert.True(t, isTLS, "zero-length record should still be a TLS handshake byte") } func TestExtractSNI_InvalidPayload(t *testing.T) { diff --git a/proxy/internal/types/types.go b/proxy/internal/types/types.go index bf3731803..ffdbf2301 100644 --- a/proxy/internal/types/types.go +++ b/proxy/internal/types/types.go @@ -54,3 +54,23 @@ func DialTimeoutFromContext(ctx context.Context) (time.Duration, bool) { d, ok := ctx.Value(dialTimeoutKey{}).(time.Duration) return d, ok && d > 0 } + +// overlayOriginKey is the context key set by per-account inbound +// listeners to mark a request as originating from the WireGuard +// overlay rather than the public-facing host listener. +type overlayOriginKey struct{} + +// WithOverlayOrigin marks the context as originating from the +// embedded NetBird overlay (tunnel-side inbound listener). +func WithOverlayOrigin(ctx context.Context) context.Context { + return context.WithValue(ctx, overlayOriginKey{}, true) +} + +// IsOverlayOrigin reports whether the request reached the proxy via +// the overlay listener. Middlewares that only make sense for WAN +// traffic (geolocation, CrowdSec IP reputation) should short-circuit +// when this is true. +func IsOverlayOrigin(ctx context.Context) bool { + v, _ := ctx.Value(overlayOriginKey{}).(bool) + return v +} diff --git a/proxy/lifecycle.go b/proxy/lifecycle.go new file mode 100644 index 000000000..6cb420722 --- /dev/null +++ b/proxy/lifecycle.go @@ -0,0 +1,160 @@ +package proxy + +import ( + "net/netip" + "time" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/proxy/internal/acme" +) + +// Config bundles every knob the proxy reads at construction time. It mirrors +// the public fields on Server so library callers don't have to learn the +// internal struct layout. Zero values mean "feature off" or "fall back to the +// internal default" depending on the field โ€” see the per-field doc. +// +// The standalone binary continues to populate Server fields directly, so +// adding fields here must not change the zero-value behaviour of Server. +type Config struct { + // ListenAddr is the TCP address the main listener binds. Required. + ListenAddr string + // ID identifies this proxy instance to management. Empty value lets + // New generate a timestamped default. + ID string + // Logger is the logrus logger used everywhere. Empty value falls back + // to log.StandardLogger(). + Logger *log.Logger + // Version is the build version string reported to management. Empty + // becomes "dev". + Version string + // ProxyURL is the public address operators use to reach this proxy. + ProxyURL string + // ManagementAddress is the gRPC URL of the management server. + ManagementAddress string + // ProxyToken authenticates this proxy with the management server. + ProxyToken string + + // CertificateDirectory is the directory holding TLS certificate + // material (static or ACME-provisioned). + CertificateDirectory string + // CertificateFile is the certificate filename within + // CertificateDirectory. + CertificateFile string + // CertificateKeyFile is the private key filename within + // CertificateDirectory. + CertificateKeyFile string + // GenerateACMECertificates toggles ACME certificate provisioning. + GenerateACMECertificates bool + // ACMEChallengeAddress is the listen address for HTTP-01 challenges. + ACMEChallengeAddress string + // ACMEDirectory is the ACME directory URL (Let's Encrypt by default). + ACMEDirectory string + // ACMEEABKID is the External Account Binding Key ID for CAs that + // require EAB (e.g. ZeroSSL). + ACMEEABKID string + // ACMEEABHMACKey is the External Account Binding HMAC key for CAs + // that require EAB. + ACMEEABHMACKey string + // ACMEChallengeType is the ACME challenge type ("tls-alpn-01" or + // "http-01"). Empty defaults to "tls-alpn-01". + ACMEChallengeType string + // CertLockMethod controls how ACME certificate locks are coordinated + // across replicas. + CertLockMethod acme.CertLockMethod + // WildcardCertDir is an optional directory containing static wildcard + // certificates that override ACME for matching domains. + WildcardCertDir string + + // DebugEndpointEnabled toggles the debug HTTP endpoint. + DebugEndpointEnabled bool + // DebugEndpointAddress is the bind address for the debug endpoint. + DebugEndpointAddress string + // HealthAddr is the bind address for the health probe and metrics + // surface. Empty disables the health probe entirely (library callers + // can attach their own). + HealthAddr string + + // ForwardedProto overrides the X-Forwarded-Proto value sent to + // backends. Valid values: "auto", "http", "https". + ForwardedProto string + // TrustedProxies is a list of IP prefixes for trusted upstream + // proxies that may set forwarding headers. + TrustedProxies []netip.Prefix + // WireguardPort is the UDP port for the embedded NetBird tunnel. + // Zero asks the OS for a random port. + WireguardPort uint16 + // ProxyProtocol enables PROXY protocol (v1/v2) on TCP listeners. + ProxyProtocol bool + // PreSharedKey is the WireGuard pre-shared key used between the + // proxy's embedded clients and peers. + PreSharedKey string + + // SupportsCustomPorts indicates whether the proxy can bind arbitrary + // ports for TCP/UDP/TLS services. + SupportsCustomPorts bool + // RequireSubdomain forces accounts to use a subdomain in front of + // the proxy's cluster domain. + RequireSubdomain bool + // Private flags this proxy as embedded in a netbird client and + // serving exclusively over the WireGuard tunnel. Also enables + // per-account inbound listeners on each embedded client's netstack. + Private bool + + // MaxDialTimeout caps the per-service backend dial timeout. + MaxDialTimeout time.Duration + // MaxSessionIdleTimeout caps the per-service session idle timeout. + MaxSessionIdleTimeout time.Duration + + // GeoDataDir is the directory containing GeoLite2 MMDB files. + GeoDataDir string + // CrowdSecAPIURL is the CrowdSec LAPI URL. Empty disables CrowdSec. + CrowdSecAPIURL string + // CrowdSecAPIKey is the CrowdSec bouncer API key. Empty disables + // CrowdSec. + CrowdSecAPIKey string +} + +// New builds a Server from cfg without performing any I/O. No goroutines +// are spawned, no network connections are dialed, and no listeners are +// bound โ€” call Start to bring the proxy up. Returning a fully-formed +// Server keeps the standalone code path (which still constructs Server +// directly) byte-for-byte equivalent. +func New(cfg Config) *Server { + return &Server{ + ListenAddr: cfg.ListenAddr, + ID: cfg.ID, + Logger: cfg.Logger, + Version: cfg.Version, + ProxyURL: cfg.ProxyURL, + ManagementAddress: cfg.ManagementAddress, + ProxyToken: cfg.ProxyToken, + CertificateDirectory: cfg.CertificateDirectory, + CertificateFile: cfg.CertificateFile, + CertificateKeyFile: cfg.CertificateKeyFile, + GenerateACMECertificates: cfg.GenerateACMECertificates, + ACMEChallengeAddress: cfg.ACMEChallengeAddress, + ACMEDirectory: cfg.ACMEDirectory, + ACMEEABKID: cfg.ACMEEABKID, + ACMEEABHMACKey: cfg.ACMEEABHMACKey, + ACMEChallengeType: cfg.ACMEChallengeType, + CertLockMethod: cfg.CertLockMethod, + WildcardCertDir: cfg.WildcardCertDir, + DebugEndpointEnabled: cfg.DebugEndpointEnabled, + DebugEndpointAddress: cfg.DebugEndpointAddress, + HealthAddress: cfg.HealthAddr, + ForwardedProto: cfg.ForwardedProto, + TrustedProxies: cfg.TrustedProxies, + WireguardPort: cfg.WireguardPort, + ProxyProtocol: cfg.ProxyProtocol, + PreSharedKey: cfg.PreSharedKey, + SupportsCustomPorts: cfg.SupportsCustomPorts, + RequireSubdomain: cfg.RequireSubdomain, + Private: cfg.Private, + MaxDialTimeout: cfg.MaxDialTimeout, + MaxSessionIdleTimeout: cfg.MaxSessionIdleTimeout, + GeoDataDir: cfg.GeoDataDir, + CrowdSecAPIURL: cfg.CrowdSecAPIURL, + CrowdSecAPIKey: cfg.CrowdSecAPIKey, + } +} diff --git a/proxy/management_integration_test.go b/proxy/management_integration_test.go index d7e891801..bf5067b85 100644 --- a/proxy/management_integration_test.go +++ b/proxy/management_integration_test.go @@ -239,6 +239,10 @@ func (m *testProxyManager) ClusterSupportsCrowdSec(_ context.Context, _ string) return nil } +func (m *testProxyManager) ClusterSupportsPrivate(_ context.Context, _ string) *bool { + return nil +} + func (m *testProxyManager) CleanupStale(_ context.Context, _ time.Duration) error { return nil } @@ -565,6 +569,7 @@ func TestIntegration_ProxyConnection_ReconnectDoesNotDuplicateState(t *testing.T proxytypes.AccountID(mapping.GetAccountId()), proxytypes.ServiceID(mapping.GetId()), nil, + mapping.GetPrivate(), ) require.NoError(t, err) diff --git a/proxy/server.go b/proxy/server.go index 6ccfa3e9a..63a0c577a 100644 --- a/proxy/server.go +++ b/proxy/server.go @@ -37,6 +37,8 @@ import ( "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/keepalive" grpcstatus "google.golang.org/grpc/status" + "google.golang.org/protobuf/encoding/protojson" + goproto "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/known/timestamppb" "github.com/netbirdio/netbird/proxy/internal/accesslog" @@ -114,9 +116,28 @@ type Server struct { // The mapping worker waits on this before processing updates. routerReady chan struct{} + // inbound, when non-nil, manages per-account inbound listeners. Set by + // initPrivateInbound only when Private is true so the standalone + // proxy keeps its zero-overhead default path. + inbound *inboundManager + + // Lifecycle state โ€” populated by Start, consumed by Stop. The fields + // stay zero on a fresh Server until Start runs so direct struct + // construction (`&Server{...}`) used by tests still works. + runCancel context.CancelFunc + mgmtConn *grpc.ClientConn + runErr error + runErrCh chan struct{} + startMu sync.Mutex + started bool + stopOnce sync.Once + // Mostly used for debugging on management. startTime time.Time + // ListenAddr is the address the main TCP listener binds. Populated by + // New from Config or by ListenAndServe from its addr argument. + ListenAddr string ID string Logger *log.Logger Version string @@ -177,6 +198,14 @@ type Server struct { // in front of this proxy's cluster domain. When true, accounts cannot // create services on the bare cluster domain. RequireSubdomain bool + // Private flags this proxy as embedded in a netbird client and serving + // exclusively over the WireGuard tunnel (i.e. `netbird proxy`). Reported + // upstream as a capability so dashboards can distinguish per-peer + // clusters from centralised ones, and turns on per-account inbound + // listeners on each embedded client's netstack: every account that + // registers a service exposes :80 + :443 inside its own WG tunnel, + // scoped to that account's services only. + Private bool // MaxDialTimeout caps the per-service backend dial timeout. // When the API sends a timeout, it is clamped to this value. // When the API sends no timeout, this value is used as the default. @@ -222,12 +251,16 @@ func (s *Server) NotifyStatus(ctx context.Context, accountID types.AccountID, se status = proto.ProxyStatus_PROXY_STATUS_ACTIVE } - _, err := s.mgmtClient.SendStatusUpdate(ctx, &proto.SendStatusUpdateRequest{ + req := &proto.SendStatusUpdateRequest{ ServiceId: string(serviceID), AccountId: string(accountID), Status: status, CertificateIssued: false, - }) + } + if connected { + req.InboundListener = s.inboundListenerProto(accountID) + } + _, err := s.mgmtClient.SendStatusUpdate(ctx, req) return err } @@ -238,56 +271,68 @@ func (s *Server) NotifyCertificateIssued(ctx context.Context, accountID types.Ac AccountId: string(accountID), Status: proto.ProxyStatus_PROXY_STATUS_ACTIVE, CertificateIssued: true, + InboundListener: s.inboundListenerProto(accountID), }) return err } -func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) { - s.initDefaults() - s.routerReady = make(chan struct{}) - s.udpRelays = make(map[types.ServiceID]*udprelay.Relay) - s.portRouters = make(map[uint16]*portRouter) - s.svcPorts = make(map[types.ServiceID][]uint16) - s.lastMappings = make(map[types.ServiceID]*proto.ProxyMapping) - - exporter, err := prometheus.New() - if err != nil { - return fmt.Errorf("create prometheus exporter: %w", err) +// inboundListenerProto resolves the per-account inbound listener state for +// the SendStatusUpdate payload. Returns nil when --private-inbound is off +// or the account has no live listener so management treats the field as +// absent. +func (s *Server) inboundListenerProto(accountID types.AccountID) *proto.ProxyInboundListener { + if s.inbound == nil { + return nil } - - provider := metric.NewMeterProvider(metric.WithReader(exporter)) - pkg := reflect.TypeOf(Server{}).PkgPath() - meter := provider.Meter(pkg) - - s.meter, err = proxymetrics.New(ctx, meter) - if err != nil { - return fmt.Errorf("create metrics: %w", err) + info, ok := s.inbound.ListenerInfo(accountID) + if !ok || info.TunnelIP == "" { + return nil } + return &proto.ProxyInboundListener{ + TunnelIp: info.TunnelIP, + HttpsPort: uint32(info.HTTPSPort), + HttpPort: uint32(info.HTTPPort), + } +} - mgmtConn, err := s.dialManagement() - if err != nil { +// ListenAndServe is the standalone entrypoint. It binds the listener, runs +// the proxy until ctx is cancelled or a background goroutine fails, then +// drains and stops. Library callers should prefer New + Start + Stop and +// own their own shutdown signalling. +func (s *Server) ListenAndServe(ctx context.Context, addr string) error { + s.ListenAddr = addr + if err := s.Start(ctx); err != nil { return err } - defer func() { - if err := mgmtConn.Close(); err != nil { - s.Logger.Debugf("management connection close: %v", err) - } - }() - s.mgmtClient = proto.NewProxyServiceClient(mgmtConn) + return s.waitAndStop(ctx) +} + +// Start brings the proxy up: dials management, configures TLS, binds the +// main listener, and spawns the SNI router and HTTPS server goroutines. It +// returns once the listener is bound; background errors are surfaced +// through Stop's return value. Start is not safe to call twice. +func (s *Server) Start(ctx context.Context) error { + s.startMu.Lock() + if s.started { + s.startMu.Unlock() + return errors.New("proxy already started") + } + s.started = true + s.startMu.Unlock() + + s.initLifecycleState() + if err := s.initMetrics(ctx); err != nil { + return err + } + + if err := s.initManagementClient(); err != nil { + return err + } + runCtx, runCancel := context.WithCancel(ctx) - defer runCancel() + s.runCancel = runCancel - // Initialize the netbird client, this is required to build peer connections - // to proxy over. - s.netbird = roundtrip.NewNetBird(s.ID, s.ProxyURL, roundtrip.ClientConfig{ - MgmtAddr: s.ManagementAddress, - WGPort: s.WireguardPort, - PreSharedKey: s.PreSharedKey, - }, s.Logger, s, s.mgmtClient) - s.netbird.OnAddPeer = s.meter.RecordAddPeerDuration - - // Create health checker before the mapping worker so it can track - // management connectivity from the first stream connection. + s.initNetBirdClient() s.healthChecker = health.NewChecker(s.Logger, s.netbird) s.crowdsecRegistry = crowdsec.NewRegistry(s.CrowdSecAPIURL, s.CrowdSecAPIKey, log.NewEntry(s.Logger)) @@ -300,34 +345,25 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) { return err } - // Configure the reverse proxy using NetBird's HTTP Client Transport for proxying. - s.proxy = proxy.NewReverseProxy(s.meter.RoundTripper(s.netbird), s.ForwardedProto, s.TrustedProxies, s.Logger) + s.initReverseProxy() - geoLookup, err := geolocation.NewLookup(s.Logger, s.GeoDataDir) - if err != nil { - return fmt.Errorf("initialize geolocation: %w", err) - } - s.geoRaw = geoLookup - if geoLookup != nil { - s.geo = geoLookup + if err := s.initGeoLookup(); err != nil { + return err } - var startupOK bool + startupOK := false defer func() { if startupOK { return } if s.geoRaw != nil { - if err := s.geoRaw.Close(); err != nil { - s.Logger.Debugf("close geolocation on startup failure: %v", err) + if closeErr := s.geoRaw.Close(); closeErr != nil { + s.Logger.Debugf("close geolocation on startup failure: %v", closeErr) } } }() - // Configure the authentication middleware with session validator for OIDC group checks. s.auth = auth.NewMiddleware(s.Logger, s.mgmtClient, s.geo) - - // Configure Access logs to management server. s.accessLog = accesslog.NewLogger(s.mgmtClient, s.Logger, s.TrustedProxies) s.startDebugEndpoint() @@ -336,35 +372,21 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) { return err } - // Build the handler chain from inside out. - handler := http.Handler(s.proxy) - handler = s.auth.Protect(handler) - handler = web.AssetHandler(handler) - handler = s.accessLog.Middleware(handler) - handler = s.meter.Middleware(handler) - handler = s.hijackTracker.Middleware(handler) + handler := s.buildHandlerChain() + s.initPrivateInbound(handler, tlsConfig) - // Start a raw TCP listener; the SNI router peeks at ClientHello - // and routes to either the HTTP handler or a TCP relay. - lc := net.ListenConfig{} - ln, err := lc.Listen(ctx, "tcp", addr) + ln, err := s.bindMainListener(ctx) if err != nil { - return fmt.Errorf("listen on %s: %w", addr, err) + return err } - if s.ProxyProtocol { - ln = s.wrapProxyProtocol(ln) - } - s.mainPort = uint16(ln.Addr().(*net.TCPAddr).Port) //nolint:gosec // port from OS is always valid - // Set up the SNI router for TCP/HTTP multiplexing on the main port. s.mainRouter = nbtcp.NewRouter(s.Logger, s.resolveDialFunc, ln.Addr()) s.mainRouter.SetObserver(s.meter) s.mainRouter.SetAccessLogger(s.accessLog) close(s.routerReady) - // The HTTP server uses the chanListener fed by the SNI router. s.https = &http.Server{ - Addr: addr, + Addr: s.ListenAddr, Handler: handler, TLSConfig: tlsConfig, ReadHeaderTimeout: httpReadHeaderTimeout, @@ -374,35 +396,201 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) { startupOK = true - httpsErr := make(chan error, 1) go func() { s.Logger.Debug("starting HTTPS server on SNI router HTTP channel") - httpsErr <- s.https.ServeTLS(s.mainRouter.HTTPListener(), "", "") + if serveErr := s.https.ServeTLS(s.mainRouter.HTTPListener(), "", ""); serveErr != nil && !errors.Is(serveErr, http.ErrServerClosed) { + s.recordRunErr(fmt.Errorf("https server: %w", serveErr)) + } }() - routerErr := make(chan error, 1) go func() { - s.Logger.Debugf("starting SNI router on %s", addr) - routerErr <- s.mainRouter.Serve(runCtx, ln) + s.Logger.Debugf("starting SNI router on %s", s.ListenAddr) + if serveErr := s.mainRouter.Serve(runCtx, ln); serveErr != nil { + s.recordRunErr(fmt.Errorf("SNI router: %w", serveErr)) + } }() + return nil +} + +// Stop drains in-flight connections, shuts down all background services, +// and releases resources. Idempotent; calling it before Start is a no-op. +// Returns the first fatal error reported by a background goroutine, if +// any. The provided ctx bounds the total wait time; once it is cancelled +// Stop returns even if drain is still in flight. +func (s *Server) Stop(ctx context.Context) error { + s.stopOnce.Do(func() { + s.startMu.Lock() + started := s.started + s.startMu.Unlock() + if !started { + return + } + + done := make(chan struct{}) + go func() { + defer close(done) + s.gracefulShutdown() + if s.runCancel != nil { + s.runCancel() + } + if s.mgmtConn != nil { + if err := s.mgmtConn.Close(); err != nil { + s.Logger.Debugf("management connection close: %v", err) + } + } + }() + + select { + case <-done: + case <-ctx.Done(): + s.Logger.Warnf("proxy stop deadline exceeded: %v", ctx.Err()) + } + }) + + s.startMu.Lock() + defer s.startMu.Unlock() + return s.runErr +} + +// waitAndStop blocks until ctx is cancelled or a background goroutine +// reports a fatal error, then drains and stops. Used by ListenAndServe. +func (s *Server) waitAndStop(ctx context.Context) error { select { - case err := <-httpsErr: - s.shutdownServices() - if !errors.Is(err, http.ErrServerClosed) { - return fmt.Errorf("https server: %w", err) - } - return nil - case err := <-routerErr: - s.shutdownServices() - if err != nil { - return fmt.Errorf("SNI router: %w", err) - } - return nil case <-ctx.Done(): - s.gracefulShutdown() - return nil + case <-s.runErrCh: } + stopCtx, cancel := context.WithTimeout(context.Background(), shutdownDrainTimeout+shutdownServiceTimeout) + defer cancel() + return s.Stop(stopCtx) +} + +// recordRunErr stores the first fatal background error and signals +// waitAndStop. Subsequent errors are logged at debug level so the first +// cause is preserved. +func (s *Server) recordRunErr(err error) { + s.startMu.Lock() + defer s.startMu.Unlock() + if s.runErr != nil { + s.Logger.Debugf("background error after first failure: %v", err) + return + } + s.runErr = err + if s.runErrCh != nil { + close(s.runErrCh) + } +} + +// initLifecycleState seeds the maps and channels Start needs to wire up +// background goroutines. Called once at the top of Start. +func (s *Server) initLifecycleState() { + s.initDefaults() + s.routerReady = make(chan struct{}) + s.runErrCh = make(chan struct{}) + s.udpRelays = make(map[types.ServiceID]*udprelay.Relay) + s.portRouters = make(map[uint16]*portRouter) + s.svcPorts = make(map[types.ServiceID][]uint16) + s.lastMappings = make(map[types.ServiceID]*proto.ProxyMapping) +} + +// initMetrics builds the prometheus exporter and meter bundle. +func (s *Server) initMetrics(ctx context.Context) error { + exporter, err := prometheus.New() + if err != nil { + return fmt.Errorf("create prometheus exporter: %w", err) + } + provider := metric.NewMeterProvider(metric.WithReader(exporter)) + pkg := reflect.TypeOf(Server{}).PkgPath() + meter := provider.Meter(pkg) + s.meter, err = proxymetrics.New(ctx, meter) + if err != nil { + return fmt.Errorf("create metrics: %w", err) + } + return nil +} + +// initManagementClient dials management and stashes the connection so +// Stop can close it deterministically. +func (s *Server) initManagementClient() error { + conn, err := s.dialManagement() + if err != nil { + return err + } + s.mgmtConn = conn + s.mgmtClient = proto.NewProxyServiceClient(conn) + return nil +} + +// initNetBirdClient builds the multi-tenant embedded NetBird client used +// for outbound RoundTripping and (when --private-inbound is on) per-account +// inbound listeners. +func (s *Server) initNetBirdClient() { + s.netbird = roundtrip.NewNetBird(s.ID, s.ProxyURL, roundtrip.ClientConfig{ + MgmtAddr: s.ManagementAddress, + WGPort: s.WireguardPort, + PreSharedKey: s.PreSharedKey, + // On --private the embedded client serves per-account inbound + // listeners and must apply management's ACL: keep BlockInbound off + // so the engine creates the ACL manager. On the standalone proxy + // the embedded client never accepts inbound, so block. + BlockInbound: !s.Private, + }, s.Logger, s, s.mgmtClient) + s.netbird.OnAddPeer = s.meter.RecordAddPeerDuration +} + +// initReverseProxy builds the meter-instrumented reverse proxy. MultiTransport +// routes targets opted into direct_upstream through the host's network stack +// (stdlib transport); everything else falls through to the embedded NetBird +// client. The split is needed so direct_upstream targets resolve DNS via the +// proxy host's resolver instead of the tunnel's DNS. +func (s *Server) initReverseProxy() { + upstreamRT := roundtrip.NewMultiTransport(s.netbird, s.Logger) + s.proxy = proxy.NewReverseProxy(s.meter.RoundTripper(upstreamRT), s.ForwardedProto, s.TrustedProxies, s.Logger) +} + +// initGeoLookup configures the GeoLite2 lookup used for country-based +// access restrictions and access-log enrichment. +func (s *Server) initGeoLookup() error { + geoLookup, err := geolocation.NewLookup(s.Logger, s.GeoDataDir) + if err != nil { + return fmt.Errorf("initialize geolocation: %w", err) + } + s.geoRaw = geoLookup + if geoLookup != nil { + s.geo = geoLookup + } + return nil +} + +// buildHandlerChain wires the request middlewares from inside out. +func (s *Server) buildHandlerChain() http.Handler { + handler := http.Handler(s.proxy) + handler = s.auth.Protect(handler) + handler = web.AssetHandler(handler) + handler = s.accessLog.Middleware(handler) + handler = s.meter.Middleware(handler) + return s.hijackTracker.Middleware(handler) +} + +// bindMainListener binds the main TCP listener and wraps it with PROXY +// protocol when configured. +func (s *Server) bindMainListener(ctx context.Context) (net.Listener, error) { + lc := net.ListenConfig{} + ln, err := lc.Listen(ctx, "tcp", s.ListenAddr) + if err != nil { + return nil, fmt.Errorf("listen on %s: %w", s.ListenAddr, err) + } + if s.ProxyProtocol { + ln = s.wrapProxyProtocol(ln) + } + s.mainPort = uint16(ln.Addr().(*net.TCPAddr).Port) //nolint:gosec // port from OS is always valid + s.Logger.WithFields(log.Fields{ + "requested_addr": s.ListenAddr, + "bound_addr": ln.Addr().String(), + "private": s.Private, + "proxy_protocol": s.ProxyProtocol, + }).Info("proxy main listener bound") + return ln, nil } // initDefaults sets fallback values for optional Server fields. @@ -434,6 +622,9 @@ func (s *Server) startDebugEndpoint() { if s.acme != nil { debugHandler.SetCertStatus(s.acme) } + if s.inbound != nil { + debugHandler.SetInboundProvider(inboundDebugAdapter{mgr: s.inbound}) + } s.debug = &http.Server{ Addr: debugAddr, Handler: debugHandler, @@ -447,16 +638,18 @@ func (s *Server) startDebugEndpoint() { }() } -// startHealthServer launches the health probe and metrics server. +// startHealthServer launches the health probe and metrics server. Empty +// HealthAddress disables the probe entirely (intended for library callers +// that want to manage their own health surface). func (s *Server) startHealthServer() error { - healthAddr := s.HealthAddress - if healthAddr == "" { - healthAddr = defaultHealthAddr + if s.HealthAddress == "" { + s.Logger.Debug("health probe disabled (empty HealthAddress)") + return nil } - s.healthServer = health.NewServer(healthAddr, s.healthChecker, s.Logger, promhttp.HandlerFor(prometheus2.DefaultGatherer, promhttp.HandlerOpts{EnableOpenMetrics: true})) - healthListener, err := net.Listen("tcp", healthAddr) + s.healthServer = health.NewServer(s.HealthAddress, s.healthChecker, s.Logger, promhttp.HandlerFor(prometheus2.DefaultGatherer, promhttp.HandlerOpts{EnableOpenMetrics: true})) + healthListener, err := net.Listen("tcp", s.HealthAddress) if err != nil { - return fmt.Errorf("health probe server listen on %s: %w", healthAddr, err) + return fmt.Errorf("health probe server listen on %s: %w", s.HealthAddress, err) } go func() { if err := s.healthServer.Serve(healthListener); err != nil && !errors.Is(err, http.ErrServerClosed) { @@ -507,8 +700,9 @@ func (s *Server) proxyProtocolPolicy(opts proxyproto.ConnPolicyOptions) (proxypr } const ( - defaultHealthAddr = "localhost:8080" - defaultDebugAddr = "localhost:8444" + // defaultDebugAddr is the localhost-bound fallback for the debug endpoint + // when DebugEndpointAddress is empty. + defaultDebugAddr = "localhost:8444" // proxyProtoHeaderTimeout is the deadline for reading the PROXY protocol // header after accepting a connection. @@ -661,8 +855,10 @@ func (s *Server) gracefulShutdown() { defer drainCancel() s.Logger.Info("draining in-flight connections") - if err := s.https.Shutdown(drainCtx); err != nil { - s.Logger.Warnf("https server drain: %v", err) + if s.https != nil { + if err := s.https.Shutdown(drainCtx); err != nil { + s.Logger.Warnf("https server drain: %v", err) + } } // Step 4: Close hijacked connections (WebSocket) that Shutdown does not handle. @@ -809,6 +1005,18 @@ func (s *Server) resolveDialFunc(accountID types.AccountID) (types.DialContextFu return client.DialContext, nil } +// initPrivateInbound wires per-account inbound listeners when --private +// is set. When the flag is off this is a no-op and the standalone proxy keeps +// its byte-for-byte previous behaviour. +func (s *Server) initPrivateInbound(handler http.Handler, tlsConfig *tls.Config) { + if !s.Private { + return + } + s.inbound = newInboundManager(s.Logger, handler, tlsConfig) + s.netbird.SetClientLifecycle(s.inbound.onClientReady, s.inbound.onClientStop) + s.Logger.Info("private inbound listeners enabled (per-account :80 + :443)") +} + // notifyError reports a resource error back to management so it can be // surfaced to the user (e.g. port bind failure, dialer resolution error). func (s *Server) notifyError(ctx context.Context, mapping *proto.ProxyMapping, err error) { @@ -942,7 +1150,8 @@ func (s *Server) newManagementMappingWorker(ctx context.Context, client proto.Pr } // syncSupported tracks whether management supports SyncMappings. - // Starts true; set to false on first Unimplemented error. + // Starts true; set to false on the first Unimplemented error so + // subsequent retries skip straight to GetMappingUpdate. syncSupported := true initialSyncDone := false @@ -992,10 +1201,15 @@ func (s *Server) newManagementMappingWorker(ctx context.Context, client proto.Pr func (s *Server) proxyCapabilities() *proto.ProxyCapabilities { supportsCrowdSec := s.crowdsecRegistry.Available() + privateCapability := s.Private + // Always true: this build enforces ProxyMapping.private via the auth middleware. + supportsPrivateService := true return &proto.ProxyCapabilities{ - SupportsCustomPorts: &s.SupportsCustomPorts, - RequireSubdomain: &s.RequireSubdomain, - SupportsCrowdsec: &supportsCrowdSec, + SupportsCustomPorts: &s.SupportsCustomPorts, + RequireSubdomain: &s.RequireSubdomain, + SupportsCrowdsec: &supportsCrowdSec, + Private: &privateCapability, + SupportsPrivateService: &supportsPrivateService, } } @@ -1027,7 +1241,6 @@ func (s *Server) trySyncMappings(ctx context.Context, client proto.ProxyServiceC return fmt.Errorf("create sync stream: %w", err) } - // Send init message. if err := stream.Send(&proto.SyncMappingsRequest{ Msg: &proto.SyncMappingsRequest_Init{ Init: &proto.SyncMappingsInit{ @@ -1058,6 +1271,10 @@ func isSyncUnimplemented(err error) bool { return ok && st.Code() == codes.Unimplemented } +// handleSyncMappingsStream consumes batches from a bidirectional SyncMappings +// stream, sending an ack after each batch is fully processed. Management waits +// for the ack before sending the next batch, providing application-level +// back-pressure. func (s *Server) handleSyncMappingsStream(ctx context.Context, stream proto.ProxyService_SyncMappingsClient, initialSyncDone *bool, connectTime time.Time) error { select { case <-s.routerReady: @@ -1095,39 +1312,10 @@ func (s *Server) handleSyncMappingsStream(ctx context.Context, stream proto.Prox } } -func (s *Server) handleMappingStream(ctx context.Context, mappingClient proto.ProxyService_GetMappingUpdateClient, initialSyncDone *bool, connectTime time.Time) error { - select { - case <-s.routerReady: - case <-ctx.Done(): - return ctx.Err() - } - - tracker := s.newSnapshotTracker(initialSyncDone, connectTime) - - for { - select { - case <-ctx.Done(): - return ctx.Err() - default: - msg, err := mappingClient.Recv() - switch { - case errors.Is(err, io.EOF): - return nil - case err != nil: - return fmt.Errorf("receive msg: %w", err) - } - - batchStart := time.Now() - s.Logger.Debug("Received mapping update, starting processing") - s.processMappings(ctx, msg.GetMapping()) - s.Logger.Debug("Processing mapping update completed") - tracker.recordBatch(ctx, s, msg.GetMapping(), msg.GetInitialSyncComplete(), batchStart) - } - } -} - // snapshotTracker accumulates service IDs during the initial snapshot and -// finalises sync state when the complete flag arrives. +// finalises sync state when the complete flag arrives. Used by both +// handleMappingStream and handleSyncMappingsStream so metric emission and +// reconciliation behave identically on either RPC. type snapshotTracker struct { done *bool connectTime time.Time @@ -1171,6 +1359,37 @@ func (t *snapshotTracker) recordBatch(ctx context.Context, s *Server, mappings [ s.Logger.Info("Initial mapping sync complete") } +func (s *Server) handleMappingStream(ctx context.Context, mappingClient proto.ProxyService_GetMappingUpdateClient, initialSyncDone *bool, connectTime time.Time) error { + select { + case <-s.routerReady: + case <-ctx.Done(): + return ctx.Err() + } + + tracker := s.newSnapshotTracker(initialSyncDone, connectTime) + + for { + select { + case <-ctx.Done(): + return ctx.Err() + default: + msg, err := mappingClient.Recv() + switch { + case errors.Is(err, io.EOF): + return nil + case err != nil: + return fmt.Errorf("receive msg: %w", err) + } + + batchStart := time.Now() + s.Logger.Debug("Received mapping update, starting processing") + s.processMappings(ctx, msg.GetMapping()) + s.Logger.Debug("Processing mapping update completed") + tracker.recordBatch(ctx, s, msg.GetMapping(), msg.GetInitialSyncComplete(), batchStart) + } + } +} + // reconcileSnapshot removes local mappings that are absent from the snapshot. // This ensures services deleted while the proxy was disconnected get cleaned up. func (s *Server) reconcileSnapshot(ctx context.Context, snapshotIDs map[types.ServiceID]struct{}) { @@ -1192,17 +1411,58 @@ func (s *Server) reconcileSnapshot(ctx context.Context, snapshotIDs map[types.Se } } -func (s *Server) processMappings(ctx context.Context, mappings []*proto.ProxyMapping) { - s.ensurePeers(ctx, mappings) +// mappingJSONMarshal dumps mappings on one line with zero-value fields visible for debug logs. +var mappingJSONMarshal = protojson.MarshalOptions{ + Multiline: false, + EmitUnpopulated: true, + UseProtoNames: true, +} +// redactMappingForLog returns a deep copy of the mapping with sensitive fields +// (auth_token, header-auth hashed values, custom upstream headers) replaced so +// debug logs never carry credentials. +func redactMappingForLog(m *proto.ProxyMapping) *proto.ProxyMapping { + const placeholder = "[REDACTED]" + c := goproto.Clone(m).(*proto.ProxyMapping) + if c.GetAuthToken() != "" { + c.AuthToken = placeholder + } + if c.Auth != nil { + for _, h := range c.Auth.GetHeaderAuths() { + if h.GetHashedValue() != "" { + h.HashedValue = placeholder + } + } + } + for _, p := range c.GetPath() { + opts := p.GetOptions() + if opts == nil || len(opts.CustomHeaders) == 0 { + continue + } + redacted := make(map[string]string, len(opts.CustomHeaders)) + for k := range opts.CustomHeaders { + redacted[k] = placeholder + } + opts.CustomHeaders = redacted + } + return c +} + +func (s *Server) processMappings(ctx context.Context, mappings []*proto.ProxyMapping) { + debug := s.Logger != nil && s.Logger.IsLevelEnabled(log.DebugLevel) for _, mapping := range mappings { - s.Logger.WithFields(log.Fields{ - "type": mapping.GetType(), - "domain": mapping.GetDomain(), - "mode": mapping.GetMode(), - "port": mapping.GetListenPort(), - "id": mapping.GetId(), - }).Debug("Processing mapping update") + if debug { + raw, err := mappingJSONMarshal.Marshal(redactMappingForLog(mapping)) + if err != nil { + raw = []byte(fmt.Sprintf("", err)) + } + s.Logger.WithFields(log.Fields{ + "type": mapping.GetType(), + "domain": mapping.GetDomain(), + "id": mapping.GetId(), + "mapping": string(raw), + }).Debug("Processing mapping update") + } switch mapping.GetType() { case proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED: if err := s.addMapping(ctx, mapping); err != nil { @@ -1228,60 +1488,6 @@ func (s *Server) processMappings(ctx context.Context, mappings []*proto.ProxyMap } } -// ensurePeers pre-creates NetBird peers for all unique accounts referenced by -// CREATED mappings. Peers for different accounts are created concurrently, -// which avoids serializing Nร—100ms gRPC round-trips during large initial syncs. -func (s *Server) ensurePeers(ctx context.Context, mappings []*proto.ProxyMapping) { - // Collect one representative mapping per account that needs a new peer. - type peerReq struct { - accountID types.AccountID - svcKey roundtrip.ServiceKey - authToken string - svcID types.ServiceID - } - seen := make(map[types.AccountID]struct{}) - var reqs []peerReq - for _, m := range mappings { - if m.GetType() != proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED { - continue - } - accountID := types.AccountID(m.GetAccountId()) - if _, ok := seen[accountID]; ok { - continue - } - seen[accountID] = struct{}{} - if s.netbird.HasClient(accountID) { - continue - } - reqs = append(reqs, peerReq{ - accountID: accountID, - svcKey: s.serviceKeyForMapping(m), - authToken: m.GetAuthToken(), - svcID: types.ServiceID(m.GetId()), - }) - } - - if len(reqs) <= 1 { - return - } - - var wg sync.WaitGroup - wg.Add(len(reqs)) - for _, r := range reqs { - go func() { - defer wg.Done() - if err := s.netbird.AddPeer(ctx, r.accountID, r.svcKey, r.authToken, r.svcID); err != nil { - s.Logger.WithFields(log.Fields{ - "account_id": r.accountID, - "service_id": r.svcID, - "error": err, - }).Warn("failed to pre-create peer for account") - } - }() - } - wg.Wait() -} - // addMapping registers a service mapping and starts the appropriate relay or routes. func (s *Server) addMapping(ctx context.Context, mapping *proto.ProxyMapping) error { accountID := types.AccountID(mapping.GetAccountId()) @@ -1353,12 +1559,16 @@ func (s *Server) setupHTTPMapping(ctx context.Context, mapping *proto.ProxyMappi if s.acme != nil { wildcardHit = s.acme.AddDomain(d, accountID, svcID) } - s.mainRouter.AddRoute(nbtcp.SNIHost(mapping.GetDomain()), nbtcp.Route{ + httpRoute := nbtcp.Route{ Type: nbtcp.RouteHTTP, AccountID: accountID, ServiceID: svcID, Domain: mapping.GetDomain(), - }) + } + s.mainRouter.AddRoute(nbtcp.SNIHost(mapping.GetDomain()), httpRoute) + if s.inbound != nil { + s.inbound.AddRoute(accountID, nbtcp.SNIHost(mapping.GetDomain()), httpRoute) + } if err := s.updateMapping(ctx, mapping); err != nil { return fmt.Errorf("update mapping for domain %q: %w", d, err) } @@ -1718,7 +1928,7 @@ func (s *Server) updateMapping(ctx context.Context, mapping *proto.ProxyMapping) s.warnIfGeoUnavailable(mapping.GetDomain(), mapping.GetAccessRestrictions()) maxSessionAge := time.Duration(mapping.GetAuth().GetMaxSessionAgeSeconds()) * time.Second - if err := s.auth.AddDomain(mapping.GetDomain(), schemes, mapping.GetAuth().GetSessionKey(), maxSessionAge, accountID, svcID, ipRestrictions); err != nil { + if err := s.auth.AddDomain(mapping.GetDomain(), schemes, mapping.GetAuth().GetSessionKey(), maxSessionAge, accountID, svcID, ipRestrictions, mapping.GetPrivate()); err != nil { return fmt.Errorf("auth setup for domain %s: %w", mapping.GetDomain(), err) } m := s.protoToMapping(ctx, mapping) @@ -1774,6 +1984,9 @@ func (s *Server) cleanupMappingRoutes(mapping *proto.ProxyMapping) { } // Remove SNI route from the main router (covers both HTTP and main-port TLS). s.mainRouter.RemoveRoute(nbtcp.SNIHost(host), svcID) + if s.inbound != nil { + s.inbound.RemoveRoute(types.AccountID(mapping.GetAccountId()), nbtcp.SNIHost(host), svcID) + } } // Extract and delete tracked custom-port entries atomically. @@ -1861,6 +2074,7 @@ func (s *Server) protoToMapping(ctx context.Context, mapping *proto.ProxyMapping if d := opts.GetRequestTimeout(); d != nil { pt.RequestTimeout = d.AsDuration() } + pt.DirectUpstream = opts.GetDirectUpstream() } pt.RequestTimeout = s.clampDialTimeout(pt.RequestTimeout) paths[pathMapping.GetPath()] = pt diff --git a/proxy/server_test.go b/proxy/server_test.go index b4fb4f8ba..10d38f250 100644 --- a/proxy/server_test.go +++ b/proxy/server_test.go @@ -1,9 +1,17 @@ package proxy import ( + "context" + "errors" + "io" "testing" + "time" + log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/shared/management/proto" ) func TestDebugEndpointDisabledByDefault(t *testing.T) { @@ -46,3 +54,151 @@ func TestDebugEndpointAddr(t *testing.T) { }) } } + +// quietLifecycleLogger keeps lifecycle tests from spamming the test output. +func quietLifecycleLogger() *log.Logger { + l := log.New() + l.SetOutput(io.Discard) + l.SetLevel(log.PanicLevel) + return l +} + +func TestStopBeforeStartIsNoOp(t *testing.T) { + srv := New(Config{Logger: quietLifecycleLogger()}) + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + + err := srv.Stop(ctx) + assert.NoError(t, err, "Stop on an unstarted server must succeed without error") + + err = srv.Stop(ctx) + assert.NoError(t, err, "Stop must remain idempotent across repeated calls") +} + +func TestStartFailsWithoutManagement(t *testing.T) { + srv := New(Config{ + Logger: quietLifecycleLogger(), + ListenAddr: "127.0.0.1:0", + ManagementAddress: "://broken-url", + }) + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + + err := srv.Start(ctx) + require.Error(t, err, "Start must surface management dial failures") + + assert.True(t, srv.started, "started flag is set before any dial attempt so a second Start fails fast") + + err = srv.Start(ctx) + require.Error(t, err, "second Start must reject") + assert.Contains(t, err.Error(), "already started", "error must explain why the call was rejected") +} + +func TestStopIsIdempotent(t *testing.T) { + srv := &Server{ + Logger: quietLifecycleLogger(), + started: true, + runErrCh: make(chan struct{}), + runCancel: func() {}, + } + srv.recordRunErr(errors.New("synthetic")) + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + + err := srv.Stop(ctx) + require.Error(t, err, "Stop must surface the recorded background error") + assert.Contains(t, err.Error(), "synthetic", "error must round-trip recordRunErr's value") + + err = srv.Stop(ctx) + require.Error(t, err, "second Stop must still report the same error") + assert.Contains(t, err.Error(), "synthetic", "idempotent Stop must return the cached error") +} + +func TestRecordRunErrPreservesFirstFailure(t *testing.T) { + srv := &Server{ + Logger: quietLifecycleLogger(), + runErrCh: make(chan struct{}), + } + + srv.recordRunErr(errors.New("first")) + srv.recordRunErr(errors.New("second")) + + require.Error(t, srv.runErr, "first failure must be retained") + assert.Contains(t, srv.runErr.Error(), "first", "second call must not overwrite the cached error") + + select { + case <-srv.runErrCh: + default: + t.Fatal("recordRunErr must close runErrCh so waitAndStop unblocks") + } +} + +func TestStopSkipsShutdownWhenNeverStarted(t *testing.T) { + srv := New(Config{Logger: quietLifecycleLogger()}) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + err := srv.Stop(ctx) + assert.NoError(t, err, "Stop on an unstarted server should not block on the cancelled ctx") +} + +func TestRedactMappingForLog_ScrubsSensitiveFields(t *testing.T) { + original := &proto.ProxyMapping{ + Id: "svc-1", + Domain: "example.com", + AuthToken: "super-secret-token", + Auth: &proto.Authentication{ + SessionKey: "pubkey-not-secret", + HeaderAuths: []*proto.HeaderAuth{ + {Header: "Authorization", HashedValue: "argon2-hash-1"}, + {Header: "X-Api-Key", HashedValue: "argon2-hash-2"}, + }, + }, + Path: []*proto.PathMapping{ + { + Path: "/api", + Target: "10.0.0.1:8080", + Options: &proto.PathTargetOptions{ + CustomHeaders: map[string]string{ + "Authorization": "Bearer upstream-token", + "X-Tenant": "acme", + }, + }, + }, + }, + } + + redacted := redactMappingForLog(original) + + assert.Equal(t, "super-secret-token", original.AuthToken, "original must not be mutated") + assert.Equal(t, "argon2-hash-1", original.Auth.HeaderAuths[0].HashedValue, "original header hash must not be mutated") + assert.Equal(t, "Bearer upstream-token", original.Path[0].Options.CustomHeaders["Authorization"], "original custom header must not be mutated") + + assert.Equal(t, "[REDACTED]", redacted.AuthToken, "auth_token must be redacted") + require.Len(t, redacted.Auth.HeaderAuths, 2, "header auths must be preserved in count") + assert.Equal(t, "Authorization", redacted.Auth.HeaderAuths[0].Header, "header name must be preserved") + assert.Equal(t, "[REDACTED]", redacted.Auth.HeaderAuths[0].HashedValue, "hashed_value must be redacted") + assert.Equal(t, "[REDACTED]", redacted.Auth.HeaderAuths[1].HashedValue, "hashed_value must be redacted for every header auth") + assert.Equal(t, "pubkey-not-secret", redacted.Auth.SessionKey, "session_key (public) must be preserved") + + headers := redacted.Path[0].Options.CustomHeaders + require.Len(t, headers, 2, "custom header keys must be preserved") + assert.Equal(t, "[REDACTED]", headers["Authorization"], "custom header values must be redacted") + assert.Equal(t, "[REDACTED]", headers["X-Tenant"], "every custom header value must be redacted") + + assert.Equal(t, "svc-1", redacted.Id, "non-sensitive fields must round-trip") + assert.Equal(t, "example.com", redacted.Domain, "non-sensitive fields must round-trip") +} + +func TestRedactMappingForLog_HandlesEmptyOrNilFields(t *testing.T) { + empty := &proto.ProxyMapping{Id: "svc-empty"} + redacted := redactMappingForLog(empty) + + assert.Equal(t, "", redacted.AuthToken, "empty auth_token must remain empty (no placeholder)") + assert.Nil(t, redacted.Auth, "nil Auth must remain nil") + assert.Empty(t, redacted.Path, "empty Path must remain empty") +} diff --git a/shared/management/client/rest/client.go b/shared/management/client/rest/client.go index f0cb4d2d1..43312b9e6 100644 --- a/shared/management/client/rest/client.go +++ b/shared/management/client/rest/client.go @@ -143,6 +143,10 @@ type Client struct { // ReverseProxyDomains NetBird reverse proxy domains APIs ReverseProxyDomains *ReverseProxyDomainsAPI + + // ReverseProxyTokens account-scoped proxy access tokens used to register + // self-hosted (bring-your-own-proxy) `netbird proxy` instances. + ReverseProxyTokens *ReverseProxyTokensAPI } // New initialize new Client instance using PAT token @@ -204,6 +208,7 @@ func (c *Client) initialize() { c.ReverseProxyServices = &ReverseProxyServicesAPI{c} c.ReverseProxyClusters = &ReverseProxyClustersAPI{c} c.ReverseProxyDomains = &ReverseProxyDomainsAPI{c} + c.ReverseProxyTokens = &ReverseProxyTokensAPI{c} } // NewRequest creates and executes new management API request diff --git a/shared/management/client/rest/reverse_proxy_clusters.go b/shared/management/client/rest/reverse_proxy_clusters.go index b55cd35a3..249833b01 100644 --- a/shared/management/client/rest/reverse_proxy_clusters.go +++ b/shared/management/client/rest/reverse_proxy_clusters.go @@ -2,6 +2,7 @@ package rest import ( "context" + "net/url" "github.com/netbirdio/netbird/shared/management/http/api" ) @@ -11,7 +12,10 @@ type ReverseProxyClustersAPI struct { c *Client } -// List lists all available proxy clusters +// List lists all available proxy clusters. Each cluster is enriched with the +// capability flags reported by its connected proxies (supports_custom_ports, +// supports_crowdsec, private, etc.), so callers can render UX gates without +// a follow-up round-trip. func (a *ReverseProxyClustersAPI) List(ctx context.Context) ([]api.ProxyCluster, error) { resp, err := a.c.NewRequest(ctx, "GET", "/api/reverse-proxies/clusters", nil, nil) if err != nil { @@ -23,3 +27,18 @@ func (a *ReverseProxyClustersAPI) List(ctx context.Context) ([]api.ProxyCluster, ret, err := parseResponse[[]api.ProxyCluster](resp) return ret, err } + +// Delete removes every self-hosted (BYOP) proxy registration for the given +// cluster address owned by the calling account. Shared clusters operated by +// NetBird cannot be deleted via this endpoint; the server returns 404 / 400 +// for cluster addresses the account does not own. +func (a *ReverseProxyClustersAPI) Delete(ctx context.Context, clusterAddress string) error { + resp, err := a.c.NewRequest(ctx, "DELETE", "/api/reverse-proxies/clusters/"+url.PathEscape(clusterAddress), nil, nil) + if err != nil { + return err + } + if resp.Body != nil { + defer resp.Body.Close() + } + return nil +} diff --git a/shared/management/client/rest/reverse_proxy_clusters_test.go b/shared/management/client/rest/reverse_proxy_clusters_test.go new file mode 100644 index 000000000..2d9f6f7bb --- /dev/null +++ b/shared/management/client/rest/reverse_proxy_clusters_test.go @@ -0,0 +1,90 @@ +//go:build integration + +package rest_test + +import ( + "context" + "encoding/json" + "net/http" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/shared/management/client/rest" + "github.com/netbirdio/netbird/shared/management/http/api" + "github.com/netbirdio/netbird/shared/management/http/util" +) + +func boolPtr(b bool) *bool { return &b } + +var testCluster = api.ProxyCluster{ + Id: "cluster-1", + Address: "proxy.netbird.local", + Type: "shared", + Online: true, + ConnectedProxies: 2, + SupportsCustomPorts: boolPtr(true), + RequireSubdomain: boolPtr(false), + SupportsCrowdsec: boolPtr(false), + Private: boolPtr(true), +} + +func TestReverseProxyClusters_List_200(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/reverse-proxies/clusters", func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "GET", r.Method, "List must use GET") + retBytes, _ := json.Marshal([]api.ProxyCluster{testCluster}) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.ReverseProxyClusters.List(context.Background()) + require.NoError(t, err) + require.Len(t, ret, 1) + assert.Equal(t, testCluster.Id, ret[0].Id) + assert.Equal(t, testCluster.Address, ret[0].Address) + require.NotNil(t, ret[0].Private, "private capability must round-trip through the client") + assert.True(t, *ret[0].Private, "private capability must reflect the server value") + }) +} + +func TestReverseProxyClusters_List_Err(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/reverse-proxies/clusters", func(w http.ResponseWriter, r *http.Request) { + retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 500}) + w.WriteHeader(500) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.ReverseProxyClusters.List(context.Background()) + assert.Error(t, err) + assert.Empty(t, ret) + }) +} + +func TestReverseProxyClusters_Delete_200(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + // PathEscape on "proxy.netbird.local" leaves it intact; the route mux + // matches the unescaped form. Sanity-check both the method and that + // path-escaping doesn't double-encode the dotted address. + mux.HandleFunc("/api/reverse-proxies/clusters/proxy.netbird.local", func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "DELETE", r.Method, "Delete must use DELETE") + w.WriteHeader(200) + }) + err := c.ReverseProxyClusters.Delete(context.Background(), "proxy.netbird.local") + require.NoError(t, err) + }) +} + +func TestReverseProxyClusters_Delete_Err(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/reverse-proxies/clusters/proxy.netbird.local", func(w http.ResponseWriter, r *http.Request) { + retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404}) + w.WriteHeader(404) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + err := c.ReverseProxyClusters.Delete(context.Background(), "proxy.netbird.local") + assert.Error(t, err) + }) +} diff --git a/shared/management/client/rest/reverse_proxy_services_test.go b/shared/management/client/rest/reverse_proxy_services_test.go index 164563e97..1a93472db 100644 --- a/shared/management/client/rest/reverse_proxy_services_test.go +++ b/shared/management/client/rest/reverse_proxy_services_test.go @@ -116,8 +116,8 @@ func TestReverseProxyServices_Create_200(t *testing.T) { Name: "test-service", Domain: "test.example.com", Enabled: true, - Auth: api.ServiceAuthConfig{}, - Targets: []api.ServiceTarget{testServiceTarget}, + Auth: &api.ServiceAuthConfig{}, + Targets: &[]api.ServiceTarget{testServiceTarget}, }) require.NoError(t, err) assert.Equal(t, testService.Id, ret.Id) @@ -136,8 +136,8 @@ func TestReverseProxyServices_Create_Err(t *testing.T) { Name: "test-service", Domain: "test.example.com", Enabled: true, - Auth: api.ServiceAuthConfig{}, - Targets: []api.ServiceTarget{testServiceTarget}, + Auth: &api.ServiceAuthConfig{}, + Targets: &[]api.ServiceTarget{testServiceTarget}, }) assert.Error(t, err) assert.Equal(t, "No", err.Error()) @@ -154,8 +154,9 @@ func TestReverseProxyServices_Create_WithPerTargetOptions(t *testing.T) { var req api.ServiceRequest require.NoError(t, json.Unmarshal(reqBytes, &req)) - require.Len(t, req.Targets, 1) - target := req.Targets[0] + require.NotNil(t, req.Targets, "targets must be set on the request") + require.Len(t, *req.Targets, 1) + target := (*req.Targets)[0] require.NotNil(t, target.Options, "options should be present") opts := target.Options require.NotNil(t, opts.SkipTlsVerify, "skip_tls_verify should be present") @@ -177,8 +178,8 @@ func TestReverseProxyServices_Create_WithPerTargetOptions(t *testing.T) { Name: "test-service", Domain: "test.example.com", Enabled: true, - Auth: api.ServiceAuthConfig{}, - Targets: []api.ServiceTarget{ + Auth: &api.ServiceAuthConfig{}, + Targets: &[]api.ServiceTarget{ { TargetId: "peer-123", TargetType: "peer", @@ -216,8 +217,8 @@ func TestReverseProxyServices_Update_200(t *testing.T) { Name: "updated-service", Domain: "test.example.com", Enabled: true, - Auth: api.ServiceAuthConfig{}, - Targets: []api.ServiceTarget{testServiceTarget}, + Auth: &api.ServiceAuthConfig{}, + Targets: &[]api.ServiceTarget{testServiceTarget}, }) require.NoError(t, err) assert.Equal(t, testService.Id, ret.Id) @@ -236,8 +237,8 @@ func TestReverseProxyServices_Update_Err(t *testing.T) { Name: "updated-service", Domain: "test.example.com", Enabled: true, - Auth: api.ServiceAuthConfig{}, - Targets: []api.ServiceTarget{testServiceTarget}, + Auth: &api.ServiceAuthConfig{}, + Targets: &[]api.ServiceTarget{testServiceTarget}, }) assert.Error(t, err) assert.Equal(t, "No", err.Error()) diff --git a/shared/management/client/rest/reverse_proxy_tokens.go b/shared/management/client/rest/reverse_proxy_tokens.go new file mode 100644 index 000000000..de59f3176 --- /dev/null +++ b/shared/management/client/rest/reverse_proxy_tokens.go @@ -0,0 +1,72 @@ +package rest + +import ( + "bytes" + "context" + "encoding/json" + "net/url" + + "github.com/netbirdio/netbird/shared/management/http/api" +) + +// ReverseProxyTokensAPI exposes the account-scoped proxy access tokens that +// self-hosted (bring-your-own-proxy) deployments use to register a +// `netbird proxy` instance with management. Tokens are bound to the +// calling account; revoking a token disconnects every proxy that +// authenticated with it. +type ReverseProxyTokensAPI struct { + c *Client +} + +// List returns every proxy token the calling account has minted, including +// already-revoked entries. The plain token is never returned โ€” only the +// metadata (id, name, created_at, last_used, revoked). +func (a *ReverseProxyTokensAPI) List(ctx context.Context) ([]api.ProxyToken, error) { + resp, err := a.c.NewRequest(ctx, "GET", "/api/reverse-proxies/proxy-tokens", nil, nil) + if err != nil { + return nil, err + } + if resp.Body != nil { + defer resp.Body.Close() + } + ret, err := parseResponse[[]api.ProxyToken](resp) + return ret, err +} + +// Create mints a fresh account-scoped proxy token. The returned +// ProxyTokenCreated.PlainToken is shown only once โ€” callers must persist +// it immediately. Subsequent reads will only expose the token metadata, +// not the secret material. +func (a *ReverseProxyTokensAPI) Create(ctx context.Context, request api.ProxyTokenRequest) (*api.ProxyTokenCreated, error) { + requestBytes, err := json.Marshal(request) + if err != nil { + return nil, err + } + resp, err := a.c.NewRequest(ctx, "POST", "/api/reverse-proxies/proxy-tokens", bytes.NewReader(requestBytes), nil) + if err != nil { + return nil, err + } + if resp.Body != nil { + defer resp.Body.Close() + } + ret, err := parseResponse[api.ProxyTokenCreated](resp) + if err != nil { + return nil, err + } + return &ret, nil +} + +// Delete revokes a previously-issued proxy token by ID. Revoked tokens +// remain in List output (with revoked=true) so operators can audit which +// credentials existed; the plain secret can no longer authenticate any +// new proxy registration. +func (a *ReverseProxyTokensAPI) Delete(ctx context.Context, tokenID string) error { + resp, err := a.c.NewRequest(ctx, "DELETE", "/api/reverse-proxies/proxy-tokens/"+url.PathEscape(tokenID), nil, nil) + if err != nil { + return err + } + if resp.Body != nil { + defer resp.Body.Close() + } + return nil +} diff --git a/shared/management/client/rest/reverse_proxy_tokens_test.go b/shared/management/client/rest/reverse_proxy_tokens_test.go new file mode 100644 index 000000000..a3f5e014f --- /dev/null +++ b/shared/management/client/rest/reverse_proxy_tokens_test.go @@ -0,0 +1,131 @@ +//go:build integration + +package rest_test + +import ( + "context" + "encoding/json" + "io" + "net/http" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/shared/management/client/rest" + "github.com/netbirdio/netbird/shared/management/http/api" + "github.com/netbirdio/netbird/shared/management/http/util" +) + +func intPtr(v int) *int { return &v } + +var testProxyToken = api.ProxyToken{ + Id: "tok-1", + Name: "ci-runner", + CreatedAt: time.Date(2026, 5, 21, 9, 0, 0, 0, time.UTC), + Revoked: false, +} + +var testProxyTokenCreated = api.ProxyTokenCreated{ + Id: "tok-1", + Name: "ci-runner", + CreatedAt: time.Date(2026, 5, 21, 9, 0, 0, 0, time.UTC), + PlainToken: "nbproxy_abcdef0123456789", + Revoked: false, +} + +func TestReverseProxyTokens_List_200(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/reverse-proxies/proxy-tokens", func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "GET", r.Method, "List must use GET") + retBytes, _ := json.Marshal([]api.ProxyToken{testProxyToken}) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.ReverseProxyTokens.List(context.Background()) + require.NoError(t, err) + require.Len(t, ret, 1) + assert.Equal(t, testProxyToken.Id, ret[0].Id) + assert.Equal(t, testProxyToken.Name, ret[0].Name) + }) +} + +func TestReverseProxyTokens_List_Err(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/reverse-proxies/proxy-tokens", func(w http.ResponseWriter, r *http.Request) { + retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 500}) + w.WriteHeader(500) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.ReverseProxyTokens.List(context.Background()) + assert.Error(t, err) + assert.Empty(t, ret) + }) +} + +func TestReverseProxyTokens_Create_200(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/reverse-proxies/proxy-tokens", func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "POST", r.Method, "Create must use POST") + body, err := io.ReadAll(r.Body) + require.NoError(t, err) + var req api.ProxyTokenRequest + require.NoError(t, json.Unmarshal(body, &req), "server must receive a valid ProxyTokenRequest body") + assert.Equal(t, "ci-runner", req.Name, "name must round-trip through the client") + require.NotNil(t, req.ExpiresIn, "expires_in must be sent when provided") + assert.Equal(t, 3600, *req.ExpiresIn, "expires_in value must round-trip unchanged") + + retBytes, _ := json.Marshal(testProxyTokenCreated) + _, err = w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.ReverseProxyTokens.Create(context.Background(), api.ProxyTokenRequest{ + Name: "ci-runner", + ExpiresIn: intPtr(3600), + }) + require.NoError(t, err) + assert.Equal(t, testProxyTokenCreated.Id, ret.Id) + assert.Equal(t, testProxyTokenCreated.PlainToken, ret.PlainToken, + "PlainToken must be returned to the caller โ€” it's the one-shot secret") + }) +} + +func TestReverseProxyTokens_Create_Err(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/reverse-proxies/proxy-tokens", func(w http.ResponseWriter, r *http.Request) { + retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Bad", Code: 400}) + w.WriteHeader(400) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.ReverseProxyTokens.Create(context.Background(), api.ProxyTokenRequest{Name: ""}) + assert.Error(t, err) + assert.Nil(t, ret) + }) +} + +func TestReverseProxyTokens_Delete_200(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/reverse-proxies/proxy-tokens/tok-1", func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "DELETE", r.Method, "Delete must use DELETE") + w.WriteHeader(200) + }) + err := c.ReverseProxyTokens.Delete(context.Background(), "tok-1") + require.NoError(t, err) + }) +} + +func TestReverseProxyTokens_Delete_Err(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/reverse-proxies/proxy-tokens/tok-1", func(w http.ResponseWriter, r *http.Request) { + retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404}) + w.WriteHeader(404) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + err := c.ReverseProxyTokens.Delete(context.Background(), "tok-1") + assert.Error(t, err) + }) +} diff --git a/shared/management/http/api/openapi.yml b/shared/management/http/api/openapi.yml index 353aff72d..6b8939598 100644 --- a/shared/management/http/api/openapi.yml +++ b/shared/management/http/api/openapi.yml @@ -3067,6 +3067,17 @@ components: $ref: '#/components/schemas/AccessRestrictions' meta: $ref: '#/components/schemas/ServiceMeta' + private: + type: boolean + description: When true, the service is NetBird-only โ€” its target points at a proxy cluster, inbound peers authenticate via their WireGuard tunnel identity (no OIDC), and an ACL policy is auto-generated from access_groups to the cluster's proxy-peer group. Requires mode=http. + default: false + example: false + access_groups: + type: array + items: + type: string + description: NetBird group IDs whose peers may reach this private service over the tunnel. Required when private=true; ignored otherwise. Mutually exclusive with bearer auth (SSO). + example: ["group-engineering"] required: - id - name @@ -3147,6 +3158,17 @@ components: $ref: '#/components/schemas/ServiceAuthConfig' access_restrictions: $ref: '#/components/schemas/AccessRestrictions' + private: + type: boolean + description: When true, the service is NetBird-only โ€” its target points at a proxy cluster, inbound peers authenticate via their WireGuard tunnel identity (no OIDC), and an ACL policy is auto-generated from access_groups to the cluster's proxy-peer group. Requires mode=http. + default: false + example: false + access_groups: + type: array + items: + type: string + description: NetBird group IDs whose peers may reach this private service over the tunnel. Required when private=true; ignored otherwise. Mutually exclusive with bearer auth (SSO). + example: ["group-engineering"] required: - name - domain @@ -3185,6 +3207,15 @@ components: type: string description: Idle timeout before a UDP session is reaped, as a Go duration string (e.g. "30s", "2m"). example: "2m" + direct_upstream: + type: boolean + description: | + When true, the proxy dials this target via the host's network stack + instead of through its embedded NetBird client. Use for upstreams + reachable without WireGuard (public APIs, LAN services, localhost + sidecars). + default: false + example: false ServiceTarget: type: object properties: @@ -3195,7 +3226,7 @@ components: target_type: type: string description: Target type - enum: [peer, host, domain, subnet] + enum: [peer, host, domain, subnet, cluster] example: "subnet" path: type: string @@ -3439,6 +3470,10 @@ components: type: boolean description: Whether all active proxies in the cluster have CrowdSec configured example: false + private: + type: boolean + description: True when at least one connected proxy in this cluster is running embedded in a netbird client (`netbird proxy`) and serving over a WireGuard tunnel. Lets the dashboard distinguish per-peer / private clusters from centralised ones. + example: false required: - id - address @@ -3494,6 +3529,10 @@ components: type: boolean description: Whether the proxy cluster has CrowdSec configured example: false + supports_private: + type: boolean + description: Whether the proxy cluster supports private (NetBird-only) services. True when at least one connected proxy in the cluster runs embedded in a netbird client. + example: false required: - id - domain diff --git a/shared/management/http/api/types.gen.go b/shared/management/http/api/types.gen.go index 16e765f8c..d7945e448 100644 --- a/shared/management/http/api/types.gen.go +++ b/shared/management/http/api/types.gen.go @@ -1063,15 +1063,18 @@ func (e ServiceTargetProtocol) Valid() bool { // Defines values for ServiceTargetTargetType. const ( - ServiceTargetTargetTypeDomain ServiceTargetTargetType = "domain" - ServiceTargetTargetTypeHost ServiceTargetTargetType = "host" - ServiceTargetTargetTypePeer ServiceTargetTargetType = "peer" - ServiceTargetTargetTypeSubnet ServiceTargetTargetType = "subnet" + ServiceTargetTargetTypeCluster ServiceTargetTargetType = "cluster" + ServiceTargetTargetTypeDomain ServiceTargetTargetType = "domain" + ServiceTargetTargetTypeHost ServiceTargetTargetType = "host" + ServiceTargetTargetTypePeer ServiceTargetTargetType = "peer" + ServiceTargetTargetTypeSubnet ServiceTargetTargetType = "subnet" ) // Valid indicates whether the value is a known member of the ServiceTargetTargetType enum. func (e ServiceTargetTargetType) Valid() bool { switch e { + case ServiceTargetTargetTypeCluster: + return true case ServiceTargetTargetTypeDomain: return true case ServiceTargetTargetTypeHost: @@ -3819,6 +3822,9 @@ type ProxyCluster struct { // Online Whether at least one proxy in the cluster has heartbeated within the active window Online bool `json:"online"` + // Private True when at least one connected proxy in this cluster is running embedded in a netbird client (`netbird proxy`) and serving over a WireGuard tunnel. Lets the dashboard distinguish per-peer / private clusters from centralised ones. + Private *bool `json:"private,omitempty"` + // RequireSubdomain Whether services on this cluster must include a subdomain label RequireSubdomain *bool `json:"require_subdomain,omitempty"` @@ -3896,6 +3902,9 @@ type ReverseProxyDomain struct { // SupportsCustomPorts Whether the cluster supports binding arbitrary TCP/UDP ports SupportsCustomPorts *bool `json:"supports_custom_ports,omitempty"` + // SupportsPrivate Whether the proxy cluster supports private (NetBird-only) services. True when at least one connected proxy in the cluster runs embedded in a netbird client. + SupportsPrivate *bool `json:"supports_private,omitempty"` + // TargetCluster The proxy cluster this domain is validated against (only for custom domains) TargetCluster *string `json:"target_cluster,omitempty"` @@ -4085,6 +4094,9 @@ type SentinelOneMatchAttributesNetworkStatus string // Service defines model for Service. type Service struct { + // AccessGroups NetBird group IDs whose peers may reach this private service over the tunnel. Required when private=true; ignored otherwise. Mutually exclusive with bearer auth (SSO). + AccessGroups *[]string `json:"access_groups,omitempty"` + // AccessRestrictions Connection-level access restrictions based on IP address or geography. Applies to both HTTP and L4 services. AccessRestrictions *AccessRestrictions `json:"access_restrictions,omitempty"` Auth ServiceAuthConfig `json:"auth"` @@ -4114,6 +4126,9 @@ type Service struct { // PortAutoAssigned Whether the listen port was auto-assigned PortAutoAssigned *bool `json:"port_auto_assigned,omitempty"` + // Private When true, the service is NetBird-only โ€” its target points at a proxy cluster, inbound peers authenticate via their WireGuard tunnel identity (no OIDC), and an ACL policy is auto-generated from access_groups to the cluster's proxy-peer group. Requires mode=http. + Private *bool `json:"private,omitempty"` + // ProxyCluster The proxy cluster handling this service (derived from domain) ProxyCluster *string `json:"proxy_cluster,omitempty"` @@ -4156,6 +4171,9 @@ type ServiceMetaStatus string // ServiceRequest defines model for ServiceRequest. type ServiceRequest struct { + // AccessGroups NetBird group IDs whose peers may reach this private service over the tunnel. Required when private=true; ignored otherwise. Mutually exclusive with bearer auth (SSO). + AccessGroups *[]string `json:"access_groups,omitempty"` + // AccessRestrictions Connection-level access restrictions based on IP address or geography. Applies to both HTTP and L4 services. AccessRestrictions *AccessRestrictions `json:"access_restrictions,omitempty"` Auth *ServiceAuthConfig `json:"auth,omitempty"` @@ -4178,6 +4196,9 @@ type ServiceRequest struct { // PassHostHeader When true, the original client Host header is passed through to the backend instead of being rewritten to the backend's address PassHostHeader *bool `json:"pass_host_header,omitempty"` + // Private When true, the service is NetBird-only โ€” its target points at a proxy cluster, inbound peers authenticate via their WireGuard tunnel identity (no OIDC), and an ACL policy is auto-generated from access_groups to the cluster's proxy-peer group. Requires mode=http. + Private *bool `json:"private,omitempty"` + // RewriteRedirects When true, Location headers in backend responses are rewritten to replace the backend address with the public-facing domain RewriteRedirects *bool `json:"rewrite_redirects,omitempty"` @@ -4224,6 +4245,12 @@ type ServiceTargetOptions struct { // CustomHeaders Extra headers sent to the backend. Hop-by-hop and proxy-managed headers (Host, Connection, Transfer-Encoding, etc.) are rejected. CustomHeaders *map[string]string `json:"custom_headers,omitempty"` + // DirectUpstream When true, the proxy dials this target via the host's network stack + // instead of through its embedded NetBird client. Use for upstreams + // reachable without WireGuard (public APIs, LAN services, localhost + // sidecars). + DirectUpstream *bool `json:"direct_upstream,omitempty"` + // PathRewrite Controls how the request path is rewritten before forwarding to the backend. Default strips the matched prefix. "preserve" keeps the full original request path. PathRewrite *ServiceTargetOptionsPathRewrite `json:"path_rewrite,omitempty"` diff --git a/shared/management/proto/proxy_service.pb.go b/shared/management/proto/proxy_service.pb.go index a3a5e4588..22c215074 100644 --- a/shared/management/proto/proxy_service.pb.go +++ b/shared/management/proto/proxy_service.pb.go @@ -188,6 +188,13 @@ type ProxyCapabilities struct { RequireSubdomain *bool `protobuf:"varint,2,opt,name=require_subdomain,json=requireSubdomain,proto3,oneof" json:"require_subdomain,omitempty"` // Whether the proxy has CrowdSec configured and can enforce IP reputation checks. SupportsCrowdsec *bool `protobuf:"varint,3,opt,name=supports_crowdsec,json=supportsCrowdsec,proto3,oneof" json:"supports_crowdsec,omitempty"` + // Whether the proxy is running embedded in the netbird client and serving + // exclusively over the WireGuard tunnel (i.e. `netbird proxy` rather than + // the standalone netbird-proxy binary). Surfaces upstream so dashboards can + // distinguish per-peer / private clusters from centralised ones. + Private *bool `protobuf:"varint,4,opt,name=private,proto3,oneof" json:"private,omitempty"` + // Whether the proxy enforces ProxyMapping.private (fails closed on ValidateTunnelPeer failure). Management MUST NOT stream private mappings to proxies that don't claim this. + SupportsPrivateService *bool `protobuf:"varint,5,opt,name=supports_private_service,json=supportsPrivateService,proto3,oneof" json:"supports_private_service,omitempty"` } func (x *ProxyCapabilities) Reset() { @@ -243,6 +250,20 @@ func (x *ProxyCapabilities) GetSupportsCrowdsec() bool { return false } +func (x *ProxyCapabilities) GetPrivate() bool { + if x != nil && x.Private != nil { + return *x.Private + } + return false +} + +func (x *ProxyCapabilities) GetSupportsPrivateService() bool { + if x != nil && x.SupportsPrivateService != nil { + return *x.SupportsPrivateService + } + return false +} + // GetMappingUpdateRequest is sent to initialise a mapping stream. type GetMappingUpdateRequest struct { state protoimpl.MessageState @@ -396,6 +417,11 @@ type PathTargetOptions struct { ProxyProtocol bool `protobuf:"varint,5,opt,name=proxy_protocol,json=proxyProtocol,proto3" json:"proxy_protocol,omitempty"` // Idle timeout before a UDP session is reaped. SessionIdleTimeout *durationpb.Duration `protobuf:"bytes,6,opt,name=session_idle_timeout,json=sessionIdleTimeout,proto3" json:"session_idle_timeout,omitempty"` + // When true, the proxy dials this target via the host's network stack + // instead of through the embedded NetBird client. Useful for upstreams + // reachable without WireGuard (public APIs, LAN services, localhost + // sidecars). Defaults to false โ€” embedded client is the standard path. + DirectUpstream bool `protobuf:"varint,7,opt,name=direct_upstream,json=directUpstream,proto3" json:"direct_upstream,omitempty"` } func (x *PathTargetOptions) Reset() { @@ -472,6 +498,13 @@ func (x *PathTargetOptions) GetSessionIdleTimeout() *durationpb.Duration { return nil } +func (x *PathTargetOptions) GetDirectUpstream() bool { + if x != nil { + return x.DirectUpstream + } + return false +} + type PathMapping struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache @@ -782,6 +815,8 @@ type ProxyMapping struct { // For L4/TLS: the port the proxy listens on. ListenPort int32 `protobuf:"varint,11,opt,name=listen_port,json=listenPort,proto3" json:"listen_port,omitempty"` AccessRestrictions *AccessRestrictions `protobuf:"bytes,12,opt,name=access_restrictions,json=accessRestrictions,proto3" json:"access_restrictions,omitempty"` + // NetBird-only: the proxy MUST call ValidateTunnelPeer and fail closed; operator auth schemes are bypassed. + Private bool `protobuf:"varint,13,opt,name=private,proto3" json:"private,omitempty"` } func (x *ProxyMapping) Reset() { @@ -900,6 +935,13 @@ func (x *ProxyMapping) GetAccessRestrictions() *AccessRestrictions { return nil } +func (x *ProxyMapping) GetPrivate() bool { + if x != nil { + return x.Private + } + return false +} + // SendAccessLogRequest consists of one or more AccessLogs from a Proxy. type SendAccessLogRequest struct { state protoimpl.MessageState @@ -1489,6 +1531,11 @@ type SendStatusUpdateRequest struct { Status ProxyStatus `protobuf:"varint,3,opt,name=status,proto3,enum=management.ProxyStatus" json:"status,omitempty"` CertificateIssued bool `protobuf:"varint,4,opt,name=certificate_issued,json=certificateIssued,proto3" json:"certificate_issued,omitempty"` ErrorMessage *string `protobuf:"bytes,5,opt,name=error_message,json=errorMessage,proto3,oneof" json:"error_message,omitempty"` + // Per-account inbound listener state for the account that owns + // service_id. Populated only when --private-inbound is enabled and the + // embedded client for the account is up. Field numbers >=50 reserved + // for observability extensions. + InboundListener *ProxyInboundListener `protobuf:"bytes,50,opt,name=inbound_listener,json=inboundListener,proto3,oneof" json:"inbound_listener,omitempty"` } func (x *SendStatusUpdateRequest) Reset() { @@ -1558,6 +1605,84 @@ func (x *SendStatusUpdateRequest) GetErrorMessage() string { return "" } +func (x *SendStatusUpdateRequest) GetInboundListener() *ProxyInboundListener { + if x != nil { + return x.InboundListener + } + return nil +} + +// ProxyInboundListener describes a per-account inbound listener that the +// proxy has bound on the embedded netstack of the account's WireGuard +// client. Surfaced so dashboards can render "this account is reachable +// at : on this proxy". +type ProxyInboundListener struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + // Tunnel IP the embedded netstack listens on. Same address other peers + // in the account see for the proxy peer. + TunnelIp string `protobuf:"bytes,1,opt,name=tunnel_ip,json=tunnelIp,proto3" json:"tunnel_ip,omitempty"` + // TLS port served on tunnel_ip (auto-detected, default 443). + HttpsPort uint32 `protobuf:"varint,2,opt,name=https_port,json=httpsPort,proto3" json:"https_port,omitempty"` + // Plain-HTTP port served on tunnel_ip (auto-detected, default 80). + HttpPort uint32 `protobuf:"varint,3,opt,name=http_port,json=httpPort,proto3" json:"http_port,omitempty"` +} + +func (x *ProxyInboundListener) Reset() { + *x = ProxyInboundListener{} + if protoimpl.UnsafeEnabled { + mi := &file_proxy_service_proto_msgTypes[18] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *ProxyInboundListener) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ProxyInboundListener) ProtoMessage() {} + +func (x *ProxyInboundListener) ProtoReflect() protoreflect.Message { + mi := &file_proxy_service_proto_msgTypes[18] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ProxyInboundListener.ProtoReflect.Descriptor instead. +func (*ProxyInboundListener) Descriptor() ([]byte, []int) { + return file_proxy_service_proto_rawDescGZIP(), []int{18} +} + +func (x *ProxyInboundListener) GetTunnelIp() string { + if x != nil { + return x.TunnelIp + } + return "" +} + +func (x *ProxyInboundListener) GetHttpsPort() uint32 { + if x != nil { + return x.HttpsPort + } + return 0 +} + +func (x *ProxyInboundListener) GetHttpPort() uint32 { + if x != nil { + return x.HttpPort + } + return 0 +} + // SendStatusUpdateResponse is intentionally empty to allow for future expansion type SendStatusUpdateResponse struct { state protoimpl.MessageState @@ -1568,7 +1693,7 @@ type SendStatusUpdateResponse struct { func (x *SendStatusUpdateResponse) Reset() { *x = SendStatusUpdateResponse{} if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[18] + mi := &file_proxy_service_proto_msgTypes[19] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1581,7 +1706,7 @@ func (x *SendStatusUpdateResponse) String() string { func (*SendStatusUpdateResponse) ProtoMessage() {} func (x *SendStatusUpdateResponse) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[18] + mi := &file_proxy_service_proto_msgTypes[19] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1594,7 +1719,7 @@ func (x *SendStatusUpdateResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use SendStatusUpdateResponse.ProtoReflect.Descriptor instead. func (*SendStatusUpdateResponse) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{18} + return file_proxy_service_proto_rawDescGZIP(), []int{19} } // CreateProxyPeerRequest is sent by the proxy to create a peer connection @@ -1614,7 +1739,7 @@ type CreateProxyPeerRequest struct { func (x *CreateProxyPeerRequest) Reset() { *x = CreateProxyPeerRequest{} if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[19] + mi := &file_proxy_service_proto_msgTypes[20] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1627,7 +1752,7 @@ func (x *CreateProxyPeerRequest) String() string { func (*CreateProxyPeerRequest) ProtoMessage() {} func (x *CreateProxyPeerRequest) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[19] + mi := &file_proxy_service_proto_msgTypes[20] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1640,7 +1765,7 @@ func (x *CreateProxyPeerRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use CreateProxyPeerRequest.ProtoReflect.Descriptor instead. func (*CreateProxyPeerRequest) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{19} + return file_proxy_service_proto_rawDescGZIP(), []int{20} } func (x *CreateProxyPeerRequest) GetServiceId() string { @@ -1691,7 +1816,7 @@ type CreateProxyPeerResponse struct { func (x *CreateProxyPeerResponse) Reset() { *x = CreateProxyPeerResponse{} if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[20] + mi := &file_proxy_service_proto_msgTypes[21] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1704,7 +1829,7 @@ func (x *CreateProxyPeerResponse) String() string { func (*CreateProxyPeerResponse) ProtoMessage() {} func (x *CreateProxyPeerResponse) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[20] + mi := &file_proxy_service_proto_msgTypes[21] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1717,7 +1842,7 @@ func (x *CreateProxyPeerResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use CreateProxyPeerResponse.ProtoReflect.Descriptor instead. func (*CreateProxyPeerResponse) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{20} + return file_proxy_service_proto_rawDescGZIP(), []int{21} } func (x *CreateProxyPeerResponse) GetSuccess() bool { @@ -1747,7 +1872,7 @@ type GetOIDCURLRequest struct { func (x *GetOIDCURLRequest) Reset() { *x = GetOIDCURLRequest{} if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[21] + mi := &file_proxy_service_proto_msgTypes[22] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1760,7 +1885,7 @@ func (x *GetOIDCURLRequest) String() string { func (*GetOIDCURLRequest) ProtoMessage() {} func (x *GetOIDCURLRequest) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[21] + mi := &file_proxy_service_proto_msgTypes[22] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1773,7 +1898,7 @@ func (x *GetOIDCURLRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use GetOIDCURLRequest.ProtoReflect.Descriptor instead. func (*GetOIDCURLRequest) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{21} + return file_proxy_service_proto_rawDescGZIP(), []int{22} } func (x *GetOIDCURLRequest) GetId() string { @@ -1808,7 +1933,7 @@ type GetOIDCURLResponse struct { func (x *GetOIDCURLResponse) Reset() { *x = GetOIDCURLResponse{} if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[22] + mi := &file_proxy_service_proto_msgTypes[23] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1821,7 +1946,7 @@ func (x *GetOIDCURLResponse) String() string { func (*GetOIDCURLResponse) ProtoMessage() {} func (x *GetOIDCURLResponse) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[22] + mi := &file_proxy_service_proto_msgTypes[23] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1834,7 +1959,7 @@ func (x *GetOIDCURLResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use GetOIDCURLResponse.ProtoReflect.Descriptor instead. func (*GetOIDCURLResponse) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{22} + return file_proxy_service_proto_rawDescGZIP(), []int{23} } func (x *GetOIDCURLResponse) GetUrl() string { @@ -1856,7 +1981,7 @@ type ValidateSessionRequest struct { func (x *ValidateSessionRequest) Reset() { *x = ValidateSessionRequest{} if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[23] + mi := &file_proxy_service_proto_msgTypes[24] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1869,7 +1994,7 @@ func (x *ValidateSessionRequest) String() string { func (*ValidateSessionRequest) ProtoMessage() {} func (x *ValidateSessionRequest) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[23] + mi := &file_proxy_service_proto_msgTypes[24] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1882,7 +2007,7 @@ func (x *ValidateSessionRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use ValidateSessionRequest.ProtoReflect.Descriptor instead. func (*ValidateSessionRequest) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{23} + return file_proxy_service_proto_rawDescGZIP(), []int{24} } func (x *ValidateSessionRequest) GetDomain() string { @@ -1908,12 +2033,21 @@ type ValidateSessionResponse struct { UserId string `protobuf:"bytes,2,opt,name=user_id,json=userId,proto3" json:"user_id,omitempty"` UserEmail string `protobuf:"bytes,3,opt,name=user_email,json=userEmail,proto3" json:"user_email,omitempty"` DeniedReason string `protobuf:"bytes,4,opt,name=denied_reason,json=deniedReason,proto3" json:"denied_reason,omitempty"` + // peer_group_ids carries the calling user's group memberships so the + // proxy can authorise policy-aware middlewares without an additional + // management round-trip. + PeerGroupIds []string `protobuf:"bytes,5,rep,name=peer_group_ids,json=peerGroupIds,proto3" json:"peer_group_ids,omitempty"` + // peer_group_names carries the human-readable display names for the + // ids in peer_group_ids, ordered identically (positional pairing). + // Stamped onto upstream requests as X-NetBird-Groups so downstream + // services can read names rather than opaque ids. + PeerGroupNames []string `protobuf:"bytes,6,rep,name=peer_group_names,json=peerGroupNames,proto3" json:"peer_group_names,omitempty"` } func (x *ValidateSessionResponse) Reset() { *x = ValidateSessionResponse{} if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[24] + mi := &file_proxy_service_proto_msgTypes[25] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1926,7 +2060,7 @@ func (x *ValidateSessionResponse) String() string { func (*ValidateSessionResponse) ProtoMessage() {} func (x *ValidateSessionResponse) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[24] + mi := &file_proxy_service_proto_msgTypes[25] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1939,7 +2073,7 @@ func (x *ValidateSessionResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use ValidateSessionResponse.ProtoReflect.Descriptor instead. func (*ValidateSessionResponse) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{24} + return file_proxy_service_proto_rawDescGZIP(), []int{25} } func (x *ValidateSessionResponse) GetValid() bool { @@ -1970,6 +2104,193 @@ func (x *ValidateSessionResponse) GetDeniedReason() string { return "" } +func (x *ValidateSessionResponse) GetPeerGroupIds() []string { + if x != nil { + return x.PeerGroupIds + } + return nil +} + +func (x *ValidateSessionResponse) GetPeerGroupNames() []string { + if x != nil { + return x.PeerGroupNames + } + return nil +} + +// ValidateTunnelPeerRequest carries the inbound peer's tunnel IP and the +// service domain whose group requirements should gate access. The calling +// account is inferred from the proxy's gRPC metadata (ProxyToken). +type ValidateTunnelPeerRequest struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + TunnelIp string `protobuf:"bytes,1,opt,name=tunnel_ip,json=tunnelIp,proto3" json:"tunnel_ip,omitempty"` + Domain string `protobuf:"bytes,2,opt,name=domain,proto3" json:"domain,omitempty"` +} + +func (x *ValidateTunnelPeerRequest) Reset() { + *x = ValidateTunnelPeerRequest{} + if protoimpl.UnsafeEnabled { + mi := &file_proxy_service_proto_msgTypes[26] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *ValidateTunnelPeerRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ValidateTunnelPeerRequest) ProtoMessage() {} + +func (x *ValidateTunnelPeerRequest) ProtoReflect() protoreflect.Message { + mi := &file_proxy_service_proto_msgTypes[26] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ValidateTunnelPeerRequest.ProtoReflect.Descriptor instead. +func (*ValidateTunnelPeerRequest) Descriptor() ([]byte, []int) { + return file_proxy_service_proto_rawDescGZIP(), []int{26} +} + +func (x *ValidateTunnelPeerRequest) GetTunnelIp() string { + if x != nil { + return x.TunnelIp + } + return "" +} + +func (x *ValidateTunnelPeerRequest) GetDomain() string { + if x != nil { + return x.Domain + } + return "" +} + +// ValidateTunnelPeerResponse mirrors ValidateSessionResponse plus a freshly +// minted session_token: when valid is true, the proxy installs the token as +// a session cookie so subsequent requests skip the management round-trip, +// matching the OIDC flow's UX. denied_reason values: +// +// "peer_not_found" โ€” no peer with that tunnel IP in the calling account +// "no_user" โ€” peer exists but is not bound to a user +// "service_not_found" +// "account_mismatch" +// "not_in_group" โ€” peer resolved but not in service.access_groups +type ValidateTunnelPeerResponse struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Valid bool `protobuf:"varint,1,opt,name=valid,proto3" json:"valid,omitempty"` + UserId string `protobuf:"bytes,2,opt,name=user_id,json=userId,proto3" json:"user_id,omitempty"` + UserEmail string `protobuf:"bytes,3,opt,name=user_email,json=userEmail,proto3" json:"user_email,omitempty"` + DeniedReason string `protobuf:"bytes,4,opt,name=denied_reason,json=deniedReason,proto3" json:"denied_reason,omitempty"` + // session_token is set only when valid is true. Same shape as the JWT + // the OIDC flow produces โ€” proxy installs it via setSessionCookie so the + // tunnel fast-path is indistinguishable from OIDC for subsequent requests. + SessionToken string `protobuf:"bytes,5,opt,name=session_token,json=sessionToken,proto3" json:"session_token,omitempty"` + // peer_group_ids carries the resolved peer's user group memberships so + // the proxy can authorise policy-aware middlewares without an additional + // management round-trip. + PeerGroupIds []string `protobuf:"bytes,6,rep,name=peer_group_ids,json=peerGroupIds,proto3" json:"peer_group_ids,omitempty"` + // peer_group_names carries the human-readable display names for the + // ids in peer_group_ids, ordered identically (positional pairing). + // Stamped onto upstream requests as X-NetBird-Groups so downstream + // services can read names rather than opaque ids. + PeerGroupNames []string `protobuf:"bytes,7,rep,name=peer_group_names,json=peerGroupNames,proto3" json:"peer_group_names,omitempty"` +} + +func (x *ValidateTunnelPeerResponse) Reset() { + *x = ValidateTunnelPeerResponse{} + if protoimpl.UnsafeEnabled { + mi := &file_proxy_service_proto_msgTypes[27] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *ValidateTunnelPeerResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ValidateTunnelPeerResponse) ProtoMessage() {} + +func (x *ValidateTunnelPeerResponse) ProtoReflect() protoreflect.Message { + mi := &file_proxy_service_proto_msgTypes[27] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ValidateTunnelPeerResponse.ProtoReflect.Descriptor instead. +func (*ValidateTunnelPeerResponse) Descriptor() ([]byte, []int) { + return file_proxy_service_proto_rawDescGZIP(), []int{27} +} + +func (x *ValidateTunnelPeerResponse) GetValid() bool { + if x != nil { + return x.Valid + } + return false +} + +func (x *ValidateTunnelPeerResponse) GetUserId() string { + if x != nil { + return x.UserId + } + return "" +} + +func (x *ValidateTunnelPeerResponse) GetUserEmail() string { + if x != nil { + return x.UserEmail + } + return "" +} + +func (x *ValidateTunnelPeerResponse) GetDeniedReason() string { + if x != nil { + return x.DeniedReason + } + return "" +} + +func (x *ValidateTunnelPeerResponse) GetSessionToken() string { + if x != nil { + return x.SessionToken + } + return "" +} + +func (x *ValidateTunnelPeerResponse) GetPeerGroupIds() []string { + if x != nil { + return x.PeerGroupIds + } + return nil +} + +func (x *ValidateTunnelPeerResponse) GetPeerGroupNames() []string { + if x != nil { + return x.PeerGroupNames + } + return nil +} + // SyncMappingsRequest is sent by the proxy on the bidirectional SyncMappings // stream. The first message MUST be an init; all subsequent messages MUST be // acks. @@ -1988,7 +2309,7 @@ type SyncMappingsRequest struct { func (x *SyncMappingsRequest) Reset() { *x = SyncMappingsRequest{} if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[25] + mi := &file_proxy_service_proto_msgTypes[28] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2001,7 +2322,7 @@ func (x *SyncMappingsRequest) String() string { func (*SyncMappingsRequest) ProtoMessage() {} func (x *SyncMappingsRequest) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[25] + mi := &file_proxy_service_proto_msgTypes[28] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2014,7 +2335,7 @@ func (x *SyncMappingsRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use SyncMappingsRequest.ProtoReflect.Descriptor instead. func (*SyncMappingsRequest) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{25} + return file_proxy_service_proto_rawDescGZIP(), []int{28} } func (m *SyncMappingsRequest) GetMsg() isSyncMappingsRequest_Msg { @@ -2071,7 +2392,7 @@ type SyncMappingsInit struct { func (x *SyncMappingsInit) Reset() { *x = SyncMappingsInit{} if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[26] + mi := &file_proxy_service_proto_msgTypes[29] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2084,7 +2405,7 @@ func (x *SyncMappingsInit) String() string { func (*SyncMappingsInit) ProtoMessage() {} func (x *SyncMappingsInit) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[26] + mi := &file_proxy_service_proto_msgTypes[29] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2097,7 +2418,7 @@ func (x *SyncMappingsInit) ProtoReflect() protoreflect.Message { // Deprecated: Use SyncMappingsInit.ProtoReflect.Descriptor instead. func (*SyncMappingsInit) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{26} + return file_proxy_service_proto_rawDescGZIP(), []int{29} } func (x *SyncMappingsInit) GetProxyId() string { @@ -2146,7 +2467,7 @@ type SyncMappingsAck struct { func (x *SyncMappingsAck) Reset() { *x = SyncMappingsAck{} if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[27] + mi := &file_proxy_service_proto_msgTypes[30] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2159,7 +2480,7 @@ func (x *SyncMappingsAck) String() string { func (*SyncMappingsAck) ProtoMessage() {} func (x *SyncMappingsAck) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[27] + mi := &file_proxy_service_proto_msgTypes[30] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2172,7 +2493,7 @@ func (x *SyncMappingsAck) ProtoReflect() protoreflect.Message { // Deprecated: Use SyncMappingsAck.ProtoReflect.Descriptor instead. func (*SyncMappingsAck) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{27} + return file_proxy_service_proto_rawDescGZIP(), []int{30} } // SyncMappingsResponse is a batch of mappings sent by management. @@ -2190,7 +2511,7 @@ type SyncMappingsResponse struct { func (x *SyncMappingsResponse) Reset() { *x = SyncMappingsResponse{} if protoimpl.UnsafeEnabled { - mi := &file_proxy_service_proto_msgTypes[28] + mi := &file_proxy_service_proto_msgTypes[31] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2203,7 +2524,7 @@ func (x *SyncMappingsResponse) String() string { func (*SyncMappingsResponse) ProtoMessage() {} func (x *SyncMappingsResponse) ProtoReflect() protoreflect.Message { - mi := &file_proxy_service_proto_msgTypes[28] + mi := &file_proxy_service_proto_msgTypes[31] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2216,7 +2537,7 @@ func (x *SyncMappingsResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use SyncMappingsResponse.ProtoReflect.Descriptor instead. func (*SyncMappingsResponse) Descriptor() ([]byte, []int) { - return file_proxy_service_proto_rawDescGZIP(), []int{28} + return file_proxy_service_proto_rawDescGZIP(), []int{31} } func (x *SyncMappingsResponse) GetMapping() []*ProxyMapping { @@ -2242,7 +2563,7 @@ var file_proxy_service_proto_rawDesc = []byte{ 0x75, 0x66, 0x2f, 0x64, 0x75, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x1a, 0x1f, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f, 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x2e, 0x70, 0x72, 0x6f, - 0x74, 0x6f, 0x22, 0xf6, 0x01, 0x0a, 0x11, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x43, 0x61, 0x70, 0x61, + 0x74, 0x6f, 0x22, 0xfd, 0x02, 0x0a, 0x11, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x43, 0x61, 0x70, 0x61, 0x62, 0x69, 0x6c, 0x69, 0x74, 0x69, 0x65, 0x73, 0x12, 0x37, 0x0a, 0x15, 0x73, 0x75, 0x70, 0x70, 0x6f, 0x72, 0x74, 0x73, 0x5f, 0x63, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5f, 0x70, 0x6f, 0x72, 0x74, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x48, 0x00, 0x52, 0x13, 0x73, 0x75, 0x70, 0x70, 0x6f, @@ -2253,59 +2574,70 @@ var file_proxy_service_proto_rawDesc = []byte{ 0x88, 0x01, 0x01, 0x12, 0x30, 0x0a, 0x11, 0x73, 0x75, 0x70, 0x70, 0x6f, 0x72, 0x74, 0x73, 0x5f, 0x63, 0x72, 0x6f, 0x77, 0x64, 0x73, 0x65, 0x63, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x48, 0x02, 0x52, 0x10, 0x73, 0x75, 0x70, 0x70, 0x6f, 0x72, 0x74, 0x73, 0x43, 0x72, 0x6f, 0x77, 0x64, 0x73, - 0x65, 0x63, 0x88, 0x01, 0x01, 0x42, 0x18, 0x0a, 0x16, 0x5f, 0x73, 0x75, 0x70, 0x70, 0x6f, 0x72, - 0x74, 0x73, 0x5f, 0x63, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5f, 0x70, 0x6f, 0x72, 0x74, 0x73, 0x42, - 0x14, 0x0a, 0x12, 0x5f, 0x72, 0x65, 0x71, 0x75, 0x69, 0x72, 0x65, 0x5f, 0x73, 0x75, 0x62, 0x64, - 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x42, 0x14, 0x0a, 0x12, 0x5f, 0x73, 0x75, 0x70, 0x70, 0x6f, 0x72, - 0x74, 0x73, 0x5f, 0x63, 0x72, 0x6f, 0x77, 0x64, 0x73, 0x65, 0x63, 0x22, 0xe6, 0x01, 0x0a, 0x17, + 0x65, 0x63, 0x88, 0x01, 0x01, 0x12, 0x1d, 0x0a, 0x07, 0x70, 0x72, 0x69, 0x76, 0x61, 0x74, 0x65, + 0x18, 0x04, 0x20, 0x01, 0x28, 0x08, 0x48, 0x03, 0x52, 0x07, 0x70, 0x72, 0x69, 0x76, 0x61, 0x74, + 0x65, 0x88, 0x01, 0x01, 0x12, 0x3d, 0x0a, 0x18, 0x73, 0x75, 0x70, 0x70, 0x6f, 0x72, 0x74, 0x73, + 0x5f, 0x70, 0x72, 0x69, 0x76, 0x61, 0x74, 0x65, 0x5f, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, + 0x18, 0x05, 0x20, 0x01, 0x28, 0x08, 0x48, 0x04, 0x52, 0x16, 0x73, 0x75, 0x70, 0x70, 0x6f, 0x72, + 0x74, 0x73, 0x50, 0x72, 0x69, 0x76, 0x61, 0x74, 0x65, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, + 0x88, 0x01, 0x01, 0x42, 0x18, 0x0a, 0x16, 0x5f, 0x73, 0x75, 0x70, 0x70, 0x6f, 0x72, 0x74, 0x73, + 0x5f, 0x63, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5f, 0x70, 0x6f, 0x72, 0x74, 0x73, 0x42, 0x14, 0x0a, + 0x12, 0x5f, 0x72, 0x65, 0x71, 0x75, 0x69, 0x72, 0x65, 0x5f, 0x73, 0x75, 0x62, 0x64, 0x6f, 0x6d, + 0x61, 0x69, 0x6e, 0x42, 0x14, 0x0a, 0x12, 0x5f, 0x73, 0x75, 0x70, 0x70, 0x6f, 0x72, 0x74, 0x73, + 0x5f, 0x63, 0x72, 0x6f, 0x77, 0x64, 0x73, 0x65, 0x63, 0x42, 0x0a, 0x0a, 0x08, 0x5f, 0x70, 0x72, + 0x69, 0x76, 0x61, 0x74, 0x65, 0x42, 0x1b, 0x0a, 0x19, 0x5f, 0x73, 0x75, 0x70, 0x70, 0x6f, 0x72, + 0x74, 0x73, 0x5f, 0x70, 0x72, 0x69, 0x76, 0x61, 0x74, 0x65, 0x5f, 0x73, 0x65, 0x72, 0x76, 0x69, + 0x63, 0x65, 0x22, 0xe6, 0x01, 0x0a, 0x17, 0x47, 0x65, 0x74, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, + 0x67, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x19, + 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x78, 0x79, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, + 0x52, 0x07, 0x70, 0x72, 0x6f, 0x78, 0x79, 0x49, 0x64, 0x12, 0x18, 0x0a, 0x07, 0x76, 0x65, 0x72, + 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x76, 0x65, 0x72, 0x73, + 0x69, 0x6f, 0x6e, 0x12, 0x39, 0x0a, 0x0a, 0x73, 0x74, 0x61, 0x72, 0x74, 0x65, 0x64, 0x5f, 0x61, + 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, + 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, + 0x61, 0x6d, 0x70, 0x52, 0x09, 0x73, 0x74, 0x61, 0x72, 0x74, 0x65, 0x64, 0x41, 0x74, 0x12, 0x18, + 0x0a, 0x07, 0x61, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, + 0x07, 0x61, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x12, 0x41, 0x0a, 0x0c, 0x63, 0x61, 0x70, 0x61, + 0x62, 0x69, 0x6c, 0x69, 0x74, 0x69, 0x65, 0x73, 0x18, 0x05, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1d, + 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x78, + 0x79, 0x43, 0x61, 0x70, 0x61, 0x62, 0x69, 0x6c, 0x69, 0x74, 0x69, 0x65, 0x73, 0x52, 0x0c, 0x63, + 0x61, 0x70, 0x61, 0x62, 0x69, 0x6c, 0x69, 0x74, 0x69, 0x65, 0x73, 0x22, 0x82, 0x01, 0x0a, 0x18, 0x47, 0x65, 0x74, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, - 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x19, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x78, 0x79, - 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x70, 0x72, 0x6f, 0x78, 0x79, - 0x49, 0x64, 0x12, 0x18, 0x0a, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x02, 0x20, - 0x01, 0x28, 0x09, 0x52, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x39, 0x0a, 0x0a, - 0x73, 0x74, 0x61, 0x72, 0x74, 0x65, 0x64, 0x5f, 0x61, 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, - 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, - 0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x09, 0x73, 0x74, - 0x61, 0x72, 0x74, 0x65, 0x64, 0x41, 0x74, 0x12, 0x18, 0x0a, 0x07, 0x61, 0x64, 0x64, 0x72, 0x65, - 0x73, 0x73, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x61, 0x64, 0x64, 0x72, 0x65, 0x73, - 0x73, 0x12, 0x41, 0x0a, 0x0c, 0x63, 0x61, 0x70, 0x61, 0x62, 0x69, 0x6c, 0x69, 0x74, 0x69, 0x65, - 0x73, 0x18, 0x05, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1d, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, - 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x43, 0x61, 0x70, 0x61, 0x62, 0x69, - 0x6c, 0x69, 0x74, 0x69, 0x65, 0x73, 0x52, 0x0c, 0x63, 0x61, 0x70, 0x61, 0x62, 0x69, 0x6c, 0x69, - 0x74, 0x69, 0x65, 0x73, 0x22, 0x82, 0x01, 0x0a, 0x18, 0x47, 0x65, 0x74, 0x4d, 0x61, 0x70, 0x70, - 0x69, 0x6e, 0x67, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, - 0x65, 0x12, 0x32, 0x0a, 0x07, 0x6d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x18, 0x01, 0x20, 0x03, - 0x28, 0x0b, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, - 0x50, 0x72, 0x6f, 0x78, 0x79, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x52, 0x07, 0x6d, 0x61, - 0x70, 0x70, 0x69, 0x6e, 0x67, 0x12, 0x32, 0x0a, 0x15, 0x69, 0x6e, 0x69, 0x74, 0x69, 0x61, 0x6c, - 0x5f, 0x73, 0x79, 0x6e, 0x63, 0x5f, 0x63, 0x6f, 0x6d, 0x70, 0x6c, 0x65, 0x74, 0x65, 0x18, 0x02, - 0x20, 0x01, 0x28, 0x08, 0x52, 0x13, 0x69, 0x6e, 0x69, 0x74, 0x69, 0x61, 0x6c, 0x53, 0x79, 0x6e, - 0x63, 0x43, 0x6f, 0x6d, 0x70, 0x6c, 0x65, 0x74, 0x65, 0x22, 0xce, 0x03, 0x0a, 0x11, 0x50, 0x61, - 0x74, 0x68, 0x54, 0x61, 0x72, 0x67, 0x65, 0x74, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x12, - 0x26, 0x0a, 0x0f, 0x73, 0x6b, 0x69, 0x70, 0x5f, 0x74, 0x6c, 0x73, 0x5f, 0x76, 0x65, 0x72, 0x69, - 0x66, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x73, 0x6b, 0x69, 0x70, 0x54, 0x6c, - 0x73, 0x56, 0x65, 0x72, 0x69, 0x66, 0x79, 0x12, 0x42, 0x0a, 0x0f, 0x72, 0x65, 0x71, 0x75, 0x65, - 0x73, 0x74, 0x5f, 0x74, 0x69, 0x6d, 0x65, 0x6f, 0x75, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, + 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x32, 0x0a, 0x07, 0x6d, 0x61, 0x70, 0x70, + 0x69, 0x6e, 0x67, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, + 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x4d, 0x61, 0x70, 0x70, + 0x69, 0x6e, 0x67, 0x52, 0x07, 0x6d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x12, 0x32, 0x0a, 0x15, + 0x69, 0x6e, 0x69, 0x74, 0x69, 0x61, 0x6c, 0x5f, 0x73, 0x79, 0x6e, 0x63, 0x5f, 0x63, 0x6f, 0x6d, + 0x70, 0x6c, 0x65, 0x74, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x13, 0x69, 0x6e, 0x69, + 0x74, 0x69, 0x61, 0x6c, 0x53, 0x79, 0x6e, 0x63, 0x43, 0x6f, 0x6d, 0x70, 0x6c, 0x65, 0x74, 0x65, + 0x22, 0xf7, 0x03, 0x0a, 0x11, 0x50, 0x61, 0x74, 0x68, 0x54, 0x61, 0x72, 0x67, 0x65, 0x74, 0x4f, + 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x12, 0x26, 0x0a, 0x0f, 0x73, 0x6b, 0x69, 0x70, 0x5f, 0x74, + 0x6c, 0x73, 0x5f, 0x76, 0x65, 0x72, 0x69, 0x66, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, + 0x0d, 0x73, 0x6b, 0x69, 0x70, 0x54, 0x6c, 0x73, 0x56, 0x65, 0x72, 0x69, 0x66, 0x79, 0x12, 0x42, + 0x0a, 0x0f, 0x72, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x5f, 0x74, 0x69, 0x6d, 0x65, 0x6f, 0x75, + 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x19, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, + 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x44, 0x75, 0x72, 0x61, 0x74, 0x69, + 0x6f, 0x6e, 0x52, 0x0e, 0x72, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x54, 0x69, 0x6d, 0x65, 0x6f, + 0x75, 0x74, 0x12, 0x3e, 0x0a, 0x0c, 0x70, 0x61, 0x74, 0x68, 0x5f, 0x72, 0x65, 0x77, 0x72, 0x69, + 0x74, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x1b, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, + 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x61, 0x74, 0x68, 0x52, 0x65, 0x77, 0x72, 0x69, 0x74, + 0x65, 0x4d, 0x6f, 0x64, 0x65, 0x52, 0x0b, 0x70, 0x61, 0x74, 0x68, 0x52, 0x65, 0x77, 0x72, 0x69, + 0x74, 0x65, 0x12, 0x57, 0x0a, 0x0e, 0x63, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5f, 0x68, 0x65, 0x61, + 0x64, 0x65, 0x72, 0x73, 0x18, 0x04, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x30, 0x2e, 0x6d, 0x61, 0x6e, + 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x61, 0x74, 0x68, 0x54, 0x61, 0x72, 0x67, + 0x65, 0x74, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x2e, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, + 0x48, 0x65, 0x61, 0x64, 0x65, 0x72, 0x73, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x52, 0x0d, 0x63, 0x75, + 0x73, 0x74, 0x6f, 0x6d, 0x48, 0x65, 0x61, 0x64, 0x65, 0x72, 0x73, 0x12, 0x25, 0x0a, 0x0e, 0x70, + 0x72, 0x6f, 0x78, 0x79, 0x5f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x05, 0x20, + 0x01, 0x28, 0x08, 0x52, 0x0d, 0x70, 0x72, 0x6f, 0x78, 0x79, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, + 0x6f, 0x6c, 0x12, 0x4b, 0x0a, 0x14, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x5f, 0x69, 0x64, + 0x6c, 0x65, 0x5f, 0x74, 0x69, 0x6d, 0x65, 0x6f, 0x75, 0x74, 0x18, 0x06, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x19, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, - 0x75, 0x66, 0x2e, 0x44, 0x75, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x0e, 0x72, 0x65, 0x71, - 0x75, 0x65, 0x73, 0x74, 0x54, 0x69, 0x6d, 0x65, 0x6f, 0x75, 0x74, 0x12, 0x3e, 0x0a, 0x0c, 0x70, - 0x61, 0x74, 0x68, 0x5f, 0x72, 0x65, 0x77, 0x72, 0x69, 0x74, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, - 0x0e, 0x32, 0x1b, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, - 0x61, 0x74, 0x68, 0x52, 0x65, 0x77, 0x72, 0x69, 0x74, 0x65, 0x4d, 0x6f, 0x64, 0x65, 0x52, 0x0b, - 0x70, 0x61, 0x74, 0x68, 0x52, 0x65, 0x77, 0x72, 0x69, 0x74, 0x65, 0x12, 0x57, 0x0a, 0x0e, 0x63, - 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5f, 0x68, 0x65, 0x61, 0x64, 0x65, 0x72, 0x73, 0x18, 0x04, 0x20, - 0x03, 0x28, 0x0b, 0x32, 0x30, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, - 0x2e, 0x50, 0x61, 0x74, 0x68, 0x54, 0x61, 0x72, 0x67, 0x65, 0x74, 0x4f, 0x70, 0x74, 0x69, 0x6f, - 0x6e, 0x73, 0x2e, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x48, 0x65, 0x61, 0x64, 0x65, 0x72, 0x73, - 0x45, 0x6e, 0x74, 0x72, 0x79, 0x52, 0x0d, 0x63, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x48, 0x65, 0x61, - 0x64, 0x65, 0x72, 0x73, 0x12, 0x25, 0x0a, 0x0e, 0x70, 0x72, 0x6f, 0x78, 0x79, 0x5f, 0x70, 0x72, - 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x05, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x70, 0x72, - 0x6f, 0x78, 0x79, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x4b, 0x0a, 0x14, 0x73, - 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x5f, 0x69, 0x64, 0x6c, 0x65, 0x5f, 0x74, 0x69, 0x6d, 0x65, - 0x6f, 0x75, 0x74, 0x18, 0x06, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x19, 0x2e, 0x67, 0x6f, 0x6f, 0x67, - 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x44, 0x75, 0x72, 0x61, - 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x12, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x49, 0x64, 0x6c, - 0x65, 0x54, 0x69, 0x6d, 0x65, 0x6f, 0x75, 0x74, 0x1a, 0x40, 0x0a, 0x12, 0x43, 0x75, 0x73, 0x74, + 0x75, 0x66, 0x2e, 0x44, 0x75, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x12, 0x73, 0x65, 0x73, + 0x73, 0x69, 0x6f, 0x6e, 0x49, 0x64, 0x6c, 0x65, 0x54, 0x69, 0x6d, 0x65, 0x6f, 0x75, 0x74, 0x12, + 0x27, 0x0a, 0x0f, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x5f, 0x75, 0x70, 0x73, 0x74, 0x72, 0x65, + 0x61, 0x6d, 0x18, 0x07, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0e, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, + 0x55, 0x70, 0x73, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x1a, 0x40, 0x0a, 0x12, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x48, 0x65, 0x61, 0x64, 0x65, 0x72, 0x73, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x14, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, @@ -2350,7 +2682,7 @@ var file_proxy_service_proto_rawDesc = []byte{ 0x64, 0x43, 0x6f, 0x75, 0x6e, 0x74, 0x72, 0x69, 0x65, 0x73, 0x12, 0x23, 0x0a, 0x0d, 0x63, 0x72, 0x6f, 0x77, 0x64, 0x73, 0x65, 0x63, 0x5f, 0x6d, 0x6f, 0x64, 0x65, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x63, 0x72, 0x6f, 0x77, 0x64, 0x73, 0x65, 0x63, 0x4d, 0x6f, 0x64, 0x65, 0x22, - 0xe6, 0x03, 0x0a, 0x0c, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, + 0x80, 0x04, 0x0a, 0x0c, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x12, 0x36, 0x0a, 0x04, 0x74, 0x79, 0x70, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x22, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x54, 0x79, @@ -2380,244 +2712,292 @@ var file_proxy_service_proto_rawDesc = []byte{ 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1e, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x41, 0x63, 0x63, 0x65, 0x73, 0x73, 0x52, 0x65, 0x73, 0x74, 0x72, 0x69, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x52, 0x12, 0x61, 0x63, 0x63, 0x65, 0x73, 0x73, 0x52, 0x65, 0x73, 0x74, - 0x72, 0x69, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x22, 0x3f, 0x0a, 0x14, 0x53, 0x65, 0x6e, 0x64, - 0x41, 0x63, 0x63, 0x65, 0x73, 0x73, 0x4c, 0x6f, 0x67, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, - 0x12, 0x27, 0x0a, 0x03, 0x6c, 0x6f, 0x67, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x15, 0x2e, - 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x41, 0x63, 0x63, 0x65, 0x73, - 0x73, 0x4c, 0x6f, 0x67, 0x52, 0x03, 0x6c, 0x6f, 0x67, 0x22, 0x17, 0x0a, 0x15, 0x53, 0x65, 0x6e, - 0x64, 0x41, 0x63, 0x63, 0x65, 0x73, 0x73, 0x4c, 0x6f, 0x67, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, - 0x73, 0x65, 0x22, 0x84, 0x05, 0x0a, 0x09, 0x41, 0x63, 0x63, 0x65, 0x73, 0x73, 0x4c, 0x6f, 0x67, - 0x12, 0x38, 0x0a, 0x09, 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x18, 0x01, 0x20, - 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, - 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, - 0x09, 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x12, 0x15, 0x0a, 0x06, 0x6c, 0x6f, - 0x67, 0x5f, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x6c, 0x6f, 0x67, 0x49, - 0x64, 0x12, 0x1d, 0x0a, 0x0a, 0x61, 0x63, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x18, - 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x61, 0x63, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x49, 0x64, - 0x12, 0x1d, 0x0a, 0x0a, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x5f, 0x69, 0x64, 0x18, 0x04, - 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x49, 0x64, 0x12, - 0x12, 0x0a, 0x04, 0x68, 0x6f, 0x73, 0x74, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x68, - 0x6f, 0x73, 0x74, 0x12, 0x12, 0x0a, 0x04, 0x70, 0x61, 0x74, 0x68, 0x18, 0x06, 0x20, 0x01, 0x28, - 0x09, 0x52, 0x04, 0x70, 0x61, 0x74, 0x68, 0x12, 0x1f, 0x0a, 0x0b, 0x64, 0x75, 0x72, 0x61, 0x74, - 0x69, 0x6f, 0x6e, 0x5f, 0x6d, 0x73, 0x18, 0x07, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0a, 0x64, 0x75, - 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x4d, 0x73, 0x12, 0x16, 0x0a, 0x06, 0x6d, 0x65, 0x74, 0x68, - 0x6f, 0x64, 0x18, 0x08, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x6d, 0x65, 0x74, 0x68, 0x6f, 0x64, - 0x12, 0x23, 0x0a, 0x0d, 0x72, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x5f, 0x63, 0x6f, 0x64, - 0x65, 0x18, 0x09, 0x20, 0x01, 0x28, 0x05, 0x52, 0x0c, 0x72, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, - 0x65, 0x43, 0x6f, 0x64, 0x65, 0x12, 0x1b, 0x0a, 0x09, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x5f, - 0x69, 0x70, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, - 0x49, 0x70, 0x12, 0x25, 0x0a, 0x0e, 0x61, 0x75, 0x74, 0x68, 0x5f, 0x6d, 0x65, 0x63, 0x68, 0x61, - 0x6e, 0x69, 0x73, 0x6d, 0x18, 0x0b, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0d, 0x61, 0x75, 0x74, 0x68, - 0x4d, 0x65, 0x63, 0x68, 0x61, 0x6e, 0x69, 0x73, 0x6d, 0x12, 0x17, 0x0a, 0x07, 0x75, 0x73, 0x65, - 0x72, 0x5f, 0x69, 0x64, 0x18, 0x0c, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x75, 0x73, 0x65, 0x72, - 0x49, 0x64, 0x12, 0x21, 0x0a, 0x0c, 0x61, 0x75, 0x74, 0x68, 0x5f, 0x73, 0x75, 0x63, 0x63, 0x65, - 0x73, 0x73, 0x18, 0x0d, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0b, 0x61, 0x75, 0x74, 0x68, 0x53, 0x75, - 0x63, 0x63, 0x65, 0x73, 0x73, 0x12, 0x21, 0x0a, 0x0c, 0x62, 0x79, 0x74, 0x65, 0x73, 0x5f, 0x75, - 0x70, 0x6c, 0x6f, 0x61, 0x64, 0x18, 0x0e, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0b, 0x62, 0x79, 0x74, - 0x65, 0x73, 0x55, 0x70, 0x6c, 0x6f, 0x61, 0x64, 0x12, 0x25, 0x0a, 0x0e, 0x62, 0x79, 0x74, 0x65, - 0x73, 0x5f, 0x64, 0x6f, 0x77, 0x6e, 0x6c, 0x6f, 0x61, 0x64, 0x18, 0x0f, 0x20, 0x01, 0x28, 0x03, - 0x52, 0x0d, 0x62, 0x79, 0x74, 0x65, 0x73, 0x44, 0x6f, 0x77, 0x6e, 0x6c, 0x6f, 0x61, 0x64, 0x12, - 0x1a, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x10, 0x20, 0x01, 0x28, - 0x09, 0x52, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x3f, 0x0a, 0x08, 0x6d, - 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x18, 0x11, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x23, 0x2e, - 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x41, 0x63, 0x63, 0x65, 0x73, - 0x73, 0x4c, 0x6f, 0x67, 0x2e, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x45, 0x6e, 0x74, - 0x72, 0x79, 0x52, 0x08, 0x6d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x1a, 0x3b, 0x0a, 0x0d, - 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x12, 0x10, 0x0a, - 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, - 0x14, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, - 0x76, 0x61, 0x6c, 0x75, 0x65, 0x3a, 0x02, 0x38, 0x01, 0x22, 0xf8, 0x01, 0x0a, 0x13, 0x41, 0x75, - 0x74, 0x68, 0x65, 0x6e, 0x74, 0x69, 0x63, 0x61, 0x74, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, - 0x74, 0x12, 0x0e, 0x0a, 0x02, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x69, + 0x72, 0x69, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x70, 0x72, 0x69, 0x76, + 0x61, 0x74, 0x65, 0x18, 0x0d, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x70, 0x72, 0x69, 0x76, 0x61, + 0x74, 0x65, 0x22, 0x3f, 0x0a, 0x14, 0x53, 0x65, 0x6e, 0x64, 0x41, 0x63, 0x63, 0x65, 0x73, 0x73, + 0x4c, 0x6f, 0x67, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x27, 0x0a, 0x03, 0x6c, 0x6f, + 0x67, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x15, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, + 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x41, 0x63, 0x63, 0x65, 0x73, 0x73, 0x4c, 0x6f, 0x67, 0x52, 0x03, + 0x6c, 0x6f, 0x67, 0x22, 0x17, 0x0a, 0x15, 0x53, 0x65, 0x6e, 0x64, 0x41, 0x63, 0x63, 0x65, 0x73, + 0x73, 0x4c, 0x6f, 0x67, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x84, 0x05, 0x0a, + 0x09, 0x41, 0x63, 0x63, 0x65, 0x73, 0x73, 0x4c, 0x6f, 0x67, 0x12, 0x38, 0x0a, 0x09, 0x74, 0x69, + 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, + 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, + 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x09, 0x74, 0x69, 0x6d, 0x65, 0x73, + 0x74, 0x61, 0x6d, 0x70, 0x12, 0x15, 0x0a, 0x06, 0x6c, 0x6f, 0x67, 0x5f, 0x69, 0x64, 0x18, 0x02, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x6c, 0x6f, 0x67, 0x49, 0x64, 0x12, 0x1d, 0x0a, 0x0a, 0x61, + 0x63, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, + 0x09, 0x61, 0x63, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x49, 0x64, 0x12, 0x1d, 0x0a, 0x0a, 0x73, 0x65, + 0x72, 0x76, 0x69, 0x63, 0x65, 0x5f, 0x69, 0x64, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, + 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x49, 0x64, 0x12, 0x12, 0x0a, 0x04, 0x68, 0x6f, 0x73, + 0x74, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x68, 0x6f, 0x73, 0x74, 0x12, 0x12, 0x0a, + 0x04, 0x70, 0x61, 0x74, 0x68, 0x18, 0x06, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x70, 0x61, 0x74, + 0x68, 0x12, 0x1f, 0x0a, 0x0b, 0x64, 0x75, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x5f, 0x6d, 0x73, + 0x18, 0x07, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0a, 0x64, 0x75, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, + 0x4d, 0x73, 0x12, 0x16, 0x0a, 0x06, 0x6d, 0x65, 0x74, 0x68, 0x6f, 0x64, 0x18, 0x08, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x06, 0x6d, 0x65, 0x74, 0x68, 0x6f, 0x64, 0x12, 0x23, 0x0a, 0x0d, 0x72, 0x65, + 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x5f, 0x63, 0x6f, 0x64, 0x65, 0x18, 0x09, 0x20, 0x01, 0x28, + 0x05, 0x52, 0x0c, 0x72, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x43, 0x6f, 0x64, 0x65, 0x12, + 0x1b, 0x0a, 0x09, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x5f, 0x69, 0x70, 0x18, 0x0a, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x08, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x49, 0x70, 0x12, 0x25, 0x0a, 0x0e, + 0x61, 0x75, 0x74, 0x68, 0x5f, 0x6d, 0x65, 0x63, 0x68, 0x61, 0x6e, 0x69, 0x73, 0x6d, 0x18, 0x0b, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x0d, 0x61, 0x75, 0x74, 0x68, 0x4d, 0x65, 0x63, 0x68, 0x61, 0x6e, + 0x69, 0x73, 0x6d, 0x12, 0x17, 0x0a, 0x07, 0x75, 0x73, 0x65, 0x72, 0x5f, 0x69, 0x64, 0x18, 0x0c, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x75, 0x73, 0x65, 0x72, 0x49, 0x64, 0x12, 0x21, 0x0a, 0x0c, + 0x61, 0x75, 0x74, 0x68, 0x5f, 0x73, 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, 0x18, 0x0d, 0x20, 0x01, + 0x28, 0x08, 0x52, 0x0b, 0x61, 0x75, 0x74, 0x68, 0x53, 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, 0x12, + 0x21, 0x0a, 0x0c, 0x62, 0x79, 0x74, 0x65, 0x73, 0x5f, 0x75, 0x70, 0x6c, 0x6f, 0x61, 0x64, 0x18, + 0x0e, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0b, 0x62, 0x79, 0x74, 0x65, 0x73, 0x55, 0x70, 0x6c, 0x6f, + 0x61, 0x64, 0x12, 0x25, 0x0a, 0x0e, 0x62, 0x79, 0x74, 0x65, 0x73, 0x5f, 0x64, 0x6f, 0x77, 0x6e, + 0x6c, 0x6f, 0x61, 0x64, 0x18, 0x0f, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0d, 0x62, 0x79, 0x74, 0x65, + 0x73, 0x44, 0x6f, 0x77, 0x6e, 0x6c, 0x6f, 0x61, 0x64, 0x12, 0x1a, 0x0a, 0x08, 0x70, 0x72, 0x6f, + 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x10, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x70, 0x72, 0x6f, + 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x3f, 0x0a, 0x08, 0x6d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, + 0x61, 0x18, 0x11, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x23, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, + 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x41, 0x63, 0x63, 0x65, 0x73, 0x73, 0x4c, 0x6f, 0x67, 0x2e, 0x4d, + 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x52, 0x08, 0x6d, 0x65, + 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x1a, 0x3b, 0x0a, 0x0d, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, + 0x74, 0x61, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x14, 0x0a, 0x05, 0x76, 0x61, 0x6c, + 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x3a, + 0x02, 0x38, 0x01, 0x22, 0xf8, 0x01, 0x0a, 0x13, 0x41, 0x75, 0x74, 0x68, 0x65, 0x6e, 0x74, 0x69, + 0x63, 0x61, 0x74, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x0e, 0x0a, 0x02, 0x69, + 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x69, 0x64, 0x12, 0x1d, 0x0a, 0x0a, 0x61, + 0x63, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, + 0x09, 0x61, 0x63, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x49, 0x64, 0x12, 0x39, 0x0a, 0x08, 0x70, 0x61, + 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1b, 0x2e, 0x6d, + 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x61, 0x73, 0x73, 0x77, 0x6f, + 0x72, 0x64, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x48, 0x00, 0x52, 0x08, 0x70, 0x61, 0x73, + 0x73, 0x77, 0x6f, 0x72, 0x64, 0x12, 0x2a, 0x0a, 0x03, 0x70, 0x69, 0x6e, 0x18, 0x04, 0x20, 0x01, + 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, + 0x50, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x48, 0x00, 0x52, 0x03, 0x70, 0x69, + 0x6e, 0x12, 0x40, 0x0a, 0x0b, 0x68, 0x65, 0x61, 0x64, 0x65, 0x72, 0x5f, 0x61, 0x75, 0x74, 0x68, + 0x18, 0x05, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1d, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, + 0x65, 0x6e, 0x74, 0x2e, 0x48, 0x65, 0x61, 0x64, 0x65, 0x72, 0x41, 0x75, 0x74, 0x68, 0x52, 0x65, + 0x71, 0x75, 0x65, 0x73, 0x74, 0x48, 0x00, 0x52, 0x0a, 0x68, 0x65, 0x61, 0x64, 0x65, 0x72, 0x41, + 0x75, 0x74, 0x68, 0x42, 0x09, 0x0a, 0x07, 0x72, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0x57, + 0x0a, 0x11, 0x48, 0x65, 0x61, 0x64, 0x65, 0x72, 0x41, 0x75, 0x74, 0x68, 0x52, 0x65, 0x71, 0x75, + 0x65, 0x73, 0x74, 0x12, 0x21, 0x0a, 0x0c, 0x68, 0x65, 0x61, 0x64, 0x65, 0x72, 0x5f, 0x76, 0x61, + 0x6c, 0x75, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0b, 0x68, 0x65, 0x61, 0x64, 0x65, + 0x72, 0x56, 0x61, 0x6c, 0x75, 0x65, 0x12, 0x1f, 0x0a, 0x0b, 0x68, 0x65, 0x61, 0x64, 0x65, 0x72, + 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0a, 0x68, 0x65, 0x61, + 0x64, 0x65, 0x72, 0x4e, 0x61, 0x6d, 0x65, 0x22, 0x2d, 0x0a, 0x0f, 0x50, 0x61, 0x73, 0x73, 0x77, + 0x6f, 0x72, 0x64, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1a, 0x0a, 0x08, 0x70, 0x61, + 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x70, 0x61, + 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x22, 0x1e, 0x0a, 0x0a, 0x50, 0x69, 0x6e, 0x52, 0x65, 0x71, + 0x75, 0x65, 0x73, 0x74, 0x12, 0x10, 0x0a, 0x03, 0x70, 0x69, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, + 0x09, 0x52, 0x03, 0x70, 0x69, 0x6e, 0x22, 0x55, 0x0a, 0x14, 0x41, 0x75, 0x74, 0x68, 0x65, 0x6e, + 0x74, 0x69, 0x63, 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x18, + 0x0a, 0x07, 0x73, 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, + 0x07, 0x73, 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, 0x12, 0x23, 0x0a, 0x0d, 0x73, 0x65, 0x73, 0x73, + 0x69, 0x6f, 0x6e, 0x5f, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, + 0x0c, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x22, 0xda, 0x02, + 0x0a, 0x17, 0x53, 0x65, 0x6e, 0x64, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x55, 0x70, 0x64, 0x61, + 0x74, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1d, 0x0a, 0x0a, 0x73, 0x65, 0x72, + 0x76, 0x69, 0x63, 0x65, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x73, + 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x49, 0x64, 0x12, 0x1d, 0x0a, 0x0a, 0x61, 0x63, 0x63, 0x6f, + 0x75, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x61, 0x63, + 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x49, 0x64, 0x12, 0x2f, 0x0a, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, + 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x17, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, + 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, + 0x52, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x2d, 0x0a, 0x12, 0x63, 0x65, 0x72, 0x74, + 0x69, 0x66, 0x69, 0x63, 0x61, 0x74, 0x65, 0x5f, 0x69, 0x73, 0x73, 0x75, 0x65, 0x64, 0x18, 0x04, + 0x20, 0x01, 0x28, 0x08, 0x52, 0x11, 0x63, 0x65, 0x72, 0x74, 0x69, 0x66, 0x69, 0x63, 0x61, 0x74, + 0x65, 0x49, 0x73, 0x73, 0x75, 0x65, 0x64, 0x12, 0x28, 0x0a, 0x0d, 0x65, 0x72, 0x72, 0x6f, 0x72, + 0x5f, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x48, 0x00, + 0x52, 0x0c, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x88, 0x01, + 0x01, 0x12, 0x50, 0x0a, 0x10, 0x69, 0x6e, 0x62, 0x6f, 0x75, 0x6e, 0x64, 0x5f, 0x6c, 0x69, 0x73, + 0x74, 0x65, 0x6e, 0x65, 0x72, 0x18, 0x32, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x20, 0x2e, 0x6d, 0x61, + 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x49, 0x6e, + 0x62, 0x6f, 0x75, 0x6e, 0x64, 0x4c, 0x69, 0x73, 0x74, 0x65, 0x6e, 0x65, 0x72, 0x48, 0x01, 0x52, + 0x0f, 0x69, 0x6e, 0x62, 0x6f, 0x75, 0x6e, 0x64, 0x4c, 0x69, 0x73, 0x74, 0x65, 0x6e, 0x65, 0x72, + 0x88, 0x01, 0x01, 0x42, 0x10, 0x0a, 0x0e, 0x5f, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x5f, 0x6d, 0x65, + 0x73, 0x73, 0x61, 0x67, 0x65, 0x42, 0x13, 0x0a, 0x11, 0x5f, 0x69, 0x6e, 0x62, 0x6f, 0x75, 0x6e, + 0x64, 0x5f, 0x6c, 0x69, 0x73, 0x74, 0x65, 0x6e, 0x65, 0x72, 0x22, 0x6f, 0x0a, 0x14, 0x50, 0x72, + 0x6f, 0x78, 0x79, 0x49, 0x6e, 0x62, 0x6f, 0x75, 0x6e, 0x64, 0x4c, 0x69, 0x73, 0x74, 0x65, 0x6e, + 0x65, 0x72, 0x12, 0x1b, 0x0a, 0x09, 0x74, 0x75, 0x6e, 0x6e, 0x65, 0x6c, 0x5f, 0x69, 0x70, 0x18, + 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x74, 0x75, 0x6e, 0x6e, 0x65, 0x6c, 0x49, 0x70, 0x12, + 0x1d, 0x0a, 0x0a, 0x68, 0x74, 0x74, 0x70, 0x73, 0x5f, 0x70, 0x6f, 0x72, 0x74, 0x18, 0x02, 0x20, + 0x01, 0x28, 0x0d, 0x52, 0x09, 0x68, 0x74, 0x74, 0x70, 0x73, 0x50, 0x6f, 0x72, 0x74, 0x12, 0x1b, + 0x0a, 0x09, 0x68, 0x74, 0x74, 0x70, 0x5f, 0x70, 0x6f, 0x72, 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, + 0x0d, 0x52, 0x08, 0x68, 0x74, 0x74, 0x70, 0x50, 0x6f, 0x72, 0x74, 0x22, 0x1a, 0x0a, 0x18, 0x53, + 0x65, 0x6e, 0x64, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x52, + 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0xb8, 0x01, 0x0a, 0x16, 0x43, 0x72, 0x65, 0x61, + 0x74, 0x65, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x50, 0x65, 0x65, 0x72, 0x52, 0x65, 0x71, 0x75, 0x65, + 0x73, 0x74, 0x12, 0x1d, 0x0a, 0x0a, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x5f, 0x69, 0x64, + 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x49, 0x64, 0x12, 0x1d, 0x0a, 0x0a, 0x61, 0x63, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x61, 0x63, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x49, 0x64, - 0x12, 0x39, 0x0a, 0x08, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x18, 0x03, 0x20, 0x01, - 0x28, 0x0b, 0x32, 0x1b, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, - 0x50, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x48, - 0x00, 0x52, 0x08, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x12, 0x2a, 0x0a, 0x03, 0x70, - 0x69, 0x6e, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, - 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, - 0x48, 0x00, 0x52, 0x03, 0x70, 0x69, 0x6e, 0x12, 0x40, 0x0a, 0x0b, 0x68, 0x65, 0x61, 0x64, 0x65, - 0x72, 0x5f, 0x61, 0x75, 0x74, 0x68, 0x18, 0x05, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1d, 0x2e, 0x6d, - 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x48, 0x65, 0x61, 0x64, 0x65, 0x72, - 0x41, 0x75, 0x74, 0x68, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x48, 0x00, 0x52, 0x0a, 0x68, - 0x65, 0x61, 0x64, 0x65, 0x72, 0x41, 0x75, 0x74, 0x68, 0x42, 0x09, 0x0a, 0x07, 0x72, 0x65, 0x71, - 0x75, 0x65, 0x73, 0x74, 0x22, 0x57, 0x0a, 0x11, 0x48, 0x65, 0x61, 0x64, 0x65, 0x72, 0x41, 0x75, - 0x74, 0x68, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x21, 0x0a, 0x0c, 0x68, 0x65, 0x61, - 0x64, 0x65, 0x72, 0x5f, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, - 0x0b, 0x68, 0x65, 0x61, 0x64, 0x65, 0x72, 0x56, 0x61, 0x6c, 0x75, 0x65, 0x12, 0x1f, 0x0a, 0x0b, - 0x68, 0x65, 0x61, 0x64, 0x65, 0x72, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, - 0x09, 0x52, 0x0a, 0x68, 0x65, 0x61, 0x64, 0x65, 0x72, 0x4e, 0x61, 0x6d, 0x65, 0x22, 0x2d, 0x0a, - 0x0f, 0x50, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, - 0x12, 0x1a, 0x0a, 0x08, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x18, 0x01, 0x20, 0x01, - 0x28, 0x09, 0x52, 0x08, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x22, 0x1e, 0x0a, 0x0a, - 0x50, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x10, 0x0a, 0x03, 0x70, 0x69, - 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x70, 0x69, 0x6e, 0x22, 0x55, 0x0a, 0x14, - 0x41, 0x75, 0x74, 0x68, 0x65, 0x6e, 0x74, 0x69, 0x63, 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, - 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x18, 0x0a, 0x07, 0x73, 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, 0x18, - 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x73, 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, 0x12, 0x23, - 0x0a, 0x0d, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x5f, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x18, - 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x54, 0x6f, - 0x6b, 0x65, 0x6e, 0x22, 0xf3, 0x01, 0x0a, 0x17, 0x53, 0x65, 0x6e, 0x64, 0x53, 0x74, 0x61, 0x74, - 0x75, 0x73, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, - 0x1d, 0x0a, 0x0a, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, - 0x01, 0x28, 0x09, 0x52, 0x09, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x49, 0x64, 0x12, 0x1d, - 0x0a, 0x0a, 0x61, 0x63, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, - 0x28, 0x09, 0x52, 0x09, 0x61, 0x63, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x49, 0x64, 0x12, 0x2f, 0x0a, - 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x17, 0x2e, - 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x78, 0x79, - 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x2d, - 0x0a, 0x12, 0x63, 0x65, 0x72, 0x74, 0x69, 0x66, 0x69, 0x63, 0x61, 0x74, 0x65, 0x5f, 0x69, 0x73, - 0x73, 0x75, 0x65, 0x64, 0x18, 0x04, 0x20, 0x01, 0x28, 0x08, 0x52, 0x11, 0x63, 0x65, 0x72, 0x74, - 0x69, 0x66, 0x69, 0x63, 0x61, 0x74, 0x65, 0x49, 0x73, 0x73, 0x75, 0x65, 0x64, 0x12, 0x28, 0x0a, - 0x0d, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x5f, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x18, 0x05, - 0x20, 0x01, 0x28, 0x09, 0x48, 0x00, 0x52, 0x0c, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x4d, 0x65, 0x73, - 0x73, 0x61, 0x67, 0x65, 0x88, 0x01, 0x01, 0x42, 0x10, 0x0a, 0x0e, 0x5f, 0x65, 0x72, 0x72, 0x6f, - 0x72, 0x5f, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x1a, 0x0a, 0x18, 0x53, 0x65, 0x6e, - 0x64, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, - 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0xb8, 0x01, 0x0a, 0x16, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, - 0x50, 0x72, 0x6f, 0x78, 0x79, 0x50, 0x65, 0x65, 0x72, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, - 0x12, 0x1d, 0x0a, 0x0a, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x5f, 0x69, 0x64, 0x18, 0x01, - 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x49, 0x64, 0x12, - 0x1d, 0x0a, 0x0a, 0x61, 0x63, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x02, 0x20, - 0x01, 0x28, 0x09, 0x52, 0x09, 0x61, 0x63, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x49, 0x64, 0x12, 0x14, - 0x0a, 0x05, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x74, - 0x6f, 0x6b, 0x65, 0x6e, 0x12, 0x30, 0x0a, 0x14, 0x77, 0x69, 0x72, 0x65, 0x67, 0x75, 0x61, 0x72, - 0x64, 0x5f, 0x70, 0x75, 0x62, 0x6c, 0x69, 0x63, 0x5f, 0x6b, 0x65, 0x79, 0x18, 0x04, 0x20, 0x01, - 0x28, 0x09, 0x52, 0x12, 0x77, 0x69, 0x72, 0x65, 0x67, 0x75, 0x61, 0x72, 0x64, 0x50, 0x75, 0x62, - 0x6c, 0x69, 0x63, 0x4b, 0x65, 0x79, 0x12, 0x18, 0x0a, 0x07, 0x63, 0x6c, 0x75, 0x73, 0x74, 0x65, - 0x72, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x63, 0x6c, 0x75, 0x73, 0x74, 0x65, 0x72, - 0x22, 0x6f, 0x0a, 0x17, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x50, - 0x65, 0x65, 0x72, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x18, 0x0a, 0x07, 0x73, - 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x73, 0x75, - 0x63, 0x63, 0x65, 0x73, 0x73, 0x12, 0x28, 0x0a, 0x0d, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x5f, 0x6d, - 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x48, 0x00, 0x52, 0x0c, - 0x65, 0x72, 0x72, 0x6f, 0x72, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x88, 0x01, 0x01, 0x42, - 0x10, 0x0a, 0x0e, 0x5f, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x5f, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, - 0x65, 0x22, 0x65, 0x0a, 0x11, 0x47, 0x65, 0x74, 0x4f, 0x49, 0x44, 0x43, 0x55, 0x52, 0x4c, 0x52, - 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x0e, 0x0a, 0x02, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, - 0x28, 0x09, 0x52, 0x02, 0x69, 0x64, 0x12, 0x1d, 0x0a, 0x0a, 0x61, 0x63, 0x63, 0x6f, 0x75, 0x6e, - 0x74, 0x5f, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x61, 0x63, 0x63, 0x6f, - 0x75, 0x6e, 0x74, 0x49, 0x64, 0x12, 0x21, 0x0a, 0x0c, 0x72, 0x65, 0x64, 0x69, 0x72, 0x65, 0x63, - 0x74, 0x5f, 0x75, 0x72, 0x6c, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0b, 0x72, 0x65, 0x64, - 0x69, 0x72, 0x65, 0x63, 0x74, 0x55, 0x72, 0x6c, 0x22, 0x26, 0x0a, 0x12, 0x47, 0x65, 0x74, 0x4f, - 0x49, 0x44, 0x43, 0x55, 0x52, 0x4c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x10, - 0x0a, 0x03, 0x75, 0x72, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x75, 0x72, 0x6c, - 0x22, 0x55, 0x0a, 0x16, 0x56, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x65, 0x53, 0x65, 0x73, 0x73, - 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x16, 0x0a, 0x06, 0x64, 0x6f, - 0x6d, 0x61, 0x69, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x64, 0x6f, 0x6d, 0x61, - 0x69, 0x6e, 0x12, 0x23, 0x0a, 0x0d, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x5f, 0x74, 0x6f, - 0x6b, 0x65, 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x73, 0x65, 0x73, 0x73, 0x69, - 0x6f, 0x6e, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x22, 0x8c, 0x01, 0x0a, 0x17, 0x56, 0x61, 0x6c, 0x69, - 0x64, 0x61, 0x74, 0x65, 0x53, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, - 0x6e, 0x73, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, - 0x28, 0x08, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x69, 0x64, 0x12, 0x17, 0x0a, 0x07, 0x75, 0x73, 0x65, - 0x72, 0x5f, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x75, 0x73, 0x65, 0x72, - 0x49, 0x64, 0x12, 0x1d, 0x0a, 0x0a, 0x75, 0x73, 0x65, 0x72, 0x5f, 0x65, 0x6d, 0x61, 0x69, 0x6c, - 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x75, 0x73, 0x65, 0x72, 0x45, 0x6d, 0x61, 0x69, - 0x6c, 0x12, 0x23, 0x0a, 0x0d, 0x64, 0x65, 0x6e, 0x69, 0x65, 0x64, 0x5f, 0x72, 0x65, 0x61, 0x73, - 0x6f, 0x6e, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x64, 0x65, 0x6e, 0x69, 0x65, 0x64, - 0x52, 0x65, 0x61, 0x73, 0x6f, 0x6e, 0x22, 0x81, 0x01, 0x0a, 0x13, 0x53, 0x79, 0x6e, 0x63, 0x4d, - 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x32, - 0x0a, 0x04, 0x69, 0x6e, 0x69, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1c, 0x2e, 0x6d, - 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x79, 0x6e, 0x63, 0x4d, 0x61, - 0x70, 0x70, 0x69, 0x6e, 0x67, 0x73, 0x49, 0x6e, 0x69, 0x74, 0x48, 0x00, 0x52, 0x04, 0x69, 0x6e, - 0x69, 0x74, 0x12, 0x2f, 0x0a, 0x03, 0x61, 0x63, 0x6b, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, - 0x1b, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x79, 0x6e, - 0x63, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x73, 0x41, 0x63, 0x6b, 0x48, 0x00, 0x52, 0x03, - 0x61, 0x63, 0x6b, 0x42, 0x05, 0x0a, 0x03, 0x6d, 0x73, 0x67, 0x22, 0xdf, 0x01, 0x0a, 0x10, 0x53, - 0x79, 0x6e, 0x63, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x73, 0x49, 0x6e, 0x69, 0x74, 0x12, - 0x19, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x78, 0x79, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, - 0x09, 0x52, 0x07, 0x70, 0x72, 0x6f, 0x78, 0x79, 0x49, 0x64, 0x12, 0x18, 0x0a, 0x07, 0x76, 0x65, - 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x76, 0x65, 0x72, - 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x39, 0x0a, 0x0a, 0x73, 0x74, 0x61, 0x72, 0x74, 0x65, 0x64, 0x5f, - 0x61, 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, - 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73, - 0x74, 0x61, 0x6d, 0x70, 0x52, 0x09, 0x73, 0x74, 0x61, 0x72, 0x74, 0x65, 0x64, 0x41, 0x74, 0x12, - 0x18, 0x0a, 0x07, 0x61, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, - 0x52, 0x07, 0x61, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x12, 0x41, 0x0a, 0x0c, 0x63, 0x61, 0x70, - 0x61, 0x62, 0x69, 0x6c, 0x69, 0x74, 0x69, 0x65, 0x73, 0x18, 0x05, 0x20, 0x01, 0x28, 0x0b, 0x32, - 0x1d, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x72, 0x6f, - 0x78, 0x79, 0x43, 0x61, 0x70, 0x61, 0x62, 0x69, 0x6c, 0x69, 0x74, 0x69, 0x65, 0x73, 0x52, 0x0c, - 0x63, 0x61, 0x70, 0x61, 0x62, 0x69, 0x6c, 0x69, 0x74, 0x69, 0x65, 0x73, 0x22, 0x11, 0x0a, 0x0f, - 0x53, 0x79, 0x6e, 0x63, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x73, 0x41, 0x63, 0x6b, 0x22, - 0x7e, 0x0a, 0x14, 0x53, 0x79, 0x6e, 0x63, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x73, 0x52, - 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x32, 0x0a, 0x07, 0x6d, 0x61, 0x70, 0x70, 0x69, - 0x6e, 0x67, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, - 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x4d, 0x61, 0x70, 0x70, 0x69, - 0x6e, 0x67, 0x52, 0x07, 0x6d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x12, 0x32, 0x0a, 0x15, 0x69, - 0x6e, 0x69, 0x74, 0x69, 0x61, 0x6c, 0x5f, 0x73, 0x79, 0x6e, 0x63, 0x5f, 0x63, 0x6f, 0x6d, 0x70, - 0x6c, 0x65, 0x74, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x13, 0x69, 0x6e, 0x69, 0x74, - 0x69, 0x61, 0x6c, 0x53, 0x79, 0x6e, 0x63, 0x43, 0x6f, 0x6d, 0x70, 0x6c, 0x65, 0x74, 0x65, 0x2a, - 0x64, 0x0a, 0x16, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x55, - 0x70, 0x64, 0x61, 0x74, 0x65, 0x54, 0x79, 0x70, 0x65, 0x12, 0x17, 0x0a, 0x13, 0x55, 0x50, 0x44, - 0x41, 0x54, 0x45, 0x5f, 0x54, 0x59, 0x50, 0x45, 0x5f, 0x43, 0x52, 0x45, 0x41, 0x54, 0x45, 0x44, - 0x10, 0x00, 0x12, 0x18, 0x0a, 0x14, 0x55, 0x50, 0x44, 0x41, 0x54, 0x45, 0x5f, 0x54, 0x59, 0x50, - 0x45, 0x5f, 0x4d, 0x4f, 0x44, 0x49, 0x46, 0x49, 0x45, 0x44, 0x10, 0x01, 0x12, 0x17, 0x0a, 0x13, - 0x55, 0x50, 0x44, 0x41, 0x54, 0x45, 0x5f, 0x54, 0x59, 0x50, 0x45, 0x5f, 0x52, 0x45, 0x4d, 0x4f, - 0x56, 0x45, 0x44, 0x10, 0x02, 0x2a, 0x46, 0x0a, 0x0f, 0x50, 0x61, 0x74, 0x68, 0x52, 0x65, 0x77, - 0x72, 0x69, 0x74, 0x65, 0x4d, 0x6f, 0x64, 0x65, 0x12, 0x18, 0x0a, 0x14, 0x50, 0x41, 0x54, 0x48, - 0x5f, 0x52, 0x45, 0x57, 0x52, 0x49, 0x54, 0x45, 0x5f, 0x44, 0x45, 0x46, 0x41, 0x55, 0x4c, 0x54, - 0x10, 0x00, 0x12, 0x19, 0x0a, 0x15, 0x50, 0x41, 0x54, 0x48, 0x5f, 0x52, 0x45, 0x57, 0x52, 0x49, - 0x54, 0x45, 0x5f, 0x50, 0x52, 0x45, 0x53, 0x45, 0x52, 0x56, 0x45, 0x10, 0x01, 0x2a, 0xc8, 0x01, - 0x0a, 0x0b, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x18, 0x0a, - 0x14, 0x50, 0x52, 0x4f, 0x58, 0x59, 0x5f, 0x53, 0x54, 0x41, 0x54, 0x55, 0x53, 0x5f, 0x50, 0x45, - 0x4e, 0x44, 0x49, 0x4e, 0x47, 0x10, 0x00, 0x12, 0x17, 0x0a, 0x13, 0x50, 0x52, 0x4f, 0x58, 0x59, - 0x5f, 0x53, 0x54, 0x41, 0x54, 0x55, 0x53, 0x5f, 0x41, 0x43, 0x54, 0x49, 0x56, 0x45, 0x10, 0x01, - 0x12, 0x23, 0x0a, 0x1f, 0x50, 0x52, 0x4f, 0x58, 0x59, 0x5f, 0x53, 0x54, 0x41, 0x54, 0x55, 0x53, - 0x5f, 0x54, 0x55, 0x4e, 0x4e, 0x45, 0x4c, 0x5f, 0x4e, 0x4f, 0x54, 0x5f, 0x43, 0x52, 0x45, 0x41, - 0x54, 0x45, 0x44, 0x10, 0x02, 0x12, 0x24, 0x0a, 0x20, 0x50, 0x52, 0x4f, 0x58, 0x59, 0x5f, 0x53, - 0x54, 0x41, 0x54, 0x55, 0x53, 0x5f, 0x43, 0x45, 0x52, 0x54, 0x49, 0x46, 0x49, 0x43, 0x41, 0x54, - 0x45, 0x5f, 0x50, 0x45, 0x4e, 0x44, 0x49, 0x4e, 0x47, 0x10, 0x03, 0x12, 0x23, 0x0a, 0x1f, 0x50, - 0x52, 0x4f, 0x58, 0x59, 0x5f, 0x53, 0x54, 0x41, 0x54, 0x55, 0x53, 0x5f, 0x43, 0x45, 0x52, 0x54, - 0x49, 0x46, 0x49, 0x43, 0x41, 0x54, 0x45, 0x5f, 0x46, 0x41, 0x49, 0x4c, 0x45, 0x44, 0x10, 0x04, - 0x12, 0x16, 0x0a, 0x12, 0x50, 0x52, 0x4f, 0x58, 0x59, 0x5f, 0x53, 0x54, 0x41, 0x54, 0x55, 0x53, - 0x5f, 0x45, 0x52, 0x52, 0x4f, 0x52, 0x10, 0x05, 0x32, 0xd3, 0x05, 0x0a, 0x0c, 0x50, 0x72, 0x6f, - 0x78, 0x79, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x5f, 0x0a, 0x10, 0x47, 0x65, 0x74, - 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x12, 0x23, 0x2e, - 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x47, 0x65, 0x74, 0x4d, 0x61, - 0x70, 0x70, 0x69, 0x6e, 0x67, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, - 0x73, 0x74, 0x1a, 0x24, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, - 0x47, 0x65, 0x74, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, - 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x30, 0x01, 0x12, 0x55, 0x0a, 0x0c, 0x53, 0x79, - 0x6e, 0x63, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x73, 0x12, 0x1f, 0x2e, 0x6d, 0x61, 0x6e, + 0x12, 0x14, 0x0a, 0x05, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, + 0x05, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x12, 0x30, 0x0a, 0x14, 0x77, 0x69, 0x72, 0x65, 0x67, 0x75, + 0x61, 0x72, 0x64, 0x5f, 0x70, 0x75, 0x62, 0x6c, 0x69, 0x63, 0x5f, 0x6b, 0x65, 0x79, 0x18, 0x04, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x12, 0x77, 0x69, 0x72, 0x65, 0x67, 0x75, 0x61, 0x72, 0x64, 0x50, + 0x75, 0x62, 0x6c, 0x69, 0x63, 0x4b, 0x65, 0x79, 0x12, 0x18, 0x0a, 0x07, 0x63, 0x6c, 0x75, 0x73, + 0x74, 0x65, 0x72, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x63, 0x6c, 0x75, 0x73, 0x74, + 0x65, 0x72, 0x22, 0x6f, 0x0a, 0x17, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x50, 0x72, 0x6f, 0x78, + 0x79, 0x50, 0x65, 0x65, 0x72, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x18, 0x0a, + 0x07, 0x73, 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, + 0x73, 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, 0x12, 0x28, 0x0a, 0x0d, 0x65, 0x72, 0x72, 0x6f, 0x72, + 0x5f, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x48, 0x00, + 0x52, 0x0c, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x88, 0x01, + 0x01, 0x42, 0x10, 0x0a, 0x0e, 0x5f, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x5f, 0x6d, 0x65, 0x73, 0x73, + 0x61, 0x67, 0x65, 0x22, 0x65, 0x0a, 0x11, 0x47, 0x65, 0x74, 0x4f, 0x49, 0x44, 0x43, 0x55, 0x52, + 0x4c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x0e, 0x0a, 0x02, 0x69, 0x64, 0x18, 0x01, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x69, 0x64, 0x12, 0x1d, 0x0a, 0x0a, 0x61, 0x63, 0x63, 0x6f, + 0x75, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x61, 0x63, + 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x49, 0x64, 0x12, 0x21, 0x0a, 0x0c, 0x72, 0x65, 0x64, 0x69, 0x72, + 0x65, 0x63, 0x74, 0x5f, 0x75, 0x72, 0x6c, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0b, 0x72, + 0x65, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x55, 0x72, 0x6c, 0x22, 0x26, 0x0a, 0x12, 0x47, 0x65, + 0x74, 0x4f, 0x49, 0x44, 0x43, 0x55, 0x52, 0x4c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, + 0x12, 0x10, 0x0a, 0x03, 0x75, 0x72, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x75, + 0x72, 0x6c, 0x22, 0x55, 0x0a, 0x16, 0x56, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x65, 0x53, 0x65, + 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x16, 0x0a, 0x06, + 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x64, 0x6f, + 0x6d, 0x61, 0x69, 0x6e, 0x12, 0x23, 0x0a, 0x0d, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x5f, + 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x73, 0x65, 0x73, + 0x73, 0x69, 0x6f, 0x6e, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x22, 0xdc, 0x01, 0x0a, 0x17, 0x56, 0x61, + 0x6c, 0x69, 0x64, 0x61, 0x74, 0x65, 0x53, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x73, + 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x69, 0x64, 0x18, 0x01, + 0x20, 0x01, 0x28, 0x08, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x69, 0x64, 0x12, 0x17, 0x0a, 0x07, 0x75, + 0x73, 0x65, 0x72, 0x5f, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x75, 0x73, + 0x65, 0x72, 0x49, 0x64, 0x12, 0x1d, 0x0a, 0x0a, 0x75, 0x73, 0x65, 0x72, 0x5f, 0x65, 0x6d, 0x61, + 0x69, 0x6c, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x75, 0x73, 0x65, 0x72, 0x45, 0x6d, + 0x61, 0x69, 0x6c, 0x12, 0x23, 0x0a, 0x0d, 0x64, 0x65, 0x6e, 0x69, 0x65, 0x64, 0x5f, 0x72, 0x65, + 0x61, 0x73, 0x6f, 0x6e, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x64, 0x65, 0x6e, 0x69, + 0x65, 0x64, 0x52, 0x65, 0x61, 0x73, 0x6f, 0x6e, 0x12, 0x24, 0x0a, 0x0e, 0x70, 0x65, 0x65, 0x72, + 0x5f, 0x67, 0x72, 0x6f, 0x75, 0x70, 0x5f, 0x69, 0x64, 0x73, 0x18, 0x05, 0x20, 0x03, 0x28, 0x09, + 0x52, 0x0c, 0x70, 0x65, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x49, 0x64, 0x73, 0x12, 0x28, + 0x0a, 0x10, 0x70, 0x65, 0x65, 0x72, 0x5f, 0x67, 0x72, 0x6f, 0x75, 0x70, 0x5f, 0x6e, 0x61, 0x6d, + 0x65, 0x73, 0x18, 0x06, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0e, 0x70, 0x65, 0x65, 0x72, 0x47, 0x72, + 0x6f, 0x75, 0x70, 0x4e, 0x61, 0x6d, 0x65, 0x73, 0x22, 0x50, 0x0a, 0x19, 0x56, 0x61, 0x6c, 0x69, + 0x64, 0x61, 0x74, 0x65, 0x54, 0x75, 0x6e, 0x6e, 0x65, 0x6c, 0x50, 0x65, 0x65, 0x72, 0x52, 0x65, + 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1b, 0x0a, 0x09, 0x74, 0x75, 0x6e, 0x6e, 0x65, 0x6c, 0x5f, + 0x69, 0x70, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x74, 0x75, 0x6e, 0x6e, 0x65, 0x6c, + 0x49, 0x70, 0x12, 0x16, 0x0a, 0x06, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x18, 0x02, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x06, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x22, 0x84, 0x02, 0x0a, 0x1a, 0x56, + 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x65, 0x54, 0x75, 0x6e, 0x6e, 0x65, 0x6c, 0x50, 0x65, 0x65, + 0x72, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x76, 0x61, 0x6c, + 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x69, 0x64, 0x12, + 0x17, 0x0a, 0x07, 0x75, 0x73, 0x65, 0x72, 0x5f, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, + 0x52, 0x06, 0x75, 0x73, 0x65, 0x72, 0x49, 0x64, 0x12, 0x1d, 0x0a, 0x0a, 0x75, 0x73, 0x65, 0x72, + 0x5f, 0x65, 0x6d, 0x61, 0x69, 0x6c, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x75, 0x73, + 0x65, 0x72, 0x45, 0x6d, 0x61, 0x69, 0x6c, 0x12, 0x23, 0x0a, 0x0d, 0x64, 0x65, 0x6e, 0x69, 0x65, + 0x64, 0x5f, 0x72, 0x65, 0x61, 0x73, 0x6f, 0x6e, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, + 0x64, 0x65, 0x6e, 0x69, 0x65, 0x64, 0x52, 0x65, 0x61, 0x73, 0x6f, 0x6e, 0x12, 0x23, 0x0a, 0x0d, + 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x5f, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x18, 0x05, 0x20, + 0x01, 0x28, 0x09, 0x52, 0x0c, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x54, 0x6f, 0x6b, 0x65, + 0x6e, 0x12, 0x24, 0x0a, 0x0e, 0x70, 0x65, 0x65, 0x72, 0x5f, 0x67, 0x72, 0x6f, 0x75, 0x70, 0x5f, + 0x69, 0x64, 0x73, 0x18, 0x06, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0c, 0x70, 0x65, 0x65, 0x72, 0x47, + 0x72, 0x6f, 0x75, 0x70, 0x49, 0x64, 0x73, 0x12, 0x28, 0x0a, 0x10, 0x70, 0x65, 0x65, 0x72, 0x5f, + 0x67, 0x72, 0x6f, 0x75, 0x70, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x73, 0x18, 0x07, 0x20, 0x03, 0x28, + 0x09, 0x52, 0x0e, 0x70, 0x65, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x4e, 0x61, 0x6d, 0x65, + 0x73, 0x22, 0x81, 0x01, 0x0a, 0x13, 0x53, 0x79, 0x6e, 0x63, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, + 0x67, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x32, 0x0a, 0x04, 0x69, 0x6e, 0x69, + 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, + 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x79, 0x6e, 0x63, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, + 0x73, 0x49, 0x6e, 0x69, 0x74, 0x48, 0x00, 0x52, 0x04, 0x69, 0x6e, 0x69, 0x74, 0x12, 0x2f, 0x0a, + 0x03, 0x61, 0x63, 0x6b, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1b, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x79, 0x6e, 0x63, 0x4d, 0x61, 0x70, 0x70, - 0x69, 0x6e, 0x67, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x20, 0x2e, 0x6d, 0x61, - 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x79, 0x6e, 0x63, 0x4d, 0x61, 0x70, - 0x70, 0x69, 0x6e, 0x67, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x28, 0x01, 0x30, - 0x01, 0x12, 0x54, 0x0a, 0x0d, 0x53, 0x65, 0x6e, 0x64, 0x41, 0x63, 0x63, 0x65, 0x73, 0x73, 0x4c, - 0x6f, 0x67, 0x12, 0x20, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, - 0x53, 0x65, 0x6e, 0x64, 0x41, 0x63, 0x63, 0x65, 0x73, 0x73, 0x4c, 0x6f, 0x67, 0x52, 0x65, 0x71, - 0x75, 0x65, 0x73, 0x74, 0x1a, 0x21, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, - 0x74, 0x2e, 0x53, 0x65, 0x6e, 0x64, 0x41, 0x63, 0x63, 0x65, 0x73, 0x73, 0x4c, 0x6f, 0x67, 0x52, - 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x51, 0x0a, 0x0c, 0x41, 0x75, 0x74, 0x68, 0x65, - 0x6e, 0x74, 0x69, 0x63, 0x61, 0x74, 0x65, 0x12, 0x1f, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, - 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x41, 0x75, 0x74, 0x68, 0x65, 0x6e, 0x74, 0x69, 0x63, 0x61, 0x74, - 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x20, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, - 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x41, 0x75, 0x74, 0x68, 0x65, 0x6e, 0x74, 0x69, 0x63, 0x61, - 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x5d, 0x0a, 0x10, 0x53, 0x65, - 0x6e, 0x64, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x12, 0x23, + 0x69, 0x6e, 0x67, 0x73, 0x41, 0x63, 0x6b, 0x48, 0x00, 0x52, 0x03, 0x61, 0x63, 0x6b, 0x42, 0x05, + 0x0a, 0x03, 0x6d, 0x73, 0x67, 0x22, 0xdf, 0x01, 0x0a, 0x10, 0x53, 0x79, 0x6e, 0x63, 0x4d, 0x61, + 0x70, 0x70, 0x69, 0x6e, 0x67, 0x73, 0x49, 0x6e, 0x69, 0x74, 0x12, 0x19, 0x0a, 0x08, 0x70, 0x72, + 0x6f, 0x78, 0x79, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x70, 0x72, + 0x6f, 0x78, 0x79, 0x49, 0x64, 0x12, 0x18, 0x0a, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, + 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x12, + 0x39, 0x0a, 0x0a, 0x73, 0x74, 0x61, 0x72, 0x74, 0x65, 0x64, 0x5f, 0x61, 0x74, 0x18, 0x03, 0x20, + 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, + 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, + 0x09, 0x73, 0x74, 0x61, 0x72, 0x74, 0x65, 0x64, 0x41, 0x74, 0x12, 0x18, 0x0a, 0x07, 0x61, 0x64, + 0x64, 0x72, 0x65, 0x73, 0x73, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x61, 0x64, 0x64, + 0x72, 0x65, 0x73, 0x73, 0x12, 0x41, 0x0a, 0x0c, 0x63, 0x61, 0x70, 0x61, 0x62, 0x69, 0x6c, 0x69, + 0x74, 0x69, 0x65, 0x73, 0x18, 0x05, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1d, 0x2e, 0x6d, 0x61, 0x6e, + 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x43, 0x61, 0x70, + 0x61, 0x62, 0x69, 0x6c, 0x69, 0x74, 0x69, 0x65, 0x73, 0x52, 0x0c, 0x63, 0x61, 0x70, 0x61, 0x62, + 0x69, 0x6c, 0x69, 0x74, 0x69, 0x65, 0x73, 0x22, 0x11, 0x0a, 0x0f, 0x53, 0x79, 0x6e, 0x63, 0x4d, + 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x73, 0x41, 0x63, 0x6b, 0x22, 0x7e, 0x0a, 0x14, 0x53, 0x79, + 0x6e, 0x63, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, + 0x73, 0x65, 0x12, 0x32, 0x0a, 0x07, 0x6d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x18, 0x01, 0x20, + 0x03, 0x28, 0x0b, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, + 0x2e, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x52, 0x07, 0x6d, + 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x12, 0x32, 0x0a, 0x15, 0x69, 0x6e, 0x69, 0x74, 0x69, 0x61, + 0x6c, 0x5f, 0x73, 0x79, 0x6e, 0x63, 0x5f, 0x63, 0x6f, 0x6d, 0x70, 0x6c, 0x65, 0x74, 0x65, 0x18, + 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x13, 0x69, 0x6e, 0x69, 0x74, 0x69, 0x61, 0x6c, 0x53, 0x79, + 0x6e, 0x63, 0x43, 0x6f, 0x6d, 0x70, 0x6c, 0x65, 0x74, 0x65, 0x2a, 0x64, 0x0a, 0x16, 0x50, 0x72, + 0x6f, 0x78, 0x79, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, + 0x54, 0x79, 0x70, 0x65, 0x12, 0x17, 0x0a, 0x13, 0x55, 0x50, 0x44, 0x41, 0x54, 0x45, 0x5f, 0x54, + 0x59, 0x50, 0x45, 0x5f, 0x43, 0x52, 0x45, 0x41, 0x54, 0x45, 0x44, 0x10, 0x00, 0x12, 0x18, 0x0a, + 0x14, 0x55, 0x50, 0x44, 0x41, 0x54, 0x45, 0x5f, 0x54, 0x59, 0x50, 0x45, 0x5f, 0x4d, 0x4f, 0x44, + 0x49, 0x46, 0x49, 0x45, 0x44, 0x10, 0x01, 0x12, 0x17, 0x0a, 0x13, 0x55, 0x50, 0x44, 0x41, 0x54, + 0x45, 0x5f, 0x54, 0x59, 0x50, 0x45, 0x5f, 0x52, 0x45, 0x4d, 0x4f, 0x56, 0x45, 0x44, 0x10, 0x02, + 0x2a, 0x46, 0x0a, 0x0f, 0x50, 0x61, 0x74, 0x68, 0x52, 0x65, 0x77, 0x72, 0x69, 0x74, 0x65, 0x4d, + 0x6f, 0x64, 0x65, 0x12, 0x18, 0x0a, 0x14, 0x50, 0x41, 0x54, 0x48, 0x5f, 0x52, 0x45, 0x57, 0x52, + 0x49, 0x54, 0x45, 0x5f, 0x44, 0x45, 0x46, 0x41, 0x55, 0x4c, 0x54, 0x10, 0x00, 0x12, 0x19, 0x0a, + 0x15, 0x50, 0x41, 0x54, 0x48, 0x5f, 0x52, 0x45, 0x57, 0x52, 0x49, 0x54, 0x45, 0x5f, 0x50, 0x52, + 0x45, 0x53, 0x45, 0x52, 0x56, 0x45, 0x10, 0x01, 0x2a, 0xc8, 0x01, 0x0a, 0x0b, 0x50, 0x72, 0x6f, + 0x78, 0x79, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x18, 0x0a, 0x14, 0x50, 0x52, 0x4f, 0x58, + 0x59, 0x5f, 0x53, 0x54, 0x41, 0x54, 0x55, 0x53, 0x5f, 0x50, 0x45, 0x4e, 0x44, 0x49, 0x4e, 0x47, + 0x10, 0x00, 0x12, 0x17, 0x0a, 0x13, 0x50, 0x52, 0x4f, 0x58, 0x59, 0x5f, 0x53, 0x54, 0x41, 0x54, + 0x55, 0x53, 0x5f, 0x41, 0x43, 0x54, 0x49, 0x56, 0x45, 0x10, 0x01, 0x12, 0x23, 0x0a, 0x1f, 0x50, + 0x52, 0x4f, 0x58, 0x59, 0x5f, 0x53, 0x54, 0x41, 0x54, 0x55, 0x53, 0x5f, 0x54, 0x55, 0x4e, 0x4e, + 0x45, 0x4c, 0x5f, 0x4e, 0x4f, 0x54, 0x5f, 0x43, 0x52, 0x45, 0x41, 0x54, 0x45, 0x44, 0x10, 0x02, + 0x12, 0x24, 0x0a, 0x20, 0x50, 0x52, 0x4f, 0x58, 0x59, 0x5f, 0x53, 0x54, 0x41, 0x54, 0x55, 0x53, + 0x5f, 0x43, 0x45, 0x52, 0x54, 0x49, 0x46, 0x49, 0x43, 0x41, 0x54, 0x45, 0x5f, 0x50, 0x45, 0x4e, + 0x44, 0x49, 0x4e, 0x47, 0x10, 0x03, 0x12, 0x23, 0x0a, 0x1f, 0x50, 0x52, 0x4f, 0x58, 0x59, 0x5f, + 0x53, 0x54, 0x41, 0x54, 0x55, 0x53, 0x5f, 0x43, 0x45, 0x52, 0x54, 0x49, 0x46, 0x49, 0x43, 0x41, + 0x54, 0x45, 0x5f, 0x46, 0x41, 0x49, 0x4c, 0x45, 0x44, 0x10, 0x04, 0x12, 0x16, 0x0a, 0x12, 0x50, + 0x52, 0x4f, 0x58, 0x59, 0x5f, 0x53, 0x54, 0x41, 0x54, 0x55, 0x53, 0x5f, 0x45, 0x52, 0x52, 0x4f, + 0x52, 0x10, 0x05, 0x32, 0xb8, 0x06, 0x0a, 0x0c, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x53, 0x65, 0x72, + 0x76, 0x69, 0x63, 0x65, 0x12, 0x5f, 0x0a, 0x10, 0x47, 0x65, 0x74, 0x4d, 0x61, 0x70, 0x70, 0x69, + 0x6e, 0x67, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x12, 0x23, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, + 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x47, 0x65, 0x74, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, + 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x24, 0x2e, + 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x47, 0x65, 0x74, 0x4d, 0x61, + 0x70, 0x70, 0x69, 0x6e, 0x67, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, + 0x6e, 0x73, 0x65, 0x30, 0x01, 0x12, 0x55, 0x0a, 0x0c, 0x53, 0x79, 0x6e, 0x63, 0x4d, 0x61, 0x70, + 0x70, 0x69, 0x6e, 0x67, 0x73, 0x12, 0x1f, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, + 0x6e, 0x74, 0x2e, 0x53, 0x79, 0x6e, 0x63, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x73, 0x52, + 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x20, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, + 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x79, 0x6e, 0x63, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x73, + 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x28, 0x01, 0x30, 0x01, 0x12, 0x54, 0x0a, 0x0d, + 0x53, 0x65, 0x6e, 0x64, 0x41, 0x63, 0x63, 0x65, 0x73, 0x73, 0x4c, 0x6f, 0x67, 0x12, 0x20, 0x2e, + 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x65, 0x6e, 0x64, 0x41, + 0x63, 0x63, 0x65, 0x73, 0x73, 0x4c, 0x6f, 0x67, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, + 0x21, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x65, 0x6e, + 0x64, 0x41, 0x63, 0x63, 0x65, 0x73, 0x73, 0x4c, 0x6f, 0x67, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, + 0x73, 0x65, 0x12, 0x51, 0x0a, 0x0c, 0x41, 0x75, 0x74, 0x68, 0x65, 0x6e, 0x74, 0x69, 0x63, 0x61, + 0x74, 0x65, 0x12, 0x1f, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, + 0x41, 0x75, 0x74, 0x68, 0x65, 0x6e, 0x74, 0x69, 0x63, 0x61, 0x74, 0x65, 0x52, 0x65, 0x71, 0x75, + 0x65, 0x73, 0x74, 0x1a, 0x20, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, + 0x2e, 0x41, 0x75, 0x74, 0x68, 0x65, 0x6e, 0x74, 0x69, 0x63, 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, + 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x5d, 0x0a, 0x10, 0x53, 0x65, 0x6e, 0x64, 0x53, 0x74, 0x61, + 0x74, 0x75, 0x73, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x12, 0x23, 0x2e, 0x6d, 0x61, 0x6e, 0x61, + 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x65, 0x6e, 0x64, 0x53, 0x74, 0x61, 0x74, 0x75, + 0x73, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x24, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x65, 0x6e, 0x64, - 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x52, 0x65, 0x71, 0x75, - 0x65, 0x73, 0x74, 0x1a, 0x24, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, - 0x2e, 0x53, 0x65, 0x6e, 0x64, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x55, 0x70, 0x64, 0x61, 0x74, - 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x5a, 0x0a, 0x0f, 0x43, 0x72, 0x65, - 0x61, 0x74, 0x65, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x50, 0x65, 0x65, 0x72, 0x12, 0x22, 0x2e, 0x6d, - 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, - 0x50, 0x72, 0x6f, 0x78, 0x79, 0x50, 0x65, 0x65, 0x72, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, - 0x1a, 0x23, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x43, 0x72, - 0x65, 0x61, 0x74, 0x65, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x50, 0x65, 0x65, 0x72, 0x52, 0x65, 0x73, - 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x4b, 0x0a, 0x0a, 0x47, 0x65, 0x74, 0x4f, 0x49, 0x44, 0x43, - 0x55, 0x52, 0x4c, 0x12, 0x1d, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, - 0x2e, 0x47, 0x65, 0x74, 0x4f, 0x49, 0x44, 0x43, 0x55, 0x52, 0x4c, 0x52, 0x65, 0x71, 0x75, 0x65, - 0x73, 0x74, 0x1a, 0x1e, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, - 0x47, 0x65, 0x74, 0x4f, 0x49, 0x44, 0x43, 0x55, 0x52, 0x4c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, - 0x73, 0x65, 0x12, 0x5a, 0x0a, 0x0f, 0x56, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x65, 0x53, 0x65, - 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x22, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, - 0x6e, 0x74, 0x2e, 0x56, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x65, 0x53, 0x65, 0x73, 0x73, 0x69, - 0x6f, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x23, 0x2e, 0x6d, 0x61, 0x6e, 0x61, - 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x56, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x65, 0x53, - 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x42, 0x08, + 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, + 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x5a, 0x0a, 0x0f, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x50, 0x72, + 0x6f, 0x78, 0x79, 0x50, 0x65, 0x65, 0x72, 0x12, 0x22, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, + 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x50, 0x72, 0x6f, 0x78, 0x79, + 0x50, 0x65, 0x65, 0x72, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x23, 0x2e, 0x6d, 0x61, + 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x50, + 0x72, 0x6f, 0x78, 0x79, 0x50, 0x65, 0x65, 0x72, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, + 0x12, 0x4b, 0x0a, 0x0a, 0x47, 0x65, 0x74, 0x4f, 0x49, 0x44, 0x43, 0x55, 0x52, 0x4c, 0x12, 0x1d, + 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x47, 0x65, 0x74, 0x4f, + 0x49, 0x44, 0x43, 0x55, 0x52, 0x4c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1e, 0x2e, + 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x47, 0x65, 0x74, 0x4f, 0x49, + 0x44, 0x43, 0x55, 0x52, 0x4c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x5a, 0x0a, + 0x0f, 0x56, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x65, 0x53, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, + 0x12, 0x22, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x56, 0x61, + 0x6c, 0x69, 0x64, 0x61, 0x74, 0x65, 0x53, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x71, + 0x75, 0x65, 0x73, 0x74, 0x1a, 0x23, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, + 0x74, 0x2e, 0x56, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x65, 0x53, 0x65, 0x73, 0x73, 0x69, 0x6f, + 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x63, 0x0a, 0x12, 0x56, 0x61, 0x6c, + 0x69, 0x64, 0x61, 0x74, 0x65, 0x54, 0x75, 0x6e, 0x6e, 0x65, 0x6c, 0x50, 0x65, 0x65, 0x72, 0x12, + 0x25, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x56, 0x61, 0x6c, + 0x69, 0x64, 0x61, 0x74, 0x65, 0x54, 0x75, 0x6e, 0x6e, 0x65, 0x6c, 0x50, 0x65, 0x65, 0x72, 0x52, + 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x26, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, + 0x65, 0x6e, 0x74, 0x2e, 0x56, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x65, 0x54, 0x75, 0x6e, 0x6e, + 0x65, 0x6c, 0x50, 0x65, 0x65, 0x72, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } @@ -2634,53 +3014,56 @@ func file_proxy_service_proto_rawDescGZIP() []byte { } var file_proxy_service_proto_enumTypes = make([]protoimpl.EnumInfo, 3) -var file_proxy_service_proto_msgTypes = make([]protoimpl.MessageInfo, 31) +var file_proxy_service_proto_msgTypes = make([]protoimpl.MessageInfo, 34) var file_proxy_service_proto_goTypes = []interface{}{ - (ProxyMappingUpdateType)(0), // 0: management.ProxyMappingUpdateType - (PathRewriteMode)(0), // 1: management.PathRewriteMode - (ProxyStatus)(0), // 2: management.ProxyStatus - (*ProxyCapabilities)(nil), // 3: management.ProxyCapabilities - (*GetMappingUpdateRequest)(nil), // 4: management.GetMappingUpdateRequest - (*GetMappingUpdateResponse)(nil), // 5: management.GetMappingUpdateResponse - (*PathTargetOptions)(nil), // 6: management.PathTargetOptions - (*PathMapping)(nil), // 7: management.PathMapping - (*HeaderAuth)(nil), // 8: management.HeaderAuth - (*Authentication)(nil), // 9: management.Authentication - (*AccessRestrictions)(nil), // 10: management.AccessRestrictions - (*ProxyMapping)(nil), // 11: management.ProxyMapping - (*SendAccessLogRequest)(nil), // 12: management.SendAccessLogRequest - (*SendAccessLogResponse)(nil), // 13: management.SendAccessLogResponse - (*AccessLog)(nil), // 14: management.AccessLog - (*AuthenticateRequest)(nil), // 15: management.AuthenticateRequest - (*HeaderAuthRequest)(nil), // 16: management.HeaderAuthRequest - (*PasswordRequest)(nil), // 17: management.PasswordRequest - (*PinRequest)(nil), // 18: management.PinRequest - (*AuthenticateResponse)(nil), // 19: management.AuthenticateResponse - (*SendStatusUpdateRequest)(nil), // 20: management.SendStatusUpdateRequest - (*SendStatusUpdateResponse)(nil), // 21: management.SendStatusUpdateResponse - (*CreateProxyPeerRequest)(nil), // 22: management.CreateProxyPeerRequest - (*CreateProxyPeerResponse)(nil), // 23: management.CreateProxyPeerResponse - (*GetOIDCURLRequest)(nil), // 24: management.GetOIDCURLRequest - (*GetOIDCURLResponse)(nil), // 25: management.GetOIDCURLResponse - (*ValidateSessionRequest)(nil), // 26: management.ValidateSessionRequest - (*ValidateSessionResponse)(nil), // 27: management.ValidateSessionResponse - (*SyncMappingsRequest)(nil), // 28: management.SyncMappingsRequest - (*SyncMappingsInit)(nil), // 29: management.SyncMappingsInit - (*SyncMappingsAck)(nil), // 30: management.SyncMappingsAck - (*SyncMappingsResponse)(nil), // 31: management.SyncMappingsResponse - nil, // 32: management.PathTargetOptions.CustomHeadersEntry - nil, // 33: management.AccessLog.MetadataEntry - (*timestamppb.Timestamp)(nil), // 34: google.protobuf.Timestamp - (*durationpb.Duration)(nil), // 35: google.protobuf.Duration + (ProxyMappingUpdateType)(0), // 0: management.ProxyMappingUpdateType + (PathRewriteMode)(0), // 1: management.PathRewriteMode + (ProxyStatus)(0), // 2: management.ProxyStatus + (*ProxyCapabilities)(nil), // 3: management.ProxyCapabilities + (*GetMappingUpdateRequest)(nil), // 4: management.GetMappingUpdateRequest + (*GetMappingUpdateResponse)(nil), // 5: management.GetMappingUpdateResponse + (*PathTargetOptions)(nil), // 6: management.PathTargetOptions + (*PathMapping)(nil), // 7: management.PathMapping + (*HeaderAuth)(nil), // 8: management.HeaderAuth + (*Authentication)(nil), // 9: management.Authentication + (*AccessRestrictions)(nil), // 10: management.AccessRestrictions + (*ProxyMapping)(nil), // 11: management.ProxyMapping + (*SendAccessLogRequest)(nil), // 12: management.SendAccessLogRequest + (*SendAccessLogResponse)(nil), // 13: management.SendAccessLogResponse + (*AccessLog)(nil), // 14: management.AccessLog + (*AuthenticateRequest)(nil), // 15: management.AuthenticateRequest + (*HeaderAuthRequest)(nil), // 16: management.HeaderAuthRequest + (*PasswordRequest)(nil), // 17: management.PasswordRequest + (*PinRequest)(nil), // 18: management.PinRequest + (*AuthenticateResponse)(nil), // 19: management.AuthenticateResponse + (*SendStatusUpdateRequest)(nil), // 20: management.SendStatusUpdateRequest + (*ProxyInboundListener)(nil), // 21: management.ProxyInboundListener + (*SendStatusUpdateResponse)(nil), // 22: management.SendStatusUpdateResponse + (*CreateProxyPeerRequest)(nil), // 23: management.CreateProxyPeerRequest + (*CreateProxyPeerResponse)(nil), // 24: management.CreateProxyPeerResponse + (*GetOIDCURLRequest)(nil), // 25: management.GetOIDCURLRequest + (*GetOIDCURLResponse)(nil), // 26: management.GetOIDCURLResponse + (*ValidateSessionRequest)(nil), // 27: management.ValidateSessionRequest + (*ValidateSessionResponse)(nil), // 28: management.ValidateSessionResponse + (*ValidateTunnelPeerRequest)(nil), // 29: management.ValidateTunnelPeerRequest + (*ValidateTunnelPeerResponse)(nil), // 30: management.ValidateTunnelPeerResponse + (*SyncMappingsRequest)(nil), // 31: management.SyncMappingsRequest + (*SyncMappingsInit)(nil), // 32: management.SyncMappingsInit + (*SyncMappingsAck)(nil), // 33: management.SyncMappingsAck + (*SyncMappingsResponse)(nil), // 34: management.SyncMappingsResponse + nil, // 35: management.PathTargetOptions.CustomHeadersEntry + nil, // 36: management.AccessLog.MetadataEntry + (*timestamppb.Timestamp)(nil), // 37: google.protobuf.Timestamp + (*durationpb.Duration)(nil), // 38: google.protobuf.Duration } var file_proxy_service_proto_depIdxs = []int32{ - 34, // 0: management.GetMappingUpdateRequest.started_at:type_name -> google.protobuf.Timestamp + 37, // 0: management.GetMappingUpdateRequest.started_at:type_name -> google.protobuf.Timestamp 3, // 1: management.GetMappingUpdateRequest.capabilities:type_name -> management.ProxyCapabilities 11, // 2: management.GetMappingUpdateResponse.mapping:type_name -> management.ProxyMapping - 35, // 3: management.PathTargetOptions.request_timeout:type_name -> google.protobuf.Duration + 38, // 3: management.PathTargetOptions.request_timeout:type_name -> google.protobuf.Duration 1, // 4: management.PathTargetOptions.path_rewrite:type_name -> management.PathRewriteMode - 32, // 5: management.PathTargetOptions.custom_headers:type_name -> management.PathTargetOptions.CustomHeadersEntry - 35, // 6: management.PathTargetOptions.session_idle_timeout:type_name -> google.protobuf.Duration + 35, // 5: management.PathTargetOptions.custom_headers:type_name -> management.PathTargetOptions.CustomHeadersEntry + 38, // 6: management.PathTargetOptions.session_idle_timeout:type_name -> google.protobuf.Duration 6, // 7: management.PathMapping.options:type_name -> management.PathTargetOptions 8, // 8: management.Authentication.header_auths:type_name -> management.HeaderAuth 0, // 9: management.ProxyMapping.type:type_name -> management.ProxyMappingUpdateType @@ -2688,38 +3071,41 @@ var file_proxy_service_proto_depIdxs = []int32{ 9, // 11: management.ProxyMapping.auth:type_name -> management.Authentication 10, // 12: management.ProxyMapping.access_restrictions:type_name -> management.AccessRestrictions 14, // 13: management.SendAccessLogRequest.log:type_name -> management.AccessLog - 34, // 14: management.AccessLog.timestamp:type_name -> google.protobuf.Timestamp - 33, // 15: management.AccessLog.metadata:type_name -> management.AccessLog.MetadataEntry + 37, // 14: management.AccessLog.timestamp:type_name -> google.protobuf.Timestamp + 36, // 15: management.AccessLog.metadata:type_name -> management.AccessLog.MetadataEntry 17, // 16: management.AuthenticateRequest.password:type_name -> management.PasswordRequest 18, // 17: management.AuthenticateRequest.pin:type_name -> management.PinRequest 16, // 18: management.AuthenticateRequest.header_auth:type_name -> management.HeaderAuthRequest 2, // 19: management.SendStatusUpdateRequest.status:type_name -> management.ProxyStatus - 29, // 20: management.SyncMappingsRequest.init:type_name -> management.SyncMappingsInit - 30, // 21: management.SyncMappingsRequest.ack:type_name -> management.SyncMappingsAck - 34, // 22: management.SyncMappingsInit.started_at:type_name -> google.protobuf.Timestamp - 3, // 23: management.SyncMappingsInit.capabilities:type_name -> management.ProxyCapabilities - 11, // 24: management.SyncMappingsResponse.mapping:type_name -> management.ProxyMapping - 4, // 25: management.ProxyService.GetMappingUpdate:input_type -> management.GetMappingUpdateRequest - 28, // 26: management.ProxyService.SyncMappings:input_type -> management.SyncMappingsRequest - 12, // 27: management.ProxyService.SendAccessLog:input_type -> management.SendAccessLogRequest - 15, // 28: management.ProxyService.Authenticate:input_type -> management.AuthenticateRequest - 20, // 29: management.ProxyService.SendStatusUpdate:input_type -> management.SendStatusUpdateRequest - 22, // 30: management.ProxyService.CreateProxyPeer:input_type -> management.CreateProxyPeerRequest - 24, // 31: management.ProxyService.GetOIDCURL:input_type -> management.GetOIDCURLRequest - 26, // 32: management.ProxyService.ValidateSession:input_type -> management.ValidateSessionRequest - 5, // 33: management.ProxyService.GetMappingUpdate:output_type -> management.GetMappingUpdateResponse - 31, // 34: management.ProxyService.SyncMappings:output_type -> management.SyncMappingsResponse - 13, // 35: management.ProxyService.SendAccessLog:output_type -> management.SendAccessLogResponse - 19, // 36: management.ProxyService.Authenticate:output_type -> management.AuthenticateResponse - 21, // 37: management.ProxyService.SendStatusUpdate:output_type -> management.SendStatusUpdateResponse - 23, // 38: management.ProxyService.CreateProxyPeer:output_type -> management.CreateProxyPeerResponse - 25, // 39: management.ProxyService.GetOIDCURL:output_type -> management.GetOIDCURLResponse - 27, // 40: management.ProxyService.ValidateSession:output_type -> management.ValidateSessionResponse - 33, // [33:41] is the sub-list for method output_type - 25, // [25:33] is the sub-list for method input_type - 25, // [25:25] is the sub-list for extension type_name - 25, // [25:25] is the sub-list for extension extendee - 0, // [0:25] is the sub-list for field type_name + 21, // 20: management.SendStatusUpdateRequest.inbound_listener:type_name -> management.ProxyInboundListener + 32, // 21: management.SyncMappingsRequest.init:type_name -> management.SyncMappingsInit + 33, // 22: management.SyncMappingsRequest.ack:type_name -> management.SyncMappingsAck + 37, // 23: management.SyncMappingsInit.started_at:type_name -> google.protobuf.Timestamp + 3, // 24: management.SyncMappingsInit.capabilities:type_name -> management.ProxyCapabilities + 11, // 25: management.SyncMappingsResponse.mapping:type_name -> management.ProxyMapping + 4, // 26: management.ProxyService.GetMappingUpdate:input_type -> management.GetMappingUpdateRequest + 31, // 27: management.ProxyService.SyncMappings:input_type -> management.SyncMappingsRequest + 12, // 28: management.ProxyService.SendAccessLog:input_type -> management.SendAccessLogRequest + 15, // 29: management.ProxyService.Authenticate:input_type -> management.AuthenticateRequest + 20, // 30: management.ProxyService.SendStatusUpdate:input_type -> management.SendStatusUpdateRequest + 23, // 31: management.ProxyService.CreateProxyPeer:input_type -> management.CreateProxyPeerRequest + 25, // 32: management.ProxyService.GetOIDCURL:input_type -> management.GetOIDCURLRequest + 27, // 33: management.ProxyService.ValidateSession:input_type -> management.ValidateSessionRequest + 29, // 34: management.ProxyService.ValidateTunnelPeer:input_type -> management.ValidateTunnelPeerRequest + 5, // 35: management.ProxyService.GetMappingUpdate:output_type -> management.GetMappingUpdateResponse + 34, // 36: management.ProxyService.SyncMappings:output_type -> management.SyncMappingsResponse + 13, // 37: management.ProxyService.SendAccessLog:output_type -> management.SendAccessLogResponse + 19, // 38: management.ProxyService.Authenticate:output_type -> management.AuthenticateResponse + 22, // 39: management.ProxyService.SendStatusUpdate:output_type -> management.SendStatusUpdateResponse + 24, // 40: management.ProxyService.CreateProxyPeer:output_type -> management.CreateProxyPeerResponse + 26, // 41: management.ProxyService.GetOIDCURL:output_type -> management.GetOIDCURLResponse + 28, // 42: management.ProxyService.ValidateSession:output_type -> management.ValidateSessionResponse + 30, // 43: management.ProxyService.ValidateTunnelPeer:output_type -> management.ValidateTunnelPeerResponse + 35, // [35:44] is the sub-list for method output_type + 26, // [26:35] is the sub-list for method input_type + 26, // [26:26] is the sub-list for extension type_name + 26, // [26:26] is the sub-list for extension extendee + 0, // [0:26] is the sub-list for field type_name } func init() { file_proxy_service_proto_init() } @@ -2945,7 +3331,7 @@ func file_proxy_service_proto_init() { } } file_proxy_service_proto_msgTypes[18].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*SendStatusUpdateResponse); i { + switch v := v.(*ProxyInboundListener); i { case 0: return &v.state case 1: @@ -2957,7 +3343,7 @@ func file_proxy_service_proto_init() { } } file_proxy_service_proto_msgTypes[19].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*CreateProxyPeerRequest); i { + switch v := v.(*SendStatusUpdateResponse); i { case 0: return &v.state case 1: @@ -2969,7 +3355,7 @@ func file_proxy_service_proto_init() { } } file_proxy_service_proto_msgTypes[20].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*CreateProxyPeerResponse); i { + switch v := v.(*CreateProxyPeerRequest); i { case 0: return &v.state case 1: @@ -2981,7 +3367,7 @@ func file_proxy_service_proto_init() { } } file_proxy_service_proto_msgTypes[21].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*GetOIDCURLRequest); i { + switch v := v.(*CreateProxyPeerResponse); i { case 0: return &v.state case 1: @@ -2993,7 +3379,7 @@ func file_proxy_service_proto_init() { } } file_proxy_service_proto_msgTypes[22].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*GetOIDCURLResponse); i { + switch v := v.(*GetOIDCURLRequest); i { case 0: return &v.state case 1: @@ -3005,7 +3391,7 @@ func file_proxy_service_proto_init() { } } file_proxy_service_proto_msgTypes[23].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*ValidateSessionRequest); i { + switch v := v.(*GetOIDCURLResponse); i { case 0: return &v.state case 1: @@ -3017,7 +3403,7 @@ func file_proxy_service_proto_init() { } } file_proxy_service_proto_msgTypes[24].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*ValidateSessionResponse); i { + switch v := v.(*ValidateSessionRequest); i { case 0: return &v.state case 1: @@ -3029,7 +3415,7 @@ func file_proxy_service_proto_init() { } } file_proxy_service_proto_msgTypes[25].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*SyncMappingsRequest); i { + switch v := v.(*ValidateSessionResponse); i { case 0: return &v.state case 1: @@ -3041,7 +3427,7 @@ func file_proxy_service_proto_init() { } } file_proxy_service_proto_msgTypes[26].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*SyncMappingsInit); i { + switch v := v.(*ValidateTunnelPeerRequest); i { case 0: return &v.state case 1: @@ -3053,7 +3439,7 @@ func file_proxy_service_proto_init() { } } file_proxy_service_proto_msgTypes[27].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*SyncMappingsAck); i { + switch v := v.(*ValidateTunnelPeerResponse); i { case 0: return &v.state case 1: @@ -3065,6 +3451,42 @@ func file_proxy_service_proto_init() { } } file_proxy_service_proto_msgTypes[28].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*SyncMappingsRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_proxy_service_proto_msgTypes[29].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*SyncMappingsInit); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_proxy_service_proto_msgTypes[30].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*SyncMappingsAck); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_proxy_service_proto_msgTypes[31].Exporter = func(v interface{}, i int) interface{} { switch v := v.(*SyncMappingsResponse); i { case 0: return &v.state @@ -3084,8 +3506,8 @@ func file_proxy_service_proto_init() { (*AuthenticateRequest_HeaderAuth)(nil), } file_proxy_service_proto_msgTypes[17].OneofWrappers = []interface{}{} - file_proxy_service_proto_msgTypes[20].OneofWrappers = []interface{}{} - file_proxy_service_proto_msgTypes[25].OneofWrappers = []interface{}{ + file_proxy_service_proto_msgTypes[21].OneofWrappers = []interface{}{} + file_proxy_service_proto_msgTypes[28].OneofWrappers = []interface{}{ (*SyncMappingsRequest_Init)(nil), (*SyncMappingsRequest_Ack)(nil), } @@ -3095,7 +3517,7 @@ func file_proxy_service_proto_init() { GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: file_proxy_service_proto_rawDesc, NumEnums: 3, - NumMessages: 31, + NumMessages: 34, NumExtensions: 0, NumServices: 1, }, diff --git a/shared/management/proto/proxy_service.proto b/shared/management/proto/proxy_service.proto index d1171b27e..71e18c721 100644 --- a/shared/management/proto/proxy_service.proto +++ b/shared/management/proto/proxy_service.proto @@ -34,6 +34,15 @@ service ProxyService { // ValidateSession validates a session token and checks user access permissions. // Called by the proxy after receiving a session token from OIDC callback. rpc ValidateSession(ValidateSessionRequest) returns (ValidateSessionResponse); + + // ValidateTunnelPeer resolves an inbound peer by its WireGuard tunnel IP and + // checks the resolved user's access against the service's access_groups. + // Acts as a fast-path equivalent of OIDC for requests originating on the + // netbird mesh: when the source IP maps to a known peer in the calling + // account and that peer is in the service's access_groups, the proxy can + // issue a session cookie without redirecting through the OIDC flow. + // Mirrors ValidateSession's response shape. + rpc ValidateTunnelPeer(ValidateTunnelPeerRequest) returns (ValidateTunnelPeerResponse); } // ProxyCapabilities describes what a proxy can handle. @@ -45,6 +54,13 @@ message ProxyCapabilities { optional bool require_subdomain = 2; // Whether the proxy has CrowdSec configured and can enforce IP reputation checks. optional bool supports_crowdsec = 3; + // Whether the proxy is running embedded in the netbird client and serving + // exclusively over the WireGuard tunnel (i.e. `netbird proxy` rather than + // the standalone netbird-proxy binary). Surfaces upstream so dashboards can + // distinguish per-peer / private clusters from centralised ones. + optional bool private = 4; + // Whether the proxy enforces ProxyMapping.private (fails closed on ValidateTunnelPeer failure). Management MUST NOT stream private mappings to proxies that don't claim this. + optional bool supports_private_service = 5; } // GetMappingUpdateRequest is sent to initialise a mapping stream. @@ -86,6 +102,11 @@ message PathTargetOptions { bool proxy_protocol = 5; // Idle timeout before a UDP session is reaped. google.protobuf.Duration session_idle_timeout = 6; + // When true, the proxy dials this target via the host's network stack + // instead of through the embedded NetBird client. Useful for upstreams + // reachable without WireGuard (public APIs, LAN services, localhost + // sidecars). Defaults to false โ€” embedded client is the standard path. + bool direct_upstream = 7; } message PathMapping { @@ -138,6 +159,8 @@ message ProxyMapping { // For L4/TLS: the port the proxy listens on. int32 listen_port = 11; AccessRestrictions access_restrictions = 12; + // NetBird-only: the proxy MUST call ValidateTunnelPeer and fail closed; operator auth schemes are bypassed. + bool private = 13; } // SendAccessLogRequest consists of one or more AccessLogs from a Proxy. @@ -213,6 +236,25 @@ message SendStatusUpdateRequest { ProxyStatus status = 3; bool certificate_issued = 4; optional string error_message = 5; + // Per-account inbound listener state for the account that owns + // service_id. Populated only when --private-inbound is enabled and the + // embedded client for the account is up. Field numbers >=50 reserved + // for observability extensions. + optional ProxyInboundListener inbound_listener = 50; +} + +// ProxyInboundListener describes a per-account inbound listener that the +// proxy has bound on the embedded netstack of the account's WireGuard +// client. Surfaced so dashboards can render "this account is reachable +// at : on this proxy". +message ProxyInboundListener { + // Tunnel IP the embedded netstack listens on. Same address other peers + // in the account see for the proxy peer. + string tunnel_ip = 1; + // TLS port served on tunnel_ip (auto-detected, default 443). + uint32 https_port = 2; + // Plain-HTTP port served on tunnel_ip (auto-detected, default 80). + uint32 http_port = 3; } // SendStatusUpdateResponse is intentionally empty to allow for future expansion @@ -254,6 +296,52 @@ message ValidateSessionResponse { string user_id = 2; string user_email = 3; string denied_reason = 4; + // peer_group_ids carries the calling user's group memberships so the + // proxy can authorise policy-aware middlewares without an additional + // management round-trip. + repeated string peer_group_ids = 5; + // peer_group_names carries the human-readable display names for the + // ids in peer_group_ids, ordered identically (positional pairing). + // Stamped onto upstream requests as X-NetBird-Groups so downstream + // services can read names rather than opaque ids. + repeated string peer_group_names = 6; +} + +// ValidateTunnelPeerRequest carries the inbound peer's tunnel IP and the +// service domain whose group requirements should gate access. The calling +// account is inferred from the proxy's gRPC metadata (ProxyToken). +message ValidateTunnelPeerRequest { + string tunnel_ip = 1; + string domain = 2; +} + +// ValidateTunnelPeerResponse mirrors ValidateSessionResponse plus a freshly +// minted session_token: when valid is true, the proxy installs the token as +// a session cookie so subsequent requests skip the management round-trip, +// matching the OIDC flow's UX. denied_reason values: +// "peer_not_found" โ€” no peer with that tunnel IP in the calling account +// "no_user" โ€” peer exists but is not bound to a user +// "service_not_found" +// "account_mismatch" +// "not_in_group" โ€” peer resolved but not in service.access_groups +message ValidateTunnelPeerResponse { + bool valid = 1; + string user_id = 2; + string user_email = 3; + string denied_reason = 4; + // session_token is set only when valid is true. Same shape as the JWT + // the OIDC flow produces โ€” proxy installs it via setSessionCookie so the + // tunnel fast-path is indistinguishable from OIDC for subsequent requests. + string session_token = 5; + // peer_group_ids carries the resolved peer's user group memberships so + // the proxy can authorise policy-aware middlewares without an additional + // management round-trip. + repeated string peer_group_ids = 6; + // peer_group_names carries the human-readable display names for the + // ids in peer_group_ids, ordered identically (positional pairing). + // Stamped onto upstream requests as X-NetBird-Groups so downstream + // services can read names rather than opaque ids. + repeated string peer_group_names = 7; } // SyncMappingsRequest is sent by the proxy on the bidirectional SyncMappings @@ -287,3 +375,4 @@ message SyncMappingsResponse { // initial_sync_complete is set on the last message of the initial snapshot. bool initial_sync_complete = 2; } + diff --git a/shared/management/proto/proxy_service_grpc.pb.go b/shared/management/proto/proxy_service_grpc.pb.go index fdc031ed7..40064fe61 100644 --- a/shared/management/proto/proxy_service_grpc.pb.go +++ b/shared/management/proto/proxy_service_grpc.pb.go @@ -35,6 +35,14 @@ type ProxyServiceClient interface { // ValidateSession validates a session token and checks user access permissions. // Called by the proxy after receiving a session token from OIDC callback. ValidateSession(ctx context.Context, in *ValidateSessionRequest, opts ...grpc.CallOption) (*ValidateSessionResponse, error) + // ValidateTunnelPeer resolves an inbound peer by its WireGuard tunnel IP and + // checks the resolved user's access against the service's access_groups. + // Acts as a fast-path equivalent of OIDC for requests originating on the + // netbird mesh: when the source IP maps to a known peer in the calling + // account and that peer is in the service's access_groups, the proxy can + // issue a session cookie without redirecting through the OIDC flow. + // Mirrors ValidateSession's response shape. + ValidateTunnelPeer(ctx context.Context, in *ValidateTunnelPeerRequest, opts ...grpc.CallOption) (*ValidateTunnelPeerResponse, error) } type proxyServiceClient struct { @@ -162,6 +170,15 @@ func (c *proxyServiceClient) ValidateSession(ctx context.Context, in *ValidateSe return out, nil } +func (c *proxyServiceClient) ValidateTunnelPeer(ctx context.Context, in *ValidateTunnelPeerRequest, opts ...grpc.CallOption) (*ValidateTunnelPeerResponse, error) { + out := new(ValidateTunnelPeerResponse) + err := c.cc.Invoke(ctx, "/management.ProxyService/ValidateTunnelPeer", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + // ProxyServiceServer is the server API for ProxyService service. // All implementations must embed UnimplementedProxyServiceServer // for forward compatibility @@ -183,6 +200,14 @@ type ProxyServiceServer interface { // ValidateSession validates a session token and checks user access permissions. // Called by the proxy after receiving a session token from OIDC callback. ValidateSession(context.Context, *ValidateSessionRequest) (*ValidateSessionResponse, error) + // ValidateTunnelPeer resolves an inbound peer by its WireGuard tunnel IP and + // checks the resolved user's access against the service's access_groups. + // Acts as a fast-path equivalent of OIDC for requests originating on the + // netbird mesh: when the source IP maps to a known peer in the calling + // account and that peer is in the service's access_groups, the proxy can + // issue a session cookie without redirecting through the OIDC flow. + // Mirrors ValidateSession's response shape. + ValidateTunnelPeer(context.Context, *ValidateTunnelPeerRequest) (*ValidateTunnelPeerResponse, error) mustEmbedUnimplementedProxyServiceServer() } @@ -214,6 +239,9 @@ func (UnimplementedProxyServiceServer) GetOIDCURL(context.Context, *GetOIDCURLRe func (UnimplementedProxyServiceServer) ValidateSession(context.Context, *ValidateSessionRequest) (*ValidateSessionResponse, error) { return nil, status.Errorf(codes.Unimplemented, "method ValidateSession not implemented") } +func (UnimplementedProxyServiceServer) ValidateTunnelPeer(context.Context, *ValidateTunnelPeerRequest) (*ValidateTunnelPeerResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method ValidateTunnelPeer not implemented") +} func (UnimplementedProxyServiceServer) mustEmbedUnimplementedProxyServiceServer() {} // UnsafeProxyServiceServer may be embedded to opt out of forward compatibility for this service. @@ -382,6 +410,24 @@ func _ProxyService_ValidateSession_Handler(srv interface{}, ctx context.Context, return interceptor(ctx, in, info, handler) } +func _ProxyService_ValidateTunnelPeer_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(ValidateTunnelPeerRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(ProxyServiceServer).ValidateTunnelPeer(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/management.ProxyService/ValidateTunnelPeer", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(ProxyServiceServer).ValidateTunnelPeer(ctx, req.(*ValidateTunnelPeerRequest)) + } + return interceptor(ctx, in, info, handler) +} + // ProxyService_ServiceDesc is the grpc.ServiceDesc for ProxyService service. // It's only intended for direct use with grpc.RegisterService, // and not to be introspected or modified (even as a copy) @@ -413,6 +459,10 @@ var ProxyService_ServiceDesc = grpc.ServiceDesc{ MethodName: "ValidateSession", Handler: _ProxyService_ValidateSession_Handler, }, + { + MethodName: "ValidateTunnelPeer", + Handler: _ProxyService_ValidateTunnelPeer_Handler, + }, }, Streams: []grpc.StreamDesc{ { From b3b0feb3b8500f8ee181445ceb1f5b79a4426269 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Tue, 26 May 2026 01:38:21 +0900 Subject: [PATCH 25/31] [client] Filter scoped/cloned default routes from BSD network monitor RTM_ADD (#6208) --- .../networkmonitor/check_change_common.go | 20 ++++++++++++------ .../systemops/routeflags_addfilter_bsd.go | 9 ++++++++ .../systemops/routeflags_addfilter_darwin.go | 21 +++++++++++++++++++ 3 files changed, 44 insertions(+), 6 deletions(-) create mode 100644 client/internal/routemanager/systemops/routeflags_addfilter_bsd.go create mode 100644 client/internal/routemanager/systemops/routeflags_addfilter_darwin.go diff --git a/client/internal/networkmonitor/check_change_common.go b/client/internal/networkmonitor/check_change_common.go index a4a4f76ac..f693081a6 100644 --- a/client/internal/networkmonitor/check_change_common.go +++ b/client/internal/networkmonitor/check_change_common.go @@ -50,7 +50,7 @@ func routeCheck(ctx context.Context, fd int, nexthopv4, nexthopv6 systemops.Next switch msg.Type { // handle route changes case unix.RTM_ADD, syscall.RTM_DELETE: - route, err := parseRouteMessage(buf[:n]) + route, flags, err := parseRouteMessage(buf[:n]) if err != nil { log.Debugf("Network monitor: error parsing routing message: %v", err) continue @@ -66,6 +66,10 @@ func routeCheck(ctx context.Context, fd int, nexthopv4, nexthopv6 systemops.Next } switch msg.Type { case unix.RTM_ADD: + if systemops.IgnoreAddedDefaultRoute(flags) { + log.Debugf("Network monitor: ignoring added default route via %s, interface %s, flags %#x", route.Gw, intf, flags) + continue + } log.Infof("Network monitor: default route changed: via %s, interface %s", route.Gw, intf) return nil case unix.RTM_DELETE: @@ -78,22 +82,26 @@ func routeCheck(ctx context.Context, fd int, nexthopv4, nexthopv6 systemops.Next } } -func parseRouteMessage(buf []byte) (*systemops.Route, error) { +func parseRouteMessage(buf []byte) (*systemops.Route, int, error) { msgs, err := route.ParseRIB(route.RIBTypeRoute, buf) if err != nil { - return nil, fmt.Errorf("parse RIB: %v", err) + return nil, 0, fmt.Errorf("parse RIB: %v", err) } if len(msgs) != 1 { - return nil, fmt.Errorf("unexpected RIB message msgs: %v", msgs) + return nil, 0, fmt.Errorf("unexpected RIB message msgs: %v", msgs) } msg, ok := msgs[0].(*route.RouteMessage) if !ok { - return nil, fmt.Errorf("unexpected RIB message type: %T", msgs[0]) + return nil, 0, fmt.Errorf("unexpected RIB message type: %T", msgs[0]) } - return systemops.MsgToRoute(msg) + r, err := systemops.MsgToRoute(msg) + if err != nil { + return nil, 0, err + } + return r, msg.Flags, nil } // waitReadable blocks until fd has data to read, or ctx is cancelled. diff --git a/client/internal/routemanager/systemops/routeflags_addfilter_bsd.go b/client/internal/routemanager/systemops/routeflags_addfilter_bsd.go new file mode 100644 index 000000000..45a1bfceb --- /dev/null +++ b/client/internal/routemanager/systemops/routeflags_addfilter_bsd.go @@ -0,0 +1,9 @@ +//go:build dragonfly || freebsd || netbsd || openbsd + +package systemops + +// IgnoreAddedDefaultRoute reports whether an RTM_ADD default route with the +// given flags should be ignored by the network monitor. +func IgnoreAddedDefaultRoute(flags int) bool { + return filterRoutesByFlags(flags) +} diff --git a/client/internal/routemanager/systemops/routeflags_addfilter_darwin.go b/client/internal/routemanager/systemops/routeflags_addfilter_darwin.go new file mode 100644 index 000000000..e8f655387 --- /dev/null +++ b/client/internal/routemanager/systemops/routeflags_addfilter_darwin.go @@ -0,0 +1,21 @@ +//go:build darwin + +package systemops + +import "golang.org/x/sys/unix" + +// IgnoreAddedDefaultRoute reports whether an RTM_ADD default route with the +// given flags should be ignored by the network monitor. Scoped routes +// (RTF_IFSCOPE) are tied to a specific interface index and cannot replace the +// unscoped default the kernel uses for general egress, so flapping ones (e.g. +// Wi-Fi calling IMS tunnels on ipsec0, Docker bridges, scoped utun defaults) +// must not trigger an engine restart. +func IgnoreAddedDefaultRoute(flags int) bool { + if filterRoutesByFlags(flags) { + return true + } + if flags&unix.RTF_IFSCOPE != 0 { + return true + } + return false +} From 4983b5cf17bfaaa733f2084b97a521143b94b179 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Tue, 26 May 2026 01:38:48 +0900 Subject: [PATCH 26/31] [client] Match DNS wildcard handlers on label boundaries (#6255) --- client/internal/dns/handler_chain.go | 3 +- client/internal/dns/handler_chain_test.go | 61 +++++++++++++++++++++++ 2 files changed, 62 insertions(+), 2 deletions(-) diff --git a/client/internal/dns/handler_chain.go b/client/internal/dns/handler_chain.go index 57e7722d4..dc20146eb 100644 --- a/client/internal/dns/handler_chain.go +++ b/client/internal/dns/handler_chain.go @@ -339,8 +339,7 @@ func (c *HandlerChain) isHandlerMatch(qname string, entry HandlerEntry) bool { case entry.Pattern == ".": return true case entry.IsWildcard: - parts := strings.Split(strings.TrimSuffix(qname, entry.Pattern), ".") - return len(parts) >= 2 && strings.HasSuffix(qname, entry.Pattern) + return strings.HasSuffix(qname, "."+entry.Pattern) default: // For non-wildcard patterns: // If handler wants subdomain matching, allow suffix match diff --git a/client/internal/dns/handler_chain_test.go b/client/internal/dns/handler_chain_test.go index 034a760dc..b3db97ba3 100644 --- a/client/internal/dns/handler_chain_test.go +++ b/client/internal/dns/handler_chain_test.go @@ -164,6 +164,54 @@ func TestHandlerChain_ServeDNS_DomainMatching(t *testing.T) { matchSubdomains: true, shouldMatch: true, }, + { + name: "wildcard label-boundary mismatch (suffix overlap)", + handlerDomain: "*.b.test.", + queryDomain: "x.ab.test.", + isWildcard: true, + matchSubdomains: false, + shouldMatch: false, + }, + { + name: "wildcard label-boundary match", + handlerDomain: "*.b.test.", + queryDomain: "x.b.test.", + isWildcard: true, + matchSubdomains: false, + shouldMatch: true, + }, + { + name: "wildcard multi-label match", + handlerDomain: "*.b.test.", + queryDomain: "x.y.b.test.", + isWildcard: true, + matchSubdomains: false, + shouldMatch: true, + }, + { + name: "wildcard no match on multi-label apex", + handlerDomain: "*.b.test.", + queryDomain: "b.test.", + isWildcard: true, + matchSubdomains: false, + shouldMatch: false, + }, + { + name: "wildcard no match on unrelated suffix containment", + handlerDomain: "*.example.com.", + queryDomain: "notexample.com.", + isWildcard: true, + matchSubdomains: false, + shouldMatch: false, + }, + { + name: "wildcard accepts pattern registered without trailing dot", + handlerDomain: "*.b.test", + queryDomain: "x.b.test.", + isWildcard: true, + matchSubdomains: false, + shouldMatch: true, + }, } for _, tt := range tests { @@ -273,6 +321,19 @@ func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) { expectedCalls: 1, expectedHandler: 2, // highest priority matching handler should be called }, + { + name: "overlapping wildcard suffixes route to correct handler", + handlers: []struct { + pattern string + priority int + }{ + {pattern: "*.b.test.", priority: nbdns.PriorityDNSRoute}, + {pattern: "*.ab.test.", priority: nbdns.PriorityDNSRoute}, + }, + queryDomain: "app.ab.test.", + expectedCalls: 1, + expectedHandler: 1, + }, { name: "root zone with specific domain", handlers: []struct { From d542c60e2182ac83951056e2ac6a96b9e15c1593 Mon Sep 17 00:00:00 2001 From: Philip Laine Date: Mon, 25 May 2026 21:00:24 +0200 Subject: [PATCH 27/31] Refactor Linux system info to use syscalls (#6230) --- client/system/info_linux.go | 39 ++++++++++++------------------------- 1 file changed, 12 insertions(+), 27 deletions(-) diff --git a/client/system/info_linux.go b/client/system/info_linux.go index 6c7a23b95..de37a9f5b 100644 --- a/client/system/info_linux.go +++ b/client/system/info_linux.go @@ -3,15 +3,14 @@ package system import ( - "bytes" "context" "os" - "os/exec" "regexp" "runtime" - "strings" "time" + "golang.org/x/sys/unix" + log "github.com/sirupsen/logrus" "github.com/zcalusic/sysinfo" @@ -29,19 +28,11 @@ func UpdateStaticInfoAsync() { // GetInfo retrieves and parses the system information func GetInfo(ctx context.Context) *Info { - info := _getInfo() - for strings.Contains(info, "broken pipe") { - info = _getInfo() - time.Sleep(500 * time.Millisecond) - } - - osStr := strings.ReplaceAll(info, "\n", "") - osStr = strings.ReplaceAll(osStr, "\r\n", "") - osInfo := strings.Split(osStr, " ") + kernelName, kernelVersion, kernelPlatform := kernelInfo() osName, osVersion := readOsReleaseFile() if osName == "" { - osName = osInfo[3] + osName = kernelName } systemHostname, _ := os.Hostname() @@ -58,8 +49,8 @@ func GetInfo(ctx context.Context) *Info { } gio := &Info{ - Kernel: osInfo[0], - Platform: osInfo[2], + Kernel: kernelName, + Platform: kernelPlatform, OS: osName, OSVersion: osVersion, Hostname: extractDeviceName(ctx, systemHostname), @@ -67,7 +58,7 @@ func GetInfo(ctx context.Context) *Info { CPUs: runtime.NumCPU(), NetbirdVersion: version.NetbirdVersion(), UIVersion: extractUserAgent(ctx), - KernelVersion: osInfo[1], + KernelVersion: kernelVersion, NetworkAddresses: addrs, SystemSerialNumber: si.SystemSerialNumber, SystemProductName: si.SystemProductName, @@ -78,18 +69,12 @@ func GetInfo(ctx context.Context) *Info { return gio } -func _getInfo() string { - cmd := exec.Command("uname", "-srio") - cmd.Stdin = strings.NewReader("some") - var out bytes.Buffer - var stderr bytes.Buffer - cmd.Stdout = &out - cmd.Stderr = &stderr - err := cmd.Run() - if err != nil { - log.Warnf("getInfo: %s", err) +func kernelInfo() (string, string, string) { + var uts unix.Utsname + if err := unix.Uname(&uts); err != nil { + return "", "", "" } - return out.String() + return unix.ByteSliceToString(uts.Sysname[:]), unix.ByteSliceToString(uts.Release[:]), unix.ByteSliceToString(uts.Machine[:]) } func sysInfo() (string, string, string) { From e89b1e0596bd0185193b065faa34d13b216d18bb Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Tue, 26 May 2026 18:51:53 +0900 Subject: [PATCH 28/31] [proxy, client] Bound embed client WireGuard per-Device memory (#5962) --- client/embed/embed.go | 47 ++++++++ client/internal/engine.go | 23 ++++ go.mod | 2 +- go.sum | 4 +- proxy/cmd/proxy/cmd/debug.go | 30 +++++ proxy/cmd/proxy/cmd/root.go | 51 +++++++++ proxy/internal/debug/client.go | 57 ++++++++++ proxy/internal/debug/handler.go | 164 +++++++++++++++++++++++----- proxy/internal/roundtrip/netbird.go | 4 +- proxy/lifecycle.go | 6 + proxy/server.go | 7 ++ 11 files changed, 365 insertions(+), 30 deletions(-) diff --git a/client/embed/embed.go b/client/embed/embed.go index 7e7f6c337..04bc60fb8 100644 --- a/client/embed/embed.go +++ b/client/embed/embed.go @@ -12,6 +12,7 @@ import ( "sync" "github.com/sirupsen/logrus" + wgdevice "golang.zx2c4.com/wireguard/device" wgnetstack "golang.zx2c4.com/wireguard/tun/netstack" "github.com/netbirdio/netbird/client/iface" @@ -100,6 +101,26 @@ type Options struct { MTU *uint16 // DNSLabels defines additional DNS labels configured in the peer. DNSLabels []string + // Performance configures the tunnel's buffer pool cap and batch size. + Performance Performance +} + +// Performance configures the embedded client's tunnel memory/throughput knobs. +// +// These settings are process-global: any non-nil field also becomes the +// default for Clients constructed by later embed.New calls in the same +// process. Nil fields are ignored. +type Performance struct { + // PreallocatedBuffersPerPool caps the per-tunnel buffer pool. Zero + // leaves the pool unbounded. Lower values trade throughput for a + // tighter memory ceiling. May also be changed on a running Client via + // Client.SetPerformance, provided this field was nonzero at construction. + PreallocatedBuffersPerPool *uint32 + // MaxBatchSize overrides the number of packets the tunnel reads or + // writes per syscall, which also bounds eager buffer allocation per + // worker. Zero uses the platform default. Applied at construction + // only; ignored by Client.SetPerformance. + MaxBatchSize *uint32 } // validateCredentials checks that exactly one credential type is provided @@ -199,6 +220,13 @@ func New(opts Options) (*Client, error) { config.PrivateKey = opts.PrivateKey } + if opts.Performance.PreallocatedBuffersPerPool != nil { + wgdevice.SetPreallocatedBuffersPerPool(*opts.Performance.PreallocatedBuffersPerPool) + } + if opts.Performance.MaxBatchSize != nil { + wgdevice.SetMaxBatchSizeOverride(*opts.Performance.MaxBatchSize) + } + return &Client{ deviceName: opts.DeviceName, setupKey: opts.SetupKey, @@ -495,6 +523,25 @@ func (c *Client) VerifySSHHostKey(peerAddress string, key []byte) error { return sshcommon.VerifyHostKey(storedKey, key, peerAddress) } +// SetPerformance retunes a running Client. Only PreallocatedBuffersPerPool +// takes effect, and only when it was nonzero at construction; +// MaxBatchSize is construction-only and returns an error if set here. +// +// Returns ErrClientNotStarted / ErrEngineNotStarted if the Client is not +// running yet. +func (c *Client) SetPerformance(t Performance) error { + if t.MaxBatchSize != nil { + return errors.New("MaxBatchSize is construction-only and cannot be changed at runtime") + } + engine, err := c.getEngine() + if err != nil { + return err + } + return engine.SetPerformance(internal.Performance{ + PreallocatedBuffersPerPool: t.PreallocatedBuffersPerPool, + }) +} + // StartCapture begins capturing packets on this client's tunnel device. // Only one capture can be active at a time; starting a new one stops the previous. // Call StopCapture (or CaptureSession.Stop) to end it. diff --git a/client/internal/engine.go b/client/internal/engine.go index 3bd0d4621..b82eb95b7 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -1967,6 +1967,29 @@ func (e *Engine) GetClientMetrics() *metrics.ClientMetrics { return e.clientMetrics } +// Performance bundles runtime-adjustable tunnel pool knobs. +// See Engine.SetPerformance. Nil fields are ignored. +type Performance struct { + PreallocatedBuffersPerPool *uint32 +} + +// SetPerformance applies the given tuning to this engine's live Device. +func (e *Engine) SetPerformance(t Performance) error { + e.syncMsgMux.Lock() + defer e.syncMsgMux.Unlock() + if e.wgInterface == nil { + return fmt.Errorf("wg interface not initialized") + } + dev := e.wgInterface.GetWGDevice() + if dev == nil { + return fmt.Errorf("wg device not initialized") + } + if t.PreallocatedBuffersPerPool != nil { + dev.SetPreallocatedBuffersPerPool(*t.PreallocatedBuffersPerPool) + } + return nil +} + func findIPFromInterfaceName(ifaceName string) (net.IP, error) { iface, err := net.InterfaceByName(ifaceName) if err != nil { diff --git a/go.mod b/go.mod index 7c1a95e79..ea0d8d73d 100644 --- a/go.mod +++ b/go.mod @@ -335,7 +335,7 @@ replace github.com/kardianos/service => github.com/netbirdio/service v0.0.0-2024 replace github.com/getlantern/systray => github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949 -replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20260107100953-33b7c9d03db0 +replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20260523085312-4b4a4e36017f replace github.com/cloudflare/circl => codeberg.org/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6 diff --git a/go.sum b/go.sum index 53789f49d..f95efefa6 100644 --- a/go.sum +++ b/go.sum @@ -499,8 +499,8 @@ github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9ax github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ= github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ= -github.com/netbirdio/wireguard-go v0.0.0-20260107100953-33b7c9d03db0 h1:h/QnNzm7xzHPm+gajcblYUOclrW2FeNeDlUNj6tTWKQ= -github.com/netbirdio/wireguard-go v0.0.0-20260107100953-33b7c9d03db0/go.mod h1:rpwXGsirqLqN2L0JDJQlwOboGHmptD5ZD6T2VmcqhTw= +github.com/netbirdio/wireguard-go v0.0.0-20260523085312-4b4a4e36017f h1:ff2D57RBjWtyQ2wVwJOxOgXAXOe/J2lJWtSX0Bz/BRk= +github.com/netbirdio/wireguard-go v0.0.0-20260523085312-4b4a4e36017f/go.mod h1:rpwXGsirqLqN2L0JDJQlwOboGHmptD5ZD6T2VmcqhTw= github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 h1:zYyBkD/k9seD2A7fsi6Oo2LfFZAehjjQMERAvZLEDnQ= github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646/go.mod h1:jpp1/29i3P1S/RLdc7JQKbRpFeM1dOBd8T9ki5s+AY8= github.com/nicksnyder/go-i18n/v2 v2.5.1 h1:IxtPxYsR9Gp60cGXjfuR/llTqV8aYMsC472zD0D1vHk= diff --git a/proxy/cmd/proxy/cmd/debug.go b/proxy/cmd/proxy/cmd/debug.go index 49afc7638..360c7a516 100644 --- a/proxy/cmd/proxy/cmd/debug.go +++ b/proxy/cmd/proxy/cmd/debug.go @@ -109,6 +109,22 @@ var debugStopCmd = &cobra.Command{ SilenceUsage: true, } +var debugPerfCmd = &cobra.Command{ + Use: "perf ", + Short: "Live-retune the tunnel buffer pool cap on all running clients", + Args: cobra.ExactArgs(1), + RunE: runDebugPerfSet, + SilenceUsage: true, +} + +var debugRuntimeCmd = &cobra.Command{ + Use: "runtime", + Short: "Show runtime stats (heap, goroutines, RSS)", + Args: cobra.NoArgs, + RunE: runDebugRuntime, + SilenceUsage: true, +} + var debugCaptureCmd = &cobra.Command{ Use: "capture [filter expression]", Short: "Capture packets on a client's WireGuard interface", @@ -159,6 +175,8 @@ func init() { debugCmd.AddCommand(debugLogCmd) debugCmd.AddCommand(debugStartCmd) debugCmd.AddCommand(debugStopCmd) + debugCmd.AddCommand(debugPerfCmd) + debugCmd.AddCommand(debugRuntimeCmd) debugCmd.AddCommand(debugCaptureCmd) rootCmd.AddCommand(debugCmd) @@ -220,6 +238,18 @@ func runDebugStop(cmd *cobra.Command, args []string) error { return getDebugClient(cmd).StopClient(cmd.Context(), args[0]) } +func runDebugPerfSet(cmd *cobra.Command, args []string) error { + n, err := strconv.ParseUint(args[0], 10, 32) + if err != nil { + return fmt.Errorf("invalid value %q: %w", args[0], err) + } + return getDebugClient(cmd).PerfSet(cmd.Context(), uint32(n)) +} + +func runDebugRuntime(cmd *cobra.Command, _ []string) error { + return getDebugClient(cmd).Runtime(cmd.Context()) +} + func runDebugCapture(cmd *cobra.Command, args []string) error { duration, _ := cmd.Flags().GetDuration("duration") forcePcap, _ := cmd.Flags().GetBool("pcap") diff --git a/proxy/cmd/proxy/cmd/root.go b/proxy/cmd/proxy/cmd/root.go index 5970886da..405fa2789 100644 --- a/proxy/cmd/proxy/cmd/root.go +++ b/proxy/cmd/proxy/cmd/root.go @@ -15,11 +15,22 @@ import ( "github.com/netbirdio/netbird/shared/management/domain" + "github.com/netbirdio/netbird/client/embed" "github.com/netbirdio/netbird/proxy" nbacme "github.com/netbirdio/netbird/proxy/internal/acme" "github.com/netbirdio/netbird/util" ) +const ( + // envPreallocatedBuffers caps the per-tunnel buffer pool. Zero (unset) + // keeps the upstream uncapped default. + envPreallocatedBuffers = "NB_PROXY_PREALLOCATED_BUFFERS" + // envMaxBatchSize overrides the per-tunnel batch size, which controls + // how many buffers each receive/TUN worker eagerly allocates. Zero + // (unset) keeps the platform default. + envMaxBatchSize = "NB_PROXY_MAX_BATCH_SIZE" +) + const DefaultManagementURL = "https://api.netbird.io:443" // envProxyToken is the environment variable name for the proxy access token. @@ -148,6 +159,45 @@ func runServer(cmd *cobra.Command, args []string) error { logger.Infof("configured log level: %s", level) + var wgPool, wgBatch uint64 + var perf embed.Performance + if raw := os.Getenv(envPreallocatedBuffers); raw != "" { + n, err := strconv.ParseUint(raw, 10, 32) + if err != nil { + return fmt.Errorf("invalid %s %q: %w", envPreallocatedBuffers, raw, err) + } + wgPool = n + v := uint32(n) + perf.PreallocatedBuffersPerPool = &v + logger.Infof("tunnel preallocated buffers per pool: %d", n) + } + if raw := os.Getenv(envMaxBatchSize); raw != "" { + n, err := strconv.ParseUint(raw, 10, 32) + if err != nil { + return fmt.Errorf("invalid %s %q: %w", envMaxBatchSize, raw, err) + } + wgBatch = n + v := uint32(n) + perf.MaxBatchSize = &v + logger.Infof("tunnel max batch size override: %d", n) + } + if wgPool > 0 { + // Each bind recv goroutine (IPv4 + IPv6 + ICE relay) plus + // RoutineReadFromTUN eagerly reserves `batch` message buffers for + // the lifetime of the Device. A pool cap below that floor blocks + // the receive pipeline at startup. + batch := wgBatch + if batch == 0 { + batch = 128 + } + const recvGoroutines = 4 + floor := batch * recvGoroutines + if wgPool < floor { + logger.Warnf("%s=%d is below the eager-allocation floor (~%d for batch=%d); startup may deadlock", + envPreallocatedBuffers, wgPool, floor, batch) + } + } + switch forwardedProto { case "auto", "http", "https": default: @@ -188,6 +238,7 @@ func runServer(cmd *cobra.Command, args []string) error { CertLockMethod: nbacme.CertLockMethod(certLockMethod), WildcardCertDir: wildcardCertDir, WireguardPort: wgPort, + Performance: perf, ProxyProtocol: proxyProtocol, PreSharedKey: preSharedKey, SupportsCustomPorts: supportsCustomPorts, diff --git a/proxy/internal/debug/client.go b/proxy/internal/debug/client.go index 736781652..77772637c 100644 --- a/proxy/internal/debug/client.go +++ b/proxy/internal/debug/client.go @@ -333,6 +333,63 @@ func (c *Client) printLogLevelResult(data map[string]any) { } } +// PerfSet live-retunes the tunnel buffer pool cap on all running embedded +// clients. Batch size is not live-tunable; configure it at proxy startup. +func (c *Client) PerfSet(ctx context.Context, value uint32) error { + path := fmt.Sprintf("/debug/perf?value=%d", value) + return c.fetchAndPrint(ctx, path, c.printPerfSet) +} + +func (c *Client) printPerfSet(data map[string]any) { + if errMsg, ok := data["error"].(string); ok && errMsg != "" { + c.printError(data) + return + } + val, _ := data["value"].(float64) + applied, _ := data["applied"].(float64) + _, _ = fmt.Fprintf(c.out, "Pool cap set to: %d\n", uint32(val)) + _, _ = fmt.Fprintf(c.out, "Applied to %d live clients\n", int(applied)) + if failed, ok := data["failed"].(map[string]any); ok && len(failed) > 0 { + _, _ = fmt.Fprintln(c.out, "Failed:") + for k, v := range failed { + _, _ = fmt.Fprintf(c.out, " %s: %v\n", k, v) + } + } +} + +// Runtime fetches runtime stats (heap, goroutines, RSS). +func (c *Client) Runtime(ctx context.Context) error { + return c.fetchAndPrint(ctx, "/debug/runtime", c.printRuntime) +} + +func (c *Client) printRuntime(data map[string]any) { + i := func(k string) uint64 { + v, _ := data[k].(float64) + return uint64(v) + } + mb := func(n uint64) string { return fmt.Sprintf("%.1f MB", float64(n)/(1<<20)) } + + _, _ = fmt.Fprintf(c.out, "Uptime: %v\n", data["uptime"]) + _, _ = fmt.Fprintf(c.out, "Go: %v on %d CPU (GOMAXPROCS=%d)\n", data["go_version"], uint32(i("num_cpu")), uint32(i("gomaxprocs"))) + _, _ = fmt.Fprintf(c.out, "Goroutines: %d\n", i("goroutines")) + _, _ = fmt.Fprintf(c.out, "Live objects: %d\n", i("live_objects")) + _, _ = fmt.Fprintf(c.out, "GC: %d cycles, %v pause total\n", i("num_gc"), time.Duration(i("pause_total_ns"))) + _, _ = fmt.Fprintln(c.out, "Heap:") + _, _ = fmt.Fprintf(c.out, " alloc: %s\n", mb(i("heap_alloc"))) + _, _ = fmt.Fprintf(c.out, " in-use: %s\n", mb(i("heap_inuse"))) + _, _ = fmt.Fprintf(c.out, " idle: %s\n", mb(i("heap_idle"))) + _, _ = fmt.Fprintf(c.out, " released: %s\n", mb(i("heap_released"))) + _, _ = fmt.Fprintf(c.out, " sys: %s\n", mb(i("heap_sys"))) + _, _ = fmt.Fprintf(c.out, "Total sys: %s\n", mb(i("sys"))) + if _, ok := data["vm_rss"]; ok { + _, _ = fmt.Fprintln(c.out, "Process:") + _, _ = fmt.Fprintf(c.out, " VmRSS: %s\n", mb(i("vm_rss"))) + _, _ = fmt.Fprintf(c.out, " VmSize: %s\n", mb(i("vm_size"))) + _, _ = fmt.Fprintf(c.out, " VmData: %s\n", mb(i("vm_data"))) + } + _, _ = fmt.Fprintf(c.out, "Clients: %d (%d started)\n", i("clients"), i("started")) +} + // StartClient starts a specific client. func (c *Client) StartClient(ctx context.Context, accountID string) error { path := "/debug/clients/" + url.PathEscape(accountID) + "/start" diff --git a/proxy/internal/debug/handler.go b/proxy/internal/debug/handler.go index 1dbfe1522..826c6817f 100644 --- a/proxy/internal/debug/handler.go +++ b/proxy/internal/debug/handler.go @@ -11,6 +11,8 @@ import ( "maps" "net" "net/http" + "os" + "runtime" "slices" "strconv" "strings" @@ -59,6 +61,7 @@ func sortedAccountIDs(m map[types.AccountID]roundtrip.ClientDebugInfo) []types.A type clientProvider interface { GetClient(accountID types.AccountID) (*nbembed.Client, bool) ListClientsForDebug() map[types.AccountID]roundtrip.ClientDebugInfo + ListClientsForStartup() map[types.AccountID]*nbembed.Client } // InboundListenerInfo describes a per-account inbound listener as @@ -165,6 +168,10 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { h.handleListClients(w, r, wantJSON) case "/debug/health": h.handleHealth(w, r, wantJSON) + case "/debug/perf": + h.handlePerf(w, r) + case "/debug/runtime": + h.handleRuntime(w, r) default: if h.handleClientRoutes(w, r, path, wantJSON) { return @@ -258,10 +265,10 @@ func (h *Handler) handleIndex(w http.ResponseWriter, _ *http.Request, wantJSON b } if wantJSON { - clientsJSON := make([]map[string]interface{}, 0, len(clients)) + clientsJSON := make([]map[string]any, 0, len(clients)) for _, id := range sortedIDs { info := clients[id] - clientsJSON = append(clientsJSON, map[string]interface{}{ + clientsJSON = append(clientsJSON, map[string]any{ "account_id": info.AccountID, "service_count": info.ServiceCount, "service_keys": info.ServiceKeys, @@ -270,7 +277,7 @@ func (h *Handler) handleIndex(w http.ResponseWriter, _ *http.Request, wantJSON b "age": time.Since(info.CreatedAt).Round(time.Second).String(), }) } - resp := map[string]interface{}{ + resp := map[string]any{ "version": version.NetbirdVersion(), "uptime": time.Since(h.startTime).Round(time.Second).String(), "client_count": len(clients), @@ -352,10 +359,10 @@ func (h *Handler) handleListClients(w http.ResponseWriter, _ *http.Request, want if h.inbound != nil { inboundAll = h.inbound.InboundListeners() } - clientsJSON := make([]map[string]interface{}, 0, len(clients)) + clientsJSON := make([]map[string]any, 0, len(clients)) for _, id := range sortedIDs { info := clients[id] - row := map[string]interface{}{ + row := map[string]any{ "account_id": info.AccountID, "service_count": info.ServiceCount, "service_keys": info.ServiceKeys, @@ -368,7 +375,7 @@ func (h *Handler) handleListClients(w http.ResponseWriter, _ *http.Request, want } clientsJSON = append(clientsJSON, row) } - resp := map[string]interface{}{ + resp := map[string]any{ "uptime": time.Since(h.startTime).Round(time.Second).String(), "client_count": len(clients), "clients": clientsJSON, @@ -458,7 +465,7 @@ func (h *Handler) handleClientStatus(w http.ResponseWriter, r *http.Request, acc }) if wantJSON { - resp := map[string]interface{}{ + resp := map[string]any{ "account_id": accountID, "status": overview.FullDetailSummary(), } @@ -557,20 +564,20 @@ func (h *Handler) handleClientTools(w http.ResponseWriter, _ *http.Request, acco func (h *Handler) handlePingTCP(w http.ResponseWriter, r *http.Request, accountID types.AccountID) { client, ok := h.provider.GetClient(accountID) if !ok { - h.writeJSON(w, map[string]interface{}{"error": "client not found"}) + h.writeJSON(w, map[string]any{"error": "client not found"}) return } host := r.URL.Query().Get("host") portStr := r.URL.Query().Get("port") if host == "" || portStr == "" { - h.writeJSON(w, map[string]interface{}{"error": "host and port parameters required"}) + h.writeJSON(w, map[string]any{"error": "host and port parameters required"}) return } port, err := strconv.Atoi(portStr) if err != nil || port < 1 || port > 65535 { - h.writeJSON(w, map[string]interface{}{"error": "invalid port"}) + h.writeJSON(w, map[string]any{"error": "invalid port"}) return } @@ -594,7 +601,7 @@ func (h *Handler) handlePingTCP(w http.ResponseWriter, r *http.Request, accountI conn, err := client.Dial(ctx, network, address) if err != nil { - h.writeJSON(w, map[string]interface{}{ + h.writeJSON(w, map[string]any{ "success": false, "host": host, "port": port, @@ -609,39 +616,38 @@ func (h *Handler) handlePingTCP(w http.ResponseWriter, r *http.Request, accountI } latency := time.Since(start) - resp := map[string]interface{}{ + h.writeJSON(w, map[string]any{ "success": true, "host": host, "port": port, "remote": remote, "latency_ms": latency.Milliseconds(), "latency": formatDuration(latency), - } - h.writeJSON(w, resp) + }) } func (h *Handler) handleLogLevel(w http.ResponseWriter, r *http.Request, accountID types.AccountID) { client, ok := h.provider.GetClient(accountID) if !ok { - h.writeJSON(w, map[string]interface{}{"error": "client not found"}) + h.writeJSON(w, map[string]any{"error": "client not found"}) return } level := r.URL.Query().Get("level") if level == "" { - h.writeJSON(w, map[string]interface{}{"error": "level parameter required (trace, debug, info, warn, error)"}) + h.writeJSON(w, map[string]any{"error": "level parameter required (trace, debug, info, warn, error)"}) return } if err := client.SetLogLevel(level); err != nil { - h.writeJSON(w, map[string]interface{}{ + h.writeJSON(w, map[string]any{ "success": false, "error": err.Error(), }) return } - h.writeJSON(w, map[string]interface{}{ + h.writeJSON(w, map[string]any{ "success": true, "level": level, }) @@ -652,7 +658,7 @@ const clientActionTimeout = 30 * time.Second func (h *Handler) handleClientStart(w http.ResponseWriter, r *http.Request, accountID types.AccountID) { client, ok := h.provider.GetClient(accountID) if !ok { - h.writeJSON(w, map[string]interface{}{"error": "client not found"}) + h.writeJSON(w, map[string]any{"error": "client not found"}) return } @@ -660,14 +666,14 @@ func (h *Handler) handleClientStart(w http.ResponseWriter, r *http.Request, acco defer cancel() if err := client.Start(ctx); err != nil { - h.writeJSON(w, map[string]interface{}{ + h.writeJSON(w, map[string]any{ "success": false, "error": err.Error(), }) return } - h.writeJSON(w, map[string]interface{}{ + h.writeJSON(w, map[string]any{ "success": true, "message": "client started", }) @@ -676,7 +682,7 @@ func (h *Handler) handleClientStart(w http.ResponseWriter, r *http.Request, acco func (h *Handler) handleClientStop(w http.ResponseWriter, r *http.Request, accountID types.AccountID) { client, ok := h.provider.GetClient(accountID) if !ok { - h.writeJSON(w, map[string]interface{}{"error": "client not found"}) + h.writeJSON(w, map[string]any{"error": "client not found"}) return } @@ -684,19 +690,125 @@ func (h *Handler) handleClientStop(w http.ResponseWriter, r *http.Request, accou defer cancel() if err := client.Stop(ctx); err != nil { - h.writeJSON(w, map[string]interface{}{ + h.writeJSON(w, map[string]any{ "success": false, "error": err.Error(), }) return } - h.writeJSON(w, map[string]interface{}{ + h.writeJSON(w, map[string]any{ "success": true, "message": "client stopped", }) } +func (h *Handler) handlePerf(w http.ResponseWriter, r *http.Request) { + raw := r.URL.Query().Get("value") + if raw == "" { + http.Error(w, "value parameter is required", http.StatusBadRequest) + return + } + n, err := strconv.ParseUint(raw, 10, 32) + if err != nil { + http.Error(w, fmt.Sprintf("invalid value %q: %v", raw, err), http.StatusBadRequest) + return + } + + capN := uint32(n) + applied := 0 + failed := map[string]string{} + for accountID, client := range h.provider.ListClientsForStartup() { + if err := client.SetPerformance(nbembed.Performance{PreallocatedBuffersPerPool: &capN}); err != nil { + failed[string(accountID)] = err.Error() + continue + } + applied++ + } + + resp := map[string]any{ + "success": true, + "value": capN, + "applied": applied, + } + if len(failed) > 0 { + resp["failed"] = failed + } + h.writeJSON(w, resp) +} + +// handleRuntime returns cheap runtime and process stats. Safe to hit on a +// running proxy; does not read pprof profiles. +func (h *Handler) handleRuntime(w http.ResponseWriter, _ *http.Request) { + var m runtime.MemStats + runtime.ReadMemStats(&m) + + clients := h.provider.ListClientsForDebug() + started := 0 + for _, c := range clients { + if c.HasClient { + started++ + } + } + + resp := map[string]any{ + "uptime": time.Since(h.startTime).Round(time.Second).String(), + "goroutines": runtime.NumGoroutine(), + "num_cpu": runtime.NumCPU(), + "gomaxprocs": runtime.GOMAXPROCS(0), + "go_version": runtime.Version(), + "heap_alloc": m.HeapAlloc, + "heap_inuse": m.HeapInuse, + "heap_idle": m.HeapIdle, + "heap_released": m.HeapReleased, + "heap_sys": m.HeapSys, + "sys": m.Sys, + "live_objects": m.Mallocs - m.Frees, + "num_gc": m.NumGC, + "pause_total_ns": m.PauseTotalNs, + "clients": len(clients), + "started": started, + } + + if proc := readProcStatus(); proc != nil { + resp["vm_rss"] = proc["VmRSS"] + resp["vm_size"] = proc["VmSize"] + resp["vm_data"] = proc["VmData"] + } + + h.writeJSON(w, resp) +} + +// readProcStatus parses /proc/self/status on Linux and returns size fields +// in bytes. Returns nil on non-Linux or read failure. +func readProcStatus() map[string]uint64 { + raw, err := os.ReadFile("/proc/self/status") + if err != nil { + return nil + } + out := map[string]uint64{} + for _, line := range strings.Split(string(raw), "\n") { + k, v, ok := strings.Cut(line, ":") + if !ok { + continue + } + if k != "VmRSS" && k != "VmSize" && k != "VmData" { + continue + } + fields := strings.Fields(v) + if len(fields) < 1 { + continue + } + n, err := strconv.ParseUint(fields[0], 10, 64) + if err != nil { + continue + } + // Values are reported in kB. + out[k] = n * 1024 + } + return out +} + const maxCaptureDuration = 30 * time.Minute // handleCapture streams a pcap or text packet capture for the given client. @@ -825,7 +937,7 @@ func (h *Handler) handleHealth(w http.ResponseWriter, r *http.Request, wantJSON h.writeJSON(w, resp) } -func (h *Handler) renderTemplate(w http.ResponseWriter, name string, data interface{}) { +func (h *Handler) renderTemplate(w http.ResponseWriter, name string, data any) { w.Header().Set("Content-Type", "text/html; charset=utf-8") tmpl := h.getTemplates() if tmpl == nil { @@ -838,7 +950,7 @@ func (h *Handler) renderTemplate(w http.ResponseWriter, name string, data interf } } -func (h *Handler) writeJSON(w http.ResponseWriter, v interface{}) { +func (h *Handler) writeJSON(w http.ResponseWriter, v any) { w.Header().Set("Content-Type", "application/json") enc := json.NewEncoder(w) enc.SetIndent("", " ") diff --git a/proxy/internal/roundtrip/netbird.go b/proxy/internal/roundtrip/netbird.go index 133e86f05..11bca22e3 100644 --- a/proxy/internal/roundtrip/netbird.go +++ b/proxy/internal/roundtrip/netbird.go @@ -131,6 +131,7 @@ type ClientConfig struct { MgmtAddr string WGPort uint16 PreSharedKey string + Performance embed.Performance // BlockInbound mirrors embed.Options.BlockInbound. Set to true on the // standalone proxy where the embedded client never accepts inbound; // set to false on the private/embedded proxy so the engine creates @@ -306,7 +307,7 @@ func (n *NetBird) createClientEntry(ctx context.Context, accountID types.Account ManagementURL: n.clientCfg.MgmtAddr, PrivateKey: privateKey.String(), LogLevel: log.WarnLevel.String(), - BlockInbound: n.clientCfg.BlockInbound, + BlockInbound: n.clientCfg.BlockInbound, // The embedded proxy peer must never be a stepping stone into // the proxy host's LAN: it only exists to reach NetBird mesh // targets or, when direct_upstream is set, the host network @@ -315,6 +316,7 @@ func (n *NetBird) createClientEntry(ctx context.Context, accountID types.Account BlockLANAccess: true, WireguardPort: &wgPort, PreSharedKey: n.clientCfg.PreSharedKey, + Performance: n.clientCfg.Performance, }) if err != nil { return nil, fmt.Errorf("create netbird client: %w", err) diff --git a/proxy/lifecycle.go b/proxy/lifecycle.go index 6cb420722..9787f237e 100644 --- a/proxy/lifecycle.go +++ b/proxy/lifecycle.go @@ -6,6 +6,7 @@ import ( log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/client/embed" "github.com/netbirdio/netbird/proxy/internal/acme" ) @@ -89,6 +90,10 @@ type Config struct { // PreSharedKey is the WireGuard pre-shared key used between the // proxy's embedded clients and peers. PreSharedKey string + // Performance configures the tunnel pool/batch sizes for every + // embedded client this proxy creates. Zero values fall back to + // upstream defaults. + Performance embed.Performance // SupportsCustomPorts indicates whether the proxy can bind arbitrary // ports for TCP/UDP/TLS services. @@ -148,6 +153,7 @@ func New(cfg Config) *Server { WireguardPort: cfg.WireguardPort, ProxyProtocol: cfg.ProxyProtocol, PreSharedKey: cfg.PreSharedKey, + Performance: cfg.Performance, SupportsCustomPorts: cfg.SupportsCustomPorts, RequireSubdomain: cfg.RequireSubdomain, Private: cfg.Private, diff --git a/proxy/server.go b/proxy/server.go index 63a0c577a..037da925c 100644 --- a/proxy/server.go +++ b/proxy/server.go @@ -41,6 +41,7 @@ import ( goproto "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/known/timestamppb" + "github.com/netbirdio/netbird/client/embed" "github.com/netbirdio/netbird/proxy/internal/accesslog" "github.com/netbirdio/netbird/proxy/internal/acme" "github.com/netbirdio/netbird/proxy/internal/auth" @@ -185,6 +186,9 @@ type Server struct { // single-account deployments; multiple accounts will fail to bind // the same port. WireguardPort uint16 + // Performance configures the tunnel pool/batch sizes for every + // embedded client this proxy spawns. + Performance embed.Performance // ProxyProtocol enables PROXY protocol (v1/v2) on TCP listeners. // When enabled, the real client IP is extracted from the PROXY header // sent by upstream L4 proxies that support PROXY protocol. @@ -333,6 +337,8 @@ func (s *Server) Start(ctx context.Context) error { s.runCancel = runCancel s.initNetBirdClient() + // Create health checker before the mapping worker so it can track + // management connectivity from the first stream connection. s.healthChecker = health.NewChecker(s.Logger, s.netbird) s.crowdsecRegistry = crowdsec.NewRegistry(s.CrowdSecAPIURL, s.CrowdSecAPIKey, log.NewEntry(s.Logger)) @@ -529,6 +535,7 @@ func (s *Server) initNetBirdClient() { MgmtAddr: s.ManagementAddress, WGPort: s.WireguardPort, PreSharedKey: s.PreSharedKey, + Performance: s.Performance, // On --private the embedded client serves per-account inbound // listeners and must apply management's ACL: keep BlockInbound off // so the engine creates the ACL manager. On the standalone proxy From 6771e35d57ae2411123aa0e4a762f437a6d4c0bb Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Tue, 26 May 2026 21:32:39 +0900 Subject: [PATCH 29/31] [client] Release js.FuncOf callbacks in wasm ssh and rdp to prevent leaks (#5982) --- client/wasm/internal/rdp/cert_validation.go | 31 ++++-- client/wasm/internal/rdp/rdcleanpath.go | 105 +++++++++++++++----- client/wasm/internal/ssh/handlers.go | 27 +++-- 3 files changed, 121 insertions(+), 42 deletions(-) diff --git a/client/wasm/internal/rdp/cert_validation.go b/client/wasm/internal/rdp/cert_validation.go index 1678c3996..47e30b6e6 100644 --- a/client/wasm/internal/rdp/cert_validation.go +++ b/client/wasm/internal/rdp/cert_validation.go @@ -6,6 +6,7 @@ import ( "crypto/tls" "crypto/x509" "fmt" + "sync" "syscall/js" "time" @@ -13,7 +14,7 @@ import ( ) const ( - certValidationTimeout = 60 * time.Second + certValidationTimeout = 5 * time.Minute ) func (p *RDCleanPathProxy) validateCertificateWithJS(conn *proxyConnection, certChain [][]byte) (bool, error) { @@ -46,17 +47,31 @@ func (p *RDCleanPathProxy) validateCertificateWithJS(conn *proxyConnection, cert promise := conn.wsHandlers.Call("onCertificateRequest", certInfo) - resultChan := make(chan bool) - errorChan := make(chan error) + resultChan := make(chan bool, 1) + errorChan := make(chan error, 1) - promise.Call("then", js.FuncOf(func(this js.Value, args []js.Value) interface{} { - result := args[0].Bool() - resultChan <- result + // Release from inside the callbacks so a post-timeout promise resolution + // does not invoke an already-released func. + var thenFn, catchFn js.Func + var releaseOnce sync.Once + release := func() { + releaseOnce.Do(func() { + thenFn.Release() + catchFn.Release() + }) + } + thenFn = js.FuncOf(func(this js.Value, args []js.Value) interface{} { + defer release() + resultChan <- args[0].Bool() return nil - })).Call("catch", js.FuncOf(func(this js.Value, args []js.Value) interface{} { + }) + catchFn = js.FuncOf(func(this js.Value, args []js.Value) interface{} { + defer release() errorChan <- fmt.Errorf("certificate validation failed") return nil - })) + }) + + promise.Call("then", thenFn).Call("catch", catchFn) select { case result := <-resultChan: diff --git a/client/wasm/internal/rdp/rdcleanpath.go b/client/wasm/internal/rdp/rdcleanpath.go index 6c36fdec6..ee420dca4 100644 --- a/client/wasm/internal/rdp/rdcleanpath.go +++ b/client/wasm/internal/rdp/rdcleanpath.go @@ -11,6 +11,7 @@ import ( "io" "net" "sync" + "sync/atomic" "syscall/js" "time" @@ -57,6 +58,8 @@ type RDCleanPathProxy struct { } activeConnections map[string]*proxyConnection destinations map[string]string + pendingHandlers map[string]js.Func + nextID atomic.Uint64 mu sync.Mutex } @@ -66,8 +69,15 @@ type proxyConnection struct { rdpConn net.Conn tlsConn *tls.Conn wsHandlers js.Value - ctx context.Context - cancel context.CancelFunc + // Go-side callbacks exposed to JS. js.FuncOf pins the Go closure in a + // global handle map and MUST be released, otherwise every connection + // leaks the Go memory the closure captures. + wsHandlerFn js.Func + onMessageFn js.Func + onCloseFn js.Func + cleanupOnce sync.Once + ctx context.Context + cancel context.CancelFunc } // NewRDCleanPathProxy creates a new RDCleanPath proxy @@ -80,7 +90,11 @@ func NewRDCleanPathProxy(client interface { } } -// CreateProxy creates a new proxy endpoint for the given destination +// CreateProxy creates a new proxy endpoint for the given destination. +// The registered handler fn and its destinations/pendingHandlers entries are +// only released once a connection is established and cleanupConnection runs. +// If a caller invokes CreateProxy but never connects to the returned URL, +// those entries stay pinned for the lifetime of the page. func (p *RDCleanPathProxy) CreateProxy(hostname, port string) js.Value { destination := net.JoinHostPort(hostname, port) @@ -88,7 +102,7 @@ func (p *RDCleanPathProxy) CreateProxy(hostname, port string) js.Value { resolve := args[0] go func() { - proxyID := fmt.Sprintf("proxy_%d", len(p.activeConnections)) + proxyID := fmt.Sprintf("proxy_%d", p.nextID.Add(1)) p.mu.Lock() if p.destinations == nil { @@ -100,7 +114,7 @@ func (p *RDCleanPathProxy) CreateProxy(hostname, port string) js.Value { proxyURL := fmt.Sprintf("%s://%s/%s", RDCleanPathProxyScheme, RDCleanPathProxyHost, proxyID) // Register the WebSocket handler for this specific proxy - js.Global().Set(fmt.Sprintf("handleRDCleanPathWebSocket_%s", proxyID), js.FuncOf(func(_ js.Value, args []js.Value) any { + handlerFn := js.FuncOf(func(_ js.Value, args []js.Value) any { if len(args) < 1 { return js.ValueOf("error: requires WebSocket argument") } @@ -108,7 +122,14 @@ func (p *RDCleanPathProxy) CreateProxy(hostname, port string) js.Value { ws := args[0] p.HandleWebSocketConnection(ws, proxyID) return nil - })) + }) + p.mu.Lock() + if p.pendingHandlers == nil { + p.pendingHandlers = make(map[string]js.Func) + } + p.pendingHandlers[proxyID] = handlerFn + p.mu.Unlock() + js.Global().Set(fmt.Sprintf("handleRDCleanPathWebSocket_%s", proxyID), handlerFn) log.Infof("Created RDCleanPath proxy endpoint: %s for destination: %s", proxyURL, destination) resolve.Invoke(proxyURL) @@ -142,6 +163,10 @@ func (p *RDCleanPathProxy) HandleWebSocketConnection(ws js.Value, proxyID string p.mu.Lock() p.activeConnections[proxyID] = conn + if fn, ok := p.pendingHandlers[proxyID]; ok { + conn.wsHandlerFn = fn + delete(p.pendingHandlers, proxyID) + } p.mu.Unlock() p.setupWebSocketHandlers(ws, conn) @@ -150,7 +175,7 @@ func (p *RDCleanPathProxy) HandleWebSocketConnection(ws js.Value, proxyID string } func (p *RDCleanPathProxy) setupWebSocketHandlers(ws js.Value, conn *proxyConnection) { - ws.Set("onGoMessage", js.FuncOf(func(this js.Value, args []js.Value) any { + conn.onMessageFn = js.FuncOf(func(this js.Value, args []js.Value) any { if len(args) < 1 { return nil } @@ -158,13 +183,15 @@ func (p *RDCleanPathProxy) setupWebSocketHandlers(ws js.Value, conn *proxyConnec data := args[0] go p.handleWebSocketMessage(conn, data) return nil - })) + }) + ws.Set("onGoMessage", conn.onMessageFn) - ws.Set("onGoClose", js.FuncOf(func(_ js.Value, args []js.Value) any { + conn.onCloseFn = js.FuncOf(func(_ js.Value, args []js.Value) any { log.Debug("WebSocket closed by JavaScript") conn.cancel() return nil - })) + }) + ws.Set("onGoClose", conn.onCloseFn) } func (p *RDCleanPathProxy) handleWebSocketMessage(conn *proxyConnection, data js.Value) { @@ -261,25 +288,49 @@ func (p *RDCleanPathProxy) handleDirectRDP(conn *proxyConnection, firstPacket [] } func (p *RDCleanPathProxy) cleanupConnection(conn *proxyConnection) { - log.Debugf("Cleaning up connection %s", conn.id) - conn.cancel() - if conn.tlsConn != nil { - log.Debug("Closing TLS connection") - if err := conn.tlsConn.Close(); err != nil { - log.Debugf("Error closing TLS connection: %v", err) + conn.cleanupOnce.Do(func() { + log.Debugf("Cleaning up connection %s", conn.id) + conn.cancel() + if conn.tlsConn != nil { + log.Debug("Closing TLS connection") + if err := conn.tlsConn.Close(); err != nil { + log.Debugf("Error closing TLS connection: %v", err) + } + conn.tlsConn = nil } - conn.tlsConn = nil - } - if conn.rdpConn != nil { - log.Debug("Closing TCP connection") - if err := conn.rdpConn.Close(); err != nil { - log.Debugf("Error closing TCP connection: %v", err) + if conn.rdpConn != nil { + log.Debug("Closing TCP connection") + if err := conn.rdpConn.Close(); err != nil { + log.Debugf("Error closing TCP connection: %v", err) + } + conn.rdpConn = nil } - conn.rdpConn = nil - } - p.mu.Lock() - delete(p.activeConnections, conn.id) - p.mu.Unlock() + js.Global().Delete(fmt.Sprintf("handleRDCleanPathWebSocket_%s", conn.id)) + + // Detach before releasing so late JS calls surface as TypeError instead + // of silent "call to released function". + if conn.wsHandlers.Truthy() { + conn.wsHandlers.Set("onGoMessage", js.Undefined()) + conn.wsHandlers.Set("onGoClose", js.Undefined()) + } + + // wsHandlerFn may be zero-value if the pending handler lookup missed. + if conn.wsHandlerFn.Truthy() { + conn.wsHandlerFn.Release() + } + if conn.onMessageFn.Truthy() { + conn.onMessageFn.Release() + } + if conn.onCloseFn.Truthy() { + conn.onCloseFn.Release() + } + + p.mu.Lock() + delete(p.activeConnections, conn.id) + delete(p.destinations, conn.id) + delete(p.pendingHandlers, conn.id) + p.mu.Unlock() + }) } func (p *RDCleanPathProxy) sendToWebSocket(conn *proxyConnection, data []byte) { diff --git a/client/wasm/internal/ssh/handlers.go b/client/wasm/internal/ssh/handlers.go index ea64eb0aa..6d33916a5 100644 --- a/client/wasm/internal/ssh/handlers.go +++ b/client/wasm/internal/ssh/handlers.go @@ -13,7 +13,7 @@ import ( func CreateJSInterface(client *Client) js.Value { jsInterface := js.Global().Get("Object").Call("create", js.Null()) - jsInterface.Set("write", js.FuncOf(func(this js.Value, args []js.Value) any { + writeFunc := js.FuncOf(func(this js.Value, args []js.Value) any { if len(args) < 1 { return js.ValueOf(false) } @@ -32,9 +32,10 @@ func CreateJSInterface(client *Client) js.Value { _, err := client.Write(bytes) return js.ValueOf(err == nil) - })) + }) + jsInterface.Set("write", writeFunc) - jsInterface.Set("resize", js.FuncOf(func(this js.Value, args []js.Value) any { + resizeFunc := js.FuncOf(func(this js.Value, args []js.Value) any { if len(args) < 2 { return js.ValueOf(false) } @@ -42,14 +43,26 @@ func CreateJSInterface(client *Client) js.Value { rows := args[1].Int() err := client.Resize(cols, rows) return js.ValueOf(err == nil) - })) + }) + jsInterface.Set("resize", resizeFunc) - jsInterface.Set("close", js.FuncOf(func(this js.Value, args []js.Value) any { + closeFunc := js.FuncOf(func(this js.Value, args []js.Value) any { client.Close() return js.Undefined() - })) + }) + jsInterface.Set("close", closeFunc) - go readLoop(client, jsInterface) + go func() { + readLoop(client, jsInterface) + // Detach before releasing so late JS calls surface as TypeError instead + // of silent "call to released function". + jsInterface.Set("write", js.Undefined()) + jsInterface.Set("resize", js.Undefined()) + jsInterface.Set("close", js.Undefined()) + writeFunc.Release() + resizeFunc.Release() + closeFunc.Release() + }() return jsInterface } From 1fbb5e6d5d11233fd9dc6324eff9fe66c1d6becb Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Tue, 26 May 2026 16:37:58 +0200 Subject: [PATCH 30/31] [management] fix owner role update (#6264) --- management/server/user.go | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/management/server/user.go b/management/server/user.go index 892d982e7..60571a702 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -762,7 +762,7 @@ func (am *DefaultAccountManager) processUserUpdate(ctx context.Context, transact } // Ensure the initiator still has admin privileges - if initiatorUser.HasAdminPower() && !freshInitiator.HasAdminPower() { + if !freshInitiator.HasAdminPower() { return false, nil, nil, nil, status.Errorf(status.PermissionDenied, "initiator role was changed during request processing") } initiatorUser = freshInitiator @@ -906,19 +906,23 @@ func validateUserUpdate(groupsMap map[string]*types.Group, initiatorUser, oldUse return nil } + if !initiatorUser.HasAdminPower() { + return status.Errorf(status.PermissionDenied, "only admins and owners can update users") + } + if initiatorUser.HasAdminPower() && initiatorUser.Id == update.Id && oldUser.Blocked != update.Blocked { return status.Errorf(status.PermissionDenied, "admins can't block or unblock themselves") } if initiatorUser.HasAdminPower() && initiatorUser.Id == update.Id && update.Role != initiatorUser.Role { return status.Errorf(status.PermissionDenied, "admins can't change their role") } - if initiatorUser.Role == types.UserRoleAdmin && oldUser.Role == types.UserRoleOwner && update.Role != oldUser.Role { + if initiatorUser.Role != types.UserRoleOwner && oldUser.Role == types.UserRoleOwner && update.Role != oldUser.Role { return status.Errorf(status.PermissionDenied, "only owners can remove owner role from their user") } - if initiatorUser.Role == types.UserRoleAdmin && oldUser.Role == types.UserRoleOwner && update.IsBlocked() && !oldUser.IsBlocked() { + if oldUser.Role == types.UserRoleOwner && update.IsBlocked() && !oldUser.IsBlocked() { return status.Errorf(status.PermissionDenied, "unable to block owner user") } - if initiatorUser.Role == types.UserRoleAdmin && update.Role == types.UserRoleOwner && update.Role != oldUser.Role { + if initiatorUser.Role != types.UserRoleOwner && update.Role == types.UserRoleOwner && update.Role != oldUser.Role { return status.Errorf(status.PermissionDenied, "only owners can add owner role to other users") } if oldUser.IsServiceUser && update.Role == types.UserRoleOwner { From 14af1795567b91c0d4ceb7e5166eacb5015daee7 Mon Sep 17 00:00:00 2001 From: Bethuel Mmbaga Date: Tue, 26 May 2026 17:44:28 +0300 Subject: [PATCH 31/31] [management] Refactor management server bootstrap (#6256) --- client/cmd/testutil_test.go | 4 +- client/internal/engine_test.go | 6 +- client/server/server_test.go | 4 +- combined/cmd/root.go | 8 ++- management/internals/server/boot.go | 38 ++++++++++-- management/internals/server/controllers.go | 3 +- management/internals/server/modules.go | 13 ++-- management/internals/server/server.go | 6 +- management/server/http/handler.go | 31 ++-------- .../server/http/middleware/auth_middleware.go | 11 ++-- .../http/middleware/auth_middleware_test.go | 9 +++ .../testing/testing_tools/channel/channel.go | 7 ++- .../validator/validator.go | 62 +++++++++++++++++++ shared/management/client/client_test.go | 4 +- 14 files changed, 146 insertions(+), 60 deletions(-) create mode 100644 management/server/integrations/integrated_validator/validator/validator.go diff --git a/client/cmd/testutil_test.go b/client/cmd/testutil_test.go index c24965e8d..205327ef5 100644 --- a/client/cmd/testutil_test.go +++ b/client/cmd/testutil_test.go @@ -11,7 +11,7 @@ import ( "go.opentelemetry.io/otel" "google.golang.org/grpc" - "github.com/netbirdio/management-integrations/integrations" + "github.com/netbirdio/netbird/management/server/integrations/integrated_validator/validator" nbcache "github.com/netbirdio/netbird/management/server/cache" @@ -109,7 +109,7 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp t.Fatal(err) } - iv, _ := integrations.NewIntegratedValidator(ctx, peersmanager, settingsManagerMock, eventStore, cacheStore) + iv, _ := validator.NewIntegratedValidator(ctx, peersmanager, settingsManagerMock, eventStore, cacheStore) metrics, err := telemetry.NewDefaultAppMetrics(ctx) require.NoError(t, err) diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index 834a49a09..289f1906f 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -27,7 +27,7 @@ import ( "github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/management/server/job" - "github.com/netbirdio/management-integrations/integrations" + "github.com/netbirdio/netbird/management/server/integrations/integrated_validator/validator" "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" @@ -66,8 +66,8 @@ import ( "github.com/netbirdio/netbird/route" mgmt "github.com/netbirdio/netbird/shared/management/client" mgmtProto "github.com/netbirdio/netbird/shared/management/proto" - relayClient "github.com/netbirdio/netbird/shared/relay/client" "github.com/netbirdio/netbird/shared/netiputil" + relayClient "github.com/netbirdio/netbird/shared/relay/client" signal "github.com/netbirdio/netbird/shared/signal/client" "github.com/netbirdio/netbird/shared/signal/proto" signalServer "github.com/netbirdio/netbird/signal/server" @@ -1641,7 +1641,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri return nil, "", err } - ia, _ := integrations.NewIntegratedValidator(context.Background(), peersManager, nil, eventStore, cacheStore) + ia, _ := validator.NewIntegratedValidator(context.Background(), peersManager, nil, eventStore, cacheStore) metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) require.NoError(t, err) diff --git a/client/server/server_test.go b/client/server/server_test.go index 641cd85fe..66e0fcc4c 100644 --- a/client/server/server_test.go +++ b/client/server/server_test.go @@ -13,7 +13,7 @@ import ( "github.com/stretchr/testify/require" "go.opentelemetry.io/otel" - "github.com/netbirdio/management-integrations/integrations" + "github.com/netbirdio/netbird/management/server/integrations/integrated_validator/validator" "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" @@ -315,7 +315,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve return nil, "", err } - ia, _ := integrations.NewIntegratedValidator(context.Background(), peersManager, settingsManagerMock, eventStore, cacheStore) + ia, _ := validator.NewIntegratedValidator(context.Background(), peersManager, settingsManagerMock, eventStore, cacheStore) metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) require.NoError(t, err) diff --git a/combined/cmd/root.go b/combined/cmd/root.go index db986b4d4..78290388b 100644 --- a/combined/cmd/root.go +++ b/combined/cmd/root.go @@ -332,7 +332,7 @@ func setupServerHooks(servers *serverInstances, cfg *CombinedConfig) { log.Infof("Signal server registered on port %s", cfg.Server.ListenAddress) } - s.SetHandlerFunc(createCombinedHandler(grpcSrv, s.APIHandler(), servers.relaySrv, servers.metricsServer.Meter, cfg)) + s.SetHandlerFunc(createCombinedHandler(grpcSrv, s.APIHandler(), s.IDPHandler(), servers.relaySrv, servers.metricsServer.Meter, cfg)) if servers.relaySrv != nil { log.Infof("Relay WebSocket handler added (path: /relay)") } @@ -521,7 +521,7 @@ func createManagementServer(cfg *CombinedConfig, mgmtConfig *nbconfig.Config) (* } // createCombinedHandler creates an HTTP handler that multiplexes Management, Signal (via wsproxy), and Relay WebSocket traffic -func createCombinedHandler(grpcServer *grpc.Server, httpHandler http.Handler, relaySrv *relayServer.Server, meter metric.Meter, cfg *CombinedConfig) http.Handler { +func createCombinedHandler(grpcServer *grpc.Server, httpHandler http.Handler, idpHandler http.Handler, relaySrv *relayServer.Server, meter metric.Meter, cfg *CombinedConfig) http.Handler { wsProxy := wsproxyserver.New(grpcServer, wsproxyserver.WithOTelMeter(meter)) var relayAcceptFn func(conn listener.Conn) @@ -556,6 +556,10 @@ func createCombinedHandler(grpcServer *grpc.Server, httpHandler http.Handler, re http.Error(w, "Relay service not enabled", http.StatusNotFound) } + // Embedded IdP (Dex) + case idpHandler != nil && strings.HasPrefix(r.URL.Path, "/oauth2"): + idpHandler.ServeHTTP(w, r) + // Management HTTP API (default) default: httpHandler.ServeHTTP(w, r) diff --git a/management/internals/server/boot.go b/management/internals/server/boot.go index 7c655f020..46e475143 100644 --- a/management/internals/server/boot.go +++ b/management/internals/server/boot.go @@ -10,8 +10,10 @@ import ( "slices" "time" + "github.com/gorilla/mux" grpcMiddleware "github.com/grpc-ecosystem/go-grpc-middleware/v2" "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/realip" + "github.com/rs/cors" "github.com/rs/xid" log "github.com/sirupsen/logrus" "google.golang.org/grpc" @@ -19,7 +21,6 @@ import ( "google.golang.org/grpc/keepalive" cachestore "github.com/eko/gocache/lib/v4/store" - "github.com/netbirdio/management-integrations/integrations" "github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/formatter/hook" @@ -27,16 +28,20 @@ import ( accesslogsmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs/manager" nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" "github.com/netbirdio/netbird/management/server/activity" + activitystore "github.com/netbirdio/netbird/management/server/activity/store" nbcache "github.com/netbirdio/netbird/management/server/cache" nbContext "github.com/netbirdio/netbird/management/server/context" nbhttp "github.com/netbirdio/netbird/management/server/http" "github.com/netbirdio/netbird/management/server/http/middleware" + "github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/telemetry" mgmtProto "github.com/netbirdio/netbird/shared/management/proto" "github.com/netbirdio/netbird/util/crypt" ) +const apiPrefix = "/api" + var ( kaep = keepalive.EnforcementPolicy{ MinTime: 15 * time.Second, @@ -94,12 +99,17 @@ func (s *BaseServer) Store() store.Store { func (s *BaseServer) EventStore() activity.Store { return Create(s, func() activity.Store { - integrationMetrics, err := integrations.InitIntegrationMetrics(context.Background(), s.Metrics()) - if err != nil { - log.Fatalf("failed to initialize integration metrics: %v", err) + var err error + key := s.Config.DataStoreEncryptionKey + if key == "" { + log.Debugf("generate new activity store encryption key") + key, err = crypt.GenerateKey() + if err != nil { + log.Fatalf("failed to generate event store encryption key: %v", err) + } } - eventStore, _, err := integrations.InitEventStore(context.Background(), s.Config.Datadir, s.Config.DataStoreEncryptionKey, integrationMetrics) + eventStore, err := activitystore.NewSqlStore(context.Background(), s.Config.Datadir, key) if err != nil { log.Fatalf("failed to initialize event store: %v", err) } @@ -110,7 +120,7 @@ func (s *BaseServer) EventStore() activity.Store { func (s *BaseServer) APIHandler() http.Handler { return Create(s, func() http.Handler { - httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.IdpManager(), s.ServiceManager(), s.ReverseProxyDomainManager(), s.AccessLogsManager(), s.ReverseProxyGRPCServer(), s.Config.ReverseProxy.TrustedHTTPProxies, s.RateLimiter()) + httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.Router(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.PermissionsManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.IdpManager(), s.ServiceManager(), s.ReverseProxyDomainManager(), s.AccessLogsManager(), s.ReverseProxyGRPCServer(), s.Config.ReverseProxy.TrustedHTTPProxies, s.RateLimiter(), s.IsValidChildAccount) if err != nil { log.Fatalf("failed to create API handler: %v", err) } @@ -118,6 +128,22 @@ func (s *BaseServer) APIHandler() http.Handler { }) } +// IDPHandler returns the HTTP handler for the embedded IdP (Dex), or nil if +// the deployment isn't using the embedded variant. +func (s *BaseServer) IDPHandler() http.Handler { + embeddedIdP, ok := s.IdpManager().(*idp.EmbeddedIdPManager) + if !ok || embeddedIdP == nil { + return nil + } + return cors.AllowAll().Handler(embeddedIdP.Handler()) +} + +func (s *BaseServer) Router() *mux.Router { + return Create(s, func() *mux.Router { + return mux.NewRouter().PathPrefix(apiPrefix).Subrouter() + }) +} + func (s *BaseServer) RateLimiter() *middleware.APIRateLimiter { return Create(s, func() *middleware.APIRateLimiter { cfg, enabled := middleware.RateLimiterConfigFromEnv() diff --git a/management/internals/server/controllers.go b/management/internals/server/controllers.go index 794c3ebe0..1b2556809 100644 --- a/management/internals/server/controllers.go +++ b/management/internals/server/controllers.go @@ -19,6 +19,7 @@ import ( "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/auth" "github.com/netbirdio/netbird/management/server/integrations/integrated_validator" + "github.com/netbirdio/netbird/management/server/integrations/integrated_validator/validator" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/job" nbjwt "github.com/netbirdio/netbird/shared/auth/jwt" @@ -38,7 +39,7 @@ func (s *BaseServer) JobManager() *job.Manager { func (s *BaseServer) IntegratedValidator() integrated_validator.IntegratedValidator { return Create(s, func() integrated_validator.IntegratedValidator { - integratedPeerValidator, err := integrations.NewIntegratedValidator( + integratedPeerValidator, err := validator.NewIntegratedValidator( context.Background(), s.PeersManager(), s.SettingsManager(), diff --git a/management/internals/server/modules.go b/management/internals/server/modules.go index ea94245d5..a70da855a 100644 --- a/management/internals/server/modules.go +++ b/management/internals/server/modules.go @@ -57,13 +57,7 @@ func (s *BaseServer) GeoLocationManager() geolocation.Geolocation { func (s *BaseServer) PermissionsManager() permissions.Manager { return Create(s, func() permissions.Manager { - manager := integrations.InitPermissionsManager(s.Store(), s.Metrics().GetMeter()) - - s.AfterInit(func(s *BaseServer) { - manager.SetAccountManager(s.AccountManager()) - }) - - return manager + return permissions.NewManager(s.Store()) }) } @@ -153,7 +147,6 @@ func (s *BaseServer) IdpManager() idp.Manager { return idpManager } - return nil }) } @@ -235,3 +228,7 @@ func (s *BaseServer) ReverseProxyDomainManager() *manager.Manager { return &m }) } + +func (s *BaseServer) IsValidChildAccount(_ context.Context, _, _, _ string) bool { + return false +} diff --git a/management/internals/server/server.go b/management/internals/server/server.go index 9b8716da1..63d13baab 100644 --- a/management/internals/server/server.go +++ b/management/internals/server/server.go @@ -188,7 +188,7 @@ func (s *BaseServer) Start(ctx context.Context) error { log.WithContext(srvCtx).Infof("running gRPC backward compatibility server: %s", compatListener.Addr().String()) } - rootHandler := s.handlerFunc(srvCtx, s.GRPCServer(), s.APIHandler(), s.Metrics().GetMeter()) + rootHandler := s.handlerFunc(srvCtx, s.GRPCServer(), s.APIHandler(), s.IDPHandler(), s.Metrics().GetMeter()) switch { case s.certManager != nil: // a call to certManager.Listener() always creates a new listener so we do it once @@ -299,7 +299,7 @@ func (s *BaseServer) SetHandlerFunc(handler http.Handler) { log.Tracef("custom handler set successfully") } -func (s *BaseServer) handlerFunc(_ context.Context, gRPCHandler *grpc.Server, httpHandler http.Handler, meter metric.Meter) http.Handler { +func (s *BaseServer) handlerFunc(_ context.Context, gRPCHandler *grpc.Server, httpHandler http.Handler, idpHandler http.Handler, meter metric.Meter) http.Handler { // Check if a custom handler was set (for multiplexing additional services) if customHandler, ok := s.GetContainer("customHandler"); ok { if handler, ok := customHandler.(http.Handler); ok { @@ -318,6 +318,8 @@ func (s *BaseServer) handlerFunc(_ context.Context, gRPCHandler *grpc.Server, ht gRPCHandler.ServeHTTP(writer, request) case request.URL.Path == wsproxy.ProxyPath+wsproxy.ManagementComponent: wsProxy.Handler().ServeHTTP(writer, request) + case idpHandler != nil && strings.HasPrefix(request.URL.Path, "/oauth2"): + idpHandler.ServeHTTP(writer, request) default: httpHandler.ServeHTTP(writer, request) } diff --git a/management/server/http/handler.go b/management/server/http/handler.go index 1e2c710db..0abdb854d 100644 --- a/management/server/http/handler.go +++ b/management/server/http/handler.go @@ -15,15 +15,13 @@ import ( "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs" - "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxytoken" + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service/manager" nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" idpmanager "github.com/netbirdio/netbird/management/server/idp" - "github.com/netbirdio/management-integrations/integrations" - "github.com/netbirdio/netbird/management/internals/controllers/network_map" "github.com/netbirdio/netbird/management/internals/modules/zones" zonesManager "github.com/netbirdio/netbird/management/internals/modules/zones/manager" @@ -32,12 +30,10 @@ import ( "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/settings" - "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/http/handlers/proxy" - nbpeers "github.com/netbirdio/netbird/management/internals/modules/peers" "github.com/netbirdio/netbird/management/server/auth" "github.com/netbirdio/netbird/management/server/geolocation" nbgroups "github.com/netbirdio/netbird/management/server/groups" @@ -56,17 +52,14 @@ import ( "github.com/netbirdio/netbird/management/server/http/middleware" "github.com/netbirdio/netbird/management/server/http/middleware/bypass" nbinstance "github.com/netbirdio/netbird/management/server/instance" - "github.com/netbirdio/netbird/management/server/integrations/integrated_validator" nbnetworks "github.com/netbirdio/netbird/management/server/networks" "github.com/netbirdio/netbird/management/server/networks/resources" "github.com/netbirdio/netbird/management/server/networks/routers" "github.com/netbirdio/netbird/management/server/telemetry" ) -const apiPrefix = "/api" - // NewAPIHandler creates the Management service HTTP API handler registering all the available endpoints. -func NewAPIHandler(ctx context.Context, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, integratedValidator integrated_validator.IntegratedValidator, proxyController port_forwarding.Controller, permissionsManager permissions.Manager, peersManager nbpeers.Manager, settingsManager settings.Manager, zManager zones.Manager, rManager records.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager, serviceManager service.Manager, reverseProxyDomainManager *manager.Manager, reverseProxyAccessLogsManager accesslogs.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, trustedHTTPProxies []netip.Prefix, rateLimiter *middleware.APIRateLimiter) (http.Handler, error) { +func NewAPIHandler(ctx context.Context, router *mux.Router, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, permissionsManager permissions.Manager, settingsManager settings.Manager, zManager zones.Manager, rManager records.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager, serviceManager service.Manager, reverseProxyDomainManager *manager.Manager, reverseProxyAccessLogsManager accesslogs.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, trustedHTTPProxies []netip.Prefix, rateLimiter *middleware.APIRateLimiter, isValidChildAccount middleware.IsValidChildAccountFunc) (http.Handler, error) { // Register bypass paths for unauthenticated endpoints if err := bypass.AddBypassPath("/api/instance"); err != nil { @@ -100,25 +93,16 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks accountManager.GetUserFromUserAuth, rateLimiter, appMetrics.GetMeter(), + isValidChildAccount, ) corsMiddleware := cors.AllowAll() - rootRouter := mux.NewRouter() metricsMiddleware := appMetrics.HTTPMiddleware() - prefix := apiPrefix - router := rootRouter.PathPrefix(prefix).Subrouter() - router.Use(metricsMiddleware.Handler, corsMiddleware.Handler, authMiddleware.Handler) - if _, err := integrations.RegisterHandlers(ctx, prefix, router, accountManager, integratedValidator, appMetrics.GetMeter(), permissionsManager, peersManager, proxyController, settingsManager); err != nil { - return nil, fmt.Errorf("register integrations endpoints: %w", err) - } - - // Check if embedded IdP is enabled for instance manager - embeddedIdP, embeddedIdpEnabled := idpManager.(*idpmanager.EmbeddedIdPManager) - instanceManager, err := nbinstance.NewManager(ctx, accountManager.GetStore(), embeddedIdP) + instanceManager, err := nbinstance.NewManager(ctx, accountManager.GetStore(), idpManager) if err != nil { return nil, fmt.Errorf("failed to create instance manager: %w", err) } @@ -154,10 +138,5 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks oauthHandler.RegisterEndpoints(router) } - // Mount embedded IdP handler at /oauth2 path if configured - if embeddedIdpEnabled { - rootRouter.PathPrefix("/oauth2").Handler(corsMiddleware.Handler(embeddedIdP.Handler())) - } - - return rootRouter, nil + return router, nil } diff --git a/management/server/http/middleware/auth_middleware.go b/management/server/http/middleware/auth_middleware.go index 6d075d9c2..34df0de23 100644 --- a/management/server/http/middleware/auth_middleware.go +++ b/management/server/http/middleware/auth_middleware.go @@ -11,8 +11,6 @@ import ( log "github.com/sirupsen/logrus" "go.opentelemetry.io/otel/metric" - "github.com/netbirdio/management-integrations/integrations" - serverauth "github.com/netbirdio/netbird/management/server/auth" nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/http/middleware/bypass" @@ -27,6 +25,8 @@ type SyncUserJWTGroupsFunc func(ctx context.Context, userAuth auth.UserAuth) err type GetUserFromUserAuthFunc func(ctx context.Context, userAuth auth.UserAuth) (*types.User, error) +type IsValidChildAccountFunc func(ctx context.Context, userID, accountID, childAccountID string) bool + // AuthMiddleware middleware to verify personal access tokens (PAT) and JWT tokens type AuthMiddleware struct { authManager serverauth.Manager @@ -35,6 +35,7 @@ type AuthMiddleware struct { syncUserJWTGroups SyncUserJWTGroupsFunc rateLimiter *APIRateLimiter patUsageTracker *PATUsageTracker + isValidChildAccount IsValidChildAccountFunc } // NewAuthMiddleware instance constructor @@ -45,6 +46,7 @@ func NewAuthMiddleware( getUserFromUserAuth GetUserFromUserAuthFunc, rateLimiter *APIRateLimiter, meter metric.Meter, + isValidChildAccount IsValidChildAccountFunc, ) *AuthMiddleware { var patUsageTracker *PATUsageTracker if meter != nil { @@ -62,6 +64,7 @@ func NewAuthMiddleware( getUserFromUserAuth: getUserFromUserAuth, rateLimiter: rateLimiter, patUsageTracker: patUsageTracker, + isValidChildAccount: isValidChildAccount, } } @@ -124,7 +127,7 @@ func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, authHeaderParts [] } if impersonate, ok := r.URL.Query()["account"]; ok && len(impersonate) == 1 { - if integrations.IsValidChildAccount(ctx, userAuth.UserId, userAuth.AccountId, impersonate[0]) { + if m.isValidChildAccount(ctx, userAuth.UserId, userAuth.AccountId, impersonate[0]) { userAuth.AccountId = impersonate[0] userAuth.IsChild = true } @@ -203,7 +206,7 @@ func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts [] } if impersonate, ok := r.URL.Query()["account"]; ok && len(impersonate) == 1 { - if integrations.IsValidChildAccount(r.Context(), userAuth.UserId, userAuth.AccountId, impersonate[0]) { + if m.isValidChildAccount(r.Context(), userAuth.UserId, userAuth.AccountId, impersonate[0]) { userAuth.AccountId = impersonate[0] userAuth.IsChild = true } diff --git a/management/server/http/middleware/auth_middleware_test.go b/management/server/http/middleware/auth_middleware_test.go index 8f736fbfd..24cf8fce5 100644 --- a/management/server/http/middleware/auth_middleware_test.go +++ b/management/server/http/middleware/auth_middleware_test.go @@ -211,6 +211,7 @@ func TestAuthMiddleware_Handler(t *testing.T) { }, disabledLimiter, nil, + func(_ context.Context, _, _, _ string) bool { return false }, ) handlerToTest := authMiddleware.Handler(nextHandler) @@ -270,6 +271,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) { }, NewAPIRateLimiter(rateLimitConfig), nil, + func(_ context.Context, _, _, _ string) bool { return false }, ) handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -322,6 +324,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) { }, NewAPIRateLimiter(rateLimitConfig), nil, + func(_ context.Context, _, _, _ string) bool { return false }, ) handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -365,6 +368,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) { }, NewAPIRateLimiter(rateLimitConfig), nil, + func(_ context.Context, _, _, _ string) bool { return false }, ) handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -409,6 +413,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) { }, NewAPIRateLimiter(rateLimitConfig), nil, + func(_ context.Context, _, _, _ string) bool { return false }, ) handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -473,6 +478,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) { }, NewAPIRateLimiter(rateLimitConfig), nil, + func(_ context.Context, _, _, _ string) bool { return false }, ) handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -532,6 +538,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) { }, NewAPIRateLimiter(rateLimitConfig), nil, + func(_ context.Context, _, _, _ string) bool { return false }, ) handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -587,6 +594,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) { }, NewAPIRateLimiter(rateLimitConfig), nil, + func(_ context.Context, _, _, _ string) bool { return false }, ) handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -687,6 +695,7 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) { }, disabledLimiter, nil, + func(_ context.Context, _, _, _ string) bool { return false }, ) for _, tc := range tt { diff --git a/management/server/http/testing/testing_tools/channel/channel.go b/management/server/http/testing/testing_tools/channel/channel.go index 3c4ea98d0..8da9c7ad4 100644 --- a/management/server/http/testing/testing_tools/channel/channel.go +++ b/management/server/http/testing/testing_tools/channel/channel.go @@ -7,6 +7,7 @@ import ( "time" "github.com/golang-jwt/jwt/v5" + "github.com/gorilla/mux" "github.com/stretchr/testify/assert" "go.opentelemetry.io/otel/metric/noop" @@ -135,7 +136,8 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee customZonesManager := zonesManager.NewManager(store, am, permissionsManager, "") zoneRecordsManager := recordsManager.NewManager(store, am, permissionsManager) - apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil, nil) + apiRouter := mux.NewRouter().PathPrefix("/api").Subrouter() + apiHandler, err := http2.NewAPIHandler(context.Background(), apiRouter, am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, permissionsManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil, nil, nil) if err != nil { t.Fatalf("Failed to create API handler: %v", err) } @@ -264,7 +266,8 @@ func BuildApiBlackBoxWithDBStateAndPeerChannel(t testing_tools.TB, sqlFile strin customZonesManager := zonesManager.NewManager(store, am, permissionsManager, "") zoneRecordsManager := recordsManager.NewManager(store, am, permissionsManager) - apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil, nil) + apiRouter := mux.NewRouter().PathPrefix("/api").Subrouter() + apiHandler, err := http2.NewAPIHandler(context.Background(), apiRouter, am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, permissionsManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil, nil, nil) if err != nil { t.Fatalf("Failed to create API handler: %v", err) } diff --git a/management/server/integrations/integrated_validator/validator/validator.go b/management/server/integrations/integrated_validator/validator/validator.go new file mode 100644 index 000000000..db1d34373 --- /dev/null +++ b/management/server/integrations/integrated_validator/validator/validator.go @@ -0,0 +1,62 @@ +package validator + +import ( + "context" + + cachestore "github.com/eko/gocache/lib/v4/store" + + "github.com/netbirdio/netbird/management/internals/modules/peers" + "github.com/netbirdio/netbird/management/server/activity" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/settings" + "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/management/proto" +) + +type IntegratedValidatorImpl struct{} + +func NewIntegratedValidator(_ context.Context, _ peers.Manager, _ settings.Manager, _ activity.Store, _ cachestore.StoreInterface) (*IntegratedValidatorImpl, error) { + return &IntegratedValidatorImpl{}, nil +} + +func (v *IntegratedValidatorImpl) ValidateExtraSettings(context.Context, *types.ExtraSettings, *types.ExtraSettings, string, string) error { + return nil +} + +func (v *IntegratedValidatorImpl) ValidatePeer(_ context.Context, update *nbpeer.Peer, _ *nbpeer.Peer, _ string, _ string, _ string, _ []string, _ *types.ExtraSettings) (*nbpeer.Peer, bool, error) { + return update, false, nil +} + +func (v *IntegratedValidatorImpl) PreparePeer(_ context.Context, _ string, peer *nbpeer.Peer, _ []string, _ *types.ExtraSettings, _ bool) *nbpeer.Peer { + return peer.Copy() +} + +func (v *IntegratedValidatorImpl) IsNotValidPeer(_ context.Context, _ string, _ *nbpeer.Peer, _ []string, _ *types.ExtraSettings) (bool, bool, error) { + return false, false, nil +} + +func (v *IntegratedValidatorImpl) GetValidatedPeers(_ context.Context, _ string, _ []*types.Group, peers []*nbpeer.Peer, _ *types.ExtraSettings) (map[string]struct{}, error) { + validatedPeers := make(map[string]struct{}) + for _, p := range peers { + validatedPeers[p.ID] = struct{}{} + } + return validatedPeers, nil +} + +func (v *IntegratedValidatorImpl) GetInvalidPeers(_ context.Context, _ string, _ *types.ExtraSettings) (map[string]string, error) { + return make(map[string]string), nil +} + +func (v *IntegratedValidatorImpl) PeerDeleted(_ context.Context, _, _ string, _ *types.ExtraSettings) error { + return nil +} + +func (v *IntegratedValidatorImpl) SetPeerInvalidationListener(_ func(accountID string, peerIDs []string)) { +} + +func (v *IntegratedValidatorImpl) Stop(_ context.Context) { +} + +func (v *IntegratedValidatorImpl) ValidateFlowResponse(_ context.Context, _ string, flowResponse *proto.PKCEAuthorizationFlow) *proto.PKCEAuthorizationFlow { + return flowResponse +} diff --git a/shared/management/client/client_test.go b/shared/management/client/client_test.go index a8e8172dc..be2c009ad 100644 --- a/shared/management/client/client_test.go +++ b/shared/management/client/client_test.go @@ -17,7 +17,7 @@ import ( "google.golang.org/grpc/codes" "google.golang.org/grpc/status" - "github.com/netbirdio/management-integrations/integrations" + "github.com/netbirdio/netbird/management/server/integrations/integrated_validator/validator" ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager" "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" @@ -103,7 +103,7 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) { t.Fatal(err) } - ia, _ := integrations.NewIntegratedValidator(ctx, peersManger, settingsManagerMock, eventStore, cacheStore) + ia, _ := validator.NewIntegratedValidator(ctx, peersManger, settingsManagerMock, eventStore, cacheStore) metrics, err := telemetry.NewDefaultAppMetrics(ctx) require.NoError(t, err)